diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 7d72a02f8874..a38e05d66cbe 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -60,7 +60,8 @@ def default_window_mapping_fn( def map_via_end(source_window: window.BoundedWindow) -> window.BoundedWindow: return list( target_window_fn.assign( - window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] + window.WindowFn.AssignContext( + source_window.max_timestamp(), window=source_window)))[-1] return map_via_end diff --git a/sdks/python/apache_beam/transforms/sideinputs_test.py b/sdks/python/apache_beam/transforms/sideinputs_test.py index 5f3cf761e1eb..9b79b9d1fa8d 100644 --- a/sdks/python/apache_beam/transforms/sideinputs_test.py +++ b/sdks/python/apache_beam/transforms/sideinputs_test.py @@ -39,6 +39,7 @@ from apache_beam.testing.util import equal_to from apache_beam.testing.util import equal_to_per_window from apache_beam.transforms import Map +from apache_beam.transforms import sideinputs from apache_beam.transforms import trigger from apache_beam.transforms import window from apache_beam.utils.timestamp import Timestamp @@ -489,6 +490,40 @@ def process( assert_that(results, equal_to([(num_records, expected_fingerprint)])) pipeline.run() + def test_default_window_mapping_fn_source_window(self): + """Test that the default window mapping function will propagate the + source window when attempting to assign context. + """ + class StringIDWindow(window.BoundedWindow): + """A window defined by an arbitrary string ID.""" + def __init__(self, window_id: str): + super().__init__(self._getTimestampFromProto()) + self.id = window_id + + @staticmethod + def _getTimestampFromProto() -> Timestamp: + return Timestamp(micros=0) + + class StringIDWindows(window.NonMergingWindowFn): + """ A windowing function that assigns each element a window with ID.""" + def assign( + self, assign_context: window.WindowFn.AssignContext + ) -> Iterable[window.BoundedWindow]: + if assign_context.element is None: + assert assign_context.window is not None + return [assign_context.window] + return [StringIDWindow(str(assign_context.element))] + + def get_window_coder(self): + return None + + mapping_fn = sideinputs.default_window_mapping_fn(StringIDWindows()) + source_window = StringIDWindows().assign( + window.WindowFn.AssignContext(Timestamp(10), element='element'))[0] + bounded_window = mapping_fn(source_window) + assert bounded_window is not None + assert bounded_window.id == 'element' + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG)