diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index aefddbcf8..7233a8eec 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -216,7 +216,7 @@ def __init__(self, collection_id: str): self.collection_id = collection_id def _pb_args(self): - return [Value(string_value=self.collection_id)] + return [Value(reference_value=""), Value(string_value=self.collection_id)] class Database(Stage): diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index c5e6a7b7f..3dd7a453e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -21,9 +21,10 @@ from __future__ import annotations import abc +import itertools from abc import ABC -from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -33,6 +34,10 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) +from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction +from google.cloud.firestore_v1.pipeline_expressions import Count +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr +from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER @@ -66,6 +71,9 @@ def __init__(self, alias: str, value: float, read_time=None): def __repr__(self): return f"" + def _to_dict(self): + return {self.alias: self.value} + class BaseAggregation(ABC): def __init__(self, alias: str | None = None): @@ -75,6 +83,27 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" + @abc.abstractmethod + def _to_pipeline_expr( + self, autoindexer: Iterable[int] + ) -> AliasedExpr[AggregateFunction]: + """ + Convert this instance to a pipeline expression for use with pipeline.aggregate() + + Args: + autoindexer: If an alias isn't supplied, one should be created with the format "field_n" + The autoindexer is an iterable that provides the `n` value to use for each expression + """ + + def _pipeline_alias(self, autoindexer): + """ + Helper to build the alias for the pipeline expression + """ + if self.alias is not None: + return self.alias + else: + return f"field_{next(autoindexer)}" + class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): @@ -88,6 +117,9 @@ def _to_protobuf(self): aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Count().as_(self._pipeline_alias(autoindexer)) + class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -107,6 +139,9 @@ def _to_protobuf(self): aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer)) + class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -126,6 +161,9 @@ def _to_protobuf(self): aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) + def _query_response_to_result( response_pb, @@ -317,3 +355,20 @@ def stream( StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: A generator of the query results. """ + + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + # use autoindexer to keep track of which field number to use for un-aliased fields + autoindexer = itertools.count(start=1) + exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] + return self._nested_query.pipeline().aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1b1ef0411..a4cc2b1b7 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -602,6 +602,19 @@ def find_nearest( distance_threshold=distance_threshold, ) + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + return self._query().pipeline() + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2de95b79a..797572b1b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -59,6 +59,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1128,6 +1129,74 @@ def recursive(self: QueryType) -> QueryType: return copied + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + if not self._client: + raise ValueError("Query does not have an associated client") + if self._all_descendants: + ppl = self._client.pipeline().collection_group(self._parent.id) + else: + ppl = self._client.pipeline().collection(self._parent._path) + + # Filters + for filter_ in self._field_filters: + ppl = ppl.where( + pipeline_expressions.BooleanExpr._from_query_filter_pb( + filter_, self._client + ) + ) + + # Projections + if self._projection and self._projection.fields: + ppl = ppl.select(*[field.field_path for field in self._projection.fields]) + + # Orders + orders = self._normalize_orders() + if orders: + exists = [] + orderings = [] + for order in orders: + field = pipeline_expressions.Field.of(order.field.field_path) + exists.append(field.exists()) + direction = ( + "ascending" + if order.direction == StructuredQuery.Direction.ASCENDING + else "descending" + ) + orderings.append(pipeline_expressions.Ordering(field, direction)) + + # Add exists filters to match Query's implicit orderby semantics. + if len(exists) == 1: + ppl = ppl.where(exists[0]) + else: + ppl = ppl.where(pipeline_expressions.And(*exists)) + + # Add sort orderings + ppl = ppl.sort(*orderings) + + # Cursors, Limit and Offset + if self._start_at or self._end_at or self._limit_to_last: + raise NotImplementedError( + "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." + ) + else: # Limit & Offset without cursors + if self._offset: + ppl = ppl.offset(self._offset) + if self._limit: + ppl = ppl.limit(self._limit) + + return ppl + def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index ef57f5b72..4639e0f7d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -587,6 +587,18 @@ def is_nan(self) -> "BooleanExpr": """ return BooleanExpr("is_nan", [self]) + @expose_as_static + def is_null(self) -> "BooleanExpr": + """Creates an expression that checks if this expression evaluates to 'Null'. + + Example: + >>> Field.of("value").is_null() + + Returns: + A new `Expr` representing the 'isNull' check. + """ + return BooleanExpr("is_null", [self]) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -627,6 +639,7 @@ def average(self) -> "Expr": """ return AggregateFunction("average", [self]) + @expose_as_static def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -1312,9 +1325,9 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.equal(None)) + return And(field.exists(), field.is_null()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.equal(None))) + return And(field.exists(), Not(field.is_null())) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): @@ -1361,7 +1374,7 @@ class And(BooleanExpr): Example: >>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND >>> # the 'status' field is "active" - >>> Expr.And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'AND' together. @@ -1377,7 +1390,7 @@ class Not(BooleanExpr): Example: >>> # Find documents where the 'completed' field is NOT true - >>> Expr.Not(Field.of("completed").equal(True)) + >>> Not(Field.of("completed").equal(True)) Args: condition: The filter condition to negate. @@ -1394,7 +1407,7 @@ class Or(BooleanExpr): Example: >>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR >>> # the 'status' field is "active" - >>> Expr.Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'OR' together. @@ -1411,7 +1424,7 @@ class Xor(BooleanExpr): Example: >>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London", >>> # or 'status' is "active". - >>> Expr.Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'XOR' together. @@ -1428,7 +1441,7 @@ class Conditional(BooleanExpr): Example: >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". - >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); + >>> Conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); Args: condition: The condition to evaluate. @@ -1440,3 +1453,24 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): super().__init__( "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) + +class Count(AggregateFunction): + """ + Represents an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + >>> Count(Field.of("productId")) + >>> Count().as_("count") + + Args: + expression: The expression or field to count. If None, counts all stage inputs. + """ + + def __init__(self, expression: Expr | None = None): + expression_list = [expression] if expression else [] + super().__init__( + "count", expression_list, use_infix_repr=bool(expression_list) + ) diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index c146a5763..5a93a869e 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -20,3 +20,4 @@ # run all tests against default database, and a named database # TODO: add enterprise mode when GA (RunQuery not currently supported) TEST_DATABASES = [None, FIRESTORE_OTHER_DB] +TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 9909fb05e..a8f94e2ba 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -42,7 +42,9 @@ MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, ) @@ -80,6 +82,58 @@ def cleanup(): operation() +def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list(itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + )) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) @@ -1231,7 +1285,7 @@ def query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs @@ -1245,9 +1299,10 @@ def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1256,9 +1311,10 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1267,9 +1323,10 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1279,9 +1336,10 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) @@ -1301,9 +1359,10 @@ def test_query_stream_w_not_eq_op(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1313,9 +1372,10 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1327,9 +1387,10 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1341,9 +1402,10 @@ def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1363,6 +1425,7 @@ def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1381,13 +1444,14 @@ def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1407,7 +1471,7 @@ def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1425,9 +1489,10 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1441,13 +1506,14 @@ def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1463,7 +1529,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): results.get_explain_metrics() - + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." @@ -1646,7 +1712,7 @@ def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1672,6 +1738,7 @@ def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1682,6 +1749,7 @@ def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1701,7 +1769,7 @@ def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1731,7 +1799,8 @@ def test_collection_group_queries(client, cleanup, database): snapshots = list(query.stream()) found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1777,7 +1846,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1820,6 +1889,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1841,6 +1911,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.skipif( @@ -2129,7 +2200,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2150,6 +2221,7 @@ def on_snapshot(docs, changes, read_time): query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) + verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2490,7 +2562,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2527,6 +2599,7 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) + verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e @@ -2566,7 +2639,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2592,6 +2665,7 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3160,7 +3234,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ @@ -3173,9 +3247,10 @@ def test_query_with_and_composite_filter(collection, database): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ @@ -3196,9 +3271,10 @@ def test_query_with_or_composite_filter(collection, database): assert gt_5 > 0 assert lt_10 > 0 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) @@ -3243,9 +3319,10 @@ def test_aggregation_queries_with_read_time( assert len(old_result) == 1 for r in old_result[0]: assert r.value == expected_value + verify_pipeline(aggregation_query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( @@ -3266,6 +3343,7 @@ def test_query_with_complex_composite_filter(collection, database): assert sum_0 > 0 assert sum_4 > 0 + verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3288,13 +3366,14 @@ def test_query_with_complex_composite_filter(collection, database): assert b_3 is True assert b_not_3 is True + verify_pipeline(query) @pytest.mark.parametrize( "aggregation_type,aggregation_args,expected", [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_aggregation_query_in_transaction( client, cleanup, @@ -3335,13 +3414,14 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True + verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function assert inner_fn_ran is True -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request @@ -3380,6 +3460,7 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index bc79ee2df..b78a77786 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -53,7 +53,9 @@ MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, ) RETRIES = retries.AsyncRetry( @@ -160,6 +162,61 @@ async def cleanup(): await operation() +async def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in await query.get()] + ) + ) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in await query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + await pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() async for s in pipeline.stream()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + + @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" @@ -1203,7 +1260,7 @@ async def async_query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs @@ -1217,9 +1274,10 @@ async def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1228,9 +1286,10 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1239,9 +1298,10 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1251,9 +1311,10 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1265,9 +1326,10 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1279,9 +1341,10 @@ async def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1301,6 +1364,7 @@ async def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1319,13 +1383,14 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = [i async for i in query.stream()] assert len(values) == 0 + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1345,7 +1410,7 @@ async def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1363,9 +1428,10 @@ async def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1379,13 +1445,14 @@ async def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + await verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1404,6 +1471,7 @@ async def test_query_stream_or_get_w_no_explain_options(query_docs, database, me # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): await results.get_explain_metrics() + await verify_pipeline(query) @pytest.mark.skipif( @@ -1570,7 +1638,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1596,6 +1664,7 @@ async def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + await verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1606,6 +1675,7 @@ async def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + await verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1625,7 +1695,7 @@ async def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1655,7 +1725,8 @@ async def test_collection_group_queries(client, cleanup, database): snapshots = [i async for i in query.stream()] found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1701,7 +1772,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1743,6 +1814,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + await verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1764,6 +1836,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + await verify_pipeline(query) @pytest.mark.skipif( diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 69ca69ec7..5064e87ae 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -20,6 +20,7 @@ from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, AvgAggregation, + BaseAggregation, CountAggregation, SumAggregation, ) @@ -27,6 +28,7 @@ from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.types import RunAggregationQueryResponse +from google.cloud.firestore_v1.field_path import FieldPath from google.protobuf.timestamp_pb2 import Timestamp from tests.unit.v1._test_helpers import ( make_aggregation_query, @@ -121,6 +123,65 @@ def test_avg_aggregation_no_alias_to_pb(): assert got_pb.alias == "" +@pytest.mark.parametrize( + "in_alias,expected_alias", [("total", "total"), (None, "field_1")] +) +def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Count + + count_aggregation = CountAggregation(alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Count) + assert len(got.expr.params) == 0 + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Sum + + count_aggregation = SumAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Sum) + assert got.expr.params[0].path == expected_path + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Avg + + count_aggregation = AvgAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Avg) + assert got.expr.params[0].path == expected_path + + +def test_aggregation__pipeline_alias_increment(): + """ + BaseAggregation.__pipeline_alias should pull from an autoindexer to populate field numbers + """ + autoindex = iter(range(10)) + mock_instance = mock.Mock() + mock_instance.alias = None + for i in range(10): + got_name = BaseAggregation._pipeline_alias(mock_instance, autoindex) + assert got_name == f"field_{i}" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -894,6 +955,16 @@ def test_aggregation_query_stream_w_explain_options_analyze_false(): _aggregation_query_stream_helper(explain_options=ExplainOptions(analyze=False)) +def test_aggretgation__to_dict(): + expected_alias = "alias" + expected_value = "value" + instance = AggregationResult(alias=expected_alias, value=expected_value) + dict_result = instance._to_dict() + assert len(dict_result) == 1 + assert next(iter(dict_result)) == expected_alias + assert dict_result[expected_alias] == expected_value + + def test_aggregation_from_query(): from google.cloud.firestore_v1 import _helpers @@ -952,3 +1023,147 @@ def test_aggregation_from_query(): metadata=client._rpc_metadata, **kwargs, ) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 9140f53e8..fdd4a1450 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -31,6 +31,7 @@ from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError from google.cloud.firestore_v1.query_results import QueryResultsList +from google.cloud.firestore_v1.field_path import FieldPath _PROJECT = "PROJECT" @@ -696,3 +697,147 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): explain_options = ExplainOptions(analyze=False) await _async_aggregation_query_stream_helper(explain_options=explain_options) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_async_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_async_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index a0194ace5..353997b8e 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -601,3 +601,23 @@ def test_asynccollectionreference_recursive(): col = _make_async_collection_reference("collection") assert isinstance(col.recursive(), AsyncQuery) + + +def test_asynccollectionreference_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = make_async_client() + collection = _make_async_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, AsyncPipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_asynccollectionreference_pipeline_no_client(): + collection = _make_async_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 54c80e5ad..dc5eb9e8a 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -909,3 +909,22 @@ async def test_asynccollectiongroup_get_partitions_w_offset(): query = _make_async_collection_group(parent).offset(10) with pytest.raises(ValueError): [i async for i in query.get_partitions(2)] + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 22baa0c5f..7f7be9c07 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -422,6 +422,20 @@ def test_basecollectionreference_end_at(mock_query): assert query == mock_query.end_at.return_value +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_pipeline(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = _make_base_collection_reference("collection") + pipeline = collection.pipeline() + + mock_query.pipeline.assert_called_once_with() + assert pipeline == mock_query.pipeline.return_value + + @mock.patch("random.choice") def test__auto_id(mock_rand_choice): from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS, _auto_id diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7804b0430..9bb3e61f8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -18,6 +18,7 @@ import pytest from tests.unit.v1._test_helpers import make_client +from google.cloud.firestore_v1 import _pipeline_stages as stages def _make_base_query(*args, **kwargs): @@ -1993,6 +1994,175 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time +def test__query_pipeline_no_client(): + mock_parent = mock.Mock() + mock_parent._client = None + query = _make_base_query(mock_parent) + with pytest.raises(ValueError, match="client"): + query.pipeline() + + +def test__query_pipeline_decendants(): + client = make_client() + query = client.collection_group("my_col") + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.CollectionGroup) + assert stage.collection_id == "my_col" + + +@pytest.mark.parametrize( + "in_path,out_path", + [ + ("my_col/doc/", "/my_col/doc/"), + ("/my_col/doc", "/my_col/doc"), + ("my_col/doc/sub_col", "/my_col/doc/sub_col"), + ], +) +def test__query_pipeline_no_decendants(in_path, out_path): + client = make_client() + collection = client.collection(in_path) + query = collection._query() + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.Collection) + assert stage.path == out_path + + +def test__query_pipeline_composite_filter(): + from google.cloud.firestore_v1 import FieldFilter + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + in_filter = FieldFilter("field_a", "==", "value_a") + query = client.collection("my_col").where(filter=in_filter) + with mock.patch.object( + expr.FilterCondition, "_from_query_filter_pb" + ) as convert_mock: + pipeline = query.pipeline() + convert_mock.assert_called_once_with(in_filter._to_pb(), client) + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Where) + assert stage.condition == convert_mock.return_value + + +def test__query_pipeline_projections(): + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Select) + assert len(stage.projections) == 2 + assert stage.projections[0].path == "field_a" + assert stage.projections[1].path == "field_b.c" + + +def test__query_pipeline_order_exists_multiple(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query = client.collection("my_col").order_by("field_a").order_by("field_b") + pipeline = query.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline.stages) == 3 + where_stage = pipeline.stages[1] + assert isinstance(where_stage, stages.Where) + # should have and with both orderings + assert isinstance(where_stage.condition, expr.And) + assert len(where_stage.condition.params) == 2 + operands = [p for p in where_stage.condition.params] + assert isinstance(operands[0], expr.Exists) + assert operands[0].params[0].path == "field_a" + assert isinstance(operands[1], expr.Exists) + assert operands[1].params[0].path == "field_b" + + +def test__query_pipeline_order_exists_single(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_single = client.collection("my_col").order_by("field_c") + pipeline_single = query_single.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline_single.stages) == 3 + where_stage_single = pipeline_single.stages[1] + assert isinstance(where_stage_single, stages.Where) + assert isinstance(where_stage_single.condition, expr.Exists) + assert where_stage_single.condition.params[0].path == "field_c" + + +def test__query_pipeline_order_sorts(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1.base_query import BaseQuery + + client = make_client() + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.ASCENDING) + .order_by("field_b", direction=BaseQuery.DESCENDING) + ) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 3 + sort_stage = pipeline.stages[2] + assert isinstance(sort_stage, stages.Sort) + assert len(sort_stage.orders) == 2 + assert isinstance(sort_stage.orders[0], expr.Ordering) + assert sort_stage.orders[0].expr.path == "field_a" + assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + assert isinstance(sort_stage.orders[1], expr.Ordering) + assert sort_stage.orders[1].expr.path == "field_b" + assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING + + +def test__query_pipeline_unsupported(): + client = make_client() + query_start = client.collection("my_col").start_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_start.pipeline() + + query_end = client.collection("my_col").end_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_end.pipeline() + + query_limit_last = client.collection("my_col").limit_to_last(10) + with pytest.raises(NotImplementedError, match="limit_to_last"): + query_limit_last.pipeline() + + +def test__query_pipeline_limit(): + client = make_client() + query = client.collection("my_col").limit(15) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Limit) + assert stage.limit == 15 + + +def test__query_pipeline_offset(): + client = make_client() + query = client.collection("my_col").offset(5) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Offset) + assert stage.offset == 5 + + def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index da91651b9..9e615541a 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -15,6 +15,7 @@ import types import mock +import pytest from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -510,6 +511,27 @@ def test_stream_w_read_time(query_class): ) +def test_collectionreference_pipeline(): + from tests.unit.v1 import _test_helpers + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = _test_helpers.make_client() + collection = _make_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, Pipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_collectionreference_pipeline_no_client(): + collection = _make_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() + + @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index c5329df33..9f06c47b8 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -834,6 +834,15 @@ def test_is_nan(self): infix_instance = arg1.is_nan() assert infix_instance == instance + def test_is_null(self): + arg1 = self._make_arg("Value") + instance = Expr.is_ull(arg1) + assert instance.name == "is_null" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_null()" + infix_instance = arg1.is_null() + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) @@ -1179,6 +1188,12 @@ def test_count(self): infix_instance = arg1.count() assert infix_instance == instance + def test_base_count(self): + instance = expr.Count() + assert instance.name == "count" + assert instance.params == [] + assert repr(instance) == "Count()" + def test_minimum(self): arg1 = self._make_arg("Value") instance = Expr.minimum(arg1) diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index d5b36e56c..fadea7e12 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -185,8 +185,9 @@ def test_to_pb(self): instance = self._make_one(input_arg) result = instance._to_pb() assert result.name == "collection_group" - assert len(result.args) == 1 - assert result.args[0].string_value == "test" + assert len(result.args) == 2 + assert result.args[0].reference_value == "" + assert result.args[1].string_value == "test" assert len(result.options) == 0 diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index b8c37cf84..8b1217370 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1046,3 +1046,22 @@ def test_collection_group_get_partitions_w_offset(database): query = _make_collection_group(parent).offset(10) with pytest.raises(ValueError): list(query.get_partitions(2)) + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, Pipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, Pipeline)