diff --git a/sdks/python/apache_beam/dataframe/__init__.py b/sdks/python/apache_beam/dataframe/__init__.py index 427fee1be71c..9071a88193de 100644 --- a/sdks/python/apache_beam/dataframe/__init__.py +++ b/sdks/python/apache_beam/dataframe/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import absolute_import + +from apache_beam.dataframe.expressions import allow_non_parallel_operations diff --git a/sdks/python/apache_beam/dataframe/doctests.py b/sdks/python/apache_beam/dataframe/doctests.py index 2d135ea3baae..93f33d1aaab5 100644 --- a/sdks/python/apache_beam/dataframe/doctests.py +++ b/sdks/python/apache_beam/dataframe/doctests.py @@ -361,6 +361,8 @@ def _run_patched(func, *args, **kwargs): original_doc_test_runner = doctest.DocTestRunner doctest.DocTestRunner = lambda **kwargs: BeamDataframeDoctestRunner( env, use_beam=use_beam, skip=skip, **kwargs) - return func(*args, extraglobs=extraglobs, optionflags=optionflags, **kwargs) + with expressions.allow_non_parallel_operations(): + return func( + *args, extraglobs=extraglobs, optionflags=optionflags, **kwargs) finally: doctest.DocTestRunner = original_doc_test_runner diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py index 376efb870478..34e01ccaa32e 100644 --- a/sdks/python/apache_beam/dataframe/expressions.py +++ b/sdks/python/apache_beam/dataframe/expressions.py @@ -16,6 +16,9 @@ from __future__ import absolute_import +import contextlib +import threading + from typing import Any from typing import Callable from typing import Iterable @@ -203,6 +206,11 @@ def __init__( be partitioned by index whenever all of its inputs are partitioned by index. """ + if (not _get_allow_non_parallel() and + requires_partition_by == partitionings.Singleton()): + raise NonParallelOperation( + "Using non-parallel form of %s " + "outside of allow_non_parallel_operations block." % name) args = tuple(args) if proxy is None: proxy = func(*(arg.proxy() for arg in args)) @@ -236,3 +244,22 @@ def elementwise_expression(name, func, args): args, requires_partition_by=partitionings.Nothing(), preserves_partition_by=partitionings.Singleton()) + + +_ALLOW_NON_PARALLEL = threading.local() +_ALLOW_NON_PARALLEL.value = False + + +def _get_allow_non_parallel(): + return _ALLOW_NON_PARALLEL.value + + +@contextlib.contextmanager +def allow_non_parallel_operations(allow=True): + old_value, _ALLOW_NON_PARALLEL.value = _ALLOW_NON_PARALLEL.value, allow + yield + _ALLOW_NON_PARALLEL.value = old_value + + +class NonParallelOperation(Exception): + pass diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 1383a529f5c7..773b3bafd559 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -21,6 +21,7 @@ import numpy as np import pandas as pd +import apache_beam as beam from apache_beam.dataframe import expressions from apache_beam.dataframe import frame_base from apache_beam.dataframe import frames # pylint: disable=unused-import @@ -81,5 +82,35 @@ def test_loc(self): self._run_test(lambda df: df.loc[lambda df: df.A > 10], df) +class AllowNonParallelTest(unittest.TestCase): + def _use_non_parallel_operation(self): + _ = frame_base.DeferredFrame.wrap( + expressions.PlaceholderExpression(pd.Series([1, 2, 3]))).replace( + 'a', 'b', limit=1) + + def test_disallow_non_parallel(self): + with self.assertRaises(expressions.NonParallelOperation): + self._use_non_parallel_operation() + + def test_allow_non_parallel_in_context(self): + with beam.dataframe.allow_non_parallel_operations(): + self._use_non_parallel_operation() + + def test_allow_non_parallel_nesting(self): + # disallowed + with beam.dataframe.allow_non_parallel_operations(): + # allowed + self._use_non_parallel_operation() + with beam.dataframe.allow_non_parallel_operations(False): + # disallowed again + with self.assertRaises(expressions.NonParallelOperation): + self._use_non_parallel_operation() + # allowed + self._use_non_parallel_operation() + # disallowed + with self.assertRaises(expressions.NonParallelOperation): + self._use_non_parallel_operation() + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/dataframe/transforms_test.py b/sdks/python/apache_beam/dataframe/transforms_test.py index 81917cfff54f..e010b714133b 100644 --- a/sdks/python/apache_beam/dataframe/transforms_test.py +++ b/sdks/python/apache_beam/dataframe/transforms_test.py @@ -81,8 +81,9 @@ def test_sum_mean(self): 'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'], 'Speed': [380., 370., 24., 26.] }) - self.run_scenario(df, lambda df: df.groupby('Animal').sum()) - self.run_scenario(df, lambda df: df.groupby('Animal').mean()) + with expressions.allow_non_parallel_operations(): + self.run_scenario(df, lambda df: df.groupby('Animal').sum()) + self.run_scenario(df, lambda df: df.groupby('Animal').mean()) def test_filter(self): df = pd.DataFrame({ @@ -95,19 +96,21 @@ def test_filter(self): df, lambda df: df.set_index('Animal').filter(regex='F.*', axis='index')) def test_aggregate(self): - a = pd.DataFrame({'col': [1, 2, 3]}) - self.run_scenario(a, lambda a: a.agg(sum)) - self.run_scenario(a, lambda a: a.agg(['mean', 'min', 'max'])) + with expressions.allow_non_parallel_operations(): + a = pd.DataFrame({'col': [1, 2, 3]}) + self.run_scenario(a, lambda a: a.agg(sum)) + self.run_scenario(a, lambda a: a.agg(['mean', 'min', 'max'])) def test_scalar(self): - a = pd.Series([1, 2, 6]) - self.run_scenario(a, lambda a: a.agg(sum)) - self.run_scenario(a, lambda a: a / a.agg(sum)) - - # Tests scalar being used as an input to a downstream stage. - df = pd.DataFrame({'key': ['a', 'a', 'b'], 'val': [1, 2, 6]}) - self.run_scenario( - df, lambda df: df.groupby('key').sum().val / df.val.agg(sum)) + with expressions.allow_non_parallel_operations(): + a = pd.Series([1, 2, 6]) + self.run_scenario(a, lambda a: a.agg(sum)) + self.run_scenario(a, lambda a: a / a.agg(sum)) + + # Tests scalar being used as an input to a downstream stage. + df = pd.DataFrame({'key': ['a', 'a', 'b'], 'val': [1, 2, 6]}) + self.run_scenario( + df, lambda df: df.groupby('key').sum().val / df.val.agg(sum)) def test_input_output_polymorphism(self): one_series = pd.Series([1])