Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from apache_beam.dataframe.expressions import allow_non_parallel_operations
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/dataframe/doctests.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def _run_patched(func, *args, **kwargs):
original_doc_test_runner = doctest.DocTestRunner
doctest.DocTestRunner = lambda **kwargs: BeamDataframeDoctestRunner(
env, use_beam=use_beam, skip=skip, **kwargs)
return func(*args, extraglobs=extraglobs, optionflags=optionflags, **kwargs)
with expressions.allow_non_parallel_operations():
return func(
*args, extraglobs=extraglobs, optionflags=optionflags, **kwargs)
finally:
doctest.DocTestRunner = original_doc_test_runner
27 changes: 27 additions & 0 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

from __future__ import absolute_import

import contextlib
import threading

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't run Jenkins tests before merging, it introduced a linter failure.

from typing import Any
from typing import Callable
from typing import Iterable
Expand Down Expand Up @@ -203,6 +206,11 @@ def __init__(
be partitioned by index whenever all of its inputs are partitioned by
index.
"""
if (not _get_allow_non_parallel() and
requires_partition_by == partitionings.Singleton()):
raise NonParallelOperation(
"Using non-parallel form of %s "
"outside of allow_non_parallel_operations block." % name)
args = tuple(args)
if proxy is None:
proxy = func(*(arg.proxy() for arg in args))
Expand Down Expand Up @@ -236,3 +244,22 @@ def elementwise_expression(name, func, args):
args,
requires_partition_by=partitionings.Nothing(),
preserves_partition_by=partitionings.Singleton())


_ALLOW_NON_PARALLEL = threading.local()
_ALLOW_NON_PARALLEL.value = False


def _get_allow_non_parallel():
return _ALLOW_NON_PARALLEL.value


@contextlib.contextmanager
def allow_non_parallel_operations(allow=True):
old_value, _ALLOW_NON_PARALLEL.value = _ALLOW_NON_PARALLEL.value, allow
yield
_ALLOW_NON_PARALLEL.value = old_value


class NonParallelOperation(Exception):
pass
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pandas as pd

import apache_beam as beam
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import frames # pylint: disable=unused-import
Expand Down Expand Up @@ -81,5 +82,35 @@ def test_loc(self):
self._run_test(lambda df: df.loc[lambda df: df.A > 10], df)


class AllowNonParallelTest(unittest.TestCase):
def _use_non_parallel_operation(self):
_ = frame_base.DeferredFrame.wrap(
expressions.PlaceholderExpression(pd.Series([1, 2, 3]))).replace(
'a', 'b', limit=1)

def test_disallow_non_parallel(self):
with self.assertRaises(expressions.NonParallelOperation):
self._use_non_parallel_operation()

def test_allow_non_parallel_in_context(self):
with beam.dataframe.allow_non_parallel_operations():
self._use_non_parallel_operation()

def test_allow_non_parallel_nesting(self):
# disallowed
with beam.dataframe.allow_non_parallel_operations():
# allowed
self._use_non_parallel_operation()
with beam.dataframe.allow_non_parallel_operations(False):
# disallowed again
with self.assertRaises(expressions.NonParallelOperation):
self._use_non_parallel_operation()
# allowed
self._use_non_parallel_operation()
# disallowed
with self.assertRaises(expressions.NonParallelOperation):
self._use_non_parallel_operation()


if __name__ == '__main__':
unittest.main()
29 changes: 16 additions & 13 deletions sdks/python/apache_beam/dataframe/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def test_sum_mean(self):
'Animal': ['Falcon', 'Falcon', 'Parrot', 'Parrot'],
'Speed': [380., 370., 24., 26.]
})
self.run_scenario(df, lambda df: df.groupby('Animal').sum())
self.run_scenario(df, lambda df: df.groupby('Animal').mean())
with expressions.allow_non_parallel_operations():
self.run_scenario(df, lambda df: df.groupby('Animal').sum())
self.run_scenario(df, lambda df: df.groupby('Animal').mean())

def test_filter(self):
df = pd.DataFrame({
Expand All @@ -95,19 +96,21 @@ def test_filter(self):
df, lambda df: df.set_index('Animal').filter(regex='F.*', axis='index'))

def test_aggregate(self):
a = pd.DataFrame({'col': [1, 2, 3]})
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a.agg(['mean', 'min', 'max']))
with expressions.allow_non_parallel_operations():
a = pd.DataFrame({'col': [1, 2, 3]})
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a.agg(['mean', 'min', 'max']))

def test_scalar(self):
a = pd.Series([1, 2, 6])
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a / a.agg(sum))

# Tests scalar being used as an input to a downstream stage.
df = pd.DataFrame({'key': ['a', 'a', 'b'], 'val': [1, 2, 6]})
self.run_scenario(
df, lambda df: df.groupby('key').sum().val / df.val.agg(sum))
with expressions.allow_non_parallel_operations():
a = pd.Series([1, 2, 6])
self.run_scenario(a, lambda a: a.agg(sum))
self.run_scenario(a, lambda a: a / a.agg(sum))

# Tests scalar being used as an input to a downstream stage.
df = pd.DataFrame({'key': ['a', 'a', 'b'], 'val': [1, 2, 6]})
self.run_scenario(
df, lambda df: df.groupby('key').sum().val / df.val.agg(sum))

def test_input_output_polymorphism(self):
one_series = pd.Series([1])
Expand Down