Skip to content

Commit 93b4997

Browse files
committed
update EEH module: replace bilateral filter with gaussian filter
1 parent a577689 commit 93b4997

2 files changed

Lines changed: 37 additions & 19 deletions

File tree

modules/eeh.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,44 @@
77
import numpy as np
88

99
from .basic_module import BasicModule, register_dependent_modules
10-
from .helpers import bilateral_filter, gen_gaussian_kernel
10+
from .helpers import gaussian_filter, gen_gaussian_kernel
1111

1212

1313
@register_dependent_modules('csc')
1414
class EEH(BasicModule):
1515
def __init__(self, cfg):
1616
super().__init__(cfg)
1717

18-
self.intensity_weights_lut = self.get_intensity_weights_lut(intensity_sigma=1.0) # x1024
19-
spatial_weights = gen_gaussian_kernel(kernel_size=5, sigma=1.0)
20-
self.spatial_weights = (1024 * spatial_weights / spatial_weights.max()).astype(np.int32) # x1024
18+
kernel = gen_gaussian_kernel(kernel_size=7, sigma=5.0)
19+
self.kernel = (1024 * kernel / kernel.max()).astype(np.int32) # x1024
2120

2221
flat_slope = self.params.middle_threshold / (self.params.middle_threshold - self.params.flat_threshold + 1E-6)
2322
edge_slope = self.params.edge_gain / 256
2423

2524
self.flat_slope = np.array(256 * flat_slope, dtype=np.int32) # x256
2625
self.edge_slope = np.array(256 * edge_slope, dtype=np.int32) # x256
27-
self.flat_intercept = np.array(-flat_slope * self.params.flat_threshold, dtype=np.int32)
28-
self.edge_intercept = np.array((1 - edge_slope) * self.params.edge_threshold, dtype=np.int32)
26+
self.flat_intercept = -np.array(256 * flat_slope * self.params.flat_threshold, dtype=np.int32) # x256
27+
self.edge_intercept = np.array(256 * (1 - edge_slope) * self.params.edge_threshold, dtype=np.int32) # x256
2928

3029
def execute(self, data):
3130
y_image = data['y_image'].astype(np.int32)
3231

33-
bf_y_image = bilateral_filter(y_image, self.spatial_weights, self.intensity_weights_lut, right_shift=10)
32+
gf_y_image = gaussian_filter(y_image, self.kernel)
3433

35-
delta = y_image - bf_y_image
34+
delta = y_image - gf_y_image
3635
sign_map = np.sign(delta)
3736
abs_delta = np.abs(delta)
3837

39-
flat_delta = np.right_shift(self.flat_slope * abs_delta, 8) + self.flat_intercept
40-
edge_delta = np.right_shift(self.edge_slope * abs_delta, 8) + self.edge_intercept
38+
flat_delta = np.right_shift(self.flat_slope * abs_delta + self.flat_intercept, 8)
39+
edge_delta = np.right_shift(self.edge_slope * abs_delta + self.edge_intercept, 8)
4140
enhanced_delta = sign_map * (
4241
(abs_delta > self.params.flat_threshold) * (abs_delta <= self.params.middle_threshold) * flat_delta +
4342
(abs_delta > self.params.middle_threshold) * (abs_delta <= self.params.edge_threshold) * abs_delta +
4443
(abs_delta > self.params.edge_threshold) * edge_delta
4544
)
4645
enhanced_delta = np.clip(enhanced_delta, -self.params.delta_threshold, self.params.delta_threshold)
4746

48-
eeh_y_image = np.clip(bf_y_image + enhanced_delta, 0, self.cfg.saturation_values.sdr)
47+
eeh_y_image = np.clip(gf_y_image + enhanced_delta, 0, self.cfg.saturation_values.sdr)
4948

5049
data['y_image'] = eeh_y_image.astype(np.uint8)
5150
data['edge_map'] = delta
52-
53-
@staticmethod
54-
def get_intensity_weights_lut(intensity_sigma):
55-
intensity_diff = np.arange(255 ** 2)
56-
exp_lut = 1024 * np.exp(-intensity_diff / (2.0 * (255 * intensity_sigma) ** 2))
57-
return exp_lut.astype(np.int32) # x1024
58-

modules/helpers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,32 @@ def mean_filter(array, filter_size=3):
139139
return (sum(shifted_arrays) / filter_size ** 2).astype(array.dtype)
140140

141141

142+
def gaussian_filter(array, kernel):
143+
"""
144+
A faster reimplementation of the bilateral-filter
145+
:param array: array to be filter: np.ndarray(H, W, ...), must be np.int dtype
146+
:param kernel: np.ndarray(h, w)
147+
:return: filtered array: np.ndarray(H, W, ...)
148+
"""
149+
150+
kh, kw = kernel.shape[:2]
151+
kernel = kernel.flatten()
152+
153+
padded_array = pad(array, pads=(kh // 2, kw // 2))
154+
shifted_arrays = shift_array(padded_array, window_size=(kh, kw))
155+
156+
gf_array = np.zeros_like(array)
157+
weights = np.zeros_like(array)
158+
159+
for i, shifted_array in enumerate(shifted_arrays):
160+
gf_array += kernel[i] * shifted_array
161+
weights += kernel[i]
162+
163+
gf_array = (gf_array / weights).astype(array.dtype)
164+
165+
return gf_array
166+
167+
142168
def bilateral_filter(array, spatial_weights, intensity_weights_lut, right_shift=0):
143169
"""
144170
A faster reimplementation of the bilateral-filter
@@ -170,4 +196,4 @@ def bilateral_filter(array, spatial_weights, intensity_weights_lut, right_shift=
170196

171197
bf_array = (bf_array / weights).astype(array.dtype)
172198

173-
return bf_array
199+
return bf_array

0 commit comments

Comments
 (0)