diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaMapToFieldsTransformProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaMapToFieldsTransformProvider.java index 2e2042aef05d..4a75c73c46f5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaMapToFieldsTransformProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaMapToFieldsTransformProvider.java @@ -156,13 +156,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } for (Map.Entry entry : configuration.getFields().entrySet()) { - if (!"java".equals(configuration.getLanguage())) { - String expr = entry.getValue().getExpression(); - if (expr == null || !inputSchema.hasField(expr)) { - throw new IllegalArgumentException( - "Unknown field or missing language specification for '" + entry.getKey() + "'"); - } - } try { JavaRowUdf udf = new JavaRowUdf(entry.getValue(), inputSchema); udfs.add(udf); diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 8d0037d4dd9f..9d7c3cb35e54 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -57,6 +57,7 @@ 'MapToFields-java': 'MapToFields-java' 'MapToFields-generic': 'MapToFields-java' 'Filter-java': 'Filter-java' + 'Filter-generic': 'Filter-java' 'Explode': 'Explode' config: mappings: @@ -75,6 +76,10 @@ drop: 'drop' fields: 'fields' error_handling: 'error_handling' + 'Filter-generic': + language: 'language' + keep: 'keep' + error_handling: 'error_handling' 'Filter-java': language: 'language' keep: 'keep' diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml new file mode 100644 index 000000000000..b676966ad6bd --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - 100 + - 200 + - 300 + + - type: MapToFields + config: + append: true + fields: + named_field: element + literal_int: 10 + literal_float: 1.5 + literal_str: '"abc"' + + - type: Filter + config: + keep: "named_field < 250" + + - type: AssertEqual + config: + elements: + - {element: 100, named_field: 100, literal_int: 10, literal_float: 1.5, literal_str: "abc"} + - {element: 200, named_field: 200, literal_int: 10, literal_float: 1.5, literal_str: "abc"} diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 8401364cde60..377bcac0e31a 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -19,6 +19,7 @@ import functools import inspect import itertools +import re from collections import abc from typing import Any from typing import Callable @@ -56,6 +57,12 @@ js2py = None JsObjectWrapper = object +_str_expression_fields = { + 'AssignTimestamps': 'timestamp', + 'Filter': 'keep', + 'Partition': 'by', +} + def normalize_mapping(spec): """ @@ -65,9 +72,87 @@ def normalize_mapping(spec): config = spec.get('config') if isinstance(config.get('drop'), str): config['drop'] = [config['drop']] + for field, value in list(config.get('fields', {}).items()): + if isinstance(value, (str, int, float)): + config['fields'][field] = {'expression': str(value)} + + elif spec['type'] in _str_expression_fields: + param = _str_expression_fields[spec['type']] + config = spec.get('config', {}) + if isinstance(config.get(param), (str, int, float)): + config[param] = {'expression': str(config.get(param))} + return spec +def is_literal(expr: str) -> bool: + # Some languages have limited integer literal ranges. + if re.fullmatch(r'-?\d+?', expr) and -1 << 31 < int(expr) < 1 << 31: + return True + elif re.fullmatch(r'-?\d+\.\d*', expr): + return True + elif re.fullmatch(r'"[^\\"]*"', expr): + return True + else: + return False + + +def validate_generic_expression( + expr_dict: dict, + input_fields: Collection[str], + allow_cmp: bool, + error_field: str) -> None: + if not isinstance(expr_dict, dict): + raise ValueError( + f"Ambiguous expression type (perhaps missing quoting?): {expr_dict}") + if len(expr_dict) != 1 or 'expression' not in expr_dict: + raise ValueError( + "Missing language specification. " + "Must specify a language when using a map with custom logic for %s" % + error_field) + expr = str(expr_dict['expression']) + + def is_atomic(expr: str): + return is_literal(expr) or expr in input_fields + + if is_atomic(expr): + return + + if allow_cmp: + maybe_cmp = re.fullmatch('(.*)([<>=!]+)(.*)', expr) + if maybe_cmp: + left, cmp, right = maybe_cmp.groups() + if (is_atomic(left.strip()) and is_atomic(right.strip()) and + cmp in {'==', '<=', '>=', '<', '>', '!='}): + return + + raise ValueError( + "Missing language specification, unknown input fields, " + f"or invalid generic expression: {expr}. " + "See https://beam.apache.org/documentation/sdks/yaml-udf/#generic") + + +def validate_generic_expressions(base_type, config, input_pcolls) -> None: + if not input_pcolls: + return + try: + input_fields = [ + name for (name, _) in named_fields_from_element_type( + next(iter(input_pcolls)).element_type) + ] + except (TypeError, ValueError): + input_fields = [] + + if base_type == 'MapToFields': + for field, value in list(config.get('fields', {}).items()): + validate_generic_expression(value, input_fields, True, field) + + elif base_type in _str_expression_fields: + param = _str_expression_fields[base_type] + validate_generic_expression( + config.get(param), input_fields, base_type == 'Filter', param) + + def _check_mapping_arguments( transform_name, expression=None, callable=None, name=None, path=None): # Argument checking @@ -282,16 +367,16 @@ def _as_callable_for_pcoll( def _as_callable(original_fields, expr, transform_name, language, input_schema): + if isinstance(expr, str): + expr = {'expression': expr} # Extract original type from upstream pcoll when doing simple mappings - original_type = input_schema.get(str(expr), None) + original_type = input_schema.get(expr.get('expression'), None) if expr in original_fields: language = "python" # TODO(yaml): support an imports parameter # TODO(yaml): support a requirements parameter (possibly at a higher level) - if isinstance(expr, str): - expr = {'expression': expr} if not isinstance(expr, dict): raise ValueError( f"Ambiguous expression type (perhaps missing quoting?): {expr}") @@ -300,7 +385,7 @@ def _as_callable(original_fields, expr, transform_name, language, input_schema): if language == "javascript": func = _expand_javascript_mapping_func(original_fields, **expr) - elif language == "python": + elif language in ("python", "generic", None): func = _expand_python_mapping_func(original_fields, **expr) else: raise ValueError( @@ -323,13 +408,9 @@ def checking_func(row): return checking_func elif original_type: - - @beam.typehints.with_output_types(convert_to_beam_type(original_type)) - def checking_func(row): - result = func(row) - return result - - return checking_func + return beam.typehints.with_output_types( + convert_to_beam_type(original_type))( + func) else: return func @@ -498,7 +579,7 @@ def _PyJsFilter( See more complete documentation on [YAML Filtering](https://beam.apache.org/documentation/sdks/yaml-udf/#filtering). """ # pylint: disable=line-too-long - keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language) + keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic') return pcoll | beam.Filter(keep_fn) @@ -530,17 +611,6 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'): f'Redefinition of field "{name}". ' 'Cannot append a field that already exists in original input.') - if language == 'generic': - for expr in fields.values(): - if not isinstance(expr, str): - raise ValueError( - "Missing language specification. " - "Must specify a language when using a map with custom logic.") - missing = set(fields.values()) - set(input_schema.keys()) - if missing: - raise ValueError( - f"Missing language specification or unknown input fields: {missing}") - if append: return input_schema, { **{name: f'`{name}`' if language in ['sql', 'calcite'] else name @@ -720,6 +790,7 @@ def create_mapping_providers(): 'Explode': _Explode, 'Filter-python': _PyJsFilter, 'Filter-javascript': _PyJsFilter, + 'Filter-generic': _PyJsFilter, 'MapToFields-python': _PyJsMapToFields, 'MapToFields-javascript': _PyJsMapToFields, 'MapToFields-generic': _PyJsMapToFields, diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index fd265c42cf73..3fb5bb7a28b8 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -38,6 +38,7 @@ from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping +from apache_beam.yaml.yaml_mapping import validate_generic_expressions __all__ = ["YamlTransform"] @@ -384,6 +385,12 @@ def create_ptransform(self, spec, input_pcolls): f'Missing inputs for transform at {identify_object(spec)}') try: + if spec['type'].endswith('-generic'): + # Centralize the validation rather than require every implementation + # to do it. + validate_generic_expressions( + spec['type'].rsplit('-', 1)[0], config, input_pcolls) + # pylint: disable=undefined-loop-variable ptransform = provider.create_transform( spec['type'], config, self.create_ptransform) diff --git a/website/www/site/content/en/documentation/sdks/yaml-udf.md b/website/www/site/content/en/documentation/sdks/yaml-udf.md index 8bf9a8de26ad..ded40de8b85e 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-udf.md +++ b/website/www/site/content/en/documentation/sdks/yaml-udf.md @@ -199,6 +199,22 @@ If one wanted to select a field that collides with a [reserved SQL keyword](http **Note**: the field mapping tags and fields defined in `drop` do not need to be escaped. Only the UDF itself needs to be a valid SQL statement. + +### Generic + +If a language is not specified the set of expressions is limited to pre-existing +fields and integer, floating point, or string literals. For example + +``` +- type: MapToFields + config: + fields: + new_col: col1 + int_literal: 389 + float_litera: 1.90216 + str_literal: '"example"' # note the double quoting +``` + ## FlatMap Sometimes it may be desirable to emit more (or less) than one record for each @@ -269,10 +285,19 @@ criteria. This can be accomplished with a `Filter` transform, e.g. ``` - type: Filter config: - language: python keep: "col2 > 0" ``` +For anything more complicated than a simple comparison between existing +fields and numeric literals a `language` parameter must be provided, e.g. + +``` +- type: Filter + config: + language: python + keep: "col2 + col3 > 0" +``` + For more complicated filtering functions, one can provide a full Python callable that takes the row as an argument to do more complex mappings (see [PythonCallableSource](https://beam.apache.org/releases/pydoc/current/apache_beam.utils.python_callable.html#apache_beam.utils.python_callable.PythonCallableWithSource)