-
-
Notifications
You must be signed in to change notification settings - Fork 345
Expand file tree
/
Copy pathprocess.py
More file actions
963 lines (774 loc) · 31.1 KB
/
process.py
File metadata and controls
963 lines (774 loc) · 31.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
from __future__ import annotations
import asyncio
import functools
import gc
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, List, NewType, Sequence, Union
from sanic.log import logger
from api import (
BaseInput,
BaseOutput,
Collector,
ExecutionOptions,
InputId,
Iterator,
Lazy,
NodeContext,
NodeData,
NodeId,
OutputId,
SettingsParser,
registry,
)
from chain.cache import CacheStrategy, OutputCache, StaticCaching, get_cache_strategies
from chain.chain import Chain, CollectorNode, FunctionNode, NewIteratorNode, Node
from chain.input import EdgeInput, Input, InputMap
from events import EventConsumer, InputsDict
from progress_controller import Aborted, ProgressController, ProgressToken
from util import timed_supplier
Output = List[object]
def collect_input_information(
node: NodeData,
inputs: list[object | Lazy[object]],
enforced: bool = True,
) -> InputsDict:
try:
input_dict: InputsDict = {}
for value, node_input in zip(inputs, node.inputs):
if isinstance(value, Lazy) and value.has_value:
value = value.value # noqa: PLW2901
if isinstance(value, Lazy):
# the value hasn't been computed yet, so we won't do so here
input_dict[node_input.id] = {"type": "pending"}
continue
if not enforced:
try:
value = node_input.enforce_(value) # noqa
except Exception:
logger.error(
f"Error enforcing input {node_input.label} (id {node_input.id})",
exc_info=True,
)
# We'll just try using the un-enforced value. Maybe it'll work.
try:
input_dict[node_input.id] = node_input.get_error_value(value)
except Exception:
logger.error(
f"Error getting error value for input {node_input.label} (id {node_input.id})",
exc_info=True,
)
return input_dict
except Exception:
# this method must not throw
logger.error("Error collecting input information.", exc_info=True)
return {}
def enforce_inputs(
inputs: list[object],
node: NodeData,
node_id: NodeId,
ignored_inputs: list[InputId],
) -> list[object]:
def enforce(i: BaseInput, value: object) -> object:
if i.id in ignored_inputs:
return None
# we generally assume that enforcing a value is cheap, so we do it as soon as possible
if i.lazy:
if isinstance(value, Lazy):
return Lazy(lambda: i.enforce_(value.value))
return Lazy.ready(i.enforce_(value))
if isinstance(value, Lazy):
value = value.value # compute lazy value
return i.enforce_(value)
try:
enforced_inputs: list[object] = []
for index, value in enumerate(inputs):
enforced_inputs.append(enforce(node.inputs[index], value))
return enforced_inputs
except Exception as e:
input_dict = collect_input_information(node, inputs, enforced=False)
raise NodeExecutionError(node_id, node, str(e), input_dict) from e
def enforce_output(raw_output: object, node: NodeData) -> RegularOutput:
l = len(node.outputs)
output: Output
if l == 0:
assert raw_output is None, f"Expected all {node.name} nodes to return None."
output = []
elif l == 1:
output = [raw_output]
else:
assert isinstance(raw_output, (tuple, list))
output = list(raw_output)
assert (
len(output) == l
), f"Expected all {node.name} nodes to have {l} output(s) but found {len(output)}."
# output-specific validations
for i, o in enumerate(node.outputs):
output[i] = o.enforce(output[i])
return RegularOutput(output)
def enforce_iterator_output(raw_output: object, node: NodeData) -> IteratorOutput:
l = len(node.outputs)
iterator_output = node.single_iterator_output
partial: list[object] = [None] * l
if l == len(iterator_output.outputs):
assert isinstance(raw_output, Iterator), "Expected the output to be an iterator"
return IteratorOutput(iterator=raw_output, partial_output=partial)
assert l > len(iterator_output.outputs)
assert isinstance(raw_output, (tuple, list))
iterator, *rest = raw_output
assert isinstance(
iterator, Iterator
), "Expected the first tuple element to be an iterator"
assert len(rest) == l - len(iterator_output.outputs)
# output-specific validations
for i, o in enumerate(node.outputs):
if o.id not in iterator_output.outputs:
partial[i] = o.enforce(rest.pop(0))
return IteratorOutput(iterator=iterator, partial_output=partial)
def run_node(
node: NodeData, context: NodeContext, inputs: list[object], node_id: NodeId
) -> NodeOutput | CollectorOutput:
if node.kind == "collector":
ignored_inputs = node.single_iterator_input.inputs
else:
ignored_inputs = []
enforced_inputs = enforce_inputs(inputs, node, node_id, ignored_inputs)
try:
if node.node_context:
raw_output = node.run(context, *enforced_inputs)
else:
raw_output = node.run(*enforced_inputs)
if node.kind == "collector":
assert isinstance(raw_output, Collector)
return CollectorOutput(raw_output)
if node.kind == "newIterator":
return enforce_iterator_output(raw_output, node)
assert node.kind == "regularNode"
return enforce_output(raw_output, node)
except Aborted:
raise
except NodeExecutionError:
raise
except Exception as e:
# collect information to provide good error messages
input_dict = collect_input_information(node, enforced_inputs)
raise NodeExecutionError(node_id, node, str(e), input_dict) from e
def run_collector_iterate(
node: CollectorNode, inputs: list[object], collector: Collector
) -> None:
iterator_input = node.data.single_iterator_input
def get_partial_inputs(values: list[object]) -> list[object]:
partial_inputs: list[object] = []
index = 0
for i in node.data.inputs:
if i.id in iterator_input.inputs:
partial_inputs.append(values[index])
index += 1
else:
partial_inputs.append(None)
return partial_inputs
enforced_inputs: list[object] = []
try:
for i in node.data.inputs:
if i.id in iterator_input.inputs:
enforced_inputs.append(i.enforce_(inputs[len(enforced_inputs)]))
except Exception as e:
input_dict = collect_input_information(
node.data, get_partial_inputs(inputs), enforced=False
)
raise NodeExecutionError(node.id, node.data, str(e), input_dict) from e
input_value = (
enforced_inputs[0] if len(enforced_inputs) == 1 else tuple(enforced_inputs)
)
try:
raw_output = collector.on_iterate(input_value)
assert raw_output is None
except Exception as e:
input_dict = collect_input_information(
node.data, get_partial_inputs(enforced_inputs)
)
raise NodeExecutionError(node.id, node.data, str(e), input_dict) from e
class _Timer:
def __init__(self) -> None:
self.duration: float = 0
@contextmanager
def run(self):
start = time.monotonic()
try:
yield None
finally:
self.add_since(start)
def add_since(self, start: float):
self.duration += time.monotonic() - start
class _IterationTimer:
def __init__(self, progress: ProgressController) -> None:
self.times: list[float] = []
self.progress = progress
self._start_time = time.monotonic()
self._start_paused = progress.time_paused
self._last_time = self._start_time
self._last_paused = self._start_paused
@property
def iterations(self) -> int:
return len(self.times)
def get_time_since_start(self) -> float:
now = time.monotonic()
paused = self.progress.time_paused
current_paused = max(0, paused - self._start_paused)
return now - self._start_time - current_paused
def add(self):
now = time.monotonic()
paused = self.progress.time_paused
current_paused = max(0, paused - self._last_paused)
self.times.append(now - self._last_time - current_paused)
self._last_time = now
self._last_paused = paused
def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]):
data: dict[OutputId, object] = {}
types: dict[OutputId, object] = {}
for index, node_output in enumerate(node_outputs):
try:
value = output[index]
if value is not None:
data[node_output.id] = node_output.get_broadcast_data(value)
types[node_output.id] = node_output.get_broadcast_type(value)
except Exception as e:
logger.error(f"Error broadcasting output: {e}")
return data, types
class NodeExecutionError(Exception):
def __init__(
self,
node_id: NodeId,
node_data: NodeData,
cause: str,
inputs: InputsDict,
):
super().__init__(cause)
self.node_id: NodeId = node_id
self.node_data: NodeData = node_data
self.inputs: InputsDict = inputs
@dataclass(frozen=True)
class RegularOutput:
output: Output
@dataclass(frozen=True)
class IteratorOutput:
iterator: Iterator
partial_output: Output
@dataclass(frozen=True)
class CollectorOutput:
collector: Collector
NodeOutput = Union[RegularOutput, IteratorOutput]
ExecutionId = NewType("ExecutionId", str)
class _ExecutorNodeContext(NodeContext):
def __init__(
self, progress: ProgressToken, settings: SettingsParser, storage_dir: Path
) -> None:
super().__init__()
self.progress = progress
self.__settings = settings
self._storage_dir = storage_dir
self.cleanup_fns: set[Callable[[], None]] = set()
@property
def aborted(self) -> bool:
return self.progress.aborted
@property
def paused(self) -> bool:
time.sleep(0.001)
return self.progress.paused
def set_progress(self, progress: float) -> None:
self.check_aborted()
# TODO: send progress event
@property
def settings(self) -> SettingsParser:
"""
Returns the settings of the current node execution.
"""
return self.__settings
@property
def storage_dir(self) -> Path:
return self._storage_dir
def add_cleanup(self, fn: Callable[[], None]) -> None:
self.cleanup_fns.add(fn)
class Executor:
"""
Class for executing chaiNNer's processing logic
"""
def __init__(
self,
id: ExecutionId,
chain: Chain,
send_broadcast_data: bool,
options: ExecutionOptions,
loop: asyncio.AbstractEventLoop,
queue: EventConsumer,
pool: ThreadPoolExecutor,
storage_dir: Path,
parent_cache: OutputCache[NodeOutput] | None = None,
):
self.id: ExecutionId = id
self.chain = chain
self.inputs: InputMap = InputMap.from_chain(chain)
self.send_broadcast_data: bool = send_broadcast_data
self.options: ExecutionOptions = options
self.cache: OutputCache[NodeOutput] = OutputCache(parent=parent_cache)
self.__broadcast_tasks: list[asyncio.Task[None]] = []
self.__context_cache: dict[str, _ExecutorNodeContext] = {}
self.progress = ProgressController()
self.loop: asyncio.AbstractEventLoop = loop
self.queue: EventConsumer = queue
self.pool: ThreadPoolExecutor = pool
self.cache_strategy: dict[NodeId, CacheStrategy] = get_cache_strategies(chain)
self._storage_dir = storage_dir
async def process(self, node_id: NodeId) -> NodeOutput | CollectorOutput:
# Return cached output value from an already-run node if that cached output exists
cached = self.cache.get(node_id)
if cached is not None:
return cached
node = self.chain.nodes[node_id]
try:
return await self.__process(node)
except Aborted:
raise
except NodeExecutionError:
raise
except Exception as e:
raise NodeExecutionError(node.id, node.data, str(e), {}) from e
async def process_regular_node(self, node: FunctionNode) -> RegularOutput:
"""
Processes the given regular node.
This will run all necessary node events.
"""
result = await self.process(node.id)
assert isinstance(result, RegularOutput)
return result
async def process_iterator_node(self, node: NewIteratorNode) -> IteratorOutput:
"""
Processes the given iterator node.
This will **not** iterate the returned iterator. Only `node-start` and
`node-broadcast` events will be sent.
"""
result = await self.process(node.id)
assert isinstance(result, IteratorOutput)
return result
async def process_collector_node(self, node: CollectorNode) -> CollectorOutput:
"""
Processes the given iterator node.
This will **not** iterate the returned collector. Only a `node-start` event
will be sent.
"""
result = await self.process(node.id)
assert isinstance(result, CollectorOutput)
return result
async def __get_node_output(self, node_id: NodeId, output_index: int) -> object:
"""
Returns the output value of the given node.
Note: `output_index` is NOT an output ID.
"""
# Recursively get the value of the input
output = await self.process(node_id)
if isinstance(output, CollectorOutput):
# this generally shouldn't be possible
raise ValueError("A collector was not run before another node needed it.")
if isinstance(output, IteratorOutput):
value = output.partial_output[output_index]
assert value is not None, "An iterator output was not assigned correctly"
return value
assert isinstance(output, RegularOutput)
return output.output[output_index]
async def __resolve_node_input(self, node_input: Input) -> object:
if isinstance(node_input, EdgeInput):
# If input is a dict indicating another node, use that node's output value
# Recursively get the value of the input
return await self.__get_node_output(node_input.id, node_input.index)
else:
# Otherwise, just use the given input (number, string, etc)
return node_input.value
async def __gather_inputs(self, node: Node) -> list[object]:
"""
Returns the list of input values for the given node.
"""
# we want to ignore some inputs if we are running a collector node
ignore: set[int] = set()
if isinstance(node, CollectorNode):
iterator_input = node.data.single_iterator_input
for input_index, i in enumerate(node.data.inputs):
if i.id in iterator_input.inputs:
ignore.add(input_index)
# some inputs are lazy, so we want to lazily resolve them
lazy: set[int] = set()
for input_index, i in enumerate(node.data.inputs):
if i.lazy:
lazy.add(input_index)
assigned_inputs = self.inputs.get(node.id)
assert len(assigned_inputs) == len(node.data.inputs)
async def get_input_value(input_index: int, node_input: Input):
if input_index in ignore:
return None
if input_index in lazy:
return Lazy.from_coroutine(
self.__resolve_node_input(assigned_inputs[input_index]), self.loop
)
return await self.__resolve_node_input(node_input)
inputs = []
for input_index, node_input in enumerate(assigned_inputs):
inputs.append(await get_input_value(input_index, node_input))
return inputs
async def __gather_collector_inputs(self, node: CollectorNode) -> list[object]:
"""
Returns the input values to be consumed by `Collector.on_iterate`.
"""
iterator_input = node.data.single_iterator_input
assigned_inputs = self.inputs.get(node.id)
assert len(assigned_inputs) == len(node.data.inputs)
inputs = []
for input_index, node_input in enumerate(assigned_inputs):
i = node.data.inputs[input_index]
if i.id in iterator_input.inputs:
inputs.append(await self.__resolve_node_input(node_input))
return inputs
def __get_node_context(self, node: Node) -> _ExecutorNodeContext:
context = self.__context_cache.get(node.data.schema_id, None)
if context is None:
package_id = registry.get_package(node.data.schema_id).id
settings = self.options.get_package_settings(package_id)
context = _ExecutorNodeContext(self.progress, settings, self._storage_dir)
self.__context_cache[node.data.schema_id] = context
return context
async def __process(self, node: Node) -> NodeOutput | CollectorOutput:
"""
Process a single node.
In the case of iterators and collectors, it will only run the node itself,
not the actual iteration or collection.
"""
logger.debug(f"node: {node}")
logger.debug(f"Running node {node.id}")
inputs = await self.__gather_inputs(node)
context = self.__get_node_context(node)
await self.progress.suspend()
await self.__send_node_start(node)
await self.progress.suspend()
output, execution_time = await self.loop.run_in_executor(
self.pool,
timed_supplier(
functools.partial(run_node, node.data, context, inputs, node.id)
),
)
await self.progress.suspend()
if isinstance(output, RegularOutput):
await self.__send_node_broadcast(node, output.output)
await self.__send_node_finish(node, execution_time)
elif isinstance(output, IteratorOutput):
await self.__send_node_broadcast(node, output.partial_output)
# TODO: execution time
# Cache the output of the node
if not isinstance(output, CollectorOutput):
self.cache.set(node.id, output, self.cache_strategy[node.id])
await self.progress.suspend()
return output
def __get_iterated_nodes(
self, node: NewIteratorNode
) -> tuple[set[CollectorNode], set[FunctionNode], set[Node]]:
"""
Returns all collector and output nodes iterated by the given iterator node
"""
collectors: set[CollectorNode] = set()
output_nodes: set[FunctionNode] = set()
seen: set[Node] = {node}
def visit(n: Node):
if n in seen:
return
seen.add(n)
if isinstance(n, CollectorNode):
collectors.add(n)
elif isinstance(n, NewIteratorNode):
raise ValueError("Nested iterators are not supported")
else:
assert isinstance(n, FunctionNode)
if n.has_side_effects():
output_nodes.add(n)
# follow edges
for edge in self.chain.edges_from(n.id):
target_node = self.chain.nodes[edge.target.id]
visit(target_node)
iterator_output = node.data.single_iterator_output
for edge in self.chain.edges_from(node.id):
# only follow iterator outputs
if edge.source.output_id in iterator_output.outputs:
target_node = self.chain.nodes[edge.target.id]
visit(target_node)
return collectors, output_nodes, seen
def __iterator_fill_partial_output(
self, node: NewIteratorNode, partial_output: Output, values: object
) -> Output:
iterator_output = node.data.single_iterator_output
values_list: list[object] = []
if len(iterator_output.outputs) == 1:
values_list.append(values)
else:
assert isinstance(values, (tuple, list))
values_list.extend(values)
assert len(values_list) == len(iterator_output.outputs)
output: Output = partial_output.copy()
for index, o in enumerate(node.data.outputs):
if o.id in iterator_output.outputs:
output[index] = o.enforce(values_list.pop(0))
return output
async def __iterate_iterator_node(self, node: NewIteratorNode):
await self.progress.suspend()
# run the iterator node itself before anything else
iterator_output = await self.process_iterator_node(node)
collector_nodes, output_nodes, all_iterated_nodes = self.__get_iterated_nodes(
node
)
all_iterated_nodes = {n.id for n in all_iterated_nodes}
if len(collector_nodes) == 0 and len(output_nodes) == 0:
# unusual, but this can happen
# since we don't need to actually iterate the iterator, we can stop here
return
def fill_partial_output(values: object) -> RegularOutput:
return RegularOutput(
self.__iterator_fill_partial_output(
node, iterator_output.partial_output, values
)
)
# run each of the collector nodes
collectors: list[tuple[Collector, _Timer, CollectorNode]] = []
for collector_node in collector_nodes:
await self.progress.suspend()
timer = _Timer()
with timer.run():
collector_output = await self.process_collector_node(collector_node)
assert isinstance(collector_output, CollectorOutput)
collectors.append((collector_output.collector, timer, collector_node))
# timing iterations
iter_times = _IterationTimer(self.progress)
expected_length = iterator_output.iterator.expected_length
async def update_progress():
iter_times.add()
iterations = iter_times.iterations
await self.__send_node_progress(
node,
iter_times.times,
iterations,
max(expected_length, iterations),
)
# iterate
await self.__send_node_progress(node, [], 0, expected_length)
deferred_errors: list[str] = []
for values in iterator_output.iterator.iter_supplier():
try:
if isinstance(values, Exception):
raise values
# write current values to cache
iter_output = fill_partial_output(values)
self.cache.set(node.id, iter_output, StaticCaching)
# broadcast
await self.__send_node_broadcast(node, iter_output.output)
# run each of the output nodes
for output_node in output_nodes:
await self.process_regular_node(output_node)
# run each of the collector nodes
for collector, timer, collector_node in collectors:
await self.progress.suspend()
iterate_inputs = await self.__gather_collector_inputs(
collector_node
)
await self.progress.suspend()
with timer.run():
run_collector_iterate(collector_node, iterate_inputs, collector)
# clear cache for next iteration
self.cache.delete_many(all_iterated_nodes)
await self.progress.suspend()
await update_progress()
# cooperative yield so the event loop can run
# https://stackoverflow.com/questions/36647825/cooperative-yield-in-asyncio
await asyncio.sleep(0)
await self.progress.suspend()
except Aborted:
raise
except Exception as e:
if iterator_output.iterator.fail_fast:
raise e
else:
deferred_errors.append(str(e))
# reset cached value
self.cache.delete_many(all_iterated_nodes)
self.cache.set(node.id, iterator_output, self.cache_strategy[node.id])
# re-broadcast final value
# TODO: Why?
await self.__send_node_broadcast(node, iterator_output.partial_output)
# finish iterator
await self.__send_node_progress_done(node, iter_times.iterations)
await self.__send_node_finish(node, iter_times.get_time_since_start())
# finalize collectors
for collector, timer, collector_node in collectors:
await self.progress.suspend()
with timer.run():
collector_output = enforce_output(
collector.on_complete(), collector_node.data
)
await self.__send_node_broadcast(collector_node, collector_output.output)
# TODO: execution time
await self.__send_node_finish(collector_node, timer.duration)
self.cache.set(
collector_node.id,
collector_output,
self.cache_strategy[collector_node.id],
)
if len(deferred_errors) > 0:
error_string = "- " + "\n- ".join(deferred_errors)
raise Exception(f"Errors occurred during iteration:\n{error_string}")
async def __process_nodes(self):
await self.__send_chain_start()
# we first need to run iterator nodes in topological order
for node_id in self.chain.topological_order():
node = self.chain.nodes[node_id]
if isinstance(node, NewIteratorNode):
await self.__iterate_iterator_node(node)
# now the output nodes outside of iterators
# Now run everything that is not in an iterator lineage
non_iterator_output_nodes = [
node
for node, iter_node in self.chain.get_parent_iterator_map().items()
if iter_node is None and node.has_side_effects()
]
for output_node in non_iterator_output_nodes:
await self.progress.suspend()
await self.process_regular_node(output_node)
# clear cache after the chain is done
self.cache.clear()
# Run cleanup functions
for context in self.__context_cache.values():
for fn in context.cleanup_fns:
try:
fn()
except Exception as e:
logger.error(f"Error running cleanup function: {e}")
# await all broadcasts
tasks = self.__broadcast_tasks
self.__broadcast_tasks = []
for task in tasks:
await task
async def run(self):
logger.debug(f"Running executor {self.id}")
try:
await self.__process_nodes()
finally:
gc.collect()
def resume(self):
logger.debug(f"Resuming executor {self.id}")
self.progress.resume()
def pause(self):
logger.debug(f"Pausing executor {self.id}")
self.progress.pause()
gc.collect()
def kill(self):
logger.debug(f"Killing executor {self.id}")
self.progress.abort()
# events
async def __send_chain_start(self):
# all nodes except the cached ones
nodes = set(self.chain.nodes.keys())
nodes.difference_update(self.cache.keys())
await self.queue.put(
{
"event": "chain-start",
"data": {
"nodes": list(nodes),
},
}
)
async def __send_node_start(self, node: Node):
await self.queue.put(
{
"event": "node-start",
"data": {
"nodeId": node.id,
},
}
)
async def __send_node_progress(
self, node: Node, times: Sequence[float], index: int, length: int
):
def get_eta(times: Sequence[float]) -> float:
avg_time = 0
if len(times) > 0:
# only consider the last 100
times = times[-100:]
# use a weighted average
weights = [max(1 / i, 0.9**i) for i in range(len(times), 0, -1)]
avg_time = sum(t * w for t, w in zip(times, weights)) / sum(weights)
remaining = max(0, length - index)
return avg_time * remaining
await self.queue.put(
{
"event": "node-progress",
"data": {
"nodeId": node.id,
"progress": 1 if length == 0 else index / length,
"index": index,
"total": length,
"eta": get_eta(times),
},
}
)
async def __send_node_progress_done(self, node: Node, length: int):
await self.queue.put(
{
"event": "node-progress",
"data": {
"nodeId": node.id,
"progress": 1,
"index": length,
"total": length,
"eta": 0,
},
}
)
async def __send_node_broadcast(
self,
node: Node,
output: Output,
):
def compute_broadcast_data():
if self.progress.aborted:
# abort the broadcast if the chain was aborted
return None
return compute_broadcast(output, node.data.outputs)
async def send_broadcast():
# TODO: Add the time it takes to compute the broadcast data to the execution time
result = await self.loop.run_in_executor(self.pool, compute_broadcast_data)
if result is None or self.progress.aborted:
return
data, types = result
await self.queue.put(
{
"event": "node-broadcast",
"data": {
"nodeId": node.id,
"data": data,
"types": types,
},
}
)
# Only broadcast the output if the node has outputs
if self.send_broadcast_data and len(node.data.outputs) > 0:
# broadcasts are done is parallel, so don't wait
self.__broadcast_tasks.append(self.loop.create_task(send_broadcast()))
async def __send_node_finish(
self,
node: Node,
execution_time: float,
):
await self.queue.put(
{
"event": "node-finish",
"data": {
"nodeId": node.id,
"executionTime": execution_time,
},
}
)