Skip to content

Commit d08c483

Browse files
committed
add multiprocessing support
1 parent 24972ed commit d08c483

1 file changed

Lines changed: 91 additions & 14 deletions

File tree

pipeline.py

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# Author: Qiu Jueqin (qiujueqin@gmail.com)
55

66

7+
import os.path as op
8+
import sys
79
import time
810
import copy
11+
import math
912
import importlib
1013
from collections import OrderedDict
14+
from multiprocessing import Process
1115

1216
import numpy as np
1317

@@ -22,20 +26,7 @@ def __init__(self, cfg):
2226
with self.cfg.unfreeze():
2327
self.cfg.saturation_values = saturation_values
2428

25-
enabled_modules = tuple(m for m, en in self.cfg.module_enable_status.items() if en)
26-
27-
self.modules = OrderedDict()
28-
for module_name in enabled_modules:
29-
package = importlib.import_module('modules.{}'.format(module_name))
30-
module_cls = getattr(package, module_name.upper())
31-
module = module_cls(self.cfg)
32-
33-
if hasattr(module, 'dependent_modules'):
34-
for m in module.dependent_modules:
35-
if m not in enabled_modules:
36-
raise RuntimeError('{} is available only if {} is activated'.format(module_name, m))
37-
38-
self.modules[module_name] = module
29+
self.modules = self.get_modules()
3930

4031
def get_saturation_values(self):
4132
"""
@@ -63,6 +54,29 @@ def get_saturation_values(self):
6354
'hdr': hdr_max_value,
6455
'sdr': sdr_max_value})
6556

57+
def get_modules(self):
58+
""" Get activated ISP modules according to the configuration """
59+
60+
if op.dirname(__file__) not in sys.path:
61+
sys.path.insert(0, op.dirname(__file__))
62+
63+
enabled_modules = tuple(m for m, en in self.cfg.module_enable_status.items() if en)
64+
65+
modules = OrderedDict()
66+
for module_name in enabled_modules:
67+
package = importlib.import_module('modules.{}'.format(module_name))
68+
module_cls = getattr(package, module_name.upper())
69+
module = module_cls(self.cfg)
70+
71+
if hasattr(module, 'dependent_modules'):
72+
for m in module.dependent_modules:
73+
if m not in enabled_modules:
74+
raise RuntimeError('{} is available only if {} is activated'.format(module_name, m))
75+
76+
modules[module_name] = module
77+
78+
return modules
79+
6680
def execute(self, bayer, save_intermediates=False, verbose=True):
6781
"""
6882
ISP pipeline execution
@@ -124,6 +138,69 @@ def get_output(self, data):
124138

125139
return output
126140

141+
def run(self, raw_path, save_dir, load_raw_fn, suffix=''):
142+
"""
143+
A higher level API that write ISP result into disk
144+
:param raw_path: path to the raw file to be executed
145+
:param save_dir: directory to save the ISP output (the output will share the filename with input)
146+
:param load_raw_fn: function to load the Bayer array from the raw_path
147+
:param suffix: suffix to added to the output filename
148+
"""
149+
150+
import cv2
151+
152+
bayer = load_raw_fn(raw_path)
153+
data, _ = self.execute(bayer, save_intermediates=False, verbose=False)
154+
output = cv2.cvtColor(data['output'], cv2.COLOR_RGB2BGR)
155+
156+
filename = op.splitext(op.basename(raw_path))[0]
157+
save_path = op.join(save_dir, '{}.png'.format(filename + suffix))
158+
cv2.imwrite(save_path, output)
159+
160+
def batch_run(self, raw_paths, save_dirs, load_raw_fn, suffixes='', num_processes=1):
161+
"""
162+
Batch execution with multiprocessing
163+
:param raw_paths: list of paths to the raw files to be executed
164+
:param save_dirs: list of directories to save the outputs. If given a string, it will be copied
165+
to a N-element list, where N is the number of paths in raw_paths
166+
:param load_raw_fn: function to load the Bayer array from the raw_path
167+
:param suffixes: a list of suffixes to added to the output filenames
168+
:param num_processes: number of processes in multiprocessing
169+
"""
170+
171+
num_files = len(raw_paths)
172+
num_batches = math.ceil(num_files / num_processes)
173+
174+
if not isinstance(save_dirs, (list, tuple)):
175+
save_dirs = [save_dirs for _ in range(num_files)]
176+
if not isinstance(suffixes, (list, tuple)):
177+
suffixes = [suffixes for _ in range(num_files)]
178+
179+
for batch_id in range(num_batches):
180+
indices = [batch_id * num_processes + rank for rank in range(num_processes)]
181+
indices = [i for i in indices if i < num_files]
182+
batch_size = len(indices)
183+
184+
raw_paths_batch = [raw_paths[i] for i in indices]
185+
save_dirs_batch = [save_dirs[i] for i in indices]
186+
suffixes_batch = [suffixes[i] for i in indices]
187+
188+
pool = []
189+
for rank in range(batch_size):
190+
pool.append(
191+
Process(target=self.run,
192+
kwargs={'raw_path': raw_paths_batch[rank],
193+
'save_dir': save_dirs_batch[rank],
194+
'load_raw_fn': load_raw_fn,
195+
'suffix': suffixes_batch[rank]})
196+
)
197+
198+
for p in pool:
199+
p.start()
200+
201+
for p in pool:
202+
p.join()
203+
127204

128205
def ycbcr_to_rgb(ycbcr_array):
129206
""" Convert YCbCr 3-channel array into sRGB array """

0 commit comments

Comments
 (0)