Skip to content

Commit 6875797

Browse files
committed
add Contrast Enhancement module with CLAHE algorithm
1 parent e545b2f commit 6875797

3 files changed

Lines changed: 155 additions & 0 deletions

File tree

modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .blc import BLC
77
from .bnf import BNF
88
from .ccm import CCM
9+
from .ceh import CEH
910
from .cfa import CFA
1011
from .cnf import CNF
1112
from .csc import CSC

modules/ceh.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# File: ceh.py
2+
# Description: Contrast Enhancement (contrast limited adaptive histogram equalization)
3+
# Created: 2021/11/12 21:46
4+
# Author: Qiu Jueqin (qiujueqin@gmail.com)
5+
6+
7+
import math
8+
import numpy as np
9+
10+
from .basic_module import BasicModule, register_dependent_modules
11+
from .helpers import pad, crop
12+
13+
14+
@register_dependent_modules('csc')
15+
class CEH(BasicModule):
16+
def __init__(self, cfg):
17+
super().__init__(cfg)
18+
self.y_tiles, self.x_tiles = self.params.tiles
19+
assert self.y_tiles >= 2 and self.x_tiles >= 2, 'only tiles >= 2 is supported'
20+
21+
self.tile_height = math.ceil(cfg.hardware.raw_height / self.y_tiles)
22+
self.tile_width = math.ceil(cfg.hardware.raw_width / self.x_tiles)
23+
24+
y_pads = self.tile_height * self.y_tiles - cfg.hardware.raw_height
25+
x_pads = self.tile_width * self.x_tiles - cfg.hardware.raw_width
26+
self.pads = (y_pads // 2, y_pads - y_pads // 2, x_pads // 2, x_pads - x_pads // 2)
27+
28+
# Weights for LUTs interpolation
29+
self.left_lut_weights = np.linspace(1024, 0, self.tile_width, dtype=np.int32).reshape((1, -1)) # x1024
30+
self.top_lut_weights = np.linspace(1024, 0, self.tile_height, dtype=np.int32).reshape((-1, 1)) # x1024
31+
32+
self.luts = np.empty(shape=(self.y_tiles, self.x_tiles, 256), dtype=np.uint8) # LUTs w.r.t. tiles
33+
34+
def execute(self, data):
35+
y_image = data['y_image'].astype(np.int32)
36+
y_image = pad(y_image, pads=self.pads)
37+
38+
# ---------- Generate tile-wise look-up tables ----------
39+
for ty in range(self.y_tiles):
40+
for tx in range(self.x_tiles):
41+
y_tile = y_image[ty * self.tile_height: (ty + 1) * self.tile_height,
42+
tx * self.tile_width: (tx + 1) * self.tile_width]
43+
self.luts[ty, tx] = self._get_tile_lut(y_tile)
44+
45+
# ---------- Interpolate and apply LUTs for different image blocks ----------
46+
ceh_y_image = np.empty_like(y_image).astype(np.uint8)
47+
for iy in range(self.y_tiles + 1):
48+
for ix in range(self.x_tiles + 1):
49+
y0 = iy * self.tile_height - self.tile_height // 2
50+
y1 = min(y0 + self.tile_height, y_image.shape[0])
51+
x0 = ix * self.tile_width - self.tile_width // 2
52+
x1 = min(x0 + self.tile_width, y_image.shape[1])
53+
y0 = max(y0, 0)
54+
x0 = max(x0, 0)
55+
56+
y_block = y_image[y0:y1, x0:x1]
57+
58+
if self._is_corner_block(ix, iy):
59+
lut_y_idx = 0 if iy == 0 else self.y_tiles - 1
60+
lut_x_idx = 0 if ix == 0 else self.x_tiles - 1
61+
lut = self.luts[lut_y_idx, lut_x_idx]
62+
ceh_y_image[y0:y1, x0:x1] = lut[y_block]
63+
64+
elif self._is_top_or_bottom_block(ix, iy):
65+
lut_y_idx = 0 if iy == 0 else self.y_tiles - 1
66+
left_lut = self.luts[lut_y_idx, ix - 1]
67+
right_lut = self.luts[lut_y_idx, ix]
68+
ceh_y_image[y0:y1, x0:x1] = self._interp_top_bottom_block(y_block, left_lut, right_lut)
69+
70+
elif self._is_left_or_right_block(ix, iy):
71+
lut_x_idx = 0 if ix == 0 else self.x_tiles - 1
72+
top_lut = self.luts[iy - 1, lut_x_idx]
73+
bottom_lut = self.luts[iy, lut_x_idx]
74+
ceh_y_image[y0:y1, x0:x1] = self._interp_left_right_block(y_block, top_lut, bottom_lut)
75+
76+
else:
77+
tl_lut = self.luts[iy - 1, ix - 1]
78+
tr_lut = self.luts[iy - 1, ix]
79+
bl_lut = self.luts[iy, ix - 1]
80+
br_lut = self.luts[iy, ix]
81+
ceh_y_image[y0:y1, x0:x1] = self._interp_neighbor_block(y_block, tl_lut, tr_lut, bl_lut, br_lut)
82+
83+
data['y_image'] = crop(ceh_y_image, self.pads)
84+
85+
def _get_tile_lut(self, tiled_array):
86+
hist = np.histogram(tiled_array, bins=256, range=(0, self.cfg.saturation_values.sdr))[0]
87+
clipped_hist = np.clip(hist, 0, self.params.clip_limit * max(hist))
88+
89+
num_clipped_pixels = (hist - clipped_hist).sum()
90+
91+
hist = clipped_hist + num_clipped_pixels / 256
92+
pdf = hist / hist.sum()
93+
cdf = np.cumsum(pdf)
94+
95+
lut = (cdf * self.cfg.saturation_values.sdr).astype(np.uint8)
96+
return lut
97+
98+
def _interp_top_bottom_block(self, block, left_lut, right_lut):
99+
return np.right_shift(
100+
self.left_lut_weights * left_lut[block].astype(np.int32) +
101+
(1024 - self.left_lut_weights) * right_lut[block].astype(np.int32), 10
102+
).astype(np.uint8)
103+
104+
def _interp_left_right_block(self, block, top_lut, bottom_lut):
105+
return np.right_shift(
106+
self.top_lut_weights * top_lut[block].astype(np.int32) +
107+
(1024 - self.top_lut_weights) * bottom_lut[block].astype(np.int32), 10
108+
).astype(np.uint8)
109+
110+
def _interp_neighbor_block(self, block, tl_lut, tr_lut, bl_lut, br_lut):
111+
top_block = self._interp_top_bottom_block(block, tl_lut, tr_lut).astype(np.int32)
112+
bottom_block = self._interp_top_bottom_block(block, bl_lut, br_lut).astype(np.int32)
113+
return np.right_shift(
114+
self.top_lut_weights * top_block + (1024 - self.top_lut_weights) * bottom_block, 10
115+
).astype(np.uint8)
116+
117+
def _is_corner_block(self, ix, iy):
118+
""" Determine if the current image block is locating in a corner region """
119+
return ((iy == 0 and ix == 0) or
120+
(iy == 0 and ix == self.x_tiles) or
121+
(iy == self.y_tiles and ix == 0) or
122+
(iy == self.y_tiles and ix == self.x_tiles))
123+
124+
def _is_top_or_bottom_block(self, ix, iy):
125+
return (iy == 0 or iy == self.y_tiles) and not self._is_corner_block(ix, iy)
126+
127+
def _is_left_or_right_block(self, ix, iy):
128+
return (ix == 0 or ix == self.x_tiles) and not self._is_corner_block(ix, iy)

modules/helpers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,32 @@ def pad(array, pads, mode='reflect'):
8383
return np.pad(array, pads, mode)
8484

8585

86+
def crop(array, crops):
87+
"""
88+
Crop an array by given margins
89+
:param array: np.ndarray(H, W, ...)
90+
:param crops: {int, sequence}
91+
if int, crops top, bottom, left, and right directions with the same margin
92+
if 2-element sequence: (y-direction crop, x-direction crop)
93+
if 4-element sequence: (top crop, bottom crop, left crop, right crop)
94+
:return: cropped array: np.ndarray(H', W', ...)
95+
"""
96+
97+
if isinstance(crops, (list, tuple, np.ndarray)):
98+
if len(crops) == 2:
99+
top_crop = bottom_crop = crops[0]
100+
left_crop = right_crop = crops[1]
101+
elif len(crops) == 4:
102+
top_crop, bottom_crop, left_crop, right_crop = crops
103+
else:
104+
raise NotImplementedError
105+
else:
106+
top_crop = bottom_crop = left_crop = right_crop = crops
107+
108+
height, width = array.shape
109+
return array[top_crop: height - bottom_crop, left_crop: width - right_crop, ...]
110+
111+
86112
def shift_array(padded_array, window_size):
87113
"""
88114
Shift an array within a window and generate window_size**2 shifted arrays

0 commit comments

Comments
 (0)