Skip to content

Commit 00c68ea

Browse files
committed
refactor dependent module checking
1 parent 4b62381 commit 00c68ea

3 files changed

Lines changed: 27 additions & 31 deletions

File tree

modules/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from .helpers import *
2-
from .basic_module import BasicModule
31
from .aaf import AAF
42
from .awb import AWB
53
from .bcc import BCC

modules/basic_module.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,32 @@
44
# Author: Qiu Jueqin (qiujueqin@gmail.com)
55

66

7-
def register_dependent_modules(dependent_module):
8-
""" A decorator to register dependent ISP modules """
9-
10-
if not isinstance(dependent_module, (list, tuple)):
11-
dependent_module = [dependent_module]
7+
MODULE_DEPENDENCIES = {}
128

13-
def _register_dependent_modules(cls):
14-
orig_init = cls.__init__
159

16-
def override_init(self, *args, **kws):
17-
orig_init(self, *args, **kws)
18-
self.dependent_modules = tuple(dependent_module)
10+
def register_dependent_modules(dependent_module_names):
11+
""" A decorator to register dependent ISP modules """
12+
if not isinstance(dependent_module_names, (list, tuple)):
13+
dependent_module_names = tuple([dependent_module_names])
1914

20-
cls.__init__ = override_init
15+
def wrapper(cls):
16+
MODULE_DEPENDENCIES[cls.__name__] = dependent_module_names
2117
return cls
2218

23-
return _register_dependent_modules
19+
return wrapper
2420

2521

2622
class BasicModule:
2723
def __init__(self, cfg):
2824
self.cfg = cfg
29-
3025
module_name = self.__class__.__name__.lower()
3126
self.params = cfg[module_name] if module_name in cfg else None
3227

3328
def execute(self, data):
3429
"""
35-
:param data: a dict containing data flow in the pipeline, as well as other intermediate results,
36-
e.g., YCbCr image from color space conversion module, edge map from edge enhancement module.
37-
Instead of returning a processed result, the execute() method in each module will in-place
38-
modify the data dict
30+
:param data: a dict containing data flow in the pipeline, as well as other intermediate
31+
results, e.g., YCbCr image from color space conversion module, edge map from edge
32+
enhancement module. Instead of returning a processed result, the execute() method in
33+
each module will in-place modify the data dict
3934
"""
4035
raise NotImplemented

pipeline.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717

1818
from utils.yacs import Config
19+
from modules.basic_module import MODULE_DEPENDENCIES
1920

2021

2122
class Pipeline:
@@ -47,11 +48,12 @@ def get_saturation_values(self):
4748

4849
# Saturation values should be carefully calculated if BLC module is activated
4950
if 'blc' in self.cfg.module_enable_status:
50-
hdr_r_max_value = raw_max_value - self.cfg.blc.bl_r
51-
hdr_b_max_value = raw_max_value - self.cfg.blc.bl_b
52-
hdr_gr_max_value = int(raw_max_value - self.cfg.blc.bl_gr + hdr_r_max_value * self.cfg.blc.alpha / 1024)
53-
hdr_gb_max_value = int(raw_max_value - self.cfg.blc.bl_gb + hdr_b_max_value * self.cfg.blc.beta / 1024)
54-
hdr_max_value = max([hdr_r_max_value, hdr_b_max_value, hdr_gr_max_value, hdr_gb_max_value])
51+
blc = self.cfg.blc
52+
hdr_max_r = raw_max_value - blc.bl_r
53+
hdr_max_b = raw_max_value - blc.bl_b
54+
hdr_max_gr = int(raw_max_value - blc.bl_gr + hdr_max_r * blc.alpha / 1024)
55+
hdr_max_gb = int(raw_max_value - blc.bl_gb + hdr_max_b * blc.beta / 1024)
56+
hdr_max_value = max(hdr_max_r, hdr_max_b, hdr_max_gr, hdr_max_gb)
5557
else:
5658
hdr_max_value = raw_max_value
5759

@@ -73,10 +75,11 @@ def get_modules(self):
7375
module_cls = getattr(package, module_name.upper())
7476
module = module_cls(self.cfg)
7577

76-
if hasattr(module, 'dependent_modules'):
77-
for m in module.dependent_modules:
78-
if m not in enabled_modules:
79-
raise RuntimeError('{} is available only if {} is activated'.format(module_name, m))
78+
for m in MODULE_DEPENDENCIES.get(module_cls.__name__, []):
79+
if m not in enabled_modules:
80+
raise RuntimeError(
81+
'{} is unavailable when {} is deactivated'.format(module_name, m)
82+
)
8083

8184
modules[module_name] = module
8285

@@ -145,7 +148,7 @@ def run(self, raw_path, save_dir, load_raw_fn, suffix=''):
145148
"""
146149
A higher level API that writes ISP result into disk
147150
:param raw_path: path to the raw file to be processed
148-
:param save_dir: directory to save the output (the output will share the filename with input)
151+
:param save_dir: directory to save the output (shares the same filename as the input)
149152
:param load_raw_fn: function to load the Bayer array from the raw_path
150153
:param suffix: suffix to added to the output filename
151154
"""
@@ -164,8 +167,8 @@ def batch_run(self, raw_paths, save_dirs, load_raw_fn, suffixes='', num_processe
164167
"""
165168
Batch running with multiprocessing
166169
:param raw_paths: list of paths to the raw files to be executed
167-
:param save_dirs: list of directories to save the outputs. If given a string, it will be copied
168-
to a N-element list, where N is the number of paths in raw_paths
170+
:param save_dirs: list of directories to save the outputs. If given a string, it will be
171+
copied to a N-element list, where N is the number of paths in raw_paths
169172
:param load_raw_fn: function to load the Bayer array from the raw_path
170173
:param suffixes: a list of suffixes to added to the output filenames
171174
:param num_processes: number of processes in multiprocessing

0 commit comments

Comments
 (0)