From f4d89bd33e6abb1cf239324c226f06f71ef5e041 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 15 Jun 2020 22:16:27 -0500 Subject: [PATCH 01/47] feat: add async tests for AsyncClient --- noxfile.py | 28 +- tests/unit/v1/async/__init__.py | 13 + tests/unit/v1/async/test_async_client.py | 754 +++++++++++++++++++++++ 3 files changed, 788 insertions(+), 7 deletions(-) create mode 100644 tests/unit/v1/async/__init__.py create mode 100644 tests/unit/v1/async/test_async_client.py diff --git a/noxfile.py b/noxfile.py index facb0bb995..cafa9785c2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -63,14 +63,13 @@ def lint_setup_py(session): session.run("python", "setup.py", "check", "--restructuredtext", "--strict") -def default(session): +def default(session, test_dir, ignore_dir): # Install all test dependencies, then install this package in-place. session.install("mock", "pytest", "pytest-cov") session.install("-e", ".") # Run py.test against the unit tests. - session.run( - "py.test", + args = [ "--quiet", "--cov=google.cloud.firestore", "--cov=google.cloud", @@ -79,15 +78,30 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit"), + test_dir, *session.posargs, - ) + ] + + if ignore_dir: + args.insert(0, f"--ignore={ignore_dir}") + + session.run("py.test", *args) @nox.session(python=["2.7", "3.5", "3.6", "3.7", "3.8"]) def unit(session): - """Run the unit test suite.""" - default(session) + """Run the unit test suite for sync tests.""" + default( + session, + os.path.join("tests", "unit"), + os.path.join("tests", "unit", "v1", "async"), + ) + + +@nox.session(python=["3.7", "3.8"]) +def unit_async(session): + """Run the unit test suite for async tests.""" + default(session, os.path.join("tests", "unit", "v1", "async"), None) @nox.session(python=["2.7", "3.7"]) diff --git a/tests/unit/v1/async/__init__.py b/tests/unit/v1/async/__init__.py new file mode 100644 index 0000000000..ab67290952 --- /dev/null +++ b/tests/unit/v1/async/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py new file mode 100644 index 0000000000..476ab501c0 --- /dev/null +++ b/tests/unit/v1/async/test_async_client.py @@ -0,0 +1,754 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import types +import unittest + +import mock + + +class TestAsyncClient(unittest.TestCase): + + PROJECT = "my-prahjekt" + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_client import AsyncClient + + return AsyncClient + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def _make_default_one(self): + credentials = _make_credentials() + return self._make_one(project=self.PROJECT, credentials=credentials) + + def test_constructor(self): + from google.cloud.firestore_v1.async_client import _CLIENT_INFO + from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE + + credentials = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=credentials) + self.assertEqual(client.project, self.PROJECT) + self.assertEqual(client._credentials, credentials) + self.assertEqual(client._database, DEFAULT_DATABASE) + self.assertIs(client._client_info, _CLIENT_INFO) + self.assertIsNone(client._emulator_host) + + def test_constructor_with_emulator_host(self): + from google.cloud.firestore_v1.async_client import _FIRESTORE_EMULATOR_HOST + + credentials = _make_credentials() + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + client = self._make_one(project=self.PROJECT, credentials=credentials) + self.assertEqual(client._emulator_host, emulator_host) + getenv.assert_called_once_with(_FIRESTORE_EMULATOR_HOST) + + def test_constructor_explicit(self): + credentials = _make_credentials() + database = "now-db" + client_info = mock.Mock() + client_options = mock.Mock() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, + ) + self.assertEqual(client.project, self.PROJECT) + self.assertEqual(client._credentials, credentials) + self.assertEqual(client._database, database) + self.assertIs(client._client_info, client_info) + self.assertIs(client._client_options, client_options) + + def test_constructor_w_client_options(self): + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, + credentials=credentials, + client_options={"api_endpoint": "foo-firestore.googleapis.com"}, + ) + self.assertEqual(client._target, "foo-firestore.googleapis.com") + + @mock.patch( + "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + def test__firestore_api_property(self, mock_client): + mock_client.SERVICE_ADDRESS = "endpoint" + client = self._make_default_one() + client_info = client._client_info = mock.Mock() + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + mock_client.assert_called_once_with( + transport=client._transport, client_info=client_info + ) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + + @mock.patch( + "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + @mock.patch( + "google.cloud.firestore_v1.gapic.transports.firestore_grpc_transport.firestore_pb2_grpc.grpc.insecure_channel", + autospec=True, + ) + def test__firestore_api_property_with_emulator( + self, mock_insecure_channel, mock_client + ): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + client = self._make_default_one() + + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + + mock_insecure_channel.assert_called_once_with(emulator_host) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + + def test___database_string_property(self): + credentials = _make_credentials() + database = "cheeeeez" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + self.assertIsNone(client._database_string_internal) + database_string = client._database_string + expected = "projects/{}/databases/{}".format(client.project, client._database) + self.assertEqual(database_string, expected) + self.assertIs(database_string, client._database_string_internal) + + # Swap it out with a unique value to verify it is cached. + client._database_string_internal = mock.sentinel.cached + self.assertIs(client._database_string, mock.sentinel.cached) + + def test___rpc_metadata_property(self): + credentials = _make_credentials() + database = "quanta" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + + self.assertEqual( + client._rpc_metadata, + [("google-cloud-resource-prefix", client._database_string)], + ) + + def test__rpc_metadata_property_with_emulator(self): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + + credentials = _make_credentials() + database = "quanta" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + + self.assertEqual( + client._rpc_metadata, + [ + ("google-cloud-resource-prefix", client._database_string), + ("authorization", "Bearer owner"), + ], + ) + + def test_collection_factory(self): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "users" + client = self._make_default_one() + collection = client.collection(collection_id) + + self.assertEqual(collection._path, (collection_id,)) + self.assertIs(collection._client, client) + self.assertIsInstance(collection, CollectionReference) + + def test_collection_factory_nested(self): + from google.cloud.firestore_v1.collection import CollectionReference + + client = self._make_default_one() + parts = ("users", "alovelace", "beep") + collection_path = "/".join(parts) + collection1 = client.collection(collection_path) + + self.assertEqual(collection1._path, parts) + self.assertIs(collection1._client, client) + self.assertIsInstance(collection1, CollectionReference) + + # Make sure using segments gives the same result. + collection2 = client.collection(*parts) + self.assertEqual(collection2._path, parts) + self.assertIs(collection2._client, client) + self.assertIsInstance(collection2, CollectionReference) + + def test_collection_group(self): + client = self._make_default_one() + query = client.collection_group("collectionId").where("foo", "==", u"bar") + + assert query._all_descendants + assert query._field_filters[0].field.field_path == "foo" + assert query._field_filters[0].value.string_value == u"bar" + assert query._field_filters[0].op == query._field_filters[0].EQUAL + assert query._parent.id == "collectionId" + + def test_collection_group_no_slashes(self): + client = self._make_default_one() + with self.assertRaises(ValueError): + client.collection_group("foo/bar") + + def test_document_factory(self): + from google.cloud.firestore_v1.document import DocumentReference + + parts = ("rooms", "roomA") + client = self._make_default_one() + doc_path = "/".join(parts) + document1 = client.document(doc_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, DocumentReference) + + # Make sure using segments gives the same result. + document2 = client.document(*parts) + self.assertEqual(document2._path, parts) + self.assertIs(document2._client, client) + self.assertIsInstance(document2, DocumentReference) + + def test_document_factory_w_absolute_path(self): + from google.cloud.firestore_v1.document import DocumentReference + + parts = ("rooms", "roomA") + client = self._make_default_one() + doc_path = "/".join(parts) + to_match = client.document(doc_path) + document1 = client.document(to_match._document_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, DocumentReference) + + def test_document_factory_w_nested_path(self): + from google.cloud.firestore_v1.document import DocumentReference + + client = self._make_default_one() + parts = ("rooms", "roomA", "shoes", "dressy") + doc_path = "/".join(parts) + document1 = client.document(doc_path) + + self.assertEqual(document1._path, parts) + self.assertIs(document1._client, client) + self.assertIsInstance(document1, DocumentReference) + + # Make sure using segments gives the same result. + document2 = client.document(*parts) + self.assertEqual(document2._path, parts) + self.assertIs(document2._client, client) + self.assertIsInstance(document2, DocumentReference) + + def test_field_path(self): + klass = self._get_target_class() + self.assertEqual(klass.field_path("a", "b", "c"), "a.b.c") + + def test_write_option_last_update(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import LastUpdateOption + + timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) + + klass = self._get_target_class() + option = klass.write_option(last_update_time=timestamp) + self.assertIsInstance(option, LastUpdateOption) + self.assertEqual(option._last_update_time, timestamp) + + def test_write_option_exists(self): + from google.cloud.firestore_v1._helpers import ExistsOption + + klass = self._get_target_class() + + option1 = klass.write_option(exists=False) + self.assertIsInstance(option1, ExistsOption) + self.assertFalse(option1._exists) + + option2 = klass.write_option(exists=True) + self.assertIsInstance(option2, ExistsOption) + self.assertTrue(option2._exists) + + def test_write_open_neither_arg(self): + from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option() + + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) + + def test_write_multiple_args(self): + from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option(exists=False, last_update_time=mock.sentinel.timestamp) + + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) + + def test_write_bad_arg(self): + from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option(spinach="popeye") + + extra = "{!r} was provided".format("spinach") + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) + + def test_collections(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.collection import CollectionReference + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = mock.Mock(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + firestore_api.list_collection_ids.return_value = iterator + + collections = list(asyncio.run(client.collections())) + + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, CollectionReference) + self.assertEqual(collection.parent, None) + self.assertEqual(collection.id, collection_id) + + base_path = client._database_string + "/documents" + firestore_api.list_collection_ids.assert_called_once_with( + base_path, metadata=client._rpc_metadata + ) + + async def _get_all_helper(self, client, references, document_pbs, **kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["batch_get_documents"]) + response_iterator = iter(document_pbs) + firestore_api.batch_get_documents.return_value = response_iterator + + # Attach the fake GAPIC to a real client. + client._firestore_api_internal = firestore_api + + # Actually call get_all(). + snapshots = client.get_all(references, **kwargs) + self.assertIsInstance(snapshots, types.AsyncGeneratorType) + + return [s async for s in snapshots] + + def _info_for_get_all(self, data1, data2): + client = self._make_default_one() + document1 = client.document("pineapple", "lamp1") + document2 = client.document("pineapple", "lamp2") + + # Make response protobufs. + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + + document_pb2, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document_pb2, read_time=read_time) + + return client, document1, document2, response1, response2 + + def test_get_all(self): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.document import DocumentSnapshot + + data1 = {"a": u"cheese"} + data2 = {"b": True, "c": 18} + info = self._info_for_get_all(data1, data2) + client, document1, document2, response1, response2 = info + + # Exercise the mocked ``batch_get_documents``. + field_paths = ["a", "b"] + snapshots = asyncio.run( + self._get_all_helper( + client, + [document1, document2], + [response1, response2], + field_paths=field_paths, + ) + ) + self.assertEqual(len(snapshots), 2) + + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document1) + self.assertEqual(snapshot1._data, data1) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document2) + self.assertEqual(snapshot2._data, data2) + + # Verify the call to the mock. + doc_paths = [document1._document_path, document2._document_path] + mask = common_pb2.DocumentMask(field_paths=field_paths) + client._firestore_api.batch_get_documents.assert_called_once_with( + client._database_string, + doc_paths, + mask, + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_get_all_with_transaction(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + data = {"so-much": 484} + info = self._info_for_get_all(data, {}) + client, document, _, response, _ = info + transaction = client.transaction() + txn_id = b"the-man-is-non-stop" + transaction._id = txn_id + + # Exercise the mocked ``batch_get_documents``. + snapshots = asyncio.run( + self._get_all_helper( + client, [document], [response], transaction=transaction + ) + ) + self.assertEqual(len(snapshots), 1) + + snapshot = snapshots[0] + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, document) + self.assertEqual(snapshot._data, data) + + # Verify the call to the mock. + doc_paths = [document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + client._database_string, + doc_paths, + None, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + def test_get_all_unknown_result(self): + from google.cloud.firestore_v1.async_client import _BAD_DOC_TEMPLATE + + info = self._info_for_get_all({"z": 28.5}, {}) + client, document, _, _, response = info + + # Exercise the mocked ``batch_get_documents``. + with self.assertRaises(ValueError) as exc_info: + asyncio.run(self._get_all_helper(client, [document], [response])) + + err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + # Verify the call to the mock. + doc_paths = [document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + client._database_string, + doc_paths, + None, + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_get_all_wrong_order(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + data1 = {"up": 10} + data2 = {"down": -10} + info = self._info_for_get_all(data1, data2) + client, document1, document2, response1, response2 = info + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) + + # Exercise the mocked ``batch_get_documents``. + snapshots = asyncio.run( + self._get_all_helper( + client, + [document1, document2, document3], + [response2, response1, response3], + ) + ) + + self.assertEqual(len(snapshots), 3) + + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document2) + self.assertEqual(snapshot1._data, data2) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document1) + self.assertEqual(snapshot2._data, data1) + + self.assertFalse(snapshots[2].exists) + + # Verify the call to the mock. + doc_paths = [ + document1._document_path, + document2._document_path, + document3._document_path, + ] + client._firestore_api.batch_get_documents.assert_called_once_with( + client._database_string, + doc_paths, + None, + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_batch(self): + from google.cloud.firestore_v1.batch import WriteBatch + + client = self._make_default_one() + batch = client.batch() + self.assertIsInstance(batch, WriteBatch) + self.assertIs(batch._client, client) + self.assertEqual(batch._write_pbs, []) + + def test_transaction(self): + from google.cloud.firestore_v1.transaction import Transaction + + client = self._make_default_one() + transaction = client.transaction(max_attempts=3, read_only=True) + self.assertIsInstance(transaction, Transaction) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, 3) + self.assertTrue(transaction._read_only) + self.assertIsNone(transaction._id) + + +class Test__reference_info(unittest.TestCase): + @staticmethod + def _call_fut(references): + from google.cloud.firestore_v1.async_client import _reference_info + + return _reference_info(references) + + def test_it(self): + from google.cloud.firestore_v1.async_client import AsyncClient + + credentials = _make_credentials() + client = AsyncClient(project="hi-projject", credentials=credentials) + + reference1 = client.document("a", "b") + reference2 = client.document("a", "b", "c", "d") + reference3 = client.document("a", "b") + reference4 = client.document("f", "g") + + doc_path1 = reference1._document_path + doc_path2 = reference2._document_path + doc_path3 = reference3._document_path + doc_path4 = reference4._document_path + self.assertEqual(doc_path1, doc_path3) + + document_paths, reference_map = self._call_fut( + [reference1, reference2, reference3, reference4] + ) + self.assertEqual(document_paths, [doc_path1, doc_path2, doc_path3, doc_path4]) + # reference3 over-rides reference1. + expected_map = { + doc_path2: reference2, + doc_path3: reference3, + doc_path4: reference4, + } + self.assertEqual(reference_map, expected_map) + + +class Test__get_reference(unittest.TestCase): + @staticmethod + def _call_fut(document_path, reference_map): + from google.cloud.firestore_v1.async_client import _get_reference + + return _get_reference(document_path, reference_map) + + def test_success(self): + doc_path = "a/b/c" + reference_map = {doc_path: mock.sentinel.reference} + self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) + + def test_failure(self): + from google.cloud.firestore_v1.async_client import _BAD_DOC_TEMPLATE + + doc_path = "1/888/call-now" + with self.assertRaises(ValueError) as exc_info: + self._call_fut(doc_path, {}) + + err_msg = _BAD_DOC_TEMPLATE.format(doc_path) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + +class Test__parse_batch_get(unittest.TestCase): + @staticmethod + def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): + from google.cloud.firestore_v1.async_client import _parse_batch_get + + return _parse_batch_get(get_doc_response, reference_map, client) + + @staticmethod + def _dummy_ref_string(): + from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE + + project = u"bazzzz" + collection_id = u"fizz" + document_id = u"buzz" + return u"projects/{}/databases/{}/documents/{}/{}".format( + project, DEFAULT_DATABASE, collection_id, document_id + ) + + def test_found(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1.document import DocumentSnapshot + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + ref_string = self._dummy_ref_string() + document_pb = document_pb2.Document( + name=ref_string, + fields={ + "foo": document_pb2.Value(double_value=1.5), + "bar": document_pb2.Value(string_value=u"skillz"), + }, + create_time=create_time, + update_time=update_time, + ) + response_pb = _make_batch_response(found=document_pb, read_time=read_time) + + reference_map = {ref_string: mock.sentinel.reference} + snapshot = self._call_fut(response_pb, reference_map) + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, mock.sentinel.reference) + self.assertEqual(snapshot._data, {"foo": 1.5, "bar": u"skillz"}) + self.assertTrue(snapshot._exists) + self.assertEqual(snapshot.read_time, read_time) + self.assertEqual(snapshot.create_time, create_time) + self.assertEqual(snapshot.update_time, update_time) + + def test_missing(self): + from google.cloud.firestore_v1.document import DocumentReference + + ref_string = self._dummy_ref_string() + response_pb = _make_batch_response(missing=ref_string) + document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) + reference_map = {ref_string: document} + snapshot = self._call_fut(response_pb, reference_map) + self.assertFalse(snapshot.exists) + self.assertEqual(snapshot.id, "bazz") + self.assertIsNone(snapshot._data) + + def test_unset_result_type(self): + response_pb = _make_batch_response() + with self.assertRaises(ValueError): + self._call_fut(response_pb, {}) + + def test_unknown_result_type(self): + response_pb = mock.Mock(spec=["WhichOneof"]) + response_pb.WhichOneof.return_value = "zoob_value" + + with self.assertRaises(ValueError): + self._call_fut(response_pb, {}) + + response_pb.WhichOneof.assert_called_once_with("result") + + +class Test__get_doc_mask(unittest.TestCase): + @staticmethod + def _call_fut(field_paths): + from google.cloud.firestore_v1.async_client import _get_doc_mask + + return _get_doc_mask(field_paths) + + def test_none(self): + self.assertIsNone(self._call_fut(None)) + + def test_paths(self): + from google.cloud.firestore_v1.proto import common_pb2 + + field_paths = ["a.b", "c"] + result = self._call_fut(field_paths) + expected = common_pb2.DocumentMask(field_paths=field_paths) + self.assertEqual(result, expected) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_batch_response(**kwargs): + from google.cloud.firestore_v1.proto import firestore_pb2 + + return firestore_pb2.BatchGetDocumentsResponse(**kwargs) + + +def _doc_get_info(ref_string, values): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + document_pb = document_pb2.Document( + name=ref_string, + fields=_helpers.encode_dict(values), + create_time=create_time, + update_time=update_time, + ) + + return document_pb, read_time From d986ed893c34908479959ac6f0fe3235255e8479 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 15 Jun 2020 22:16:50 -0500 Subject: [PATCH 02/47] feat: add AsyncClient implementation --- google/cloud/firestore_v1/async_client.py | 619 ++++++++++++++++++++++ 1 file changed, 619 insertions(+) create mode 100644 google/cloud/firestore_v1/async_client.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py new file mode 100644 index 0000000000..952b412ce5 --- /dev/null +++ b/google/cloud/firestore_v1/async_client.py @@ -0,0 +1,619 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client for interacting with the Google Cloud Firestore API. + +This is the base from which all interactions with the API occur. + +In the hierarchy of API concepts + +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.collection.CollectionReference` +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.document.DocumentReference` +""" +import os + +import google.api_core.client_options +from google.api_core.gapic_v1 import client_info +from google.cloud.client import ClientWithProject + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import __version__ +from google.cloud.firestore_v1 import query +from google.cloud.firestore_v1 import types +from google.cloud.firestore_v1.batch import WriteBatch +from google.cloud.firestore_v1.collection import CollectionReference +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.document import DocumentSnapshot +from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.gapic import firestore_client +from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport +from google.cloud.firestore_v1.transaction import Transaction + + +DEFAULT_DATABASE = "(default)" +"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" +_BAD_OPTION_ERR = ( + "Exactly one of ``last_update_time`` or ``exists`` " "must be provided." +) +_BAD_DOC_TEMPLATE = ( + "Document {!r} appeared in response but was not present among references" +) +_ACTIVE_TXN = "There is already an active transaction." +_INACTIVE_TXN = "There is no active transaction." +_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) +_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST" + + +class AsyncClient(ClientWithProject): + """Client for interacting with Google Cloud Firestore API. + + .. note:: + + Since the Cloud Firestore API requires the gRPC transport, no + ``_http`` argument is accepted by this class. + + Args: + project (Optional[str]): The project which the client acts on behalf + of. If not passed, falls back to the default inferred + from the environment. + credentials (Optional[~google.auth.credentials.Credentials]): The + OAuth2 Credentials to use for this client. If not passed, falls + back to the default inferred from the environment. + database (Optional[str]): The database name that the client targets. + For now, :attr:`DEFAULT_DATABASE` (the default value) is the + only valid database. + client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]): + The client info used to send a user-agent string along with API + requests. If ``None``, then default info will be used. Generally, + you only need to set this if you're developing your own library + or partner tool. + client_options (Union[dict, google.api_core.client_options.ClientOptions]): + Client options used to set user options on the client. API Endpoint + should be set through client_options. + """ + + SCOPE = ( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + ) + """The scopes required for authenticating with the Firestore service.""" + + _firestore_api_internal = None + _database_string_internal = None + _rpc_metadata_internal = None + + def __init__( + self, + project=None, + credentials=None, + database=DEFAULT_DATABASE, + client_info=_CLIENT_INFO, + client_options=None, + ): + # NOTE: This API has no use for the _http argument, but sending it + # will have no impact since the _http() @property only lazily + # creates a working HTTP object. + super(AsyncClient, self).__init__( + project=project, credentials=credentials, _http=None + ) + self._client_info = client_info + if client_options: + if type(client_options) == dict: + client_options = google.api_core.client_options.from_dict( + client_options + ) + self._client_options = client_options + + self._database = database + self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST) + + @property + def _firestore_api(self): + """Lazy-loading getter GAPIC Firestore API. + + Returns: + :class:`~google.cloud.gapic.firestore.v1`.firestore_client.FirestoreClient: + >> client.collection('top') + + For a sub-collection: + + .. code-block:: python + + >>> client.collection('mydocs/doc/subcol') + >>> # is the same as + >>> client.collection('mydocs', 'doc', 'subcol') + + Sub-collections can be nested deeper in a similar fashion. + + Args: + collection_path (Tuple[str, ...]): Can either be + + * A single ``/``-delimited path to a collection + * A tuple of collection path segments + + Returns: + :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + A reference to a collection in the Firestore database. + """ + if len(collection_path) == 1: + path = collection_path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) + else: + path = collection_path + + return CollectionReference(*path, client=self) + + def collection_group(self, collection_id): + """ + Creates and returns a new Query that includes all documents in the + database that are contained in a collection or subcollection with the + given collection_id. + + .. code-block:: python + + >>> query = client.collection_group('mygroup') + + @param {string} collectionId Identifies the collections to query over. + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. + @returns {Query} The created Query. + """ + if "/" in collection_id: + raise ValueError( + "Invalid collection_id " + + collection_id + + ". Collection IDs must not contain '/'." + ) + + collection = self.collection(collection_id) + return query.Query(collection, all_descendants=True) + + def document(self, *document_path): + """Get a reference to a document in a collection. + + For a top-level document: + + .. code-block:: python + + >>> client.document('collek/shun') + >>> # is the same as + >>> client.document('collek', 'shun') + + For a document in a sub-collection: + + .. code-block:: python + + >>> client.document('mydocs/doc/subcol/child') + >>> # is the same as + >>> client.document('mydocs', 'doc', 'subcol', 'child') + + Documents in sub-collections can be nested deeper in a similar fashion. + + Args: + document_path (Tuple[str, ...]): Can either be + + * A single ``/``-delimited path to a document + * A tuple of document path segments + + Returns: + :class:`~google.cloud.firestore_v1.document.DocumentReference`: + A reference to a document in a collection. + """ + if len(document_path) == 1: + path = document_path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) + else: + path = document_path + + # DocumentReference takes a relative path. Strip the database string if present. + base_path = self._database_string + "/documents/" + joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path) + if joined_path.startswith(base_path): + joined_path = joined_path[len(base_path) :] + path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) + + return DocumentReference(*path, client=self) + + @staticmethod + def field_path(*field_names): + """Create a **field path** from a list of nested field names. + + A **field path** is a ``.``-delimited concatenation of the field + names. It is used to represent a nested field. For example, + in the data + + .. code-block:: python + + data = { + 'aa': { + 'bb': { + 'cc': 10, + }, + }, + } + + the field path ``'aa.bb.cc'`` represents the data stored in + ``data['aa']['bb']['cc']``. + + Args: + field_names (Tuple[str, ...]): The list of field names. + + Returns: + str: The ``.``-delimited field path. + """ + return render_field_path(field_names) + + @staticmethod + def write_option(**kwargs): + """Create a write option for write operations. + + Write operations include :meth:`~google.cloud.DocumentReference.set`, + :meth:`~google.cloud.DocumentReference.update` and + :meth:`~google.cloud.DocumentReference.delete`. + + One of the following keyword arguments must be provided: + + * ``last_update_time`` (:class:`google.protobuf.timestamp_pb2.\ + Timestamp`): A timestamp. When set, the target document must + exist and have been last updated at that time. Protobuf + ``update_time`` timestamps are typically returned from methods + that perform write operations as part of a "write result" + protobuf or directly. + * ``exists`` (:class:`bool`): Indicates if the document being modified + should already exist. + + Providing no argument would make the option have no effect (so + it is not allowed). Providing multiple would be an apparent + contradiction, since ``last_update_time`` assumes that the + document **was** updated (it can't have been updated if it + doesn't exist) and ``exists`` indicate that it is unknown if the + document exists or not. + + Args: + kwargs (Dict[str, Any]): The keyword arguments described above. + + Raises: + TypeError: If anything other than exactly one argument is + provided by the caller. + + Returns: + :class:`~google.cloud.firestore_v1.client.WriteOption`: + The option to be used to configure a write message. + """ + if len(kwargs) != 1: + raise TypeError(_BAD_OPTION_ERR) + + name, value = kwargs.popitem() + if name == "last_update_time": + return _helpers.LastUpdateOption(value) + elif name == "exists": + return _helpers.ExistsOption(value) + else: + extra = "{!r} was provided".format(name) + raise TypeError(_BAD_OPTION_ERR, extra) + + async def get_all(self, references, field_paths=None, transaction=None): + """Retrieve a batch of documents. + + .. note:: + + Documents returned by this method are not guaranteed to be + returned in the same order that they are given in ``references``. + + .. note:: + + If multiple ``references`` refer to the same document, the server + will only return one result. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + references (List[.DocumentReference, ...]): Iterable of document + references to be retrieved. + field_paths (Optional[Iterable[str, ...]]): An iterable of field + paths (``.``-delimited list of field names) to use as a + projection of document fields in the returned results. If + no value is provided, all fields will be returned. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that these ``references`` will be + retrieved in. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + document_paths, reference_map = _reference_info(references) + mask = _get_doc_mask(field_paths) + response_iterator = self._firestore_api.batch_get_documents( + self._database_string, + document_paths, + mask, + transaction=_helpers.get_transaction_id(transaction), + metadata=self._rpc_metadata, + ) + + for get_doc_response in response_iterator: + yield _parse_batch_get(get_doc_response, reference_map, self) + + async def collections(self): + """List top-level collections of the client's database. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: + iterator of subcollections of the current document. + """ + iterator = self._firestore_api.list_collection_ids( + "{}/documents".format(self._database_string), metadata=self._rpc_metadata + ) + iterator.client = self + iterator.item_to_value = _item_to_collection_ref + return iterator + + def batch(self): + """Get a batch instance from this client. + + Returns: + :class:`~google.cloud.firestore_v1.batch.WriteBatch`: + A "write" batch to be used for accumulating document changes and + sending the changes all at once. + """ + return WriteBatch(self) + + def transaction(self, **kwargs): + """Get a transaction that uses this client. + + See :class:`~google.cloud.firestore_v1.transaction.Transaction` for + more information on transactions and the constructor arguments. + + Args: + kwargs (Dict[str, Any]): The keyword arguments (other than + ``client``) to pass along to the + :class:`~google.cloud.firestore_v1.transaction.Transaction` + constructor. + + Returns: + :class:`~google.cloud.firestore_v1.transaction.Transaction`: + A transaction attached to this client. + """ + return Transaction(self, **kwargs) + + +def _reference_info(references): + """Get information about document references. + + Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. + + Args: + references (List[.DocumentReference, ...]): Iterable of document + references. + + Returns: + Tuple[List[str, ...], Dict[str, .DocumentReference]]: A two-tuple of + + * fully-qualified documents paths for each reference in ``references`` + * a mapping from the paths to the original reference. (If multiple + ``references`` contains multiple references to the same document, + that key will be overwritten in the result.) + """ + document_paths = [] + reference_map = {} + for reference in references: + doc_path = reference._document_path + document_paths.append(doc_path) + reference_map[doc_path] = reference + + return document_paths, reference_map + + +def _get_reference(document_path, reference_map): + """Get a document reference from a dictionary. + + This just wraps a simple dictionary look-up with a helpful error that is + specific to :meth:`~google.cloud.firestore.client.Client.get_all`, the + **public** caller of this function. + + Args: + document_path (str): A fully-qualified document path. + reference_map (Dict[str, .DocumentReference]): A mapping (produced + by :func:`_reference_info`) of fully-qualified document paths to + document references. + + Returns: + .DocumentReference: The matching reference. + + Raises: + ValueError: If ``document_path`` has not been encountered. + """ + try: + return reference_map[document_path] + except KeyError: + msg = _BAD_DOC_TEMPLATE.format(document_path) + raise ValueError(msg) + + +def _parse_batch_get(get_doc_response, reference_map, client): + """Parse a `BatchGetDocumentsResponse` protobuf. + + Args: + get_doc_response (~google.cloud.proto.firestore.v1.\ + firestore_pb2.BatchGetDocumentsResponse): A single response (from + a stream) containing the "get" response for a document. + reference_map (Dict[str, .DocumentReference]): A mapping (produced + by :func:`_reference_info`) of fully-qualified document paths to + document references. + client (:class:`~google.cloud.firestore_v1.client.Client`): + A client that has a document factory. + + Returns: + [.DocumentSnapshot]: The retrieved snapshot. + + Raises: + ValueError: If the response has a ``result`` field (a oneof) other + than ``found`` or ``missing``. + """ + result_type = get_doc_response.WhichOneof("result") + if result_type == "found": + reference = _get_reference(get_doc_response.found.name, reference_map) + data = _helpers.decode_dict(get_doc_response.found.fields, client) + snapshot = DocumentSnapshot( + reference, + data, + exists=True, + read_time=get_doc_response.read_time, + create_time=get_doc_response.found.create_time, + update_time=get_doc_response.found.update_time, + ) + elif result_type == "missing": + reference = _get_reference(get_doc_response.missing, reference_map) + snapshot = DocumentSnapshot( + reference, + None, + exists=False, + read_time=get_doc_response.read_time, + create_time=None, + update_time=None, + ) + else: + raise ValueError( + "`BatchGetDocumentsResponse.result` (a oneof) had a field other " + "than `found` or `missing` set, or was unset" + ) + return snapshot + + +def _get_doc_mask(field_paths): + """Get a document mask if field paths are provided. + + Args: + field_paths (Optional[Iterable[str, ...]]): An iterable of field + paths (``.``-delimited list of field names) to use as a + projection of document fields in the returned results. + + Returns: + Optional[google.cloud.firestore_v1.types.DocumentMask]: A mask + to project documents to a restricted set of field paths. + """ + if field_paths is None: + return None + else: + return types.DocumentMask(field_paths=field_paths) + + +def _item_to_collection_ref(iterator, item): + """Convert collection ID to collection ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (str): ID of the collection + """ + return iterator.client.collection(item) From 54f0289cc95dafd76aade3b42e772d3f8fa5c4f9 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 15 Jun 2020 22:37:21 -0500 Subject: [PATCH 03/47] feat: add AsyncDocument implementation --- google/cloud/firestore_v1/async_document.py | 788 ++++++++++++++++++++ 1 file changed, 788 insertions(+) create mode 100644 google/cloud/firestore_v1/async_document.py diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py new file mode 100644 index 0000000000..41c26a03f2 --- /dev/null +++ b/google/cloud/firestore_v1/async_document.py @@ -0,0 +1,788 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing documents for the Google Cloud Firestore API.""" + +import copy + +import six + +from google.api_core import exceptions +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import field_path as field_path_module +from google.cloud.firestore_v1.proto import common_pb2 +from google.cloud.firestore_v1.watch import Watch + + +class AsyncDocumentReference(object): + """A reference to a document in a Firestore database. + + The document may already exist or can be created by this class. + + Args: + path (Tuple[str, ...]): The components in the document path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection (as well as the base document). + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client`. It represents + the client that created this document reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + _document_path_internal = None + + def __init__(self, *path, **kwargs): + _helpers.verify_path(path, is_collection=False) + self._path = path + self._client = kwargs.pop("client", None) + if kwargs: + raise TypeError( + "Received unexpected arguments", kwargs, "Only `client` is supported" + ) + + def __copy__(self): + """Shallow copy the instance. + + We leave the client "as-is" but tuple-unpack the path. + + Returns: + .AsyncDocumentReference: A copy of the current document. + """ + result = self.__class__(*self._path, client=self._client) + result._document_path_internal = self._document_path_internal + return result + + def __deepcopy__(self, unused_memo): + """Deep copy the instance. + + This isn't a true deep copy, wee leave the client "as-is" but + tuple-unpack the path. + + Returns: + .AsyncDocumentReference: A copy of the current document. + """ + return self.__copy__() + + def __eq__(self, other): + """Equality check against another instance. + + Args: + other (Any): A value to compare against. + + Returns: + Union[bool, NotImplementedType]: Indicating if the values are + equal. + """ + if isinstance(other, AsyncDocumentReference): + return self._client == other._client and self._path == other._path + else: + return NotImplemented + + def __hash__(self): + return hash(self._path) + hash(self._client) + + def __ne__(self, other): + """Inequality check against another instance. + + Args: + other (Any): A value to compare against. + + Returns: + Union[bool, NotImplementedType]: Indicating if the values are + not equal. + """ + if isinstance(other, AsyncDocumentReference): + return self._client != other._client or self._path != other._path + else: + return NotImplemented + + @property + def path(self): + """Database-relative for this document. + + Returns: + str: The document's relative path. + """ + return "/".join(self._path) + + @property + def _document_path(self): + """Create and cache the full path for this document. + + Of the form: + + ``projects/{project_id}/databases/{database_id}/... + documents/{document_path}`` + + Returns: + str: The full document path. + + Raises: + ValueError: If the current document reference has no ``client``. + """ + if self._document_path_internal is None: + if self._client is None: + raise ValueError("A document reference requires a `client`.") + self._document_path_internal = _get_document_path(self._client, self._path) + + return self._document_path_internal + + @property + def id(self): + """The document identifier (within its collection). + + Returns: + str: The last component of the path. + """ + return self._path[-1] + + @property + def parent(self): + """Collection that owns the current document. + + Returns: + :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + The parent collection. + """ + parent_path = self._path[:-1] + return self._client.collection(*parent_path) + + def collection(self, collection_id): + """Create a sub-collection underneath the current document. + + Args: + collection_id (str): The sub-collection identifier (sometimes + referred to as the "kind"). + + Returns: + :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + The child collection. + """ + child_path = self._path + (collection_id,) + return self._client.collection(*child_path) + + async def create(self, document_data): + """Create the current document in the Firestore database. + + Args: + document_data (dict): Property names and values to use for + creating a document. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the committed document. + A write result contains an ``update_time`` field. + + Raises: + :class:`~google.cloud.exceptions.Conflict`: + If the document already exists. + """ + batch = self._client.batch() + batch.create(self, document_data) + write_results = batch.commit() + return _first_write_result(write_results) + + async def set(self, document_data, merge=False): + """Replace the current document in the Firestore database. + + A write ``option`` can be specified to indicate preconditions of + the "set" operation. If no ``option`` is specified and this document + doesn't exist yet, this method will create it. + + Overwrites all content for the document with the fields in + ``document_data``. This method performs almost the same functionality + as :meth:`create`. The only difference is that this method doesn't + make any requirements on the existence of the document (unless + ``option`` is used), whereas as :meth:`create` will fail if the + document already exists. + + Args: + document_data (dict): Property names and values to use for + replacing a document. + merge (Optional[bool] or Optional[List]): + If True, apply merging instead of overwriting the state + of the document. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the committed document. A write + result contains an ``update_time`` field. + """ + batch = self._client.batch() + batch.set(self, document_data, merge=merge) + write_results = batch.commit() + return _first_write_result(write_results) + + async def update(self, field_updates, option=None): + """Update an existing document in the Firestore database. + + By default, this method verifies that the document exists on the + server before making updates. A write ``option`` can be specified to + override these preconditions. + + Each key in ``field_updates`` can either be a field name or a + **field path** (For more information on **field paths**, see + :meth:`~google.cloud.firestore_v1.client.Client.field_path`.) To + illustrate this, consider a document with + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + }, + 'other': True, + } + + stored on the server. If the field name is used in the update: + + .. code-block:: python + + >>> field_updates = { + ... 'foo': { + ... 'quux': 800, + ... }, + ... } + >>> document.update(field_updates) + + then all of ``foo`` will be overwritten on the server and the new + value will be + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'quux': 800, + }, + 'other': True, + } + + On the other hand, if a ``.``-delimited **field path** is used in the + update: + + .. code-block:: python + + >>> field_updates = { + ... 'foo.quux': 800, + ... } + >>> document.update(field_updates) + + then only ``foo.quux`` will be updated on the server and the + field ``foo.bar`` will remain intact: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + 'quux': 800, + }, + 'other': True, + } + + .. warning:: + + A **field path** can only be used as a top-level key in + ``field_updates``. + + To delete / remove a field from an existing document, use the + :attr:`~google.cloud.firestore_v1.transforms.DELETE_FIELD` sentinel. + So with the example above, sending + + .. code-block:: python + + >>> field_updates = { + ... 'other': firestore.DELETE_FIELD, + ... } + >>> document.update(field_updates) + + would update the value on the server to: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + }, + } + + To set a field to the current time on the server when the + update is received, use the + :attr:`~google.cloud.firestore_v1.transforms.SERVER_TIMESTAMP` + sentinel. + Sending + + .. code-block:: python + + >>> field_updates = { + ... 'foo.now': firestore.SERVER_TIMESTAMP, + ... } + >>> document.update(field_updates) + + would update the value on the server to: + + .. code-block:: python + + >>> snapshot = document.get() + >>> snapshot.to_dict() + { + 'foo': { + 'bar': 'baz', + 'now': datetime.datetime(2012, ...), + }, + 'other': True, + } + + Args: + field_updates (dict): Field names or paths to update and values + to update with. + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + + Returns: + :class:`~google.cloud.firestore_v1.types.WriteResult`: + The write result corresponding to the updated document. A write + result contains an ``update_time`` field. + + Raises: + ~google.cloud.exceptions.NotFound: If the document does not exist. + """ + batch = self._client.batch() + batch.update(self, field_updates, option=option) + write_results = batch.commit() + return _first_write_result(write_results) + + async def delete(self, option=None): + """Delete the current document in the Firestore database. + + Args: + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + + Returns: + :class:`google.protobuf.timestamp_pb2.Timestamp`: + The time that the delete request was received by the server. + If the document did not exist when the delete was sent (i.e. + nothing was deleted), this method will still succeed and will + still return the time that the request was received by the server. + """ + write_pb = _helpers.pb_for_delete(self._document_path, option) + commit_response = self._client._firestore_api.commit( + self._client._database_string, + [write_pb], + transaction=None, + metadata=self._client._rpc_metadata, + ) + + return commit_response.commit_time + + async def get(self, field_paths=None, transaction=None): + """Retrieve a snapshot of the current document. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + field_paths (Optional[Iterable[str, ...]]): An iterable of field + paths (``.``-delimited list of field names) to use as a + projection of document fields in the returned results. If + no value is provided, all fields will be returned. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this reference + will be retrieved in. + + Returns: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + A snapshot of the current document. If the document does not + exist at the time of the snapshot is taken, the snapshot's + :attr:`reference`, :attr:`data`, :attr:`update_time`, and + :attr:`create_time` attributes will all be ``None`` and + its :attr:`exists` attribute will be ``False``. + """ + if isinstance(field_paths, six.string_types): + raise ValueError("'field_paths' must be a sequence of paths, not a string.") + + if field_paths is not None: + mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + firestore_api = self._client._firestore_api + try: + document_pb = firestore_api.get_document( + self._document_path, + mask=mask, + transaction=_helpers.get_transaction_id(transaction), + metadata=self._client._rpc_metadata, + ) + except exceptions.NotFound: + data = None + exists = False + create_time = None + update_time = None + else: + data = _helpers.decode_dict(document_pb.fields, self._client) + exists = True + create_time = document_pb.create_time + update_time = document_pb.update_time + + return DocumentSnapshot( + reference=self, + data=data, + exists=exists, + read_time=None, # No server read_time available + create_time=create_time, + update_time=update_time, + ) + + async def collections(self, page_size=None): + """List subcollections of the current document. + + Args: + page_size (Optional[int]]): The maximum number of collections + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: + iterator of subcollections of the current document. If the + document does not exist at the time of `snapshot`, the + iterator will be empty + """ + iterator = self._client._firestore_api.list_collection_ids( + self._document_path, + page_size=page_size, + metadata=self._client._rpc_metadata, + ) + iterator.document = self + iterator.item_to_value = _item_to_collection_ref + return iterator + + def on_snapshot(self, callback): + """Watch this document. + + This starts a watch on this document using a background thread. The + provided callback is run on the snapshot. + + Args: + callback(Callable[[:class:`~google.cloud.firestore.document.DocumentSnapshot`], NoneType]): + a callback to run when a change occurs + + Example: + + .. code-block:: python + + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + collection_ref = db.collection(u'users') + + def on_snapshot(document_snapshot, changes, read_time): + doc = document_snapshot + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + doc_ref = db.collection(u'users').document( + u'alovelace' + unique_resource_id()) + + # Watch this document + doc_watch = doc_ref.on_snapshot(on_snapshot) + + # Terminate this watch + doc_watch.unsubscribe() + """ + return Watch.for_document( + self, callback, DocumentSnapshot, AsyncDocumentReference + ) + + +class DocumentSnapshot(object): + """A snapshot of document data in a Firestore database. + + This represents data retrieved at a specific time and may not contain + all fields stored for the document (i.e. a hand-picked selection of + fields may have been retrieved). + + Instances of this class are not intended to be constructed by hand, + rather they'll be returned as responses to various methods, such as + :meth:`~google.cloud.AsyncDocumentReference.get`. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.AsyncDocumentReference`): + A document reference corresponding to the document that contains + the data in this snapshot. + data (Dict[str, Any]): + The data retrieved in the snapshot. + exists (bool): + Indicates if the document existed at the time the snapshot was + retrieved. + read_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this snapshot was read from the server. + create_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this document was created. + update_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this document was last updated. + """ + + def __init__(self, reference, data, exists, read_time, create_time, update_time): + self._reference = reference + # We want immutable data, so callers can't modify this value + # out from under us. + self._data = copy.deepcopy(data) + self._exists = exists + self.read_time = read_time + self.create_time = create_time + self.update_time = update_time + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._reference == other._reference and self._data == other._data + + def __hash__(self): + seconds = self.update_time.seconds + nanos = self.update_time.nanos + return hash(self._reference) + hash(seconds) + hash(nanos) + + @property + def _client(self): + """The client that owns the document reference for this snapshot. + + Returns: + :class:`~google.cloud.firestore_v1.client.Client`: + The client that owns this document. + """ + return self._reference._client + + @property + def exists(self): + """Existence flag. + + Indicates if the document existed at the time this snapshot + was retrieved. + + Returns: + bool: The existence flag. + """ + return self._exists + + @property + def id(self): + """The document identifier (within its collection). + + Returns: + str: The last component of the path of the document. + """ + return self._reference.id + + @property + def reference(self): + """Document reference corresponding to document that owns this data. + + Returns: + :class:`~google.cloud.firestore_v1.document.AsyncDocumentReference`: + A document reference corresponding to this document. + """ + return self._reference + + def get(self, field_path): + """Get a value from the snapshot data. + + If the data is nested, for example: + + .. code-block:: python + + >>> snapshot.to_dict() + { + 'top1': { + 'middle2': { + 'bottom3': 20, + 'bottom4': 22, + }, + 'middle5': True, + }, + 'top6': b'\x00\x01 foo', + } + + a **field path** can be used to access the nested data. For + example: + + .. code-block:: python + + >>> snapshot.get('top1') + { + 'middle2': { + 'bottom3': 20, + 'bottom4': 22, + }, + 'middle5': True, + } + >>> snapshot.get('top1.middle2') + { + 'bottom3': 20, + 'bottom4': 22, + } + >>> snapshot.get('top1.middle2.bottom3') + 20 + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + A copy is returned since the data may contain mutable values, + but the data stored in the snapshot must remain immutable. + + Args: + field_path (str): A field path (``.``-delimited list of + field names). + + Returns: + Any or None: + (A copy of) the value stored for the ``field_path`` or + None if snapshot document does not exist. + + Raises: + KeyError: If the ``field_path`` does not match nested data + in the snapshot. + """ + if not self._exists: + return None + nested_data = field_path_module.get_nested_value(field_path, self._data) + return copy.deepcopy(nested_data) + + def to_dict(self): + """Retrieve the data contained in this snapshot. + + A copy is returned since the data may contain mutable values, + but the data stored in the snapshot must remain immutable. + + Returns: + Dict[str, Any] or None: + The data in the snapshot. Returns None if reference + does not exist. + """ + if not self._exists: + return None + return copy.deepcopy(self._data) + + +def _get_document_path(client, path): + """Convert a path tuple into a full path string. + + Of the form: + + ``projects/{project_id}/databases/{database_id}/... + documents/{document_path}`` + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that holds configuration details and a GAPIC client + object. + path (Tuple[str, ...]): The components in a document path. + + Returns: + str: The fully-qualified document path. + """ + parts = (client._database_string, "documents") + path + return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) + + +def _consume_single_get(response_iterator): + """Consume a gRPC stream that should contain a single response. + + The stream will correspond to a ``BatchGetDocuments`` request made + for a single document. + + Args: + response_iterator (~google.cloud.exceptions.GrpcRendezvous): A + streaming iterator returned from a ``BatchGetDocuments`` + request. + + Returns: + ~google.cloud.proto.firestore.v1.\ + firestore_pb2.BatchGetDocumentsResponse: The single "get" + response in the batch. + + Raises: + ValueError: If anything other than exactly one response is returned. + """ + # Calling ``list()`` consumes the entire iterator. + all_responses = list(response_iterator) + if len(all_responses) != 1: + raise ValueError( + "Unexpected response from `BatchGetDocumentsResponse`", + all_responses, + "Expected only one result", + ) + + return all_responses[0] + + +def _first_write_result(write_results): + """Get first write result from list. + + For cases where ``len(write_results) > 1``, this assumes the writes + occurred at the same time (e.g. if an update and transform are sent + at the same time). + + Args: + write_results (List[google.cloud.proto.firestore.v1.\ + write_pb2.WriteResult, ...]: The write results from a + ``CommitResponse``. + + Returns: + google.cloud.firestore_v1.types.WriteResult: The + lone write result from ``write_results``. + + Raises: + ValueError: If there are zero write results. This is likely to + **never** occur, since the backend should be stable. + """ + if not write_results: + raise ValueError("Expected at least one write result") + + return write_results[0] + + +def _item_to_collection_ref(iterator, item): + """Convert collection ID to collection ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (str): ID of the collection + """ + return iterator.document.collection(item) From 2879c7082ae954f859c5621bb255308fa2af9f02 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 15 Jun 2020 22:37:37 -0500 Subject: [PATCH 04/47] feat: add AsyncDocument support to AsyncClient --- google/cloud/firestore_v1/async_client.py | 30 +++++++++++------------ tests/unit/v1/async/test_async_client.py | 28 ++++++++++----------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 952b412ce5..a5de26f827 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -21,7 +21,7 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.collection.CollectionReference` * a :class:`~google.cloud.firestore_v1.client.Client` owns a - :class:`~google.cloud.firestore_v1.document.DocumentReference` + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ import os @@ -35,8 +35,8 @@ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference -from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.document import DocumentSnapshot +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.gapic import firestore_client from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport @@ -301,7 +301,7 @@ def document(self, *document_path): * A tuple of document path segments Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: + :class:`~google.cloud.firestore_v1.document.AsyncDocumentReference`: A reference to a document in a collection. """ if len(document_path) == 1: @@ -309,14 +309,14 @@ def document(self, *document_path): else: path = document_path - # DocumentReference takes a relative path. Strip the database string if present. + # AsyncDocumentReference takes a relative path. Strip the database string if present. base_path = self._database_string + "/documents/" joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path) if joined_path.startswith(base_path): joined_path = joined_path[len(base_path) :] path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) - return DocumentReference(*path, client=self) + return AsyncDocumentReference(*path, client=self) @staticmethod def field_path(*field_names): @@ -351,9 +351,9 @@ def field_path(*field_names): def write_option(**kwargs): """Create a write option for write operations. - Write operations include :meth:`~google.cloud.DocumentReference.set`, - :meth:`~google.cloud.DocumentReference.update` and - :meth:`~google.cloud.DocumentReference.delete`. + Write operations include :meth:`~google.cloud.AsyncDocumentReference.set`, + :meth:`~google.cloud.AsyncDocumentReference.update` and + :meth:`~google.cloud.AsyncDocumentReference.delete`. One of the following keyword arguments must be provided: @@ -417,7 +417,7 @@ async def get_all(self, references, field_paths=None, transaction=None): allowed). Args: - references (List[.DocumentReference, ...]): Iterable of document + references (List[.AsyncDocumentReference, ...]): Iterable of document references to be retrieved. field_paths (Optional[Iterable[str, ...]]): An iterable of field paths (``.``-delimited list of field names) to use as a @@ -493,11 +493,11 @@ def _reference_info(references): Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. Args: - references (List[.DocumentReference, ...]): Iterable of document + references (List[.AsyncDocumentReference, ...]): Iterable of document references. Returns: - Tuple[List[str, ...], Dict[str, .DocumentReference]]: A two-tuple of + Tuple[List[str, ...], Dict[str, .AsyncDocumentReference]]: A two-tuple of * fully-qualified documents paths for each reference in ``references`` * a mapping from the paths to the original reference. (If multiple @@ -523,12 +523,12 @@ def _get_reference(document_path, reference_map): Args: document_path (str): A fully-qualified document path. - reference_map (Dict[str, .DocumentReference]): A mapping (produced + reference_map (Dict[str, .AsyncDocumentReference]): A mapping (produced by :func:`_reference_info`) of fully-qualified document paths to document references. Returns: - .DocumentReference: The matching reference. + .AsyncDocumentReference: The matching reference. Raises: ValueError: If ``document_path`` has not been encountered. @@ -547,7 +547,7 @@ def _parse_batch_get(get_doc_response, reference_map, client): get_doc_response (~google.cloud.proto.firestore.v1.\ firestore_pb2.BatchGetDocumentsResponse): A single response (from a stream) containing the "get" response for a document. - reference_map (Dict[str, .DocumentReference]): A mapping (produced + reference_map (Dict[str, .AsyncDocumentReference]): A mapping (produced by :func:`_reference_info`) of fully-qualified document paths to document references. client (:class:`~google.cloud.firestore_v1.client.Client`): diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 476ab501c0..4c36f8d278 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -229,7 +229,7 @@ def test_collection_group_no_slashes(self): client.collection_group("foo/bar") def test_document_factory(self): - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference parts = ("rooms", "roomA") client = self._make_default_one() @@ -238,16 +238,16 @@ def test_document_factory(self): self.assertEqual(document1._path, parts) self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) + self.assertIsInstance(document1, AsyncDocumentReference) # Make sure using segments gives the same result. document2 = client.document(*parts) self.assertEqual(document2._path, parts) self.assertIs(document2._client, client) - self.assertIsInstance(document2, DocumentReference) + self.assertIsInstance(document2, AsyncDocumentReference) def test_document_factory_w_absolute_path(self): - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference parts = ("rooms", "roomA") client = self._make_default_one() @@ -257,10 +257,10 @@ def test_document_factory_w_absolute_path(self): self.assertEqual(document1._path, parts) self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) + self.assertIsInstance(document1, AsyncDocumentReference) def test_document_factory_w_nested_path(self): - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference client = self._make_default_one() parts = ("rooms", "roomA", "shoes", "dressy") @@ -269,13 +269,13 @@ def test_document_factory_w_nested_path(self): self.assertEqual(document1._path, parts) self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) + self.assertIsInstance(document1, AsyncDocumentReference) # Make sure using segments gives the same result. document2 = client.document(*parts) self.assertEqual(document2._path, parts) self.assertIs(document2._client, client) - self.assertIsInstance(document2, DocumentReference) + self.assertIsInstance(document2, AsyncDocumentReference) def test_field_path(self): klass = self._get_target_class() @@ -400,7 +400,7 @@ def _info_for_get_all(self, data1, data2): def test_get_all(self): from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import DocumentSnapshot data1 = {"a": u"cheese"} data2 = {"b": True, "c": 18} @@ -441,7 +441,7 @@ def test_get_all(self): ) def test_get_all_with_transaction(self): - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import DocumentSnapshot data = {"so-much": 484} info = self._info_for_get_all(data, {}) @@ -497,7 +497,7 @@ def test_get_all_unknown_result(self): ) def test_get_all_wrong_order(self): - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import DocumentSnapshot data1 = {"up": 10} data2 = {"down": -10} @@ -645,7 +645,7 @@ def _dummy_ref_string(): def test_found(self): from google.cloud.firestore_v1.proto import document_pb2 from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import DocumentSnapshot now = datetime.datetime.utcnow() read_time = _datetime_to_pb_timestamp(now) @@ -676,11 +676,11 @@ def test_found(self): self.assertEqual(snapshot.update_time, update_time) def test_missing(self): - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference ref_string = self._dummy_ref_string() response_pb = _make_batch_response(missing=ref_string) - document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) + document = AsyncDocumentReference("fizz", "bazz", client=mock.sentinel.client) reference_map = {ref_string: document} snapshot = self._call_fut(response_pb, reference_map) self.assertFalse(snapshot.exists) From f9935f7c30f223b3c0bb62cdbf142daca1341414 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 15 Jun 2020 22:59:46 -0500 Subject: [PATCH 05/47] feat: add AsyncDocument tests Note: tests relying on Collection will fail in this commit --- tests/unit/v1/async/test_async_document.py | 826 +++++++++++++++++++++ 1 file changed, 826 insertions(+) create mode 100644 tests/unit/v1/async/test_async_document.py diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py new file mode 100644 index 0000000000..d9bdea96aa --- /dev/null +++ b/tests/unit/v1/async/test_async_document.py @@ -0,0 +1,826 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import collections +import unittest + +import mock + + +class TestAsyncDocumentReference(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + return AsyncDocumentReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 + + document = self._make_one( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + self.assertIs(document._client, client) + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + self.assertEqual(document.path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(None, "before", "bad-collection-id", "fifteen") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None) + with self.assertRaises(ValueError): + self._make_one("Just", "A-Collection", "Sub") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) + + def test___copy__(self): + client = _make_client("rain") + document = self._make_one("a", "b", client=client) + # Access the document path so it is copied. + doc_path = document._document_path + self.assertEqual(doc_path, document._document_path_internal) + + new_document = document.__copy__() + self.assertIsNot(new_document, document) + self.assertIs(new_document._client, document._client) + self.assertEqual(new_document._path, document._path) + self.assertEqual( + new_document._document_path_internal, document._document_path_internal + ) + + def test___deepcopy__calls_copy(self): + client = mock.sentinel.client + document = self._make_one("a", "b", client=client) + document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) + + unused_memo = {} + new_document = document.__deepcopy__(unused_memo) + self.assertIs(new_document, mock.sentinel.new_doc) + document.__copy__.assert_called_once_with() + + def test__eq__same_type(self): + document1 = self._make_one("X", "YY", client=mock.sentinel.client) + document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) + document3 = self._make_one("X", "YY", client=mock.sentinel.client2) + document4 = self._make_one("X", "YY", client=mock.sentinel.client) + + pairs = ((document1, document2), (document1, document3), (document2, document3)) + for candidate1, candidate2 in pairs: + # We use == explicitly since assertNotEqual would use !=. + equality_val = candidate1 == candidate2 + self.assertFalse(equality_val) + + # Check the only equal one. + self.assertEqual(document1, document4) + self.assertIsNot(document1, document4) + + def test__eq__other_type(self): + document = self._make_one("X", "YY", client=mock.sentinel.client) + other = object() + equality_val = document == other + self.assertFalse(equality_val) + self.assertIs(document.__eq__(other), NotImplemented) + + def test___hash__(self): + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + document = self._make_one("X", "YY", client=client) + self.assertEqual(hash(document), hash(("X", "YY")) + hash(client)) + + def test__ne__same_type(self): + document1 = self._make_one("X", "YY", client=mock.sentinel.client) + document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) + document3 = self._make_one("X", "YY", client=mock.sentinel.client2) + document4 = self._make_one("X", "YY", client=mock.sentinel.client) + + self.assertNotEqual(document1, document2) + self.assertNotEqual(document1, document3) + self.assertNotEqual(document2, document3) + + # We use != explicitly since assertEqual would use ==. + inequality_val = document1 != document4 + self.assertFalse(inequality_val) + self.assertIsNot(document1, document4) + + def test__ne__other_type(self): + document = self._make_one("X", "YY", client=mock.sentinel.client) + other = object() + self.assertNotEqual(document, other) + self.assertIs(document.__ne__(other), NotImplemented) + + def test__document_path_property(self): + project = "hi-its-me-ok-bye" + client = _make_client(project=project) + + collection_id = "then" + document_id = "090909iii" + document = self._make_one(collection_id, document_id, client=client) + doc_path = document._document_path + expected = "projects/{}/databases/{}/documents/{}/{}".format( + project, client._database, collection_id, document_id + ) + self.assertEqual(doc_path, expected) + self.assertIs(document._document_path_internal, doc_path) + + # Make sure value is cached. + document._document_path_internal = mock.sentinel.cached + self.assertIs(document._document_path, mock.sentinel.cached) + + def test__document_path_property_no_client(self): + document = self._make_one("hi", "bye") + self.assertIsNone(document._client) + with self.assertRaises(ValueError): + getattr(document, "_document_path") + + self.assertIsNone(document._document_path_internal) + + def test_id_property(self): + document_id = "867-5309" + document = self._make_one("Co-lek-shun", document_id) + self.assertEqual(document.id, document_id) + + def test_parent_property(self): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + document = self._make_one(collection_id, document_id, client=client) + + parent = document.parent + self.assertIsInstance(parent, CollectionReference) + self.assertIs(parent._client, client) + self.assertEqual(parent._path, (collection_id,)) + + def test_collection_factory(self): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + new_collection = "fruits" + client = _make_client() + document = self._make_one(collection_id, document_id, client=client) + + child = document.collection(new_collection) + self.assertIsInstance(child, CollectionReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_id, document_id, new_collection)) + + @staticmethod + def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1 import _helpers + + return write_pb2.Write( + update=document_pb2.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common_pb2.Precondition(exists=False), + ) + + @staticmethod + def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.proto import firestore_pb2 + + response = mock.create_autospec(firestore_pb2.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response + + def test_create(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("foo", "twelve", client=client) + document_data = {"hello": "goodbye", "count": 99} + write_result = asyncio.run(document.create(document_data)) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + write_pb = self._write_pb_for_create(document._document_path, document_data) + firestore_api.commit.assert_called_once_with( + client._database_string, + [write_pb], + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_create_empty(self): + # Create a minimal fake GAPIC with a dummy response. + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + firestore_api = mock.Mock(spec=["commit"]) + document_reference = mock.create_autospec(AsyncDocumentReference) + snapshot = mock.create_autospec(DocumentSnapshot) + snapshot.exists = True + document_reference.get.return_value = snapshot + firestore_api.commit.return_value = self._make_commit_repsonse( + write_results=[document_reference] + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + client.get_all = mock.MagicMock() + client.get_all.exists.return_value = True + + # Actually make a document and call create(). + document = self._make_one("foo", "twelve", client=client) + document_data = {} + write_result = asyncio.run(document.create(document_data)) + self.assertTrue(asyncio.run(write_result.get()).exists) + + @staticmethod + def _write_pb_for_set(document_path, document_data, merge): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1 import _helpers + + write_pbs = write_pb2.Write( + update=document_pb2.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ) + ) + if merge: + field_paths = [ + field_path + for field_path, value in _helpers.extract_fields( + document_data, _helpers.FieldPath() + ) + ] + field_paths = [ + field_path.to_api_repr() for field_path in sorted(field_paths) + ] + mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + write_pbs.update_mask.CopyFrom(mask) + return write_pbs + + def _set_helper(self, merge=False, **option_kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("db-dee-bee") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("User", "Interface", client=client) + document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} + write_result = asyncio.run(document.set(document_data, merge)) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + write_pb = self._write_pb_for_set(document._document_path, document_data, merge) + + firestore_api.commit.assert_called_once_with( + client._database_string, + [write_pb], + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_set(self): + self._set_helper() + + def test_set_merge(self): + self._set_helper(merge=True) + + @staticmethod + def _write_pb_for_update(document_path, update_values, field_paths): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1 import _helpers + + return write_pb2.Write( + update=document_pb2.Document( + name=document_path, fields=_helpers.encode_dict(update_values) + ), + update_mask=common_pb2.DocumentMask(field_paths=field_paths), + current_document=common_pb2.Precondition(exists=True), + ) + + def _update_helper(self, **option_kwargs): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = collections.OrderedDict( + (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) + ) + if option_kwargs: + option = client.write_option(**option_kwargs) + write_result = asyncio.run(document.update(field_updates, option=option)) + else: + option = None + write_result = asyncio.run(document.update(field_updates)) + + # Verify the response and the mocks. + self.assertIs(write_result, mock.sentinel.write_result) + update_values = { + "hello": field_updates["hello"], + "then": {"do": field_updates["then.do"]}, + } + field_paths = list(field_updates.keys()) + write_pb = self._write_pb_for_update( + document._document_path, update_values, sorted(field_paths) + ) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + client._database_string, + [write_pb], + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_update_with_exists(self): + with self.assertRaises(ValueError): + self._update_helper(exists=True) + + def test_update(self): + self._update_helper() + + def test_update_with_precondition(self): + from google.protobuf import timestamp_pb2 + + timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + self._update_helper(last_update_time=timestamp) + + def test_empty_update(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = self._make_one("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = {} + with self.assertRaises(ValueError): + asyncio.run(document.update(field_updates)) + + def _delete_helper(self, **option_kwargs): + from google.cloud.firestore_v1.proto import write_pb2 + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = self._make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + + # Actually make a document and call delete(). + document = self._make_one("where", "we-are", client=client) + if option_kwargs: + option = client.write_option(**option_kwargs) + delete_time = asyncio.run(document.delete(option=option)) + else: + option = None + delete_time = asyncio.run(document.delete()) + + # Verify the response and the mocks. + self.assertIs(delete_time, mock.sentinel.commit_time) + write_pb = write_pb2.Write(delete=document._document_path) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + client._database_string, + [write_pb], + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_delete(self): + self._delete_helper() + + def test_delete_with_option(self): + from google.protobuf import timestamp_pb2 + + timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + self._delete_helper(last_update_time=timestamp_pb) + + def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): + from google.api_core.exceptions import NotFound + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.transaction import Transaction + + # Create a minimal fake GAPIC with a dummy response. + create_time = 123 + update_time = 234 + firestore_api = mock.Mock(spec=["get_document"]) + response = mock.create_autospec(document_pb2.Document) + response.fields = {} + response.create_time = create_time + response.update_time = update_time + + if not_found: + firestore_api.get_document.side_effect = NotFound("testing") + else: + firestore_api.get_document.return_value = response + + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + + document = self._make_one("where", "we-are", client=client) + + if use_transaction: + transaction = Transaction(client) + transaction_id = transaction._id = b"asking-me-2" + else: + transaction = None + + snapshot = asyncio.run(document.get(field_paths=field_paths, transaction=transaction)) + + self.assertIs(snapshot.reference, document) + if not_found: + self.assertIsNone(snapshot._data) + self.assertFalse(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIsNone(snapshot.create_time) + self.assertIsNone(snapshot.update_time) + else: + self.assertEqual(snapshot.to_dict(), {}) + self.assertTrue(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIs(snapshot.create_time, create_time) + self.assertIs(snapshot.update_time, update_time) + + # Verify the request made to the API + if field_paths is not None: + mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + if use_transaction: + expected_transaction_id = transaction_id + else: + expected_transaction_id = None + + firestore_api.get_document.assert_called_once_with( + document._document_path, + mask=mask, + transaction=expected_transaction_id, + metadata=client._rpc_metadata, + ) + + def test_get_not_found(self): + self._get_helper(not_found=True) + + def test_get_default(self): + self._get_helper() + + def test_get_w_string_field_path(self): + with self.assertRaises(ValueError): + self._get_helper(field_paths="foo") + + def test_get_with_field_path(self): + self._get_helper(field_paths=["foo"]) + + def test_get_with_multiple_field_paths(self): + self._get_helper(field_paths=["foo", "bar.baz"]) + + def test_get_with_transaction(self): + self._get_helper(use_transaction=True) + + def _collections_helper(self, page_size=None): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.gapic.firestore_client import FirestoreClient + + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + collection_ids = ["coll-1", "coll-2"] + iterator = _Iterator(pages=[collection_ids]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_collection_ids.return_value = iterator + + client = _make_client() + client._firestore_api_internal = api_client + + # Actually make a document and call delete(). + document = self._make_one("where", "we-are", client=client) + if page_size is not None: + collections = list(asyncio.run(document.collections(page_size=page_size))) + else: + collections = list(asyncio.run(document.collections())) + + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, CollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + api_client.list_collection_ids.assert_called_once_with( + document._document_path, page_size=page_size, metadata=client._rpc_metadata + ) + + def test_collections_wo_page_size(self): + self._collections_helper() + + def test_collections_w_page_size(self): + self._collections_helper(page_size=10) + + @mock.patch("google.cloud.firestore_v1.async_document.Watch", autospec=True) + def test_on_snapshot(self, watch): + client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) + document = self._make_one("yellow", "mellow", client=client) + document.on_snapshot(None) + watch.for_document.assert_called_once() + + +class TestDocumentSnapshot(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + return DocumentSnapshot + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def _make_reference(self, *args, **kwargs): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + return AsyncDocumentReference(*args, **kwargs) + + def _make_w_ref(self, ref_path=("a", "b"), data={}, exists=True): + client = mock.sentinel.client + reference = self._make_reference(*ref_path, client=client) + return self._make_one( + reference, + data, + exists, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + + def test_constructor(self): + client = mock.sentinel.client + reference = self._make_reference("hi", "bye", client=client) + data = {"zoop": 83} + snapshot = self._make_one( + reference, + data, + True, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + self.assertIs(snapshot._reference, reference) + self.assertEqual(snapshot._data, data) + self.assertIsNot(snapshot._data, data) # Make sure copied. + self.assertTrue(snapshot._exists) + self.assertIs(snapshot.read_time, mock.sentinel.read_time) + self.assertIs(snapshot.create_time, mock.sentinel.create_time) + self.assertIs(snapshot.update_time, mock.sentinel.update_time) + + def test___eq___other_type(self): + snapshot = self._make_w_ref() + other = object() + self.assertFalse(snapshot == other) + + def test___eq___different_reference_same_data(self): + snapshot = self._make_w_ref(("a", "b")) + other = self._make_w_ref(("c", "d")) + self.assertFalse(snapshot == other) + + def test___eq___same_reference_different_data(self): + snapshot = self._make_w_ref(("a", "b")) + other = self._make_w_ref(("a", "b"), {"foo": "bar"}) + self.assertFalse(snapshot == other) + + def test___eq___same_reference_same_data(self): + snapshot = self._make_w_ref(("a", "b"), {"foo": "bar"}) + other = self._make_w_ref(("a", "b"), {"foo": "bar"}) + self.assertTrue(snapshot == other) + + def test___hash__(self): + from google.protobuf import timestamp_pb2 + + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + reference = self._make_reference("hi", "bye", client=client) + data = {"zoop": 83} + update_time = timestamp_pb2.Timestamp(seconds=123456, nanos=123456789) + snapshot = self._make_one( + reference, data, True, None, mock.sentinel.create_time, update_time + ) + self.assertEqual( + hash(snapshot), hash(reference) + hash(123456) + hash(123456789) + ) + + def test__client_property(self): + reference = self._make_reference( + "ok", "fine", "now", "fore", client=mock.sentinel.client + ) + snapshot = self._make_one(reference, {}, False, None, None, None) + self.assertIs(snapshot._client, mock.sentinel.client) + + def test_exists_property(self): + reference = mock.sentinel.reference + + snapshot1 = self._make_one(reference, {}, False, None, None, None) + self.assertFalse(snapshot1.exists) + snapshot2 = self._make_one(reference, {}, True, None, None, None) + self.assertTrue(snapshot2.exists) + + def test_id_property(self): + document_id = "around" + reference = self._make_reference( + "look", document_id, client=mock.sentinel.client + ) + snapshot = self._make_one(reference, {}, True, None, None, None) + self.assertEqual(snapshot.id, document_id) + self.assertEqual(reference.id, document_id) + + def test_reference_property(self): + snapshot = self._make_one(mock.sentinel.reference, {}, True, None, None, None) + self.assertIs(snapshot.reference, mock.sentinel.reference) + + def test_get(self): + data = {"one": {"bold": "move"}} + snapshot = self._make_one(None, data, True, None, None, None) + + first_read = snapshot.get("one") + second_read = snapshot.get("one") + self.assertEqual(first_read, data.get("one")) + self.assertIsNot(first_read, data.get("one")) + self.assertEqual(first_read, second_read) + self.assertIsNot(first_read, second_read) + + with self.assertRaises(KeyError): + snapshot.get("two") + + def test_nonexistent_snapshot(self): + snapshot = self._make_one(None, None, False, None, None, None) + self.assertIsNone(snapshot.get("one")) + + def test_to_dict(self): + data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} + snapshot = self._make_one(None, data, True, None, None, None) + as_dict = snapshot.to_dict() + self.assertEqual(as_dict, data) + self.assertIsNot(as_dict, data) + # Check that the data remains unchanged. + as_dict["b"].append("hi") + self.assertEqual(data, snapshot.to_dict()) + self.assertNotEqual(data, as_dict) + + def test_non_existent(self): + snapshot = self._make_one(None, None, False, None, None, None) + as_dict = snapshot.to_dict() + self.assertIsNone(as_dict) + + +class Test__get_document_path(unittest.TestCase): + @staticmethod + def _call_fut(client, path): + from google.cloud.firestore_v1.async_document import _get_document_path + + return _get_document_path(client, path) + + def test_it(self): + project = "prah-jekt" + client = _make_client(project=project) + path = ("Some", "Document", "Child", "Shockument") + document_path = self._call_fut(client, path) + + expected = "projects/{}/databases/{}/documents/{}".format( + project, client._database, "/".join(path) + ) + self.assertEqual(document_path, expected) + + +class Test__consume_single_get(unittest.TestCase): + @staticmethod + def _call_fut(response_iterator): + from google.cloud.firestore_v1.async_document import _consume_single_get + + return _consume_single_get(response_iterator) + + def test_success(self): + response_iterator = iter([mock.sentinel.result]) + result = self._call_fut(response_iterator) + self.assertIs(result, mock.sentinel.result) + + def test_failure_not_enough(self): + response_iterator = iter([]) + with self.assertRaises(ValueError): + self._call_fut(response_iterator) + + def test_failure_too_many(self): + response_iterator = iter([None, None]) + with self.assertRaises(ValueError): + self._call_fut(response_iterator) + + +class Test__first_write_result(unittest.TestCase): + @staticmethod + def _call_fut(write_results): + from google.cloud.firestore_v1.async_document import _first_write_result + + return _first_write_result(write_results) + + def test_success(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + single_result = write_pb2.WriteResult( + update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) + ) + write_results = [single_result] + result = self._call_fut(write_results) + self.assertIs(result, single_result) + + def test_failure_not_enough(self): + write_results = [] + with self.assertRaises(ValueError): + self._call_fut(write_results) + + def test_more_than_one(self): + from google.cloud.firestore_v1.proto import write_pb2 + + result1 = write_pb2.WriteResult() + result2 = write_pb2.WriteResult() + write_results = [result1, result2] + result = self._call_fut(write_results) + self.assertIs(result, result1) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) From 1cc47ad92ec7f82394b67971cf411e3d48928e12 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 18:26:50 -0500 Subject: [PATCH 06/47] feat: add AsyncCollectionReference class --- google/cloud/firestore_v1/async_collection.py | 469 ++++++++++++++++++ 1 file changed, 469 insertions(+) create mode 100644 google/cloud/firestore_v1/async_collection.py diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py new file mode 100644 index 0000000000..a9fbcb805e --- /dev/null +++ b/google/cloud/firestore_v1/async_collection.py @@ -0,0 +1,469 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing collections for the Google Cloud Firestore API.""" +import random +import warnings + +import six + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import query as query_mod +from google.cloud.firestore_v1.watch import Watch +from google.cloud.firestore_v1 import async_document + +_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + + +class AsyncCollectionReference(object): + """A reference to a collection in a Firestore database. + + The collection may already exist or this class can facilitate creation + of documents within the collection. + + Args: + path (Tuple[str, ...]): The components in the collection path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection. + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client` if provided. It + represents the client that created this collection reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + def __init__(self, *path, **kwargs): + _helpers.verify_path(path, is_collection=True) + self._path = path + self._client = kwargs.pop("client", None) + if kwargs: + raise TypeError( + "Received unexpected arguments", kwargs, "Only `client` is supported" + ) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._path == other._path and self._client == other._client + + @property + def id(self): + """The collection identifier. + + Returns: + str: The last component of the path. + """ + return self._path[-1] + + @property + def parent(self): + """Document that owns the current collection. + + Returns: + Optional[:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`]: + The parent document, if the current collection is not a + top-level collection. + """ + if len(self._path) == 1: + return None + else: + parent_path = self._path[:-1] + return self._client.document(*parent_path) + + def document(self, document_id=None): + """Create a sub-document underneath the current collection. + + Args: + document_id (Optional[str]): The document identifier + within the current collection. If not provided, will default + to a random 20 character string composed of digits, + uppercase and lowercase and letters. + + Returns: + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`: + The child document. + """ + if document_id is None: + document_id = _auto_id() + + child_path = self._path + (document_id,) + return self._client.document(*child_path) + + def _parent_info(self): + """Get fully-qualified parent path and prefix for this collection. + + Returns: + Tuple[str, str]: Pair of + + * the fully-qualified (with database and project) path to the + parent of this collection (will either be the database path + or a document path). + * the prefix to a document in this collection. + """ + parent_doc = self.parent + if parent_doc is None: + parent_path = _helpers.DOCUMENT_PATH_DELIMITER.join( + (self._client._database_string, "documents") + ) + else: + parent_path = parent_doc._document_path + + expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) + return parent_path, expected_prefix + + async def add(self, document_data, document_id=None): + """Create a document in the Firestore database with the provided data. + + Args: + document_data (dict): Property names and values to use for + creating the document. + document_id (Optional[str]): The document identifier within the + current collection. If not provided, an ID will be + automatically assigned by the server (the assigned ID will be + a random 20 character string composed of digits, + uppercase and lowercase letters). + + Returns: + Tuple[:class:`google.protobuf.timestamp_pb2.Timestamp`, \ + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`]: + Pair of + + * The ``update_time`` when the document was created/overwritten. + * A document reference for the created document. + + Raises: + ~google.cloud.exceptions.Conflict: If ``document_id`` is provided + and the document already exists. + """ + if document_id is None: + document_id = _auto_id() + + document_ref = self.document(document_id) + write_result = document_ref.create(document_data) + return write_result.update_time, document_ref + + async def list_documents(self, page_size=None): + """List all subdocuments of the current collection. + + Args: + page_size (Optional[int]]): The maximum number of documents + in each page of results from this request. Non-positive values + are ignored. Defaults to a sensible value set by the API. + + Returns: + Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: + iterator of subdocuments of the current collection. If the + collection does not exist at the time of `snapshot`, the + iterator will be empty + """ + parent, _ = self._parent_info() + + iterator = self._client._firestore_api.list_documents( + parent, + self.id, + page_size=page_size, + show_missing=True, + metadata=self._client._rpc_metadata, + ) + iterator.collection = self + iterator.item_to_value = _item_to_document_ref + return iterator + + async def select(self, field_paths): + """Create a "select" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.select` for + more information on this method. + + Args: + field_paths (Iterable[str, ...]): An iterable of field paths + (``.``-delimited list of field names) to use as a projection + of document fields in the query results. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A "projected" query. + """ + query = query_mod.Query(self) + return query.select(field_paths) + + def where(self, field_path, op_string, value): + """Create a "where" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.where` for + more information on this method. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) for the field to filter on. + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=`` + and ``>``. + value (Any): The value to compare the field against in the filter. + If ``value`` is :data:`None` or a NaN, then ``==`` is the only + allowed operation. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A filtered query. + """ + query = query_mod.Query(self) + return query.where(field_path, op_string, value) + + def order_by(self, field_path, **kwargs): + """Create an "order by" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.order_by` for + more information on this method. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) on which to order the query results. + kwargs (Dict[str, Any]): The keyword arguments to pass along + to the query. The only supported keyword is ``direction``, + see :meth:`~google.cloud.firestore_v1.query.Query.order_by` + for more information. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An "order by" query. + """ + query = query_mod.Query(self) + return query.order_by(field_path, **kwargs) + + def limit(self, count): + """Create a limited query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.limit` for + more information on this method. + + Args: + count (int): Maximum number of documents to return that match + the query. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A limited query. + """ + query = query_mod.Query(self) + return query.limit(count) + + def offset(self, num_to_skip): + """Skip to an offset in a query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.offset` for + more information on this method. + + Args: + num_to_skip (int): The number of results to skip at the beginning + of query results. (Must be non-negative.) + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An offset query. + """ + query = query_mod.Query(self) + return query.offset(num_to_skip) + + def start_at(self, document_fields): + """Start query at a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.start_at` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = query_mod.Query(self) + return query.start_at(document_fields) + + def start_after(self, document_fields): + """Start query after a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.start_after` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = query_mod.Query(self) + return query.start_after(document_fields) + + def end_before(self, document_fields): + """End query before a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.end_before` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = query_mod.Query(self) + return query.end_before(document_fields) + + def end_at(self, document_fields): + """End query at a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.end_at` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = query_mod.Query(self) + return query.end_at(document_fields) + + async def get(self, transaction=None): + """Deprecated alias for :meth:`stream`.""" + warnings.warn( + "'Collection.get' is deprecated: please use 'Collection.stream' instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.stream(transaction=transaction) + + def stream(self, transaction=None): + """Read the documents in this collection. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + query = query_mod.Query(self) + return query.stream(transaction=transaction) + + def on_snapshot(self, callback): + """Monitor the documents in this collection. + + This starts a watch on this collection using a background thread. The + provided callback is run on the snapshot of the documents. + + Args: + callback (Callable[[:class:`~google.cloud.firestore.collection.CollectionSnapshot`], NoneType]): + a callback to run when a change occurs. + + Example: + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + collection_ref = db.collection(u'users') + + def on_snapshot(collection_snapshot, changes, read_time): + for doc in collection_snapshot.documents: + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + # Watch this collection + collection_watch = collection_ref.on_snapshot(on_snapshot) + + # Terminate this watch + collection_watch.unsubscribe() + """ + return Watch.for_query( + query_mod.Query(self), + callback, + async_document.DocumentSnapshot, + async_document.AsyncDocumentReference, + ) + + +def _auto_id(): + """Generate a "random" automatically generated ID. + + Returns: + str: A 20 character string composed of digits, uppercase and + lowercase and letters. + """ + return "".join(random.choice(_AUTO_ID_CHARS) for _ in six.moves.xrange(20)) + + +def _item_to_document_ref(iterator, item): + """Convert Document resource to document ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (dict): document resource + """ + document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] + return iterator.collection.document(document_id) From 281f8048a516fe4fb7b4d143e70ecd91ec1f10a9 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 18:27:04 -0500 Subject: [PATCH 07/47] feat: integrate AsyncCollectionReference --- google/cloud/firestore_v1/async_client.py | 10 +++++----- google/cloud/firestore_v1/async_document.py | 6 +++--- tests/unit/v1/async/test_async_client.py | 14 +++++++------- tests/unit/v1/async/test_async_document.py | 20 +++++++++++--------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index a5de26f827..1d8cf4f2aa 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -19,7 +19,7 @@ In the hierarchy of API concepts * a :class:`~google.cloud.firestore_v1.client.Client` owns a - :class:`~google.cloud.firestore_v1.collection.CollectionReference` + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference` * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ @@ -34,7 +34,7 @@ from google.cloud.firestore_v1 import query from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.batch import WriteBatch -from google.cloud.firestore_v1.collection import CollectionReference +from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.field_path import render_field_path @@ -238,7 +238,7 @@ def collection(self, *collection_path): * A tuple of collection path segments Returns: - :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: A reference to a collection in the Firestore database. """ if len(collection_path) == 1: @@ -246,7 +246,7 @@ def collection(self, *collection_path): else: path = collection_path - return CollectionReference(*path, client=self) + return AsyncCollectionReference(*path, client=self) def collection_group(self, collection_id): """ @@ -448,7 +448,7 @@ async def collections(self): """List top-level collections of the client's database. Returns: - Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: + Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: iterator of subcollections of the current document. """ iterator = self._firestore_api.list_collection_ids( diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 41c26a03f2..65be7697ec 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -162,7 +162,7 @@ def parent(self): """Collection that owns the current document. Returns: - :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: The parent collection. """ parent_path = self._path[:-1] @@ -176,7 +176,7 @@ def collection(self, collection_id): referred to as the "kind"). Returns: - :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: The child collection. """ child_path = self._path + (collection_id,) @@ -479,7 +479,7 @@ async def collections(self, page_size=None): are ignored. Defaults to a sensible value set by the API. Returns: - Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: + Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: iterator of subcollections of the current document. If the document does not exist at the time of `snapshot`, the iterator will be empty diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 4c36f8d278..8a0b51f48a 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -185,7 +185,7 @@ def test__rpc_metadata_property_with_emulator(self): ) def test_collection_factory(self): - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference collection_id = "users" client = self._make_default_one() @@ -193,10 +193,10 @@ def test_collection_factory(self): self.assertEqual(collection._path, (collection_id,)) self.assertIs(collection._client, client) - self.assertIsInstance(collection, CollectionReference) + self.assertIsInstance(collection, AsyncCollectionReference) def test_collection_factory_nested(self): - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference client = self._make_default_one() parts = ("users", "alovelace", "beep") @@ -205,13 +205,13 @@ def test_collection_factory_nested(self): self.assertEqual(collection1._path, parts) self.assertIs(collection1._client, client) - self.assertIsInstance(collection1, CollectionReference) + self.assertIsInstance(collection1, AsyncCollectionReference) # Make sure using segments gives the same result. collection2 = client.collection(*parts) self.assertEqual(collection2._path, parts) self.assertIs(collection2._client, client) - self.assertIsInstance(collection2, CollectionReference) + self.assertIsInstance(collection2, AsyncCollectionReference) def test_collection_group(self): client = self._make_default_one() @@ -336,7 +336,7 @@ def test_write_bad_arg(self): def test_collections(self): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference collection_ids = ["users", "projects"] client = self._make_default_one() @@ -360,7 +360,7 @@ def _next_page(self): self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) + self.assertIsInstance(collection, AsyncCollectionReference) self.assertEqual(collection.parent, None) self.assertEqual(collection.id, collection_id) diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index d9bdea96aa..e3a04918d6 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -168,7 +168,7 @@ def test_id_property(self): self.assertEqual(document.id, document_id) def test_parent_property(self): - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference collection_id = "grocery-store" document_id = "market" @@ -176,12 +176,12 @@ def test_parent_property(self): document = self._make_one(collection_id, document_id, client=client) parent = document.parent - self.assertIsInstance(parent, CollectionReference) + self.assertIsInstance(parent, AsyncCollectionReference) self.assertIs(parent._client, client) self.assertEqual(parent._path, (collection_id,)) def test_collection_factory(self): - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference collection_id = "grocery-store" document_id = "market" @@ -190,7 +190,7 @@ def test_collection_factory(self): document = self._make_one(collection_id, document_id, client=client) child = document.collection(new_collection) - self.assertIsInstance(child, CollectionReference) + self.assertIsInstance(child, AsyncCollectionReference) self.assertIs(child._client, client) self.assertEqual(child._path, (collection_id, document_id, new_collection)) @@ -483,7 +483,9 @@ def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): else: transaction = None - snapshot = asyncio.run(document.get(field_paths=field_paths, transaction=transaction)) + snapshot = asyncio.run( + document.get(field_paths=field_paths, transaction=transaction) + ) self.assertIs(snapshot.reference, document) if not_found: @@ -539,7 +541,7 @@ def test_get_with_transaction(self): def _collections_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page - from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.gapic.firestore_client import FirestoreClient class _Iterator(Iterator): @@ -570,7 +572,7 @@ def _next_page(self): # Verify the response and the mocks. self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) + self.assertIsInstance(collection, AsyncCollectionReference) self.assertEqual(collection.parent, document) self.assertEqual(collection.id, collection_id) @@ -820,7 +822,7 @@ def _make_credentials(): def _make_client(project="project-project"): - from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient credentials = _make_credentials() - return Client(project=project, credentials=credentials) + return AsyncClient(project=project, credentials=credentials) From 59ae8a63b8e722b0dc97a873e9557b1dba3475a0 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 19:05:20 -0500 Subject: [PATCH 08/47] feat: add async_collection tests --- tests/unit/v1/async/test_async_collection.py | 583 +++++++++++++++++++ 1 file changed, 583 insertions(+) create mode 100644 tests/unit/v1/async/test_async_collection.py diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py new file mode 100644 index 0000000000..6688cbe80b --- /dev/null +++ b/tests/unit/v1/async/test_async_collection.py @@ -0,0 +1,583 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import types +import unittest + +import mock +import six + + +class TestAsyncCollectionReference(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + return AsyncCollectionReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + @staticmethod + def _get_public_methods(klass): + return set( + name + for name, value in six.iteritems(klass.__dict__) + if (not name.startswith("_") and isinstance(value, types.FunctionType)) + ) + + def test_query_method_matching(self): + from google.cloud.firestore_v1.query import Query + + query_methods = self._get_public_methods(Query) + klass = self._get_target_class() + collection_methods = self._get_public_methods(klass) + # Make sure every query method is present on + # ``AsyncCollectionReference``. + self.assertLessEqual(query_methods, collection_methods) + + def test_constructor(self): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + self.assertIs(collection._client, client) + expected_path = (collection_id1, document_id, collection_id2) + self.assertEqual(collection._path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(99, "doc", "bad-collection-id") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None, "sub-collection") + with self.assertRaises(ValueError): + self._make_one("Just", "A-Document") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", donut=True) + + def test___eq___other_type(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = object() + self.assertFalse(collection == other) + + def test___eq___different_path_same_client(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = self._make_one("other", client=client) + self.assertFalse(collection == other) + + def test___eq___same_path_different_client(self): + client = mock.sentinel.client + other_client = mock.sentinel.other_client + collection = self._make_one("name", client=client) + other = self._make_one("name", client=other_client) + self.assertFalse(collection == other) + + def test___eq___same_path_same_client(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = self._make_one("name", client=client) + self.assertTrue(collection == other) + + def test_id_property(self): + collection_id = "hi-bob" + collection = self._make_one(collection_id) + self.assertEqual(collection.id, collection_id) + + def test_parent_property(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + collection_id1 = "grocery-store" + document_id = "market" + collection_id2 = "darth" + client = _make_client() + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + + parent = collection.parent + self.assertIsInstance(parent, AsyncDocumentReference) + self.assertIs(parent._client, client) + self.assertEqual(parent._path, (collection_id1, document_id)) + + def test_parent_property_top_level(self): + collection = self._make_one("tahp-leh-vull") + self.assertIsNone(collection.parent) + + def test_document_factory_explicit_id(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + collection = self._make_one(collection_id, client=client) + + child = collection.document(document_id) + self.assertIsInstance(child, AsyncDocumentReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_id, document_id)) + + @mock.patch( + "google.cloud.firestore_v1.async_collection._auto_id", + return_value="zorpzorpthreezorp012", + ) + def test_document_factory_auto_id(self, mock_auto_id): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + collection_name = "space-town" + client = _make_client() + collection = self._make_one(collection_name, client=client) + + child = collection.document() + self.assertIsInstance(child, AsyncDocumentReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_name, mock_auto_id.return_value)) + + mock_auto_id.assert_called_once_with() + + def test__parent_info_top_level(self): + client = _make_client() + collection_id = "soap" + collection = self._make_one(collection_id, client=client) + + parent_path, expected_prefix = collection._parent_info() + + expected_path = "projects/{}/databases/{}/documents".format( + client.project, client._database + ) + self.assertEqual(parent_path, expected_path) + prefix = "{}/{}".format(expected_path, collection_id) + self.assertEqual(expected_prefix, prefix) + + def test__parent_info_nested(self): + collection_id1 = "bar" + document_id = "baz" + collection_id2 = "chunk" + client = _make_client() + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + + parent_path, expected_prefix = collection._parent_info() + + expected_path = "projects/{}/databases/{}/documents/{}/{}".format( + client.project, client._database, collection_id1, document_id + ) + self.assertEqual(parent_path, expected_path) + prefix = "{}/{}".format(expected_path, collection_id2) + self.assertEqual(expected_prefix, prefix) + + def test_add_auto_assigned(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_create + + # Create a minimal fake GAPIC add attach it to a real client. + firestore_api = mock.Mock(spec=["create_document", "commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + create_doc_response = document_pb2.Document() + firestore_api.create_document.return_value = create_doc_response + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection. + collection = self._make_one("grand-parent", "parent", "child", client=client) + + # Actually call add() on our collection; include a transform to make + # sure transforms during adds work. + document_data = {"been": "here", "now": SERVER_TIMESTAMP} + + patch = mock.patch("google.cloud.firestore_v1.async_collection._auto_id") + random_doc_id = "DEADBEEF" + with patch as patched: + patched.return_value = random_doc_id + update_time, document_ref = asyncio.run(collection.add(document_data)) + + # Verify the response and the mocks. + self.assertIs(update_time, mock.sentinel.update_time) + self.assertIsInstance(document_ref, AsyncDocumentReference) + self.assertIs(document_ref._client, client) + expected_path = collection._path + (random_doc_id,) + self.assertEqual(document_ref._path, expected_path) + + write_pbs = pbs_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + client._database_string, + write_pbs, + transaction=None, + metadata=client._rpc_metadata, + ) + # Since we generate the ID locally, we don't call 'create_document'. + firestore_api.create_document.assert_not_called() + + @staticmethod + def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1 import _helpers + + return write_pb2.Write( + update=document_pb2.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common_pb2.Precondition(exists=False), + ) + + def test_add_explicit_id(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection and call add(). + collection = self._make_one("parent", client=client) + document_data = {"zorp": 208.75, "i-did-not": b"know that"} + doc_id = "child" + update_time, document_ref = asyncio.run( + collection.add(document_data, document_id=doc_id) + ) + + # Verify the response and the mocks. + self.assertIs(update_time, mock.sentinel.update_time) + self.assertIsInstance(document_ref, AsyncDocumentReference) + self.assertIs(document_ref._client, client) + self.assertEqual(document_ref._path, (collection.id, doc_id)) + + write_pb = self._write_pb_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + client._database_string, + [write_pb], + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_select(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + field_paths = ["a", "b"] + query = collection.select(field_paths) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + projection_paths = [ + field_ref.field_path for field_ref in query._projection.fields + ] + self.assertEqual(projection_paths, field_paths) + + @staticmethod + def _make_field_filter_pb(field_path, op_string, value): + from google.cloud.firestore_v1.proto import query_pb2 + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.query import _enum_from_op_string + + return query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=_enum_from_op_string(op_string), + value=_helpers.encode_value(value), + ) + + def test_where(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + field_path = "foo" + op_string = "==" + value = 45 + query = collection.where(field_path, op_string, value) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(len(query._field_filters), 1) + field_filter_pb = query._field_filters[0] + self.assertEqual( + field_filter_pb, self._make_field_filter_pb(field_path, op_string, value) + ) + + @staticmethod + def _make_order_pb(field_path, direction): + from google.cloud.firestore_v1.proto import query_pb2 + from google.cloud.firestore_v1.query import _enum_from_direction + + return query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + direction=_enum_from_direction(direction), + ) + + def test_order_by(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + field_path = "foo" + direction = Query.DESCENDING + query = collection.order_by(field_path, direction=direction) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(len(query._orders), 1) + order_pb = query._orders[0] + self.assertEqual(order_pb, self._make_order_pb(field_path, direction)) + + def test_limit(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + limit = 15 + query = collection.limit(limit) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._limit, limit) + + def test_offset(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + offset = 113 + query = collection.offset(offset) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._offset, offset) + + def test_start_at(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + doc_fields = {"a": "b"} + query = collection.start_at(doc_fields) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._start_at, (doc_fields, True)) + + def test_start_after(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + doc_fields = {"d": "foo", "e": 10} + query = collection.start_after(doc_fields) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._start_at, (doc_fields, False)) + + def test_end_before(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + doc_fields = {"bar": 10.5} + query = collection.end_before(doc_fields) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._end_at, (doc_fields, True)) + + def test_end_at(self): + from google.cloud.firestore_v1.query import Query + + collection = self._make_one("collection") + doc_fields = {"opportunity": True, "reason": 9} + query = collection.end_at(doc_fields) + + self.assertIsInstance(query, Query) + self.assertIs(query._parent, collection) + self.assertEqual(query._end_at, (doc_fields, False)) + + def _list_documents_helper(self, page_size=None): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.gapic.firestore_client import FirestoreClient + from google.cloud.firestore_v1.proto.document_pb2 import Document + + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + client = _make_client() + template = client._database_string + "/documents/{}" + document_ids = ["doc-1", "doc-2"] + documents = [ + Document(name=template.format(document_id)) for document_id in document_ids + ] + iterator = _Iterator(pages=[documents]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_documents.return_value = iterator + client._firestore_api_internal = api_client + collection = self._make_one("collection", client=client) + + if page_size is not None: + documents = list( + asyncio.run(collection.list_documents(page_size=page_size)) + ) + else: + documents = list(asyncio.run(collection.list_documents())) + + # Verify the response and the mocks. + self.assertEqual(len(documents), len(document_ids)) + for document, document_id in zip(documents, document_ids): + self.assertIsInstance(document, AsyncDocumentReference) + self.assertEqual(document.parent, collection) + self.assertEqual(document.id, document_id) + + parent, _ = collection._parent_info() + api_client.list_documents.assert_called_once_with( + parent, + collection.id, + page_size=page_size, + show_missing=True, + metadata=client._rpc_metadata, + ) + + def test_list_documents_wo_page_size(self): + self._list_documents_helper() + + def test_list_documents_w_page_size(self): + self._list_documents_helper(page_size=25) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_get(self, query_class): + import warnings + + collection = self._make_one("collection") + with warnings.catch_warnings(record=True) as warned: + get_response = asyncio.run(collection.get()) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + self.assertIs(get_response, query_instance.stream.return_value) + query_instance.stream.assert_called_once_with(transaction=None) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_get_with_transaction(self, query_class): + import warnings + + collection = self._make_one("collection") + transaction = mock.sentinel.txn + with warnings.catch_warnings(record=True) as warned: + get_response = asyncio.run(collection.get(transaction=transaction)) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + self.assertIs(get_response, query_instance.stream.return_value) + query_instance.stream.assert_called_once_with(transaction=transaction) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_stream(self, query_class): + collection = self._make_one("collection") + stream_response = asyncio.run(collection.stream()) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + self.assertIs(stream_response, query_instance.stream.return_value) + query_instance.stream.assert_called_once_with(transaction=None) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + def test_stream_with_transaction(self, query_class): + collection = self._make_one("collection") + transaction = mock.sentinel.txn + stream_response = asyncio.run(collection.stream(transaction=transaction)) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + self.assertIs(stream_response, query_instance.stream.return_value) + query_instance.stream.assert_called_once_with(transaction=transaction) + + @mock.patch("google.cloud.firestore_v1.async_collection.Watch", autospec=True) + def test_on_snapshot(self, watch): + collection = self._make_one("collection") + collection.on_snapshot(None) + watch.for_query.assert_called_once() + + +class Test__auto_id(unittest.TestCase): + @staticmethod + def _call_fut(): + from google.cloud.firestore_v1.async_collection import _auto_id + + return _auto_id() + + @mock.patch("random.choice") + def test_it(self, mock_rand_choice): + from google.cloud.firestore_v1.async_collection import _AUTO_ID_CHARS + + mock_result = "0123456789abcdefghij" + mock_rand_choice.side_effect = list(mock_result) + result = self._call_fut() + self.assertEqual(result, mock_result) + + mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 + self.assertEqual(mock_rand_choice.mock_calls, mock_calls) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(): + from google.cloud.firestore_v1.async_client import AsyncClient + + credentials = _make_credentials() + return AsyncClient(project="project-project", credentials=credentials) From 7ab12a09c78967e71c325d84d22729a2da4e9c0c Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 19:06:19 -0500 Subject: [PATCH 09/47] fix: swap coroutine/function declaration in async_collection --- google/cloud/firestore_v1/async_collection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index a9fbcb805e..030a4ceb49 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -159,7 +159,7 @@ async def add(self, document_data, document_id=None): document_id = _auto_id() document_ref = self.document(document_id) - write_result = document_ref.create(document_data) + write_result = await document_ref.create(document_data) return write_result.update_time, document_ref async def list_documents(self, page_size=None): @@ -189,7 +189,7 @@ async def list_documents(self, page_size=None): iterator.item_to_value = _item_to_document_ref return iterator - async def select(self, field_paths): + def select(self, field_paths): """Create a "select" query with this collection as parent. See @@ -381,9 +381,9 @@ async def get(self, transaction=None): DeprecationWarning, stacklevel=2, ) - return self.stream(transaction=transaction) + return await self.stream(transaction=transaction) - def stream(self, transaction=None): + async def stream(self, transaction=None): """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which From b700c7ba0aef59cad3da334d892ab233dec119d2 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 19:19:38 -0500 Subject: [PATCH 10/47] feat: add async_batch implementation --- google/cloud/firestore_v1/async_batch.py | 160 +++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 google/cloud/firestore_v1/async_batch.py diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py new file mode 100644 index 0000000000..eed0bacbfa --- /dev/null +++ b/google/cloud/firestore_v1/async_batch.py @@ -0,0 +1,160 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for batch requests to the Google Cloud Firestore API.""" + + +from google.cloud.firestore_v1 import _helpers + + +class AsyncWriteBatch(object): + """Accumulate write operations to be sent in a batch. + + This has the same set of methods for write operations that + :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` does, + e.g. :meth:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference.create`. + + Args: + client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`): + The client that created this batch. + """ + + def __init__(self, client): + self._client = client + self._write_pbs = [] + self.write_results = None + self.commit_time = None + + def _add_write_pbs(self, write_pbs): + """Add `Write`` protobufs to this transaction. + + This method intended to be over-ridden by subclasses. + + Args: + write_pbs (List[google.cloud.proto.firestore.v1.\ + write_pb2.Write]): A list of write protobufs to be added. + """ + self._write_pbs.extend(write_pbs) + + def create(self, reference, document_data): + """Add a "change" to this batch to create a document. + + If the document given by ``reference`` already exists, then this + batch will fail when :meth:`commit`-ed. + + Args: + reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): + A document reference to be created in this batch. + document_data (dict): Property names and values to use for + creating a document. + """ + write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) + self._add_write_pbs(write_pbs) + + def set(self, reference, document_data, merge=False): + """Add a "change" to replace a document. + + See + :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.set` for + more information on how ``option`` determines how the change is + applied. + + Args: + reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): + A document reference that will have values set in this batch. + document_data (dict): + Property names and values to use for replacing a document. + merge (Optional[bool] or Optional[List]): + If True, apply merging instead of overwriting the state + of the document. + """ + if merge is not False: + write_pbs = _helpers.pbs_for_set_with_merge( + reference._document_path, document_data, merge + ) + else: + write_pbs = _helpers.pbs_for_set_no_merge( + reference._document_path, document_data + ) + + self._add_write_pbs(write_pbs) + + def update(self, reference, field_updates, option=None): + """Add a "change" to update a document. + + See + :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.update` + for more information on ``field_updates`` and ``option``. + + Args: + reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): + A document reference that will be updated in this batch. + field_updates (dict): + Field names or paths to update and values to update with. + option (Optional[:class:`~google.cloud.firestore_v1.async_client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + """ + if option.__class__.__name__ == "ExistsOption": + raise ValueError("you must not pass an explicit write option to " "update.") + write_pbs = _helpers.pbs_for_update( + reference._document_path, field_updates, option + ) + self._add_write_pbs(write_pbs) + + def delete(self, reference, option=None): + """Add a "change" to delete a document. + + See + :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.delete` + for more information on how ``option`` determines how the change is + applied. + + Args: + reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): + A document reference that will be deleted in this batch. + option (Optional[:class:`~google.cloud.firestore_v1.async_client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + """ + write_pb = _helpers.pb_for_delete(reference._document_path, option) + self._add_write_pbs([write_pb]) + + async def commit(self): + """Commit the changes accumulated in this batch. + + Returns: + List[:class:`google.cloud.proto.firestore.v1.write_pb2.WriteResult`, ...]: + The write results corresponding to the changes committed, returned + in the same order as the changes were applied to this batch. A + write result contains an ``update_time`` field. + """ + commit_response = self._client._firestore_api.commit( + self._client._database_string, + self._write_pbs, + transaction=None, + metadata=self._client._rpc_metadata, + ) + + self._write_pbs = [] + self.write_results = results = list(commit_response.write_results) + self.commit_time = commit_response.commit_time + return results + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is None: + await self.commit() From a6a948f241fb4912d8d5051e979ae0181a5e1fed Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 19:19:49 -0500 Subject: [PATCH 11/47] feat: integrate async_batch --- google/cloud/firestore_v1/async_client.py | 6 +++--- google/cloud/firestore_v1/async_document.py | 6 +++--- tests/unit/v1/async/test_async_client.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 1d8cf4f2aa..362e9b8d69 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -33,7 +33,7 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import query from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1.batch import WriteBatch +from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_document import DocumentSnapshot @@ -462,11 +462,11 @@ def batch(self): """Get a batch instance from this client. Returns: - :class:`~google.cloud.firestore_v1.batch.WriteBatch`: + :class:`~google.cloud.firestore_v1.async_batch.AsyncWriteBatch`: A "write" batch to be used for accumulating document changes and sending the changes all at once. """ - return WriteBatch(self) + return AsyncWriteBatch(self) def transaction(self, **kwargs): """Get a transaction that uses this client. diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 65be7697ec..1371f865fc 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -200,7 +200,7 @@ async def create(self, document_data): """ batch = self._client.batch() batch.create(self, document_data) - write_results = batch.commit() + write_results = await batch.commit() return _first_write_result(write_results) async def set(self, document_data, merge=False): @@ -231,7 +231,7 @@ async def set(self, document_data, merge=False): """ batch = self._client.batch() batch.set(self, document_data, merge=merge) - write_results = batch.commit() + write_results = await batch.commit() return _first_write_result(write_results) async def update(self, field_updates, option=None): @@ -379,7 +379,7 @@ async def update(self, field_updates, option=None): """ batch = self._client.batch() batch.update(self, field_updates, option=option) - write_results = batch.commit() + write_results = await batch.commit() return _first_write_result(write_results) async def delete(self, option=None): diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 8a0b51f48a..f7c92976b9 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -544,11 +544,11 @@ def test_get_all_wrong_order(self): ) def test_batch(self): - from google.cloud.firestore_v1.batch import WriteBatch + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch client = self._make_default_one() batch = client.batch() - self.assertIsInstance(batch, WriteBatch) + self.assertIsInstance(batch, AsyncWriteBatch) self.assertIs(batch._client, client) self.assertEqual(batch._write_pbs, []) From e0fb873048da317978afe7438fb66f0b74ff72c3 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 16 Jun 2020 20:30:10 -0500 Subject: [PATCH 12/47] feat: add async_batch tests --- tests/unit/v1/async/test_async_batch.py | 280 ++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 tests/unit/v1/async/test_async_batch.py diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py new file mode 100644 index 0000000000..0816c2d32f --- /dev/null +++ b/tests/unit/v1/async/test_async_batch.py @@ -0,0 +1,280 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest + +import mock + + +class TestAsyncWriteBatch(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + + return AsyncWriteBatch + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + batch = self._make_one(mock.sentinel.client) + self.assertIs(batch._client, mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + self.assertIsNone(batch.write_results) + self.assertIsNone(batch.commit_time) + + def test__add_write_pbs(self): + batch = self._make_one(mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) + self.assertEqual(batch._write_pbs, [mock.sentinel.write1, mock.sentinel.write2]) + + def test_create(self): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("this", "one") + document_data = {"a": 10, "b": 2.5} + ret_val = batch.create(reference, document_data) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={ + "a": _value_pb(integer_value=document_data["a"]), + "b": _value_pb(double_value=document_data["b"]), + }, + ), + current_document=common_pb2.Precondition(exists=False), + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_set(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, + ) + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_set_merge(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data, merge=True) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, + ), + update_mask={"field_paths": [field]}, + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_update(self): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("cats", "cradle") + field_path = "head.foot" + value = u"knees toes shoulders" + field_updates = {field_path: value} + + ret_val = batch.update(reference, field_updates) + self.assertIsNone(ret_val) + + map_pb = document_pb2.MapValue(fields={"foot": _value_pb(string_value=value)}) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={"head": _value_pb(map_value=map_pb)}, + ), + update_mask=common_pb2.DocumentMask(field_paths=[field_path]), + current_document=common_pb2.Precondition(exists=True), + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_delete(self): + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("early", "mornin", "dawn", "now") + ret_val = batch.delete(reference) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write(delete=reference._document_path) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_commit(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore_pb2.CommitResponse( + write_results=[write_pb2.WriteResult(), write_pb2.WriteResult()], + commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = self._make_one(client) + document1 = client.document("a", "b") + batch.create(document1, {"ten": 10, "buck": u"ets"}) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + write_results = asyncio.run(batch.commit()) + self.assertEqual(write_results, list(commit_response.write_results)) + self.assertEqual(batch.write_results, write_results) + self.assertEqual(batch.commit_time, timestamp) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + client._database_string, + write_pbs, + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_as_context_mgr_wo_error(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore_pb2.CommitResponse( + write_results=[write_pb2.WriteResult(), write_pb2.WriteResult()], + commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + client = _make_client() + client._firestore_api_internal = firestore_api + batch = self._make_one(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + write_pbs = asyncio.run(self._as_context_mgr_wo_error_helper(batch, document1, document2)) + + self.assertEqual(batch.write_results, list(commit_response.write_results)) + self.assertEqual(batch.commit_time, timestamp) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + client._database_string, + write_pbs, + transaction=None, + metadata=client._rpc_metadata, + ) + + async def _as_context_mgr_wo_error_helper(self, batch, document1, document2): + async with batch as ctx_mgr: + self.assertIs(ctx_mgr, batch) + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + write_pbs = batch._write_pbs[::] + return write_pbs + + def test_as_context_mgr_w_error(self): + firestore_api = mock.Mock(spec=["commit"]) + client = _make_client() + client._firestore_api_internal = firestore_api + batch = self._make_one(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + asyncio.run(self._as_context_mgr_w_error_helper(batch, document1, document2)) + + self.assertIsNone(batch.write_results) + self.assertIsNone(batch.commit_time) + # batch still has its changes + self.assertEqual(len(batch._write_pbs), 2) + + firestore_api.commit.assert_not_called() + + async def _as_context_mgr_w_error_helper(self, batch, document1, document2): + with self.assertRaises(RuntimeError): + async with batch as ctx_mgr: + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + raise RuntimeError("testing") + + + +def _value_pb(**kwargs): + from google.cloud.firestore_v1.proto.document_pb2 import Value + + return Value(**kwargs) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="seventy-nine"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) From 7e993268a4f77b18db545300c0af95fac43d7496 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Thu, 18 Jun 2020 13:15:09 -0500 Subject: [PATCH 13/47] feat: add async_query implementation --- google/cloud/firestore_v1/async_query.py | 1041 ++++++++++++++++++++++ 1 file changed, 1041 insertions(+) create mode 100644 google/cloud/firestore_v1/async_query.py diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py new file mode 100644 index 0000000000..4061902db8 --- /dev/null +++ b/google/cloud/firestore_v1/async_query.py @@ -0,0 +1,1041 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.query.Query` can be created directly from +a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be +a more common way to create a query than direct usage of the constructor. +""" +import copy +import math +import warnings + +from google.protobuf import wrappers_pb2 +import six + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import document +from google.cloud.firestore_v1 import field_path as field_path_module +from google.cloud.firestore_v1 import transforms +from google.cloud.firestore_v1.gapic import enums +from google.cloud.firestore_v1.proto import query_pb2 +from google.cloud.firestore_v1.order import Order +from google.cloud.firestore_v1.watch import Watch + +_EQ_OP = "==" +_operator_enum = enums.StructuredQuery.FieldFilter.Operator +_COMPARISON_OPERATORS = { + "<": _operator_enum.LESS_THAN, + "<=": _operator_enum.LESS_THAN_OR_EQUAL, + _EQ_OP: _operator_enum.EQUAL, + ">=": _operator_enum.GREATER_THAN_OR_EQUAL, + ">": _operator_enum.GREATER_THAN, + "array_contains": _operator_enum.ARRAY_CONTAINS, + "in": _operator_enum.IN, + "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, +} +_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." +_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' +_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." +_BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." +_INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." +_MISSING_ORDER_BY = ( + 'The "order by" field path {!r} is not present in the cursor data {!r}. ' + "All fields sent to ``order_by()`` must be present in the fields " + "if passed to one of ``start_at()`` / ``start_after()`` / " + "``end_before()`` / ``end_at()`` to define a cursor." +) +_NO_ORDERS_FOR_CURSOR = ( + "Attempting to create a cursor with no fields to order on. " + "When defining a cursor with one of ``start_at()`` / ``start_after()`` / " + "``end_before()`` / ``end_at()``, all fields in the cursor must " + "come from fields set in ``order_by()``." +) +_MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." + + +class AsyncQuery(object): + """Represents a query to the Firestore API. + + Instances of this class are considered immutable: all methods that + would modify an instance instead return a new instance. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + projection (Optional[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.Projection`]): + A projection of document fields to limit the query results to. + field_filters (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.FieldFilter`, ...]]): + The filters to be applied in the query. + orders (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.Order`, ...]]): + The "order by" entries to use in the query. + limit (Optional[int]): + The maximum number of documents the query is allowed to return. + offset (Optional[int]): + The number of results to skip. + start_at (Optional[Tuple[dict, bool]]): + Two-tuple of : + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * an ``after`` flag + + The fields and the flag combine to form a cursor used as + a starting point in a query result set. If the ``after`` + flag is :data:`True`, the results will start just after any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + end_at (Optional[Tuple[dict, bool]]): + Two-tuple of: + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * a ``before`` flag + + The fields and the flag combine to form a cursor used as + an ending point in a query result set. If the ``before`` + flag is :data:`True`, the results will end just before any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + all_descendants (Optional[bool]): + When false, selects only collections that are immediate children + of the `parent` specified in the containing `RunQueryRequest`. + When true, selects all descendant collections. + """ + + ASCENDING = "ASCENDING" + """str: Sort query results in ascending order on a field.""" + DESCENDING = "DESCENDING" + """str: Sort query results in descending order on a field.""" + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + offset=None, + start_at=None, + end_at=None, + all_descendants=False, + ): + self._parent = parent + self._projection = projection + self._field_filters = field_filters + self._orders = orders + self._limit = limit + self._offset = offset + self._start_at = start_at + self._end_at = end_at + self._all_descendants = all_descendants + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + self._parent == other._parent + and self._projection == other._projection + and self._field_filters == other._field_filters + and self._orders == other._orders + and self._limit == other._limit + and self._offset == other._offset + and self._start_at == other._start_at + and self._end_at == other._end_at + and self._all_descendants == other._all_descendants + ) + + @property + def _client(self): + """The client of the parent collection. + + Returns: + :class:`~google.cloud.firestore_v1.client.Client`: + The client that owns this query. + """ + return self._parent._client + + def select(self, field_paths): + """Project documents matching query to a limited set of fields. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If the current query already has a projection set (i.e. has already + called :meth:`~google.cloud.firestore_v1.query.Query.select`), this + will overwrite it. + + Args: + field_paths (Iterable[str, ...]): An iterable of field paths + (``.``-delimited list of field names) to use as a projection + of document fields in the query results. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A "projected" query. Acts as a copy of the current query, + modified with the newly added projection. + Raises: + ValueError: If any ``field_path`` is invalid. + """ + field_paths = list(field_paths) + for field_path in field_paths: + field_path_module.split_field_path(field_path) # raises + + new_projection = query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ) + return self.__class__( + self._parent, + projection=new_projection, + field_filters=self._field_filters, + orders=self._orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def where(self, field_path, op_string, value): + """Filter the query on a field. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + Returns a new :class:`~google.cloud.firestore_v1.query.Query` that + filters on a specific field path, according to an operation (e.g. + ``==`` or "equals") and a particular value to be paired with that + operation. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) for the field to filter on. + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, + ``in``, ``array_contains`` and ``array_contains_any``. + value (Any): The value to compare the field against in the filter. + If ``value`` is :data:`None` or a NaN, then ``==`` is the only + allowed operation. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A filtered query. Acts as a copy of the current query, + modified with the newly added filter. + + Raises: + ValueError: If ``field_path`` is invalid. + ValueError: If ``value`` is a NaN or :data:`None` and + ``op_string`` is not ``==``. + """ + field_path_module.split_field_path(field_path) # raises + + if value is None: + if op_string != _EQ_OP: + raise ValueError(_BAD_OP_NAN_NULL) + filter_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + elif _isnan(value): + if op_string != _EQ_OP: + raise ValueError(_BAD_OP_NAN_NULL) + filter_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NAN, + ) + elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): + raise ValueError(_INVALID_WHERE_TRANSFORM) + else: + filter_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=_enum_from_op_string(op_string), + value=_helpers.encode_value(value), + ) + + new_filters = self._field_filters + (filter_pb,) + return self.__class__( + self._parent, + projection=self._projection, + field_filters=new_filters, + orders=self._orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + @staticmethod + def _make_order(field_path, direction): + """Helper for :meth:`order_by`.""" + return query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + direction=_enum_from_direction(direction), + ) + + def order_by(self, field_path, direction=ASCENDING): + """Modify the query to add an order clause on a specific field. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + Successive :meth:`~google.cloud.firestore_v1.query.Query.order_by` + calls will further refine the ordering of results returned by the query + (i.e. the new "order by" fields will be added to existing ones). + + Args: + field_path (str): A field path (``.``-delimited list of + field names) on which to order the query results. + direction (Optional[str]): The direction to order by. Must be one + of :attr:`ASCENDING` or :attr:`DESCENDING`, defaults to + :attr:`ASCENDING`. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An ordered query. Acts as a copy of the current query, modified + with the newly added "order by" constraint. + + Raises: + ValueError: If ``field_path`` is invalid. + ValueError: If ``direction`` is not one of :attr:`ASCENDING` or + :attr:`DESCENDING`. + """ + field_path_module.split_field_path(field_path) # raises + + order_pb = self._make_order(field_path, direction) + + new_orders = self._orders + (order_pb,) + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=new_orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def limit(self, count): + """Limit a query to return a fixed number of results. + + If the current query already has a limit set, this will overwrite it. + + Args: + count (int): Maximum number of documents to return that match + the query. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A limited query. Acts as a copy of the current query, modified + with the newly added "limit" filter. + """ + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=self._orders, + limit=count, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def offset(self, num_to_skip): + """Skip to an offset in a query. + + If the current query already has specified an offset, this will + overwrite it. + + Args: + num_to_skip (int): The number of results to skip at the beginning + of query results. (Must be non-negative.) + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An offset query. Acts as a copy of the current query, modified + with the newly added "offset" field. + """ + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=self._orders, + limit=self._limit, + offset=num_to_skip, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def _check_snapshot(self, document_fields): + """Validate local snapshots for non-collection-group queries. + + Raises: + ValueError: for non-collection-group queries, if the snapshot + is from a different collection. + """ + if self._all_descendants: + return + + if document_fields.reference._path[:-1] != self._parent._path: + raise ValueError("Cannot use snapshot from another collection as a cursor.") + + def _cursor_helper(self, document_fields, before, start): + """Set values to be used for a ``start_at`` or ``end_at`` cursor. + + The values will later be used in a query protobuf. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + before (bool): Flag indicating if the document in + ``document_fields`` should (:data:`False`) or + shouldn't (:data:`True`) be included in the result set. + start (Optional[bool]): determines if the cursor is a ``start_at`` + cursor (:data:`True`) or an ``end_at`` cursor (:data:`False`). + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "start at" cursor. + """ + if isinstance(document_fields, tuple): + document_fields = list(document_fields) + elif isinstance(document_fields, document.DocumentSnapshot): + self._check_snapshot(document_fields) + else: + # NOTE: We copy so that the caller can't modify after calling. + document_fields = copy.deepcopy(document_fields) + + cursor_pair = document_fields, before + query_kwargs = { + "projection": self._projection, + "field_filters": self._field_filters, + "orders": self._orders, + "limit": self._limit, + "offset": self._offset, + "all_descendants": self._all_descendants, + } + if start: + query_kwargs["start_at"] = cursor_pair + query_kwargs["end_at"] = self._end_at + else: + query_kwargs["start_at"] = self._start_at + query_kwargs["end_at"] = cursor_pair + + return self.__class__(self._parent, **query_kwargs) + + def start_at(self, document_fields): + """Start query results at a particular document value. + + The result set will **include** the document specified by + ``document_fields``. + + If the current query already has specified a start cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.start_after` -- this + will overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as + a copy of the current query, modified with the newly added + "start at" cursor. + """ + return self._cursor_helper(document_fields, before=True, start=True) + + def start_after(self, document_fields): + """Start query results after a particular document value. + + The result set will **exclude** the document specified by + ``document_fields``. + + If the current query already has specified a start cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.start_at` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "start after" cursor. + """ + return self._cursor_helper(document_fields, before=False, start=True) + + def end_before(self, document_fields): + """End query results before a particular document value. + + The result set will **exclude** the document specified by + ``document_fields``. + + If the current query already has specified an end cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.end_at` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "end before" cursor. + """ + return self._cursor_helper(document_fields, before=True, start=False) + + def end_at(self, document_fields): + """End query results at a particular document value. + + The result set will **include** the document specified by + ``document_fields``. + + If the current query already has specified an end cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.end_before` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "end at" cursor. + """ + return self._cursor_helper(document_fields, before=False, start=False) + + def _filters_pb(self): + """Convert all the filters into a single generic Filter protobuf. + + This may be a lone field filter or unary filter, may be a composite + filter or may be :data:`None`. + + Returns: + :class:`google.cloud.firestore_v1.types.StructuredQuery.Filter`: + A "generic" filter representing the current query's filters. + """ + num_filters = len(self._field_filters) + if num_filters == 0: + return None + elif num_filters == 1: + return _filter_pb(self._field_filters[0]) + else: + composite_filter = query_pb2.StructuredQuery.CompositeFilter( + op=enums.StructuredQuery.CompositeFilter.Operator.AND, + filters=[_filter_pb(filter_) for filter_ in self._field_filters], + ) + return query_pb2.StructuredQuery.Filter(composite_filter=composite_filter) + + @staticmethod + def _normalize_projection(projection): + """Helper: convert field paths to message.""" + if projection is not None: + + fields = list(projection.fields) + + if not fields: + field_ref = query_pb2.StructuredQuery.FieldReference( + field_path="__name__" + ) + return query_pb2.StructuredQuery.Projection(fields=[field_ref]) + + return projection + + def _normalize_orders(self): + """Helper: adjust orders based on cursors, where clauses.""" + orders = list(self._orders) + _has_snapshot_cursor = False + + if self._start_at: + if isinstance(self._start_at[0], document.DocumentSnapshot): + _has_snapshot_cursor = True + + if self._end_at: + if isinstance(self._end_at[0], document.DocumentSnapshot): + _has_snapshot_cursor = True + + if _has_snapshot_cursor: + should_order = [ + _enum_from_op_string(key) + for key in _COMPARISON_OPERATORS + if key not in (_EQ_OP, "array_contains") + ] + order_keys = [order.field.field_path for order in orders] + for filter_ in self._field_filters: + field = filter_.field.field_path + if filter_.op in should_order and field not in order_keys: + orders.append(self._make_order(field, "ASCENDING")) + if not orders: + orders.append(self._make_order("__name__", "ASCENDING")) + else: + order_keys = [order.field.field_path for order in orders] + if "__name__" not in order_keys: + direction = orders[-1].direction # enum? + orders.append(self._make_order("__name__", direction)) + + return orders + + def _normalize_cursor(self, cursor, orders): + """Helper: convert cursor to a list of values based on orders.""" + if cursor is None: + return + + if not orders: + raise ValueError(_NO_ORDERS_FOR_CURSOR) + + document_fields, before = cursor + + order_keys = [order.field.field_path for order in orders] + + if isinstance(document_fields, document.DocumentSnapshot): + snapshot = document_fields + document_fields = snapshot.to_dict() + document_fields["__name__"] = snapshot.reference + + if isinstance(document_fields, dict): + # Transform to list using orders + values = [] + data = document_fields + for order_key in order_keys: + try: + if order_key in data: + values.append(data[order_key]) + else: + values.append( + field_path_module.get_nested_value(order_key, data) + ) + except KeyError: + msg = _MISSING_ORDER_BY.format(order_key, data) + raise ValueError(msg) + document_fields = values + + if len(document_fields) != len(orders): + msg = _MISMATCH_CURSOR_W_ORDER_BY.format(document_fields, order_keys) + raise ValueError(msg) + + _transform_bases = (transforms.Sentinel, transforms._ValueList) + + for index, key_field in enumerate(zip(order_keys, document_fields)): + key, field = key_field + + if isinstance(field, _transform_bases): + msg = _INVALID_CURSOR_TRANSFORM + raise ValueError(msg) + + if key == "__name__" and isinstance(field, six.string_types): + document_fields[index] = self._parent.document(field) + + return document_fields, before + + def _to_protobuf(self): + """Convert the current query into the equivalent protobuf. + + Returns: + :class:`google.cloud.firestore_v1.types.StructuredQuery`: + The query protobuf. + """ + projection = self._normalize_projection(self._projection) + orders = self._normalize_orders() + start_at = self._normalize_cursor(self._start_at, orders) + end_at = self._normalize_cursor(self._end_at, orders) + + query_kwargs = { + "select": projection, + "from": [ + query_pb2.StructuredQuery.CollectionSelector( + collection_id=self._parent.id, all_descendants=self._all_descendants + ) + ], + "where": self._filters_pb(), + "order_by": orders, + "start_at": _cursor_pb(start_at), + "end_at": _cursor_pb(end_at), + } + if self._offset is not None: + query_kwargs["offset"] = self._offset + if self._limit is not None: + query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) + + return query_pb2.StructuredQuery(**query_kwargs) + + async def get(self, transaction=None): + """Deprecated alias for :meth:`stream`.""" + warnings.warn( + "'Query.get' is deprecated: please use 'Query.stream' instead.", + DeprecationWarning, + stacklevel=2, + ) + return await self.stream(transaction=transaction) + + async def stream(self, transaction=None): + """Read the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + parent_path, expected_prefix = self._parent._parent_info() + response_iterator = self._client._firestore_api.run_query( + parent_path, + self._to_protobuf(), + transaction=_helpers.get_transaction_id(transaction), + metadata=self._client._rpc_metadata, + ) + + for response in response_iterator: + if self._all_descendants: + snapshot = _collection_group_query_response_to_snapshot( + response, self._parent + ) + else: + snapshot = _query_response_to_snapshot( + response, self._parent, expected_prefix + ) + if snapshot is not None: + yield snapshot + + def on_snapshot(self, callback): + """Monitor the documents in this collection that match this query. + + This starts a watch on this query using a background thread. The + provided callback is run on the snapshot of the documents. + + Args: + callback(Callable[[:class:`~google.cloud.firestore.query.QuerySnapshot`], NoneType]): + a callback to run when a change occurs. + + Example: + + .. code-block:: python + + from google.cloud import firestore_v1 + + db = firestore_v1.Client() + query_ref = db.collection(u'users').where("user", "==", u'Ada') + + def on_snapshot(docs, changes, read_time): + for doc in docs: + print(u'{} => {}'.format(doc.id, doc.to_dict())) + + # Watch this query + query_watch = query_ref.on_snapshot(on_snapshot) + + # Terminate this watch + query_watch.unsubscribe() + """ + return Watch.for_query( + self, callback, document.DocumentSnapshot, document.DocumentReference + ) + + def _comparator(self, doc1, doc2): + _orders = self._orders + + # Add implicit sorting by name, using the last specified direction. + if len(_orders) == 0: + lastDirection = Query.ASCENDING + else: + if _orders[-1].direction == 1: + lastDirection = Query.ASCENDING + else: + lastDirection = Query.DESCENDING + + orderBys = list(_orders) + + order_pb = query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path="id"), + direction=_enum_from_direction(lastDirection), + ) + orderBys.append(order_pb) + + for orderBy in orderBys: + if orderBy.field.field_path == "id": + # If ordering by docuent id, compare resource paths. + comp = Order()._compare_to(doc1.reference._path, doc2.reference._path) + else: + if ( + orderBy.field.field_path not in doc1._data + or orderBy.field.field_path not in doc2._data + ): + raise ValueError( + "Can only compare fields that exist in the " + "DocumentSnapshot. Please include the fields you are " + "ordering on in your select() call." + ) + v1 = doc1._data[orderBy.field.field_path] + v2 = doc2._data[orderBy.field.field_path] + encoded_v1 = _helpers.encode_value(v1) + encoded_v2 = _helpers.encode_value(v2) + comp = Order().compare(encoded_v1, encoded_v2) + + if comp != 0: + # 1 == Ascending, -1 == Descending + return orderBy.direction * comp + + return 0 + + +def _enum_from_op_string(op_string): + """Convert a string representation of a binary operator to an enum. + + These enums come from the protobuf message definition + ``StructuredQuery.FieldFilter.Operator``. + + Args: + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=`` + and ``>``. + + Returns: + int: The enum corresponding to ``op_string``. + + Raises: + ValueError: If ``op_string`` is not a valid operator. + """ + try: + return _COMPARISON_OPERATORS[op_string] + except KeyError: + choices = ", ".join(sorted(_COMPARISON_OPERATORS.keys())) + msg = _BAD_OP_STRING.format(op_string, choices) + raise ValueError(msg) + + +def _isnan(value): + """Check if a value is NaN. + + This differs from ``math.isnan`` in that **any** input type is + allowed. + + Args: + value (Any): A value to check for NaN-ness. + + Returns: + bool: Indicates if the value is the NaN float. + """ + if isinstance(value, float): + return math.isnan(value) + else: + return False + + +def _enum_from_direction(direction): + """Convert a string representation of a direction to an enum. + + Args: + direction (str): A direction to order by. Must be one of + :attr:`~google.cloud.firestore.Query.ASCENDING` or + :attr:`~google.cloud.firestore.Query.DESCENDING`. + + Returns: + int: The enum corresponding to ``direction``. + + Raises: + ValueError: If ``direction`` is not a valid direction. + """ + if isinstance(direction, int): + return direction + + if direction == AsyncQuery.ASCENDING: + return enums.StructuredQuery.Direction.ASCENDING + elif direction == AsyncQuery.DESCENDING: + return enums.StructuredQuery.Direction.DESCENDING + else: + msg = _BAD_DIR_STRING.format(direction, Query.ASCENDING, Query.DESCENDING) + raise ValueError(msg) + + +def _filter_pb(field_or_unary): + """Convert a specific protobuf filter to the generic filter type. + + Args: + field_or_unary (Union[google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.FieldFilter, google.cloud.proto.\ + firestore.v1.query_pb2.StructuredQuery.FieldFilter]): A + field or unary filter to convert to a generic filter. + + Returns: + google.cloud.firestore_v1.types.\ + StructuredQuery.Filter: A "generic" filter. + + Raises: + ValueError: If ``field_or_unary`` is not a field or unary filter. + """ + if isinstance(field_or_unary, query_pb2.StructuredQuery.FieldFilter): + return query_pb2.StructuredQuery.Filter(field_filter=field_or_unary) + elif isinstance(field_or_unary, query_pb2.StructuredQuery.UnaryFilter): + return query_pb2.StructuredQuery.Filter(unary_filter=field_or_unary) + else: + raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) + + +def _cursor_pb(cursor_pair): + """Convert a cursor pair to a protobuf. + + If ``cursor_pair`` is :data:`None`, just returns :data:`None`. + + Args: + cursor_pair (Optional[Tuple[list, bool]]): Two-tuple of + + * a list of field values. + * a ``before`` flag + + Returns: + Optional[google.cloud.firestore_v1.types.Cursor]: A + protobuf cursor corresponding to the values. + """ + if cursor_pair is not None: + data, before = cursor_pair + value_pbs = [_helpers.encode_value(value) for value in data] + return query_pb2.Cursor(values=value_pbs, before=before) + + +def _query_response_to_snapshot(response_pb, collection, expected_prefix): + """Parse a query response protobuf to a document snapshot. + + Args: + response_pb (google.cloud.proto.firestore.v1.\ + firestore_pb2.RunQueryResponse): A + collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + A reference to the collection that initiated the query. + expected_prefix (str): The expected prefix for fully-qualified + document names returned in the query results. This can be computed + directly from ``collection`` via :meth:`_parent_info`. + + Returns: + Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: + A snapshot of the data returned in the query. If + ``response_pb.document`` is not set, the snapshot will be :data:`None`. + """ + if not response_pb.HasField("document"): + return None + + document_id = _helpers.get_doc_id(response_pb.document, expected_prefix) + reference = collection.document(document_id) + data = _helpers.decode_dict(response_pb.document.fields, collection._client) + snapshot = document.DocumentSnapshot( + reference, + data, + exists=True, + read_time=response_pb.read_time, + create_time=response_pb.document.create_time, + update_time=response_pb.document.update_time, + ) + return snapshot + + +def _collection_group_query_response_to_snapshot(response_pb, collection): + """Parse a query response protobuf to a document snapshot. + + Args: + response_pb (google.cloud.proto.firestore.v1.\ + firestore_pb2.RunQueryResponse): A + collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + A reference to the collection that initiated the query. + + Returns: + Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: + A snapshot of the data returned in the query. If + ``response_pb.document`` is not set, the snapshot will be :data:`None`. + """ + if not response_pb.HasField("document"): + return None + reference = collection._client.document(response_pb.document.name) + data = _helpers.decode_dict(response_pb.document.fields, collection._client) + snapshot = document.DocumentSnapshot( + reference, + data, + exists=True, + read_time=response_pb.read_time, + create_time=response_pb.document.create_time, + update_time=response_pb.document.update_time, + ) + return snapshot From ec21bdc393a8d356cc1ba0ec3aa0e72e5d83c11b Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Thu, 18 Jun 2020 13:15:16 -0500 Subject: [PATCH 14/47] feat: add async_query integration --- google/cloud/firestore_v1/async_client.py | 8 +-- google/cloud/firestore_v1/async_collection.py | 68 ++++++++++--------- tests/unit/v1/async/test_async_batch.py | 5 +- tests/unit/v1/async/test_async_collection.py | 63 +++++++++-------- 4 files changed, 76 insertions(+), 68 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 362e9b8d69..e482cae3ea 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -31,7 +31,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import __version__ -from google.cloud.firestore_v1 import query +from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -250,7 +250,7 @@ def collection(self, *collection_path): def collection_group(self, collection_id): """ - Creates and returns a new Query that includes all documents in the + Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the given collection_id. @@ -261,7 +261,7 @@ def collection_group(self, collection_id): @param {string} collectionId Identifies the collections to query over. Every collection or subcollection with this ID as the last segment of its path will be included. Cannot contain a slash. - @returns {Query} The created Query. + @returns {AsyncQuery} The created AsyncQuery. """ if "/" in collection_id: raise ValueError( @@ -271,7 +271,7 @@ def collection_group(self, collection_id): ) collection = self.collection(collection_id) - return query.Query(collection, all_descendants=True) + return AsyncQuery(collection, all_descendants=True) def document(self, *document_path): """Get a reference to a document in a collection. diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 030a4ceb49..c6a0fea3cb 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -19,7 +19,7 @@ import six from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import query as query_mod +from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import async_document @@ -193,7 +193,7 @@ def select(self, field_paths): """Create a "select" query with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.select` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.select` for more information on this method. Args: @@ -202,17 +202,17 @@ def select(self, field_paths): of document fields in the query results. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A "projected" query. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.select(field_paths) def where(self, field_path, op_string, value): """Create a "where" query with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.where` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.where` for more information on this method. Args: @@ -226,17 +226,17 @@ def where(self, field_path, op_string, value): allowed operation. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A filtered query. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.where(field_path, op_string, value) def order_by(self, field_path, **kwargs): """Create an "order by" query with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.order_by` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.order_by` for more information on this method. Args: @@ -244,21 +244,21 @@ def order_by(self, field_path, **kwargs): field names) on which to order the query results. kwargs (Dict[str, Any]): The keyword arguments to pass along to the query. The only supported keyword is ``direction``, - see :meth:`~google.cloud.firestore_v1.query.Query.order_by` + see :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.order_by` for more information. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: An "order by" query. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.order_by(field_path, **kwargs) def limit(self, count): """Create a limited query with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.limit` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.limit` for more information on this method. Args: @@ -266,17 +266,17 @@ def limit(self, count): the query. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A limited query. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.limit(count) def offset(self, num_to_skip): """Skip to an offset in a query with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.offset` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.offset` for more information on this method. Args: @@ -284,17 +284,17 @@ def offset(self, num_to_skip): of query results. (Must be non-negative.) Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: An offset query. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.offset(num_to_skip) def start_at(self, document_fields): """Start query at a cursor with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.start_at` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.start_at` for more information on this method. Args: @@ -305,17 +305,17 @@ def start_at(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.start_at(document_fields) def start_after(self, document_fields): """Start query after a cursor with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.start_after` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.start_after` for more information on this method. Args: @@ -326,17 +326,17 @@ def start_after(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.start_after(document_fields) def end_before(self, document_fields): """End query before a cursor with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.end_before` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.end_before` for more information on this method. Args: @@ -347,17 +347,17 @@ def end_before(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.end_before(document_fields) def end_at(self, document_fields): """End query at a cursor with this collection as parent. See - :meth:`~google.cloud.firestore_v1.query.Query.end_at` for + :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.end_at` for more information on this method. Args: @@ -368,10 +368,10 @@ def end_at(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = query_mod.Query(self) + query = AsyncQuery(self) return query.end_at(document_fields) async def get(self, transaction=None): @@ -381,7 +381,8 @@ async def get(self, transaction=None): DeprecationWarning, stacklevel=2, ) - return await self.stream(transaction=transaction) + async for d in self.stream(transaction=transaction): + yield d async def stream(self, transaction=None): """Read the documents in this collection. @@ -410,8 +411,9 @@ async def stream(self, transaction=None): :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - query = query_mod.Query(self) - return query.stream(transaction=transaction) + query = AsyncQuery(self) + async for d in query.stream(transaction=transaction): + yield d def on_snapshot(self, callback): """Monitor the documents in this collection. @@ -440,7 +442,7 @@ def on_snapshot(collection_snapshot, changes, read_time): collection_watch.unsubscribe() """ return Watch.for_query( - query_mod.Query(self), + AsyncQuery(self), callback, async_document.DocumentSnapshot, async_document.AsyncDocumentReference, diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index 0816c2d32f..7df76a6dae 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -212,7 +212,9 @@ def test_as_context_mgr_wo_error(self): document1 = client.document("a", "b") document2 = client.document("c", "d", "e", "f") - write_pbs = asyncio.run(self._as_context_mgr_wo_error_helper(batch, document1, document2)) + write_pbs = asyncio.run( + self._as_context_mgr_wo_error_helper(batch, document1, document2) + ) self.assertEqual(batch.write_results, list(commit_response.write_results)) self.assertEqual(batch.commit_time, timestamp) @@ -260,7 +262,6 @@ async def _as_context_mgr_w_error_helper(self, batch, document1, document2): raise RuntimeError("testing") - def _value_pb(**kwargs): from google.cloud.firestore_v1.proto.document_pb2 import Value diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 6688cbe80b..c9afe7486e 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import pytest import types import unittest @@ -40,9 +41,9 @@ def _get_public_methods(klass): ) def test_query_method_matching(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery - query_methods = self._get_public_methods(Query) + query_methods = self._get_public_methods(AsyncQuery) klass = self._get_target_class() collection_methods = self._get_public_methods(klass) # Make sure every query method is present on @@ -297,13 +298,13 @@ def test_add_explicit_id(self): ) def test_select(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") field_paths = ["a", "b"] query = collection.select(field_paths) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) projection_paths = [ field_ref.field_path for field_ref in query._projection.fields @@ -323,7 +324,7 @@ def _make_field_filter_pb(field_path, op_string, value): ) def test_where(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") field_path = "foo" @@ -331,7 +332,7 @@ def test_where(self): value = 45 query = collection.where(field_path, op_string, value) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(len(query._field_filters), 1) field_filter_pb = query._field_filters[0] @@ -350,82 +351,82 @@ def _make_order_pb(field_path, direction): ) def test_order_by(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") field_path = "foo" - direction = Query.DESCENDING + direction = AsyncQuery.DESCENDING query = collection.order_by(field_path, direction=direction) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(len(query._orders), 1) order_pb = query._orders[0] self.assertEqual(order_pb, self._make_order_pb(field_path, direction)) def test_limit(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") limit = 15 query = collection.limit(limit) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._limit, limit) def test_offset(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") offset = 113 query = collection.offset(offset) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._offset, offset) def test_start_at(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") doc_fields = {"a": "b"} query = collection.start_at(doc_fields) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._start_at, (doc_fields, True)) def test_start_after(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") doc_fields = {"d": "foo", "e": 10} query = collection.start_after(doc_fields) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._start_at, (doc_fields, False)) def test_end_before(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") doc_fields = {"bar": 10.5} query = collection.end_before(doc_fields) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._end_at, (doc_fields, True)) def test_end_at(self): - from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.async_query import AsyncQuery collection = self._make_one("collection") doc_fields = {"opportunity": True, "reason": 9} query = collection.end_at(doc_fields) - self.assertIsInstance(query, Query) + self.assertIsInstance(query, AsyncQuery) self.assertIs(query._parent, collection) self.assertEqual(query._end_at, (doc_fields, False)) @@ -487,13 +488,14 @@ def test_list_documents_wo_page_size(self): def test_list_documents_w_page_size(self): self._list_documents_helper(page_size=25) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + @pytest.mark.skip(reason="no way of currently testing this") + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) def test_get(self, query_class): import warnings collection = self._make_one("collection") with warnings.catch_warnings(record=True) as warned: - get_response = asyncio.run(collection.get()) + get_response = collection.get() query_class.assert_called_once_with(collection) query_instance = query_class.return_value @@ -504,14 +506,15 @@ def test_get(self, query_class): self.assertEqual(len(warned), 1) self.assertIs(warned[0].category, DeprecationWarning) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + @pytest.mark.skip(reason="no way of currently testing this") + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) def test_get_with_transaction(self, query_class): import warnings collection = self._make_one("collection") transaction = mock.sentinel.txn with warnings.catch_warnings(record=True) as warned: - get_response = asyncio.run(collection.get(transaction=transaction)) + get_response = collection.get(transaction=transaction) query_class.assert_called_once_with(collection) query_instance = query_class.return_value @@ -522,21 +525,23 @@ def test_get_with_transaction(self, query_class): self.assertEqual(len(warned), 1) self.assertIs(warned[0].category, DeprecationWarning) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + @pytest.mark.skip(reason="no way of currently testing this") + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) def test_stream(self, query_class): collection = self._make_one("collection") - stream_response = asyncio.run(collection.stream()) + stream_response = collection.stream() query_class.assert_called_once_with(collection) query_instance = query_class.return_value self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) + @pytest.mark.skip(reason="no way of currently testing this") + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) def test_stream_with_transaction(self, query_class): collection = self._make_one("collection") transaction = mock.sentinel.txn - stream_response = asyncio.run(collection.stream(transaction=transaction)) + stream_response = collection.stream(transaction=transaction) query_class.assert_called_once_with(collection) query_instance = query_class.return_value From 42595bcd3097306df76c6fc114da6d638344251a Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Fri, 19 Jun 2020 11:10:15 -0500 Subject: [PATCH 15/47] feat: add async_query tests --- tests/unit/v1/async/test_async_query.py | 1771 +++++++++++++++++++++++ 1 file changed, 1771 insertions(+) create mode 100644 tests/unit/v1/async/test_async_query.py diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py new file mode 100644 index 0000000000..9b47641522 --- /dev/null +++ b/tests/unit/v1/async/test_async_query.py @@ -0,0 +1,1771 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import types +import unittest + +import mock +import six + + +class TestAsyncQuery(unittest.TestCase): + + if six.PY2: + assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_query import AsyncQuery + + return AsyncQuery + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIsNone(query._projection) + self.assertEqual(query._field_filters, ()) + self.assertEqual(query._orders, ()) + self.assertIsNone(query._limit) + self.assertIsNone(query._offset) + self.assertIsNone(query._start_at) + self.assertIsNone(query._end_at) + self.assertFalse(query._all_descendants) + + def _make_one_all_fields( + self, limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True + ): + kwargs = { + "projection": mock.sentinel.projection, + "field_filters": mock.sentinel.filters, + "orders": mock.sentinel.orders, + "limit": limit, + "offset": offset, + "start_at": mock.sentinel.start_at, + "end_at": mock.sentinel.end_at, + "all_descendants": all_descendants, + } + for field in skip_fields: + kwargs.pop(field) + if parent is None: + parent = mock.sentinel.parent + return self._make_one(parent, **kwargs) + + def test_constructor_explicit(self): + limit = 234 + offset = 56 + query = self._make_one_all_fields(limit=limit, offset=offset) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIs(query._projection, mock.sentinel.projection) + self.assertIs(query._field_filters, mock.sentinel.filters) + self.assertEqual(query._orders, mock.sentinel.orders) + self.assertEqual(query._limit, limit) + self.assertEqual(query._offset, offset) + self.assertIs(query._start_at, mock.sentinel.start_at) + self.assertIs(query._end_at, mock.sentinel.end_at) + self.assertTrue(query._all_descendants) + + def test__client_property(self): + parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) + query = self._make_one(parent) + self.assertIs(query._client, mock.sentinel.client) + + def test___eq___other_type(self): + query = self._make_one_all_fields() + other = object() + self.assertFalse(query == other) + + def test___eq___different_parent(self): + parent = mock.sentinel.parent + other_parent = mock.sentinel.other_parent + query = self._make_one_all_fields(parent=parent) + other = self._make_one_all_fields(parent=other_parent) + self.assertFalse(query == other) + + def test___eq___different_projection(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) + query._projection = mock.sentinel.projection + other = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) + other._projection = mock.sentinel.other_projection + self.assertFalse(query == other) + + def test___eq___different_field_filters(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) + query._field_filters = mock.sentinel.field_filters + other = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) + other._field_filters = mock.sentinel.other_field_filters + self.assertFalse(query == other) + + def test___eq___different_orders(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) + query._orders = mock.sentinel.orders + other = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) + other._orders = mock.sentinel.other_orders + self.assertFalse(query == other) + + def test___eq___different_limit(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, limit=10) + other = self._make_one_all_fields(parent=parent, limit=20) + self.assertFalse(query == other) + + def test___eq___different_offset(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, offset=10) + other = self._make_one_all_fields(parent=parent, offset=20) + self.assertFalse(query == other) + + def test___eq___different_start_at(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) + query._start_at = mock.sentinel.start_at + other = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) + other._start_at = mock.sentinel.other_start_at + self.assertFalse(query == other) + + def test___eq___different_end_at(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) + query._end_at = mock.sentinel.end_at + other = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) + other._end_at = mock.sentinel.other_end_at + self.assertFalse(query == other) + + def test___eq___different_all_descendants(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, all_descendants=True) + other = self._make_one_all_fields(parent=parent, all_descendants=False) + self.assertFalse(query == other) + + def test___eq___hit(self): + query = self._make_one_all_fields() + other = self._make_one_all_fields() + self.assertTrue(query == other) + + def _compare_queries(self, query1, query2, attr_name): + attrs1 = query1.__dict__.copy() + attrs2 = query2.__dict__.copy() + + attrs1.pop(attr_name) + attrs2.pop(attr_name) + + # The only different should be in ``attr_name``. + self.assertEqual(len(attrs1), len(attrs2)) + for key, value in attrs1.items(): + self.assertIs(value, attrs2[key]) + + @staticmethod + def _make_projection_for_select(field_paths): + from google.cloud.firestore_v1.proto import query_pb2 + + return query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ) + + def test_select_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.select(["*"]) + + def test_select(self): + query1 = self._make_one_all_fields(all_descendants=True) + + field_paths2 = ["foo", "bar"] + query2 = query1.select(field_paths2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual( + query2._projection, self._make_projection_for_select(field_paths2) + ) + self._compare_queries(query1, query2, "_projection") + + # Make sure it overrides. + field_paths3 = ["foo.baz"] + query3 = query2.select(field_paths3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual( + query3._projection, self._make_projection_for_select(field_paths3) + ) + self._compare_queries(query2, query3, "_projection") + + def test_where_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.where("*", "==", 1) + + def test_where(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query = self._make_one_all_fields( + skip_fields=("field_filters",), all_descendants=True + ) + new_query = query.where("power.level", ">", 9000) + + self.assertIsNot(query, new_query) + self.assertIsInstance(new_query, self._get_target_class()) + self.assertEqual(len(new_query._field_filters), 1) + + field_pb = new_query._field_filters[0] + expected_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="power.level"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(integer_value=9000), + ) + self.assertEqual(field_pb, expected_pb) + self._compare_queries(query, new_query, "_field_filters") + + def _where_unary_helper(self, value, op_enum, op_string="=="): + from google.cloud.firestore_v1.proto import query_pb2 + + query = self._make_one_all_fields(skip_fields=("field_filters",)) + field_path = "feeeld" + new_query = query.where(field_path, op_string, value) + + self.assertIsNot(query, new_query) + self.assertIsInstance(new_query, self._get_target_class()) + self.assertEqual(len(new_query._field_filters), 1) + + field_pb = new_query._field_filters[0] + expected_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + self.assertEqual(field_pb, expected_pb) + self._compare_queries(query, new_query, "_field_filters") + + def test_where_eq_null(self): + from google.cloud.firestore_v1.gapic import enums + + op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NULL + self._where_unary_helper(None, op_enum) + + def test_where_gt_null(self): + with self.assertRaises(ValueError): + self._where_unary_helper(None, 0, op_string=">") + + def test_where_eq_nan(self): + from google.cloud.firestore_v1.gapic import enums + + op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NAN + self._where_unary_helper(float("nan"), op_enum) + + def test_where_le_nan(self): + with self.assertRaises(ValueError): + self._where_unary_helper(float("nan"), 0, op_string="<=") + + def test_where_w_delete(self): + from google.cloud.firestore_v1 import DELETE_FIELD + + with self.assertRaises(ValueError): + self._where_unary_helper(DELETE_FIELD, 0) + + def test_where_w_server_timestamp(self): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + with self.assertRaises(ValueError): + self._where_unary_helper(SERVER_TIMESTAMP, 0) + + def test_where_w_array_remove(self): + from google.cloud.firestore_v1 import ArrayRemove + + with self.assertRaises(ValueError): + self._where_unary_helper(ArrayRemove([1, 3, 5]), 0) + + def test_where_w_array_union(self): + from google.cloud.firestore_v1 import ArrayUnion + + with self.assertRaises(ValueError): + self._where_unary_helper(ArrayUnion([2, 4, 8]), 0) + + def test_order_by_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.order_by("*") + + def test_order_by(self): + from google.cloud.firestore_v1.gapic import enums + + klass = self._get_target_class() + query1 = self._make_one_all_fields( + skip_fields=("orders",), all_descendants=True + ) + + field_path2 = "a" + query2 = query1.order_by(field_path2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, klass) + order_pb2 = _make_order_pb( + field_path2, enums.StructuredQuery.Direction.ASCENDING + ) + self.assertEqual(query2._orders, (order_pb2,)) + self._compare_queries(query1, query2, "_orders") + + # Make sure it appends to the orders. + field_path3 = "b" + query3 = query2.order_by(field_path3, direction=klass.DESCENDING) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, klass) + order_pb3 = _make_order_pb( + field_path3, enums.StructuredQuery.Direction.DESCENDING + ) + self.assertEqual(query3._orders, (order_pb2, order_pb3)) + self._compare_queries(query2, query3, "_orders") + + def test_limit(self): + query1 = self._make_one_all_fields(all_descendants=True) + + limit2 = 100 + query2 = query1.limit(limit2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual(query2._limit, limit2) + self._compare_queries(query1, query2, "_limit") + + # Make sure it overrides. + limit3 = 10 + query3 = query2.limit(limit3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._limit, limit3) + self._compare_queries(query2, query3, "_limit") + + def test_offset(self): + query1 = self._make_one_all_fields(all_descendants=True) + + offset2 = 23 + query2 = query1.offset(offset2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual(query2._offset, offset2) + self._compare_queries(query1, query2, "_offset") + + # Make sure it overrides. + offset3 = 35 + query3 = query2.offset(offset3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._offset, offset3) + self._compare_queries(query2, query3, "_offset") + + @staticmethod + def _make_collection(*path, **kw): + from google.cloud.firestore_v1 import collection + + return collection.CollectionReference(*path, **kw) + + @staticmethod + def _make_docref(*path, **kw): + from google.cloud.firestore_v1 import document + + return document.DocumentReference(*path, **kw) + + @staticmethod + def _make_snapshot(docref, values): + from google.cloud.firestore_v1 import document + + return document.DocumentSnapshot(docref, values, True, None, None, None) + + def test__cursor_helper_w_dict(self): + values = {"a": 7, "b": "foo"} + query1 = self._make_one(mock.sentinel.parent) + query1._all_descendants = True + query2 = query1._cursor_helper(values, True, True) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._end_at) + self.assertTrue(query2._all_descendants) + + cursor, before = query2._start_at + + self.assertEqual(cursor, values) + self.assertTrue(before) + + def test__cursor_helper_w_tuple(self): + values = (7, "foo") + query1 = self._make_one(mock.sentinel.parent) + query2 = query1._cursor_helper(values, False, True) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._end_at) + + cursor, before = query2._start_at + + self.assertEqual(cursor, list(values)) + self.assertFalse(before) + + def test__cursor_helper_w_list(self): + values = [7, "foo"] + query1 = self._make_one(mock.sentinel.parent) + query2 = query1._cursor_helper(values, True, False) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertEqual(cursor, values) + self.assertIsNot(cursor, values) + self.assertTrue(before) + + def test__cursor_helper_w_snapshot_wrong_collection(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("there", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection) + + with self.assertRaises(ValueError): + query._cursor_helper(snapshot, False, False) + + def test__cursor_helper_w_snapshot_other_collection_all_descendants(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("there", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query1 = self._make_one(collection, all_descendants=True) + + query2 = query1._cursor_helper(snapshot, False, False) + + self.assertIs(query2._parent, collection) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, ()) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertIs(cursor, snapshot) + self.assertFalse(before) + + def test__cursor_helper_w_snapshot(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query1 = self._make_one(collection) + + query2 = query1._cursor_helper(snapshot, False, False) + + self.assertIs(query2._parent, collection) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, ()) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertIs(cursor, snapshot) + self.assertFalse(before) + + def test_start_at(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields( + parent=collection, skip_fields=("orders",), all_descendants=True + ) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.start_at(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._start_at, (document_fields3, True)) + self._compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.start_at(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._start_at, (document_fields5, True)) + self._compare_queries(query4, query5, "_start_at") + + def test_start_after(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.start_after(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._start_at, (document_fields3, False)) + self._compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.start_after(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._start_at, (document_fields5, False)) + self._compare_queries(query4, query5, "_start_at") + + def test_end_before(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.end_before(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._end_at, (document_fields3, True)) + self._compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.end_before(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._end_at, (document_fields5, True)) + self._compare_queries(query4, query5, "_end_at") + self._compare_queries(query4, query5, "_end_at") + + def test_end_at(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.end_at(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._end_at, (document_fields3, False)) + self._compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.end_at(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._end_at, (document_fields5, False)) + self._compare_queries(query4, query5, "_end_at") + + def test__filters_pb_empty(self): + query = self._make_one(mock.sentinel.parent) + self.assertEqual(len(query._field_filters), 0) + self.assertIsNone(query._filters_pb()) + + def test__filters_pb_single(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query1 = self._make_one(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + filter_pb = query2._filters_pb() + expected_pb = query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="x.y"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=50.5), + ) + ) + self.assertEqual(filter_pb, expected_pb) + + def test__filters_pb_multi(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query1 = self._make_one(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + query3 = query2.where("ABC", "==", 123) + + filter_pb = query3._filters_pb() + op_class = enums.StructuredQuery.FieldFilter.Operator + expected_pb = query_pb2.StructuredQuery.Filter( + composite_filter=query_pb2.StructuredQuery.CompositeFilter( + op=enums.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference( + field_path="x.y" + ), + op=op_class.GREATER_THAN, + value=document_pb2.Value(double_value=50.5), + ) + ), + query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference( + field_path="ABC" + ), + op=op_class.EQUAL, + value=document_pb2.Value(integer_value=123), + ) + ), + ], + ) + ) + self.assertEqual(filter_pb, expected_pb) + + def test__normalize_projection_none(self): + query = self._make_one(mock.sentinel.parent) + self.assertIsNone(query._normalize_projection(None)) + + def test__normalize_projection_empty(self): + projection = self._make_projection_for_select([]) + query = self._make_one(mock.sentinel.parent) + normalized = query._normalize_projection(projection) + field_paths = [field_ref.field_path for field_ref in normalized.fields] + self.assertEqual(field_paths, ["__name__"]) + + def test__normalize_projection_non_empty(self): + projection = self._make_projection_for_select(["a", "b"]) + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._normalize_projection(projection), projection) + + def test__normalize_orders_wo_orders_wo_cursors(self): + query = self._make_one(mock.sentinel.parent) + expected = [] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_w_orders_wo_cursors(self): + query = self._make_one(mock.sentinel.parent).order_by("a") + expected = [query._make_order("a", "ASCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection).start_at(snapshot) + expected = [query._make_order("__name__", "ASCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_w_name_orders_w_snapshot_cursor(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = ( + self._make_one(collection) + .order_by("__name__", "DESCENDING") + .start_at(snapshot) + ) + expected = [query._make_order("__name__", "DESCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = ( + self._make_one(collection) + .where("c", "<=", 20) + .order_by("c", "DESCENDING") + .start_at(snapshot) + ) + expected = [ + query._make_order("c", "DESCENDING"), + query._make_order("__name__", "DESCENDING"), + ] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot) + expected = [ + query._make_order("c", "ASCENDING"), + query._make_order("__name__", "ASCENDING"), + ] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_cursor_none(self): + query = self._make_one(mock.sentinel.parent) + self.assertIsNone(query._normalize_cursor(None, query._orders)) + + def test__normalize_cursor_no_order(self): + cursor = ([1], True) + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_list_mismatched_order(self): + cursor = ([1, 2], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_dict_mismatched_order(self): + cursor = ({"a": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_delete(self): + from google.cloud.firestore_v1 import DELETE_FIELD + + cursor = ([DELETE_FIELD], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_server_timestamp(self): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + cursor = ([SERVER_TIMESTAMP], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_array_remove(self): + from google.cloud.firestore_v1 import ArrayRemove + + cursor = ([ArrayRemove([1, 3, 5])], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_array_union(self): + from google.cloud.firestore_v1 import ArrayUnion + + cursor = ([ArrayUnion([2, 4, 8])], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_list_hit(self): + cursor = ([1], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_hit(self): + cursor = ({"b": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_with_dot_key_hit(self): + cursor = ({"b.a": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_with_inner_data_hit(self): + cursor = ({"b": {"a": 1}}, True) + query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_snapshot_hit(self): + values = {"b": 1} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + cursor = (snapshot, True) + collection = self._make_collection("here") + query = self._make_one(collection).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_w___name___w_reference(self): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client"]) + parent._client = client + parent._path = ["C"] + query = self._make_one(parent).order_by("__name__", "ASCENDING") + docref = self._make_docref("here", "doc_id") + values = {"a": 7} + snapshot = self._make_snapshot(docref, values) + expected = docref + cursor = (snapshot, True) + + self.assertEqual( + query._normalize_cursor(cursor, query._orders), ([expected], True) + ) + + def test__normalize_cursor_w___name___wo_slash(self): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client", "document"]) + parent._client = client + parent._path = ["C"] + document = parent.document.return_value = mock.Mock(spec=[]) + query = self._make_one(parent).order_by("__name__", "ASCENDING") + cursor = (["b"], True) + expected = document + + self.assertEqual( + query._normalize_cursor(cursor, query._orders), ([expected], True) + ) + parent.document.assert_called_once_with("b") + + def test__to_protobuf_all_fields(self): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.select(["X", "Y", "Z"]) + query3 = query2.where("Y", ">", 2.5) + query4 = query3.order_by("X") + query5 = query4.limit(17) + query6 = query5.offset(3) + query7 = query6.start_at({"X": 10}) + query8 = query7.end_at({"X": 25}) + + structured_query_pb = query8._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "select": query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in ["X", "Y", "Z"] + ] + ), + "where": query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="Y"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=2.5), + ) + ), + "order_by": [ + _make_order_pb("X", enums.StructuredQuery.Direction.ASCENDING) + ], + "start_at": query_pb2.Cursor( + values=[document_pb2.Value(integer_value=10)], before=True + ), + "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=25)]), + "offset": 3, + "limit": wrappers_pb2.Int32Value(value=17), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_select_only(self): + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = self._make_one(parent) + field_paths = ["a.b", "a.c", "d"] + query2 = query1.select(field_paths) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "select": query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_where_only(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="dog", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.where("a", "==", u"b") + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "where": query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="a"), + op=enums.StructuredQuery.FieldFilter.Operator.EQUAL, + value=document_pb2.Value(string_value=u"b"), + ) + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_order_by_only(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="fish", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.order_by("abc") + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("abc", enums.StructuredQuery.Direction.ASCENDING) + ], + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_start_at_only(self): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="phish", spec=["id"]) + query = self._make_one(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) + + structured_query_pb = query._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("X.Y", enums.StructuredQuery.Direction.ASCENDING) + ], + "start_at": query_pb2.Cursor( + values=[document_pb2.Value(string_value=u"Z")] + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_end_at_only(self): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="ghoti", spec=["id"]) + query = self._make_one(parent).order_by("a").end_at({"a": 88}) + + structured_query_pb = query._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("a", enums.StructuredQuery.Direction.ASCENDING) + ], + "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=88)]), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_offset_only(self): + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cartt", spec=["id"]) + query1 = self._make_one(parent) + offset = 14 + query2 = query1.offset(offset) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "offset": offset, + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_limit_only(self): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="donut", spec=["id"]) + query1 = self._make_one(parent) + limit = 31 + query2 = query1.limit(limit) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "limit": wrappers_pb2.Int32Value(value=limit), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + + self.assertEqual(structured_query_pb, expected_pb) + + def test_get_simple(self): + asyncio.run(self._test_get_simple_helper()) + + async def _test_get_simple_helper(self): + import warnings + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + + with warnings.catch_warnings(record=True) as warned: + get_response = query.get() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("dee", "sleep")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + # Verify the deprecation + self.assertEqual(len(warned), 1) + self.assertIs(warned[0].category, DeprecationWarning) + + def test_stream_simple(self): + asyncio.run(self._test_stream_simple_helper()) + + async def _test_stream_simple_helper(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("dee", "sleep")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_stream_with_transaction(self): + asyncio.run(self._test_stream_with_transaction_helper()) + + async def _test_stream_with_transaction_helper(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream(transaction=transaction) + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("declaration", "burger")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + def test_stream_no_results(self): + asyncio.run(self._test_stream_no_results_helper()) + + async def _test_stream_no_results_helper(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response = _make_query_response() + run_query_response = iter([empty_response]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = self._make_one(parent) + + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + self.assertEqual([x async for x in get_response], []) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_stream_second_response_in_empty_stream(self): + asyncio.run(self._test_stream_second_response_in_empty_stream_helper()) + + async def _test_stream_second_response_in_empty_stream_helper(self): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = iter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = self._make_one(parent) + + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + self.assertEqual([x async for x in get_response], []) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_stream_with_skipped_results(self): + asyncio.run(self._test_stream_with_skipped_results_helper()) + + async def _test_stream_with_skipped_results_helper(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("talk", "and", "chew-gum") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + response_pb1 = _make_query_response(skipped_results=1) + name = "{}/clock".format(expected_prefix) + data = {"noon": 12, "nested": {"bird": 10.5}} + response_pb2 = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("talk", "and", "chew-gum", "clock")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_stream_empty_after_first_response(self): + asyncio.run(self._test_stream_empty_after_first_response_helper()) + + async def _test_stream_empty_after_first_response_helper(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/bark".format(expected_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + self.assertEqual(snapshot.reference._path, ("charles", "bark")) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + def test_stream_w_collection_group(self): + asyncio.run(self._test_stream_w_collection_group_helper()) + + async def _test_stream_w_collection_group_helper(self): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + other = client.collection("dora") + + # Add two dummy responses to the minimal fake GAPIC. + _, other_prefix = other._parent_info() + name = "{}/bark".format(other_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = self._make_one(parent) + query._all_descendants = True + get_response = query.stream() + self.assertIsInstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + self.assertEqual(len(returned), 1) + snapshot = returned[0] + to_match = other.document("bark") + self.assertEqual(snapshot.reference._document_path, to_match._document_path) + self.assertEqual(snapshot.to_dict(), data) + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + parent_path, + query._to_protobuf(), + transaction=None, + metadata=client._rpc_metadata, + ) + + @mock.patch("google.cloud.firestore_v1.async_query.Watch", autospec=True) + def test_on_snapshot(self, watch): + query = self._make_one(mock.sentinel.parent) + query.on_snapshot(None) + watch.for_query.assert_called_once() + + def test_comparator_no_ordering(self): + query = self._make_one(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, -1) + + def test_comparator_no_ordering_same_id(self): + query = self._make_one(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument1") + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, 0) + + def test_comparator_ordering(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, 1) + + def test_comparator_ordering_descending(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = -1 # descending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, -1) + + def test_comparator_missing_order_by_field_in_data_raises(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = {} + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + with self.assertRaisesRegex(ValueError, "Can only compare fields "): + query._comparator(doc1, doc2) + + +class Test__enum_from_op_string(unittest.TestCase): + @staticmethod + def _call_fut(op_string): + from google.cloud.firestore_v1.async_query import _enum_from_op_string + + return _enum_from_op_string(op_string) + + @staticmethod + def _get_op_class(): + from google.cloud.firestore_v1.gapic import enums + + return enums.StructuredQuery.FieldFilter.Operator + + def test_lt(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) + + def test_le(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) + + def test_eq(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("=="), op_class.EQUAL) + + def test_ge(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) + + def test_gt(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) + + def test_array_contains(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) + + def test_in(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("in"), op_class.IN) + + def test_array_contains_any(self): + op_class = self._get_op_class() + self.assertEqual( + self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY + ) + + def test_invalid(self): + with self.assertRaises(ValueError): + self._call_fut("?") + + +class Test__isnan(unittest.TestCase): + @staticmethod + def _call_fut(value): + from google.cloud.firestore_v1.async_query import _isnan + + return _isnan(value) + + def test_valid(self): + self.assertTrue(self._call_fut(float("nan"))) + + def test_invalid(self): + self.assertFalse(self._call_fut(51.5)) + self.assertFalse(self._call_fut(None)) + self.assertFalse(self._call_fut("str")) + self.assertFalse(self._call_fut(int)) + self.assertFalse(self._call_fut(1.0 + 1.0j)) + + +class Test__enum_from_direction(unittest.TestCase): + @staticmethod + def _call_fut(direction): + from google.cloud.firestore_v1.async_query import _enum_from_direction + + return _enum_from_direction(direction) + + def test_success(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.async_query import AsyncQuery + + dir_class = enums.StructuredQuery.Direction + self.assertEqual(self._call_fut(AsyncQuery.ASCENDING), dir_class.ASCENDING) + self.assertEqual(self._call_fut(AsyncQuery.DESCENDING), dir_class.DESCENDING) + + # Ints pass through + self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING) + self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING) + + def test_failure(self): + with self.assertRaises(ValueError): + self._call_fut("neither-ASCENDING-nor-DESCENDING") + + +class Test__filter_pb(unittest.TestCase): + @staticmethod + def _call_fut(field_or_unary): + from google.cloud.firestore_v1.async_query import _filter_pb + + return _filter_pb(field_or_unary) + + def test_unary(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import query_pb2 + + unary_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="a.b.c"), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + filter_pb = self._call_fut(unary_pb) + expected_pb = query_pb2.StructuredQuery.Filter(unary_filter=unary_pb) + self.assertEqual(filter_pb, expected_pb) + + def test_field(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + field_filter_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="XYZ"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=90.75), + ) + filter_pb = self._call_fut(field_filter_pb) + expected_pb = query_pb2.StructuredQuery.Filter(field_filter=field_filter_pb) + self.assertEqual(filter_pb, expected_pb) + + def test_bad_type(self): + with self.assertRaises(ValueError): + self._call_fut(None) + + +class Test__cursor_pb(unittest.TestCase): + @staticmethod + def _call_fut(cursor_pair): + from google.cloud.firestore_v1.async_query import _cursor_pb + + return _cursor_pb(cursor_pair) + + def test_no_pair(self): + self.assertIsNone(self._call_fut(None)) + + def test_success(self): + from google.cloud.firestore_v1.proto import query_pb2 + from google.cloud.firestore_v1 import _helpers + + data = [1.5, 10, True] + cursor_pair = data, True + + cursor_pb = self._call_fut(cursor_pair) + + expected_pb = query_pb2.Cursor( + values=[_helpers.encode_value(value) for value in data], before=True + ) + self.assertEqual(cursor_pb, expected_pb) + + +class Test__query_response_to_snapshot(unittest.TestCase): + @staticmethod + def _call_fut(response_pb, collection, expected_prefix): + from google.cloud.firestore_v1.async_query import _query_response_to_snapshot + + return _query_response_to_snapshot(response_pb, collection, expected_prefix) + + def test_empty(self): + response_pb = _make_query_response() + snapshot = self._call_fut(response_pb, None, None) + self.assertIsNone(snapshot) + + def test_after_offset(self): + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = self._call_fut(response_pb, None, None) + self.assertIsNone(snapshot) + + def test_response(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + client = _make_client() + collection = client.collection("a", "b", "c") + _, expected_prefix = collection._parent_info() + + # Create name for the protobuf. + doc_id = "gigantic" + name = "{}/{}".format(expected_prefix, doc_id) + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=name, data=data) + + snapshot = self._call_fut(response_pb, collection, expected_prefix) + self.assertIsInstance(snapshot, DocumentSnapshot) + expected_path = collection._path + (doc_id,) + self.assertEqual(snapshot.reference._path, expected_path) + self.assertEqual(snapshot.to_dict(), data) + self.assertTrue(snapshot.exists) + self.assertEqual(snapshot.read_time, response_pb.read_time) + self.assertEqual(snapshot.create_time, response_pb.document.create_time) + self.assertEqual(snapshot.update_time, response_pb.document.update_time) + + +class Test__collection_group_query_response_to_snapshot(unittest.TestCase): + @staticmethod + def _call_fut(response_pb, collection): + from google.cloud.firestore_v1.async_query import ( + _collection_group_query_response_to_snapshot, + ) + + return _collection_group_query_response_to_snapshot(response_pb, collection) + + def test_empty(self): + response_pb = _make_query_response() + snapshot = self._call_fut(response_pb, None) + self.assertIsNone(snapshot) + + def test_after_offset(self): + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = self._call_fut(response_pb, None) + self.assertIsNone(snapshot) + + def test_response(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + client = _make_client() + collection = client.collection("a", "b", "c") + other_collection = client.collection("a", "b", "d") + to_match = other_collection.document("gigantic") + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=to_match._document_path, data=data) + + snapshot = self._call_fut(response_pb, collection) + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertEqual(snapshot.reference._document_path, to_match._document_path) + self.assertEqual(snapshot.to_dict(), data) + self.assertTrue(snapshot.exists) + self.assertEqual(snapshot.read_time, response_pb.read_time) + self.assertEqual(snapshot.create_time, response_pb.document.create_time) + self.assertEqual(snapshot.update_time, response_pb.document.update_time) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) + + +def _make_order_pb(field_path, direction): + from google.cloud.firestore_v1.proto import query_pb2 + + return query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + direction=direction, + ) + + +def _make_query_response(**kwargs): + # kwargs supported are ``skipped_results``, ``name`` and ``data`` + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + kwargs["read_time"] = read_time + + name = kwargs.pop("name", None) + data = kwargs.pop("data", None) + if name is not None and data is not None: + document_pb = document_pb2.Document( + name=name, fields=_helpers.encode_dict(data) + ) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + document_pb.update_time.CopyFrom(update_time) + document_pb.create_time.CopyFrom(create_time) + + kwargs["document"] = document_pb + + return firestore_pb2.RunQueryResponse(**kwargs) From fb568fd6ea9ba603ffe9cea63ec4035760fe6847 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Fri, 19 Jun 2020 11:11:23 -0500 Subject: [PATCH 16/47] fix: AsyncQuery.get async_generator nesting --- google/cloud/firestore_v1/async_query.py | 61 ++++++++++++------------ 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 4061902db8..83c1e105f3 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -181,7 +181,7 @@ def select(self, field_paths): more information on **field paths**. If the current query already has a projection set (i.e. has already - called :meth:`~google.cloud.firestore_v1.query.Query.select`), this + called :meth:`~google.cloud.firestore_v1.query.AsyncQuery.select`), this will overwrite it. Args: @@ -190,7 +190,7 @@ def select(self, field_paths): of document fields in the query results. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A "projected" query. Acts as a copy of the current query, modified with the newly added projection. Raises: @@ -224,7 +224,7 @@ def where(self, field_path, op_string, value): See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for more information on **field paths**. - Returns a new :class:`~google.cloud.firestore_v1.query.Query` that + Returns a new :class:`~google.cloud.firestore_v1.query.AsyncQuery` that filters on a specific field path, according to an operation (e.g. ``==`` or "equals") and a particular value to be paired with that operation. @@ -240,7 +240,7 @@ def where(self, field_path, op_string, value): allowed operation. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A filtered query. Acts as a copy of the current query, modified with the newly added filter. @@ -301,7 +301,7 @@ def order_by(self, field_path, direction=ASCENDING): See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for more information on **field paths**. - Successive :meth:`~google.cloud.firestore_v1.query.Query.order_by` + Successive :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by` calls will further refine the ordering of results returned by the query (i.e. the new "order by" fields will be added to existing ones). @@ -313,7 +313,7 @@ def order_by(self, field_path, direction=ASCENDING): :attr:`ASCENDING`. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: An ordered query. Acts as a copy of the current query, modified with the newly added "order by" constraint. @@ -349,7 +349,7 @@ def limit(self, count): the query. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A limited query. Acts as a copy of the current query, modified with the newly added "limit" filter. """ @@ -376,7 +376,7 @@ def offset(self, num_to_skip): of query results. (Must be non-negative.) Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: An offset query. Acts as a copy of the current query, modified with the newly added "offset" field. """ @@ -412,7 +412,7 @@ def _cursor_helper(self, document_fields, before, start): When the query is sent to the server, the ``document_fields`` will be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. Args: document_fields @@ -427,7 +427,7 @@ def _cursor_helper(self, document_fields, before, start): cursor (:data:`True`) or an ``end_at`` cursor (:data:`False`). Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A query with cursor. Acts as a copy of the current query, modified with the newly added "start at" cursor. """ @@ -465,12 +465,12 @@ def start_at(self, document_fields): If the current query already has specified a start cursor -- either via this method or - :meth:`~google.cloud.firestore_v1.query.Query.start_after` -- this + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.start_after` -- this will overwrite it. When the query is sent to the server, the ``document_fields`` will be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. Args: document_fields @@ -480,7 +480,7 @@ def start_at(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A query with cursor. Acts as a copy of the current query, modified with the newly added "start at" cursor. @@ -495,12 +495,12 @@ def start_after(self, document_fields): If the current query already has specified a start cursor -- either via this method or - :meth:`~google.cloud.firestore_v1.query.Query.start_at` -- this will + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.start_at` -- this will overwrite it. When the query is sent to the server, the ``document_fields`` will be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. Args: document_fields @@ -510,7 +510,7 @@ def start_after(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A query with cursor. Acts as a copy of the current query, modified with the newly added "start after" cursor. """ @@ -524,12 +524,12 @@ def end_before(self, document_fields): If the current query already has specified an end cursor -- either via this method or - :meth:`~google.cloud.firestore_v1.query.Query.end_at` -- this will + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.end_at` -- this will overwrite it. When the query is sent to the server, the ``document_fields`` will be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. Args: document_fields @@ -539,7 +539,7 @@ def end_before(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A query with cursor. Acts as a copy of the current query, modified with the newly added "end before" cursor. """ @@ -553,12 +553,12 @@ def end_at(self, document_fields): If the current query already has specified an end cursor -- either via this method or - :meth:`~google.cloud.firestore_v1.query.Query.end_before` -- this will + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.end_before` -- this will overwrite it. When the query is sent to the server, the ``document_fields`` will be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. Args: document_fields @@ -568,7 +568,7 @@ def end_at(self, document_fields): of values that represent a position in a query result set. Returns: - :class:`~google.cloud.firestore_v1.query.Query`: + :class:`~google.cloud.firestore_v1.query.AsyncQuery`: A query with cursor. Acts as a copy of the current query, modified with the newly added "end at" cursor. """ @@ -731,11 +731,12 @@ def _to_protobuf(self): async def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" warnings.warn( - "'Query.get' is deprecated: please use 'Query.stream' instead.", + "'AsyncQuery.get' is deprecated: please use 'AsyncQuery.stream' instead.", DeprecationWarning, stacklevel=2, ) - return await self.stream(transaction=transaction) + async for d in self.stream(transaction=transaction): + yield d async def stream(self, transaction=None): """Read the documents in the collection that match this query. @@ -822,12 +823,12 @@ def _comparator(self, doc1, doc2): # Add implicit sorting by name, using the last specified direction. if len(_orders) == 0: - lastDirection = Query.ASCENDING + lastDirection = AsyncQuery.ASCENDING else: if _orders[-1].direction == 1: - lastDirection = Query.ASCENDING + lastDirection = AsyncQuery.ASCENDING else: - lastDirection = Query.DESCENDING + lastDirection = AsyncQuery.DESCENDING orderBys = list(_orders) @@ -912,8 +913,8 @@ def _enum_from_direction(direction): Args: direction (str): A direction to order by. Must be one of - :attr:`~google.cloud.firestore.Query.ASCENDING` or - :attr:`~google.cloud.firestore.Query.DESCENDING`. + :attr:`~google.cloud.firestore.AsyncQuery.ASCENDING` or + :attr:`~google.cloud.firestore.AsyncQuery.DESCENDING`. Returns: int: The enum corresponding to ``direction``. @@ -929,7 +930,7 @@ def _enum_from_direction(direction): elif direction == AsyncQuery.DESCENDING: return enums.StructuredQuery.Direction.DESCENDING else: - msg = _BAD_DIR_STRING.format(direction, Query.ASCENDING, Query.DESCENDING) + msg = _BAD_DIR_STRING.format(direction, AsyncQuery.ASCENDING, AsyncQuery.DESCENDING) raise ValueError(msg) From 438bcf3a36dc814ce88fb971ec0a4613c017bf39 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Fri, 19 Jun 2020 20:40:42 -0500 Subject: [PATCH 17/47] feat: add async_transaction integration and tests --- google/cloud/firestore_v1/async_client.py | 12 +- .../cloud/firestore_v1/async_transaction.py | 443 +++++++ tests/unit/v1/async/test_async_client.py | 4 +- tests/unit/v1/async/test_async_transaction.py | 1021 +++++++++++++++++ 4 files changed, 1472 insertions(+), 8 deletions(-) create mode 100644 google/cloud/firestore_v1/async_transaction.py create mode 100644 tests/unit/v1/async/test_async_transaction.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index e482cae3ea..b9ff127df8 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -40,7 +40,7 @@ from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.gapic import firestore_client from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport -from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.async_transaction import AsyncTransaction DEFAULT_DATABASE = "(default)" @@ -423,7 +423,7 @@ async def get_all(self, references, field_paths=None, transaction=None): paths (``.``-delimited list of field names) to use as a projection of document fields in the returned results. If no value is provided, all fields will be returned. - transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`]): An existing transaction that these ``references`` will be retrieved in. @@ -471,20 +471,20 @@ def batch(self): def transaction(self, **kwargs): """Get a transaction that uses this client. - See :class:`~google.cloud.firestore_v1.transaction.Transaction` for + See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for more information on transactions and the constructor arguments. Args: kwargs (Dict[str, Any]): The keyword arguments (other than ``client``) to pass along to the - :class:`~google.cloud.firestore_v1.transaction.Transaction` + :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` constructor. Returns: - :class:`~google.cloud.firestore_v1.transaction.Transaction`: + :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`: A transaction attached to this client. """ - return Transaction(self, **kwargs) + return AsyncTransaction(self, **kwargs) def _reference_info(references): diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py new file mode 100644 index 0000000000..188d5596c0 --- /dev/null +++ b/google/cloud/firestore_v1/async_transaction.py @@ -0,0 +1,443 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for applying Google Cloud Firestore changes in a transaction.""" + + +import asyncio +import random +import time + +import six + +from google.api_core import exceptions +from google.cloud.firestore_v1 import async_batch +from google.cloud.firestore_v1 import types +from google.cloud.firestore_v1.async_document import AsyncDocumentReference +from google.cloud.firestore_v1.async_query import AsyncQuery + + +MAX_ATTEMPTS = 5 +"""int: Default number of transaction attempts (with retries).""" +_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." +_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." +_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") +_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") +_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." +_INITIAL_SLEEP = 1.0 +"""float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" +_MAX_SLEEP = 30.0 +"""float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" +_MULTIPLIER = 2.0 +"""float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" +_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." +_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." + + +class AsyncTransaction(async_batch.AsyncWriteBatch): + """Accumulate read-and-write operations to be sent in a transaction. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this transaction. + max_attempts (Optional[int]): The maximum number of attempts for + the transaction (i.e. allowing retries). Defaults to + :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. + read_only (Optional[bool]): Flag indicating if the transaction + should be read-only or should allow writes. Defaults to + :data:`False`. + """ + + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): + super(AsyncTransaction, self).__init__(client) + self._max_attempts = max_attempts + self._read_only = read_only + self._id = None + + def _add_write_pbs(self, write_pbs): + """Add `Write`` protobufs to this transaction. + + Args: + write_pbs (List[google.cloud.proto.firestore.v1.\ + write_pb2.Write]): A list of write protobufs to be added. + + Raises: + ValueError: If this transaction is read-only. + """ + if self._read_only: + raise ValueError(_WRITE_READ_ONLY) + + super(AsyncTransaction, self)._add_write_pbs(write_pbs) + + def _options_protobuf(self, retry_id): + """Convert the current object to protobuf. + + The ``retry_id`` value is used when retrying a transaction that + failed (e.g. due to contention). It is intended to be the "first" + transaction that failed (i.e. if multiple retries are needed). + + Args: + retry_id (Union[bytes, NoneType]): Transaction ID of a transaction + to be retried. + + Returns: + Optional[google.cloud.firestore_v1.types.TransactionOptions]: + The protobuf ``TransactionOptions`` if ``read_only==True`` or if + there is a transaction ID to be retried, else :data:`None`. + + Raises: + ValueError: If ``retry_id`` is not :data:`None` but the + transaction is read-only. + """ + if retry_id is not None: + if self._read_only: + raise ValueError(_CANT_RETRY_READ_ONLY) + + return types.TransactionOptions( + read_write=types.TransactionOptions.ReadWrite( + retry_transaction=retry_id + ) + ) + elif self._read_only: + return types.TransactionOptions( + read_only=types.TransactionOptions.ReadOnly() + ) + else: + return None + + @property + def in_progress(self): + """Determine if this transaction has already begun. + + Returns: + bool: Indicates if the transaction has started. + """ + return self._id is not None + + @property + def id(self): + """Get the current transaction ID. + + Returns: + Optional[bytes]: The transaction ID (or :data:`None` if the + current transaction is not in progress). + """ + return self._id + + async def _begin(self, retry_id=None): + """Begin the transaction. + + Args: + retry_id (Optional[bytes]): Transaction ID of a transaction to be + retried. + + Raises: + ValueError: If the current transaction has already begun. + """ + if self.in_progress: + msg = _CANT_BEGIN.format(self._id) + raise ValueError(msg) + + transaction_response = self._client._firestore_api.begin_transaction( + self._client._database_string, + options_=self._options_protobuf(retry_id), + metadata=self._client._rpc_metadata, + ) + self._id = transaction_response.transaction + + def _clean_up(self): + """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. + + This intended to occur on success or failure of the associated RPCs. + """ + self._write_pbs = [] + self._id = None + + async def _rollback(self): + """Roll back the transaction. + + Raises: + ValueError: If no transaction is in progress. + """ + if not self.in_progress: + raise ValueError(_CANT_ROLLBACK) + + try: + # NOTE: The response is just ``google.protobuf.Empty``. + self._client._firestore_api.rollback( + self._client._database_string, + self._id, + metadata=self._client._rpc_metadata, + ) + finally: + self._clean_up() + + async def _commit(self): + """Transactionally commit the changes accumulated. + + Returns: + List[:class:`google.cloud.proto.firestore.v1.write_pb2.WriteResult`, ...]: + The write results corresponding to the changes committed, returned + in the same order as the changes were applied to this transaction. + A write result contains an ``update_time`` field. + + Raises: + ValueError: If no transaction is in progress. + """ + if not self.in_progress: + raise ValueError(_CANT_COMMIT) + + commit_response = await _commit_with_retry(self._client, self._write_pbs, self._id) + + self._clean_up() + return list(commit_response.write_results) + + async def get_all(self, references): + """Retrieves multiple documents from Firestore. + + Args: + references (List[.AsyncDocumentReference, ...]): Iterable of document + references to be retrieved. + + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + return self._client.get_all(references, transaction=self) + + async def get(self, ref_or_query): + """ + Retrieve a document or a query result from the database. + Args: + ref_or_query The document references or query object to return. + Yields: + .DocumentSnapshot: The next document snapshot that fulfills the + query, or :data:`None` if the document does not exist. + """ + if isinstance(ref_or_query, AsyncDocumentReference): + return self._client.get_all([ref_or_query], transaction=self) + elif isinstance(ref_or_query, AsyncQuery): + return ref_or_query.stream(transaction=self) + else: + raise ValueError( + 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.' + ) + + +class _Transactional(object): + """Provide a callable object to use as a transactional decorater. + + This is surfaced via + :func:`~google.cloud.firestore_v1.async_transaction.transactional`. + + Args: + to_wrap (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): + A callable that should be run (and retried) in a transaction. + """ + + def __init__(self, to_wrap): + self.to_wrap = to_wrap + self.current_id = None + """Optional[bytes]: The current transaction ID.""" + self.retry_id = None + """Optional[bytes]: The ID of the first attempted transaction.""" + + def _reset(self): + """Unset the transaction IDs.""" + self.current_id = None + self.retry_id = None + + async def _pre_commit(self, transaction, *args, **kwargs): + """Begin transaction and call the wrapped callable. + + If the callable raises an exception, the transaction will be rolled + back. If not, the transaction will be "ready" for ``Commit`` (i.e. + it will have staged writes). + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + A transaction to execute the callable within. + args (Tuple[Any, ...]): The extra positional arguments to pass + along to the wrapped callable. + kwargs (Dict[str, Any]): The extra keyword arguments to pass + along to the wrapped callable. + + Returns: + Any: result of the wrapped callable. + + Raises: + Exception: Any failure caused by ``to_wrap``. + """ + # Force the ``transaction`` to be not "in progress". + transaction._clean_up() + await transaction._begin(retry_id=self.retry_id) + + # Update the stored transaction IDs. + self.current_id = transaction._id + if self.retry_id is None: + self.retry_id = self.current_id + try: + return self.to_wrap(transaction, *args, **kwargs) + except: # noqa + # NOTE: If ``rollback`` fails this will lose the information + # from the original failure. + await transaction._rollback() + raise + + async def _maybe_commit(self, transaction): + """Try to commit the transaction. + + If the transaction is read-write and the ``Commit`` fails with the + ``ABORTED`` status code, it will be retried. Any other failure will + not be caught. + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + The transaction to be ``Commit``-ed. + + Returns: + bool: Indicating if the commit succeeded. + """ + try: + await transaction._commit() + return True + except exceptions.GoogleAPICallError as exc: + if transaction._read_only: + raise + + if isinstance(exc, exceptions.Aborted): + # If a read-write transaction returns ABORTED, retry. + return False + else: + raise + + async def __call__(self, transaction, *args, **kwargs): + """Execute the wrapped callable within a transaction. + + Args: + transaction + (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + A transaction to execute the callable within. + args (Tuple[Any, ...]): The extra positional arguments to pass + along to the wrapped callable. + kwargs (Dict[str, Any]): The extra keyword arguments to pass + along to the wrapped callable. + + Returns: + Any: The result of the wrapped callable. + + Raises: + ValueError: If the transaction does not succeed in + ``max_attempts``. + """ + self._reset() + + for attempt in six.moves.xrange(transaction._max_attempts): + result = await self._pre_commit(transaction, *args, **kwargs) + succeeded = await self._maybe_commit(transaction) + if succeeded: + return result + + # Subsequent requests will use the failed transaction ID as part of + # the ``BeginTransactionRequest`` when restarting this transaction + # (via ``options.retry_transaction``). This preserves the "spot in + # line" of the transaction, so exponential backoff is not required + # in this case. + + await transaction._rollback() + msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + raise ValueError(msg) + + +def transactional(to_wrap): + """Decorate a callable so that it runs in a transaction. + + Args: + to_wrap + (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + A callable that should be run (and retried) in a transaction. + + Returns: + Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: + the wrapped callable. + """ + return _Transactional(to_wrap) + + +async def _commit_with_retry(client, write_pbs, transaction_id): + """Call ``Commit`` on the GAPIC client with retry / sleep. + + Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level + retry is handled by the underlying GAPICd client, but in this case it + doesn't because ``Commit`` is not always idempotent. But here we know it + is "idempotent"-like because it has a transaction ID. We also need to do + our own retry to special-case the ``INVALID_ARGUMENT`` error. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + A client with GAPIC client and configuration details. + write_pbs (List[:class:`google.cloud.proto.firestore.v1.write_pb2.Write`, ...]): + A ``Write`` protobuf instance to be committed. + transaction_id (bytes): + ID of an existing transaction that this commit will run in. + + Returns: + :class:`google.cloud.firestore_v1.types.CommitResponse`: + The protobuf response from ``Commit``. + + Raises: + ~google.api_core.exceptions.GoogleAPICallError: If a non-retryable + exception is encountered. + """ + current_sleep = _INITIAL_SLEEP + while True: + try: + return client._firestore_api.commit( + client._database_string, + write_pbs, + transaction=transaction_id, + metadata=client._rpc_metadata, + ) + except exceptions.ServiceUnavailable: + # Retry + pass + + current_sleep = await _sleep(current_sleep) + + +async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): + """Sleep and produce a new sleep time. + + .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ + 2015/03/backoff.html + + Select a duration between zero and ``current_sleep``. It might seem + counterintuitive to have so much jitter, but + `Exponential Backoff And Jitter`_ argues that "full jitter" is + the best strategy. + + Args: + current_sleep (float): The current "max" for sleep interval. + max_sleep (Optional[float]): Eventual "max" sleep time + multiplier (Optional[float]): Multiplier for exponential backoff. + + Returns: + float: Newly doubled ``current_sleep`` or ``max_sleep`` (whichever + is smaller) + """ + actual_sleep = random.uniform(0.0, current_sleep) + await asyncio.sleep(actual_sleep) + return min(multiplier * current_sleep, max_sleep) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index f7c92976b9..47db52eca9 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -553,11 +553,11 @@ def test_batch(self): self.assertEqual(batch._write_pbs, []) def test_transaction(self): - from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.async_transaction import AsyncTransaction client = self._make_default_one() transaction = client.transaction(max_attempts=3, read_only=True) - self.assertIsInstance(transaction, Transaction) + self.assertIsInstance(transaction, AsyncTransaction) self.assertEqual(transaction._write_pbs, []) self.assertEqual(transaction._max_attempts, 3) self.assertTrue(transaction._read_only) diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py new file mode 100644 index 0000000000..c9fcfeda89 --- /dev/null +++ b/tests/unit/v1/async/test_async_transaction.py @@ -0,0 +1,1021 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest +import mock + + +class TestAsyncTransaction(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + return AsyncTransaction + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + from google.cloud.firestore_v1.async_transaction import MAX_ATTEMPTS + + transaction = self._make_one(mock.sentinel.client) + self.assertIs(transaction._client, mock.sentinel.client) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) + self.assertFalse(transaction._read_only) + self.assertIsNone(transaction._id) + + def test_constructor_explicit(self): + transaction = self._make_one( + mock.sentinel.client, max_attempts=10, read_only=True + ) + self.assertIs(transaction._client, mock.sentinel.client) + self.assertEqual(transaction._write_pbs, []) + self.assertEqual(transaction._max_attempts, 10) + self.assertTrue(transaction._read_only) + self.assertIsNone(transaction._id) + + def test__add_write_pbs_failure(self): + from google.cloud.firestore_v1.async_transaction import _WRITE_READ_ONLY + + batch = self._make_one(mock.sentinel.client, read_only=True) + self.assertEqual(batch._write_pbs, []) + with self.assertRaises(ValueError) as exc_info: + batch._add_write_pbs([mock.sentinel.write]) + + self.assertEqual(exc_info.exception.args, (_WRITE_READ_ONLY,)) + self.assertEqual(batch._write_pbs, []) + + def test__add_write_pbs(self): + batch = self._make_one(mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + batch._add_write_pbs([mock.sentinel.write]) + self.assertEqual(batch._write_pbs, [mock.sentinel.write]) + + def test__options_protobuf_read_only(self): + from google.cloud.firestore_v1.proto import common_pb2 + + transaction = self._make_one(mock.sentinel.client, read_only=True) + options_pb = transaction._options_protobuf(None) + expected_pb = common_pb2.TransactionOptions( + read_only=common_pb2.TransactionOptions.ReadOnly() + ) + self.assertEqual(options_pb, expected_pb) + + def test__options_protobuf_read_only_retry(self): + from google.cloud.firestore_v1.async_transaction import _CANT_RETRY_READ_ONLY + + transaction = self._make_one(mock.sentinel.client, read_only=True) + retry_id = b"illuminate" + + with self.assertRaises(ValueError) as exc_info: + transaction._options_protobuf(retry_id) + + self.assertEqual(exc_info.exception.args, (_CANT_RETRY_READ_ONLY,)) + + def test__options_protobuf_read_write(self): + transaction = self._make_one(mock.sentinel.client) + options_pb = transaction._options_protobuf(None) + self.assertIsNone(options_pb) + + def test__options_protobuf_on_retry(self): + from google.cloud.firestore_v1.proto import common_pb2 + + transaction = self._make_one(mock.sentinel.client) + retry_id = b"hocus-pocus" + options_pb = transaction._options_protobuf(retry_id) + expected_pb = common_pb2.TransactionOptions( + read_write=common_pb2.TransactionOptions.ReadWrite( + retry_transaction=retry_id + ) + ) + self.assertEqual(options_pb, expected_pb) + + def test_in_progress_property(self): + transaction = self._make_one(mock.sentinel.client) + self.assertFalse(transaction.in_progress) + transaction._id = b"not-none-bites" + self.assertTrue(transaction.in_progress) + + def test_id_property(self): + transaction = self._make_one(mock.sentinel.client) + transaction._id = mock.sentinel.eye_dee + self.assertIs(transaction.id, mock.sentinel.eye_dee) + + def test__begin(self): + from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.proto import firestore_pb2 + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + txn_id = b"to-begin" + response = firestore_pb2.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and ``begin()`` it. + transaction = self._make_one(client) + self.assertIsNone(transaction._id) + + ret_val = asyncio.run(transaction._begin()) + self.assertIsNone(ret_val) + self.assertEqual(transaction._id, txn_id) + + # Verify the called mock. + firestore_api.begin_transaction.assert_called_once_with( + client._database_string, options_=None, metadata=client._rpc_metadata + ) + + def test__begin_failure(self): + from google.cloud.firestore_v1.async_transaction import _CANT_BEGIN + + client = _make_client() + transaction = self._make_one(client) + transaction._id = b"not-none" + + with self.assertRaises(ValueError) as exc_info: + asyncio.run(transaction._begin()) + + err_msg = _CANT_BEGIN.format(transaction._id) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + def test__clean_up(self): + transaction = self._make_one(mock.sentinel.client) + transaction._write_pbs.extend( + [mock.sentinel.write_pb1, mock.sentinel.write_pb2] + ) + transaction._id = b"not-this-time-my-friend" + + ret_val = transaction._clean_up() + self.assertIsNone(ret_val) + + self.assertEqual(transaction._write_pbs, []) + self.assertIsNone(transaction._id) + + def test__rollback(self): + from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + firestore_api.rollback.return_value = empty_pb2.Empty() + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = self._make_one(client) + txn_id = b"to-be-r\x00lled" + transaction._id = txn_id + ret_val = asyncio.run(transaction._rollback()) + self.assertIsNone(ret_val) + self.assertIsNone(transaction._id) + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + client._database_string, txn_id, metadata=client._rpc_metadata + ) + + def test__rollback_not_allowed(self): + from google.cloud.firestore_v1.async_transaction import _CANT_ROLLBACK + + client = _make_client() + transaction = self._make_one(client) + self.assertIsNone(transaction._id) + + with self.assertRaises(ValueError) as exc_info: + asyncio.run(transaction._rollback()) + + self.assertEqual(exc_info.exception.args, (_CANT_ROLLBACK,)) + + def test__rollback_failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during rollback.") + firestore_api.rollback.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = self._make_one(client) + txn_id = b"roll-bad-server" + transaction._id = txn_id + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + asyncio.run(transaction._rollback()) + + self.assertIs(exc_info.exception, exc) + self.assertIsNone(transaction._id) + self.assertEqual(transaction._write_pbs, []) + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + client._database_string, txn_id, metadata=client._rpc_metadata + ) + + def test__commit(self): + from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + commit_response = firestore_pb2.CommitResponse( + write_results=[write_pb2.WriteResult()] + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("phone-joe") + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = self._make_one(client) + txn_id = b"under-over-thru-woods" + transaction._id = txn_id + document = client.document("zap", "galaxy", "ship", "space") + transaction.set(document, {"apple": 4.5}) + write_pbs = transaction._write_pbs[::] + + write_results = asyncio.run(transaction._commit()) + self.assertEqual(write_results, list(commit_response.write_results)) + # Make sure transaction has no more "changes". + self.assertIsNone(transaction._id) + self.assertEqual(transaction._write_pbs, []) + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + client._database_string, + write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + def test__commit_not_allowed(self): + from google.cloud.firestore_v1.async_transaction import _CANT_COMMIT + + transaction = self._make_one(mock.sentinel.client) + self.assertIsNone(transaction._id) + with self.assertRaises(ValueError) as exc_info: + asyncio.run(transaction._commit()) + + self.assertEqual(exc_info.exception.args, (_CANT_COMMIT,)) + + def test__commit_failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during commit.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = self._make_one(client) + txn_id = b"beep-fail-commit" + transaction._id = txn_id + transaction.create(client.document("up", "down"), {"water": 1.0}) + transaction.delete(client.document("up", "left")) + write_pbs = transaction._write_pbs[::] + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + asyncio.run(transaction._commit()) + + self.assertIs(exc_info.exception, exc) + self.assertEqual(transaction._id, txn_id) + self.assertEqual(transaction._write_pbs, write_pbs) + + # Verify the called mock. + firestore_api.commit.assert_called_once_with( + client._database_string, + write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + def test_get_all(self): + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref1, ref2 = mock.Mock(), mock.Mock() + result = asyncio.run(transaction.get_all([ref1, ref2])) + client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) + self.assertIs(result, client.get_all.return_value) + + def test_get_document_ref(self): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + client = mock.Mock(spec=["get_all"]) + transaction = self._make_one(client) + ref = AsyncDocumentReference("documents", "doc-id") + result = asyncio.run(transaction.get(ref)) + client.get_all.assert_called_once_with([ref], transaction=transaction) + self.assertIs(result, client.get_all.return_value) + + def test_get_w_query(self): + from google.cloud.firestore_v1.async_query import AsyncQuery + + client = mock.Mock(spec=[]) + transaction = self._make_one(client) + query = AsyncQuery(parent=mock.Mock(spec=[])) + query.stream = mock.MagicMock() + result = asyncio.run(transaction.get(query)) + query.stream.assert_called_once_with(transaction=transaction) + self.assertIs(result, query.stream.return_value) + + def test_get_failure(self): + client = _make_client() + transaction = self._make_one(client) + ref_or_query = object() + with self.assertRaises(ValueError): + asyncio.run(transaction.get(ref_or_query)) + + +class Test_Transactional(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.async_transaction import _Transactional + + return _Transactional + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + wrapped = self._make_one(mock.sentinel.callable_) + self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) + self.assertIsNone(wrapped.current_id) + self.assertIsNone(wrapped.retry_id) + + def test__reset(self): + wrapped = self._make_one(mock.sentinel.callable_) + wrapped.current_id = b"not-none" + wrapped.retry_id = b"also-not" + + ret_val = wrapped._reset() + self.assertIsNone(ret_val) + + self.assertIsNone(wrapped.current_id) + self.assertIsNone(wrapped.retry_id) + + def test__pre_commit_success(self): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"totes-began" + transaction = _make_transaction(txn_id) + result = asyncio.run(wrapped._pre_commit(transaction, "pos", key="word")) + self.assertIs(result, mock.sentinel.result) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "pos", key="word") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=None, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + def test__pre_commit_retry_id_already_set_success(self): + from google.cloud.firestore_v1.proto import common_pb2 + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + txn_id1 = b"already-set" + wrapped.retry_id = txn_id1 + + txn_id2 = b"ok-here-too" + transaction = _make_transaction(txn_id2) + result = asyncio.run(wrapped._pre_commit(transaction)) + self.assertIs(result, mock.sentinel.result) + + self.assertEqual(transaction._id, txn_id2) + self.assertEqual(wrapped.current_id, txn_id2) + self.assertEqual(wrapped.retry_id, txn_id1) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction) + firestore_api = transaction._client._firestore_api + options_ = common_pb2.TransactionOptions( + read_write=common_pb2.TransactionOptions.ReadWrite( + retry_transaction=txn_id1 + ) + ) + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=options_, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + def test__pre_commit_failure(self): + exc = RuntimeError("Nope not today.") + to_wrap = mock.Mock(side_effect=exc, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"gotta-fail" + transaction = _make_transaction(txn_id) + with self.assertRaises(RuntimeError) as exc_info: + asyncio.run(wrapped._pre_commit(transaction, 10, 20)) + self.assertIs(exc_info.exception, exc) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, 10, 20) + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=None, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + transaction._client._database_string, + txn_id, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + def test__pre_commit_failure_with_rollback_failure(self): + from google.api_core import exceptions + + exc1 = ValueError("I will not be only failure.") + to_wrap = mock.Mock(side_effect=exc1, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"both-will-fail" + transaction = _make_transaction(txn_id) + # Actually force the ``rollback`` to fail as well. + exc2 = exceptions.InternalServerError("Rollback blues.") + firestore_api = transaction._client._firestore_api + firestore_api.rollback.side_effect = exc2 + + # Try to ``_pre_commit`` + with self.assertRaises(exceptions.InternalServerError) as exc_info: + asyncio.run(wrapped._pre_commit(transaction, a="b", c="zebra")) + self.assertIs(exc_info.exception, exc2) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, a="b", c="zebra") + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=None, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + transaction._client._database_string, + txn_id, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + def test__maybe_commit_success(self): + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"nyet" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + succeeded = asyncio.run(wrapped._maybe_commit(transaction)) + self.assertTrue(succeeded) + + # On success, _id is reset. + self.assertIsNone(transaction._id) + + # Verify mocks. + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + def test__maybe_commit_failure_read_only(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed" + transaction = _make_transaction(txn_id, read_only=True) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail (use ABORTED, but cannot + # retry since read-only). + exc = exceptions.Aborted("Read-only did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with self.assertRaises(exceptions.Aborted) as exc_info: + asyncio.run(wrapped._maybe_commit(transaction)) + self.assertIs(exc_info.exception, exc) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + def test__maybe_commit_failure_can_retry(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed-but-retry" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Read-write did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + succeeded = asyncio.run(wrapped._maybe_commit(transaction)) + self.assertFalse(succeeded) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + def test__maybe_commit_failure_cannot_retry(self): + from google.api_core import exceptions + + wrapped = self._make_one(mock.sentinel.callable_) + + txn_id = b"failed-but-not-retryable" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.InternalServerError("Real bad thing") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with self.assertRaises(exceptions.InternalServerError) as exc_info: + asyncio.run(wrapped._maybe_commit(transaction)) + self.assertIs(exc_info.exception, exc) + + self.assertEqual(transaction._id, txn_id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + def test___call__success_first_attempt(self): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + result = asyncio.run(wrapped(transaction, "a", b="c")) + self.assertIs(result, mock.sentinel.result) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "a", b="c") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=None, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + def test___call__success_second_attempt(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = [ + exc, + firestore_pb2.CommitResponse(write_results=[write_pb2.WriteResult()]), + ] + + # Call the __call__-able ``wrapped``. + result = asyncio.run(wrapped(transaction, "a", b="c")) + self.assertIs(result, mock.sentinel.result) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + self.assertEqual(to_wrap.mock_calls, [wrapped_call, wrapped_call]) + firestore_api = transaction._client._firestore_api + db_str = transaction._client._database_string + options_ = common_pb2.TransactionOptions( + read_write=common_pb2.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + self.assertEqual( + firestore_api.begin_transaction.mock_calls, + [ + mock.call( + db_str, options_=None, metadata=transaction._client._rpc_metadata + ), + mock.call( + db_str, + options_=options_, + metadata=transaction._client._rpc_metadata, + ), + ], + ) + firestore_api.rollback.assert_not_called() + commit_call = mock.call( + db_str, [], transaction=txn_id, metadata=transaction._client._rpc_metadata + ) + self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) + + def test___call__failure(self): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = self._make_one(to_wrap) + + txn_id = b"only-one-shot" + transaction = _make_transaction(txn_id, max_attempts=1) + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + # Call the __call__-able ``wrapped``. + with self.assertRaises(ValueError) as exc_info: + asyncio.run(wrapped(transaction, "here", there=1.5)) + + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + self.assertIsNone(transaction._id) + self.assertEqual(wrapped.current_id, txn_id) + self.assertEqual(wrapped.retry_id, txn_id) + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + transaction._client._database_string, + options_=None, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + transaction._client._database_string, + txn_id, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + transaction._client._database_string, + [], + transaction=txn_id, + metadata=transaction._client._rpc_metadata, + ) + + +class Test_transactional(unittest.TestCase): + @staticmethod + def _call_fut(to_wrap): + from google.cloud.firestore_v1.async_transaction import transactional + + return transactional(to_wrap) + + def test_it(self): + from google.cloud.firestore_v1.async_transaction import _Transactional + + wrapped = self._call_fut(mock.sentinel.callable_) + self.assertIsInstance(wrapped, _Transactional) + self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) + + +class Test__commit_with_retry(unittest.TestCase): + @staticmethod + def _call_fut(client, write_pbs, transaction_id): + from google.cloud.firestore_v1.async_transaction import _commit_with_retry + + return asyncio.run(_commit_with_retry(client, write_pbs, transaction_id)) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") + def test_success_first_attempt(self, _sleep): + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("summer") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"cheeeeeez" + commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) + self.assertIs(commit_response, firestore_api.commit.return_value) + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + client._database_string, + mock.sentinel.write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0]) + def test_success_third_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first two requests fail and the third succeeds. + firestore_api.commit.side_effect = [ + exceptions.ServiceUnavailable("Server sleepy."), + exceptions.ServiceUnavailable("Server groggy."), + mock.sentinel.commit_response, + ] + + # Attach the fake GAPIC to a real client. + client = _make_client("outside") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-world\x00" + commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) + self.assertIs(commit_response, mock.sentinel.commit_response) + + # Verify mocks used. + self.assertEqual(_sleep.call_count, 2) + _sleep.assert_any_call(1.0) + _sleep.assert_any_call(2.0) + # commit() called same way 3 times. + commit_call = mock.call( + client._database_string, + mock.sentinel.write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + self.assertEqual( + firestore_api.commit.mock_calls, [commit_call, commit_call, commit_call] + ) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") + def test_failure_first_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails with an un-retryable error. + exc = exceptions.ResourceExhausted("We ran out of fries.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" + with self.assertRaises(exceptions.ResourceExhausted) as exc_info: + self._call_fut(client, mock.sentinel.write_pbs, txn_id) + + self.assertIs(exc_info.exception, exc) + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + client._database_string, + mock.sentinel.write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + + @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) + def test_failure_second_attempt(self, _sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.gapic import firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails retry-able and second + # fails non-retryable. + exc1 = exceptions.ServiceUnavailable("Come back next time.") + exc2 = exceptions.InternalServerError("Server on fritz.") + firestore_api.commit.side_effect = [exc1, exc2] + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-journey-when-and-where-well-go" + with self.assertRaises(exceptions.InternalServerError) as exc_info: + self._call_fut(client, mock.sentinel.write_pbs, txn_id) + + self.assertIs(exc_info.exception, exc2) + + # Verify mocks used. + _sleep.assert_called_once_with(1.0) + # commit() called same way 2 times. + commit_call = mock.call( + client._database_string, + mock.sentinel.write_pbs, + transaction=txn_id, + metadata=client._rpc_metadata, + ) + self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) + + +class Test__sleep(unittest.TestCase): + @staticmethod + def _call_fut(current_sleep, **kwargs): + from google.cloud.firestore_v1.async_transaction import _sleep + + return asyncio.run(_sleep(current_sleep, **kwargs)) + + @mock.patch("random.uniform", return_value=5.5) + @mock.patch("asyncio.sleep", return_value=None) + def test_defaults(self, sleep, uniform): + curr_sleep = 10.0 + self.assertLessEqual(uniform.return_value, curr_sleep) + + new_sleep = self._call_fut(curr_sleep) + self.assertEqual(new_sleep, 2.0 * curr_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + @mock.patch("random.uniform", return_value=10.5) + @mock.patch("asyncio.sleep", return_value=None) + def test_explicit(self, sleep, uniform): + curr_sleep = 12.25 + self.assertLessEqual(uniform.return_value, curr_sleep) + + multiplier = 1.5 + new_sleep = self._call_fut(curr_sleep, max_sleep=100.0, multiplier=multiplier) + self.assertEqual(new_sleep, multiplier * curr_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + @mock.patch("random.uniform", return_value=6.75) + @mock.patch("asyncio.sleep", return_value=None) + def test_exceeds_max(self, sleep, uniform): + curr_sleep = 20.0 + self.assertLessEqual(uniform.return_value, curr_sleep) + + max_sleep = 38.5 + new_sleep = self._call_fut(curr_sleep, max_sleep=max_sleep, multiplier=2.0) + self.assertEqual(new_sleep, max_sleep) + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="feral-tom-cat"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) + + +def _make_transaction(txn_id, **txn_kwargs): + from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + # Create a fake GAPIC ... + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # ... with a dummy ``BeginTransactionResponse`` result ... + begin_response = firestore_pb2.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = begin_response + # ... and a dummy ``Rollback`` result ... + firestore_api.rollback.return_value = empty_pb2.Empty() + # ... and a dummy ``Commit`` result. + commit_response = firestore_pb2.CommitResponse( + write_results=[write_pb2.WriteResult()] + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + return AsyncTransaction(client, **txn_kwargs) From 48685ef4a19a63f029ceadc7942d709dbec0184c Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 10:35:01 -0500 Subject: [PATCH 18/47] fix: linter errors --- google/cloud/firestore_v1/async_query.py | 4 +++- google/cloud/firestore_v1/async_transaction.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 83c1e105f3..b6001914ab 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -930,7 +930,9 @@ def _enum_from_direction(direction): elif direction == AsyncQuery.DESCENDING: return enums.StructuredQuery.Direction.DESCENDING else: - msg = _BAD_DIR_STRING.format(direction, AsyncQuery.ASCENDING, AsyncQuery.DESCENDING) + msg = _BAD_DIR_STRING.format( + direction, AsyncQuery.ASCENDING, AsyncQuery.DESCENDING + ) raise ValueError(msg) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 188d5596c0..0d0456318e 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -17,7 +17,6 @@ import asyncio import random -import time import six @@ -198,7 +197,9 @@ async def _commit(self): if not self.in_progress: raise ValueError(_CANT_COMMIT) - commit_response = await _commit_with_retry(self._client, self._write_pbs, self._id) + commit_response = await _commit_with_retry( + self._client, self._write_pbs, self._id + ) self._clean_up() return list(commit_response.write_results) From 36e914dd5538077b8456b2d49c3904de2cad1438 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 10:35:20 -0500 Subject: [PATCH 19/47] feat: refactor async tests to use aiounittest and pytest-asyncio --- noxfile.py | 3 +- tests/unit/v1/async/test_async_batch.py | 46 ++--- tests/unit/v1/async/test_async_client.py | 59 +++--- tests/unit/v1/async/test_async_collection.py | 38 ++-- tests/unit/v1/async/test_async_document.py | 141 +++++++------ tests/unit/v1/async/test_async_query.py | 70 +++---- tests/unit/v1/async/test_async_transaction.py | 186 +++++++++++------- 7 files changed, 292 insertions(+), 251 deletions(-) diff --git a/noxfile.py b/noxfile.py index cafa9785c2..236113a516 100644 --- a/noxfile.py +++ b/noxfile.py @@ -98,9 +98,10 @@ def unit(session): ) -@nox.session(python=["3.7", "3.8"]) +@nox.session(python=["3.6", "3.7", "3.8"]) def unit_async(session): """Run the unit test suite for async tests.""" + session.install("pytest-asyncio", "aiounittest") default(session, os.path.join("tests", "unit", "v1", "async"), None) diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index 7df76a6dae..aa71999b5c 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import unittest +import pytest +import aiounittest import mock -class TestAsyncWriteBatch(unittest.TestCase): +class TestAsyncWriteBatch(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch @@ -153,7 +153,8 @@ def test_delete(self): new_write_pb = write_pb2.Write(delete=reference._document_path) self.assertEqual(batch._write_pbs, [new_write_pb]) - def test_commit(self): + @pytest.mark.asyncio + async def test_commit(self): from google.protobuf import timestamp_pb2 from google.cloud.firestore_v1.proto import firestore_pb2 from google.cloud.firestore_v1.proto import write_pb2 @@ -179,7 +180,7 @@ def test_commit(self): batch.delete(document2) write_pbs = batch._write_pbs[::] - write_results = asyncio.run(batch.commit()) + write_results = await batch.commit() self.assertEqual(write_results, list(commit_response.write_results)) self.assertEqual(batch.write_results, write_results) self.assertEqual(batch.commit_time, timestamp) @@ -194,7 +195,8 @@ def test_commit(self): metadata=client._rpc_metadata, ) - def test_as_context_mgr_wo_error(self): + @pytest.mark.asyncio + async def test_as_context_mgr_wo_error(self): from google.protobuf import timestamp_pb2 from google.cloud.firestore_v1.proto import firestore_pb2 from google.cloud.firestore_v1.proto import write_pb2 @@ -212,9 +214,11 @@ def test_as_context_mgr_wo_error(self): document1 = client.document("a", "b") document2 = client.document("c", "d", "e", "f") - write_pbs = asyncio.run( - self._as_context_mgr_wo_error_helper(batch, document1, document2) - ) + async with batch as ctx_mgr: + self.assertIs(ctx_mgr, batch) + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + write_pbs = batch._write_pbs[::] self.assertEqual(batch.write_results, list(commit_response.write_results)) self.assertEqual(batch.commit_time, timestamp) @@ -229,15 +233,8 @@ def test_as_context_mgr_wo_error(self): metadata=client._rpc_metadata, ) - async def _as_context_mgr_wo_error_helper(self, batch, document1, document2): - async with batch as ctx_mgr: - self.assertIs(ctx_mgr, batch) - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) - ctx_mgr.delete(document2) - write_pbs = batch._write_pbs[::] - return write_pbs - - def test_as_context_mgr_w_error(self): + @pytest.mark.asyncio + async def test_as_context_mgr_w_error(self): firestore_api = mock.Mock(spec=["commit"]) client = _make_client() client._firestore_api_internal = firestore_api @@ -245,7 +242,11 @@ def test_as_context_mgr_w_error(self): document1 = client.document("a", "b") document2 = client.document("c", "d", "e", "f") - asyncio.run(self._as_context_mgr_w_error_helper(batch, document1, document2)) + with self.assertRaises(RuntimeError): + async with batch as ctx_mgr: + ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) + ctx_mgr.delete(document2) + raise RuntimeError("testing") self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) @@ -254,13 +255,6 @@ def test_as_context_mgr_w_error(self): firestore_api.commit.assert_not_called() - async def _as_context_mgr_w_error_helper(self, batch, document1, document2): - with self.assertRaises(RuntimeError): - async with batch as ctx_mgr: - ctx_mgr.create(document1, {"ten": 10, "buck": u"ets"}) - ctx_mgr.delete(document2) - raise RuntimeError("testing") - def _value_pb(**kwargs): from google.cloud.firestore_v1.proto.document_pb2 import Value diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 47db52eca9..f5af7f5107 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +import pytest import datetime import types -import unittest +import aiounittest import mock -class TestAsyncClient(unittest.TestCase): +class TestAsyncClient(aiounittest.AsyncTestCase): PROJECT = "my-prahjekt" @@ -333,7 +333,8 @@ def test_write_bad_arg(self): extra = "{!r} was provided".format("spinach") self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) - def test_collections(self): + @pytest.mark.asyncio + async def test_collections(self): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -356,7 +357,7 @@ def _next_page(self): iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = list(asyncio.run(client.collections())) + collections = list(await client.collections()) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): @@ -398,7 +399,8 @@ def _info_for_get_all(self, data1, data2): return client, document1, document2, response1, response2 - def test_get_all(self): + @pytest.mark.asyncio + async def test_get_all(self): from google.cloud.firestore_v1.proto import common_pb2 from google.cloud.firestore_v1.async_document import DocumentSnapshot @@ -409,13 +411,11 @@ def test_get_all(self): # Exercise the mocked ``batch_get_documents``. field_paths = ["a", "b"] - snapshots = asyncio.run( - self._get_all_helper( - client, - [document1, document2], - [response1, response2], - field_paths=field_paths, - ) + snapshots = await self._get_all_helper( + client, + [document1, document2], + [response1, response2], + field_paths=field_paths, ) self.assertEqual(len(snapshots), 2) @@ -440,7 +440,8 @@ def test_get_all(self): metadata=client._rpc_metadata, ) - def test_get_all_with_transaction(self): + @pytest.mark.asyncio + async def test_get_all_with_transaction(self): from google.cloud.firestore_v1.async_document import DocumentSnapshot data = {"so-much": 484} @@ -451,10 +452,8 @@ def test_get_all_with_transaction(self): transaction._id = txn_id # Exercise the mocked ``batch_get_documents``. - snapshots = asyncio.run( - self._get_all_helper( - client, [document], [response], transaction=transaction - ) + snapshots = await self._get_all_helper( + client, [document], [response], transaction=transaction ) self.assertEqual(len(snapshots), 1) @@ -473,7 +472,8 @@ def test_get_all_with_transaction(self): metadata=client._rpc_metadata, ) - def test_get_all_unknown_result(self): + @pytest.mark.asyncio + async def test_get_all_unknown_result(self): from google.cloud.firestore_v1.async_client import _BAD_DOC_TEMPLATE info = self._info_for_get_all({"z": 28.5}, {}) @@ -481,7 +481,7 @@ def test_get_all_unknown_result(self): # Exercise the mocked ``batch_get_documents``. with self.assertRaises(ValueError) as exc_info: - asyncio.run(self._get_all_helper(client, [document], [response])) + await self._get_all_helper(client, [document], [response]) err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) self.assertEqual(exc_info.exception.args, (err_msg,)) @@ -496,7 +496,8 @@ def test_get_all_unknown_result(self): metadata=client._rpc_metadata, ) - def test_get_all_wrong_order(self): + @pytest.mark.asyncio + async def test_get_all_wrong_order(self): from google.cloud.firestore_v1.async_document import DocumentSnapshot data1 = {"up": 10} @@ -507,12 +508,8 @@ def test_get_all_wrong_order(self): response3 = _make_batch_response(missing=document3._document_path) # Exercise the mocked ``batch_get_documents``. - snapshots = asyncio.run( - self._get_all_helper( - client, - [document1, document2, document3], - [response2, response1, response3], - ) + snapshots = await self._get_all_helper( + client, [document1, document2, document3], [response2, response1, response3] ) self.assertEqual(len(snapshots), 3) @@ -564,7 +561,7 @@ def test_transaction(self): self.assertIsNone(transaction._id) -class Test__reference_info(unittest.TestCase): +class Test__reference_info(aiounittest.AsyncTestCase): @staticmethod def _call_fut(references): from google.cloud.firestore_v1.async_client import _reference_info @@ -601,7 +598,7 @@ def test_it(self): self.assertEqual(reference_map, expected_map) -class Test__get_reference(unittest.TestCase): +class Test__get_reference(aiounittest.AsyncTestCase): @staticmethod def _call_fut(document_path, reference_map): from google.cloud.firestore_v1.async_client import _get_reference @@ -624,7 +621,7 @@ def test_failure(self): self.assertEqual(exc_info.exception.args, (err_msg,)) -class Test__parse_batch_get(unittest.TestCase): +class Test__parse_batch_get(aiounittest.AsyncTestCase): @staticmethod def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): from google.cloud.firestore_v1.async_client import _parse_batch_get @@ -702,7 +699,7 @@ def test_unknown_result_type(self): response_pb.WhichOneof.assert_called_once_with("result") -class Test__get_doc_mask(unittest.TestCase): +class Test__get_doc_mask(aiounittest.AsyncTestCase): @staticmethod def _call_fut(field_paths): from google.cloud.firestore_v1.async_client import _get_doc_mask diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index c9afe7486e..26a11ea735 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import pytest import types -import unittest +import aiounittest import mock import six -class TestAsyncCollectionReference(unittest.TestCase): +class TestAsyncCollectionReference(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -190,7 +189,8 @@ def test__parent_info_nested(self): prefix = "{}/{}".format(expected_path, collection_id2) self.assertEqual(expected_prefix, prefix) - def test_add_auto_assigned(self): + @pytest.mark.asyncio + async def test_add_auto_assigned(self): from google.cloud.firestore_v1.proto import document_pb2 from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1 import SERVER_TIMESTAMP @@ -223,7 +223,7 @@ def test_add_auto_assigned(self): random_doc_id = "DEADBEEF" with patch as patched: patched.return_value = random_doc_id - update_time, document_ref = asyncio.run(collection.add(document_data)) + update_time, document_ref = await collection.add(document_data) # Verify the response and the mocks. self.assertIs(update_time, mock.sentinel.update_time) @@ -256,7 +256,8 @@ def _write_pb_for_create(document_path, document_data): current_document=common_pb2.Precondition(exists=False), ) - def test_add_explicit_id(self): + @pytest.mark.asyncio + async def test_add_explicit_id(self): from google.cloud.firestore_v1.async_document import AsyncDocumentReference # Create a minimal fake GAPIC with a dummy response. @@ -279,8 +280,8 @@ def test_add_explicit_id(self): collection = self._make_one("parent", client=client) document_data = {"zorp": 208.75, "i-did-not": b"know that"} doc_id = "child" - update_time, document_ref = asyncio.run( - collection.add(document_data, document_id=doc_id) + update_time, document_ref = await collection.add( + document_data, document_id=doc_id ) # Verify the response and the mocks. @@ -430,7 +431,8 @@ def test_end_at(self): self.assertIs(query._parent, collection) self.assertEqual(query._end_at, (doc_fields, False)) - def _list_documents_helper(self, page_size=None): + @pytest.mark.asyncio + async def _list_documents_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -460,11 +462,9 @@ def _next_page(self): collection = self._make_one("collection", client=client) if page_size is not None: - documents = list( - asyncio.run(collection.list_documents(page_size=page_size)) - ) + documents = list(await collection.list_documents(page_size=page_size)) else: - documents = list(asyncio.run(collection.list_documents())) + documents = list(await collection.list_documents()) # Verify the response and the mocks. self.assertEqual(len(documents), len(document_ids)) @@ -482,11 +482,13 @@ def _next_page(self): metadata=client._rpc_metadata, ) - def test_list_documents_wo_page_size(self): - self._list_documents_helper() + @pytest.mark.asyncio + async def test_list_documents_wo_page_size(self): + await self._list_documents_helper() - def test_list_documents_w_page_size(self): - self._list_documents_helper(page_size=25) + @pytest.mark.asyncio + async def test_list_documents_w_page_size(self): + await self._list_documents_helper(page_size=25) @pytest.mark.skip(reason="no way of currently testing this") @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @@ -555,7 +557,7 @@ def test_on_snapshot(self, watch): watch.for_query.assert_called_once() -class Test__auto_id(unittest.TestCase): +class Test__auto_id(aiounittest.AsyncTestCase): @staticmethod def _call_fut(): from google.cloud.firestore_v1.async_collection import _auto_id diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index e3a04918d6..2f26f6a2a8 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +import pytest import collections -import unittest +import aiounittest import mock -class TestAsyncDocumentReference(unittest.TestCase): +class TestAsyncDocumentReference(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -217,7 +217,8 @@ def _make_commit_repsonse(write_results=None): response.commit_time = mock.sentinel.commit_time return response - def test_create(self): + @pytest.mark.asyncio + async def test_create(self): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() @@ -229,7 +230,7 @@ def test_create(self): # Actually make a document and call create(). document = self._make_one("foo", "twelve", client=client) document_data = {"hello": "goodbye", "count": 99} - write_result = asyncio.run(document.create(document_data)) + write_result = await document.create(document_data) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -241,7 +242,8 @@ def test_create(self): metadata=client._rpc_metadata, ) - def test_create_empty(self): + @pytest.mark.asyncio + async def test_create_empty(self): # Create a minimal fake GAPIC with a dummy response. from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_document import DocumentSnapshot @@ -264,8 +266,8 @@ def test_create_empty(self): # Actually make a document and call create(). document = self._make_one("foo", "twelve", client=client) document_data = {} - write_result = asyncio.run(document.create(document_data)) - self.assertTrue(asyncio.run(write_result.get()).exists) + write_result = await document.create(document_data) + self.assertTrue((await write_result.get()).exists) @staticmethod def _write_pb_for_set(document_path, document_data, merge): @@ -293,7 +295,8 @@ def _write_pb_for_set(document_path, document_data, merge): write_pbs.update_mask.CopyFrom(mask) return write_pbs - def _set_helper(self, merge=False, **option_kwargs): + @pytest.mark.asyncio + async def _set_helper(self, merge=False, **option_kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() @@ -305,7 +308,7 @@ def _set_helper(self, merge=False, **option_kwargs): # Actually make a document and call create(). document = self._make_one("User", "Interface", client=client) document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} - write_result = asyncio.run(document.set(document_data, merge)) + write_result = await document.set(document_data, merge) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -318,11 +321,13 @@ def _set_helper(self, merge=False, **option_kwargs): metadata=client._rpc_metadata, ) - def test_set(self): - self._set_helper() + @pytest.mark.asyncio + async def test_set(self): + await self._set_helper() - def test_set_merge(self): - self._set_helper(merge=True) + @pytest.mark.asyncio + async def test_set_merge(self): + await self._set_helper(merge=True) @staticmethod def _write_pb_for_update(document_path, update_values, field_paths): @@ -339,7 +344,8 @@ def _write_pb_for_update(document_path, update_values, field_paths): current_document=common_pb2.Precondition(exists=True), ) - def _update_helper(self, **option_kwargs): + @pytest.mark.asyncio + async def _update_helper(self, **option_kwargs): from google.cloud.firestore_v1.transforms import DELETE_FIELD # Create a minimal fake GAPIC with a dummy response. @@ -358,10 +364,10 @@ def _update_helper(self, **option_kwargs): ) if option_kwargs: option = client.write_option(**option_kwargs) - write_result = asyncio.run(document.update(field_updates, option=option)) + write_result = await document.update(field_updates, option=option) else: option = None - write_result = asyncio.run(document.update(field_updates)) + write_result = await document.update(field_updates) # Verify the response and the mocks. self.assertIs(write_result, mock.sentinel.write_result) @@ -382,20 +388,24 @@ def _update_helper(self, **option_kwargs): metadata=client._rpc_metadata, ) - def test_update_with_exists(self): + @pytest.mark.asyncio + async def test_update_with_exists(self): with self.assertRaises(ValueError): - self._update_helper(exists=True) + await self._update_helper(exists=True) - def test_update(self): - self._update_helper() + @pytest.mark.asyncio + async def test_update(self): + await self._update_helper() - def test_update_with_precondition(self): + @pytest.mark.asyncio + async def test_update_with_precondition(self): from google.protobuf import timestamp_pb2 timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - self._update_helper(last_update_time=timestamp) + await self._update_helper(last_update_time=timestamp) - def test_empty_update(self): + @pytest.mark.asyncio + async def test_empty_update(self): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() @@ -409,9 +419,10 @@ def test_empty_update(self): # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. field_updates = {} with self.assertRaises(ValueError): - asyncio.run(document.update(field_updates)) + await document.update(field_updates) - def _delete_helper(self, **option_kwargs): + @pytest.mark.asyncio + async def _delete_helper(self, **option_kwargs): from google.cloud.firestore_v1.proto import write_pb2 # Create a minimal fake GAPIC with a dummy response. @@ -426,10 +437,10 @@ def _delete_helper(self, **option_kwargs): document = self._make_one("where", "we-are", client=client) if option_kwargs: option = client.write_option(**option_kwargs) - delete_time = asyncio.run(document.delete(option=option)) + delete_time = await document.delete(option=option) else: option = None - delete_time = asyncio.run(document.delete()) + delete_time = await document.delete() # Verify the response and the mocks. self.assertIs(delete_time, mock.sentinel.commit_time) @@ -443,16 +454,21 @@ def _delete_helper(self, **option_kwargs): metadata=client._rpc_metadata, ) - def test_delete(self): - self._delete_helper() + @pytest.mark.asyncio + async def test_delete(self): + await self._delete_helper() - def test_delete_with_option(self): + @pytest.mark.asyncio + async def test_delete_with_option(self): from google.protobuf import timestamp_pb2 timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - self._delete_helper(last_update_time=timestamp_pb) + await self._delete_helper(last_update_time=timestamp_pb) - def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): + @pytest.mark.asyncio + async def _get_helper( + self, field_paths=None, use_transaction=False, not_found=False + ): from google.api_core.exceptions import NotFound from google.cloud.firestore_v1.proto import common_pb2 from google.cloud.firestore_v1.proto import document_pb2 @@ -483,9 +499,7 @@ def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): else: transaction = None - snapshot = asyncio.run( - document.get(field_paths=field_paths, transaction=transaction) - ) + snapshot = await document.get(field_paths=field_paths, transaction=transaction) self.assertIs(snapshot.reference, document) if not_found: @@ -519,26 +533,33 @@ def _get_helper(self, field_paths=None, use_transaction=False, not_found=False): metadata=client._rpc_metadata, ) - def test_get_not_found(self): - self._get_helper(not_found=True) + @pytest.mark.asyncio + async def test_get_not_found(self): + await self._get_helper(not_found=True) - def test_get_default(self): - self._get_helper() + @pytest.mark.asyncio + async def test_get_default(self): + await self._get_helper() - def test_get_w_string_field_path(self): + @pytest.mark.asyncio + async def test_get_w_string_field_path(self): with self.assertRaises(ValueError): - self._get_helper(field_paths="foo") + await self._get_helper(field_paths="foo") - def test_get_with_field_path(self): - self._get_helper(field_paths=["foo"]) + @pytest.mark.asyncio + async def test_get_with_field_path(self): + await self._get_helper(field_paths=["foo"]) - def test_get_with_multiple_field_paths(self): - self._get_helper(field_paths=["foo", "bar.baz"]) + @pytest.mark.asyncio + async def test_get_with_multiple_field_paths(self): + await self._get_helper(field_paths=["foo", "bar.baz"]) - def test_get_with_transaction(self): - self._get_helper(use_transaction=True) + @pytest.mark.asyncio + async def test_get_with_transaction(self): + await self._get_helper(use_transaction=True) - def _collections_helper(self, page_size=None): + @pytest.mark.asyncio + async def _collections_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -565,9 +586,9 @@ def _next_page(self): # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if page_size is not None: - collections = list(asyncio.run(document.collections(page_size=page_size))) + collections = list(await document.collections(page_size=page_size)) else: - collections = list(asyncio.run(document.collections())) + collections = list(await document.collections()) # Verify the response and the mocks. self.assertEqual(len(collections), len(collection_ids)) @@ -580,11 +601,13 @@ def _next_page(self): document._document_path, page_size=page_size, metadata=client._rpc_metadata ) - def test_collections_wo_page_size(self): - self._collections_helper() + @pytest.mark.asyncio + async def test_collections_wo_page_size(self): + await self._collections_helper() - def test_collections_w_page_size(self): - self._collections_helper(page_size=10) + @pytest.mark.asyncio + async def test_collections_w_page_size(self): + await self._collections_helper(page_size=10) @mock.patch("google.cloud.firestore_v1.async_document.Watch", autospec=True) def test_on_snapshot(self, watch): @@ -594,7 +617,7 @@ def test_on_snapshot(self, watch): watch.for_document.assert_called_once() -class TestDocumentSnapshot(unittest.TestCase): +class TestDocumentSnapshot(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_document import DocumentSnapshot @@ -740,7 +763,7 @@ def test_non_existent(self): self.assertIsNone(as_dict) -class Test__get_document_path(unittest.TestCase): +class Test__get_document_path(aiounittest.AsyncTestCase): @staticmethod def _call_fut(client, path): from google.cloud.firestore_v1.async_document import _get_document_path @@ -759,7 +782,7 @@ def test_it(self): self.assertEqual(document_path, expected) -class Test__consume_single_get(unittest.TestCase): +class Test__consume_single_get(aiounittest.AsyncTestCase): @staticmethod def _call_fut(response_iterator): from google.cloud.firestore_v1.async_document import _consume_single_get @@ -782,7 +805,7 @@ def test_failure_too_many(self): self._call_fut(response_iterator) -class Test__first_write_result(unittest.TestCase): +class Test__first_write_result(aiounittest.AsyncTestCase): @staticmethod def _call_fut(write_results): from google.cloud.firestore_v1.async_document import _first_write_result diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index 9b47641522..8f773a501e 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +import pytest import datetime import types -import unittest +import aiounittest import mock import six -class TestAsyncQuery(unittest.TestCase): +class TestAsyncQuery(aiounittest.AsyncTestCase): if six.PY2: - assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + assertRaisesRegex = aiounittest.AsyncTestCase.assertRaisesRegexp @staticmethod def _get_target_class(): @@ -1063,10 +1063,8 @@ def test__to_protobuf_limit_only(self): self.assertEqual(structured_query_pb, expected_pb) - def test_get_simple(self): - asyncio.run(self._test_get_simple_helper()) - - async def _test_get_simple_helper(self): + @pytest.mark.asyncio + async def test_get_simple(self): import warnings # Create a minimal fake GAPIC. @@ -1112,10 +1110,8 @@ async def _test_get_simple_helper(self): self.assertEqual(len(warned), 1) self.assertIs(warned[0].category, DeprecationWarning) - def test_stream_simple(self): - asyncio.run(self._test_stream_simple_helper()) - - async def _test_stream_simple_helper(self): + @pytest.mark.asyncio + async def test_stream_simple(self): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -1152,10 +1148,8 @@ async def _test_stream_simple_helper(self): metadata=client._rpc_metadata, ) - def test_stream_with_transaction(self): - asyncio.run(self._test_stream_with_transaction_helper()) - - async def _test_stream_with_transaction_helper(self): + @pytest.mark.asyncio + async def test_stream_with_transaction(self): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -1196,10 +1190,8 @@ async def _test_stream_with_transaction_helper(self): metadata=client._rpc_metadata, ) - def test_stream_no_results(self): - asyncio.run(self._test_stream_no_results_helper()) - - async def _test_stream_no_results_helper(self): + @pytest.mark.asyncio + async def test_stream_no_results(self): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) empty_response = _make_query_response() @@ -1227,10 +1219,8 @@ async def _test_stream_no_results_helper(self): metadata=client._rpc_metadata, ) - def test_stream_second_response_in_empty_stream(self): - asyncio.run(self._test_stream_second_response_in_empty_stream_helper()) - - async def _test_stream_second_response_in_empty_stream_helper(self): + @pytest.mark.asyncio + async def test_stream_second_response_in_empty_stream(self): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) empty_response1 = _make_query_response() @@ -1259,10 +1249,8 @@ async def _test_stream_second_response_in_empty_stream_helper(self): metadata=client._rpc_metadata, ) - def test_stream_with_skipped_results(self): - asyncio.run(self._test_stream_with_skipped_results_helper()) - - async def _test_stream_with_skipped_results_helper(self): + @pytest.mark.asyncio + async def test_stream_with_skipped_results(self): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -1300,10 +1288,8 @@ async def _test_stream_with_skipped_results_helper(self): metadata=client._rpc_metadata, ) - def test_stream_empty_after_first_response(self): - asyncio.run(self._test_stream_empty_after_first_response_helper()) - - async def _test_stream_empty_after_first_response_helper(self): + @pytest.mark.asyncio + async def test_stream_empty_after_first_response(self): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -1341,10 +1327,8 @@ async def _test_stream_empty_after_first_response_helper(self): metadata=client._rpc_metadata, ) - def test_stream_w_collection_group(self): - asyncio.run(self._test_stream_w_collection_group_helper()) - - async def _test_stream_w_collection_group_helper(self): + @pytest.mark.asyncio + async def test_stream_w_collection_group(self): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -1482,7 +1466,7 @@ def test_comparator_missing_order_by_field_in_data_raises(self): query._comparator(doc1, doc2) -class Test__enum_from_op_string(unittest.TestCase): +class Test__enum_from_op_string(aiounittest.AsyncTestCase): @staticmethod def _call_fut(op_string): from google.cloud.firestore_v1.async_query import _enum_from_op_string @@ -1534,7 +1518,7 @@ def test_invalid(self): self._call_fut("?") -class Test__isnan(unittest.TestCase): +class Test__isnan(aiounittest.AsyncTestCase): @staticmethod def _call_fut(value): from google.cloud.firestore_v1.async_query import _isnan @@ -1552,7 +1536,7 @@ def test_invalid(self): self.assertFalse(self._call_fut(1.0 + 1.0j)) -class Test__enum_from_direction(unittest.TestCase): +class Test__enum_from_direction(aiounittest.AsyncTestCase): @staticmethod def _call_fut(direction): from google.cloud.firestore_v1.async_query import _enum_from_direction @@ -1576,7 +1560,7 @@ def test_failure(self): self._call_fut("neither-ASCENDING-nor-DESCENDING") -class Test__filter_pb(unittest.TestCase): +class Test__filter_pb(aiounittest.AsyncTestCase): @staticmethod def _call_fut(field_or_unary): from google.cloud.firestore_v1.async_query import _filter_pb @@ -1614,7 +1598,7 @@ def test_bad_type(self): self._call_fut(None) -class Test__cursor_pb(unittest.TestCase): +class Test__cursor_pb(aiounittest.AsyncTestCase): @staticmethod def _call_fut(cursor_pair): from google.cloud.firestore_v1.async_query import _cursor_pb @@ -1639,7 +1623,7 @@ def test_success(self): self.assertEqual(cursor_pb, expected_pb) -class Test__query_response_to_snapshot(unittest.TestCase): +class Test__query_response_to_snapshot(aiounittest.AsyncTestCase): @staticmethod def _call_fut(response_pb, collection, expected_prefix): from google.cloud.firestore_v1.async_query import _query_response_to_snapshot @@ -1681,7 +1665,7 @@ def test_response(self): self.assertEqual(snapshot.update_time, response_pb.document.update_time) -class Test__collection_group_query_response_to_snapshot(unittest.TestCase): +class Test__collection_group_query_response_to_snapshot(aiounittest.AsyncTestCase): @staticmethod def _call_fut(response_pb, collection): from google.cloud.firestore_v1.async_query import ( diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index c9fcfeda89..d4b0020832 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import unittest +import pytest +import aiounittest import mock -class TestAsyncTransaction(unittest.TestCase): +class TestAsyncTransaction(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_transaction import AsyncTransaction @@ -115,7 +115,8 @@ def test_id_property(self): transaction._id = mock.sentinel.eye_dee self.assertIs(transaction.id, mock.sentinel.eye_dee) - def test__begin(self): + @pytest.mark.asyncio + async def test__begin(self): from google.cloud.firestore_v1.gapic import firestore_client from google.cloud.firestore_v1.proto import firestore_pb2 @@ -135,7 +136,7 @@ def test__begin(self): transaction = self._make_one(client) self.assertIsNone(transaction._id) - ret_val = asyncio.run(transaction._begin()) + ret_val = await transaction._begin() self.assertIsNone(ret_val) self.assertEqual(transaction._id, txn_id) @@ -144,7 +145,8 @@ def test__begin(self): client._database_string, options_=None, metadata=client._rpc_metadata ) - def test__begin_failure(self): + @pytest.mark.asyncio + async def test__begin_failure(self): from google.cloud.firestore_v1.async_transaction import _CANT_BEGIN client = _make_client() @@ -152,7 +154,7 @@ def test__begin_failure(self): transaction._id = b"not-none" with self.assertRaises(ValueError) as exc_info: - asyncio.run(transaction._begin()) + await transaction._begin() err_msg = _CANT_BEGIN.format(transaction._id) self.assertEqual(exc_info.exception.args, (err_msg,)) @@ -170,7 +172,8 @@ def test__clean_up(self): self.assertEqual(transaction._write_pbs, []) self.assertIsNone(transaction._id) - def test__rollback(self): + @pytest.mark.asyncio + async def test__rollback(self): from google.protobuf import empty_pb2 from google.cloud.firestore_v1.gapic import firestore_client @@ -188,7 +191,7 @@ def test__rollback(self): transaction = self._make_one(client) txn_id = b"to-be-r\x00lled" transaction._id = txn_id - ret_val = asyncio.run(transaction._rollback()) + ret_val = await transaction._rollback() self.assertIsNone(ret_val) self.assertIsNone(transaction._id) @@ -197,7 +200,8 @@ def test__rollback(self): client._database_string, txn_id, metadata=client._rpc_metadata ) - def test__rollback_not_allowed(self): + @pytest.mark.asyncio + async def test__rollback_not_allowed(self): from google.cloud.firestore_v1.async_transaction import _CANT_ROLLBACK client = _make_client() @@ -205,11 +209,12 @@ def test__rollback_not_allowed(self): self.assertIsNone(transaction._id) with self.assertRaises(ValueError) as exc_info: - asyncio.run(transaction._rollback()) + await transaction._rollback() self.assertEqual(exc_info.exception.args, (_CANT_ROLLBACK,)) - def test__rollback_failure(self): + @pytest.mark.asyncio + async def test__rollback_failure(self): from google.api_core import exceptions from google.cloud.firestore_v1.gapic import firestore_client @@ -230,7 +235,7 @@ def test__rollback_failure(self): transaction._id = txn_id with self.assertRaises(exceptions.InternalServerError) as exc_info: - asyncio.run(transaction._rollback()) + await transaction._rollback() self.assertIs(exc_info.exception, exc) self.assertIsNone(transaction._id) @@ -241,7 +246,8 @@ def test__rollback_failure(self): client._database_string, txn_id, metadata=client._rpc_metadata ) - def test__commit(self): + @pytest.mark.asyncio + async def test__commit(self): from google.cloud.firestore_v1.gapic import firestore_client from google.cloud.firestore_v1.proto import firestore_pb2 from google.cloud.firestore_v1.proto import write_pb2 @@ -267,7 +273,7 @@ def test__commit(self): transaction.set(document, {"apple": 4.5}) write_pbs = transaction._write_pbs[::] - write_results = asyncio.run(transaction._commit()) + write_results = await transaction._commit() self.assertEqual(write_results, list(commit_response.write_results)) # Make sure transaction has no more "changes". self.assertIsNone(transaction._id) @@ -281,17 +287,19 @@ def test__commit(self): metadata=client._rpc_metadata, ) - def test__commit_not_allowed(self): + @pytest.mark.asyncio + async def test__commit_not_allowed(self): from google.cloud.firestore_v1.async_transaction import _CANT_COMMIT transaction = self._make_one(mock.sentinel.client) self.assertIsNone(transaction._id) with self.assertRaises(ValueError) as exc_info: - asyncio.run(transaction._commit()) + await transaction._commit() self.assertEqual(exc_info.exception.args, (_CANT_COMMIT,)) - def test__commit_failure(self): + @pytest.mark.asyncio + async def test__commit_failure(self): from google.api_core import exceptions from google.cloud.firestore_v1.gapic import firestore_client @@ -315,7 +323,7 @@ def test__commit_failure(self): write_pbs = transaction._write_pbs[::] with self.assertRaises(exceptions.InternalServerError) as exc_info: - asyncio.run(transaction._commit()) + await transaction._commit() self.assertIs(exc_info.exception, exc) self.assertEqual(transaction._id, txn_id) @@ -329,44 +337,48 @@ def test__commit_failure(self): metadata=client._rpc_metadata, ) - def test_get_all(self): + @pytest.mark.asyncio + async def test_get_all(self): client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) ref1, ref2 = mock.Mock(), mock.Mock() - result = asyncio.run(transaction.get_all([ref1, ref2])) + result = await transaction.get_all([ref1, ref2]) client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction) self.assertIs(result, client.get_all.return_value) - def test_get_document_ref(self): + @pytest.mark.asyncio + async def test_get_document_ref(self): from google.cloud.firestore_v1.async_document import AsyncDocumentReference client = mock.Mock(spec=["get_all"]) transaction = self._make_one(client) ref = AsyncDocumentReference("documents", "doc-id") - result = asyncio.run(transaction.get(ref)) + result = await transaction.get(ref) client.get_all.assert_called_once_with([ref], transaction=transaction) self.assertIs(result, client.get_all.return_value) - def test_get_w_query(self): + @pytest.mark.asyncio + async def test_get_w_query(self): from google.cloud.firestore_v1.async_query import AsyncQuery client = mock.Mock(spec=[]) transaction = self._make_one(client) query = AsyncQuery(parent=mock.Mock(spec=[])) query.stream = mock.MagicMock() - result = asyncio.run(transaction.get(query)) + result = await transaction.get(query) query.stream.assert_called_once_with(transaction=transaction) self.assertIs(result, query.stream.return_value) - def test_get_failure(self): + @pytest.mark.asyncio + async def test_get_failure(self): client = _make_client() transaction = self._make_one(client) ref_or_query = object() with self.assertRaises(ValueError): - asyncio.run(transaction.get(ref_or_query)) + await transaction.get(ref_or_query) -class Test_Transactional(unittest.TestCase): +class Test_Transactional(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_transaction import _Transactional @@ -394,13 +406,14 @@ def test__reset(self): self.assertIsNone(wrapped.current_id) self.assertIsNone(wrapped.retry_id) - def test__pre_commit_success(self): + @pytest.mark.asyncio + async def test__pre_commit_success(self): to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"totes-began" transaction = _make_transaction(txn_id) - result = asyncio.run(wrapped._pre_commit(transaction, "pos", key="word")) + result = await wrapped._pre_commit(transaction, "pos", key="word") self.assertIs(result, mock.sentinel.result) self.assertEqual(transaction._id, txn_id) @@ -418,7 +431,8 @@ def test__pre_commit_success(self): firestore_api.rollback.assert_not_called() firestore_api.commit.assert_not_called() - def test__pre_commit_retry_id_already_set_success(self): + @pytest.mark.asyncio + async def test__pre_commit_retry_id_already_set_success(self): from google.cloud.firestore_v1.proto import common_pb2 to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) @@ -428,7 +442,7 @@ def test__pre_commit_retry_id_already_set_success(self): txn_id2 = b"ok-here-too" transaction = _make_transaction(txn_id2) - result = asyncio.run(wrapped._pre_commit(transaction)) + result = await wrapped._pre_commit(transaction) self.assertIs(result, mock.sentinel.result) self.assertEqual(transaction._id, txn_id2) @@ -451,7 +465,8 @@ def test__pre_commit_retry_id_already_set_success(self): firestore_api.rollback.assert_not_called() firestore_api.commit.assert_not_called() - def test__pre_commit_failure(self): + @pytest.mark.asyncio + async def test__pre_commit_failure(self): exc = RuntimeError("Nope not today.") to_wrap = mock.Mock(side_effect=exc, spec=[]) wrapped = self._make_one(to_wrap) @@ -459,7 +474,7 @@ def test__pre_commit_failure(self): txn_id = b"gotta-fail" transaction = _make_transaction(txn_id) with self.assertRaises(RuntimeError) as exc_info: - asyncio.run(wrapped._pre_commit(transaction, 10, 20)) + await wrapped._pre_commit(transaction, 10, 20) self.assertIs(exc_info.exception, exc) self.assertIsNone(transaction._id) @@ -481,7 +496,8 @@ def test__pre_commit_failure(self): ) firestore_api.commit.assert_not_called() - def test__pre_commit_failure_with_rollback_failure(self): + @pytest.mark.asyncio + async def test__pre_commit_failure_with_rollback_failure(self): from google.api_core import exceptions exc1 = ValueError("I will not be only failure.") @@ -497,7 +513,7 @@ def test__pre_commit_failure_with_rollback_failure(self): # Try to ``_pre_commit`` with self.assertRaises(exceptions.InternalServerError) as exc_info: - asyncio.run(wrapped._pre_commit(transaction, a="b", c="zebra")) + await wrapped._pre_commit(transaction, a="b", c="zebra") self.assertIs(exc_info.exception, exc2) self.assertIsNone(transaction._id) @@ -518,13 +534,14 @@ def test__pre_commit_failure_with_rollback_failure(self): ) firestore_api.commit.assert_not_called() - def test__maybe_commit_success(self): + @pytest.mark.asyncio + async def test__maybe_commit_success(self): wrapped = self._make_one(mock.sentinel.callable_) txn_id = b"nyet" transaction = _make_transaction(txn_id) transaction._id = txn_id # We won't call ``begin()``. - succeeded = asyncio.run(wrapped._maybe_commit(transaction)) + succeeded = await wrapped._maybe_commit(transaction) self.assertTrue(succeeded) # On success, _id is reset. @@ -541,7 +558,8 @@ def test__maybe_commit_success(self): metadata=transaction._client._rpc_metadata, ) - def test__maybe_commit_failure_read_only(self): + @pytest.mark.asyncio + async def test__maybe_commit_failure_read_only(self): from google.api_core import exceptions wrapped = self._make_one(mock.sentinel.callable_) @@ -559,7 +577,7 @@ def test__maybe_commit_failure_read_only(self): firestore_api.commit.side_effect = exc with self.assertRaises(exceptions.Aborted) as exc_info: - asyncio.run(wrapped._maybe_commit(transaction)) + await wrapped._maybe_commit(transaction) self.assertIs(exc_info.exception, exc) self.assertEqual(transaction._id, txn_id) @@ -576,7 +594,8 @@ def test__maybe_commit_failure_read_only(self): metadata=transaction._client._rpc_metadata, ) - def test__maybe_commit_failure_can_retry(self): + @pytest.mark.asyncio + async def test__maybe_commit_failure_can_retry(self): from google.api_core import exceptions wrapped = self._make_one(mock.sentinel.callable_) @@ -592,7 +611,7 @@ def test__maybe_commit_failure_can_retry(self): firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc - succeeded = asyncio.run(wrapped._maybe_commit(transaction)) + succeeded = await wrapped._maybe_commit(transaction) self.assertFalse(succeeded) self.assertEqual(transaction._id, txn_id) @@ -609,7 +628,8 @@ def test__maybe_commit_failure_can_retry(self): metadata=transaction._client._rpc_metadata, ) - def test__maybe_commit_failure_cannot_retry(self): + @pytest.mark.asyncio + async def test__maybe_commit_failure_cannot_retry(self): from google.api_core import exceptions wrapped = self._make_one(mock.sentinel.callable_) @@ -626,7 +646,7 @@ def test__maybe_commit_failure_cannot_retry(self): firestore_api.commit.side_effect = exc with self.assertRaises(exceptions.InternalServerError) as exc_info: - asyncio.run(wrapped._maybe_commit(transaction)) + await wrapped._maybe_commit(transaction) self.assertIs(exc_info.exception, exc) self.assertEqual(transaction._id, txn_id) @@ -643,13 +663,14 @@ def test__maybe_commit_failure_cannot_retry(self): metadata=transaction._client._rpc_metadata, ) - def test___call__success_first_attempt(self): + @pytest.mark.asyncio + async def test___call__success_first_attempt(self): to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) txn_id = b"whole-enchilada" transaction = _make_transaction(txn_id) - result = asyncio.run(wrapped(transaction, "a", b="c")) + result = await wrapped(transaction, "a", b="c") self.assertIs(result, mock.sentinel.result) self.assertIsNone(transaction._id) @@ -672,7 +693,8 @@ def test___call__success_first_attempt(self): metadata=transaction._client._rpc_metadata, ) - def test___call__success_second_attempt(self): + @pytest.mark.asyncio + async def test___call__success_second_attempt(self): from google.api_core import exceptions from google.cloud.firestore_v1.proto import common_pb2 from google.cloud.firestore_v1.proto import firestore_pb2 @@ -693,7 +715,7 @@ def test___call__success_second_attempt(self): ] # Call the __call__-able ``wrapped``. - result = asyncio.run(wrapped(transaction, "a", b="c")) + result = await wrapped(transaction, "a", b="c") self.assertIs(result, mock.sentinel.result) self.assertIsNone(transaction._id) @@ -727,9 +749,12 @@ def test___call__success_second_attempt(self): ) self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) - def test___call__failure(self): + @pytest.mark.asyncio + async def test___call__failure(self): from google.api_core import exceptions - from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE + from google.cloud.firestore_v1.async_transaction import ( + _EXCEED_ATTEMPTS_TEMPLATE, + ) to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) @@ -744,7 +769,7 @@ def test___call__failure(self): # Call the __call__-able ``wrapped``. with self.assertRaises(ValueError) as exc_info: - asyncio.run(wrapped(transaction, "here", there=1.5)) + await wrapped(transaction, "here", there=1.5) err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) self.assertEqual(exc_info.exception.args, (err_msg,)) @@ -773,7 +798,7 @@ def test___call__failure(self): ) -class Test_transactional(unittest.TestCase): +class Test_transactional(aiounittest.AsyncTestCase): @staticmethod def _call_fut(to_wrap): from google.cloud.firestore_v1.async_transaction import transactional @@ -788,15 +813,17 @@ def test_it(self): self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) -class Test__commit_with_retry(unittest.TestCase): +class Test__commit_with_retry(aiounittest.AsyncTestCase): @staticmethod - def _call_fut(client, write_pbs, transaction_id): + @pytest.mark.asyncio + async def _call_fut(client, write_pbs, transaction_id): from google.cloud.firestore_v1.async_transaction import _commit_with_retry - return asyncio.run(_commit_with_retry(client, write_pbs, transaction_id)) + return await _commit_with_retry(client, write_pbs, transaction_id) @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") - def test_success_first_attempt(self, _sleep): + @pytest.mark.asyncio + async def test_success_first_attempt(self, _sleep): from google.cloud.firestore_v1.gapic import firestore_client # Create a minimal fake GAPIC with a dummy result. @@ -810,7 +837,7 @@ def test_success_first_attempt(self, _sleep): # Call function and check result. txn_id = b"cheeeeeez" - commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) + commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) self.assertIs(commit_response, firestore_api.commit.return_value) # Verify mocks used. @@ -822,8 +849,11 @@ def test_success_first_attempt(self, _sleep): metadata=client._rpc_metadata, ) - @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0]) - def test_success_third_attempt(self, _sleep): + @mock.patch( + "google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0] + ) + @pytest.mark.asyncio + async def test_success_third_attempt(self, _sleep): from google.api_core import exceptions from google.cloud.firestore_v1.gapic import firestore_client @@ -844,7 +874,7 @@ def test_success_third_attempt(self, _sleep): # Call function and check result. txn_id = b"the-world\x00" - commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) + commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) self.assertIs(commit_response, mock.sentinel.commit_response) # Verify mocks used. @@ -863,7 +893,8 @@ def test_success_third_attempt(self, _sleep): ) @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") - def test_failure_first_attempt(self, _sleep): + @pytest.mark.asyncio + async def test_failure_first_attempt(self, _sleep): from google.api_core import exceptions from google.cloud.firestore_v1.gapic import firestore_client @@ -882,7 +913,7 @@ def test_failure_first_attempt(self, _sleep): # Call function and check result. txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" with self.assertRaises(exceptions.ResourceExhausted) as exc_info: - self._call_fut(client, mock.sentinel.write_pbs, txn_id) + await self._call_fut(client, mock.sentinel.write_pbs, txn_id) self.assertIs(exc_info.exception, exc) @@ -896,7 +927,8 @@ def test_failure_first_attempt(self, _sleep): ) @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) - def test_failure_second_attempt(self, _sleep): + @pytest.mark.asyncio + async def test_failure_second_attempt(self, _sleep): from google.api_core import exceptions from google.cloud.firestore_v1.gapic import firestore_client @@ -917,7 +949,7 @@ def test_failure_second_attempt(self, _sleep): # Call function and check result. txn_id = b"the-journey-when-and-where-well-go" with self.assertRaises(exceptions.InternalServerError) as exc_info: - self._call_fut(client, mock.sentinel.write_pbs, txn_id) + await self._call_fut(client, mock.sentinel.write_pbs, txn_id) self.assertIs(exc_info.exception, exc2) @@ -933,20 +965,22 @@ def test_failure_second_attempt(self, _sleep): self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) -class Test__sleep(unittest.TestCase): +class Test__sleep(aiounittest.AsyncTestCase): @staticmethod - def _call_fut(current_sleep, **kwargs): + @pytest.mark.asyncio + async def _call_fut(current_sleep, **kwargs): from google.cloud.firestore_v1.async_transaction import _sleep - return asyncio.run(_sleep(current_sleep, **kwargs)) + return await _sleep(current_sleep, **kwargs) @mock.patch("random.uniform", return_value=5.5) @mock.patch("asyncio.sleep", return_value=None) - def test_defaults(self, sleep, uniform): + @pytest.mark.asyncio + async def test_defaults(self, sleep, uniform): curr_sleep = 10.0 self.assertLessEqual(uniform.return_value, curr_sleep) - new_sleep = self._call_fut(curr_sleep) + new_sleep = await self._call_fut(curr_sleep) self.assertEqual(new_sleep, 2.0 * curr_sleep) uniform.assert_called_once_with(0.0, curr_sleep) @@ -954,12 +988,15 @@ def test_defaults(self, sleep, uniform): @mock.patch("random.uniform", return_value=10.5) @mock.patch("asyncio.sleep", return_value=None) - def test_explicit(self, sleep, uniform): + @pytest.mark.asyncio + async def test_explicit(self, sleep, uniform): curr_sleep = 12.25 self.assertLessEqual(uniform.return_value, curr_sleep) multiplier = 1.5 - new_sleep = self._call_fut(curr_sleep, max_sleep=100.0, multiplier=multiplier) + new_sleep = await self._call_fut( + curr_sleep, max_sleep=100.0, multiplier=multiplier + ) self.assertEqual(new_sleep, multiplier * curr_sleep) uniform.assert_called_once_with(0.0, curr_sleep) @@ -967,12 +1004,15 @@ def test_explicit(self, sleep, uniform): @mock.patch("random.uniform", return_value=6.75) @mock.patch("asyncio.sleep", return_value=None) - def test_exceeds_max(self, sleep, uniform): + @pytest.mark.asyncio + async def test_exceeds_max(self, sleep, uniform): curr_sleep = 20.0 self.assertLessEqual(uniform.return_value, curr_sleep) max_sleep = 38.5 - new_sleep = self._call_fut(curr_sleep, max_sleep=max_sleep, multiplier=2.0) + new_sleep = await self._call_fut( + curr_sleep, max_sleep=max_sleep, multiplier=2.0 + ) self.assertEqual(new_sleep, max_sleep) uniform.assert_called_once_with(0.0, curr_sleep) From 0eec38c42c18f7aa1b539964912c520ab82cfa88 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 11:47:29 -0500 Subject: [PATCH 20/47] feat: remove duplicate code from async_client --- google/cloud/firestore_v1/async_client.py | 360 ++------------------ google/cloud/firestore_v1/async_document.py | 171 +--------- 2 files changed, 21 insertions(+), 510 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index b9ff127df8..4567e7e7db 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,9 +25,19 @@ """ import os -import google.api_core.client_options -from google.api_core.gapic_v1 import client_info -from google.cloud.client import ClientWithProject +from google.cloud.firestore_v1.client import ( + Client, + DEFAULT_DATABASE, + _BAD_OPTION_ERR, + _BAD_DOC_TEMPLATE, + _CLIENT_INFO, + _FIRESTORE_EMULATOR_HOST, + _reference_info, + _get_reference, + _parse_batch_get, + _get_doc_mask, + _item_to_collection_ref, +) from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import __version__ @@ -43,21 +53,7 @@ from google.cloud.firestore_v1.async_transaction import AsyncTransaction -DEFAULT_DATABASE = "(default)" -"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" -_BAD_OPTION_ERR = ( - "Exactly one of ``last_update_time`` or ``exists`` " "must be provided." -) -_BAD_DOC_TEMPLATE = ( - "Document {!r} appeared in response but was not present among references" -) -_ACTIVE_TXN = "There is already an active transaction." -_INACTIVE_TXN = "There is no active transaction." -_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) -_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST" - - -class AsyncClient(ClientWithProject): +class AsyncClient(Client): """Client for interacting with Google Cloud Firestore API. .. note:: @@ -85,16 +81,6 @@ class AsyncClient(ClientWithProject): should be set through client_options. """ - SCOPE = ( - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - ) - """The scopes required for authenticating with the Firestore service.""" - - _firestore_api_internal = None - _database_string_internal = None - _rpc_metadata_internal = None - def __init__( self, project=None, @@ -107,110 +93,12 @@ def __init__( # will have no impact since the _http() @property only lazily # creates a working HTTP object. super(AsyncClient, self).__init__( - project=project, credentials=credentials, _http=None + project=project, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, ) - self._client_info = client_info - if client_options: - if type(client_options) == dict: - client_options = google.api_core.client_options.from_dict( - client_options - ) - self._client_options = client_options - - self._database = database - self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST) - - @property - def _firestore_api(self): - """Lazy-loading getter GAPIC Firestore API. - - Returns: - :class:`~google.cloud.gapic.firestore.v1`.firestore_client.FirestoreClient: - >> snapshot.to_dict() - { - 'top1': { - 'middle2': { - 'bottom3': 20, - 'bottom4': 22, - }, - 'middle5': True, - }, - 'top6': b'\x00\x01 foo', - } - - a **field path** can be used to access the nested data. For - example: - - .. code-block:: python - - >>> snapshot.get('top1') - { - 'middle2': { - 'bottom3': 20, - 'bottom4': 22, - }, - 'middle5': True, - } - >>> snapshot.get('top1.middle2') - { - 'bottom3': 20, - 'bottom4': 22, - } - >>> snapshot.get('top1.middle2.bottom3') - 20 - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - A copy is returned since the data may contain mutable values, - but the data stored in the snapshot must remain immutable. - - Args: - field_path (str): A field path (``.``-delimited list of - field names). - - Returns: - Any or None: - (A copy of) the value stored for the ``field_path`` or - None if snapshot document does not exist. - - Raises: - KeyError: If the ``field_path`` does not match nested data - in the snapshot. - """ - if not self._exists: - return None - nested_data = field_path_module.get_nested_value(field_path, self._data) - return copy.deepcopy(nested_data) - - def to_dict(self): - """Retrieve the data contained in this snapshot. - - A copy is returned since the data may contain mutable values, - but the data stored in the snapshot must remain immutable. - - Returns: - Dict[str, Any] or None: - The data in the snapshot. Returns None if reference - does not exist. - """ - if not self._exists: - return None - return copy.deepcopy(self._data) - - def _get_document_path(client, path): """Convert a path tuple into a full path string. From 81102650a64b5639d85d02cca8b4dd0095f90012 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 14:23:20 -0500 Subject: [PATCH 21/47] feat: remove duplicate code from async_batch --- google/cloud/firestore_v1/async_batch.py | 103 +---------------------- 1 file changed, 3 insertions(+), 100 deletions(-) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index eed0bacbfa..2afeb2ebbf 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -16,9 +16,10 @@ from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.batch import WriteBatch -class AsyncWriteBatch(object): +class AsyncWriteBatch(WriteBatch): """Accumulate write operations to be sent in a batch. This has the same set of methods for write operations that @@ -31,105 +32,7 @@ class AsyncWriteBatch(object): """ def __init__(self, client): - self._client = client - self._write_pbs = [] - self.write_results = None - self.commit_time = None - - def _add_write_pbs(self, write_pbs): - """Add `Write`` protobufs to this transaction. - - This method intended to be over-ridden by subclasses. - - Args: - write_pbs (List[google.cloud.proto.firestore.v1.\ - write_pb2.Write]): A list of write protobufs to be added. - """ - self._write_pbs.extend(write_pbs) - - def create(self, reference, document_data): - """Add a "change" to this batch to create a document. - - If the document given by ``reference`` already exists, then this - batch will fail when :meth:`commit`-ed. - - Args: - reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): - A document reference to be created in this batch. - document_data (dict): Property names and values to use for - creating a document. - """ - write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) - self._add_write_pbs(write_pbs) - - def set(self, reference, document_data, merge=False): - """Add a "change" to replace a document. - - See - :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.set` for - more information on how ``option`` determines how the change is - applied. - - Args: - reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): - A document reference that will have values set in this batch. - document_data (dict): - Property names and values to use for replacing a document. - merge (Optional[bool] or Optional[List]): - If True, apply merging instead of overwriting the state - of the document. - """ - if merge is not False: - write_pbs = _helpers.pbs_for_set_with_merge( - reference._document_path, document_data, merge - ) - else: - write_pbs = _helpers.pbs_for_set_no_merge( - reference._document_path, document_data - ) - - self._add_write_pbs(write_pbs) - - def update(self, reference, field_updates, option=None): - """Add a "change" to update a document. - - See - :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.update` - for more information on ``field_updates`` and ``option``. - - Args: - reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): - A document reference that will be updated in this batch. - field_updates (dict): - Field names or paths to update and values to update with. - option (Optional[:class:`~google.cloud.firestore_v1.async_client.WriteOption`]): - A write option to make assertions / preconditions on the server - state of the document before applying changes. - """ - if option.__class__.__name__ == "ExistsOption": - raise ValueError("you must not pass an explicit write option to " "update.") - write_pbs = _helpers.pbs_for_update( - reference._document_path, field_updates, option - ) - self._add_write_pbs(write_pbs) - - def delete(self, reference, option=None): - """Add a "change" to delete a document. - - See - :meth:`google.cloud.firestore_v1.async_document.AsyncDocumentReference.delete` - for more information on how ``option`` determines how the change is - applied. - - Args: - reference (:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`): - A document reference that will be deleted in this batch. - option (Optional[:class:`~google.cloud.firestore_v1.async_client.WriteOption`]): - A write option to make assertions / preconditions on the server - state of the document before applying changes. - """ - write_pb = _helpers.pb_for_delete(reference._document_path, option) - self._add_write_pbs([write_pb]) + super(AsyncWriteBatch, self).__init__(client=client) async def commit(self): """Commit the changes accumulated in this batch. From 80270deb1835878f599a5c5bc4e3a3363b4bde41 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 14:58:35 -0500 Subject: [PATCH 22/47] feat: remove duplicate code from async_collection --- google/cloud/firestore_v1/async_collection.py | 88 ++----------------- tests/unit/v1/async/test_async_collection.py | 2 +- 2 files changed, 9 insertions(+), 81 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index c6a0fea3cb..32c4baf654 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -18,15 +18,19 @@ import six +from google.cloud.firestore_v1.collection import ( + CollectionReference, + _AUTO_ID_CHARS, + _auto_id, + _item_to_document_ref, +) from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import async_document -_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" - -class AsyncCollectionReference(object): +class AsyncCollectionReference(CollectionReference): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation @@ -53,83 +57,7 @@ class AsyncCollectionReference(object): """ def __init__(self, *path, **kwargs): - _helpers.verify_path(path, is_collection=True) - self._path = path - self._client = kwargs.pop("client", None) - if kwargs: - raise TypeError( - "Received unexpected arguments", kwargs, "Only `client` is supported" - ) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self._path == other._path and self._client == other._client - - @property - def id(self): - """The collection identifier. - - Returns: - str: The last component of the path. - """ - return self._path[-1] - - @property - def parent(self): - """Document that owns the current collection. - - Returns: - Optional[:class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`]: - The parent document, if the current collection is not a - top-level collection. - """ - if len(self._path) == 1: - return None - else: - parent_path = self._path[:-1] - return self._client.document(*parent_path) - - def document(self, document_id=None): - """Create a sub-document underneath the current collection. - - Args: - document_id (Optional[str]): The document identifier - within the current collection. If not provided, will default - to a random 20 character string composed of digits, - uppercase and lowercase and letters. - - Returns: - :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference`: - The child document. - """ - if document_id is None: - document_id = _auto_id() - - child_path = self._path + (document_id,) - return self._client.document(*child_path) - - def _parent_info(self): - """Get fully-qualified parent path and prefix for this collection. - - Returns: - Tuple[str, str]: Pair of - - * the fully-qualified (with database and project) path to the - parent of this collection (will either be the database path - or a document path). - * the prefix to a document in this collection. - """ - parent_doc = self.parent - if parent_doc is None: - parent_path = _helpers.DOCUMENT_PATH_DELIMITER.join( - (self._client._database_string, "documents") - ) - else: - parent_path = parent_doc._document_path - - expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) - return parent_path, expected_prefix + super(AsyncCollectionReference, self).__init__(*path, **kwargs) async def add(self, document_data, document_id=None): """Create a document in the Firestore database with the provided data. diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 26a11ea735..3865d32a50 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -140,7 +140,7 @@ def test_document_factory_explicit_id(self): self.assertEqual(child._path, (collection_id, document_id)) @mock.patch( - "google.cloud.firestore_v1.async_collection._auto_id", + "google.cloud.firestore_v1.collection._auto_id", return_value="zorpzorpthreezorp012", ) def test_document_factory_auto_id(self, mock_auto_id): From 47199eadbbe4adc7ec2edfda4bd164f4e75de180 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 15:11:40 -0500 Subject: [PATCH 23/47] feat: remove duplicate code from async_document --- google/cloud/firestore_v1/async_document.py | 232 +------------------- google/cloud/firestore_v1/document.py | 4 +- 2 files changed, 12 insertions(+), 224 deletions(-) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 7bd2b43fb3..7d17417179 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -18,16 +18,22 @@ import six -from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot +from google.cloud.firestore_v1.document import ( + DocumentReference, + DocumentSnapshot, + _get_document_path, + _consume_single_get, + _first_write_result, + _item_to_collection_ref, +) from google.api_core import exceptions from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1.proto import common_pb2 from google.cloud.firestore_v1.watch import Watch -class AsyncDocumentReference(object): +class AsyncDocumentReference(DocumentReference): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. @@ -52,137 +58,8 @@ class AsyncDocumentReference(object): TypeError: If a keyword other than ``client`` is used. """ - _document_path_internal = None - def __init__(self, *path, **kwargs): - _helpers.verify_path(path, is_collection=False) - self._path = path - self._client = kwargs.pop("client", None) - if kwargs: - raise TypeError( - "Received unexpected arguments", kwargs, "Only `client` is supported" - ) - - def __copy__(self): - """Shallow copy the instance. - - We leave the client "as-is" but tuple-unpack the path. - - Returns: - .AsyncDocumentReference: A copy of the current document. - """ - result = self.__class__(*self._path, client=self._client) - result._document_path_internal = self._document_path_internal - return result - - def __deepcopy__(self, unused_memo): - """Deep copy the instance. - - This isn't a true deep copy, wee leave the client "as-is" but - tuple-unpack the path. - - Returns: - .AsyncDocumentReference: A copy of the current document. - """ - return self.__copy__() - - def __eq__(self, other): - """Equality check against another instance. - - Args: - other (Any): A value to compare against. - - Returns: - Union[bool, NotImplementedType]: Indicating if the values are - equal. - """ - if isinstance(other, AsyncDocumentReference): - return self._client == other._client and self._path == other._path - else: - return NotImplemented - - def __hash__(self): - return hash(self._path) + hash(self._client) - - def __ne__(self, other): - """Inequality check against another instance. - - Args: - other (Any): A value to compare against. - - Returns: - Union[bool, NotImplementedType]: Indicating if the values are - not equal. - """ - if isinstance(other, AsyncDocumentReference): - return self._client != other._client or self._path != other._path - else: - return NotImplemented - - @property - def path(self): - """Database-relative for this document. - - Returns: - str: The document's relative path. - """ - return "/".join(self._path) - - @property - def _document_path(self): - """Create and cache the full path for this document. - - Of the form: - - ``projects/{project_id}/databases/{database_id}/... - documents/{document_path}`` - - Returns: - str: The full document path. - - Raises: - ValueError: If the current document reference has no ``client``. - """ - if self._document_path_internal is None: - if self._client is None: - raise ValueError("A document reference requires a `client`.") - self._document_path_internal = _get_document_path(self._client, self._path) - - return self._document_path_internal - - @property - def id(self): - """The document identifier (within its collection). - - Returns: - str: The last component of the path. - """ - return self._path[-1] - - @property - def parent(self): - """Collection that owns the current document. - - Returns: - :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: - The parent collection. - """ - parent_path = self._path[:-1] - return self._client.collection(*parent_path) - - def collection(self, collection_id): - """Create a sub-collection underneath the current document. - - Args: - collection_id (str): The sub-collection identifier (sometimes - referred to as the "kind"). - - Returns: - :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: - The child collection. - """ - child_path = self._path + (collection_id,) - return self._client.collection(*child_path) + super(AsyncDocumentReference, self).__init__(*path, **kwargs) async def create(self, document_data): """Create the current document in the Firestore database. @@ -530,92 +407,3 @@ def on_snapshot(document_snapshot, changes, read_time): return Watch.for_document( self, callback, DocumentSnapshot, AsyncDocumentReference ) - - -def _get_document_path(client, path): - """Convert a path tuple into a full path string. - - Of the form: - - ``projects/{project_id}/databases/{database_id}/... - documents/{document_path}`` - - Args: - client (:class:`~google.cloud.firestore_v1.client.Client`): - The client that holds configuration details and a GAPIC client - object. - path (Tuple[str, ...]): The components in a document path. - - Returns: - str: The fully-qualified document path. - """ - parts = (client._database_string, "documents") + path - return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) - - -def _consume_single_get(response_iterator): - """Consume a gRPC stream that should contain a single response. - - The stream will correspond to a ``BatchGetDocuments`` request made - for a single document. - - Args: - response_iterator (~google.cloud.exceptions.GrpcRendezvous): A - streaming iterator returned from a ``BatchGetDocuments`` - request. - - Returns: - ~google.cloud.proto.firestore.v1.\ - firestore_pb2.BatchGetDocumentsResponse: The single "get" - response in the batch. - - Raises: - ValueError: If anything other than exactly one response is returned. - """ - # Calling ``list()`` consumes the entire iterator. - all_responses = list(response_iterator) - if len(all_responses) != 1: - raise ValueError( - "Unexpected response from `BatchGetDocumentsResponse`", - all_responses, - "Expected only one result", - ) - - return all_responses[0] - - -def _first_write_result(write_results): - """Get first write result from list. - - For cases where ``len(write_results) > 1``, this assumes the writes - occurred at the same time (e.g. if an update and transform are sent - at the same time). - - Args: - write_results (List[google.cloud.proto.firestore.v1.\ - write_pb2.WriteResult, ...]: The write results from a - ``CommitResponse``. - - Returns: - google.cloud.firestore_v1.types.WriteResult: The - lone write result from ``write_results``. - - Raises: - ValueError: If there are zero write results. This is likely to - **never** occur, since the backend should be stable. - """ - if not write_results: - raise ValueError("Expected at least one write result") - - return write_results[0] - - -def _item_to_collection_ref(iterator, item): - """Convert collection ID to collection ref. - - Args: - iterator (google.api_core.page_iterator.GRPCIterator): - iterator response - item (str): ID of the collection - """ - return iterator.document.collection(item) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 571315e875..c51c7c5c74 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -94,7 +94,7 @@ def __eq__(self, other): Union[bool, NotImplementedType]: Indicating if the values are equal. """ - if isinstance(other, DocumentReference): + if isinstance(other, self.__class__): return self._client == other._client and self._path == other._path else: return NotImplemented @@ -112,7 +112,7 @@ def __ne__(self, other): Union[bool, NotImplementedType]: Indicating if the values are not equal. """ - if isinstance(other, DocumentReference): + if isinstance(other, self.__class__): return self._client != other._client or self._path != other._path else: return NotImplemented From da04662edce88291aa494055dd1852a288e913f4 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 16:01:35 -0500 Subject: [PATCH 24/47] fix: remove unused imports --- google/cloud/firestore_v1/async_batch.py | 1 - google/cloud/firestore_v1/async_client.py | 11 - google/cloud/firestore_v1/async_collection.py | 26 - google/cloud/firestore_v1/async_document.py | 4 - google/cloud/firestore_v1/async_query.py | 883 +----------------- tests/unit/v1/async/test_async_client.py | 14 +- tests/unit/v1/async/test_async_collection.py | 2 +- tests/unit/v1/async/test_async_document.py | 6 +- tests/unit/v1/async/test_async_query.py | 10 +- 9 files changed, 38 insertions(+), 919 deletions(-) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 2afeb2ebbf..3aa882800c 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -15,7 +15,6 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.batch import WriteBatch diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 4567e7e7db..8b21396a42 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -23,33 +23,22 @@ * a :class:`~google.cloud.firestore_v1.client.Client` owns a :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ -import os from google.cloud.firestore_v1.client import ( Client, DEFAULT_DATABASE, - _BAD_OPTION_ERR, - _BAD_DOC_TEMPLATE, _CLIENT_INFO, - _FIRESTORE_EMULATOR_HOST, _reference_info, - _get_reference, _parse_batch_get, _get_doc_mask, _item_to_collection_ref, ) from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1.async_query import AsyncQuery -from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import AsyncDocumentReference -from google.cloud.firestore_v1.async_document import DocumentSnapshot -from google.cloud.firestore_v1.field_path import render_field_path -from google.cloud.firestore_v1.gapic import firestore_client -from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport from google.cloud.firestore_v1.async_transaction import AsyncTransaction diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 32c4baf654..a10f82e256 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -13,18 +13,14 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" -import random import warnings -import six from google.cloud.firestore_v1.collection import ( CollectionReference, - _AUTO_ID_CHARS, _auto_id, _item_to_document_ref, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import async_document @@ -375,25 +371,3 @@ def on_snapshot(collection_snapshot, changes, read_time): async_document.DocumentSnapshot, async_document.AsyncDocumentReference, ) - - -def _auto_id(): - """Generate a "random" automatically generated ID. - - Returns: - str: A 20 character string composed of digits, uppercase and - lowercase and letters. - """ - return "".join(random.choice(_AUTO_ID_CHARS) for _ in six.moves.xrange(20)) - - -def _item_to_document_ref(iterator, item): - """Convert Document resource to document ref. - - Args: - iterator (google.api_core.page_iterator.GRPCIterator): - iterator response - item (dict): document resource - """ - document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] - return iterator.collection.document(document_id) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 7d17417179..64c8b394d2 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -14,15 +14,11 @@ """Classes for representing documents for the Google Cloud Firestore API.""" -import copy - import six from google.cloud.firestore_v1.document import ( DocumentReference, DocumentSnapshot, - _get_document_path, - _consume_single_get, _first_write_result, _item_to_collection_ref, ) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index b6001914ab..f8c33e16df 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -18,55 +18,20 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ -import copy -import math import warnings -from google.protobuf import wrappers_pb2 -import six +from google.cloud.firestore_v1.query import ( + Query, + _query_response_to_snapshot, + _collection_group_query_response_to_snapshot, +) from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import document -from google.cloud.firestore_v1 import field_path as field_path_module -from google.cloud.firestore_v1 import transforms -from google.cloud.firestore_v1.gapic import enums -from google.cloud.firestore_v1.proto import query_pb2 -from google.cloud.firestore_v1.order import Order +from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.watch import Watch -_EQ_OP = "==" -_operator_enum = enums.StructuredQuery.FieldFilter.Operator -_COMPARISON_OPERATORS = { - "<": _operator_enum.LESS_THAN, - "<=": _operator_enum.LESS_THAN_OR_EQUAL, - _EQ_OP: _operator_enum.EQUAL, - ">=": _operator_enum.GREATER_THAN_OR_EQUAL, - ">": _operator_enum.GREATER_THAN, - "array_contains": _operator_enum.ARRAY_CONTAINS, - "in": _operator_enum.IN, - "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, -} -_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." -_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' -_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." -_BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." -_INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." -_MISSING_ORDER_BY = ( - 'The "order by" field path {!r} is not present in the cursor data {!r}. ' - "All fields sent to ``order_by()`` must be present in the fields " - "if passed to one of ``start_at()`` / ``start_after()`` / " - "``end_before()`` / ``end_at()`` to define a cursor." -) -_NO_ORDERS_FOR_CURSOR = ( - "Attempting to create a cursor with no fields to order on. " - "When defining a cursor with one of ``start_at()`` / ``start_after()`` / " - "``end_before()`` / ``end_at()``, all fields in the cursor must " - "come from fields set in ``order_by()``." -) -_MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." - -class AsyncQuery(object): +class AsyncQuery(Query): """Represents a query to the Firestore API. Instances of this class are considered immutable: all methods that @@ -122,11 +87,6 @@ class AsyncQuery(object): When true, selects all descendant collections. """ - ASCENDING = "ASCENDING" - """str: Sort query results in ascending order on a field.""" - DESCENDING = "DESCENDING" - """str: Sort query results in descending order on a field.""" - def __init__( self, parent, @@ -139,595 +99,18 @@ def __init__( end_at=None, all_descendants=False, ): - self._parent = parent - self._projection = projection - self._field_filters = field_filters - self._orders = orders - self._limit = limit - self._offset = offset - self._start_at = start_at - self._end_at = end_at - self._all_descendants = all_descendants - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return ( - self._parent == other._parent - and self._projection == other._projection - and self._field_filters == other._field_filters - and self._orders == other._orders - and self._limit == other._limit - and self._offset == other._offset - and self._start_at == other._start_at - and self._end_at == other._end_at - and self._all_descendants == other._all_descendants - ) - - @property - def _client(self): - """The client of the parent collection. - - Returns: - :class:`~google.cloud.firestore_v1.client.Client`: - The client that owns this query. - """ - return self._parent._client - - def select(self, field_paths): - """Project documents matching query to a limited set of fields. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - If the current query already has a projection set (i.e. has already - called :meth:`~google.cloud.firestore_v1.query.AsyncQuery.select`), this - will overwrite it. - - Args: - field_paths (Iterable[str, ...]): An iterable of field paths - (``.``-delimited list of field names) to use as a projection - of document fields in the query results. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A "projected" query. Acts as a copy of the current query, - modified with the newly added projection. - Raises: - ValueError: If any ``field_path`` is invalid. - """ - field_paths = list(field_paths) - for field_path in field_paths: - field_path_module.split_field_path(field_path) # raises - - new_projection = query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ) - return self.__class__( - self._parent, - projection=new_projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def where(self, field_path, op_string, value): - """Filter the query on a field. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - Returns a new :class:`~google.cloud.firestore_v1.query.AsyncQuery` that - filters on a specific field path, according to an operation (e.g. - ``==`` or "equals") and a particular value to be paired with that - operation. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) for the field to filter on. - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, - ``in``, ``array_contains`` and ``array_contains_any``. - value (Any): The value to compare the field against in the filter. - If ``value`` is :data:`None` or a NaN, then ``==`` is the only - allowed operation. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A filtered query. Acts as a copy of the current query, - modified with the newly added filter. - - Raises: - ValueError: If ``field_path`` is invalid. - ValueError: If ``value`` is a NaN or :data:`None` and - ``op_string`` is not ``==``. - """ - field_path_module.split_field_path(field_path) # raises - - if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - filter_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - elif _isnan(value): - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - filter_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NAN, - ) - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): - raise ValueError(_INVALID_WHERE_TRANSFORM) - else: - filter_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=_enum_from_op_string(op_string), - value=_helpers.encode_value(value), - ) - - new_filters = self._field_filters + (filter_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=new_filters, - orders=self._orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - @staticmethod - def _make_order(field_path, direction): - """Helper for :meth:`order_by`.""" - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=_enum_from_direction(direction), - ) - - def order_by(self, field_path, direction=ASCENDING): - """Modify the query to add an order clause on a specific field. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - Successive :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by` - calls will further refine the ordering of results returned by the query - (i.e. the new "order by" fields will be added to existing ones). - - Args: - field_path (str): A field path (``.``-delimited list of - field names) on which to order the query results. - direction (Optional[str]): The direction to order by. Must be one - of :attr:`ASCENDING` or :attr:`DESCENDING`, defaults to - :attr:`ASCENDING`. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - An ordered query. Acts as a copy of the current query, modified - with the newly added "order by" constraint. - - Raises: - ValueError: If ``field_path`` is invalid. - ValueError: If ``direction`` is not one of :attr:`ASCENDING` or - :attr:`DESCENDING`. - """ - field_path_module.split_field_path(field_path) # raises - - order_pb = self._make_order(field_path, direction) - - new_orders = self._orders + (order_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=new_orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def limit(self, count): - """Limit a query to return a fixed number of results. - - If the current query already has a limit set, this will overwrite it. - - Args: - count (int): Maximum number of documents to return that match - the query. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A limited query. Acts as a copy of the current query, modified - with the newly added "limit" filter. - """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=count, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, + super(AsyncQuery, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, ) - def offset(self, num_to_skip): - """Skip to an offset in a query. - - If the current query already has specified an offset, this will - overwrite it. - - Args: - num_to_skip (int): The number of results to skip at the beginning - of query results. (Must be non-negative.) - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - An offset query. Acts as a copy of the current query, modified - with the newly added "offset" field. - """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - offset=num_to_skip, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def _check_snapshot(self, document_fields): - """Validate local snapshots for non-collection-group queries. - - Raises: - ValueError: for non-collection-group queries, if the snapshot - is from a different collection. - """ - if self._all_descendants: - return - - if document_fields.reference._path[:-1] != self._parent._path: - raise ValueError("Cannot use snapshot from another collection as a cursor.") - - def _cursor_helper(self, document_fields, before, start): - """Set values to be used for a ``start_at`` or ``end_at`` cursor. - - The values will later be used in a query protobuf. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - before (bool): Flag indicating if the document in - ``document_fields`` should (:data:`False`) or - shouldn't (:data:`True`) be included in the result set. - start (Optional[bool]): determines if the cursor is a ``start_at`` - cursor (:data:`True`) or an ``end_at`` cursor (:data:`False`). - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "start at" cursor. - """ - if isinstance(document_fields, tuple): - document_fields = list(document_fields) - elif isinstance(document_fields, document.DocumentSnapshot): - self._check_snapshot(document_fields) - else: - # NOTE: We copy so that the caller can't modify after calling. - document_fields = copy.deepcopy(document_fields) - - cursor_pair = document_fields, before - query_kwargs = { - "projection": self._projection, - "field_filters": self._field_filters, - "orders": self._orders, - "limit": self._limit, - "offset": self._offset, - "all_descendants": self._all_descendants, - } - if start: - query_kwargs["start_at"] = cursor_pair - query_kwargs["end_at"] = self._end_at - else: - query_kwargs["start_at"] = self._start_at - query_kwargs["end_at"] = cursor_pair - - return self.__class__(self._parent, **query_kwargs) - - def start_at(self, document_fields): - """Start query results at a particular document value. - - The result set will **include** the document specified by - ``document_fields``. - - If the current query already has specified a start cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.start_after` -- this - will overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A query with cursor. Acts as - a copy of the current query, modified with the newly added - "start at" cursor. - """ - return self._cursor_helper(document_fields, before=True, start=True) - - def start_after(self, document_fields): - """Start query results after a particular document value. - - The result set will **exclude** the document specified by - ``document_fields``. - - If the current query already has specified a start cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.start_at` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "start after" cursor. - """ - return self._cursor_helper(document_fields, before=False, start=True) - - def end_before(self, document_fields): - """End query results before a particular document value. - - The result set will **exclude** the document specified by - ``document_fields``. - - If the current query already has specified an end cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.end_at` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "end before" cursor. - """ - return self._cursor_helper(document_fields, before=True, start=False) - - def end_at(self, document_fields): - """End query results at a particular document value. - - The result set will **include** the document specified by - ``document_fields``. - - If the current query already has specified an end cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.end_before` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.AsyncQuery.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.AsyncQuery`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "end at" cursor. - """ - return self._cursor_helper(document_fields, before=False, start=False) - - def _filters_pb(self): - """Convert all the filters into a single generic Filter protobuf. - - This may be a lone field filter or unary filter, may be a composite - filter or may be :data:`None`. - - Returns: - :class:`google.cloud.firestore_v1.types.StructuredQuery.Filter`: - A "generic" filter representing the current query's filters. - """ - num_filters = len(self._field_filters) - if num_filters == 0: - return None - elif num_filters == 1: - return _filter_pb(self._field_filters[0]) - else: - composite_filter = query_pb2.StructuredQuery.CompositeFilter( - op=enums.StructuredQuery.CompositeFilter.Operator.AND, - filters=[_filter_pb(filter_) for filter_ in self._field_filters], - ) - return query_pb2.StructuredQuery.Filter(composite_filter=composite_filter) - - @staticmethod - def _normalize_projection(projection): - """Helper: convert field paths to message.""" - if projection is not None: - - fields = list(projection.fields) - - if not fields: - field_ref = query_pb2.StructuredQuery.FieldReference( - field_path="__name__" - ) - return query_pb2.StructuredQuery.Projection(fields=[field_ref]) - - return projection - - def _normalize_orders(self): - """Helper: adjust orders based on cursors, where clauses.""" - orders = list(self._orders) - _has_snapshot_cursor = False - - if self._start_at: - if isinstance(self._start_at[0], document.DocumentSnapshot): - _has_snapshot_cursor = True - - if self._end_at: - if isinstance(self._end_at[0], document.DocumentSnapshot): - _has_snapshot_cursor = True - - if _has_snapshot_cursor: - should_order = [ - _enum_from_op_string(key) - for key in _COMPARISON_OPERATORS - if key not in (_EQ_OP, "array_contains") - ] - order_keys = [order.field.field_path for order in orders] - for filter_ in self._field_filters: - field = filter_.field.field_path - if filter_.op in should_order and field not in order_keys: - orders.append(self._make_order(field, "ASCENDING")) - if not orders: - orders.append(self._make_order("__name__", "ASCENDING")) - else: - order_keys = [order.field.field_path for order in orders] - if "__name__" not in order_keys: - direction = orders[-1].direction # enum? - orders.append(self._make_order("__name__", direction)) - - return orders - - def _normalize_cursor(self, cursor, orders): - """Helper: convert cursor to a list of values based on orders.""" - if cursor is None: - return - - if not orders: - raise ValueError(_NO_ORDERS_FOR_CURSOR) - - document_fields, before = cursor - - order_keys = [order.field.field_path for order in orders] - - if isinstance(document_fields, document.DocumentSnapshot): - snapshot = document_fields - document_fields = snapshot.to_dict() - document_fields["__name__"] = snapshot.reference - - if isinstance(document_fields, dict): - # Transform to list using orders - values = [] - data = document_fields - for order_key in order_keys: - try: - if order_key in data: - values.append(data[order_key]) - else: - values.append( - field_path_module.get_nested_value(order_key, data) - ) - except KeyError: - msg = _MISSING_ORDER_BY.format(order_key, data) - raise ValueError(msg) - document_fields = values - - if len(document_fields) != len(orders): - msg = _MISMATCH_CURSOR_W_ORDER_BY.format(document_fields, order_keys) - raise ValueError(msg) - - _transform_bases = (transforms.Sentinel, transforms._ValueList) - - for index, key_field in enumerate(zip(order_keys, document_fields)): - key, field = key_field - - if isinstance(field, _transform_bases): - msg = _INVALID_CURSOR_TRANSFORM - raise ValueError(msg) - - if key == "__name__" and isinstance(field, six.string_types): - document_fields[index] = self._parent.document(field) - - return document_fields, before - - def _to_protobuf(self): - """Convert the current query into the equivalent protobuf. - - Returns: - :class:`google.cloud.firestore_v1.types.StructuredQuery`: - The query protobuf. - """ - projection = self._normalize_projection(self._projection) - orders = self._normalize_orders() - start_at = self._normalize_cursor(self._start_at, orders) - end_at = self._normalize_cursor(self._end_at, orders) - - query_kwargs = { - "select": projection, - "from": [ - query_pb2.StructuredQuery.CollectionSelector( - collection_id=self._parent.id, all_descendants=self._all_descendants - ) - ], - "where": self._filters_pb(), - "order_by": orders, - "start_at": _cursor_pb(start_at), - "end_at": _cursor_pb(end_at), - } - if self._offset is not None: - query_kwargs["offset"] = self._offset - if self._limit is not None: - query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) - - return query_pb2.StructuredQuery(**query_kwargs) - async def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" warnings.warn( @@ -762,7 +145,7 @@ async def stream(self, transaction=None): An existing transaction that this query will run in. Yields: - :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + :class:`~google.cloud.firestore_v1.async_document.DocumentSnapshot`: The next document that fulfills the query. """ parent_path, expected_prefix = self._parent._parent_info() @@ -815,230 +198,8 @@ def on_snapshot(docs, changes, read_time): query_watch.unsubscribe() """ return Watch.for_query( - self, callback, document.DocumentSnapshot, document.DocumentReference + self, + callback, + async_document.DocumentSnapshot, + async_document.AsyncDocumentReference, ) - - def _comparator(self, doc1, doc2): - _orders = self._orders - - # Add implicit sorting by name, using the last specified direction. - if len(_orders) == 0: - lastDirection = AsyncQuery.ASCENDING - else: - if _orders[-1].direction == 1: - lastDirection = AsyncQuery.ASCENDING - else: - lastDirection = AsyncQuery.DESCENDING - - orderBys = list(_orders) - - order_pb = query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path="id"), - direction=_enum_from_direction(lastDirection), - ) - orderBys.append(order_pb) - - for orderBy in orderBys: - if orderBy.field.field_path == "id": - # If ordering by docuent id, compare resource paths. - comp = Order()._compare_to(doc1.reference._path, doc2.reference._path) - else: - if ( - orderBy.field.field_path not in doc1._data - or orderBy.field.field_path not in doc2._data - ): - raise ValueError( - "Can only compare fields that exist in the " - "DocumentSnapshot. Please include the fields you are " - "ordering on in your select() call." - ) - v1 = doc1._data[orderBy.field.field_path] - v2 = doc2._data[orderBy.field.field_path] - encoded_v1 = _helpers.encode_value(v1) - encoded_v2 = _helpers.encode_value(v2) - comp = Order().compare(encoded_v1, encoded_v2) - - if comp != 0: - # 1 == Ascending, -1 == Descending - return orderBy.direction * comp - - return 0 - - -def _enum_from_op_string(op_string): - """Convert a string representation of a binary operator to an enum. - - These enums come from the protobuf message definition - ``StructuredQuery.FieldFilter.Operator``. - - Args: - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` - and ``>``. - - Returns: - int: The enum corresponding to ``op_string``. - - Raises: - ValueError: If ``op_string`` is not a valid operator. - """ - try: - return _COMPARISON_OPERATORS[op_string] - except KeyError: - choices = ", ".join(sorted(_COMPARISON_OPERATORS.keys())) - msg = _BAD_OP_STRING.format(op_string, choices) - raise ValueError(msg) - - -def _isnan(value): - """Check if a value is NaN. - - This differs from ``math.isnan`` in that **any** input type is - allowed. - - Args: - value (Any): A value to check for NaN-ness. - - Returns: - bool: Indicates if the value is the NaN float. - """ - if isinstance(value, float): - return math.isnan(value) - else: - return False - - -def _enum_from_direction(direction): - """Convert a string representation of a direction to an enum. - - Args: - direction (str): A direction to order by. Must be one of - :attr:`~google.cloud.firestore.AsyncQuery.ASCENDING` or - :attr:`~google.cloud.firestore.AsyncQuery.DESCENDING`. - - Returns: - int: The enum corresponding to ``direction``. - - Raises: - ValueError: If ``direction`` is not a valid direction. - """ - if isinstance(direction, int): - return direction - - if direction == AsyncQuery.ASCENDING: - return enums.StructuredQuery.Direction.ASCENDING - elif direction == AsyncQuery.DESCENDING: - return enums.StructuredQuery.Direction.DESCENDING - else: - msg = _BAD_DIR_STRING.format( - direction, AsyncQuery.ASCENDING, AsyncQuery.DESCENDING - ) - raise ValueError(msg) - - -def _filter_pb(field_or_unary): - """Convert a specific protobuf filter to the generic filter type. - - Args: - field_or_unary (Union[google.cloud.proto.firestore.v1.\ - query_pb2.StructuredQuery.FieldFilter, google.cloud.proto.\ - firestore.v1.query_pb2.StructuredQuery.FieldFilter]): A - field or unary filter to convert to a generic filter. - - Returns: - google.cloud.firestore_v1.types.\ - StructuredQuery.Filter: A "generic" filter. - - Raises: - ValueError: If ``field_or_unary`` is not a field or unary filter. - """ - if isinstance(field_or_unary, query_pb2.StructuredQuery.FieldFilter): - return query_pb2.StructuredQuery.Filter(field_filter=field_or_unary) - elif isinstance(field_or_unary, query_pb2.StructuredQuery.UnaryFilter): - return query_pb2.StructuredQuery.Filter(unary_filter=field_or_unary) - else: - raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) - - -def _cursor_pb(cursor_pair): - """Convert a cursor pair to a protobuf. - - If ``cursor_pair`` is :data:`None`, just returns :data:`None`. - - Args: - cursor_pair (Optional[Tuple[list, bool]]): Two-tuple of - - * a list of field values. - * a ``before`` flag - - Returns: - Optional[google.cloud.firestore_v1.types.Cursor]: A - protobuf cursor corresponding to the values. - """ - if cursor_pair is not None: - data, before = cursor_pair - value_pbs = [_helpers.encode_value(value) for value in data] - return query_pb2.Cursor(values=value_pbs, before=before) - - -def _query_response_to_snapshot(response_pb, collection, expected_prefix): - """Parse a query response protobuf to a document snapshot. - - Args: - response_pb (google.cloud.proto.firestore.v1.\ - firestore_pb2.RunQueryResponse): A - collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): - A reference to the collection that initiated the query. - expected_prefix (str): The expected prefix for fully-qualified - document names returned in the query results. This can be computed - directly from ``collection`` via :meth:`_parent_info`. - - Returns: - Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: - A snapshot of the data returned in the query. If - ``response_pb.document`` is not set, the snapshot will be :data:`None`. - """ - if not response_pb.HasField("document"): - return None - - document_id = _helpers.get_doc_id(response_pb.document, expected_prefix) - reference = collection.document(document_id) - data = _helpers.decode_dict(response_pb.document.fields, collection._client) - snapshot = document.DocumentSnapshot( - reference, - data, - exists=True, - read_time=response_pb.read_time, - create_time=response_pb.document.create_time, - update_time=response_pb.document.update_time, - ) - return snapshot - - -def _collection_group_query_response_to_snapshot(response_pb, collection): - """Parse a query response protobuf to a document snapshot. - - Args: - response_pb (google.cloud.proto.firestore.v1.\ - firestore_pb2.RunQueryResponse): A - collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): - A reference to the collection that initiated the query. - - Returns: - Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: - A snapshot of the data returned in the query. If - ``response_pb.document`` is not set, the snapshot will be :data:`None`. - """ - if not response_pb.HasField("document"): - return None - reference = collection._client.document(response_pb.document.name) - data = _helpers.decode_dict(response_pb.document.fields, collection._client) - snapshot = document.DocumentSnapshot( - reference, - data, - exists=True, - read_time=response_pb.read_time, - create_time=response_pb.document.create_time, - update_time=response_pb.document.update_time, - ) - return snapshot diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index f5af7f5107..651ae4168f 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -51,7 +51,7 @@ def test_constructor(self): self.assertIsNone(client._emulator_host) def test_constructor_with_emulator_host(self): - from google.cloud.firestore_v1.async_client import _FIRESTORE_EMULATOR_HOST + from google.cloud.firestore_v1.client import _FIRESTORE_EMULATOR_HOST credentials = _make_credentials() emulator_host = "localhost:8081" @@ -306,7 +306,7 @@ def test_write_option_exists(self): self.assertTrue(option2._exists) def test_write_open_neither_arg(self): - from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.client import _BAD_OPTION_ERR klass = self._get_target_class() with self.assertRaises(TypeError) as exc_info: @@ -315,7 +315,7 @@ def test_write_open_neither_arg(self): self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) def test_write_multiple_args(self): - from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.client import _BAD_OPTION_ERR klass = self._get_target_class() with self.assertRaises(TypeError) as exc_info: @@ -324,7 +324,7 @@ def test_write_multiple_args(self): self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) def test_write_bad_arg(self): - from google.cloud.firestore_v1.async_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.client import _BAD_OPTION_ERR klass = self._get_target_class() with self.assertRaises(TypeError) as exc_info: @@ -474,7 +474,7 @@ async def test_get_all_with_transaction(self): @pytest.mark.asyncio async def test_get_all_unknown_result(self): - from google.cloud.firestore_v1.async_client import _BAD_DOC_TEMPLATE + from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE info = self._info_for_get_all({"z": 28.5}, {}) client, document, _, _, response = info @@ -601,7 +601,7 @@ def test_it(self): class Test__get_reference(aiounittest.AsyncTestCase): @staticmethod def _call_fut(document_path, reference_map): - from google.cloud.firestore_v1.async_client import _get_reference + from google.cloud.firestore_v1.client import _get_reference return _get_reference(document_path, reference_map) @@ -611,7 +611,7 @@ def test_success(self): self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) def test_failure(self): - from google.cloud.firestore_v1.async_client import _BAD_DOC_TEMPLATE + from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE doc_path = "1/888/call-now" with self.assertRaises(ValueError) as exc_info: diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 3865d32a50..021390d3b6 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -566,7 +566,7 @@ def _call_fut(): @mock.patch("random.choice") def test_it(self, mock_rand_choice): - from google.cloud.firestore_v1.async_collection import _AUTO_ID_CHARS + from google.cloud.firestore_v1.collection import _AUTO_ID_CHARS mock_result = "0123456789abcdefghij" mock_rand_choice.side_effect = list(mock_result) diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index 2f26f6a2a8..a86f588b48 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -766,7 +766,7 @@ def test_non_existent(self): class Test__get_document_path(aiounittest.AsyncTestCase): @staticmethod def _call_fut(client, path): - from google.cloud.firestore_v1.async_document import _get_document_path + from google.cloud.firestore_v1.document import _get_document_path return _get_document_path(client, path) @@ -785,7 +785,7 @@ def test_it(self): class Test__consume_single_get(aiounittest.AsyncTestCase): @staticmethod def _call_fut(response_iterator): - from google.cloud.firestore_v1.async_document import _consume_single_get + from google.cloud.firestore_v1.document import _consume_single_get return _consume_single_get(response_iterator) @@ -808,7 +808,7 @@ def test_failure_too_many(self): class Test__first_write_result(aiounittest.AsyncTestCase): @staticmethod def _call_fut(write_results): - from google.cloud.firestore_v1.async_document import _first_write_result + from google.cloud.firestore_v1.document import _first_write_result return _first_write_result(write_results) diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index 8f773a501e..b18a1c7c7c 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -1469,7 +1469,7 @@ def test_comparator_missing_order_by_field_in_data_raises(self): class Test__enum_from_op_string(aiounittest.AsyncTestCase): @staticmethod def _call_fut(op_string): - from google.cloud.firestore_v1.async_query import _enum_from_op_string + from google.cloud.firestore_v1.query import _enum_from_op_string return _enum_from_op_string(op_string) @@ -1521,7 +1521,7 @@ def test_invalid(self): class Test__isnan(aiounittest.AsyncTestCase): @staticmethod def _call_fut(value): - from google.cloud.firestore_v1.async_query import _isnan + from google.cloud.firestore_v1.query import _isnan return _isnan(value) @@ -1539,7 +1539,7 @@ def test_invalid(self): class Test__enum_from_direction(aiounittest.AsyncTestCase): @staticmethod def _call_fut(direction): - from google.cloud.firestore_v1.async_query import _enum_from_direction + from google.cloud.firestore_v1.query import _enum_from_direction return _enum_from_direction(direction) @@ -1563,7 +1563,7 @@ def test_failure(self): class Test__filter_pb(aiounittest.AsyncTestCase): @staticmethod def _call_fut(field_or_unary): - from google.cloud.firestore_v1.async_query import _filter_pb + from google.cloud.firestore_v1.query import _filter_pb return _filter_pb(field_or_unary) @@ -1601,7 +1601,7 @@ def test_bad_type(self): class Test__cursor_pb(aiounittest.AsyncTestCase): @staticmethod def _call_fut(cursor_pair): - from google.cloud.firestore_v1.async_query import _cursor_pb + from google.cloud.firestore_v1.query import _cursor_pb return _cursor_pb(cursor_pair) From 461f76ec39774f8ed605d99d1f2648ceb77135e8 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 16:01:55 -0500 Subject: [PATCH 25/47] fix: remove duplicate test --- tests/unit/v1/async/test_async_client.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 651ae4168f..bcae9228e3 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -598,29 +598,6 @@ def test_it(self): self.assertEqual(reference_map, expected_map) -class Test__get_reference(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(document_path, reference_map): - from google.cloud.firestore_v1.client import _get_reference - - return _get_reference(document_path, reference_map) - - def test_success(self): - doc_path = "a/b/c" - reference_map = {doc_path: mock.sentinel.reference} - self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) - - def test_failure(self): - from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE - - doc_path = "1/888/call-now" - with self.assertRaises(ValueError) as exc_info: - self._call_fut(doc_path, {}) - - err_msg = _BAD_DOC_TEMPLATE.format(doc_path) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - class Test__parse_batch_get(aiounittest.AsyncTestCase): @staticmethod def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): From 0df1c2ed00afa6d963b8b1890d4f6771a8067972 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 16:29:24 -0500 Subject: [PATCH 26/47] feat: remove duplicate code from async_transaction --- .../cloud/firestore_v1/async_transaction.py | 45 +++++++------------ tests/unit/v1/async/test_async_transaction.py | 8 ++-- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 0d0456318e..f22380fddd 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -20,6 +20,19 @@ import six +from google.cloud.firestore_v1.transaction import ( + MAX_ATTEMPTS, + _CANT_BEGIN, + _CANT_ROLLBACK, + _CANT_COMMIT, + _WRITE_READ_ONLY, + _INITIAL_SLEEP, + _MAX_SLEEP, + _MULTIPLIER, + _EXCEED_ATTEMPTS_TEMPLATE, + _CANT_RETRY_READ_ONLY, + _Transactional, +) from google.api_core import exceptions from google.cloud.firestore_v1 import async_batch from google.cloud.firestore_v1 import types @@ -27,23 +40,6 @@ from google.cloud.firestore_v1.async_query import AsyncQuery -MAX_ATTEMPTS = 5 -"""int: Default number of transaction attempts (with retries).""" -_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." -_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." -_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") -_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") -_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." -_INITIAL_SLEEP = 1.0 -"""float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" -_MAX_SLEEP = 30.0 -"""float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" -_MULTIPLIER = 2.0 -"""float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" -_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." -_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." - - class AsyncTransaction(async_batch.AsyncWriteBatch): """Accumulate read-and-write operations to be sent in a transaction. @@ -236,7 +232,7 @@ async def get(self, ref_or_query): ) -class _Transactional(object): +class _AsyncTransactional(_Transactional): """Provide a callable object to use as a transactional decorater. This is surfaced via @@ -248,16 +244,7 @@ class _Transactional(object): """ def __init__(self, to_wrap): - self.to_wrap = to_wrap - self.current_id = None - """Optional[bytes]: The current transaction ID.""" - self.retry_id = None - """Optional[bytes]: The ID of the first attempted transaction.""" - - def _reset(self): - """Unset the transaction IDs.""" - self.current_id = None - self.retry_id = None + super(_AsyncTransactional, self).__init__(to_wrap) async def _pre_commit(self, transaction, *args, **kwargs): """Begin transaction and call the wrapped callable. @@ -375,7 +362,7 @@ def transactional(to_wrap): Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: the wrapped callable. """ - return _Transactional(to_wrap) + return _AsyncTransactional(to_wrap) async def _commit_with_retry(client, write_pbs, transaction_id): diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index d4b0020832..57e8e6e1b6 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -381,9 +381,9 @@ async def test_get_failure(self): class Test_Transactional(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): - from google.cloud.firestore_v1.async_transaction import _Transactional + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional - return _Transactional + return _AsyncTransactional def _make_one(self, *args, **kwargs): klass = self._get_target_class() @@ -806,10 +806,10 @@ def _call_fut(to_wrap): return transactional(to_wrap) def test_it(self): - from google.cloud.firestore_v1.async_transaction import _Transactional + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional wrapped = self._call_fut(mock.sentinel.callable_) - self.assertIsInstance(wrapped, _Transactional) + self.assertIsInstance(wrapped, _AsyncTransactional) self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) From bac2541eeedf417c3775c0349bfbf1b7e60f6fff Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 17:48:43 -0500 Subject: [PATCH 27/47] fix: remove unused Python2 compatibility --- tests/unit/v1/async/test_async_query.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index b18a1c7c7c..296fead036 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -18,14 +18,9 @@ import aiounittest import mock -import six class TestAsyncQuery(aiounittest.AsyncTestCase): - - if six.PY2: - assertRaisesRegex = aiounittest.AsyncTestCase.assertRaisesRegexp - @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_query import AsyncQuery From 02003238f5b354b2759a64d88ae013ff126082dc Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 17:49:19 -0500 Subject: [PATCH 28/47] fix: resolve async generator tests --- google/cloud/firestore_v1/async_collection.py | 24 ++++++------- tests/unit/v1/async/test_async_collection.py | 36 ++++++++++++------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index a10f82e256..121d6f324d 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -21,7 +21,7 @@ _auto_id, _item_to_document_ref, ) -from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1 import async_query from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import async_document @@ -129,7 +129,7 @@ def select(self, field_paths): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A "projected" query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.select(field_paths) def where(self, field_path, op_string, value): @@ -153,7 +153,7 @@ def where(self, field_path, op_string, value): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A filtered query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.where(field_path, op_string, value) def order_by(self, field_path, **kwargs): @@ -175,7 +175,7 @@ def order_by(self, field_path, **kwargs): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: An "order by" query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.order_by(field_path, **kwargs) def limit(self, count): @@ -193,7 +193,7 @@ def limit(self, count): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A limited query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.limit(count) def offset(self, num_to_skip): @@ -211,7 +211,7 @@ def offset(self, num_to_skip): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: An offset query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.offset(num_to_skip) def start_at(self, document_fields): @@ -232,7 +232,7 @@ def start_at(self, document_fields): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.start_at(document_fields) def start_after(self, document_fields): @@ -253,7 +253,7 @@ def start_after(self, document_fields): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.start_after(document_fields) def end_before(self, document_fields): @@ -274,7 +274,7 @@ def end_before(self, document_fields): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.end_before(document_fields) def end_at(self, document_fields): @@ -295,7 +295,7 @@ def end_at(self, document_fields): :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: A query with cursor. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) return query.end_at(document_fields) async def get(self, transaction=None): @@ -335,7 +335,7 @@ async def stream(self, transaction=None): :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: The next document that fulfills the query. """ - query = AsyncQuery(self) + query = async_query.AsyncQuery(self) async for d in query.stream(transaction=transaction): yield d @@ -366,7 +366,7 @@ def on_snapshot(collection_snapshot, changes, read_time): collection_watch.unsubscribe() """ return Watch.for_query( - AsyncQuery(self), + async_query.AsyncQuery(self), callback, async_document.DocumentSnapshot, async_document.AsyncDocumentReference, diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 021390d3b6..80c4f5d7bc 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -490,27 +490,30 @@ async def test_list_documents_wo_page_size(self): async def test_list_documents_w_page_size(self): await self._list_documents_helper(page_size=25) - @pytest.mark.skip(reason="no way of currently testing this") @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - def test_get(self, query_class): + @pytest.mark.asyncio + async def test_get(self, query_class): import warnings collection = self._make_one("collection") with warnings.catch_warnings(record=True) as warned: get_response = collection.get() + async for _ in get_response: + pass + query_class.assert_called_once_with(collection) query_instance = query_class.return_value - self.assertIs(get_response, query_instance.stream.return_value) + # self.assertIs(get_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) # Verify the deprecation self.assertEqual(len(warned), 1) self.assertIs(warned[0].category, DeprecationWarning) - @pytest.mark.skip(reason="no way of currently testing this") @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - def test_get_with_transaction(self, query_class): + @pytest.mark.asyncio + async def test_get_with_transaction(self, query_class): import warnings collection = self._make_one("collection") @@ -518,36 +521,45 @@ def test_get_with_transaction(self, query_class): with warnings.catch_warnings(record=True) as warned: get_response = collection.get(transaction=transaction) + async for _ in get_response: + pass + query_class.assert_called_once_with(collection) query_instance = query_class.return_value - self.assertIs(get_response, query_instance.stream.return_value) + # self.assertIs(get_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=transaction) # Verify the deprecation self.assertEqual(len(warned), 1) self.assertIs(warned[0].category, DeprecationWarning) - @pytest.mark.skip(reason="no way of currently testing this") @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - def test_stream(self, query_class): + @pytest.mark.asyncio + async def test_stream(self, query_class): collection = self._make_one("collection") stream_response = collection.stream() + async for _ in stream_response: + pass + query_class.assert_called_once_with(collection) query_instance = query_class.return_value - self.assertIs(stream_response, query_instance.stream.return_value) + # self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) - @pytest.mark.skip(reason="no way of currently testing this") @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - def test_stream_with_transaction(self, query_class): + @pytest.mark.asyncio + async def test_stream_with_transaction(self, query_class): collection = self._make_one("collection") transaction = mock.sentinel.txn stream_response = collection.stream(transaction=transaction) + async for _ in stream_response: + pass + query_class.assert_called_once_with(collection) query_instance = query_class.return_value - self.assertIs(stream_response, query_instance.stream.return_value) + # self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=transaction) @mock.patch("google.cloud.firestore_v1.async_collection.Watch", autospec=True) From 8362fc82eee7c138a30f9c4e0d15f5a0e291766b Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 22 Jun 2020 18:21:16 -0500 Subject: [PATCH 29/47] fix: create mock async generator to get full coverage --- tests/unit/v1/async/test_async_collection.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 80c4f5d7bc..874f825100 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -20,6 +20,15 @@ import six +class MockAsyncIter: + def __init__(self, count): + self.count = count + + async def __aiter__(self, **_): + for i in range(self.count): + yield i + + class TestAsyncCollectionReference(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): @@ -495,6 +504,8 @@ async def test_list_documents_w_page_size(self): async def test_get(self, query_class): import warnings + query_class.return_value.stream.return_value = MockAsyncIter(3) + collection = self._make_one("collection") with warnings.catch_warnings(record=True) as warned: get_response = collection.get() @@ -504,7 +515,6 @@ async def test_get(self, query_class): query_class.assert_called_once_with(collection) query_instance = query_class.return_value - # self.assertIs(get_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) # Verify the deprecation @@ -516,6 +526,8 @@ async def test_get(self, query_class): async def test_get_with_transaction(self, query_class): import warnings + query_class.return_value.stream.return_value = MockAsyncIter(3) + collection = self._make_one("collection") transaction = mock.sentinel.txn with warnings.catch_warnings(record=True) as warned: @@ -526,7 +538,6 @@ async def test_get_with_transaction(self, query_class): query_class.assert_called_once_with(collection) query_instance = query_class.return_value - # self.assertIs(get_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=transaction) # Verify the deprecation @@ -536,6 +547,8 @@ async def test_get_with_transaction(self, query_class): @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_stream(self, query_class): + query_class.return_value.stream.return_value = MockAsyncIter(3) + collection = self._make_one("collection") stream_response = collection.stream() @@ -544,12 +557,13 @@ async def test_stream(self, query_class): query_class.assert_called_once_with(collection) query_instance = query_class.return_value - # self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=None) @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_stream_with_transaction(self, query_class): + query_class.return_value.stream.return_value = MockAsyncIter(3) + collection = self._make_one("collection") transaction = mock.sentinel.txn stream_response = collection.stream(transaction=transaction) @@ -559,7 +573,6 @@ async def test_stream_with_transaction(self, query_class): query_class.assert_called_once_with(collection) query_instance = query_class.return_value - # self.assertIs(stream_response, query_instance.stream.return_value) query_instance.stream.assert_called_once_with(transaction=transaction) @mock.patch("google.cloud.firestore_v1.async_collection.Watch", autospec=True) From 758a8d6405ea6f2c47ebaf81ece4e11bb91148d8 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 23 Jun 2020 13:46:26 -0500 Subject: [PATCH 30/47] fix: copyright date --- google/cloud/firestore_v1/async_batch.py | 2 +- google/cloud/firestore_v1/async_client.py | 2 +- google/cloud/firestore_v1/async_collection.py | 2 +- google/cloud/firestore_v1/async_document.py | 2 +- google/cloud/firestore_v1/async_query.py | 2 +- google/cloud/firestore_v1/async_transaction.py | 2 +- tests/unit/v1/async/test_async_batch.py | 2 +- tests/unit/v1/async/test_async_client.py | 2 +- tests/unit/v1/async/test_async_collection.py | 2 +- tests/unit/v1/async/test_async_document.py | 2 +- tests/unit/v1/async/test_async_query.py | 2 +- tests/unit/v1/async/test_async_transaction.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 3aa882800c..495bee06ce 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 8b21396a42..d93ca6fb54 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 202 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 121d6f324d..cdc1a80fba 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 64c8b394d2..1111f6b19d 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index f8c33e16df..83024284ef 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index f22380fddd..069f8168c3 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index aa71999b5c..301bc58a81 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index bcae9228e3..76220f01e1 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 874f825100..055d7ae353 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index a86f588b48..09b2f951e5 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index 296fead036..73e5f3d764 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index 57e8e6e1b6..32b061c8d0 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 07ea88301c537827cfb75295f4860d5f2d9b0652 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 23 Jun 2020 18:46:06 -0500 Subject: [PATCH 31/47] feat: create Client/AsyncClient superclass --- google/cloud/firestore_v1/async_client.py | 58 +-- google/cloud/firestore_v1/base_client.py | 491 ++++++++++++++++++++++ google/cloud/firestore_v1/client.py | 417 ++---------------- tests/unit/v1/async/test_async_client.py | 299 +------------ tests/unit/v1/test_base_client.py | 358 ++++++++++++++++ tests/unit/v1/test_client.py | 322 +------------- 6 files changed, 924 insertions(+), 1021 deletions(-) create mode 100644 google/cloud/firestore_v1/base_client.py create mode 100644 tests/unit/v1/test_base_client.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index d93ca6fb54..04ff127edf 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -24,14 +24,15 @@ :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ -from google.cloud.firestore_v1.client import ( - Client, +from google.cloud.firestore_v1.base_client import ( + BaseClient, DEFAULT_DATABASE, _CLIENT_INFO, _reference_info, _parse_batch_get, _get_doc_mask, _item_to_collection_ref, + _path_helper, ) from google.cloud.firestore_v1 import _helpers @@ -42,7 +43,7 @@ from google.cloud.firestore_v1.async_transaction import AsyncTransaction -class AsyncClient(Client): +class AsyncClient(BaseClient): """Client for interacting with Google Cloud Firestore API. .. note:: @@ -78,9 +79,6 @@ def __init__( client_info=_CLIENT_INFO, client_options=None, ): - # NOTE: This API has no use for the _http argument, but sending it - # will have no impact since the _http() @property only lazily - # creates a working HTTP object. super(AsyncClient, self).__init__( project=project, credentials=credentials, @@ -118,12 +116,7 @@ def collection(self, *collection_path): :class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`: A reference to a collection in the Firestore database. """ - if len(collection_path) == 1: - path = collection_path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) - else: - path = collection_path - - return AsyncCollectionReference(*path, client=self) + return AsyncCollectionReference(*_path_helper(collection_path), client=self) def collection_group(self, collection_id): """ @@ -135,20 +128,19 @@ def collection_group(self, collection_id): >>> query = client.collection_group('mygroup') - @param {string} collectionId Identifies the collections to query over. - Every collection or subcollection with this ID as the last segment of its - path will be included. Cannot contain a slash. - @returns {AsyncQuery} The created AsyncQuery. - """ - if "/" in collection_id: - raise ValueError( - "Invalid collection_id " - + collection_id - + ". Collection IDs must not contain '/'." - ) + Args: + collection_id (str) Identifies the collections to query over. + + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. - collection = self.collection(collection_id) - return AsyncQuery(collection, all_descendants=True) + Returns: + :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: + The created AsyncQuery. + """ + return AsyncQuery( + self._get_collection_reference(collection_id), all_descendants=True + ) def document(self, *document_path): """Get a reference to a document in a collection. @@ -181,19 +173,9 @@ def document(self, *document_path): :class:`~google.cloud.firestore_v1.document.AsyncDocumentReference`: A reference to a document in a collection. """ - if len(document_path) == 1: - path = document_path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) - else: - path = document_path - - # AsyncDocumentReference takes a relative path. Strip the database string if present. - base_path = self._database_string + "/documents/" - joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path) - if joined_path.startswith(base_path): - joined_path = joined_path[len(base_path) :] - path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) - - return AsyncDocumentReference(*path, client=self) + return AsyncDocumentReference( + *self._document_path_helper(*document_path), client=self + ) async def get_all(self, references, field_paths=None, transaction=None): """Retrieve a batch of documents. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py new file mode 100644 index 0000000000..d020c251a7 --- /dev/null +++ b/google/cloud/firestore_v1/base_client.py @@ -0,0 +1,491 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client for interacting with the Google Cloud Firestore API. + +This is the base from which all interactions with the API occur. + +In the hierarchy of API concepts + +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.collection.CollectionReference` +* a :class:`~google.cloud.firestore_v1.client.Client` owns a + :class:`~google.cloud.firestore_v1.document.DocumentReference` +""" +import os + +import google.api_core.client_options +from google.api_core.gapic_v1 import client_info +from google.cloud.client import ClientWithProject + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import __version__ +from google.cloud.firestore_v1 import types +from google.cloud.firestore_v1.document import DocumentSnapshot +from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.gapic import firestore_client +from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport + + +DEFAULT_DATABASE = "(default)" +"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" +_BAD_OPTION_ERR = ( + "Exactly one of ``last_update_time`` or ``exists`` " "must be provided." +) +_BAD_DOC_TEMPLATE = ( + "Document {!r} appeared in response but was not present among references" +) +_ACTIVE_TXN = "There is already an active transaction." +_INACTIVE_TXN = "There is no active transaction." +_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) +_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST" + + +class BaseClient(ClientWithProject): + """Client for interacting with Google Cloud Firestore API. + + .. note:: + + Since the Cloud Firestore API requires the gRPC transport, no + ``_http`` argument is accepted by this class. + + Args: + project (Optional[str]): The project which the client acts on behalf + of. If not passed, falls back to the default inferred + from the environment. + credentials (Optional[~google.auth.credentials.Credentials]): The + OAuth2 Credentials to use for this client. If not passed, falls + back to the default inferred from the environment. + database (Optional[str]): The database name that the client targets. + For now, :attr:`DEFAULT_DATABASE` (the default value) is the + only valid database. + client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]): + The client info used to send a user-agent string along with API + requests. If ``None``, then default info will be used. Generally, + you only need to set this if you're developing your own library + or partner tool. + client_options (Union[dict, google.api_core.client_options.ClientOptions]): + Client options used to set user options on the client. API Endpoint + should be set through client_options. + """ + + SCOPE = ( + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + ) + """The scopes required for authenticating with the Firestore service.""" + + _firestore_api_internal = None + _database_string_internal = None + _rpc_metadata_internal = None + + def __init__( + self, + project=None, + credentials=None, + database=DEFAULT_DATABASE, + client_info=_CLIENT_INFO, + client_options=None, + ): + # NOTE: This API has no use for the _http argument, but sending it + # will have no impact since the _http() @property only lazily + # creates a working HTTP object. + super(BaseClient, self).__init__( + project=project, credentials=credentials, _http=None + ) + self._client_info = client_info + if client_options: + if type(client_options) == dict: + client_options = google.api_core.client_options.from_dict( + client_options + ) + self._client_options = client_options + + self._database = database + self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST) + + @property + def _firestore_api(self): + """Lazy-loading getter GAPIC Firestore API. + + Returns: + :class:`~google.cloud.gapic.firestore.v1`.firestore_client.FirestoreClient: + >> query = client.collection_group('mygroup') - @param {string} collectionId Identifies the collections to query over. - Every collection or subcollection with this ID as the last segment of its - path will be included. Cannot contain a slash. - @returns {Query} The created Query. - """ - if "/" in collection_id: - raise ValueError( - "Invalid collection_id " - + collection_id - + ". Collection IDs must not contain '/'." - ) + Args: + collection_id (str) Identifies the collections to query over. - collection = self.collection(collection_id) - return query.Query(collection, all_descendants=True) + Every collection or subcollection with this ID as the last segment of its + path will be included. Cannot contain a slash. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + The created Query. + """ + return Query( + self._get_collection_reference(collection_id), all_descendants=True + ) def document(self, *document_path): """Get a reference to a document in a collection. @@ -304,97 +173,9 @@ def document(self, *document_path): :class:`~google.cloud.firestore_v1.document.DocumentReference`: A reference to a document in a collection. """ - if len(document_path) == 1: - path = document_path[0].split(_helpers.DOCUMENT_PATH_DELIMITER) - else: - path = document_path - - # DocumentReference takes a relative path. Strip the database string if present. - base_path = self._database_string + "/documents/" - joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path) - if joined_path.startswith(base_path): - joined_path = joined_path[len(base_path) :] - path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER) - - return DocumentReference(*path, client=self) - - @staticmethod - def field_path(*field_names): - """Create a **field path** from a list of nested field names. - - A **field path** is a ``.``-delimited concatenation of the field - names. It is used to represent a nested field. For example, - in the data - - .. code-block:: python - - data = { - 'aa': { - 'bb': { - 'cc': 10, - }, - }, - } - - the field path ``'aa.bb.cc'`` represents the data stored in - ``data['aa']['bb']['cc']``. - - Args: - field_names (Tuple[str, ...]): The list of field names. - - Returns: - str: The ``.``-delimited field path. - """ - return render_field_path(field_names) - - @staticmethod - def write_option(**kwargs): - """Create a write option for write operations. - - Write operations include :meth:`~google.cloud.DocumentReference.set`, - :meth:`~google.cloud.DocumentReference.update` and - :meth:`~google.cloud.DocumentReference.delete`. - - One of the following keyword arguments must be provided: - - * ``last_update_time`` (:class:`google.protobuf.timestamp_pb2.\ - Timestamp`): A timestamp. When set, the target document must - exist and have been last updated at that time. Protobuf - ``update_time`` timestamps are typically returned from methods - that perform write operations as part of a "write result" - protobuf or directly. - * ``exists`` (:class:`bool`): Indicates if the document being modified - should already exist. - - Providing no argument would make the option have no effect (so - it is not allowed). Providing multiple would be an apparent - contradiction, since ``last_update_time`` assumes that the - document **was** updated (it can't have been updated if it - doesn't exist) and ``exists`` indicate that it is unknown if the - document exists or not. - - Args: - kwargs (Dict[str, Any]): The keyword arguments described above. - - Raises: - TypeError: If anything other than exactly one argument is - provided by the caller. - - Returns: - :class:`~google.cloud.firestore_v1.client.WriteOption`: - The option to be used to configure a write message. - """ - if len(kwargs) != 1: - raise TypeError(_BAD_OPTION_ERR) - - name, value = kwargs.popitem() - if name == "last_update_time": - return _helpers.LastUpdateOption(value) - elif name == "exists": - return _helpers.ExistsOption(value) - else: - extra = "{!r} was provided".format(name) - raise TypeError(_BAD_OPTION_ERR, extra) + return DocumentReference( + *self._document_path_helper(*document_path), client=self + ) def get_all(self, references, field_paths=None, transaction=None): """Retrieve a batch of documents. @@ -485,135 +266,3 @@ def transaction(self, **kwargs): A transaction attached to this client. """ return Transaction(self, **kwargs) - - -def _reference_info(references): - """Get information about document references. - - Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. - - Args: - references (List[.DocumentReference, ...]): Iterable of document - references. - - Returns: - Tuple[List[str, ...], Dict[str, .DocumentReference]]: A two-tuple of - - * fully-qualified documents paths for each reference in ``references`` - * a mapping from the paths to the original reference. (If multiple - ``references`` contains multiple references to the same document, - that key will be overwritten in the result.) - """ - document_paths = [] - reference_map = {} - for reference in references: - doc_path = reference._document_path - document_paths.append(doc_path) - reference_map[doc_path] = reference - - return document_paths, reference_map - - -def _get_reference(document_path, reference_map): - """Get a document reference from a dictionary. - - This just wraps a simple dictionary look-up with a helpful error that is - specific to :meth:`~google.cloud.firestore.client.Client.get_all`, the - **public** caller of this function. - - Args: - document_path (str): A fully-qualified document path. - reference_map (Dict[str, .DocumentReference]): A mapping (produced - by :func:`_reference_info`) of fully-qualified document paths to - document references. - - Returns: - .DocumentReference: The matching reference. - - Raises: - ValueError: If ``document_path`` has not been encountered. - """ - try: - return reference_map[document_path] - except KeyError: - msg = _BAD_DOC_TEMPLATE.format(document_path) - raise ValueError(msg) - - -def _parse_batch_get(get_doc_response, reference_map, client): - """Parse a `BatchGetDocumentsResponse` protobuf. - - Args: - get_doc_response (~google.cloud.proto.firestore.v1.\ - firestore_pb2.BatchGetDocumentsResponse): A single response (from - a stream) containing the "get" response for a document. - reference_map (Dict[str, .DocumentReference]): A mapping (produced - by :func:`_reference_info`) of fully-qualified document paths to - document references. - client (:class:`~google.cloud.firestore_v1.client.Client`): - A client that has a document factory. - - Returns: - [.DocumentSnapshot]: The retrieved snapshot. - - Raises: - ValueError: If the response has a ``result`` field (a oneof) other - than ``found`` or ``missing``. - """ - result_type = get_doc_response.WhichOneof("result") - if result_type == "found": - reference = _get_reference(get_doc_response.found.name, reference_map) - data = _helpers.decode_dict(get_doc_response.found.fields, client) - snapshot = DocumentSnapshot( - reference, - data, - exists=True, - read_time=get_doc_response.read_time, - create_time=get_doc_response.found.create_time, - update_time=get_doc_response.found.update_time, - ) - elif result_type == "missing": - reference = _get_reference(get_doc_response.missing, reference_map) - snapshot = DocumentSnapshot( - reference, - None, - exists=False, - read_time=get_doc_response.read_time, - create_time=None, - update_time=None, - ) - else: - raise ValueError( - "`BatchGetDocumentsResponse.result` (a oneof) had a field other " - "than `found` or `missing` set, or was unset" - ) - return snapshot - - -def _get_doc_mask(field_paths): - """Get a document mask if field paths are provided. - - Args: - field_paths (Optional[Iterable[str, ...]]): An iterable of field - paths (``.``-delimited list of field names) to use as a - projection of document fields in the returned results. - - Returns: - Optional[google.cloud.firestore_v1.types.DocumentMask]: A mask - to project documents to a restricted set of field paths. - """ - if field_paths is None: - return None - else: - return types.DocumentMask(field_paths=field_paths) - - -def _item_to_collection_ref(iterator, item): - """Convert collection ID to collection ref. - - Args: - iterator (google.api_core.page_iterator.GRPCIterator): - iterator response - item (str): ID of the collection - """ - return iterator.client.collection(item) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index 76220f01e1..e83fd7db08 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -51,7 +51,7 @@ def test_constructor(self): self.assertIsNone(client._emulator_host) def test_constructor_with_emulator_host(self): - from google.cloud.firestore_v1.client import _FIRESTORE_EMULATOR_HOST + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST credentials = _make_credentials() emulator_host = "localhost:8081" @@ -88,102 +88,6 @@ def test_constructor_w_client_options(self): ) self.assertEqual(client._target, "foo-firestore.googleapis.com") - @mock.patch( - "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, - ) - def test__firestore_api_property(self, mock_client): - mock_client.SERVICE_ADDRESS = "endpoint" - client = self._make_default_one() - client_info = client._client_info = mock.Mock() - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - mock_client.assert_called_once_with( - transport=client._transport, client_info=client_info - ) - - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) - - @mock.patch( - "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, - ) - @mock.patch( - "google.cloud.firestore_v1.gapic.transports.firestore_grpc_transport.firestore_pb2_grpc.grpc.insecure_channel", - autospec=True, - ) - def test__firestore_api_property_with_emulator( - self, mock_insecure_channel, mock_client - ): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - client = self._make_default_one() - - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - - mock_insecure_channel.assert_called_once_with(emulator_host) - - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) - - def test___database_string_property(self): - credentials = _make_credentials() - database = "cheeeeez" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - self.assertIsNone(client._database_string_internal) - database_string = client._database_string - expected = "projects/{}/databases/{}".format(client.project, client._database) - self.assertEqual(database_string, expected) - self.assertIs(database_string, client._database_string_internal) - - # Swap it out with a unique value to verify it is cached. - client._database_string_internal = mock.sentinel.cached - self.assertIs(client._database_string, mock.sentinel.cached) - - def test___rpc_metadata_property(self): - credentials = _make_credentials() - database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - self.assertEqual( - client._rpc_metadata, - [("google-cloud-resource-prefix", client._database_string)], - ) - - def test__rpc_metadata_property_with_emulator(self): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - - credentials = _make_credentials() - database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - self.assertEqual( - client._rpc_metadata, - [ - ("google-cloud-resource-prefix", client._database_string), - ("authorization", "Bearer owner"), - ], - ) - def test_collection_factory(self): from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -213,6 +117,15 @@ def test_collection_factory_nested(self): self.assertIs(collection2._client, client) self.assertIsInstance(collection2, AsyncCollectionReference) + def test__get_collection_reference(self): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + client = self._make_default_one() + collection = client._get_collection_reference("collectionId") + + self.assertIs(collection._client, client) + self.assertIsInstance(collection, AsyncCollectionReference) + def test_collection_group(self): client = self._make_default_one() query = client.collection_group("collectionId").where("foo", "==", u"bar") @@ -277,62 +190,6 @@ def test_document_factory_w_nested_path(self): self.assertIs(document2._client, client) self.assertIsInstance(document2, AsyncDocumentReference) - def test_field_path(self): - klass = self._get_target_class() - self.assertEqual(klass.field_path("a", "b", "c"), "a.b.c") - - def test_write_option_last_update(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1._helpers import LastUpdateOption - - timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) - - klass = self._get_target_class() - option = klass.write_option(last_update_time=timestamp) - self.assertIsInstance(option, LastUpdateOption) - self.assertEqual(option._last_update_time, timestamp) - - def test_write_option_exists(self): - from google.cloud.firestore_v1._helpers import ExistsOption - - klass = self._get_target_class() - - option1 = klass.write_option(exists=False) - self.assertIsInstance(option1, ExistsOption) - self.assertFalse(option1._exists) - - option2 = klass.write_option(exists=True) - self.assertIsInstance(option2, ExistsOption) - self.assertTrue(option2._exists) - - def test_write_open_neither_arg(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option() - - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) - - def test_write_multiple_args(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(exists=False, last_update_time=mock.sentinel.timestamp) - - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) - - def test_write_bad_arg(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(spinach="popeye") - - extra = "{!r} was provided".format("spinach") - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) - @pytest.mark.asyncio async def test_collections(self): from google.api_core.page_iterator import Iterator @@ -474,7 +331,7 @@ async def test_get_all_with_transaction(self): @pytest.mark.asyncio async def test_get_all_unknown_result(self): - from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE info = self._info_for_get_all({"z": 28.5}, {}) client, document, _, _, response = info @@ -561,140 +418,6 @@ def test_transaction(self): self.assertIsNone(transaction._id) -class Test__reference_info(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(references): - from google.cloud.firestore_v1.async_client import _reference_info - - return _reference_info(references) - - def test_it(self): - from google.cloud.firestore_v1.async_client import AsyncClient - - credentials = _make_credentials() - client = AsyncClient(project="hi-projject", credentials=credentials) - - reference1 = client.document("a", "b") - reference2 = client.document("a", "b", "c", "d") - reference3 = client.document("a", "b") - reference4 = client.document("f", "g") - - doc_path1 = reference1._document_path - doc_path2 = reference2._document_path - doc_path3 = reference3._document_path - doc_path4 = reference4._document_path - self.assertEqual(doc_path1, doc_path3) - - document_paths, reference_map = self._call_fut( - [reference1, reference2, reference3, reference4] - ) - self.assertEqual(document_paths, [doc_path1, doc_path2, doc_path3, doc_path4]) - # reference3 over-rides reference1. - expected_map = { - doc_path2: reference2, - doc_path3: reference3, - doc_path4: reference4, - } - self.assertEqual(reference_map, expected_map) - - -class Test__parse_batch_get(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): - from google.cloud.firestore_v1.async_client import _parse_batch_get - - return _parse_batch_get(get_doc_response, reference_map, client) - - @staticmethod - def _dummy_ref_string(): - from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE - - project = u"bazzzz" - collection_id = u"fizz" - document_id = u"buzz" - return u"projects/{}/databases/{}/documents/{}/{}".format( - project, DEFAULT_DATABASE, collection_id, document_id - ) - - def test_found(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1.async_document import DocumentSnapshot - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - - ref_string = self._dummy_ref_string() - document_pb = document_pb2.Document( - name=ref_string, - fields={ - "foo": document_pb2.Value(double_value=1.5), - "bar": document_pb2.Value(string_value=u"skillz"), - }, - create_time=create_time, - update_time=update_time, - ) - response_pb = _make_batch_response(found=document_pb, read_time=read_time) - - reference_map = {ref_string: mock.sentinel.reference} - snapshot = self._call_fut(response_pb, reference_map) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, mock.sentinel.reference) - self.assertEqual(snapshot._data, {"foo": 1.5, "bar": u"skillz"}) - self.assertTrue(snapshot._exists) - self.assertEqual(snapshot.read_time, read_time) - self.assertEqual(snapshot.create_time, create_time) - self.assertEqual(snapshot.update_time, update_time) - - def test_missing(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - ref_string = self._dummy_ref_string() - response_pb = _make_batch_response(missing=ref_string) - document = AsyncDocumentReference("fizz", "bazz", client=mock.sentinel.client) - reference_map = {ref_string: document} - snapshot = self._call_fut(response_pb, reference_map) - self.assertFalse(snapshot.exists) - self.assertEqual(snapshot.id, "bazz") - self.assertIsNone(snapshot._data) - - def test_unset_result_type(self): - response_pb = _make_batch_response() - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - def test_unknown_result_type(self): - response_pb = mock.Mock(spec=["WhichOneof"]) - response_pb.WhichOneof.return_value = "zoob_value" - - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - response_pb.WhichOneof.assert_called_once_with("result") - - -class Test__get_doc_mask(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(field_paths): - from google.cloud.firestore_v1.async_client import _get_doc_mask - - return _get_doc_mask(field_paths) - - def test_none(self): - self.assertIsNone(self._call_fut(None)) - - def test_paths(self): - from google.cloud.firestore_v1.proto import common_pb2 - - field_paths = ["a.b", "c"] - result = self._call_fut(field_paths) - expected = common_pb2.DocumentMask(field_paths=field_paths) - self.assertEqual(result, expected) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py new file mode 100644 index 0000000000..40523104d2 --- /dev/null +++ b/tests/unit/v1/test_base_client.py @@ -0,0 +1,358 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest + +import mock + + +class TestClient(unittest.TestCase): + + PROJECT = "my-prahjekt" + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.client import Client + + return Client + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def _make_default_one(self): + credentials = _make_credentials() + return self._make_one(project=self.PROJECT, credentials=credentials) + + @mock.patch( + "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + def test__firestore_api_property(self, mock_client): + mock_client.SERVICE_ADDRESS = "endpoint" + client = self._make_default_one() + client_info = client._client_info = mock.Mock() + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + mock_client.assert_called_once_with( + transport=client._transport, client_info=client_info + ) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + + @mock.patch( + "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + @mock.patch( + "google.cloud.firestore_v1.gapic.transports.firestore_grpc_transport.firestore_pb2_grpc.grpc.insecure_channel", + autospec=True, + ) + def test__firestore_api_property_with_emulator( + self, mock_insecure_channel, mock_client + ): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + client = self._make_default_one() + + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + + mock_insecure_channel.assert_called_once_with(emulator_host) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + + def test___database_string_property(self): + credentials = _make_credentials() + database = "cheeeeez" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + self.assertIsNone(client._database_string_internal) + database_string = client._database_string + expected = "projects/{}/databases/{}".format(client.project, client._database) + self.assertEqual(database_string, expected) + self.assertIs(database_string, client._database_string_internal) + + # Swap it out with a unique value to verify it is cached. + client._database_string_internal = mock.sentinel.cached + self.assertIs(client._database_string, mock.sentinel.cached) + + def test___rpc_metadata_property(self): + credentials = _make_credentials() + database = "quanta" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + + self.assertEqual( + client._rpc_metadata, + [("google-cloud-resource-prefix", client._database_string)], + ) + + def test__rpc_metadata_property_with_emulator(self): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + + credentials = _make_credentials() + database = "quanta" + client = self._make_one( + project=self.PROJECT, credentials=credentials, database=database + ) + + self.assertEqual( + client._rpc_metadata, + [ + ("google-cloud-resource-prefix", client._database_string), + ("authorization", "Bearer owner"), + ], + ) + + def test_field_path(self): + klass = self._get_target_class() + self.assertEqual(klass.field_path("a", "b", "c"), "a.b.c") + + def test_write_option_last_update(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import LastUpdateOption + + timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) + + klass = self._get_target_class() + option = klass.write_option(last_update_time=timestamp) + self.assertIsInstance(option, LastUpdateOption) + self.assertEqual(option._last_update_time, timestamp) + + def test_write_option_exists(self): + from google.cloud.firestore_v1._helpers import ExistsOption + + klass = self._get_target_class() + + option1 = klass.write_option(exists=False) + self.assertIsInstance(option1, ExistsOption) + self.assertFalse(option1._exists) + + option2 = klass.write_option(exists=True) + self.assertIsInstance(option2, ExistsOption) + self.assertTrue(option2._exists) + + def test_write_open_neither_arg(self): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option() + + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) + + def test_write_multiple_args(self): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option(exists=False, last_update_time=mock.sentinel.timestamp) + + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) + + def test_write_bad_arg(self): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + + klass = self._get_target_class() + with self.assertRaises(TypeError) as exc_info: + klass.write_option(spinach="popeye") + + extra = "{!r} was provided".format("spinach") + self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) + + +class Test__reference_info(unittest.TestCase): + @staticmethod + def _call_fut(references): + from google.cloud.firestore_v1.base_client import _reference_info + + return _reference_info(references) + + def test_it(self): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + client = Client(project="hi-projject", credentials=credentials) + + reference1 = client.document("a", "b") + reference2 = client.document("a", "b", "c", "d") + reference3 = client.document("a", "b") + reference4 = client.document("f", "g") + + doc_path1 = reference1._document_path + doc_path2 = reference2._document_path + doc_path3 = reference3._document_path + doc_path4 = reference4._document_path + self.assertEqual(doc_path1, doc_path3) + + document_paths, reference_map = self._call_fut( + [reference1, reference2, reference3, reference4] + ) + self.assertEqual(document_paths, [doc_path1, doc_path2, doc_path3, doc_path4]) + # reference3 over-rides reference1. + expected_map = { + doc_path2: reference2, + doc_path3: reference3, + doc_path4: reference4, + } + self.assertEqual(reference_map, expected_map) + + +class Test__get_reference(unittest.TestCase): + @staticmethod + def _call_fut(document_path, reference_map): + from google.cloud.firestore_v1.base_client import _get_reference + + return _get_reference(document_path, reference_map) + + def test_success(self): + doc_path = "a/b/c" + reference_map = {doc_path: mock.sentinel.reference} + self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) + + def test_failure(self): + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + + doc_path = "1/888/call-now" + with self.assertRaises(ValueError) as exc_info: + self._call_fut(doc_path, {}) + + err_msg = _BAD_DOC_TEMPLATE.format(doc_path) + self.assertEqual(exc_info.exception.args, (err_msg,)) + + +class Test__parse_batch_get(unittest.TestCase): + @staticmethod + def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): + from google.cloud.firestore_v1.base_client import _parse_batch_get + + return _parse_batch_get(get_doc_response, reference_map, client) + + @staticmethod + def _dummy_ref_string(): + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE + + project = u"bazzzz" + collection_id = u"fizz" + document_id = u"buzz" + return u"projects/{}/databases/{}/documents/{}/{}".format( + project, DEFAULT_DATABASE, collection_id, document_id + ) + + def test_found(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1.document import DocumentSnapshot + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + ref_string = self._dummy_ref_string() + document_pb = document_pb2.Document( + name=ref_string, + fields={ + "foo": document_pb2.Value(double_value=1.5), + "bar": document_pb2.Value(string_value=u"skillz"), + }, + create_time=create_time, + update_time=update_time, + ) + response_pb = _make_batch_response(found=document_pb, read_time=read_time) + + reference_map = {ref_string: mock.sentinel.reference} + snapshot = self._call_fut(response_pb, reference_map) + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertIs(snapshot._reference, mock.sentinel.reference) + self.assertEqual(snapshot._data, {"foo": 1.5, "bar": u"skillz"}) + self.assertTrue(snapshot._exists) + self.assertEqual(snapshot.read_time, read_time) + self.assertEqual(snapshot.create_time, create_time) + self.assertEqual(snapshot.update_time, update_time) + + def test_missing(self): + from google.cloud.firestore_v1.document import DocumentReference + + ref_string = self._dummy_ref_string() + response_pb = _make_batch_response(missing=ref_string) + document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) + reference_map = {ref_string: document} + snapshot = self._call_fut(response_pb, reference_map) + self.assertFalse(snapshot.exists) + self.assertEqual(snapshot.id, "bazz") + self.assertIsNone(snapshot._data) + + def test_unset_result_type(self): + response_pb = _make_batch_response() + with self.assertRaises(ValueError): + self._call_fut(response_pb, {}) + + def test_unknown_result_type(self): + response_pb = mock.Mock(spec=["WhichOneof"]) + response_pb.WhichOneof.return_value = "zoob_value" + + with self.assertRaises(ValueError): + self._call_fut(response_pb, {}) + + response_pb.WhichOneof.assert_called_once_with("result") + + +class Test__get_doc_mask(unittest.TestCase): + @staticmethod + def _call_fut(field_paths): + from google.cloud.firestore_v1.base_client import _get_doc_mask + + return _get_doc_mask(field_paths) + + def test_none(self): + self.assertIsNone(self._call_fut(None)) + + def test_paths(self): + from google.cloud.firestore_v1.proto import common_pb2 + + field_paths = ["a.b", "c"] + result = self._call_fut(field_paths) + expected = common_pb2.DocumentMask(field_paths=field_paths) + self.assertEqual(result, expected) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_batch_response(**kwargs): + from google.cloud.firestore_v1.proto import firestore_pb2 + + return firestore_pb2.BatchGetDocumentsResponse(**kwargs) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 7ec062422a..4e295c467d 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -50,7 +50,7 @@ def test_constructor(self): self.assertIsNone(client._emulator_host) def test_constructor_with_emulator_host(self): - from google.cloud.firestore_v1.client import _FIRESTORE_EMULATOR_HOST + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST credentials = _make_credentials() emulator_host = "localhost:8081" @@ -87,102 +87,6 @@ def test_constructor_w_client_options(self): ) self.assertEqual(client._target, "foo-firestore.googleapis.com") - @mock.patch( - "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, - ) - def test__firestore_api_property(self, mock_client): - mock_client.SERVICE_ADDRESS = "endpoint" - client = self._make_default_one() - client_info = client._client_info = mock.Mock() - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - mock_client.assert_called_once_with( - transport=client._transport, client_info=client_info - ) - - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) - - @mock.patch( - "google.cloud.firestore_v1.gapic.firestore_client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, - ) - @mock.patch( - "google.cloud.firestore_v1.gapic.transports.firestore_grpc_transport.firestore_pb2_grpc.grpc.insecure_channel", - autospec=True, - ) - def test__firestore_api_property_with_emulator( - self, mock_insecure_channel, mock_client - ): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - client = self._make_default_one() - - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - - mock_insecure_channel.assert_called_once_with(emulator_host) - - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) - - def test___database_string_property(self): - credentials = _make_credentials() - database = "cheeeeez" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - self.assertIsNone(client._database_string_internal) - database_string = client._database_string - expected = "projects/{}/databases/{}".format(client.project, client._database) - self.assertEqual(database_string, expected) - self.assertIs(database_string, client._database_string_internal) - - # Swap it out with a unique value to verify it is cached. - client._database_string_internal = mock.sentinel.cached - self.assertIs(client._database_string, mock.sentinel.cached) - - def test___rpc_metadata_property(self): - credentials = _make_credentials() - database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - self.assertEqual( - client._rpc_metadata, - [("google-cloud-resource-prefix", client._database_string)], - ) - - def test__rpc_metadata_property_with_emulator(self): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - - credentials = _make_credentials() - database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - self.assertEqual( - client._rpc_metadata, - [ - ("google-cloud-resource-prefix", client._database_string), - ("authorization", "Bearer owner"), - ], - ) - def test_collection_factory(self): from google.cloud.firestore_v1.collection import CollectionReference @@ -212,6 +116,15 @@ def test_collection_factory_nested(self): self.assertIs(collection2._client, client) self.assertIsInstance(collection2, CollectionReference) + def test__get_collection_reference(self): + from google.cloud.firestore_v1.collection import CollectionReference + + client = self._make_default_one() + collection = client._get_collection_reference("collectionId") + + self.assertIs(collection._client, client) + self.assertIsInstance(collection, CollectionReference) + def test_collection_group(self): client = self._make_default_one() query = client.collection_group("collectionId").where("foo", "==", u"bar") @@ -276,62 +189,6 @@ def test_document_factory_w_nested_path(self): self.assertIs(document2._client, client) self.assertIsInstance(document2, DocumentReference) - def test_field_path(self): - klass = self._get_target_class() - self.assertEqual(klass.field_path("a", "b", "c"), "a.b.c") - - def test_write_option_last_update(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1._helpers import LastUpdateOption - - timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) - - klass = self._get_target_class() - option = klass.write_option(last_update_time=timestamp) - self.assertIsInstance(option, LastUpdateOption) - self.assertEqual(option._last_update_time, timestamp) - - def test_write_option_exists(self): - from google.cloud.firestore_v1._helpers import ExistsOption - - klass = self._get_target_class() - - option1 = klass.write_option(exists=False) - self.assertIsInstance(option1, ExistsOption) - self.assertFalse(option1._exists) - - option2 = klass.write_option(exists=True) - self.assertIsInstance(option2, ExistsOption) - self.assertTrue(option2._exists) - - def test_write_open_neither_arg(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option() - - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) - - def test_write_multiple_args(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(exists=False, last_update_time=mock.sentinel.timestamp) - - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) - - def test_write_bad_arg(self): - from google.cloud.firestore_v1.client import _BAD_OPTION_ERR - - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(spinach="popeye") - - extra = "{!r} was provided".format("spinach") - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) - def test_collections(self): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page @@ -469,7 +326,7 @@ def test_get_all_with_transaction(self): ) def test_get_all_unknown_result(self): - from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE info = self._info_for_get_all({"z": 28.5}, {}) client, document, _, _, response = info @@ -555,163 +412,6 @@ def test_transaction(self): self.assertIsNone(transaction._id) -class Test__reference_info(unittest.TestCase): - @staticmethod - def _call_fut(references): - from google.cloud.firestore_v1.client import _reference_info - - return _reference_info(references) - - def test_it(self): - from google.cloud.firestore_v1.client import Client - - credentials = _make_credentials() - client = Client(project="hi-projject", credentials=credentials) - - reference1 = client.document("a", "b") - reference2 = client.document("a", "b", "c", "d") - reference3 = client.document("a", "b") - reference4 = client.document("f", "g") - - doc_path1 = reference1._document_path - doc_path2 = reference2._document_path - doc_path3 = reference3._document_path - doc_path4 = reference4._document_path - self.assertEqual(doc_path1, doc_path3) - - document_paths, reference_map = self._call_fut( - [reference1, reference2, reference3, reference4] - ) - self.assertEqual(document_paths, [doc_path1, doc_path2, doc_path3, doc_path4]) - # reference3 over-rides reference1. - expected_map = { - doc_path2: reference2, - doc_path3: reference3, - doc_path4: reference4, - } - self.assertEqual(reference_map, expected_map) - - -class Test__get_reference(unittest.TestCase): - @staticmethod - def _call_fut(document_path, reference_map): - from google.cloud.firestore_v1.client import _get_reference - - return _get_reference(document_path, reference_map) - - def test_success(self): - doc_path = "a/b/c" - reference_map = {doc_path: mock.sentinel.reference} - self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) - - def test_failure(self): - from google.cloud.firestore_v1.client import _BAD_DOC_TEMPLATE - - doc_path = "1/888/call-now" - with self.assertRaises(ValueError) as exc_info: - self._call_fut(doc_path, {}) - - err_msg = _BAD_DOC_TEMPLATE.format(doc_path) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - -class Test__parse_batch_get(unittest.TestCase): - @staticmethod - def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): - from google.cloud.firestore_v1.client import _parse_batch_get - - return _parse_batch_get(get_doc_response, reference_map, client) - - @staticmethod - def _dummy_ref_string(): - from google.cloud.firestore_v1.client import DEFAULT_DATABASE - - project = u"bazzzz" - collection_id = u"fizz" - document_id = u"buzz" - return u"projects/{}/databases/{}/documents/{}/{}".format( - project, DEFAULT_DATABASE, collection_id, document_id - ) - - def test_found(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1.document import DocumentSnapshot - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - - ref_string = self._dummy_ref_string() - document_pb = document_pb2.Document( - name=ref_string, - fields={ - "foo": document_pb2.Value(double_value=1.5), - "bar": document_pb2.Value(string_value=u"skillz"), - }, - create_time=create_time, - update_time=update_time, - ) - response_pb = _make_batch_response(found=document_pb, read_time=read_time) - - reference_map = {ref_string: mock.sentinel.reference} - snapshot = self._call_fut(response_pb, reference_map) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, mock.sentinel.reference) - self.assertEqual(snapshot._data, {"foo": 1.5, "bar": u"skillz"}) - self.assertTrue(snapshot._exists) - self.assertEqual(snapshot.read_time, read_time) - self.assertEqual(snapshot.create_time, create_time) - self.assertEqual(snapshot.update_time, update_time) - - def test_missing(self): - from google.cloud.firestore_v1.document import DocumentReference - - ref_string = self._dummy_ref_string() - response_pb = _make_batch_response(missing=ref_string) - document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) - reference_map = {ref_string: document} - snapshot = self._call_fut(response_pb, reference_map) - self.assertFalse(snapshot.exists) - self.assertEqual(snapshot.id, "bazz") - self.assertIsNone(snapshot._data) - - def test_unset_result_type(self): - response_pb = _make_batch_response() - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - def test_unknown_result_type(self): - response_pb = mock.Mock(spec=["WhichOneof"]) - response_pb.WhichOneof.return_value = "zoob_value" - - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - response_pb.WhichOneof.assert_called_once_with("result") - - -class Test__get_doc_mask(unittest.TestCase): - @staticmethod - def _call_fut(field_paths): - from google.cloud.firestore_v1.client import _get_doc_mask - - return _get_doc_mask(field_paths) - - def test_none(self): - self.assertIsNone(self._call_fut(None)) - - def test_paths(self): - from google.cloud.firestore_v1.proto import common_pb2 - - field_paths = ["a.b", "c"] - result = self._call_fut(field_paths) - expected = common_pb2.DocumentMask(field_paths=field_paths) - self.assertEqual(result, expected) - - def _make_credentials(): import google.auth.credentials From eedd62c4ebab8941b002cde6f1927ac723dea741 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 23 Jun 2020 18:53:28 -0500 Subject: [PATCH 32/47] fix: base client test class --- tests/unit/v1/test_base_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 40523104d2..1452b7aa85 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -18,7 +18,7 @@ import mock -class TestClient(unittest.TestCase): +class TestBaseClient(unittest.TestCase): PROJECT = "my-prahjekt" From 3a52326cffbd282ddf2e8f467a5defcf888fae1c Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 23 Jun 2020 19:05:13 -0500 Subject: [PATCH 33/47] feat: create WriteBatch/AsyncWriteBatch superclass --- google/cloud/firestore_v1/async_batch.py | 4 +- google/cloud/firestore_v1/base_batch.py | 132 +++++++++++++++++ google/cloud/firestore_v1/batch.py | 104 +------------- tests/unit/v1/async/test_async_batch.py | 123 ---------------- tests/unit/v1/test_base_batch.py | 172 +++++++++++++++++++++++ tests/unit/v1/test_batch.py | 123 ---------------- 6 files changed, 309 insertions(+), 349 deletions(-) create mode 100644 google/cloud/firestore_v1/base_batch.py create mode 100644 tests/unit/v1/test_base_batch.py diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 495bee06ce..7fb18e90e2 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -15,10 +15,10 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" -from google.cloud.firestore_v1.batch import WriteBatch +from google.cloud.firestore_v1.base_batch import BaseWriteBatch -class AsyncWriteBatch(WriteBatch): +class AsyncWriteBatch(BaseWriteBatch): """Accumulate write operations to be sent in a batch. This has the same set of methods for write operations that diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py new file mode 100644 index 0000000000..45f8c49d99 --- /dev/null +++ b/google/cloud/firestore_v1/base_batch.py @@ -0,0 +1,132 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for batch requests to the Google Cloud Firestore API.""" + + +from google.cloud.firestore_v1 import _helpers + + +class BaseWriteBatch(object): + """Accumulate write operations to be sent in a batch. + + This has the same set of methods for write operations that + :class:`~google.cloud.firestore_v1.document.DocumentReference` does, + e.g. :meth:`~google.cloud.firestore_v1.document.DocumentReference.create`. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this batch. + """ + + def __init__(self, client): + self._client = client + self._write_pbs = [] + self.write_results = None + self.commit_time = None + + def _add_write_pbs(self, write_pbs): + """Add `Write`` protobufs to this transaction. + + This method intended to be over-ridden by subclasses. + + Args: + write_pbs (List[google.cloud.proto.firestore.v1.\ + write_pb2.Write]): A list of write protobufs to be added. + """ + self._write_pbs.extend(write_pbs) + + def create(self, reference, document_data): + """Add a "change" to this batch to create a document. + + If the document given by ``reference`` already exists, then this + batch will fail when :meth:`commit`-ed. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): + A document reference to be created in this batch. + document_data (dict): Property names and values to use for + creating a document. + """ + write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) + self._add_write_pbs(write_pbs) + + def set(self, reference, document_data, merge=False): + """Add a "change" to replace a document. + + See + :meth:`google.cloud.firestore_v1.document.DocumentReference.set` for + more information on how ``option`` determines how the change is + applied. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): + A document reference that will have values set in this batch. + document_data (dict): + Property names and values to use for replacing a document. + merge (Optional[bool] or Optional[List]): + If True, apply merging instead of overwriting the state + of the document. + """ + if merge is not False: + write_pbs = _helpers.pbs_for_set_with_merge( + reference._document_path, document_data, merge + ) + else: + write_pbs = _helpers.pbs_for_set_no_merge( + reference._document_path, document_data + ) + + self._add_write_pbs(write_pbs) + + def update(self, reference, field_updates, option=None): + """Add a "change" to update a document. + + See + :meth:`google.cloud.firestore_v1.document.DocumentReference.update` + for more information on ``field_updates`` and ``option``. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): + A document reference that will be updated in this batch. + field_updates (dict): + Field names or paths to update and values to update with. + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + """ + if option.__class__.__name__ == "ExistsOption": + raise ValueError("you must not pass an explicit write option to " "update.") + write_pbs = _helpers.pbs_for_update( + reference._document_path, field_updates, option + ) + self._add_write_pbs(write_pbs) + + def delete(self, reference, option=None): + """Add a "change" to delete a document. + + See + :meth:`google.cloud.firestore_v1.document.DocumentReference.delete` + for more information on how ``option`` determines how the change is + applied. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): + A document reference that will be deleted in this batch. + option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): + A write option to make assertions / preconditions on the server + state of the document before applying changes. + """ + write_pb = _helpers.pb_for_delete(reference._document_path, option) + self._add_write_pbs([write_pb]) diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 56483af10c..9a48e460a5 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -15,10 +15,10 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" -from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_batch import BaseWriteBatch -class WriteBatch(object): +class WriteBatch(BaseWriteBatch): """Accumulate write operations to be sent in a batch. This has the same set of methods for write operations that @@ -31,105 +31,7 @@ class WriteBatch(object): """ def __init__(self, client): - self._client = client - self._write_pbs = [] - self.write_results = None - self.commit_time = None - - def _add_write_pbs(self, write_pbs): - """Add `Write`` protobufs to this transaction. - - This method intended to be over-ridden by subclasses. - - Args: - write_pbs (List[google.cloud.proto.firestore.v1.\ - write_pb2.Write]): A list of write protobufs to be added. - """ - self._write_pbs.extend(write_pbs) - - def create(self, reference, document_data): - """Add a "change" to this batch to create a document. - - If the document given by ``reference`` already exists, then this - batch will fail when :meth:`commit`-ed. - - Args: - reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): - A document reference to be created in this batch. - document_data (dict): Property names and values to use for - creating a document. - """ - write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) - self._add_write_pbs(write_pbs) - - def set(self, reference, document_data, merge=False): - """Add a "change" to replace a document. - - See - :meth:`google.cloud.firestore_v1.document.DocumentReference.set` for - more information on how ``option`` determines how the change is - applied. - - Args: - reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): - A document reference that will have values set in this batch. - document_data (dict): - Property names and values to use for replacing a document. - merge (Optional[bool] or Optional[List]): - If True, apply merging instead of overwriting the state - of the document. - """ - if merge is not False: - write_pbs = _helpers.pbs_for_set_with_merge( - reference._document_path, document_data, merge - ) - else: - write_pbs = _helpers.pbs_for_set_no_merge( - reference._document_path, document_data - ) - - self._add_write_pbs(write_pbs) - - def update(self, reference, field_updates, option=None): - """Add a "change" to update a document. - - See - :meth:`google.cloud.firestore_v1.document.DocumentReference.update` - for more information on ``field_updates`` and ``option``. - - Args: - reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): - A document reference that will be updated in this batch. - field_updates (dict): - Field names or paths to update and values to update with. - option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): - A write option to make assertions / preconditions on the server - state of the document before applying changes. - """ - if option.__class__.__name__ == "ExistsOption": - raise ValueError("you must not pass an explicit write option to " "update.") - write_pbs = _helpers.pbs_for_update( - reference._document_path, field_updates, option - ) - self._add_write_pbs(write_pbs) - - def delete(self, reference, option=None): - """Add a "change" to delete a document. - - See - :meth:`google.cloud.firestore_v1.document.DocumentReference.delete` - for more information on how ``option`` determines how the change is - applied. - - Args: - reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): - A document reference that will be deleted in this batch. - option (Optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): - A write option to make assertions / preconditions on the server - state of the document before applying changes. - """ - write_pb = _helpers.pb_for_delete(reference._document_path, option) - self._add_write_pbs([write_pb]) + super(WriteBatch, self).__init__(client=client) def commit(self): """Commit the changes accumulated in this batch. diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index 301bc58a81..6b6a8af774 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -36,123 +36,6 @@ def test_constructor(self): self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - def test__add_write_pbs(self): - batch = self._make_one(mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) - self.assertEqual(batch._write_pbs, [mock.sentinel.write1, mock.sentinel.write2]) - - def test_create(self): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("this", "one") - document_data = {"a": 10, "b": 2.5} - ret_val = batch.create(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={ - "a": _value_pb(integer_value=document_data["a"]), - "b": _value_pb(double_value=document_data["b"]), - }, - ), - current_document=common_pb2.Precondition(exists=False), - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ) - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set_merge(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data, merge=True) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ), - update_mask={"field_paths": [field]}, - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_update(self): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("cats", "cradle") - field_path = "head.foot" - value = u"knees toes shoulders" - field_updates = {field_path: value} - - ret_val = batch.update(reference, field_updates) - self.assertIsNone(ret_val) - - map_pb = document_pb2.MapValue(fields={"foot": _value_pb(string_value=value)}) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={"head": _value_pb(map_value=map_pb)}, - ), - update_mask=common_pb2.DocumentMask(field_paths=[field_path]), - current_document=common_pb2.Precondition(exists=True), - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_delete(self): - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("early", "mornin", "dawn", "now") - ret_val = batch.delete(reference) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write(delete=reference._document_path) - self.assertEqual(batch._write_pbs, [new_write_pb]) - @pytest.mark.asyncio async def test_commit(self): from google.protobuf import timestamp_pb2 @@ -256,12 +139,6 @@ async def test_as_context_mgr_w_error(self): firestore_api.commit.assert_not_called() -def _value_pb(**kwargs): - from google.cloud.firestore_v1.proto.document_pb2 import Value - - return Value(**kwargs) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_batch.py b/tests/unit/v1/test_base_batch.py new file mode 100644 index 0000000000..824ebbc87c --- /dev/null +++ b/tests/unit/v1/test_base_batch.py @@ -0,0 +1,172 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +class TestBaseWriteBatch(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.base_batch import BaseWriteBatch + + return BaseWriteBatch + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + batch = self._make_one(mock.sentinel.client) + self.assertIs(batch._client, mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + self.assertIsNone(batch.write_results) + self.assertIsNone(batch.commit_time) + + def test__add_write_pbs(self): + batch = self._make_one(mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) + self.assertEqual(batch._write_pbs, [mock.sentinel.write1, mock.sentinel.write2]) + + def test_create(self): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("this", "one") + document_data = {"a": 10, "b": 2.5} + ret_val = batch.create(reference, document_data) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={ + "a": _value_pb(integer_value=document_data["a"]), + "b": _value_pb(double_value=document_data["b"]), + }, + ), + current_document=common_pb2.Precondition(exists=False), + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_set(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, + ) + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_set_merge(self): + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data, merge=True) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, + ), + update_mask={"field_paths": [field]}, + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_update(self): + from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("cats", "cradle") + field_path = "head.foot" + value = u"knees toes shoulders" + field_updates = {field_path: value} + + ret_val = batch.update(reference, field_updates) + self.assertIsNone(ret_val) + + map_pb = document_pb2.MapValue(fields={"foot": _value_pb(string_value=value)}) + new_write_pb = write_pb2.Write( + update=document_pb2.Document( + name=reference._document_path, + fields={"head": _value_pb(map_value=map_pb)}, + ), + update_mask=common_pb2.DocumentMask(field_paths=[field_path]), + current_document=common_pb2.Precondition(exists=True), + ) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + def test_delete(self): + from google.cloud.firestore_v1.proto import write_pb2 + + client = _make_client() + batch = self._make_one(client) + self.assertEqual(batch._write_pbs, []) + + reference = client.document("early", "mornin", "dawn", "now") + ret_val = batch.delete(reference) + self.assertIsNone(ret_val) + new_write_pb = write_pb2.Write(delete=reference._document_path) + self.assertEqual(batch._write_pbs, [new_write_pb]) + + +def _value_pb(**kwargs): + from google.cloud.firestore_v1.proto.document_pb2 import Value + + return Value(**kwargs) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="seventy-nine"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index 08421d6039..cf971b87e3 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -35,123 +35,6 @@ def test_constructor(self): self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - def test__add_write_pbs(self): - batch = self._make_one(mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) - self.assertEqual(batch._write_pbs, [mock.sentinel.write1, mock.sentinel.write2]) - - def test_create(self): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("this", "one") - document_data = {"a": 10, "b": 2.5} - ret_val = batch.create(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={ - "a": _value_pb(integer_value=document_data["a"]), - "b": _value_pb(double_value=document_data["b"]), - }, - ), - current_document=common_pb2.Precondition(exists=False), - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ) - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set_merge(self): - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data, merge=True) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ), - update_mask={"field_paths": [field]}, - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_update(self): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("cats", "cradle") - field_path = "head.foot" - value = u"knees toes shoulders" - field_updates = {field_path: value} - - ret_val = batch.update(reference, field_updates) - self.assertIsNone(ret_val) - - map_pb = document_pb2.MapValue(fields={"foot": _value_pb(string_value=value)}) - new_write_pb = write_pb2.Write( - update=document_pb2.Document( - name=reference._document_path, - fields={"head": _value_pb(map_value=map_pb)}, - ), - update_mask=common_pb2.DocumentMask(field_paths=[field_path]), - current_document=common_pb2.Precondition(exists=True), - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_delete(self): - from google.cloud.firestore_v1.proto import write_pb2 - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("early", "mornin", "dawn", "now") - ret_val = batch.delete(reference) - self.assertIsNone(ret_val) - new_write_pb = write_pb2.Write(delete=reference._document_path) - self.assertEqual(batch._write_pbs, [new_write_pb]) - def test_commit(self): from google.protobuf import timestamp_pb2 from google.cloud.firestore_v1.proto import firestore_pb2 @@ -252,12 +135,6 @@ def test_as_context_mgr_w_error(self): firestore_api.commit.assert_not_called() -def _value_pb(**kwargs): - from google.cloud.firestore_v1.proto.document_pb2 import Value - - return Value(**kwargs) - - def _make_credentials(): import google.auth.credentials From 2e7b4998deb502b180fc7f5c02b777519a128f33 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 23 Jun 2020 20:22:46 -0500 Subject: [PATCH 34/47] feat: create CollectionReference/AsyncCollectionReference superclass --- google/cloud/firestore_v1/async_collection.py | 201 +--------- google/cloud/firestore_v1/base_collection.py | 352 ++++++++++++++++++ google/cloud/firestore_v1/collection.py | 300 +-------------- tests/unit/v1/async/test_async_collection.py | 149 +------- tests/unit/v1/test_base_collection.py | 202 ++++++++++ tests/unit/v1/test_collection.py | 149 +------- 6 files changed, 602 insertions(+), 751 deletions(-) create mode 100644 google/cloud/firestore_v1/base_collection.py create mode 100644 tests/unit/v1/test_base_collection.py diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index cdc1a80fba..77c43107f7 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -16,8 +16,8 @@ import warnings -from google.cloud.firestore_v1.collection import ( - CollectionReference, +from google.cloud.firestore_v1.base_collection import ( + BaseCollectionReference, _auto_id, _item_to_document_ref, ) @@ -26,7 +26,7 @@ from google.cloud.firestore_v1 import async_document -class AsyncCollectionReference(CollectionReference): +class AsyncCollectionReference(BaseCollectionReference): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation @@ -55,6 +55,14 @@ class AsyncCollectionReference(CollectionReference): def __init__(self, *path, **kwargs): super(AsyncCollectionReference, self).__init__(*path, **kwargs) + def _query(self): + """Query factory. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query` + """ + return async_query.AsyncQuery(self) + async def add(self, document_data, document_id=None): """Create a document in the Firestore database with the provided data. @@ -113,191 +121,6 @@ async def list_documents(self, page_size=None): iterator.item_to_value = _item_to_document_ref return iterator - def select(self, field_paths): - """Create a "select" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.select` for - more information on this method. - - Args: - field_paths (Iterable[str, ...]): An iterable of field paths - (``.``-delimited list of field names) to use as a projection - of document fields in the query results. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A "projected" query. - """ - query = async_query.AsyncQuery(self) - return query.select(field_paths) - - def where(self, field_path, op_string, value): - """Create a "where" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.where` for - more information on this method. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) for the field to filter on. - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` - and ``>``. - value (Any): The value to compare the field against in the filter. - If ``value`` is :data:`None` or a NaN, then ``==`` is the only - allowed operation. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A filtered query. - """ - query = async_query.AsyncQuery(self) - return query.where(field_path, op_string, value) - - def order_by(self, field_path, **kwargs): - """Create an "order by" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.order_by` for - more information on this method. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) on which to order the query results. - kwargs (Dict[str, Any]): The keyword arguments to pass along - to the query. The only supported keyword is ``direction``, - see :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.order_by` - for more information. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - An "order by" query. - """ - query = async_query.AsyncQuery(self) - return query.order_by(field_path, **kwargs) - - def limit(self, count): - """Create a limited query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.limit` for - more information on this method. - - Args: - count (int): Maximum number of documents to return that match - the query. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A limited query. - """ - query = async_query.AsyncQuery(self) - return query.limit(count) - - def offset(self, num_to_skip): - """Skip to an offset in a query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.offset` for - more information on this method. - - Args: - num_to_skip (int): The number of results to skip at the beginning - of query results. (Must be non-negative.) - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - An offset query. - """ - query = async_query.AsyncQuery(self) - return query.offset(num_to_skip) - - def start_at(self, document_fields): - """Start query at a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.start_at` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A query with cursor. - """ - query = async_query.AsyncQuery(self) - return query.start_at(document_fields) - - def start_after(self, document_fields): - """Start query after a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.start_after` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A query with cursor. - """ - query = async_query.AsyncQuery(self) - return query.start_after(document_fields) - - def end_before(self, document_fields): - """End query before a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.end_before` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A query with cursor. - """ - query = async_query.AsyncQuery(self) - return query.end_before(document_fields) - - def end_at(self, document_fields): - """End query at a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.async_query.AsyncQuery.end_at` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.async_query.AsyncQuery`: - A query with cursor. - """ - query = async_query.AsyncQuery(self) - return query.end_at(document_fields) - async def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" warnings.warn( @@ -366,7 +189,7 @@ def on_snapshot(collection_snapshot, changes, read_time): collection_watch.unsubscribe() """ return Watch.for_query( - async_query.AsyncQuery(self), + self._query(), callback, async_document.DocumentSnapshot, async_document.AsyncDocumentReference, diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py new file mode 100644 index 0000000000..179f17f2cc --- /dev/null +++ b/google/cloud/firestore_v1/base_collection.py @@ -0,0 +1,352 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing collections for the Google Cloud Firestore API.""" +import random +import six + +from google.cloud.firestore_v1 import _helpers + +_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + + +class BaseCollectionReference(object): + """A reference to a collection in a Firestore database. + + The collection may already exist or this class can facilitate creation + of documents within the collection. + + Args: + path (Tuple[str, ...]): The components in the collection path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection. + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client` if provided. It + represents the client that created this collection reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + def __init__(self, *path, **kwargs): + _helpers.verify_path(path, is_collection=True) + self._path = path + self._client = kwargs.pop("client", None) + if kwargs: + raise TypeError( + "Received unexpected arguments", kwargs, "Only `client` is supported" + ) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._path == other._path and self._client == other._client + + @property + def id(self): + """The collection identifier. + + Returns: + str: The last component of the path. + """ + return self._path[-1] + + @property + def parent(self): + """Document that owns the current collection. + + Returns: + Optional[:class:`~google.cloud.firestore_v1.document.DocumentReference`]: + The parent document, if the current collection is not a + top-level collection. + """ + if len(self._path) == 1: + return None + else: + parent_path = self._path[:-1] + return self._client.document(*parent_path) + + def _query(self): + raise NotImplementedError + + def document(self, document_id=None): + """Create a sub-document underneath the current collection. + + Args: + document_id (Optional[str]): The document identifier + within the current collection. If not provided, will default + to a random 20 character string composed of digits, + uppercase and lowercase and letters. + + Returns: + :class:`~google.cloud.firestore_v1.document.DocumentReference`: + The child document. + """ + if document_id is None: + document_id = _auto_id() + + child_path = self._path + (document_id,) + return self._client.document(*child_path) + + def _parent_info(self): + """Get fully-qualified parent path and prefix for this collection. + + Returns: + Tuple[str, str]: Pair of + + * the fully-qualified (with database and project) path to the + parent of this collection (will either be the database path + or a document path). + * the prefix to a document in this collection. + """ + parent_doc = self.parent + if parent_doc is None: + parent_path = _helpers.DOCUMENT_PATH_DELIMITER.join( + (self._client._database_string, "documents") + ) + else: + parent_path = parent_doc._document_path + + expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) + return parent_path, expected_prefix + + def add(self, document_data, document_id=None): + raise NotImplementedError + + def list_documents(self, page_size=None): + raise NotImplementedError + + def select(self, field_paths): + """Create a "select" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.select` for + more information on this method. + + Args: + field_paths (Iterable[str, ...]): An iterable of field paths + (``.``-delimited list of field names) to use as a projection + of document fields in the query results. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A "projected" query. + """ + query = self._query() + return query.select(field_paths) + + def where(self, field_path, op_string, value): + """Create a "where" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.where` for + more information on this method. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) for the field to filter on. + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=`` + and ``>``. + value (Any): The value to compare the field against in the filter. + If ``value`` is :data:`None` or a NaN, then ``==`` is the only + allowed operation. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A filtered query. + """ + query = self._query() + return query.where(field_path, op_string, value) + + def order_by(self, field_path, **kwargs): + """Create an "order by" query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.order_by` for + more information on this method. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) on which to order the query results. + kwargs (Dict[str, Any]): The keyword arguments to pass along + to the query. The only supported keyword is ``direction``, + see :meth:`~google.cloud.firestore_v1.query.Query.order_by` + for more information. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An "order by" query. + """ + query = self._query() + return query.order_by(field_path, **kwargs) + + def limit(self, count): + """Create a limited query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.limit` for + more information on this method. + + Args: + count (int): Maximum number of documents to return that match + the query. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A limited query. + """ + query = self._query() + return query.limit(count) + + def offset(self, num_to_skip): + """Skip to an offset in a query with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.offset` for + more information on this method. + + Args: + num_to_skip (int): The number of results to skip at the beginning + of query results. (Must be non-negative.) + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An offset query. + """ + query = self._query() + return query.offset(num_to_skip) + + def start_at(self, document_fields): + """Start query at a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.start_at` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = self._query() + return query.start_at(document_fields) + + def start_after(self, document_fields): + """Start query after a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.start_after` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = self._query() + return query.start_after(document_fields) + + def end_before(self, document_fields): + """End query before a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.end_before` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = self._query() + return query.end_before(document_fields) + + def end_at(self, document_fields): + """End query at a cursor with this collection as parent. + + See + :meth:`~google.cloud.firestore_v1.query.Query.end_at` for + more information on this method. + + Args: + document_fields (Union[:class:`~google.cloud.firestore_v1.\ + document.DocumentSnapshot`, dict, list, tuple]): + A document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. + """ + query = self._query() + return query.end_at(document_fields) + + def get(self, transaction=None): + raise NotImplementedError + + def stream(self, transaction=None): + raise NotImplementedError + + def on_snapshot(self, callback): + raise NotImplementedError + + +def _auto_id(): + """Generate a "random" automatically generated ID. + + Returns: + str: A 20 character string composed of digits, uppercase and + lowercase and letters. + """ + return "".join(random.choice(_AUTO_ID_CHARS) for _ in six.moves.xrange(20)) + + +def _item_to_document_ref(iterator, item): + """Convert Document resource to document ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (dict): document resource + """ + document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] + return iterator.collection.document(document_id) diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 27c3eeaa31..8659af0ed8 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -13,20 +13,19 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" -import random import warnings -import six - -from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_collection import ( + BaseCollectionReference, + _auto_id, + _item_to_document_ref, +) from google.cloud.firestore_v1 import query as query_mod from google.cloud.firestore_v1.watch import Watch from google.cloud.firestore_v1 import document -_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" - -class CollectionReference(object): +class CollectionReference(BaseCollectionReference): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation @@ -53,83 +52,15 @@ class CollectionReference(object): """ def __init__(self, *path, **kwargs): - _helpers.verify_path(path, is_collection=True) - self._path = path - self._client = kwargs.pop("client", None) - if kwargs: - raise TypeError( - "Received unexpected arguments", kwargs, "Only `client` is supported" - ) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self._path == other._path and self._client == other._client - - @property - def id(self): - """The collection identifier. - - Returns: - str: The last component of the path. - """ - return self._path[-1] - - @property - def parent(self): - """Document that owns the current collection. - - Returns: - Optional[:class:`~google.cloud.firestore_v1.document.DocumentReference`]: - The parent document, if the current collection is not a - top-level collection. - """ - if len(self._path) == 1: - return None - else: - parent_path = self._path[:-1] - return self._client.document(*parent_path) - - def document(self, document_id=None): - """Create a sub-document underneath the current collection. - - Args: - document_id (Optional[str]): The document identifier - within the current collection. If not provided, will default - to a random 20 character string composed of digits, - uppercase and lowercase and letters. - - Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: - The child document. - """ - if document_id is None: - document_id = _auto_id() + super(CollectionReference, self).__init__(*path, **kwargs) - child_path = self._path + (document_id,) - return self._client.document(*child_path) - - def _parent_info(self): - """Get fully-qualified parent path and prefix for this collection. + def _query(self): + """Query factory. Returns: - Tuple[str, str]: Pair of - - * the fully-qualified (with database and project) path to the - parent of this collection (will either be the database path - or a document path). - * the prefix to a document in this collection. + :class:`~google.cloud.firestore_v1.query.Query` """ - parent_doc = self.parent - if parent_doc is None: - parent_path = _helpers.DOCUMENT_PATH_DELIMITER.join( - (self._client._database_string, "documents") - ) - else: - parent_path = parent_doc._document_path - - expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) - return parent_path, expected_prefix + return query_mod.Query(self) def add(self, document_data, document_id=None): """Create a document in the Firestore database with the provided data. @@ -189,191 +120,6 @@ def list_documents(self, page_size=None): iterator.item_to_value = _item_to_document_ref return iterator - def select(self, field_paths): - """Create a "select" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.select` for - more information on this method. - - Args: - field_paths (Iterable[str, ...]): An iterable of field paths - (``.``-delimited list of field names) to use as a projection - of document fields in the query results. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A "projected" query. - """ - query = query_mod.Query(self) - return query.select(field_paths) - - def where(self, field_path, op_string, value): - """Create a "where" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.where` for - more information on this method. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) for the field to filter on. - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` - and ``>``. - value (Any): The value to compare the field against in the filter. - If ``value`` is :data:`None` or a NaN, then ``==`` is the only - allowed operation. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A filtered query. - """ - query = query_mod.Query(self) - return query.where(field_path, op_string, value) - - def order_by(self, field_path, **kwargs): - """Create an "order by" query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.order_by` for - more information on this method. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) on which to order the query results. - kwargs (Dict[str, Any]): The keyword arguments to pass along - to the query. The only supported keyword is ``direction``, - see :meth:`~google.cloud.firestore_v1.query.Query.order_by` - for more information. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - An "order by" query. - """ - query = query_mod.Query(self) - return query.order_by(field_path, **kwargs) - - def limit(self, count): - """Create a limited query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.limit` for - more information on this method. - - Args: - count (int): Maximum number of documents to return that match - the query. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A limited query. - """ - query = query_mod.Query(self) - return query.limit(count) - - def offset(self, num_to_skip): - """Skip to an offset in a query with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.offset` for - more information on this method. - - Args: - num_to_skip (int): The number of results to skip at the beginning - of query results. (Must be non-negative.) - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - An offset query. - """ - query = query_mod.Query(self) - return query.offset(num_to_skip) - - def start_at(self, document_fields): - """Start query at a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.start_at` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. - """ - query = query_mod.Query(self) - return query.start_at(document_fields) - - def start_after(self, document_fields): - """Start query after a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.start_after` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. - """ - query = query_mod.Query(self) - return query.start_after(document_fields) - - def end_before(self, document_fields): - """End query before a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.end_before` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. - """ - query = query_mod.Query(self) - return query.end_before(document_fields) - - def end_at(self, document_fields): - """End query at a cursor with this collection as parent. - - See - :meth:`~google.cloud.firestore_v1.query.Query.end_at` for - more information on this method. - - Args: - document_fields (Union[:class:`~google.cloud.firestore_v1.\ - document.DocumentSnapshot`, dict, list, tuple]): - A document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. - """ - query = query_mod.Query(self) - return query.end_at(document_fields) - def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" warnings.warn( @@ -440,30 +186,8 @@ def on_snapshot(collection_snapshot, changes, read_time): collection_watch.unsubscribe() """ return Watch.for_query( - query_mod.Query(self), + self._query(), callback, document.DocumentSnapshot, document.DocumentReference, ) - - -def _auto_id(): - """Generate a "random" automatically generated ID. - - Returns: - str: A 20 character string composed of digits, uppercase and - lowercase and letters. - """ - return "".join(random.choice(_AUTO_ID_CHARS) for _ in six.moves.xrange(20)) - - -def _item_to_document_ref(iterator, item): - """Convert Document resource to document ref. - - Args: - iterator (google.api_core.page_iterator.GRPCIterator): - iterator response - item (dict): document resource - """ - document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] - return iterator.collection.document(document_id) diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 055d7ae353..91c64373ca 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -42,10 +42,18 @@ def _make_one(self, *args, **kwargs): @staticmethod def _get_public_methods(klass): - return set( - name - for name, value in six.iteritems(klass.__dict__) - if (not name.startswith("_") and isinstance(value, types.FunctionType)) + return set().union( + *( + ( + name + for name, value in six.iteritems(class_.__dict__) + if ( + not name.startswith("_") + and isinstance(value, types.FunctionType) + ) + ) + for class_ in (klass,) + klass.__bases__ + ) ) def test_query_method_matching(self): @@ -85,119 +93,6 @@ def test_constructor_invalid_kwarg(self): with self.assertRaises(TypeError): self._make_one("Coh-lek-shun", donut=True) - def test___eq___other_type(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = object() - self.assertFalse(collection == other) - - def test___eq___different_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("other", client=client) - self.assertFalse(collection == other) - - def test___eq___same_path_different_client(self): - client = mock.sentinel.client - other_client = mock.sentinel.other_client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=other_client) - self.assertFalse(collection == other) - - def test___eq___same_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=client) - self.assertTrue(collection == other) - - def test_id_property(self): - collection_id = "hi-bob" - collection = self._make_one(collection_id) - self.assertEqual(collection.id, collection_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - collection_id1 = "grocery-store" - document_id = "market" - collection_id2 = "darth" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - - parent = collection.parent - self.assertIsInstance(parent, AsyncDocumentReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id1, document_id)) - - def test_parent_property_top_level(self): - collection = self._make_one("tahp-leh-vull") - self.assertIsNone(collection.parent) - - def test_document_factory_explicit_id(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - collection = self._make_one(collection_id, client=client) - - child = collection.document(document_id) - self.assertIsInstance(child, AsyncDocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id)) - - @mock.patch( - "google.cloud.firestore_v1.collection._auto_id", - return_value="zorpzorpthreezorp012", - ) - def test_document_factory_auto_id(self, mock_auto_id): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - collection_name = "space-town" - client = _make_client() - collection = self._make_one(collection_name, client=client) - - child = collection.document() - self.assertIsInstance(child, AsyncDocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_name, mock_auto_id.return_value)) - - mock_auto_id.assert_called_once_with() - - def test__parent_info_top_level(self): - client = _make_client() - collection_id = "soap" - collection = self._make_one(collection_id, client=client) - - parent_path, expected_prefix = collection._parent_info() - - expected_path = "projects/{}/databases/{}/documents".format( - client.project, client._database - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id) - self.assertEqual(expected_prefix, prefix) - - def test__parent_info_nested(self): - collection_id1 = "bar" - document_id = "baz" - collection_id2 = "chunk" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - - parent_path, expected_prefix = collection._parent_info() - - expected_path = "projects/{}/databases/{}/documents/{}/{}".format( - client.project, client._database, collection_id1, document_id - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id2) - self.assertEqual(expected_prefix, prefix) - @pytest.mark.asyncio async def test_add_auto_assigned(self): from google.cloud.firestore_v1.proto import document_pb2 @@ -582,26 +477,6 @@ def test_on_snapshot(self, watch): watch.for_query.assert_called_once() -class Test__auto_id(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(): - from google.cloud.firestore_v1.async_collection import _auto_id - - return _auto_id() - - @mock.patch("random.choice") - def test_it(self, mock_rand_choice): - from google.cloud.firestore_v1.collection import _AUTO_ID_CHARS - - mock_result = "0123456789abcdefghij" - mock_rand_choice.side_effect = list(mock_result) - result = self._call_fut() - self.assertEqual(result, mock_result) - - mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 - self.assertEqual(mock_rand_choice.mock_calls, mock_calls) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py new file mode 100644 index 0000000000..c73a10a818 --- /dev/null +++ b/tests/unit/v1/test_base_collection.py @@ -0,0 +1,202 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +class TestCollectionReference(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + return BaseCollectionReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + self.assertIs(collection._client, client) + expected_path = (collection_id1, document_id, collection_id2) + self.assertEqual(collection._path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(99, "doc", "bad-collection-id") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None, "sub-collection") + with self.assertRaises(ValueError): + self._make_one("Just", "A-Document") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", donut=True) + + def test___eq___other_type(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = object() + self.assertFalse(collection == other) + + def test___eq___different_path_same_client(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = self._make_one("other", client=client) + self.assertFalse(collection == other) + + def test___eq___same_path_different_client(self): + client = mock.sentinel.client + other_client = mock.sentinel.other_client + collection = self._make_one("name", client=client) + other = self._make_one("name", client=other_client) + self.assertFalse(collection == other) + + def test___eq___same_path_same_client(self): + client = mock.sentinel.client + collection = self._make_one("name", client=client) + other = self._make_one("name", client=client) + self.assertTrue(collection == other) + + def test_id_property(self): + collection_id = "hi-bob" + collection = self._make_one(collection_id) + self.assertEqual(collection.id, collection_id) + + def test_parent_property(self): + from google.cloud.firestore_v1.document import DocumentReference + + collection_id1 = "grocery-store" + document_id = "market" + collection_id2 = "darth" + client = _make_client() + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + + parent = collection.parent + self.assertIsInstance(parent, DocumentReference) + self.assertIs(parent._client, client) + self.assertEqual(parent._path, (collection_id1, document_id)) + + def test_parent_property_top_level(self): + collection = self._make_one("tahp-leh-vull") + self.assertIsNone(collection.parent) + + def test_document_factory_explicit_id(self): + from google.cloud.firestore_v1.document import DocumentReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + collection = self._make_one(collection_id, client=client) + + child = collection.document(document_id) + self.assertIsInstance(child, DocumentReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_id, document_id)) + + @mock.patch( + "google.cloud.firestore_v1.base_collection._auto_id", + return_value="zorpzorpthreezorp012", + ) + def test_document_factory_auto_id(self, mock_auto_id): + from google.cloud.firestore_v1.document import DocumentReference + + collection_name = "space-town" + client = _make_client() + collection = self._make_one(collection_name, client=client) + + child = collection.document() + self.assertIsInstance(child, DocumentReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_name, mock_auto_id.return_value)) + + mock_auto_id.assert_called_once_with() + + def test__parent_info_top_level(self): + client = _make_client() + collection_id = "soap" + collection = self._make_one(collection_id, client=client) + + parent_path, expected_prefix = collection._parent_info() + + expected_path = "projects/{}/databases/{}/documents".format( + client.project, client._database + ) + self.assertEqual(parent_path, expected_path) + prefix = "{}/{}".format(expected_path, collection_id) + self.assertEqual(expected_prefix, prefix) + + def test__parent_info_nested(self): + collection_id1 = "bar" + document_id = "baz" + collection_id2 = "chunk" + client = _make_client() + collection = self._make_one( + collection_id1, document_id, collection_id2, client=client + ) + + parent_path, expected_prefix = collection._parent_info() + + expected_path = "projects/{}/databases/{}/documents/{}/{}".format( + client.project, client._database, collection_id1, document_id + ) + self.assertEqual(parent_path, expected_path) + prefix = "{}/{}".format(expected_path, collection_id2) + self.assertEqual(expected_prefix, prefix) + + +class Test__auto_id(unittest.TestCase): + @staticmethod + def _call_fut(): + from google.cloud.firestore_v1.base_collection import _auto_id + + return _auto_id() + + @mock.patch("random.choice") + def test_it(self, mock_rand_choice): + from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS + + mock_result = "0123456789abcdefghij" + mock_rand_choice.side_effect = list(mock_result) + result = self._call_fut() + self.assertEqual(result, mock_result) + + mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 + self.assertEqual(mock_rand_choice.mock_calls, mock_calls) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project="project-project", credentials=credentials) diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index fde538b9db..1ef2e66746 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -32,10 +32,18 @@ def _make_one(self, *args, **kwargs): @staticmethod def _get_public_methods(klass): - return set( - name - for name, value in six.iteritems(klass.__dict__) - if (not name.startswith("_") and isinstance(value, types.FunctionType)) + return set().union( + *( + ( + name + for name, value in six.iteritems(class_.__dict__) + if ( + not name.startswith("_") + and isinstance(value, types.FunctionType) + ) + ) + for class_ in (klass,) + klass.__bases__ + ) ) def test_query_method_matching(self): @@ -75,119 +83,6 @@ def test_constructor_invalid_kwarg(self): with self.assertRaises(TypeError): self._make_one("Coh-lek-shun", donut=True) - def test___eq___other_type(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = object() - self.assertFalse(collection == other) - - def test___eq___different_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("other", client=client) - self.assertFalse(collection == other) - - def test___eq___same_path_different_client(self): - client = mock.sentinel.client - other_client = mock.sentinel.other_client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=other_client) - self.assertFalse(collection == other) - - def test___eq___same_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=client) - self.assertTrue(collection == other) - - def test_id_property(self): - collection_id = "hi-bob" - collection = self._make_one(collection_id) - self.assertEqual(collection.id, collection_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.document import DocumentReference - - collection_id1 = "grocery-store" - document_id = "market" - collection_id2 = "darth" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - - parent = collection.parent - self.assertIsInstance(parent, DocumentReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id1, document_id)) - - def test_parent_property_top_level(self): - collection = self._make_one("tahp-leh-vull") - self.assertIsNone(collection.parent) - - def test_document_factory_explicit_id(self): - from google.cloud.firestore_v1.document import DocumentReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - collection = self._make_one(collection_id, client=client) - - child = collection.document(document_id) - self.assertIsInstance(child, DocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id)) - - @mock.patch( - "google.cloud.firestore_v1.collection._auto_id", - return_value="zorpzorpthreezorp012", - ) - def test_document_factory_auto_id(self, mock_auto_id): - from google.cloud.firestore_v1.document import DocumentReference - - collection_name = "space-town" - client = _make_client() - collection = self._make_one(collection_name, client=client) - - child = collection.document() - self.assertIsInstance(child, DocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_name, mock_auto_id.return_value)) - - mock_auto_id.assert_called_once_with() - - def test__parent_info_top_level(self): - client = _make_client() - collection_id = "soap" - collection = self._make_one(collection_id, client=client) - - parent_path, expected_prefix = collection._parent_info() - - expected_path = "projects/{}/databases/{}/documents".format( - client.project, client._database - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id) - self.assertEqual(expected_prefix, prefix) - - def test__parent_info_nested(self): - collection_id1 = "bar" - document_id = "baz" - collection_id2 = "chunk" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - - parent_path, expected_prefix = collection._parent_info() - - expected_path = "projects/{}/databases/{}/documents/{}/{}".format( - client.project, client._database, collection_id1, document_id - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id2) - self.assertEqual(expected_prefix, prefix) - def test_add_auto_assigned(self): from google.cloud.firestore_v1.proto import document_pb2 from google.cloud.firestore_v1.document import DocumentReference @@ -545,26 +440,6 @@ def test_on_snapshot(self, watch): watch.for_query.assert_called_once() -class Test__auto_id(unittest.TestCase): - @staticmethod - def _call_fut(): - from google.cloud.firestore_v1.collection import _auto_id - - return _auto_id() - - @mock.patch("random.choice") - def test_it(self, mock_rand_choice): - from google.cloud.firestore_v1.collection import _AUTO_ID_CHARS - - mock_result = "0123456789abcdefghij" - mock_rand_choice.side_effect = list(mock_result) - result = self._call_fut() - self.assertEqual(result, mock_result) - - mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 - self.assertEqual(mock_rand_choice.mock_calls, mock_calls) - - def _make_credentials(): import google.auth.credentials From 96cd765c7a0b6bac2dc2df9a37662aed9f29d38d Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 24 Jun 2020 16:51:40 -0500 Subject: [PATCH 35/47] feat: create DocumentReference/AsyncDocumentReference superclass --- google/cloud/firestore_v1/async_document.py | 6 +- google/cloud/firestore_v1/base_client.py | 2 +- google/cloud/firestore_v1/base_document.py | 457 ++++++++++++++++++++ google/cloud/firestore_v1/document.py | 401 +---------------- tests/unit/v1/async/test_async_document.py | 368 +--------------- tests/unit/v1/test_base_document.py | 427 ++++++++++++++++++ tests/unit/v1/test_document.py | 368 +--------------- 7 files changed, 911 insertions(+), 1118 deletions(-) create mode 100644 google/cloud/firestore_v1/base_document.py create mode 100644 tests/unit/v1/test_base_document.py diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 1111f6b19d..1cd66b57d7 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -16,8 +16,8 @@ import six -from google.cloud.firestore_v1.document import ( - DocumentReference, +from google.cloud.firestore_v1.base_document import ( + BaseDocumentReference, DocumentSnapshot, _first_write_result, _item_to_collection_ref, @@ -29,7 +29,7 @@ from google.cloud.firestore_v1.watch import Watch -class AsyncDocumentReference(DocumentReference): +class AsyncDocumentReference(BaseDocumentReference): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index d020c251a7..ff6e0f40cc 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -32,7 +32,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1.document import DocumentSnapshot +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.gapic import firestore_client from google.cloud.firestore_v1.gapic.transports import firestore_grpc_transport diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py new file mode 100644 index 0000000000..f04956293e --- /dev/null +++ b/google/cloud/firestore_v1/base_document.py @@ -0,0 +1,457 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing documents for the Google Cloud Firestore API.""" + +import copy + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import field_path as field_path_module + + +class BaseDocumentReference(object): + """A reference to a document in a Firestore database. + + The document may already exist or can be created by this class. + + Args: + path (Tuple[str, ...]): The components in the document path. + This is a series of strings representing each collection and + sub-collection ID, as well as the document IDs for any documents + that contain a sub-collection (as well as the base document). + kwargs (dict): The keyword arguments for the constructor. The only + supported keyword is ``client`` and it must be a + :class:`~google.cloud.firestore_v1.client.Client`. It represents + the client that created this document reference. + + Raises: + ValueError: if + + * the ``path`` is empty + * there are an even number of elements + * a collection ID in ``path`` is not a string + * a document ID in ``path`` is not a string + TypeError: If a keyword other than ``client`` is used. + """ + + _document_path_internal = None + + def __init__(self, *path, **kwargs): + _helpers.verify_path(path, is_collection=False) + self._path = path + self._client = kwargs.pop("client", None) + if kwargs: + raise TypeError( + "Received unexpected arguments", kwargs, "Only `client` is supported" + ) + + def __copy__(self): + """Shallow copy the instance. + + We leave the client "as-is" but tuple-unpack the path. + + Returns: + .DocumentReference: A copy of the current document. + """ + result = self.__class__(*self._path, client=self._client) + result._document_path_internal = self._document_path_internal + return result + + def __deepcopy__(self, unused_memo): + """Deep copy the instance. + + This isn't a true deep copy, wee leave the client "as-is" but + tuple-unpack the path. + + Returns: + .DocumentReference: A copy of the current document. + """ + return self.__copy__() + + def __eq__(self, other): + """Equality check against another instance. + + Args: + other (Any): A value to compare against. + + Returns: + Union[bool, NotImplementedType]: Indicating if the values are + equal. + """ + if isinstance(other, self.__class__): + return self._client == other._client and self._path == other._path + else: + return NotImplemented + + def __hash__(self): + return hash(self._path) + hash(self._client) + + def __ne__(self, other): + """Inequality check against another instance. + + Args: + other (Any): A value to compare against. + + Returns: + Union[bool, NotImplementedType]: Indicating if the values are + not equal. + """ + if isinstance(other, self.__class__): + return self._client != other._client or self._path != other._path + else: + return NotImplemented + + @property + def path(self): + """Database-relative for this document. + + Returns: + str: The document's relative path. + """ + return "/".join(self._path) + + @property + def _document_path(self): + """Create and cache the full path for this document. + + Of the form: + + ``projects/{project_id}/databases/{database_id}/... + documents/{document_path}`` + + Returns: + str: The full document path. + + Raises: + ValueError: If the current document reference has no ``client``. + """ + if self._document_path_internal is None: + if self._client is None: + raise ValueError("A document reference requires a `client`.") + self._document_path_internal = _get_document_path(self._client, self._path) + + return self._document_path_internal + + @property + def id(self): + """The document identifier (within its collection). + + Returns: + str: The last component of the path. + """ + return self._path[-1] + + @property + def parent(self): + """Collection that owns the current document. + + Returns: + :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + The parent collection. + """ + parent_path = self._path[:-1] + return self._client.collection(*parent_path) + + def collection(self, collection_id): + """Create a sub-collection underneath the current document. + + Args: + collection_id (str): The sub-collection identifier (sometimes + referred to as the "kind"). + + Returns: + :class:`~google.cloud.firestore_v1.collection.CollectionReference`: + The child collection. + """ + child_path = self._path + (collection_id,) + return self._client.collection(*child_path) + + def create(self, document_data): + raise NotImplementedError + + def set(self, document_data, merge=False): + raise NotImplementedError + + def update(self, field_updates, option=None): + raise NotImplementedError + + def delete(self, option=None): + raise NotImplementedError + + def get(self, field_paths=None, transaction=None): + raise NotImplementedError + + def collections(self, page_size=None): + raise NotImplementedError + + def on_snapshot(self, callback): + raise NotImplementedError + + +class DocumentSnapshot(object): + """A snapshot of document data in a Firestore database. + + This represents data retrieved at a specific time and may not contain + all fields stored for the document (i.e. a hand-picked selection of + fields may have been retrieved). + + Instances of this class are not intended to be constructed by hand, + rather they'll be returned as responses to various methods, such as + :meth:`~google.cloud.DocumentReference.get`. + + Args: + reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): + A document reference corresponding to the document that contains + the data in this snapshot. + data (Dict[str, Any]): + The data retrieved in the snapshot. + exists (bool): + Indicates if the document existed at the time the snapshot was + retrieved. + read_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this snapshot was read from the server. + create_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this document was created. + update_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): + The time that this document was last updated. + """ + + def __init__(self, reference, data, exists, read_time, create_time, update_time): + self._reference = reference + # We want immutable data, so callers can't modify this value + # out from under us. + self._data = copy.deepcopy(data) + self._exists = exists + self.read_time = read_time + self.create_time = create_time + self.update_time = update_time + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._reference == other._reference and self._data == other._data + + def __hash__(self): + seconds = self.update_time.seconds + nanos = self.update_time.nanos + return hash(self._reference) + hash(seconds) + hash(nanos) + + @property + def _client(self): + """The client that owns the document reference for this snapshot. + + Returns: + :class:`~google.cloud.firestore_v1.client.Client`: + The client that owns this document. + """ + return self._reference._client + + @property + def exists(self): + """Existence flag. + + Indicates if the document existed at the time this snapshot + was retrieved. + + Returns: + bool: The existence flag. + """ + return self._exists + + @property + def id(self): + """The document identifier (within its collection). + + Returns: + str: The last component of the path of the document. + """ + return self._reference.id + + @property + def reference(self): + """Document reference corresponding to document that owns this data. + + Returns: + :class:`~google.cloud.firestore_v1.document.DocumentReference`: + A document reference corresponding to this document. + """ + return self._reference + + def get(self, field_path): + """Get a value from the snapshot data. + + If the data is nested, for example: + + .. code-block:: python + + >>> snapshot.to_dict() + { + 'top1': { + 'middle2': { + 'bottom3': 20, + 'bottom4': 22, + }, + 'middle5': True, + }, + 'top6': b'\x00\x01 foo', + } + + a **field path** can be used to access the nested data. For + example: + + .. code-block:: python + + >>> snapshot.get('top1') + { + 'middle2': { + 'bottom3': 20, + 'bottom4': 22, + }, + 'middle5': True, + } + >>> snapshot.get('top1.middle2') + { + 'bottom3': 20, + 'bottom4': 22, + } + >>> snapshot.get('top1.middle2.bottom3') + 20 + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + A copy is returned since the data may contain mutable values, + but the data stored in the snapshot must remain immutable. + + Args: + field_path (str): A field path (``.``-delimited list of + field names). + + Returns: + Any or None: + (A copy of) the value stored for the ``field_path`` or + None if snapshot document does not exist. + + Raises: + KeyError: If the ``field_path`` does not match nested data + in the snapshot. + """ + if not self._exists: + return None + nested_data = field_path_module.get_nested_value(field_path, self._data) + return copy.deepcopy(nested_data) + + def to_dict(self): + """Retrieve the data contained in this snapshot. + + A copy is returned since the data may contain mutable values, + but the data stored in the snapshot must remain immutable. + + Returns: + Dict[str, Any] or None: + The data in the snapshot. Returns None if reference + does not exist. + """ + if not self._exists: + return None + return copy.deepcopy(self._data) + + +def _get_document_path(client, path): + """Convert a path tuple into a full path string. + + Of the form: + + ``projects/{project_id}/databases/{database_id}/... + documents/{document_path}`` + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that holds configuration details and a GAPIC client + object. + path (Tuple[str, ...]): The components in a document path. + + Returns: + str: The fully-qualified document path. + """ + parts = (client._database_string, "documents") + path + return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) + + +def _consume_single_get(response_iterator): + """Consume a gRPC stream that should contain a single response. + + The stream will correspond to a ``BatchGetDocuments`` request made + for a single document. + + Args: + response_iterator (~google.cloud.exceptions.GrpcRendezvous): A + streaming iterator returned from a ``BatchGetDocuments`` + request. + + Returns: + ~google.cloud.proto.firestore.v1.\ + firestore_pb2.BatchGetDocumentsResponse: The single "get" + response in the batch. + + Raises: + ValueError: If anything other than exactly one response is returned. + """ + # Calling ``list()`` consumes the entire iterator. + all_responses = list(response_iterator) + if len(all_responses) != 1: + raise ValueError( + "Unexpected response from `BatchGetDocumentsResponse`", + all_responses, + "Expected only one result", + ) + + return all_responses[0] + + +def _first_write_result(write_results): + """Get first write result from list. + + For cases where ``len(write_results) > 1``, this assumes the writes + occurred at the same time (e.g. if an update and transform are sent + at the same time). + + Args: + write_results (List[google.cloud.proto.firestore.v1.\ + write_pb2.WriteResult, ...]: The write results from a + ``CommitResponse``. + + Returns: + google.cloud.firestore_v1.types.WriteResult: The + lone write result from ``write_results``. + + Raises: + ValueError: If there are zero write results. This is likely to + **never** occur, since the backend should be stable. + """ + if not write_results: + raise ValueError("Expected at least one write result") + + return write_results[0] + + +def _item_to_collection_ref(iterator, item): + """Convert collection ID to collection ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (str): ID of the collection + """ + return iterator.document.collection(item) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index c51c7c5c74..bbe2ca19cd 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -14,18 +14,22 @@ """Classes for representing documents for the Google Cloud Firestore API.""" -import copy - import six +from google.cloud.firestore_v1.base_document import ( + BaseDocumentReference, + DocumentSnapshot, + _first_write_result, + _item_to_collection_ref, +) + from google.api_core import exceptions from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import field_path as field_path_module from google.cloud.firestore_v1.proto import common_pb2 from google.cloud.firestore_v1.watch import Watch -class DocumentReference(object): +class DocumentReference(BaseDocumentReference): """A reference to a document in a Firestore database. The document may already exist or can be created by this class. @@ -50,137 +54,8 @@ class DocumentReference(object): TypeError: If a keyword other than ``client`` is used. """ - _document_path_internal = None - def __init__(self, *path, **kwargs): - _helpers.verify_path(path, is_collection=False) - self._path = path - self._client = kwargs.pop("client", None) - if kwargs: - raise TypeError( - "Received unexpected arguments", kwargs, "Only `client` is supported" - ) - - def __copy__(self): - """Shallow copy the instance. - - We leave the client "as-is" but tuple-unpack the path. - - Returns: - .DocumentReference: A copy of the current document. - """ - result = self.__class__(*self._path, client=self._client) - result._document_path_internal = self._document_path_internal - return result - - def __deepcopy__(self, unused_memo): - """Deep copy the instance. - - This isn't a true deep copy, wee leave the client "as-is" but - tuple-unpack the path. - - Returns: - .DocumentReference: A copy of the current document. - """ - return self.__copy__() - - def __eq__(self, other): - """Equality check against another instance. - - Args: - other (Any): A value to compare against. - - Returns: - Union[bool, NotImplementedType]: Indicating if the values are - equal. - """ - if isinstance(other, self.__class__): - return self._client == other._client and self._path == other._path - else: - return NotImplemented - - def __hash__(self): - return hash(self._path) + hash(self._client) - - def __ne__(self, other): - """Inequality check against another instance. - - Args: - other (Any): A value to compare against. - - Returns: - Union[bool, NotImplementedType]: Indicating if the values are - not equal. - """ - if isinstance(other, self.__class__): - return self._client != other._client or self._path != other._path - else: - return NotImplemented - - @property - def path(self): - """Database-relative for this document. - - Returns: - str: The document's relative path. - """ - return "/".join(self._path) - - @property - def _document_path(self): - """Create and cache the full path for this document. - - Of the form: - - ``projects/{project_id}/databases/{database_id}/... - documents/{document_path}`` - - Returns: - str: The full document path. - - Raises: - ValueError: If the current document reference has no ``client``. - """ - if self._document_path_internal is None: - if self._client is None: - raise ValueError("A document reference requires a `client`.") - self._document_path_internal = _get_document_path(self._client, self._path) - - return self._document_path_internal - - @property - def id(self): - """The document identifier (within its collection). - - Returns: - str: The last component of the path. - """ - return self._path[-1] - - @property - def parent(self): - """Collection that owns the current document. - - Returns: - :class:`~google.cloud.firestore_v1.collection.CollectionReference`: - The parent collection. - """ - parent_path = self._path[:-1] - return self._client.collection(*parent_path) - - def collection(self, collection_id): - """Create a sub-collection underneath the current document. - - Args: - collection_id (str): The sub-collection identifier (sometimes - referred to as the "kind"). - - Returns: - :class:`~google.cloud.firestore_v1.collection.CollectionReference`: - The child collection. - """ - child_path = self._path + (collection_id,) - return self._client.collection(*child_path) + super(DocumentReference, self).__init__(*path, **kwargs) def create(self, document_data): """Create the current document in the Firestore database. @@ -526,261 +401,3 @@ def on_snapshot(document_snapshot, changes, read_time): doc_watch.unsubscribe() """ return Watch.for_document(self, callback, DocumentSnapshot, DocumentReference) - - -class DocumentSnapshot(object): - """A snapshot of document data in a Firestore database. - - This represents data retrieved at a specific time and may not contain - all fields stored for the document (i.e. a hand-picked selection of - fields may have been retrieved). - - Instances of this class are not intended to be constructed by hand, - rather they'll be returned as responses to various methods, such as - :meth:`~google.cloud.DocumentReference.get`. - - Args: - reference (:class:`~google.cloud.firestore_v1.document.DocumentReference`): - A document reference corresponding to the document that contains - the data in this snapshot. - data (Dict[str, Any]): - The data retrieved in the snapshot. - exists (bool): - Indicates if the document existed at the time the snapshot was - retrieved. - read_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): - The time that this snapshot was read from the server. - create_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): - The time that this document was created. - update_time (:class:`google.protobuf.timestamp_pb2.Timestamp`): - The time that this document was last updated. - """ - - def __init__(self, reference, data, exists, read_time, create_time, update_time): - self._reference = reference - # We want immutable data, so callers can't modify this value - # out from under us. - self._data = copy.deepcopy(data) - self._exists = exists - self.read_time = read_time - self.create_time = create_time - self.update_time = update_time - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self._reference == other._reference and self._data == other._data - - def __hash__(self): - seconds = self.update_time.seconds - nanos = self.update_time.nanos - return hash(self._reference) + hash(seconds) + hash(nanos) - - @property - def _client(self): - """The client that owns the document reference for this snapshot. - - Returns: - :class:`~google.cloud.firestore_v1.client.Client`: - The client that owns this document. - """ - return self._reference._client - - @property - def exists(self): - """Existence flag. - - Indicates if the document existed at the time this snapshot - was retrieved. - - Returns: - bool: The existence flag. - """ - return self._exists - - @property - def id(self): - """The document identifier (within its collection). - - Returns: - str: The last component of the path of the document. - """ - return self._reference.id - - @property - def reference(self): - """Document reference corresponding to document that owns this data. - - Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: - A document reference corresponding to this document. - """ - return self._reference - - def get(self, field_path): - """Get a value from the snapshot data. - - If the data is nested, for example: - - .. code-block:: python - - >>> snapshot.to_dict() - { - 'top1': { - 'middle2': { - 'bottom3': 20, - 'bottom4': 22, - }, - 'middle5': True, - }, - 'top6': b'\x00\x01 foo', - } - - a **field path** can be used to access the nested data. For - example: - - .. code-block:: python - - >>> snapshot.get('top1') - { - 'middle2': { - 'bottom3': 20, - 'bottom4': 22, - }, - 'middle5': True, - } - >>> snapshot.get('top1.middle2') - { - 'bottom3': 20, - 'bottom4': 22, - } - >>> snapshot.get('top1.middle2.bottom3') - 20 - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - A copy is returned since the data may contain mutable values, - but the data stored in the snapshot must remain immutable. - - Args: - field_path (str): A field path (``.``-delimited list of - field names). - - Returns: - Any or None: - (A copy of) the value stored for the ``field_path`` or - None if snapshot document does not exist. - - Raises: - KeyError: If the ``field_path`` does not match nested data - in the snapshot. - """ - if not self._exists: - return None - nested_data = field_path_module.get_nested_value(field_path, self._data) - return copy.deepcopy(nested_data) - - def to_dict(self): - """Retrieve the data contained in this snapshot. - - A copy is returned since the data may contain mutable values, - but the data stored in the snapshot must remain immutable. - - Returns: - Dict[str, Any] or None: - The data in the snapshot. Returns None if reference - does not exist. - """ - if not self._exists: - return None - return copy.deepcopy(self._data) - - -def _get_document_path(client, path): - """Convert a path tuple into a full path string. - - Of the form: - - ``projects/{project_id}/databases/{database_id}/... - documents/{document_path}`` - - Args: - client (:class:`~google.cloud.firestore_v1.client.Client`): - The client that holds configuration details and a GAPIC client - object. - path (Tuple[str, ...]): The components in a document path. - - Returns: - str: The fully-qualified document path. - """ - parts = (client._database_string, "documents") + path - return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) - - -def _consume_single_get(response_iterator): - """Consume a gRPC stream that should contain a single response. - - The stream will correspond to a ``BatchGetDocuments`` request made - for a single document. - - Args: - response_iterator (~google.cloud.exceptions.GrpcRendezvous): A - streaming iterator returned from a ``BatchGetDocuments`` - request. - - Returns: - ~google.cloud.proto.firestore.v1.\ - firestore_pb2.BatchGetDocumentsResponse: The single "get" - response in the batch. - - Raises: - ValueError: If anything other than exactly one response is returned. - """ - # Calling ``list()`` consumes the entire iterator. - all_responses = list(response_iterator) - if len(all_responses) != 1: - raise ValueError( - "Unexpected response from `BatchGetDocumentsResponse`", - all_responses, - "Expected only one result", - ) - - return all_responses[0] - - -def _first_write_result(write_results): - """Get first write result from list. - - For cases where ``len(write_results) > 1``, this assumes the writes - occurred at the same time (e.g. if an update and transform are sent - at the same time). - - Args: - write_results (List[google.cloud.proto.firestore.v1.\ - write_pb2.WriteResult, ...]: The write results from a - ``CommitResponse``. - - Returns: - google.cloud.firestore_v1.types.WriteResult: The - lone write result from ``write_results``. - - Raises: - ValueError: If there are zero write results. This is likely to - **never** occur, since the backend should be stable. - """ - if not write_results: - raise ValueError("Expected at least one write result") - - return write_results[0] - - -def _item_to_collection_ref(iterator, item): - """Convert collection ID to collection ref. - - Args: - iterator (google.api_core.page_iterator.GRPCIterator): - iterator response - item (str): ID of the collection - """ - return iterator.document.collection(item) diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index 09b2f951e5..be265d7bfd 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -61,138 +61,14 @@ def test_constructor_invalid_kwarg(self): with self.assertRaises(TypeError): self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) - def test___copy__(self): - client = _make_client("rain") - document = self._make_one("a", "b", client=client) - # Access the document path so it is copied. - doc_path = document._document_path - self.assertEqual(doc_path, document._document_path_internal) - - new_document = document.__copy__() - self.assertIsNot(new_document, document) - self.assertIs(new_document._client, document._client) - self.assertEqual(new_document._path, document._path) - self.assertEqual( - new_document._document_path_internal, document._document_path_internal - ) - - def test___deepcopy__calls_copy(self): - client = mock.sentinel.client - document = self._make_one("a", "b", client=client) - document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) - - unused_memo = {} - new_document = document.__deepcopy__(unused_memo) - self.assertIs(new_document, mock.sentinel.new_doc) - document.__copy__.assert_called_once_with() - - def test__eq__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - pairs = ((document1, document2), (document1, document3), (document2, document3)) - for candidate1, candidate2 in pairs: - # We use == explicitly since assertNotEqual would use !=. - equality_val = candidate1 == candidate2 - self.assertFalse(equality_val) - - # Check the only equal one. - self.assertEqual(document1, document4) - self.assertIsNot(document1, document4) - - def test__eq__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - equality_val = document == other - self.assertFalse(equality_val) - self.assertIs(document.__eq__(other), NotImplemented) - - def test___hash__(self): - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - document = self._make_one("X", "YY", client=client) - self.assertEqual(hash(document), hash(("X", "YY")) + hash(client)) - - def test__ne__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - self.assertNotEqual(document1, document2) - self.assertNotEqual(document1, document3) - self.assertNotEqual(document2, document3) - - # We use != explicitly since assertEqual would use ==. - inequality_val = document1 != document4 - self.assertFalse(inequality_val) - self.assertIsNot(document1, document4) - - def test__ne__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - self.assertNotEqual(document, other) - self.assertIs(document.__ne__(other), NotImplemented) - - def test__document_path_property(self): - project = "hi-its-me-ok-bye" - client = _make_client(project=project) - - collection_id = "then" - document_id = "090909iii" - document = self._make_one(collection_id, document_id, client=client) - doc_path = document._document_path - expected = "projects/{}/databases/{}/documents/{}/{}".format( - project, client._database, collection_id, document_id - ) - self.assertEqual(doc_path, expected) - self.assertIs(document._document_path_internal, doc_path) - - # Make sure value is cached. - document._document_path_internal = mock.sentinel.cached - self.assertIs(document._document_path, mock.sentinel.cached) - - def test__document_path_property_no_client(self): - document = self._make_one("hi", "bye") - self.assertIsNone(document._client) - with self.assertRaises(ValueError): - getattr(document, "_document_path") - - self.assertIsNone(document._document_path_internal) - - def test_id_property(self): - document_id = "867-5309" - document = self._make_one("Co-lek-shun", document_id) - self.assertEqual(document.id, document_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) - - parent = document.parent - self.assertIsInstance(parent, AsyncCollectionReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id,)) - - def test_collection_factory(self): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - - collection_id = "grocery-store" - document_id = "market" - new_collection = "fruits" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) + @staticmethod + def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.proto import firestore_pb2 - child = document.collection(new_collection) - self.assertIsInstance(child, AsyncCollectionReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id, new_collection)) + response = mock.create_autospec(firestore_pb2.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response @staticmethod def _write_pb_for_create(document_path, document_data): @@ -208,15 +84,6 @@ def _write_pb_for_create(document_path, document_data): current_document=common_pb2.Precondition(exists=False), ) - @staticmethod - def _make_commit_repsonse(write_results=None): - from google.cloud.firestore_v1.proto import firestore_pb2 - - response = mock.create_autospec(firestore_pb2.CommitResponse) - response.write_results = write_results or [mock.sentinel.write_result] - response.commit_time = mock.sentinel.commit_time - return response - @pytest.mark.asyncio async def test_create(self): # Create a minimal fake GAPIC with a dummy response. @@ -617,227 +484,6 @@ def test_on_snapshot(self, watch): watch.for_document.assert_called_once() -class TestDocumentSnapshot(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_document import DocumentSnapshot - - return DocumentSnapshot - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def _make_reference(self, *args, **kwargs): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - return AsyncDocumentReference(*args, **kwargs) - - def _make_w_ref(self, ref_path=("a", "b"), data={}, exists=True): - client = mock.sentinel.client - reference = self._make_reference(*ref_path, client=client) - return self._make_one( - reference, - data, - exists, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - - def test_constructor(self): - client = mock.sentinel.client - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - snapshot = self._make_one( - reference, - data, - True, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - self.assertIs(snapshot._reference, reference) - self.assertEqual(snapshot._data, data) - self.assertIsNot(snapshot._data, data) # Make sure copied. - self.assertTrue(snapshot._exists) - self.assertIs(snapshot.read_time, mock.sentinel.read_time) - self.assertIs(snapshot.create_time, mock.sentinel.create_time) - self.assertIs(snapshot.update_time, mock.sentinel.update_time) - - def test___eq___other_type(self): - snapshot = self._make_w_ref() - other = object() - self.assertFalse(snapshot == other) - - def test___eq___different_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("c", "d")) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_different_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b"), {"foo": "bar"}) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertTrue(snapshot == other) - - def test___hash__(self): - from google.protobuf import timestamp_pb2 - - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - update_time = timestamp_pb2.Timestamp(seconds=123456, nanos=123456789) - snapshot = self._make_one( - reference, data, True, None, mock.sentinel.create_time, update_time - ) - self.assertEqual( - hash(snapshot), hash(reference) + hash(123456) + hash(123456789) - ) - - def test__client_property(self): - reference = self._make_reference( - "ok", "fine", "now", "fore", client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, False, None, None, None) - self.assertIs(snapshot._client, mock.sentinel.client) - - def test_exists_property(self): - reference = mock.sentinel.reference - - snapshot1 = self._make_one(reference, {}, False, None, None, None) - self.assertFalse(snapshot1.exists) - snapshot2 = self._make_one(reference, {}, True, None, None, None) - self.assertTrue(snapshot2.exists) - - def test_id_property(self): - document_id = "around" - reference = self._make_reference( - "look", document_id, client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, True, None, None, None) - self.assertEqual(snapshot.id, document_id) - self.assertEqual(reference.id, document_id) - - def test_reference_property(self): - snapshot = self._make_one(mock.sentinel.reference, {}, True, None, None, None) - self.assertIs(snapshot.reference, mock.sentinel.reference) - - def test_get(self): - data = {"one": {"bold": "move"}} - snapshot = self._make_one(None, data, True, None, None, None) - - first_read = snapshot.get("one") - second_read = snapshot.get("one") - self.assertEqual(first_read, data.get("one")) - self.assertIsNot(first_read, data.get("one")) - self.assertEqual(first_read, second_read) - self.assertIsNot(first_read, second_read) - - with self.assertRaises(KeyError): - snapshot.get("two") - - def test_nonexistent_snapshot(self): - snapshot = self._make_one(None, None, False, None, None, None) - self.assertIsNone(snapshot.get("one")) - - def test_to_dict(self): - data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} - snapshot = self._make_one(None, data, True, None, None, None) - as_dict = snapshot.to_dict() - self.assertEqual(as_dict, data) - self.assertIsNot(as_dict, data) - # Check that the data remains unchanged. - as_dict["b"].append("hi") - self.assertEqual(data, snapshot.to_dict()) - self.assertNotEqual(data, as_dict) - - def test_non_existent(self): - snapshot = self._make_one(None, None, False, None, None, None) - as_dict = snapshot.to_dict() - self.assertIsNone(as_dict) - - -class Test__get_document_path(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(client, path): - from google.cloud.firestore_v1.document import _get_document_path - - return _get_document_path(client, path) - - def test_it(self): - project = "prah-jekt" - client = _make_client(project=project) - path = ("Some", "Document", "Child", "Shockument") - document_path = self._call_fut(client, path) - - expected = "projects/{}/databases/{}/documents/{}".format( - project, client._database, "/".join(path) - ) - self.assertEqual(document_path, expected) - - -class Test__consume_single_get(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(response_iterator): - from google.cloud.firestore_v1.document import _consume_single_get - - return _consume_single_get(response_iterator) - - def test_success(self): - response_iterator = iter([mock.sentinel.result]) - result = self._call_fut(response_iterator) - self.assertIs(result, mock.sentinel.result) - - def test_failure_not_enough(self): - response_iterator = iter([]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - def test_failure_too_many(self): - response_iterator = iter([None, None]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - -class Test__first_write_result(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(write_results): - from google.cloud.firestore_v1.document import _first_write_result - - return _first_write_result(write_results) - - def test_success(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - single_result = write_pb2.WriteResult( - update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) - ) - write_results = [single_result] - result = self._call_fut(write_results) - self.assertIs(result, single_result) - - def test_failure_not_enough(self): - write_results = [] - with self.assertRaises(ValueError): - self._call_fut(write_results) - - def test_more_than_one(self): - from google.cloud.firestore_v1.proto import write_pb2 - - result1 = write_pb2.WriteResult() - result2 = write_pb2.WriteResult() - write_results = [result1, result2] - result = self._call_fut(write_results) - self.assertIs(result, result1) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_base_document.py b/tests/unit/v1/test_base_document.py new file mode 100644 index 0000000000..7e61f4cbb2 --- /dev/null +++ b/tests/unit/v1/test_base_document.py @@ -0,0 +1,427 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +class TestDocumentReference(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.document import DocumentReference + + return DocumentReference + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 + + document = self._make_one( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + self.assertIs(document._client, client) + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + self.assertEqual(document.path, expected_path) + + def test_constructor_invalid_path(self): + with self.assertRaises(ValueError): + self._make_one() + with self.assertRaises(ValueError): + self._make_one(None, "before", "bad-collection-id", "fifteen") + with self.assertRaises(ValueError): + self._make_one("bad-document-ID", None) + with self.assertRaises(ValueError): + self._make_one("Just", "A-Collection", "Sub") + + def test_constructor_invalid_kwarg(self): + with self.assertRaises(TypeError): + self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) + + def test___copy__(self): + client = _make_client("rain") + document = self._make_one("a", "b", client=client) + # Access the document path so it is copied. + doc_path = document._document_path + self.assertEqual(doc_path, document._document_path_internal) + + new_document = document.__copy__() + self.assertIsNot(new_document, document) + self.assertIs(new_document._client, document._client) + self.assertEqual(new_document._path, document._path) + self.assertEqual( + new_document._document_path_internal, document._document_path_internal + ) + + def test___deepcopy__calls_copy(self): + client = mock.sentinel.client + document = self._make_one("a", "b", client=client) + document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) + + unused_memo = {} + new_document = document.__deepcopy__(unused_memo) + self.assertIs(new_document, mock.sentinel.new_doc) + document.__copy__.assert_called_once_with() + + def test__eq__same_type(self): + document1 = self._make_one("X", "YY", client=mock.sentinel.client) + document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) + document3 = self._make_one("X", "YY", client=mock.sentinel.client2) + document4 = self._make_one("X", "YY", client=mock.sentinel.client) + + pairs = ((document1, document2), (document1, document3), (document2, document3)) + for candidate1, candidate2 in pairs: + # We use == explicitly since assertNotEqual would use !=. + equality_val = candidate1 == candidate2 + self.assertFalse(equality_val) + + # Check the only equal one. + self.assertEqual(document1, document4) + self.assertIsNot(document1, document4) + + def test__eq__other_type(self): + document = self._make_one("X", "YY", client=mock.sentinel.client) + other = object() + equality_val = document == other + self.assertFalse(equality_val) + self.assertIs(document.__eq__(other), NotImplemented) + + def test___hash__(self): + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + document = self._make_one("X", "YY", client=client) + self.assertEqual(hash(document), hash(("X", "YY")) + hash(client)) + + def test__ne__same_type(self): + document1 = self._make_one("X", "YY", client=mock.sentinel.client) + document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) + document3 = self._make_one("X", "YY", client=mock.sentinel.client2) + document4 = self._make_one("X", "YY", client=mock.sentinel.client) + + self.assertNotEqual(document1, document2) + self.assertNotEqual(document1, document3) + self.assertNotEqual(document2, document3) + + # We use != explicitly since assertEqual would use ==. + inequality_val = document1 != document4 + self.assertFalse(inequality_val) + self.assertIsNot(document1, document4) + + def test__ne__other_type(self): + document = self._make_one("X", "YY", client=mock.sentinel.client) + other = object() + self.assertNotEqual(document, other) + self.assertIs(document.__ne__(other), NotImplemented) + + def test__document_path_property(self): + project = "hi-its-me-ok-bye" + client = _make_client(project=project) + + collection_id = "then" + document_id = "090909iii" + document = self._make_one(collection_id, document_id, client=client) + doc_path = document._document_path + expected = "projects/{}/databases/{}/documents/{}/{}".format( + project, client._database, collection_id, document_id + ) + self.assertEqual(doc_path, expected) + self.assertIs(document._document_path_internal, doc_path) + + # Make sure value is cached. + document._document_path_internal = mock.sentinel.cached + self.assertIs(document._document_path, mock.sentinel.cached) + + def test__document_path_property_no_client(self): + document = self._make_one("hi", "bye") + self.assertIsNone(document._client) + with self.assertRaises(ValueError): + getattr(document, "_document_path") + + self.assertIsNone(document._document_path_internal) + + def test_id_property(self): + document_id = "867-5309" + document = self._make_one("Co-lek-shun", document_id) + self.assertEqual(document.id, document_id) + + def test_parent_property(self): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + document = self._make_one(collection_id, document_id, client=client) + + parent = document.parent + self.assertIsInstance(parent, CollectionReference) + self.assertIs(parent._client, client) + self.assertEqual(parent._path, (collection_id,)) + + def test_collection_factory(self): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + new_collection = "fruits" + client = _make_client() + document = self._make_one(collection_id, document_id, client=client) + + child = document.collection(new_collection) + self.assertIsInstance(child, CollectionReference) + self.assertIs(child._client, client) + self.assertEqual(child._path, (collection_id, document_id, new_collection)) + + +class TestDocumentSnapshot(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.document import DocumentSnapshot + + return DocumentSnapshot + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def _make_reference(self, *args, **kwargs): + from google.cloud.firestore_v1.document import DocumentReference + + return DocumentReference(*args, **kwargs) + + def _make_w_ref(self, ref_path=("a", "b"), data={}, exists=True): + client = mock.sentinel.client + reference = self._make_reference(*ref_path, client=client) + return self._make_one( + reference, + data, + exists, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + + def test_constructor(self): + client = mock.sentinel.client + reference = self._make_reference("hi", "bye", client=client) + data = {"zoop": 83} + snapshot = self._make_one( + reference, + data, + True, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + self.assertIs(snapshot._reference, reference) + self.assertEqual(snapshot._data, data) + self.assertIsNot(snapshot._data, data) # Make sure copied. + self.assertTrue(snapshot._exists) + self.assertIs(snapshot.read_time, mock.sentinel.read_time) + self.assertIs(snapshot.create_time, mock.sentinel.create_time) + self.assertIs(snapshot.update_time, mock.sentinel.update_time) + + def test___eq___other_type(self): + snapshot = self._make_w_ref() + other = object() + self.assertFalse(snapshot == other) + + def test___eq___different_reference_same_data(self): + snapshot = self._make_w_ref(("a", "b")) + other = self._make_w_ref(("c", "d")) + self.assertFalse(snapshot == other) + + def test___eq___same_reference_different_data(self): + snapshot = self._make_w_ref(("a", "b")) + other = self._make_w_ref(("a", "b"), {"foo": "bar"}) + self.assertFalse(snapshot == other) + + def test___eq___same_reference_same_data(self): + snapshot = self._make_w_ref(("a", "b"), {"foo": "bar"}) + other = self._make_w_ref(("a", "b"), {"foo": "bar"}) + self.assertTrue(snapshot == other) + + def test___hash__(self): + from google.protobuf import timestamp_pb2 + + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + reference = self._make_reference("hi", "bye", client=client) + data = {"zoop": 83} + update_time = timestamp_pb2.Timestamp(seconds=123456, nanos=123456789) + snapshot = self._make_one( + reference, data, True, None, mock.sentinel.create_time, update_time + ) + self.assertEqual( + hash(snapshot), hash(reference) + hash(123456) + hash(123456789) + ) + + def test__client_property(self): + reference = self._make_reference( + "ok", "fine", "now", "fore", client=mock.sentinel.client + ) + snapshot = self._make_one(reference, {}, False, None, None, None) + self.assertIs(snapshot._client, mock.sentinel.client) + + def test_exists_property(self): + reference = mock.sentinel.reference + + snapshot1 = self._make_one(reference, {}, False, None, None, None) + self.assertFalse(snapshot1.exists) + snapshot2 = self._make_one(reference, {}, True, None, None, None) + self.assertTrue(snapshot2.exists) + + def test_id_property(self): + document_id = "around" + reference = self._make_reference( + "look", document_id, client=mock.sentinel.client + ) + snapshot = self._make_one(reference, {}, True, None, None, None) + self.assertEqual(snapshot.id, document_id) + self.assertEqual(reference.id, document_id) + + def test_reference_property(self): + snapshot = self._make_one(mock.sentinel.reference, {}, True, None, None, None) + self.assertIs(snapshot.reference, mock.sentinel.reference) + + def test_get(self): + data = {"one": {"bold": "move"}} + snapshot = self._make_one(None, data, True, None, None, None) + + first_read = snapshot.get("one") + second_read = snapshot.get("one") + self.assertEqual(first_read, data.get("one")) + self.assertIsNot(first_read, data.get("one")) + self.assertEqual(first_read, second_read) + self.assertIsNot(first_read, second_read) + + with self.assertRaises(KeyError): + snapshot.get("two") + + def test_nonexistent_snapshot(self): + snapshot = self._make_one(None, None, False, None, None, None) + self.assertIsNone(snapshot.get("one")) + + def test_to_dict(self): + data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} + snapshot = self._make_one(None, data, True, None, None, None) + as_dict = snapshot.to_dict() + self.assertEqual(as_dict, data) + self.assertIsNot(as_dict, data) + # Check that the data remains unchanged. + as_dict["b"].append("hi") + self.assertEqual(data, snapshot.to_dict()) + self.assertNotEqual(data, as_dict) + + def test_non_existent(self): + snapshot = self._make_one(None, None, False, None, None, None) + as_dict = snapshot.to_dict() + self.assertIsNone(as_dict) + + +class Test__get_document_path(unittest.TestCase): + @staticmethod + def _call_fut(client, path): + from google.cloud.firestore_v1.base_document import _get_document_path + + return _get_document_path(client, path) + + def test_it(self): + project = "prah-jekt" + client = _make_client(project=project) + path = ("Some", "Document", "Child", "Shockument") + document_path = self._call_fut(client, path) + + expected = "projects/{}/databases/{}/documents/{}".format( + project, client._database, "/".join(path) + ) + self.assertEqual(document_path, expected) + + +class Test__consume_single_get(unittest.TestCase): + @staticmethod + def _call_fut(response_iterator): + from google.cloud.firestore_v1.base_document import _consume_single_get + + return _consume_single_get(response_iterator) + + def test_success(self): + response_iterator = iter([mock.sentinel.result]) + result = self._call_fut(response_iterator) + self.assertIs(result, mock.sentinel.result) + + def test_failure_not_enough(self): + response_iterator = iter([]) + with self.assertRaises(ValueError): + self._call_fut(response_iterator) + + def test_failure_too_many(self): + response_iterator = iter([None, None]) + with self.assertRaises(ValueError): + self._call_fut(response_iterator) + + +class Test__first_write_result(unittest.TestCase): + @staticmethod + def _call_fut(write_results): + from google.cloud.firestore_v1.base_document import _first_write_result + + return _first_write_result(write_results) + + def test_success(self): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.proto import write_pb2 + + single_result = write_pb2.WriteResult( + update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) + ) + write_results = [single_result] + result = self._call_fut(write_results) + self.assertIs(result, single_result) + + def test_failure_not_enough(self): + write_results = [] + with self.assertRaises(ValueError): + self._call_fut(write_results) + + def test_more_than_one(self): + from google.cloud.firestore_v1.proto import write_pb2 + + result1 = write_pb2.WriteResult() + result2 = write_pb2.WriteResult() + write_results = [result1, result2] + result = self._call_fut(write_results) + self.assertIs(result, result1) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index 89a19df674..cc80aa9646 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -60,138 +60,14 @@ def test_constructor_invalid_kwarg(self): with self.assertRaises(TypeError): self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) - def test___copy__(self): - client = _make_client("rain") - document = self._make_one("a", "b", client=client) - # Access the document path so it is copied. - doc_path = document._document_path - self.assertEqual(doc_path, document._document_path_internal) - - new_document = document.__copy__() - self.assertIsNot(new_document, document) - self.assertIs(new_document._client, document._client) - self.assertEqual(new_document._path, document._path) - self.assertEqual( - new_document._document_path_internal, document._document_path_internal - ) - - def test___deepcopy__calls_copy(self): - client = mock.sentinel.client - document = self._make_one("a", "b", client=client) - document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) - - unused_memo = {} - new_document = document.__deepcopy__(unused_memo) - self.assertIs(new_document, mock.sentinel.new_doc) - document.__copy__.assert_called_once_with() - - def test__eq__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - pairs = ((document1, document2), (document1, document3), (document2, document3)) - for candidate1, candidate2 in pairs: - # We use == explicitly since assertNotEqual would use !=. - equality_val = candidate1 == candidate2 - self.assertFalse(equality_val) - - # Check the only equal one. - self.assertEqual(document1, document4) - self.assertIsNot(document1, document4) - - def test__eq__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - equality_val = document == other - self.assertFalse(equality_val) - self.assertIs(document.__eq__(other), NotImplemented) - - def test___hash__(self): - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - document = self._make_one("X", "YY", client=client) - self.assertEqual(hash(document), hash(("X", "YY")) + hash(client)) - - def test__ne__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - self.assertNotEqual(document1, document2) - self.assertNotEqual(document1, document3) - self.assertNotEqual(document2, document3) - - # We use != explicitly since assertEqual would use ==. - inequality_val = document1 != document4 - self.assertFalse(inequality_val) - self.assertIsNot(document1, document4) - - def test__ne__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - self.assertNotEqual(document, other) - self.assertIs(document.__ne__(other), NotImplemented) - - def test__document_path_property(self): - project = "hi-its-me-ok-bye" - client = _make_client(project=project) - - collection_id = "then" - document_id = "090909iii" - document = self._make_one(collection_id, document_id, client=client) - doc_path = document._document_path - expected = "projects/{}/databases/{}/documents/{}/{}".format( - project, client._database, collection_id, document_id - ) - self.assertEqual(doc_path, expected) - self.assertIs(document._document_path_internal, doc_path) - - # Make sure value is cached. - document._document_path_internal = mock.sentinel.cached - self.assertIs(document._document_path, mock.sentinel.cached) - - def test__document_path_property_no_client(self): - document = self._make_one("hi", "bye") - self.assertIsNone(document._client) - with self.assertRaises(ValueError): - getattr(document, "_document_path") - - self.assertIsNone(document._document_path_internal) - - def test_id_property(self): - document_id = "867-5309" - document = self._make_one("Co-lek-shun", document_id) - self.assertEqual(document.id, document_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.collection import CollectionReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) - - parent = document.parent - self.assertIsInstance(parent, CollectionReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id,)) - - def test_collection_factory(self): - from google.cloud.firestore_v1.collection import CollectionReference - - collection_id = "grocery-store" - document_id = "market" - new_collection = "fruits" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) + @staticmethod + def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.proto import firestore_pb2 - child = document.collection(new_collection) - self.assertIsInstance(child, CollectionReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id, new_collection)) + response = mock.create_autospec(firestore_pb2.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response @staticmethod def _write_pb_for_create(document_path, document_data): @@ -207,15 +83,6 @@ def _write_pb_for_create(document_path, document_data): current_document=common_pb2.Precondition(exists=False), ) - @staticmethod - def _make_commit_repsonse(write_results=None): - from google.cloud.firestore_v1.proto import firestore_pb2 - - response = mock.create_autospec(firestore_pb2.CommitResponse) - response.write_results = write_results or [mock.sentinel.write_result] - response.commit_time = mock.sentinel.commit_time - return response - def test_create(self): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) @@ -591,227 +458,6 @@ def test_on_snapshot(self, watch): watch.for_document.assert_called_once() -class TestDocumentSnapshot(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.document import DocumentSnapshot - - return DocumentSnapshot - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def _make_reference(self, *args, **kwargs): - from google.cloud.firestore_v1.document import DocumentReference - - return DocumentReference(*args, **kwargs) - - def _make_w_ref(self, ref_path=("a", "b"), data={}, exists=True): - client = mock.sentinel.client - reference = self._make_reference(*ref_path, client=client) - return self._make_one( - reference, - data, - exists, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - - def test_constructor(self): - client = mock.sentinel.client - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - snapshot = self._make_one( - reference, - data, - True, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - self.assertIs(snapshot._reference, reference) - self.assertEqual(snapshot._data, data) - self.assertIsNot(snapshot._data, data) # Make sure copied. - self.assertTrue(snapshot._exists) - self.assertIs(snapshot.read_time, mock.sentinel.read_time) - self.assertIs(snapshot.create_time, mock.sentinel.create_time) - self.assertIs(snapshot.update_time, mock.sentinel.update_time) - - def test___eq___other_type(self): - snapshot = self._make_w_ref() - other = object() - self.assertFalse(snapshot == other) - - def test___eq___different_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("c", "d")) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_different_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b"), {"foo": "bar"}) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertTrue(snapshot == other) - - def test___hash__(self): - from google.protobuf import timestamp_pb2 - - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - update_time = timestamp_pb2.Timestamp(seconds=123456, nanos=123456789) - snapshot = self._make_one( - reference, data, True, None, mock.sentinel.create_time, update_time - ) - self.assertEqual( - hash(snapshot), hash(reference) + hash(123456) + hash(123456789) - ) - - def test__client_property(self): - reference = self._make_reference( - "ok", "fine", "now", "fore", client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, False, None, None, None) - self.assertIs(snapshot._client, mock.sentinel.client) - - def test_exists_property(self): - reference = mock.sentinel.reference - - snapshot1 = self._make_one(reference, {}, False, None, None, None) - self.assertFalse(snapshot1.exists) - snapshot2 = self._make_one(reference, {}, True, None, None, None) - self.assertTrue(snapshot2.exists) - - def test_id_property(self): - document_id = "around" - reference = self._make_reference( - "look", document_id, client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, True, None, None, None) - self.assertEqual(snapshot.id, document_id) - self.assertEqual(reference.id, document_id) - - def test_reference_property(self): - snapshot = self._make_one(mock.sentinel.reference, {}, True, None, None, None) - self.assertIs(snapshot.reference, mock.sentinel.reference) - - def test_get(self): - data = {"one": {"bold": "move"}} - snapshot = self._make_one(None, data, True, None, None, None) - - first_read = snapshot.get("one") - second_read = snapshot.get("one") - self.assertEqual(first_read, data.get("one")) - self.assertIsNot(first_read, data.get("one")) - self.assertEqual(first_read, second_read) - self.assertIsNot(first_read, second_read) - - with self.assertRaises(KeyError): - snapshot.get("two") - - def test_nonexistent_snapshot(self): - snapshot = self._make_one(None, None, False, None, None, None) - self.assertIsNone(snapshot.get("one")) - - def test_to_dict(self): - data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} - snapshot = self._make_one(None, data, True, None, None, None) - as_dict = snapshot.to_dict() - self.assertEqual(as_dict, data) - self.assertIsNot(as_dict, data) - # Check that the data remains unchanged. - as_dict["b"].append("hi") - self.assertEqual(data, snapshot.to_dict()) - self.assertNotEqual(data, as_dict) - - def test_non_existent(self): - snapshot = self._make_one(None, None, False, None, None, None) - as_dict = snapshot.to_dict() - self.assertIsNone(as_dict) - - -class Test__get_document_path(unittest.TestCase): - @staticmethod - def _call_fut(client, path): - from google.cloud.firestore_v1.document import _get_document_path - - return _get_document_path(client, path) - - def test_it(self): - project = "prah-jekt" - client = _make_client(project=project) - path = ("Some", "Document", "Child", "Shockument") - document_path = self._call_fut(client, path) - - expected = "projects/{}/databases/{}/documents/{}".format( - project, client._database, "/".join(path) - ) - self.assertEqual(document_path, expected) - - -class Test__consume_single_get(unittest.TestCase): - @staticmethod - def _call_fut(response_iterator): - from google.cloud.firestore_v1.document import _consume_single_get - - return _consume_single_get(response_iterator) - - def test_success(self): - response_iterator = iter([mock.sentinel.result]) - result = self._call_fut(response_iterator) - self.assertIs(result, mock.sentinel.result) - - def test_failure_not_enough(self): - response_iterator = iter([]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - def test_failure_too_many(self): - response_iterator = iter([None, None]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - -class Test__first_write_result(unittest.TestCase): - @staticmethod - def _call_fut(write_results): - from google.cloud.firestore_v1.document import _first_write_result - - return _first_write_result(write_results) - - def test_success(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.proto import write_pb2 - - single_result = write_pb2.WriteResult( - update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) - ) - write_results = [single_result] - result = self._call_fut(write_results) - self.assertIs(result, single_result) - - def test_failure_not_enough(self): - write_results = [] - with self.assertRaises(ValueError): - self._call_fut(write_results) - - def test_more_than_one(self): - from google.cloud.firestore_v1.proto import write_pb2 - - result1 = write_pb2.WriteResult() - result2 = write_pb2.WriteResult() - write_results = [result1, result2] - result = self._call_fut(write_results) - self.assertIs(result, result1) - - def _make_credentials(): import google.auth.credentials From ddcd71c331acc67af5287dd7d69a92ddd76710b6 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 24 Jun 2020 21:29:20 -0500 Subject: [PATCH 36/47] fix: base document test class name --- tests/unit/v1/test_base_document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/v1/test_base_document.py b/tests/unit/v1/test_base_document.py index 7e61f4cbb2..f520254edd 100644 --- a/tests/unit/v1/test_base_document.py +++ b/tests/unit/v1/test_base_document.py @@ -17,7 +17,7 @@ import mock -class TestDocumentReference(unittest.TestCase): +class TestBaseDocumentReference(unittest.TestCase): @staticmethod def _get_target_class(): from google.cloud.firestore_v1.document import DocumentReference From c097d0b282b0d294e03bced622e833cf1e12bf4a Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 24 Jun 2020 21:29:35 -0500 Subject: [PATCH 37/47] feat: create Query/AsyncQuery superclass --- google/cloud/firestore_v1/async_query.py | 6 +- google/cloud/firestore_v1/base_query.py | 961 ++++++++++++ google/cloud/firestore_v1/query.py | 872 +---------- tests/unit/v1/async/test_async_collection.py | 4 +- tests/unit/v1/async/test_async_query.py | 1394 +---------------- tests/unit/v1/test_base_query.py | 1441 ++++++++++++++++++ tests/unit/v1/test_collection.py | 4 +- tests/unit/v1/test_query.py | 1390 +---------------- 8 files changed, 2431 insertions(+), 3641 deletions(-) create mode 100644 google/cloud/firestore_v1/base_query.py create mode 100644 tests/unit/v1/test_base_query.py diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 83024284ef..dbfa1866f0 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,8 +20,8 @@ """ import warnings -from google.cloud.firestore_v1.query import ( - Query, +from google.cloud.firestore_v1.base_query import ( + BaseQuery, _query_response_to_snapshot, _collection_group_query_response_to_snapshot, ) @@ -31,7 +31,7 @@ from google.cloud.firestore_v1.watch import Watch -class AsyncQuery(Query): +class AsyncQuery(BaseQuery): """Represents a query to the Firestore API. Instances of this class are considered immutable: all methods that diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py new file mode 100644 index 0000000000..e861ddfb62 --- /dev/null +++ b/google/cloud/firestore_v1/base_query.py @@ -0,0 +1,961 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.query.Query` can be created directly from +a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be +a more common way to create a query than direct usage of the constructor. +""" +import copy +import math + +from google.protobuf import wrappers_pb2 +import six + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import document +from google.cloud.firestore_v1 import field_path as field_path_module +from google.cloud.firestore_v1 import transforms +from google.cloud.firestore_v1.gapic import enums +from google.cloud.firestore_v1.proto import query_pb2 +from google.cloud.firestore_v1.order import Order + +_EQ_OP = "==" +_operator_enum = enums.StructuredQuery.FieldFilter.Operator +_COMPARISON_OPERATORS = { + "<": _operator_enum.LESS_THAN, + "<=": _operator_enum.LESS_THAN_OR_EQUAL, + _EQ_OP: _operator_enum.EQUAL, + ">=": _operator_enum.GREATER_THAN_OR_EQUAL, + ">": _operator_enum.GREATER_THAN, + "array_contains": _operator_enum.ARRAY_CONTAINS, + "in": _operator_enum.IN, + "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, +} +_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." +_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' +_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." +_BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." +_INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." +_MISSING_ORDER_BY = ( + 'The "order by" field path {!r} is not present in the cursor data {!r}. ' + "All fields sent to ``order_by()`` must be present in the fields " + "if passed to one of ``start_at()`` / ``start_after()`` / " + "``end_before()`` / ``end_at()`` to define a cursor." +) +_NO_ORDERS_FOR_CURSOR = ( + "Attempting to create a cursor with no fields to order on. " + "When defining a cursor with one of ``start_at()`` / ``start_after()`` / " + "``end_before()`` / ``end_at()``, all fields in the cursor must " + "come from fields set in ``order_by()``." +) +_MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." + + +class BaseQuery(object): + """Represents a query to the Firestore API. + + Instances of this class are considered immutable: all methods that + would modify an instance instead return a new instance. + + Args: + parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + The collection that this query applies to. + projection (Optional[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.Projection`]): + A projection of document fields to limit the query results to. + field_filters (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.FieldFilter`, ...]]): + The filters to be applied in the query. + orders (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.Order`, ...]]): + The "order by" entries to use in the query. + limit (Optional[int]): + The maximum number of documents the query is allowed to return. + offset (Optional[int]): + The number of results to skip. + start_at (Optional[Tuple[dict, bool]]): + Two-tuple of : + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * an ``after`` flag + + The fields and the flag combine to form a cursor used as + a starting point in a query result set. If the ``after`` + flag is :data:`True`, the results will start just after any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + end_at (Optional[Tuple[dict, bool]]): + Two-tuple of: + + * a mapping of fields. Any field that is present in this mapping + must also be present in ``orders`` + * a ``before`` flag + + The fields and the flag combine to form a cursor used as + an ending point in a query result set. If the ``before`` + flag is :data:`True`, the results will end just before any + documents which have fields matching the cursor, otherwise + any matching documents will be included in the result set. + When the query is formed, the document values + will be used in the order given by ``orders``. + all_descendants (Optional[bool]): + When false, selects only collections that are immediate children + of the `parent` specified in the containing `RunQueryRequest`. + When true, selects all descendant collections. + """ + + ASCENDING = "ASCENDING" + """str: Sort query results in ascending order on a field.""" + DESCENDING = "DESCENDING" + """str: Sort query results in descending order on a field.""" + + def __init__( + self, + parent, + projection=None, + field_filters=(), + orders=(), + limit=None, + offset=None, + start_at=None, + end_at=None, + all_descendants=False, + ): + self._parent = parent + self._projection = projection + self._field_filters = field_filters + self._orders = orders + self._limit = limit + self._offset = offset + self._start_at = start_at + self._end_at = end_at + self._all_descendants = all_descendants + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + self._parent == other._parent + and self._projection == other._projection + and self._field_filters == other._field_filters + and self._orders == other._orders + and self._limit == other._limit + and self._offset == other._offset + and self._start_at == other._start_at + and self._end_at == other._end_at + and self._all_descendants == other._all_descendants + ) + + @property + def _client(self): + """The client of the parent collection. + + Returns: + :class:`~google.cloud.firestore_v1.client.Client`: + The client that owns this query. + """ + return self._parent._client + + def select(self, field_paths): + """Project documents matching query to a limited set of fields. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + If the current query already has a projection set (i.e. has already + called :meth:`~google.cloud.firestore_v1.query.Query.select`), this + will overwrite it. + + Args: + field_paths (Iterable[str, ...]): An iterable of field paths + (``.``-delimited list of field names) to use as a projection + of document fields in the query results. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A "projected" query. Acts as a copy of the current query, + modified with the newly added projection. + Raises: + ValueError: If any ``field_path`` is invalid. + """ + field_paths = list(field_paths) + for field_path in field_paths: + field_path_module.split_field_path(field_path) # raises + + new_projection = query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ) + return self.__class__( + self._parent, + projection=new_projection, + field_filters=self._field_filters, + orders=self._orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def where(self, field_path, op_string, value): + """Filter the query on a field. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + Returns a new :class:`~google.cloud.firestore_v1.query.Query` that + filters on a specific field path, according to an operation (e.g. + ``==`` or "equals") and a particular value to be paired with that + operation. + + Args: + field_path (str): A field path (``.``-delimited list of + field names) for the field to filter on. + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, + ``in``, ``array_contains`` and ``array_contains_any``. + value (Any): The value to compare the field against in the filter. + If ``value`` is :data:`None` or a NaN, then ``==`` is the only + allowed operation. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A filtered query. Acts as a copy of the current query, + modified with the newly added filter. + + Raises: + ValueError: If ``field_path`` is invalid. + ValueError: If ``value`` is a NaN or :data:`None` and + ``op_string`` is not ``==``. + """ + field_path_module.split_field_path(field_path) # raises + + if value is None: + if op_string != _EQ_OP: + raise ValueError(_BAD_OP_NAN_NULL) + filter_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + elif _isnan(value): + if op_string != _EQ_OP: + raise ValueError(_BAD_OP_NAN_NULL) + filter_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NAN, + ) + elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): + raise ValueError(_INVALID_WHERE_TRANSFORM) + else: + filter_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=_enum_from_op_string(op_string), + value=_helpers.encode_value(value), + ) + + new_filters = self._field_filters + (filter_pb,) + return self.__class__( + self._parent, + projection=self._projection, + field_filters=new_filters, + orders=self._orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + @staticmethod + def _make_order(field_path, direction): + """Helper for :meth:`order_by`.""" + return query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + direction=_enum_from_direction(direction), + ) + + def order_by(self, field_path, direction=ASCENDING): + """Modify the query to add an order clause on a specific field. + + See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for + more information on **field paths**. + + Successive :meth:`~google.cloud.firestore_v1.query.Query.order_by` + calls will further refine the ordering of results returned by the query + (i.e. the new "order by" fields will be added to existing ones). + + Args: + field_path (str): A field path (``.``-delimited list of + field names) on which to order the query results. + direction (Optional[str]): The direction to order by. Must be one + of :attr:`ASCENDING` or :attr:`DESCENDING`, defaults to + :attr:`ASCENDING`. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An ordered query. Acts as a copy of the current query, modified + with the newly added "order by" constraint. + + Raises: + ValueError: If ``field_path`` is invalid. + ValueError: If ``direction`` is not one of :attr:`ASCENDING` or + :attr:`DESCENDING`. + """ + field_path_module.split_field_path(field_path) # raises + + order_pb = self._make_order(field_path, direction) + + new_orders = self._orders + (order_pb,) + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=new_orders, + limit=self._limit, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def limit(self, count): + """Limit a query to return a fixed number of results. + + If the current query already has a limit set, this will overwrite it. + + Args: + count (int): Maximum number of documents to return that match + the query. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A limited query. Acts as a copy of the current query, modified + with the newly added "limit" filter. + """ + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=self._orders, + limit=count, + offset=self._offset, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def offset(self, num_to_skip): + """Skip to an offset in a query. + + If the current query already has specified an offset, this will + overwrite it. + + Args: + num_to_skip (int): The number of results to skip at the beginning + of query results. (Must be non-negative.) + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + An offset query. Acts as a copy of the current query, modified + with the newly added "offset" field. + """ + return self.__class__( + self._parent, + projection=self._projection, + field_filters=self._field_filters, + orders=self._orders, + limit=self._limit, + offset=num_to_skip, + start_at=self._start_at, + end_at=self._end_at, + all_descendants=self._all_descendants, + ) + + def _check_snapshot(self, document_fields): + """Validate local snapshots for non-collection-group queries. + + Raises: + ValueError: for non-collection-group queries, if the snapshot + is from a different collection. + """ + if self._all_descendants: + return + + if document_fields.reference._path[:-1] != self._parent._path: + raise ValueError("Cannot use snapshot from another collection as a cursor.") + + def _cursor_helper(self, document_fields, before, start): + """Set values to be used for a ``start_at`` or ``end_at`` cursor. + + The values will later be used in a query protobuf. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + before (bool): Flag indicating if the document in + ``document_fields`` should (:data:`False`) or + shouldn't (:data:`True`) be included in the result set. + start (Optional[bool]): determines if the cursor is a ``start_at`` + cursor (:data:`True`) or an ``end_at`` cursor (:data:`False`). + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "start at" cursor. + """ + if isinstance(document_fields, tuple): + document_fields = list(document_fields) + elif isinstance(document_fields, document.DocumentSnapshot): + self._check_snapshot(document_fields) + else: + # NOTE: We copy so that the caller can't modify after calling. + document_fields = copy.deepcopy(document_fields) + + cursor_pair = document_fields, before + query_kwargs = { + "projection": self._projection, + "field_filters": self._field_filters, + "orders": self._orders, + "limit": self._limit, + "offset": self._offset, + "all_descendants": self._all_descendants, + } + if start: + query_kwargs["start_at"] = cursor_pair + query_kwargs["end_at"] = self._end_at + else: + query_kwargs["start_at"] = self._start_at + query_kwargs["end_at"] = cursor_pair + + return self.__class__(self._parent, **query_kwargs) + + def start_at(self, document_fields): + """Start query results at a particular document value. + + The result set will **include** the document specified by + ``document_fields``. + + If the current query already has specified a start cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.start_after` -- this + will overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as + a copy of the current query, modified with the newly added + "start at" cursor. + """ + return self._cursor_helper(document_fields, before=True, start=True) + + def start_after(self, document_fields): + """Start query results after a particular document value. + + The result set will **exclude** the document specified by + ``document_fields``. + + If the current query already has specified a start cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.start_at` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "start after" cursor. + """ + return self._cursor_helper(document_fields, before=False, start=True) + + def end_before(self, document_fields): + """End query results before a particular document value. + + The result set will **exclude** the document specified by + ``document_fields``. + + If the current query already has specified an end cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.end_at` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "end before" cursor. + """ + return self._cursor_helper(document_fields, before=True, start=False) + + def end_at(self, document_fields): + """End query results at a particular document value. + + The result set will **include** the document specified by + ``document_fields``. + + If the current query already has specified an end cursor -- either + via this method or + :meth:`~google.cloud.firestore_v1.query.Query.end_before` -- this will + overwrite it. + + When the query is sent to the server, the ``document_fields`` will + be used in the order given by fields set by + :meth:`~google.cloud.firestore_v1.query.Query.order_by`. + + Args: + document_fields + (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): + a document snapshot or a dictionary/list/tuple of fields + representing a query results cursor. A cursor is a collection + of values that represent a position in a query result set. + + Returns: + :class:`~google.cloud.firestore_v1.query.Query`: + A query with cursor. Acts as a copy of the current query, modified + with the newly added "end at" cursor. + """ + return self._cursor_helper(document_fields, before=False, start=False) + + def _filters_pb(self): + """Convert all the filters into a single generic Filter protobuf. + + This may be a lone field filter or unary filter, may be a composite + filter or may be :data:`None`. + + Returns: + :class:`google.cloud.firestore_v1.types.StructuredQuery.Filter`: + A "generic" filter representing the current query's filters. + """ + num_filters = len(self._field_filters) + if num_filters == 0: + return None + elif num_filters == 1: + return _filter_pb(self._field_filters[0]) + else: + composite_filter = query_pb2.StructuredQuery.CompositeFilter( + op=enums.StructuredQuery.CompositeFilter.Operator.AND, + filters=[_filter_pb(filter_) for filter_ in self._field_filters], + ) + return query_pb2.StructuredQuery.Filter(composite_filter=composite_filter) + + @staticmethod + def _normalize_projection(projection): + """Helper: convert field paths to message.""" + if projection is not None: + + fields = list(projection.fields) + + if not fields: + field_ref = query_pb2.StructuredQuery.FieldReference( + field_path="__name__" + ) + return query_pb2.StructuredQuery.Projection(fields=[field_ref]) + + return projection + + def _normalize_orders(self): + """Helper: adjust orders based on cursors, where clauses.""" + orders = list(self._orders) + _has_snapshot_cursor = False + + if self._start_at: + if isinstance(self._start_at[0], document.DocumentSnapshot): + _has_snapshot_cursor = True + + if self._end_at: + if isinstance(self._end_at[0], document.DocumentSnapshot): + _has_snapshot_cursor = True + + if _has_snapshot_cursor: + should_order = [ + _enum_from_op_string(key) + for key in _COMPARISON_OPERATORS + if key not in (_EQ_OP, "array_contains") + ] + order_keys = [order.field.field_path for order in orders] + for filter_ in self._field_filters: + field = filter_.field.field_path + if filter_.op in should_order and field not in order_keys: + orders.append(self._make_order(field, "ASCENDING")) + if not orders: + orders.append(self._make_order("__name__", "ASCENDING")) + else: + order_keys = [order.field.field_path for order in orders] + if "__name__" not in order_keys: + direction = orders[-1].direction # enum? + orders.append(self._make_order("__name__", direction)) + + return orders + + def _normalize_cursor(self, cursor, orders): + """Helper: convert cursor to a list of values based on orders.""" + if cursor is None: + return + + if not orders: + raise ValueError(_NO_ORDERS_FOR_CURSOR) + + document_fields, before = cursor + + order_keys = [order.field.field_path for order in orders] + + if isinstance(document_fields, document.DocumentSnapshot): + snapshot = document_fields + document_fields = snapshot.to_dict() + document_fields["__name__"] = snapshot.reference + + if isinstance(document_fields, dict): + # Transform to list using orders + values = [] + data = document_fields + for order_key in order_keys: + try: + if order_key in data: + values.append(data[order_key]) + else: + values.append( + field_path_module.get_nested_value(order_key, data) + ) + except KeyError: + msg = _MISSING_ORDER_BY.format(order_key, data) + raise ValueError(msg) + document_fields = values + + if len(document_fields) != len(orders): + msg = _MISMATCH_CURSOR_W_ORDER_BY.format(document_fields, order_keys) + raise ValueError(msg) + + _transform_bases = (transforms.Sentinel, transforms._ValueList) + + for index, key_field in enumerate(zip(order_keys, document_fields)): + key, field = key_field + + if isinstance(field, _transform_bases): + msg = _INVALID_CURSOR_TRANSFORM + raise ValueError(msg) + + if key == "__name__" and isinstance(field, six.string_types): + document_fields[index] = self._parent.document(field) + + return document_fields, before + + def _to_protobuf(self): + """Convert the current query into the equivalent protobuf. + + Returns: + :class:`google.cloud.firestore_v1.types.StructuredQuery`: + The query protobuf. + """ + projection = self._normalize_projection(self._projection) + orders = self._normalize_orders() + start_at = self._normalize_cursor(self._start_at, orders) + end_at = self._normalize_cursor(self._end_at, orders) + + query_kwargs = { + "select": projection, + "from": [ + query_pb2.StructuredQuery.CollectionSelector( + collection_id=self._parent.id, all_descendants=self._all_descendants + ) + ], + "where": self._filters_pb(), + "order_by": orders, + "start_at": _cursor_pb(start_at), + "end_at": _cursor_pb(end_at), + } + if self._offset is not None: + query_kwargs["offset"] = self._offset + if self._limit is not None: + query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) + + return query_pb2.StructuredQuery(**query_kwargs) + + def get(self, transaction=None): + raise NotImplementedError + + def stream(self, transaction=None): + raise NotImplementedError + + def on_snapshot(self, callback): + raise NotImplementedError + + def _comparator(self, doc1, doc2): + _orders = self._orders + + # Add implicit sorting by name, using the last specified direction. + if len(_orders) == 0: + lastDirection = BaseQuery.ASCENDING + else: + if _orders[-1].direction == 1: + lastDirection = BaseQuery.ASCENDING + else: + lastDirection = BaseQuery.DESCENDING + + orderBys = list(_orders) + + order_pb = query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path="id"), + direction=_enum_from_direction(lastDirection), + ) + orderBys.append(order_pb) + + for orderBy in orderBys: + if orderBy.field.field_path == "id": + # If ordering by docuent id, compare resource paths. + comp = Order()._compare_to(doc1.reference._path, doc2.reference._path) + else: + if ( + orderBy.field.field_path not in doc1._data + or orderBy.field.field_path not in doc2._data + ): + raise ValueError( + "Can only compare fields that exist in the " + "DocumentSnapshot. Please include the fields you are " + "ordering on in your select() call." + ) + v1 = doc1._data[orderBy.field.field_path] + v2 = doc2._data[orderBy.field.field_path] + encoded_v1 = _helpers.encode_value(v1) + encoded_v2 = _helpers.encode_value(v2) + comp = Order().compare(encoded_v1, encoded_v2) + + if comp != 0: + # 1 == Ascending, -1 == Descending + return orderBy.direction * comp + + return 0 + + +def _enum_from_op_string(op_string): + """Convert a string representation of a binary operator to an enum. + + These enums come from the protobuf message definition + ``StructuredQuery.FieldFilter.Operator``. + + Args: + op_string (str): A comparison operation in the form of a string. + Acceptable values are ``<``, ``<=``, ``==``, ``>=`` + and ``>``. + + Returns: + int: The enum corresponding to ``op_string``. + + Raises: + ValueError: If ``op_string`` is not a valid operator. + """ + try: + return _COMPARISON_OPERATORS[op_string] + except KeyError: + choices = ", ".join(sorted(_COMPARISON_OPERATORS.keys())) + msg = _BAD_OP_STRING.format(op_string, choices) + raise ValueError(msg) + + +def _isnan(value): + """Check if a value is NaN. + + This differs from ``math.isnan`` in that **any** input type is + allowed. + + Args: + value (Any): A value to check for NaN-ness. + + Returns: + bool: Indicates if the value is the NaN float. + """ + if isinstance(value, float): + return math.isnan(value) + else: + return False + + +def _enum_from_direction(direction): + """Convert a string representation of a direction to an enum. + + Args: + direction (str): A direction to order by. Must be one of + :attr:`~google.cloud.firestore.Query.ASCENDING` or + :attr:`~google.cloud.firestore.Query.DESCENDING`. + + Returns: + int: The enum corresponding to ``direction``. + + Raises: + ValueError: If ``direction`` is not a valid direction. + """ + if isinstance(direction, int): + return direction + + if direction == BaseQuery.ASCENDING: + return enums.StructuredQuery.Direction.ASCENDING + elif direction == BaseQuery.DESCENDING: + return enums.StructuredQuery.Direction.DESCENDING + else: + msg = _BAD_DIR_STRING.format( + direction, BaseQuery.ASCENDING, BaseQuery.DESCENDING + ) + raise ValueError(msg) + + +def _filter_pb(field_or_unary): + """Convert a specific protobuf filter to the generic filter type. + + Args: + field_or_unary (Union[google.cloud.proto.firestore.v1.\ + query_pb2.StructuredQuery.FieldFilter, google.cloud.proto.\ + firestore.v1.query_pb2.StructuredQuery.FieldFilter]): A + field or unary filter to convert to a generic filter. + + Returns: + google.cloud.firestore_v1.types.\ + StructuredQuery.Filter: A "generic" filter. + + Raises: + ValueError: If ``field_or_unary`` is not a field or unary filter. + """ + if isinstance(field_or_unary, query_pb2.StructuredQuery.FieldFilter): + return query_pb2.StructuredQuery.Filter(field_filter=field_or_unary) + elif isinstance(field_or_unary, query_pb2.StructuredQuery.UnaryFilter): + return query_pb2.StructuredQuery.Filter(unary_filter=field_or_unary) + else: + raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) + + +def _cursor_pb(cursor_pair): + """Convert a cursor pair to a protobuf. + + If ``cursor_pair`` is :data:`None`, just returns :data:`None`. + + Args: + cursor_pair (Optional[Tuple[list, bool]]): Two-tuple of + + * a list of field values. + * a ``before`` flag + + Returns: + Optional[google.cloud.firestore_v1.types.Cursor]: A + protobuf cursor corresponding to the values. + """ + if cursor_pair is not None: + data, before = cursor_pair + value_pbs = [_helpers.encode_value(value) for value in data] + return query_pb2.Cursor(values=value_pbs, before=before) + + +def _query_response_to_snapshot(response_pb, collection, expected_prefix): + """Parse a query response protobuf to a document snapshot. + + Args: + response_pb (google.cloud.proto.firestore.v1.\ + firestore_pb2.RunQueryResponse): A + collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + A reference to the collection that initiated the query. + expected_prefix (str): The expected prefix for fully-qualified + document names returned in the query results. This can be computed + directly from ``collection`` via :meth:`_parent_info`. + + Returns: + Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: + A snapshot of the data returned in the query. If + ``response_pb.document`` is not set, the snapshot will be :data:`None`. + """ + if not response_pb.HasField("document"): + return None + + document_id = _helpers.get_doc_id(response_pb.document, expected_prefix) + reference = collection.document(document_id) + data = _helpers.decode_dict(response_pb.document.fields, collection._client) + snapshot = document.DocumentSnapshot( + reference, + data, + exists=True, + read_time=response_pb.read_time, + create_time=response_pb.document.create_time, + update_time=response_pb.document.update_time, + ) + return snapshot + + +def _collection_group_query_response_to_snapshot(response_pb, collection): + """Parse a query response protobuf to a document snapshot. + + Args: + response_pb (google.cloud.proto.firestore.v1.\ + firestore_pb2.RunQueryResponse): A + collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): + A reference to the collection that initiated the query. + + Returns: + Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: + A snapshot of the data returned in the query. If + ``response_pb.document`` is not set, the snapshot will be :data:`None`. + """ + if not response_pb.HasField("document"): + return None + reference = collection._client.document(response_pb.document.name) + data = _helpers.decode_dict(response_pb.document.fields, collection._client) + snapshot = document.DocumentSnapshot( + reference, + data, + exists=True, + read_time=response_pb.read_time, + create_time=response_pb.document.create_time, + update_time=response_pb.document.update_time, + ) + return snapshot diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 6a6326c903..f99c03a8df 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -18,55 +18,20 @@ a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be a more common way to create a query than direct usage of the constructor. """ -import copy -import math import warnings -from google.protobuf import wrappers_pb2 -import six +from google.cloud.firestore_v1.base_query import ( + BaseQuery, + _query_response_to_snapshot, + _collection_group_query_response_to_snapshot, +) from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import document -from google.cloud.firestore_v1 import field_path as field_path_module -from google.cloud.firestore_v1 import transforms -from google.cloud.firestore_v1.gapic import enums -from google.cloud.firestore_v1.proto import query_pb2 -from google.cloud.firestore_v1.order import Order from google.cloud.firestore_v1.watch import Watch -_EQ_OP = "==" -_operator_enum = enums.StructuredQuery.FieldFilter.Operator -_COMPARISON_OPERATORS = { - "<": _operator_enum.LESS_THAN, - "<=": _operator_enum.LESS_THAN_OR_EQUAL, - _EQ_OP: _operator_enum.EQUAL, - ">=": _operator_enum.GREATER_THAN_OR_EQUAL, - ">": _operator_enum.GREATER_THAN, - "array_contains": _operator_enum.ARRAY_CONTAINS, - "in": _operator_enum.IN, - "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, -} -_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." -_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' -_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values." -_BAD_DIR_STRING = "Invalid direction {!r}. Must be one of {!r} or {!r}." -_INVALID_CURSOR_TRANSFORM = "Transforms cannot be used as cursor values." -_MISSING_ORDER_BY = ( - 'The "order by" field path {!r} is not present in the cursor data {!r}. ' - "All fields sent to ``order_by()`` must be present in the fields " - "if passed to one of ``start_at()`` / ``start_after()`` / " - "``end_before()`` / ``end_at()`` to define a cursor." -) -_NO_ORDERS_FOR_CURSOR = ( - "Attempting to create a cursor with no fields to order on. " - "When defining a cursor with one of ``start_at()`` / ``start_after()`` / " - "``end_before()`` / ``end_at()``, all fields in the cursor must " - "come from fields set in ``order_by()``." -) -_MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}." - -class Query(object): +class Query(BaseQuery): """Represents a query to the Firestore API. Instances of this class are considered immutable: all methods that @@ -122,11 +87,6 @@ class Query(object): When true, selects all descendant collections. """ - ASCENDING = "ASCENDING" - """str: Sort query results in ascending order on a field.""" - DESCENDING = "DESCENDING" - """str: Sort query results in descending order on a field.""" - def __init__( self, parent, @@ -139,595 +99,18 @@ def __init__( end_at=None, all_descendants=False, ): - self._parent = parent - self._projection = projection - self._field_filters = field_filters - self._orders = orders - self._limit = limit - self._offset = offset - self._start_at = start_at - self._end_at = end_at - self._all_descendants = all_descendants - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return ( - self._parent == other._parent - and self._projection == other._projection - and self._field_filters == other._field_filters - and self._orders == other._orders - and self._limit == other._limit - and self._offset == other._offset - and self._start_at == other._start_at - and self._end_at == other._end_at - and self._all_descendants == other._all_descendants - ) - - @property - def _client(self): - """The client of the parent collection. - - Returns: - :class:`~google.cloud.firestore_v1.client.Client`: - The client that owns this query. - """ - return self._parent._client - - def select(self, field_paths): - """Project documents matching query to a limited set of fields. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - If the current query already has a projection set (i.e. has already - called :meth:`~google.cloud.firestore_v1.query.Query.select`), this - will overwrite it. - - Args: - field_paths (Iterable[str, ...]): An iterable of field paths - (``.``-delimited list of field names) to use as a projection - of document fields in the query results. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A "projected" query. Acts as a copy of the current query, - modified with the newly added projection. - Raises: - ValueError: If any ``field_path`` is invalid. - """ - field_paths = list(field_paths) - for field_path in field_paths: - field_path_module.split_field_path(field_path) # raises - - new_projection = query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ) - return self.__class__( - self._parent, - projection=new_projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def where(self, field_path, op_string, value): - """Filter the query on a field. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - Returns a new :class:`~google.cloud.firestore_v1.query.Query` that - filters on a specific field path, according to an operation (e.g. - ``==`` or "equals") and a particular value to be paired with that - operation. - - Args: - field_path (str): A field path (``.``-delimited list of - field names) for the field to filter on. - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, - ``in``, ``array_contains`` and ``array_contains_any``. - value (Any): The value to compare the field against in the filter. - If ``value`` is :data:`None` or a NaN, then ``==`` is the only - allowed operation. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A filtered query. Acts as a copy of the current query, - modified with the newly added filter. - - Raises: - ValueError: If ``field_path`` is invalid. - ValueError: If ``value`` is a NaN or :data:`None` and - ``op_string`` is not ``==``. - """ - field_path_module.split_field_path(field_path) # raises - - if value is None: - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - filter_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - elif _isnan(value): - if op_string != _EQ_OP: - raise ValueError(_BAD_OP_NAN_NULL) - filter_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NAN, - ) - elif isinstance(value, (transforms.Sentinel, transforms._ValueList)): - raise ValueError(_INVALID_WHERE_TRANSFORM) - else: - filter_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=_enum_from_op_string(op_string), - value=_helpers.encode_value(value), - ) - - new_filters = self._field_filters + (filter_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=new_filters, - orders=self._orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - @staticmethod - def _make_order(field_path, direction): - """Helper for :meth:`order_by`.""" - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=_enum_from_direction(direction), - ) - - def order_by(self, field_path, direction=ASCENDING): - """Modify the query to add an order clause on a specific field. - - See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for - more information on **field paths**. - - Successive :meth:`~google.cloud.firestore_v1.query.Query.order_by` - calls will further refine the ordering of results returned by the query - (i.e. the new "order by" fields will be added to existing ones). - - Args: - field_path (str): A field path (``.``-delimited list of - field names) on which to order the query results. - direction (Optional[str]): The direction to order by. Must be one - of :attr:`ASCENDING` or :attr:`DESCENDING`, defaults to - :attr:`ASCENDING`. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - An ordered query. Acts as a copy of the current query, modified - with the newly added "order by" constraint. - - Raises: - ValueError: If ``field_path`` is invalid. - ValueError: If ``direction`` is not one of :attr:`ASCENDING` or - :attr:`DESCENDING`. - """ - field_path_module.split_field_path(field_path) # raises - - order_pb = self._make_order(field_path, direction) - - new_orders = self._orders + (order_pb,) - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=new_orders, - limit=self._limit, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def limit(self, count): - """Limit a query to return a fixed number of results. - - If the current query already has a limit set, this will overwrite it. - - Args: - count (int): Maximum number of documents to return that match - the query. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A limited query. Acts as a copy of the current query, modified - with the newly added "limit" filter. - """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=count, - offset=self._offset, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, + super(Query, self).__init__( + parent=parent, + projection=projection, + field_filters=field_filters, + orders=orders, + limit=limit, + offset=offset, + start_at=start_at, + end_at=end_at, + all_descendants=all_descendants, ) - def offset(self, num_to_skip): - """Skip to an offset in a query. - - If the current query already has specified an offset, this will - overwrite it. - - Args: - num_to_skip (int): The number of results to skip at the beginning - of query results. (Must be non-negative.) - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - An offset query. Acts as a copy of the current query, modified - with the newly added "offset" field. - """ - return self.__class__( - self._parent, - projection=self._projection, - field_filters=self._field_filters, - orders=self._orders, - limit=self._limit, - offset=num_to_skip, - start_at=self._start_at, - end_at=self._end_at, - all_descendants=self._all_descendants, - ) - - def _check_snapshot(self, document_fields): - """Validate local snapshots for non-collection-group queries. - - Raises: - ValueError: for non-collection-group queries, if the snapshot - is from a different collection. - """ - if self._all_descendants: - return - - if document_fields.reference._path[:-1] != self._parent._path: - raise ValueError("Cannot use snapshot from another collection as a cursor.") - - def _cursor_helper(self, document_fields, before, start): - """Set values to be used for a ``start_at`` or ``end_at`` cursor. - - The values will later be used in a query protobuf. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - before (bool): Flag indicating if the document in - ``document_fields`` should (:data:`False`) or - shouldn't (:data:`True`) be included in the result set. - start (Optional[bool]): determines if the cursor is a ``start_at`` - cursor (:data:`True`) or an ``end_at`` cursor (:data:`False`). - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "start at" cursor. - """ - if isinstance(document_fields, tuple): - document_fields = list(document_fields) - elif isinstance(document_fields, document.DocumentSnapshot): - self._check_snapshot(document_fields) - else: - # NOTE: We copy so that the caller can't modify after calling. - document_fields = copy.deepcopy(document_fields) - - cursor_pair = document_fields, before - query_kwargs = { - "projection": self._projection, - "field_filters": self._field_filters, - "orders": self._orders, - "limit": self._limit, - "offset": self._offset, - "all_descendants": self._all_descendants, - } - if start: - query_kwargs["start_at"] = cursor_pair - query_kwargs["end_at"] = self._end_at - else: - query_kwargs["start_at"] = self._start_at - query_kwargs["end_at"] = cursor_pair - - return self.__class__(self._parent, **query_kwargs) - - def start_at(self, document_fields): - """Start query results at a particular document value. - - The result set will **include** the document specified by - ``document_fields``. - - If the current query already has specified a start cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.Query.start_after` -- this - will overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. Acts as - a copy of the current query, modified with the newly added - "start at" cursor. - """ - return self._cursor_helper(document_fields, before=True, start=True) - - def start_after(self, document_fields): - """Start query results after a particular document value. - - The result set will **exclude** the document specified by - ``document_fields``. - - If the current query already has specified a start cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.Query.start_at` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "start after" cursor. - """ - return self._cursor_helper(document_fields, before=False, start=True) - - def end_before(self, document_fields): - """End query results before a particular document value. - - The result set will **exclude** the document specified by - ``document_fields``. - - If the current query already has specified an end cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.Query.end_at` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "end before" cursor. - """ - return self._cursor_helper(document_fields, before=True, start=False) - - def end_at(self, document_fields): - """End query results at a particular document value. - - The result set will **include** the document specified by - ``document_fields``. - - If the current query already has specified an end cursor -- either - via this method or - :meth:`~google.cloud.firestore_v1.query.Query.end_before` -- this will - overwrite it. - - When the query is sent to the server, the ``document_fields`` will - be used in the order given by fields set by - :meth:`~google.cloud.firestore_v1.query.Query.order_by`. - - Args: - document_fields - (Union[:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`, dict, list, tuple]): - a document snapshot or a dictionary/list/tuple of fields - representing a query results cursor. A cursor is a collection - of values that represent a position in a query result set. - - Returns: - :class:`~google.cloud.firestore_v1.query.Query`: - A query with cursor. Acts as a copy of the current query, modified - with the newly added "end at" cursor. - """ - return self._cursor_helper(document_fields, before=False, start=False) - - def _filters_pb(self): - """Convert all the filters into a single generic Filter protobuf. - - This may be a lone field filter or unary filter, may be a composite - filter or may be :data:`None`. - - Returns: - :class:`google.cloud.firestore_v1.types.StructuredQuery.Filter`: - A "generic" filter representing the current query's filters. - """ - num_filters = len(self._field_filters) - if num_filters == 0: - return None - elif num_filters == 1: - return _filter_pb(self._field_filters[0]) - else: - composite_filter = query_pb2.StructuredQuery.CompositeFilter( - op=enums.StructuredQuery.CompositeFilter.Operator.AND, - filters=[_filter_pb(filter_) for filter_ in self._field_filters], - ) - return query_pb2.StructuredQuery.Filter(composite_filter=composite_filter) - - @staticmethod - def _normalize_projection(projection): - """Helper: convert field paths to message.""" - if projection is not None: - - fields = list(projection.fields) - - if not fields: - field_ref = query_pb2.StructuredQuery.FieldReference( - field_path="__name__" - ) - return query_pb2.StructuredQuery.Projection(fields=[field_ref]) - - return projection - - def _normalize_orders(self): - """Helper: adjust orders based on cursors, where clauses.""" - orders = list(self._orders) - _has_snapshot_cursor = False - - if self._start_at: - if isinstance(self._start_at[0], document.DocumentSnapshot): - _has_snapshot_cursor = True - - if self._end_at: - if isinstance(self._end_at[0], document.DocumentSnapshot): - _has_snapshot_cursor = True - - if _has_snapshot_cursor: - should_order = [ - _enum_from_op_string(key) - for key in _COMPARISON_OPERATORS - if key not in (_EQ_OP, "array_contains") - ] - order_keys = [order.field.field_path for order in orders] - for filter_ in self._field_filters: - field = filter_.field.field_path - if filter_.op in should_order and field not in order_keys: - orders.append(self._make_order(field, "ASCENDING")) - if not orders: - orders.append(self._make_order("__name__", "ASCENDING")) - else: - order_keys = [order.field.field_path for order in orders] - if "__name__" not in order_keys: - direction = orders[-1].direction # enum? - orders.append(self._make_order("__name__", direction)) - - return orders - - def _normalize_cursor(self, cursor, orders): - """Helper: convert cursor to a list of values based on orders.""" - if cursor is None: - return - - if not orders: - raise ValueError(_NO_ORDERS_FOR_CURSOR) - - document_fields, before = cursor - - order_keys = [order.field.field_path for order in orders] - - if isinstance(document_fields, document.DocumentSnapshot): - snapshot = document_fields - document_fields = snapshot.to_dict() - document_fields["__name__"] = snapshot.reference - - if isinstance(document_fields, dict): - # Transform to list using orders - values = [] - data = document_fields - for order_key in order_keys: - try: - if order_key in data: - values.append(data[order_key]) - else: - values.append( - field_path_module.get_nested_value(order_key, data) - ) - except KeyError: - msg = _MISSING_ORDER_BY.format(order_key, data) - raise ValueError(msg) - document_fields = values - - if len(document_fields) != len(orders): - msg = _MISMATCH_CURSOR_W_ORDER_BY.format(document_fields, order_keys) - raise ValueError(msg) - - _transform_bases = (transforms.Sentinel, transforms._ValueList) - - for index, key_field in enumerate(zip(order_keys, document_fields)): - key, field = key_field - - if isinstance(field, _transform_bases): - msg = _INVALID_CURSOR_TRANSFORM - raise ValueError(msg) - - if key == "__name__" and isinstance(field, six.string_types): - document_fields[index] = self._parent.document(field) - - return document_fields, before - - def _to_protobuf(self): - """Convert the current query into the equivalent protobuf. - - Returns: - :class:`google.cloud.firestore_v1.types.StructuredQuery`: - The query protobuf. - """ - projection = self._normalize_projection(self._projection) - orders = self._normalize_orders() - start_at = self._normalize_cursor(self._start_at, orders) - end_at = self._normalize_cursor(self._end_at, orders) - - query_kwargs = { - "select": projection, - "from": [ - query_pb2.StructuredQuery.CollectionSelector( - collection_id=self._parent.id, all_descendants=self._all_descendants - ) - ], - "where": self._filters_pb(), - "order_by": orders, - "start_at": _cursor_pb(start_at), - "end_at": _cursor_pb(end_at), - } - if self._offset is not None: - query_kwargs["offset"] = self._offset - if self._limit is not None: - query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) - - return query_pb2.StructuredQuery(**query_kwargs) - def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" warnings.warn( @@ -816,226 +199,3 @@ def on_snapshot(docs, changes, read_time): return Watch.for_query( self, callback, document.DocumentSnapshot, document.DocumentReference ) - - def _comparator(self, doc1, doc2): - _orders = self._orders - - # Add implicit sorting by name, using the last specified direction. - if len(_orders) == 0: - lastDirection = Query.ASCENDING - else: - if _orders[-1].direction == 1: - lastDirection = Query.ASCENDING - else: - lastDirection = Query.DESCENDING - - orderBys = list(_orders) - - order_pb = query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path="id"), - direction=_enum_from_direction(lastDirection), - ) - orderBys.append(order_pb) - - for orderBy in orderBys: - if orderBy.field.field_path == "id": - # If ordering by docuent id, compare resource paths. - comp = Order()._compare_to(doc1.reference._path, doc2.reference._path) - else: - if ( - orderBy.field.field_path not in doc1._data - or orderBy.field.field_path not in doc2._data - ): - raise ValueError( - "Can only compare fields that exist in the " - "DocumentSnapshot. Please include the fields you are " - "ordering on in your select() call." - ) - v1 = doc1._data[orderBy.field.field_path] - v2 = doc2._data[orderBy.field.field_path] - encoded_v1 = _helpers.encode_value(v1) - encoded_v2 = _helpers.encode_value(v2) - comp = Order().compare(encoded_v1, encoded_v2) - - if comp != 0: - # 1 == Ascending, -1 == Descending - return orderBy.direction * comp - - return 0 - - -def _enum_from_op_string(op_string): - """Convert a string representation of a binary operator to an enum. - - These enums come from the protobuf message definition - ``StructuredQuery.FieldFilter.Operator``. - - Args: - op_string (str): A comparison operation in the form of a string. - Acceptable values are ``<``, ``<=``, ``==``, ``>=`` - and ``>``. - - Returns: - int: The enum corresponding to ``op_string``. - - Raises: - ValueError: If ``op_string`` is not a valid operator. - """ - try: - return _COMPARISON_OPERATORS[op_string] - except KeyError: - choices = ", ".join(sorted(_COMPARISON_OPERATORS.keys())) - msg = _BAD_OP_STRING.format(op_string, choices) - raise ValueError(msg) - - -def _isnan(value): - """Check if a value is NaN. - - This differs from ``math.isnan`` in that **any** input type is - allowed. - - Args: - value (Any): A value to check for NaN-ness. - - Returns: - bool: Indicates if the value is the NaN float. - """ - if isinstance(value, float): - return math.isnan(value) - else: - return False - - -def _enum_from_direction(direction): - """Convert a string representation of a direction to an enum. - - Args: - direction (str): A direction to order by. Must be one of - :attr:`~google.cloud.firestore.Query.ASCENDING` or - :attr:`~google.cloud.firestore.Query.DESCENDING`. - - Returns: - int: The enum corresponding to ``direction``. - - Raises: - ValueError: If ``direction`` is not a valid direction. - """ - if isinstance(direction, int): - return direction - - if direction == Query.ASCENDING: - return enums.StructuredQuery.Direction.ASCENDING - elif direction == Query.DESCENDING: - return enums.StructuredQuery.Direction.DESCENDING - else: - msg = _BAD_DIR_STRING.format(direction, Query.ASCENDING, Query.DESCENDING) - raise ValueError(msg) - - -def _filter_pb(field_or_unary): - """Convert a specific protobuf filter to the generic filter type. - - Args: - field_or_unary (Union[google.cloud.proto.firestore.v1.\ - query_pb2.StructuredQuery.FieldFilter, google.cloud.proto.\ - firestore.v1.query_pb2.StructuredQuery.FieldFilter]): A - field or unary filter to convert to a generic filter. - - Returns: - google.cloud.firestore_v1.types.\ - StructuredQuery.Filter: A "generic" filter. - - Raises: - ValueError: If ``field_or_unary`` is not a field or unary filter. - """ - if isinstance(field_or_unary, query_pb2.StructuredQuery.FieldFilter): - return query_pb2.StructuredQuery.Filter(field_filter=field_or_unary) - elif isinstance(field_or_unary, query_pb2.StructuredQuery.UnaryFilter): - return query_pb2.StructuredQuery.Filter(unary_filter=field_or_unary) - else: - raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) - - -def _cursor_pb(cursor_pair): - """Convert a cursor pair to a protobuf. - - If ``cursor_pair`` is :data:`None`, just returns :data:`None`. - - Args: - cursor_pair (Optional[Tuple[list, bool]]): Two-tuple of - - * a list of field values. - * a ``before`` flag - - Returns: - Optional[google.cloud.firestore_v1.types.Cursor]: A - protobuf cursor corresponding to the values. - """ - if cursor_pair is not None: - data, before = cursor_pair - value_pbs = [_helpers.encode_value(value) for value in data] - return query_pb2.Cursor(values=value_pbs, before=before) - - -def _query_response_to_snapshot(response_pb, collection, expected_prefix): - """Parse a query response protobuf to a document snapshot. - - Args: - response_pb (google.cloud.proto.firestore.v1.\ - firestore_pb2.RunQueryResponse): A - collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): - A reference to the collection that initiated the query. - expected_prefix (str): The expected prefix for fully-qualified - document names returned in the query results. This can be computed - directly from ``collection`` via :meth:`_parent_info`. - - Returns: - Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: - A snapshot of the data returned in the query. If - ``response_pb.document`` is not set, the snapshot will be :data:`None`. - """ - if not response_pb.HasField("document"): - return None - - document_id = _helpers.get_doc_id(response_pb.document, expected_prefix) - reference = collection.document(document_id) - data = _helpers.decode_dict(response_pb.document.fields, collection._client) - snapshot = document.DocumentSnapshot( - reference, - data, - exists=True, - read_time=response_pb.read_time, - create_time=response_pb.document.create_time, - update_time=response_pb.document.update_time, - ) - return snapshot - - -def _collection_group_query_response_to_snapshot(response_pb, collection): - """Parse a query response protobuf to a document snapshot. - - Args: - response_pb (google.cloud.proto.firestore.v1.\ - firestore_pb2.RunQueryResponse): A - collection (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): - A reference to the collection that initiated the query. - - Returns: - Optional[:class:`~google.cloud.firestore.document.DocumentSnapshot`]: - A snapshot of the data returned in the query. If - ``response_pb.document`` is not set, the snapshot will be :data:`None`. - """ - if not response_pb.HasField("document"): - return None - reference = collection._client.document(response_pb.document.name) - data = _helpers.decode_dict(response_pb.document.fields, collection._client) - snapshot = document.DocumentSnapshot( - reference, - data, - exists=True, - read_time=response_pb.read_time, - create_time=response_pb.document.create_time, - update_time=response_pb.document.update_time, - ) - return snapshot diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 91c64373ca..9cb97ae263 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -220,7 +220,7 @@ def test_select(self): def _make_field_filter_pb(field_path, op_string, value): from google.cloud.firestore_v1.proto import query_pb2 from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.query import _enum_from_op_string + from google.cloud.firestore_v1.base_query import _enum_from_op_string return query_pb2.StructuredQuery.FieldFilter( field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), @@ -248,7 +248,7 @@ def test_where(self): @staticmethod def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1.query import _enum_from_direction + from google.cloud.firestore_v1.base_query import _enum_from_direction return query_pb2.StructuredQuery.Order( field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index 73e5f3d764..5a9edd6d30 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -13,12 +13,13 @@ # limitations under the License. import pytest -import datetime import types import aiounittest import mock +from tests.unit.v1.test_base_query import _make_credentials, _make_query_response + class TestAsyncQuery(aiounittest.AsyncTestCase): @staticmethod @@ -43,1021 +44,6 @@ def test_constructor_defaults(self): self.assertIsNone(query._end_at) self.assertFalse(query._all_descendants) - def _make_one_all_fields( - self, limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True - ): - kwargs = { - "projection": mock.sentinel.projection, - "field_filters": mock.sentinel.filters, - "orders": mock.sentinel.orders, - "limit": limit, - "offset": offset, - "start_at": mock.sentinel.start_at, - "end_at": mock.sentinel.end_at, - "all_descendants": all_descendants, - } - for field in skip_fields: - kwargs.pop(field) - if parent is None: - parent = mock.sentinel.parent - return self._make_one(parent, **kwargs) - - def test_constructor_explicit(self): - limit = 234 - offset = 56 - query = self._make_one_all_fields(limit=limit, offset=offset) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIs(query._projection, mock.sentinel.projection) - self.assertIs(query._field_filters, mock.sentinel.filters) - self.assertEqual(query._orders, mock.sentinel.orders) - self.assertEqual(query._limit, limit) - self.assertEqual(query._offset, offset) - self.assertIs(query._start_at, mock.sentinel.start_at) - self.assertIs(query._end_at, mock.sentinel.end_at) - self.assertTrue(query._all_descendants) - - def test__client_property(self): - parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) - query = self._make_one(parent) - self.assertIs(query._client, mock.sentinel.client) - - def test___eq___other_type(self): - query = self._make_one_all_fields() - other = object() - self.assertFalse(query == other) - - def test___eq___different_parent(self): - parent = mock.sentinel.parent - other_parent = mock.sentinel.other_parent - query = self._make_one_all_fields(parent=parent) - other = self._make_one_all_fields(parent=other_parent) - self.assertFalse(query == other) - - def test___eq___different_projection(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - query._projection = mock.sentinel.projection - other = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - other._projection = mock.sentinel.other_projection - self.assertFalse(query == other) - - def test___eq___different_field_filters(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - query._field_filters = mock.sentinel.field_filters - other = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - other._field_filters = mock.sentinel.other_field_filters - self.assertFalse(query == other) - - def test___eq___different_orders(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - query._orders = mock.sentinel.orders - other = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - other._orders = mock.sentinel.other_orders - self.assertFalse(query == other) - - def test___eq___different_limit(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, limit=10) - other = self._make_one_all_fields(parent=parent, limit=20) - self.assertFalse(query == other) - - def test___eq___different_offset(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, offset=10) - other = self._make_one_all_fields(parent=parent, offset=20) - self.assertFalse(query == other) - - def test___eq___different_start_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - query._start_at = mock.sentinel.start_at - other = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - other._start_at = mock.sentinel.other_start_at - self.assertFalse(query == other) - - def test___eq___different_end_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - query._end_at = mock.sentinel.end_at - other = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - other._end_at = mock.sentinel.other_end_at - self.assertFalse(query == other) - - def test___eq___different_all_descendants(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, all_descendants=True) - other = self._make_one_all_fields(parent=parent, all_descendants=False) - self.assertFalse(query == other) - - def test___eq___hit(self): - query = self._make_one_all_fields() - other = self._make_one_all_fields() - self.assertTrue(query == other) - - def _compare_queries(self, query1, query2, attr_name): - attrs1 = query1.__dict__.copy() - attrs2 = query2.__dict__.copy() - - attrs1.pop(attr_name) - attrs2.pop(attr_name) - - # The only different should be in ``attr_name``. - self.assertEqual(len(attrs1), len(attrs2)) - for key, value in attrs1.items(): - self.assertIs(value, attrs2[key]) - - @staticmethod - def _make_projection_for_select(field_paths): - from google.cloud.firestore_v1.proto import query_pb2 - - return query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ) - - def test_select_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.select(["*"]) - - def test_select(self): - query1 = self._make_one_all_fields(all_descendants=True) - - field_paths2 = ["foo", "bar"] - query2 = query1.select(field_paths2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual( - query2._projection, self._make_projection_for_select(field_paths2) - ) - self._compare_queries(query1, query2, "_projection") - - # Make sure it overrides. - field_paths3 = ["foo.baz"] - query3 = query2.select(field_paths3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual( - query3._projection, self._make_projection_for_select(field_paths3) - ) - self._compare_queries(query2, query3, "_projection") - - def test_where_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.where("*", "==", 1) - - def test_where(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query = self._make_one_all_fields( - skip_fields=("field_filters",), all_descendants=True - ) - new_query = query.where("power.level", ">", 9000) - - self.assertIsNot(query, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) - - field_pb = new_query._field_filters[0] - expected_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="power.level"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(integer_value=9000), - ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query, new_query, "_field_filters") - - def _where_unary_helper(self, value, op_enum, op_string="=="): - from google.cloud.firestore_v1.proto import query_pb2 - - query = self._make_one_all_fields(skip_fields=("field_filters",)) - field_path = "feeeld" - new_query = query.where(field_path, op_string, value) - - self.assertIsNot(query, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) - - field_pb = new_query._field_filters[0] - expected_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=op_enum, - ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query, new_query, "_field_filters") - - def test_where_eq_null(self): - from google.cloud.firestore_v1.gapic import enums - - op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NULL - self._where_unary_helper(None, op_enum) - - def test_where_gt_null(self): - with self.assertRaises(ValueError): - self._where_unary_helper(None, 0, op_string=">") - - def test_where_eq_nan(self): - from google.cloud.firestore_v1.gapic import enums - - op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NAN - self._where_unary_helper(float("nan"), op_enum) - - def test_where_le_nan(self): - with self.assertRaises(ValueError): - self._where_unary_helper(float("nan"), 0, op_string="<=") - - def test_where_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD - - with self.assertRaises(ValueError): - self._where_unary_helper(DELETE_FIELD, 0) - - def test_where_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - - with self.assertRaises(ValueError): - self._where_unary_helper(SERVER_TIMESTAMP, 0) - - def test_where_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove - - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayRemove([1, 3, 5]), 0) - - def test_where_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion - - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayUnion([2, 4, 8]), 0) - - def test_order_by_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.order_by("*") - - def test_order_by(self): - from google.cloud.firestore_v1.gapic import enums - - klass = self._get_target_class() - query1 = self._make_one_all_fields( - skip_fields=("orders",), all_descendants=True - ) - - field_path2 = "a" - query2 = query1.order_by(field_path2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, klass) - order_pb2 = _make_order_pb( - field_path2, enums.StructuredQuery.Direction.ASCENDING - ) - self.assertEqual(query2._orders, (order_pb2,)) - self._compare_queries(query1, query2, "_orders") - - # Make sure it appends to the orders. - field_path3 = "b" - query3 = query2.order_by(field_path3, direction=klass.DESCENDING) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, klass) - order_pb3 = _make_order_pb( - field_path3, enums.StructuredQuery.Direction.DESCENDING - ) - self.assertEqual(query3._orders, (order_pb2, order_pb3)) - self._compare_queries(query2, query3, "_orders") - - def test_limit(self): - query1 = self._make_one_all_fields(all_descendants=True) - - limit2 = 100 - query2 = query1.limit(limit2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._limit, limit2) - self._compare_queries(query1, query2, "_limit") - - # Make sure it overrides. - limit3 = 10 - query3 = query2.limit(limit3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._limit, limit3) - self._compare_queries(query2, query3, "_limit") - - def test_offset(self): - query1 = self._make_one_all_fields(all_descendants=True) - - offset2 = 23 - query2 = query1.offset(offset2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._offset, offset2) - self._compare_queries(query1, query2, "_offset") - - # Make sure it overrides. - offset3 = 35 - query3 = query2.offset(offset3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._offset, offset3) - self._compare_queries(query2, query3, "_offset") - - @staticmethod - def _make_collection(*path, **kw): - from google.cloud.firestore_v1 import collection - - return collection.CollectionReference(*path, **kw) - - @staticmethod - def _make_docref(*path, **kw): - from google.cloud.firestore_v1 import document - - return document.DocumentReference(*path, **kw) - - @staticmethod - def _make_snapshot(docref, values): - from google.cloud.firestore_v1 import document - - return document.DocumentSnapshot(docref, values, True, None, None, None) - - def test__cursor_helper_w_dict(self): - values = {"a": 7, "b": "foo"} - query1 = self._make_one(mock.sentinel.parent) - query1._all_descendants = True - query2 = query1._cursor_helper(values, True, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - self.assertTrue(query2._all_descendants) - - cursor, before = query2._start_at - - self.assertEqual(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_tuple(self): - values = (7, "foo") - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, False, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - - cursor, before = query2._start_at - - self.assertEqual(cursor, list(values)) - self.assertFalse(before) - - def test__cursor_helper_w_list(self): - values = [7, "foo"] - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, True, False) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertEqual(cursor, values) - self.assertIsNot(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_snapshot_wrong_collection(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection) - - with self.assertRaises(ValueError): - query._cursor_helper(snapshot, False, False) - - def test__cursor_helper_w_snapshot_other_collection_all_descendants(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection, all_descendants=True) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test__cursor_helper_w_snapshot(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test_start_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields( - parent=collection, skip_fields=("orders",), all_descendants=True - ) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.start_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_start_at") - - def test_start_after(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.start_after(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_after(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_start_at") - - def test_end_before(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.end_before(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_before(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_end_at") - self._compare_queries(query4, query5, "_end_at") - - def test_end_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.end_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_end_at") - - def test__filters_pb_empty(self): - query = self._make_one(mock.sentinel.parent) - self.assertEqual(len(query._field_filters), 0) - self.assertIsNone(query._filters_pb()) - - def test__filters_pb_single(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - filter_pb = query2._filters_pb() - expected_pb = query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="x.y"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=50.5), - ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__filters_pb_multi(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - query3 = query2.where("ABC", "==", 123) - - filter_pb = query3._filters_pb() - op_class = enums.StructuredQuery.FieldFilter.Operator - expected_pb = query_pb2.StructuredQuery.Filter( - composite_filter=query_pb2.StructuredQuery.CompositeFilter( - op=enums.StructuredQuery.CompositeFilter.Operator.AND, - filters=[ - query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference( - field_path="x.y" - ), - op=op_class.GREATER_THAN, - value=document_pb2.Value(double_value=50.5), - ) - ), - query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference( - field_path="ABC" - ), - op=op_class.EQUAL, - value=document_pb2.Value(integer_value=123), - ) - ), - ], - ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__normalize_projection_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_projection(None)) - - def test__normalize_projection_empty(self): - projection = self._make_projection_for_select([]) - query = self._make_one(mock.sentinel.parent) - normalized = query._normalize_projection(projection) - field_paths = [field_ref.field_path for field_ref in normalized.fields] - self.assertEqual(field_paths, ["__name__"]) - - def test__normalize_projection_non_empty(self): - projection = self._make_projection_for_select(["a", "b"]) - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._normalize_projection(projection), projection) - - def test__normalize_orders_wo_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent) - expected = [] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent).order_by("a") - expected = [query._make_order("a", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).start_at(snapshot) - expected = [query._make_order("__name__", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_name_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .order_by("__name__", "DESCENDING") - .start_at(snapshot) - ) - expected = [query._make_order("__name__", "DESCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .where("c", "<=", 20) - .order_by("c", "DESCENDING") - .start_at(snapshot) - ) - expected = [ - query._make_order("c", "DESCENDING"), - query._make_order("__name__", "DESCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot) - expected = [ - query._make_order("c", "ASCENDING"), - query._make_order("__name__", "ASCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_cursor_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_cursor(None, query._orders)) - - def test__normalize_cursor_no_order(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_list_mismatched_order(self): - cursor = ([1, 2], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_dict_mismatched_order(self): - cursor = ({"a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD - - cursor = ([DELETE_FIELD], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - - cursor = ([SERVER_TIMESTAMP], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove - - cursor = ([ArrayRemove([1, 3, 5])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion - - cursor = ([ArrayUnion([2, 4, 8])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_list_hit(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_hit(self): - cursor = ({"b": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_with_dot_key_hit(self): - cursor = ({"b.a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_with_inner_data_hit(self): - cursor = ({"b": {"a": 1}}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_snapshot_hit(self): - values = {"b": 1} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - cursor = (snapshot, True) - collection = self._make_collection("here") - query = self._make_one(collection).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_w___name___w_reference(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client"]) - parent._client = client - parent._path = ["C"] - query = self._make_one(parent).order_by("__name__", "ASCENDING") - docref = self._make_docref("here", "doc_id") - values = {"a": 7} - snapshot = self._make_snapshot(docref, values) - expected = docref - cursor = (snapshot, True) - - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) - - def test__normalize_cursor_w___name___wo_slash(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client", "document"]) - parent._client = client - parent._path = ["C"] - document = parent.document.return_value = mock.Mock(spec=[]) - query = self._make_one(parent).order_by("__name__", "ASCENDING") - cursor = (["b"], True) - expected = document - - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) - parent.document.assert_called_once_with("b") - - def test__to_protobuf_all_fields(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.select(["X", "Y", "Z"]) - query3 = query2.where("Y", ">", 2.5) - query4 = query3.order_by("X") - query5 = query4.limit(17) - query6 = query5.offset(3) - query7 = query6.start_at({"X": 10}) - query8 = query7.end_at({"X": 25}) - - structured_query_pb = query8._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in ["X", "Y", "Z"] - ] - ), - "where": query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="Y"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=2.5), - ) - ), - "order_by": [ - _make_order_pb("X", enums.StructuredQuery.Direction.ASCENDING) - ], - "start_at": query_pb2.Cursor( - values=[document_pb2.Value(integer_value=10)], before=True - ), - "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=25)]), - "offset": 3, - "limit": wrappers_pb2.Int32Value(value=17), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_select_only(self): - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - field_paths = ["a.b", "a.c", "d"] - query2 = query1.select(field_paths) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_where_only(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="dog", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.where("a", "==", u"b") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "where": query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="a"), - op=enums.StructuredQuery.FieldFilter.Operator.EQUAL, - value=document_pb2.Value(string_value=u"b"), - ) - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_order_by_only(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="fish", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.order_by("abc") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("abc", enums.StructuredQuery.Direction.ASCENDING) - ], - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_start_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="phish", spec=["id"]) - query = self._make_one(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) - - structured_query_pb = query._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("X.Y", enums.StructuredQuery.Direction.ASCENDING) - ], - "start_at": query_pb2.Cursor( - values=[document_pb2.Value(string_value=u"Z")] - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_end_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="ghoti", spec=["id"]) - query = self._make_one(parent).order_by("a").end_at({"a": 88}) - - structured_query_pb = query._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("a", enums.StructuredQuery.Direction.ASCENDING) - ], - "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=88)]), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_offset_only(self): - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cartt", spec=["id"]) - query1 = self._make_one(parent) - offset = 14 - query2 = query1.offset(offset) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "offset": offset, - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_limit_only(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="donut", spec=["id"]) - query1 = self._make_one(parent) - limit = 31 - query2 = query1.limit(limit) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "limit": wrappers_pb2.Int32Value(value=limit), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - - self.assertEqual(structured_query_pb, expected_pb) - @pytest.mark.asyncio async def test_get_simple(self): import warnings @@ -1370,381 +356,9 @@ def test_on_snapshot(self, watch): query.on_snapshot(None) watch.for_query.assert_called_once() - def test_comparator_no_ordering(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_no_ordering_same_id(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument1") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 0) - - def test_comparator_ordering(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 1) - - def test_comparator_ordering_descending(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = -1 # descending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_missing_order_by_field_in_data_raises(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = {} - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - with self.assertRaisesRegex(ValueError, "Can only compare fields "): - query._comparator(doc1, doc2) - - -class Test__enum_from_op_string(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(op_string): - from google.cloud.firestore_v1.query import _enum_from_op_string - - return _enum_from_op_string(op_string) - - @staticmethod - def _get_op_class(): - from google.cloud.firestore_v1.gapic import enums - - return enums.StructuredQuery.FieldFilter.Operator - - def test_lt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) - - def test_le(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) - - def test_eq(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("=="), op_class.EQUAL) - - def test_ge(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) - - def test_gt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) - - def test_array_contains(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) - - def test_in(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("in"), op_class.IN) - - def test_array_contains_any(self): - op_class = self._get_op_class() - self.assertEqual( - self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY - ) - - def test_invalid(self): - with self.assertRaises(ValueError): - self._call_fut("?") - - -class Test__isnan(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(value): - from google.cloud.firestore_v1.query import _isnan - - return _isnan(value) - - def test_valid(self): - self.assertTrue(self._call_fut(float("nan"))) - - def test_invalid(self): - self.assertFalse(self._call_fut(51.5)) - self.assertFalse(self._call_fut(None)) - self.assertFalse(self._call_fut("str")) - self.assertFalse(self._call_fut(int)) - self.assertFalse(self._call_fut(1.0 + 1.0j)) - - -class Test__enum_from_direction(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(direction): - from google.cloud.firestore_v1.query import _enum_from_direction - - return _enum_from_direction(direction) - - def test_success(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.async_query import AsyncQuery - - dir_class = enums.StructuredQuery.Direction - self.assertEqual(self._call_fut(AsyncQuery.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(AsyncQuery.DESCENDING), dir_class.DESCENDING) - - # Ints pass through - self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING) - - def test_failure(self): - with self.assertRaises(ValueError): - self._call_fut("neither-ASCENDING-nor-DESCENDING") - - -class Test__filter_pb(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(field_or_unary): - from google.cloud.firestore_v1.query import _filter_pb - - return _filter_pb(field_or_unary) - - def test_unary(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import query_pb2 - - unary_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="a.b.c"), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - filter_pb = self._call_fut(unary_pb) - expected_pb = query_pb2.StructuredQuery.Filter(unary_filter=unary_pb) - self.assertEqual(filter_pb, expected_pb) - - def test_field(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - field_filter_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="XYZ"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=90.75), - ) - filter_pb = self._call_fut(field_filter_pb) - expected_pb = query_pb2.StructuredQuery.Filter(field_filter=field_filter_pb) - self.assertEqual(filter_pb, expected_pb) - - def test_bad_type(self): - with self.assertRaises(ValueError): - self._call_fut(None) - - -class Test__cursor_pb(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(cursor_pair): - from google.cloud.firestore_v1.query import _cursor_pb - - return _cursor_pb(cursor_pair) - - def test_no_pair(self): - self.assertIsNone(self._call_fut(None)) - - def test_success(self): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1 import _helpers - - data = [1.5, 10, True] - cursor_pair = data, True - - cursor_pb = self._call_fut(cursor_pair) - - expected_pb = query_pb2.Cursor( - values=[_helpers.encode_value(value) for value in data], before=True - ) - self.assertEqual(cursor_pb, expected_pb) - - -class Test__query_response_to_snapshot(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(response_pb, collection, expected_prefix): - from google.cloud.firestore_v1.async_query import _query_response_to_snapshot - - return _query_response_to_snapshot(response_pb, collection, expected_prefix) - - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - client = _make_client() - collection = client.collection("a", "b", "c") - _, expected_prefix = collection._parent_info() - - # Create name for the protobuf. - doc_id = "gigantic" - name = "{}/{}".format(expected_prefix, doc_id) - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=name, data=data) - - snapshot = self._call_fut(response_pb, collection, expected_prefix) - self.assertIsInstance(snapshot, DocumentSnapshot) - expected_path = collection._path + (doc_id,) - self.assertEqual(snapshot.reference._path, expected_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb.read_time) - self.assertEqual(snapshot.create_time, response_pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb.document.update_time) - - -class Test__collection_group_query_response_to_snapshot(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(response_pb, collection): - from google.cloud.firestore_v1.async_query import ( - _collection_group_query_response_to_snapshot, - ) - - return _collection_group_query_response_to_snapshot(response_pb, collection) - - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) - - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) - - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - client = _make_client() - collection = client.collection("a", "b", "c") - other_collection = client.collection("a", "b", "d") - to_match = other_collection.document("gigantic") - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=to_match._document_path, data=data) - - snapshot = self._call_fut(response_pb, collection) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertEqual(snapshot.reference._document_path, to_match._document_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb.read_time) - self.assertEqual(snapshot.create_time, response_pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb.document.update_time) - - -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - def _make_client(project="project-project"): - from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient credentials = _make_credentials() - return Client(project=project, credentials=credentials) - - -def _make_order_pb(field_path, direction): - from google.cloud.firestore_v1.proto import query_pb2 - - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=direction, - ) - - -def _make_query_response(**kwargs): - # kwargs supported are ``skipped_results``, ``name`` and ``data`` - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1 import _helpers - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - kwargs["read_time"] = read_time - - name = kwargs.pop("name", None) - data = kwargs.pop("data", None) - if name is not None and data is not None: - document_pb = document_pb2.Document( - name=name, fields=_helpers.encode_dict(data) - ) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - document_pb.update_time.CopyFrom(update_time) - document_pb.create_time.CopyFrom(create_time) - - kwargs["document"] = document_pb - - return firestore_pb2.RunQueryResponse(**kwargs) + return AsyncClient(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py new file mode 100644 index 0000000000..f65c425605 --- /dev/null +++ b/tests/unit/v1/test_base_query.py @@ -0,0 +1,1441 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest + +import mock +import six + + +class TestBaseQuery(unittest.TestCase): + + if six.PY2: + assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.query import Query + + return Query + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIsNone(query._projection) + self.assertEqual(query._field_filters, ()) + self.assertEqual(query._orders, ()) + self.assertIsNone(query._limit) + self.assertIsNone(query._offset) + self.assertIsNone(query._start_at) + self.assertIsNone(query._end_at) + self.assertFalse(query._all_descendants) + + def _make_one_all_fields( + self, limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True + ): + kwargs = { + "projection": mock.sentinel.projection, + "field_filters": mock.sentinel.filters, + "orders": mock.sentinel.orders, + "limit": limit, + "offset": offset, + "start_at": mock.sentinel.start_at, + "end_at": mock.sentinel.end_at, + "all_descendants": all_descendants, + } + for field in skip_fields: + kwargs.pop(field) + if parent is None: + parent = mock.sentinel.parent + return self._make_one(parent, **kwargs) + + def test_constructor_explicit(self): + limit = 234 + offset = 56 + query = self._make_one_all_fields(limit=limit, offset=offset) + self.assertIs(query._parent, mock.sentinel.parent) + self.assertIs(query._projection, mock.sentinel.projection) + self.assertIs(query._field_filters, mock.sentinel.filters) + self.assertEqual(query._orders, mock.sentinel.orders) + self.assertEqual(query._limit, limit) + self.assertEqual(query._offset, offset) + self.assertIs(query._start_at, mock.sentinel.start_at) + self.assertIs(query._end_at, mock.sentinel.end_at) + self.assertTrue(query._all_descendants) + + def test__client_property(self): + parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) + query = self._make_one(parent) + self.assertIs(query._client, mock.sentinel.client) + + def test___eq___other_type(self): + query = self._make_one_all_fields() + other = object() + self.assertFalse(query == other) + + def test___eq___different_parent(self): + parent = mock.sentinel.parent + other_parent = mock.sentinel.other_parent + query = self._make_one_all_fields(parent=parent) + other = self._make_one_all_fields(parent=other_parent) + self.assertFalse(query == other) + + def test___eq___different_projection(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) + query._projection = mock.sentinel.projection + other = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) + other._projection = mock.sentinel.other_projection + self.assertFalse(query == other) + + def test___eq___different_field_filters(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) + query._field_filters = mock.sentinel.field_filters + other = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) + other._field_filters = mock.sentinel.other_field_filters + self.assertFalse(query == other) + + def test___eq___different_orders(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) + query._orders = mock.sentinel.orders + other = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) + other._orders = mock.sentinel.other_orders + self.assertFalse(query == other) + + def test___eq___different_limit(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, limit=10) + other = self._make_one_all_fields(parent=parent, limit=20) + self.assertFalse(query == other) + + def test___eq___different_offset(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, offset=10) + other = self._make_one_all_fields(parent=parent, offset=20) + self.assertFalse(query == other) + + def test___eq___different_start_at(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) + query._start_at = mock.sentinel.start_at + other = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) + other._start_at = mock.sentinel.other_start_at + self.assertFalse(query == other) + + def test___eq___different_end_at(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) + query._end_at = mock.sentinel.end_at + other = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) + other._end_at = mock.sentinel.other_end_at + self.assertFalse(query == other) + + def test___eq___different_all_descendants(self): + parent = mock.sentinel.parent + query = self._make_one_all_fields(parent=parent, all_descendants=True) + other = self._make_one_all_fields(parent=parent, all_descendants=False) + self.assertFalse(query == other) + + def test___eq___hit(self): + query = self._make_one_all_fields() + other = self._make_one_all_fields() + self.assertTrue(query == other) + + def _compare_queries(self, query1, query2, attr_name): + attrs1 = query1.__dict__.copy() + attrs2 = query2.__dict__.copy() + + attrs1.pop(attr_name) + attrs2.pop(attr_name) + + # The only different should be in ``attr_name``. + self.assertEqual(len(attrs1), len(attrs2)) + for key, value in attrs1.items(): + self.assertIs(value, attrs2[key]) + + @staticmethod + def _make_projection_for_select(field_paths): + from google.cloud.firestore_v1.proto import query_pb2 + + return query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ) + + def test_select_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.select(["*"]) + + def test_select(self): + query1 = self._make_one_all_fields(all_descendants=True) + + field_paths2 = ["foo", "bar"] + query2 = query1.select(field_paths2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual( + query2._projection, self._make_projection_for_select(field_paths2) + ) + self._compare_queries(query1, query2, "_projection") + + # Make sure it overrides. + field_paths3 = ["foo.baz"] + query3 = query2.select(field_paths3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual( + query3._projection, self._make_projection_for_select(field_paths3) + ) + self._compare_queries(query2, query3, "_projection") + + def test_where_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.where("*", "==", 1) + + def test_where(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query = self._make_one_all_fields( + skip_fields=("field_filters",), all_descendants=True + ) + new_query = query.where("power.level", ">", 9000) + + self.assertIsNot(query, new_query) + self.assertIsInstance(new_query, self._get_target_class()) + self.assertEqual(len(new_query._field_filters), 1) + + field_pb = new_query._field_filters[0] + expected_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="power.level"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(integer_value=9000), + ) + self.assertEqual(field_pb, expected_pb) + self._compare_queries(query, new_query, "_field_filters") + + def _where_unary_helper(self, value, op_enum, op_string="=="): + from google.cloud.firestore_v1.proto import query_pb2 + + query = self._make_one_all_fields(skip_fields=("field_filters",)) + field_path = "feeeld" + new_query = query.where(field_path, op_string, value) + + self.assertIsNot(query, new_query) + self.assertIsInstance(new_query, self._get_target_class()) + self.assertEqual(len(new_query._field_filters), 1) + + field_pb = new_query._field_filters[0] + expected_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + self.assertEqual(field_pb, expected_pb) + self._compare_queries(query, new_query, "_field_filters") + + def test_where_eq_null(self): + from google.cloud.firestore_v1.gapic import enums + + op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NULL + self._where_unary_helper(None, op_enum) + + def test_where_gt_null(self): + with self.assertRaises(ValueError): + self._where_unary_helper(None, 0, op_string=">") + + def test_where_eq_nan(self): + from google.cloud.firestore_v1.gapic import enums + + op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NAN + self._where_unary_helper(float("nan"), op_enum) + + def test_where_le_nan(self): + with self.assertRaises(ValueError): + self._where_unary_helper(float("nan"), 0, op_string="<=") + + def test_where_w_delete(self): + from google.cloud.firestore_v1 import DELETE_FIELD + + with self.assertRaises(ValueError): + self._where_unary_helper(DELETE_FIELD, 0) + + def test_where_w_server_timestamp(self): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + with self.assertRaises(ValueError): + self._where_unary_helper(SERVER_TIMESTAMP, 0) + + def test_where_w_array_remove(self): + from google.cloud.firestore_v1 import ArrayRemove + + with self.assertRaises(ValueError): + self._where_unary_helper(ArrayRemove([1, 3, 5]), 0) + + def test_where_w_array_union(self): + from google.cloud.firestore_v1 import ArrayUnion + + with self.assertRaises(ValueError): + self._where_unary_helper(ArrayUnion([2, 4, 8]), 0) + + def test_order_by_invalid_path(self): + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query.order_by("*") + + def test_order_by(self): + from google.cloud.firestore_v1.gapic import enums + + klass = self._get_target_class() + query1 = self._make_one_all_fields( + skip_fields=("orders",), all_descendants=True + ) + + field_path2 = "a" + query2 = query1.order_by(field_path2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, klass) + order_pb2 = _make_order_pb( + field_path2, enums.StructuredQuery.Direction.ASCENDING + ) + self.assertEqual(query2._orders, (order_pb2,)) + self._compare_queries(query1, query2, "_orders") + + # Make sure it appends to the orders. + field_path3 = "b" + query3 = query2.order_by(field_path3, direction=klass.DESCENDING) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, klass) + order_pb3 = _make_order_pb( + field_path3, enums.StructuredQuery.Direction.DESCENDING + ) + self.assertEqual(query3._orders, (order_pb2, order_pb3)) + self._compare_queries(query2, query3, "_orders") + + def test_limit(self): + query1 = self._make_one_all_fields(all_descendants=True) + + limit2 = 100 + query2 = query1.limit(limit2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual(query2._limit, limit2) + self._compare_queries(query1, query2, "_limit") + + # Make sure it overrides. + limit3 = 10 + query3 = query2.limit(limit3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._limit, limit3) + self._compare_queries(query2, query3, "_limit") + + def test_offset(self): + query1 = self._make_one_all_fields(all_descendants=True) + + offset2 = 23 + query2 = query1.offset(offset2) + self.assertIsNot(query2, query1) + self.assertIsInstance(query2, self._get_target_class()) + self.assertEqual(query2._offset, offset2) + self._compare_queries(query1, query2, "_offset") + + # Make sure it overrides. + offset3 = 35 + query3 = query2.offset(offset3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._offset, offset3) + self._compare_queries(query2, query3, "_offset") + + @staticmethod + def _make_collection(*path, **kw): + from google.cloud.firestore_v1 import collection + + return collection.CollectionReference(*path, **kw) + + @staticmethod + def _make_docref(*path, **kw): + from google.cloud.firestore_v1 import document + + return document.DocumentReference(*path, **kw) + + @staticmethod + def _make_snapshot(docref, values): + from google.cloud.firestore_v1 import document + + return document.DocumentSnapshot(docref, values, True, None, None, None) + + def test__cursor_helper_w_dict(self): + values = {"a": 7, "b": "foo"} + query1 = self._make_one(mock.sentinel.parent) + query1._all_descendants = True + query2 = query1._cursor_helper(values, True, True) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._end_at) + self.assertTrue(query2._all_descendants) + + cursor, before = query2._start_at + + self.assertEqual(cursor, values) + self.assertTrue(before) + + def test__cursor_helper_w_tuple(self): + values = (7, "foo") + query1 = self._make_one(mock.sentinel.parent) + query2 = query1._cursor_helper(values, False, True) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._end_at) + + cursor, before = query2._start_at + + self.assertEqual(cursor, list(values)) + self.assertFalse(before) + + def test__cursor_helper_w_list(self): + values = [7, "foo"] + query1 = self._make_one(mock.sentinel.parent) + query2 = query1._cursor_helper(values, True, False) + + self.assertIs(query2._parent, mock.sentinel.parent) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, query1._orders) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertEqual(cursor, values) + self.assertIsNot(cursor, values) + self.assertTrue(before) + + def test__cursor_helper_w_snapshot_wrong_collection(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("there", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection) + + with self.assertRaises(ValueError): + query._cursor_helper(snapshot, False, False) + + def test__cursor_helper_w_snapshot_other_collection_all_descendants(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("there", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query1 = self._make_one(collection, all_descendants=True) + + query2 = query1._cursor_helper(snapshot, False, False) + + self.assertIs(query2._parent, collection) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, ()) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertIs(cursor, snapshot) + self.assertFalse(before) + + def test__cursor_helper_w_snapshot(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query1 = self._make_one(collection) + + query2 = query1._cursor_helper(snapshot, False, False) + + self.assertIs(query2._parent, collection) + self.assertIsNone(query2._projection) + self.assertEqual(query2._field_filters, ()) + self.assertEqual(query2._orders, ()) + self.assertIsNone(query2._limit) + self.assertIsNone(query2._offset) + self.assertIsNone(query2._start_at) + + cursor, before = query2._end_at + + self.assertIs(cursor, snapshot) + self.assertFalse(before) + + def test_start_at(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields( + parent=collection, skip_fields=("orders",), all_descendants=True + ) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.start_at(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._start_at, (document_fields3, True)) + self._compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.start_at(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._start_at, (document_fields5, True)) + self._compare_queries(query4, query5, "_start_at") + + def test_start_after(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.start_after(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._start_at, (document_fields3, False)) + self._compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.start_after(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._start_at, (document_fields5, False)) + self._compare_queries(query4, query5, "_start_at") + + def test_end_before(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.end_before(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._end_at, (document_fields3, True)) + self._compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.end_before(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._end_at, (document_fields5, True)) + self._compare_queries(query4, query5, "_end_at") + self._compare_queries(query4, query5, "_end_at") + + def test_end_at(self): + collection = self._make_collection("here") + query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.end_at(document_fields3) + self.assertIsNot(query3, query2) + self.assertIsInstance(query3, self._get_target_class()) + self.assertEqual(query3._end_at, (document_fields3, False)) + self._compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = self._make_docref("here", "doc_id") + document_fields5 = self._make_snapshot(docref, values5) + query5 = query4.end_at(document_fields5) + self.assertIsNot(query5, query4) + self.assertIsInstance(query5, self._get_target_class()) + self.assertEqual(query5._end_at, (document_fields5, False)) + self._compare_queries(query4, query5, "_end_at") + + def test__filters_pb_empty(self): + query = self._make_one(mock.sentinel.parent) + self.assertEqual(len(query._field_filters), 0) + self.assertIsNone(query._filters_pb()) + + def test__filters_pb_single(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query1 = self._make_one(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + filter_pb = query2._filters_pb() + expected_pb = query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="x.y"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=50.5), + ) + ) + self.assertEqual(filter_pb, expected_pb) + + def test__filters_pb_multi(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + query1 = self._make_one(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + query3 = query2.where("ABC", "==", 123) + + filter_pb = query3._filters_pb() + op_class = enums.StructuredQuery.FieldFilter.Operator + expected_pb = query_pb2.StructuredQuery.Filter( + composite_filter=query_pb2.StructuredQuery.CompositeFilter( + op=enums.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference( + field_path="x.y" + ), + op=op_class.GREATER_THAN, + value=document_pb2.Value(double_value=50.5), + ) + ), + query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference( + field_path="ABC" + ), + op=op_class.EQUAL, + value=document_pb2.Value(integer_value=123), + ) + ), + ], + ) + ) + self.assertEqual(filter_pb, expected_pb) + + def test__normalize_projection_none(self): + query = self._make_one(mock.sentinel.parent) + self.assertIsNone(query._normalize_projection(None)) + + def test__normalize_projection_empty(self): + projection = self._make_projection_for_select([]) + query = self._make_one(mock.sentinel.parent) + normalized = query._normalize_projection(projection) + field_paths = [field_ref.field_path for field_ref in normalized.fields] + self.assertEqual(field_paths, ["__name__"]) + + def test__normalize_projection_non_empty(self): + projection = self._make_projection_for_select(["a", "b"]) + query = self._make_one(mock.sentinel.parent) + self.assertIs(query._normalize_projection(projection), projection) + + def test__normalize_orders_wo_orders_wo_cursors(self): + query = self._make_one(mock.sentinel.parent) + expected = [] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_w_orders_wo_cursors(self): + query = self._make_one(mock.sentinel.parent).order_by("a") + expected = [query._make_order("a", "ASCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection).start_at(snapshot) + expected = [query._make_order("__name__", "ASCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_w_name_orders_w_snapshot_cursor(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = ( + self._make_one(collection) + .order_by("__name__", "DESCENDING") + .start_at(snapshot) + ) + expected = [query._make_order("__name__", "DESCENDING")] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = ( + self._make_one(collection) + .where("c", "<=", 20) + .order_by("c", "DESCENDING") + .start_at(snapshot) + ) + expected = [ + query._make_order("c", "DESCENDING"), + query._make_order("__name__", "DESCENDING"), + ] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self): + values = {"a": 7, "b": "foo"} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + collection = self._make_collection("here") + query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot) + expected = [ + query._make_order("c", "ASCENDING"), + query._make_order("__name__", "ASCENDING"), + ] + self.assertEqual(query._normalize_orders(), expected) + + def test__normalize_cursor_none(self): + query = self._make_one(mock.sentinel.parent) + self.assertIsNone(query._normalize_cursor(None, query._orders)) + + def test__normalize_cursor_no_order(self): + cursor = ([1], True) + query = self._make_one(mock.sentinel.parent) + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_list_mismatched_order(self): + cursor = ([1, 2], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_dict_mismatched_order(self): + cursor = ({"a": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_delete(self): + from google.cloud.firestore_v1 import DELETE_FIELD + + cursor = ([DELETE_FIELD], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_server_timestamp(self): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + cursor = ([SERVER_TIMESTAMP], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_array_remove(self): + from google.cloud.firestore_v1 import ArrayRemove + + cursor = ([ArrayRemove([1, 3, 5])], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_w_array_union(self): + from google.cloud.firestore_v1 import ArrayUnion + + cursor = ([ArrayUnion([2, 4, 8])], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + with self.assertRaises(ValueError): + query._normalize_cursor(cursor, query._orders) + + def test__normalize_cursor_as_list_hit(self): + cursor = ([1], True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_hit(self): + cursor = ({"b": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_with_dot_key_hit(self): + cursor = ({"b.a": 1}, True) + query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_dict_with_inner_data_hit(self): + cursor = ({"b": {"a": 1}}, True) + query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_as_snapshot_hit(self): + values = {"b": 1} + docref = self._make_docref("here", "doc_id") + snapshot = self._make_snapshot(docref, values) + cursor = (snapshot, True) + collection = self._make_collection("here") + query = self._make_one(collection).order_by("b", "ASCENDING") + + self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + + def test__normalize_cursor_w___name___w_reference(self): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client"]) + parent._client = client + parent._path = ["C"] + query = self._make_one(parent).order_by("__name__", "ASCENDING") + docref = self._make_docref("here", "doc_id") + values = {"a": 7} + snapshot = self._make_snapshot(docref, values) + expected = docref + cursor = (snapshot, True) + + self.assertEqual( + query._normalize_cursor(cursor, query._orders), ([expected], True) + ) + + def test__normalize_cursor_w___name___wo_slash(self): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client", "document"]) + parent._client = client + parent._path = ["C"] + document = parent.document.return_value = mock.Mock(spec=[]) + query = self._make_one(parent).order_by("__name__", "ASCENDING") + cursor = (["b"], True) + expected = document + + self.assertEqual( + query._normalize_cursor(cursor, query._orders), ([expected], True) + ) + parent.document.assert_called_once_with("b") + + def test__to_protobuf_all_fields(self): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.select(["X", "Y", "Z"]) + query3 = query2.where("Y", ">", 2.5) + query4 = query3.order_by("X") + query5 = query4.limit(17) + query6 = query5.offset(3) + query7 = query6.start_at({"X": 10}) + query8 = query7.end_at({"X": 25}) + + structured_query_pb = query8._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "select": query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in ["X", "Y", "Z"] + ] + ), + "where": query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="Y"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=2.5), + ) + ), + "order_by": [ + _make_order_pb("X", enums.StructuredQuery.Direction.ASCENDING) + ], + "start_at": query_pb2.Cursor( + values=[document_pb2.Value(integer_value=10)], before=True + ), + "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=25)]), + "offset": 3, + "limit": wrappers_pb2.Int32Value(value=17), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_select_only(self): + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = self._make_one(parent) + field_paths = ["a.b", "a.c", "d"] + query2 = query1.select(field_paths) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "select": query_pb2.StructuredQuery.Projection( + fields=[ + query_pb2.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_where_only(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="dog", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.where("a", "==", u"b") + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "where": query_pb2.StructuredQuery.Filter( + field_filter=query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="a"), + op=enums.StructuredQuery.FieldFilter.Operator.EQUAL, + value=document_pb2.Value(string_value=u"b"), + ) + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_order_by_only(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="fish", spec=["id"]) + query1 = self._make_one(parent) + query2 = query1.order_by("abc") + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("abc", enums.StructuredQuery.Direction.ASCENDING) + ], + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_start_at_only(self): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="phish", spec=["id"]) + query = self._make_one(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) + + structured_query_pb = query._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("X.Y", enums.StructuredQuery.Direction.ASCENDING) + ], + "start_at": query_pb2.Cursor( + values=[document_pb2.Value(string_value=u"Z")] + ), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_end_at_only(self): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="ghoti", spec=["id"]) + query = self._make_one(parent).order_by("a").end_at({"a": 88}) + + structured_query_pb = query._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "order_by": [ + _make_order_pb("a", enums.StructuredQuery.Direction.ASCENDING) + ], + "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=88)]), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_offset_only(self): + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="cartt", spec=["id"]) + query1 = self._make_one(parent) + offset = 14 + query2 = query1.offset(offset) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "offset": offset, + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + self.assertEqual(structured_query_pb, expected_pb) + + def test__to_protobuf_limit_only(self): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + parent = mock.Mock(id="donut", spec=["id"]) + query1 = self._make_one(parent) + limit = 31 + query2 = query1.limit(limit) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from": [ + query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) + ], + "limit": wrappers_pb2.Int32Value(value=limit), + } + expected_pb = query_pb2.StructuredQuery(**query_kwargs) + + self.assertEqual(structured_query_pb, expected_pb) + + def test_comparator_no_ordering(self): + query = self._make_one(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, -1) + + def test_comparator_no_ordering_same_id(self): + query = self._make_one(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument1") + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, 0) + + def test_comparator_ordering(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, 1) + + def test_comparator_ordering_descending(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = -1 # descending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + sort = query._comparator(doc1, doc2) + self.assertEqual(sort, -1) + + def test_comparator_missing_order_by_field_in_data_raises(self): + query = self._make_one(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] + + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = {} + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } + + with self.assertRaisesRegex(ValueError, "Can only compare fields "): + query._comparator(doc1, doc2) + + +class Test__enum_from_op_string(unittest.TestCase): + @staticmethod + def _call_fut(op_string): + from google.cloud.firestore_v1.base_query import _enum_from_op_string + + return _enum_from_op_string(op_string) + + @staticmethod + def _get_op_class(): + from google.cloud.firestore_v1.gapic import enums + + return enums.StructuredQuery.FieldFilter.Operator + + def test_lt(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) + + def test_le(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) + + def test_eq(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("=="), op_class.EQUAL) + + def test_ge(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) + + def test_gt(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) + + def test_array_contains(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) + + def test_in(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("in"), op_class.IN) + + def test_array_contains_any(self): + op_class = self._get_op_class() + self.assertEqual( + self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY + ) + + def test_invalid(self): + with self.assertRaises(ValueError): + self._call_fut("?") + + +class Test__isnan(unittest.TestCase): + @staticmethod + def _call_fut(value): + from google.cloud.firestore_v1.base_query import _isnan + + return _isnan(value) + + def test_valid(self): + self.assertTrue(self._call_fut(float("nan"))) + + def test_invalid(self): + self.assertFalse(self._call_fut(51.5)) + self.assertFalse(self._call_fut(None)) + self.assertFalse(self._call_fut("str")) + self.assertFalse(self._call_fut(int)) + self.assertFalse(self._call_fut(1.0 + 1.0j)) + + +class Test__enum_from_direction(unittest.TestCase): + @staticmethod + def _call_fut(direction): + from google.cloud.firestore_v1.base_query import _enum_from_direction + + return _enum_from_direction(direction) + + def test_success(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.query import Query + + dir_class = enums.StructuredQuery.Direction + self.assertEqual(self._call_fut(Query.ASCENDING), dir_class.ASCENDING) + self.assertEqual(self._call_fut(Query.DESCENDING), dir_class.DESCENDING) + + # Ints pass through + self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING) + self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING) + + def test_failure(self): + with self.assertRaises(ValueError): + self._call_fut("neither-ASCENDING-nor-DESCENDING") + + +class Test__filter_pb(unittest.TestCase): + @staticmethod + def _call_fut(field_or_unary): + from google.cloud.firestore_v1.base_query import _filter_pb + + return _filter_pb(field_or_unary) + + def test_unary(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import query_pb2 + + unary_pb = query_pb2.StructuredQuery.UnaryFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="a.b.c"), + op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + filter_pb = self._call_fut(unary_pb) + expected_pb = query_pb2.StructuredQuery.Filter(unary_filter=unary_pb) + self.assertEqual(filter_pb, expected_pb) + + def test_field(self): + from google.cloud.firestore_v1.gapic import enums + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import query_pb2 + + field_filter_pb = query_pb2.StructuredQuery.FieldFilter( + field=query_pb2.StructuredQuery.FieldReference(field_path="XYZ"), + op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document_pb2.Value(double_value=90.75), + ) + filter_pb = self._call_fut(field_filter_pb) + expected_pb = query_pb2.StructuredQuery.Filter(field_filter=field_filter_pb) + self.assertEqual(filter_pb, expected_pb) + + def test_bad_type(self): + with self.assertRaises(ValueError): + self._call_fut(None) + + +class Test__cursor_pb(unittest.TestCase): + @staticmethod + def _call_fut(cursor_pair): + from google.cloud.firestore_v1.base_query import _cursor_pb + + return _cursor_pb(cursor_pair) + + def test_no_pair(self): + self.assertIsNone(self._call_fut(None)) + + def test_success(self): + from google.cloud.firestore_v1.proto import query_pb2 + from google.cloud.firestore_v1 import _helpers + + data = [1.5, 10, True] + cursor_pair = data, True + + cursor_pb = self._call_fut(cursor_pair) + + expected_pb = query_pb2.Cursor( + values=[_helpers.encode_value(value) for value in data], before=True + ) + self.assertEqual(cursor_pb, expected_pb) + + +class Test__query_response_to_snapshot(unittest.TestCase): + @staticmethod + def _call_fut(response_pb, collection, expected_prefix): + from google.cloud.firestore_v1.base_query import _query_response_to_snapshot + + return _query_response_to_snapshot(response_pb, collection, expected_prefix) + + def test_empty(self): + response_pb = _make_query_response() + snapshot = self._call_fut(response_pb, None, None) + self.assertIsNone(snapshot) + + def test_after_offset(self): + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = self._call_fut(response_pb, None, None) + self.assertIsNone(snapshot) + + def test_response(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + client = _make_client() + collection = client.collection("a", "b", "c") + _, expected_prefix = collection._parent_info() + + # Create name for the protobuf. + doc_id = "gigantic" + name = "{}/{}".format(expected_prefix, doc_id) + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=name, data=data) + + snapshot = self._call_fut(response_pb, collection, expected_prefix) + self.assertIsInstance(snapshot, DocumentSnapshot) + expected_path = collection._path + (doc_id,) + self.assertEqual(snapshot.reference._path, expected_path) + self.assertEqual(snapshot.to_dict(), data) + self.assertTrue(snapshot.exists) + self.assertEqual(snapshot.read_time, response_pb.read_time) + self.assertEqual(snapshot.create_time, response_pb.document.create_time) + self.assertEqual(snapshot.update_time, response_pb.document.update_time) + + +class Test__collection_group_query_response_to_snapshot(unittest.TestCase): + @staticmethod + def _call_fut(response_pb, collection): + from google.cloud.firestore_v1.query import ( + _collection_group_query_response_to_snapshot, + ) + + return _collection_group_query_response_to_snapshot(response_pb, collection) + + def test_empty(self): + response_pb = _make_query_response() + snapshot = self._call_fut(response_pb, None) + self.assertIsNone(snapshot) + + def test_after_offset(self): + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = self._call_fut(response_pb, None) + self.assertIsNone(snapshot) + + def test_response(self): + from google.cloud.firestore_v1.document import DocumentSnapshot + + client = _make_client() + collection = client.collection("a", "b", "c") + other_collection = client.collection("a", "b", "d") + to_match = other_collection.document("gigantic") + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=to_match._document_path, data=data) + + snapshot = self._call_fut(response_pb, collection) + self.assertIsInstance(snapshot, DocumentSnapshot) + self.assertEqual(snapshot.reference._document_path, to_match._document_path) + self.assertEqual(snapshot.to_dict(), data) + self.assertTrue(snapshot.exists) + self.assertEqual(snapshot.read_time, response_pb.read_time) + self.assertEqual(snapshot.create_time, response_pb.document.create_time) + self.assertEqual(snapshot.update_time, response_pb.document.update_time) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="project-project"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) + + +def _make_order_pb(field_path, direction): + from google.cloud.firestore_v1.proto import query_pb2 + + return query_pb2.StructuredQuery.Order( + field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), + direction=direction, + ) + + +def _make_query_response(**kwargs): + # kwargs supported are ``skipped_results``, ``name`` and ``data`` + from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + kwargs["read_time"] = read_time + + name = kwargs.pop("name", None) + data = kwargs.pop("data", None) + if name is not None and data is not None: + document_pb = document_pb2.Document( + name=name, fields=_helpers.encode_dict(data) + ) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + document_pb.update_time.CopyFrom(update_time) + document_pb.create_time.CopyFrom(create_time) + + kwargs["document"] = document_pb + + return firestore_pb2.RunQueryResponse(**kwargs) diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 1ef2e66746..4188d959f4 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -206,7 +206,7 @@ def test_select(self): def _make_field_filter_pb(field_path, op_string, value): from google.cloud.firestore_v1.proto import query_pb2 from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.query import _enum_from_op_string + from google.cloud.firestore_v1.base_query import _enum_from_op_string return query_pb2.StructuredQuery.FieldFilter( field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), @@ -234,7 +234,7 @@ def test_where(self): @staticmethod def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1.query import _enum_from_direction + from google.cloud.firestore_v1.base_query import _enum_from_direction return query_pb2.StructuredQuery.Order( field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index bdb0e922d0..896706c748 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import types import unittest import mock import six +from tests.unit.v1.test_base_query import _make_credentials, _make_query_response + class TestQuery(unittest.TestCase): @@ -47,1021 +48,6 @@ def test_constructor_defaults(self): self.assertIsNone(query._end_at) self.assertFalse(query._all_descendants) - def _make_one_all_fields( - self, limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True - ): - kwargs = { - "projection": mock.sentinel.projection, - "field_filters": mock.sentinel.filters, - "orders": mock.sentinel.orders, - "limit": limit, - "offset": offset, - "start_at": mock.sentinel.start_at, - "end_at": mock.sentinel.end_at, - "all_descendants": all_descendants, - } - for field in skip_fields: - kwargs.pop(field) - if parent is None: - parent = mock.sentinel.parent - return self._make_one(parent, **kwargs) - - def test_constructor_explicit(self): - limit = 234 - offset = 56 - query = self._make_one_all_fields(limit=limit, offset=offset) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIs(query._projection, mock.sentinel.projection) - self.assertIs(query._field_filters, mock.sentinel.filters) - self.assertEqual(query._orders, mock.sentinel.orders) - self.assertEqual(query._limit, limit) - self.assertEqual(query._offset, offset) - self.assertIs(query._start_at, mock.sentinel.start_at) - self.assertIs(query._end_at, mock.sentinel.end_at) - self.assertTrue(query._all_descendants) - - def test__client_property(self): - parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) - query = self._make_one(parent) - self.assertIs(query._client, mock.sentinel.client) - - def test___eq___other_type(self): - query = self._make_one_all_fields() - other = object() - self.assertFalse(query == other) - - def test___eq___different_parent(self): - parent = mock.sentinel.parent - other_parent = mock.sentinel.other_parent - query = self._make_one_all_fields(parent=parent) - other = self._make_one_all_fields(parent=other_parent) - self.assertFalse(query == other) - - def test___eq___different_projection(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - query._projection = mock.sentinel.projection - other = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - other._projection = mock.sentinel.other_projection - self.assertFalse(query == other) - - def test___eq___different_field_filters(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - query._field_filters = mock.sentinel.field_filters - other = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - other._field_filters = mock.sentinel.other_field_filters - self.assertFalse(query == other) - - def test___eq___different_orders(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - query._orders = mock.sentinel.orders - other = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - other._orders = mock.sentinel.other_orders - self.assertFalse(query == other) - - def test___eq___different_limit(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, limit=10) - other = self._make_one_all_fields(parent=parent, limit=20) - self.assertFalse(query == other) - - def test___eq___different_offset(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, offset=10) - other = self._make_one_all_fields(parent=parent, offset=20) - self.assertFalse(query == other) - - def test___eq___different_start_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - query._start_at = mock.sentinel.start_at - other = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - other._start_at = mock.sentinel.other_start_at - self.assertFalse(query == other) - - def test___eq___different_end_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - query._end_at = mock.sentinel.end_at - other = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - other._end_at = mock.sentinel.other_end_at - self.assertFalse(query == other) - - def test___eq___different_all_descendants(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, all_descendants=True) - other = self._make_one_all_fields(parent=parent, all_descendants=False) - self.assertFalse(query == other) - - def test___eq___hit(self): - query = self._make_one_all_fields() - other = self._make_one_all_fields() - self.assertTrue(query == other) - - def _compare_queries(self, query1, query2, attr_name): - attrs1 = query1.__dict__.copy() - attrs2 = query2.__dict__.copy() - - attrs1.pop(attr_name) - attrs2.pop(attr_name) - - # The only different should be in ``attr_name``. - self.assertEqual(len(attrs1), len(attrs2)) - for key, value in attrs1.items(): - self.assertIs(value, attrs2[key]) - - @staticmethod - def _make_projection_for_select(field_paths): - from google.cloud.firestore_v1.proto import query_pb2 - - return query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ) - - def test_select_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.select(["*"]) - - def test_select(self): - query1 = self._make_one_all_fields(all_descendants=True) - - field_paths2 = ["foo", "bar"] - query2 = query1.select(field_paths2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual( - query2._projection, self._make_projection_for_select(field_paths2) - ) - self._compare_queries(query1, query2, "_projection") - - # Make sure it overrides. - field_paths3 = ["foo.baz"] - query3 = query2.select(field_paths3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual( - query3._projection, self._make_projection_for_select(field_paths3) - ) - self._compare_queries(query2, query3, "_projection") - - def test_where_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.where("*", "==", 1) - - def test_where(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query = self._make_one_all_fields( - skip_fields=("field_filters",), all_descendants=True - ) - new_query = query.where("power.level", ">", 9000) - - self.assertIsNot(query, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) - - field_pb = new_query._field_filters[0] - expected_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="power.level"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(integer_value=9000), - ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query, new_query, "_field_filters") - - def _where_unary_helper(self, value, op_enum, op_string="=="): - from google.cloud.firestore_v1.proto import query_pb2 - - query = self._make_one_all_fields(skip_fields=("field_filters",)) - field_path = "feeeld" - new_query = query.where(field_path, op_string, value) - - self.assertIsNot(query, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) - - field_pb = new_query._field_filters[0] - expected_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=op_enum, - ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query, new_query, "_field_filters") - - def test_where_eq_null(self): - from google.cloud.firestore_v1.gapic import enums - - op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NULL - self._where_unary_helper(None, op_enum) - - def test_where_gt_null(self): - with self.assertRaises(ValueError): - self._where_unary_helper(None, 0, op_string=">") - - def test_where_eq_nan(self): - from google.cloud.firestore_v1.gapic import enums - - op_enum = enums.StructuredQuery.UnaryFilter.Operator.IS_NAN - self._where_unary_helper(float("nan"), op_enum) - - def test_where_le_nan(self): - with self.assertRaises(ValueError): - self._where_unary_helper(float("nan"), 0, op_string="<=") - - def test_where_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD - - with self.assertRaises(ValueError): - self._where_unary_helper(DELETE_FIELD, 0) - - def test_where_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - - with self.assertRaises(ValueError): - self._where_unary_helper(SERVER_TIMESTAMP, 0) - - def test_where_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove - - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayRemove([1, 3, 5]), 0) - - def test_where_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion - - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayUnion([2, 4, 8]), 0) - - def test_order_by_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query.order_by("*") - - def test_order_by(self): - from google.cloud.firestore_v1.gapic import enums - - klass = self._get_target_class() - query1 = self._make_one_all_fields( - skip_fields=("orders",), all_descendants=True - ) - - field_path2 = "a" - query2 = query1.order_by(field_path2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, klass) - order_pb2 = _make_order_pb( - field_path2, enums.StructuredQuery.Direction.ASCENDING - ) - self.assertEqual(query2._orders, (order_pb2,)) - self._compare_queries(query1, query2, "_orders") - - # Make sure it appends to the orders. - field_path3 = "b" - query3 = query2.order_by(field_path3, direction=klass.DESCENDING) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, klass) - order_pb3 = _make_order_pb( - field_path3, enums.StructuredQuery.Direction.DESCENDING - ) - self.assertEqual(query3._orders, (order_pb2, order_pb3)) - self._compare_queries(query2, query3, "_orders") - - def test_limit(self): - query1 = self._make_one_all_fields(all_descendants=True) - - limit2 = 100 - query2 = query1.limit(limit2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._limit, limit2) - self._compare_queries(query1, query2, "_limit") - - # Make sure it overrides. - limit3 = 10 - query3 = query2.limit(limit3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._limit, limit3) - self._compare_queries(query2, query3, "_limit") - - def test_offset(self): - query1 = self._make_one_all_fields(all_descendants=True) - - offset2 = 23 - query2 = query1.offset(offset2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._offset, offset2) - self._compare_queries(query1, query2, "_offset") - - # Make sure it overrides. - offset3 = 35 - query3 = query2.offset(offset3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._offset, offset3) - self._compare_queries(query2, query3, "_offset") - - @staticmethod - def _make_collection(*path, **kw): - from google.cloud.firestore_v1 import collection - - return collection.CollectionReference(*path, **kw) - - @staticmethod - def _make_docref(*path, **kw): - from google.cloud.firestore_v1 import document - - return document.DocumentReference(*path, **kw) - - @staticmethod - def _make_snapshot(docref, values): - from google.cloud.firestore_v1 import document - - return document.DocumentSnapshot(docref, values, True, None, None, None) - - def test__cursor_helper_w_dict(self): - values = {"a": 7, "b": "foo"} - query1 = self._make_one(mock.sentinel.parent) - query1._all_descendants = True - query2 = query1._cursor_helper(values, True, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - self.assertTrue(query2._all_descendants) - - cursor, before = query2._start_at - - self.assertEqual(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_tuple(self): - values = (7, "foo") - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, False, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - - cursor, before = query2._start_at - - self.assertEqual(cursor, list(values)) - self.assertFalse(before) - - def test__cursor_helper_w_list(self): - values = [7, "foo"] - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, True, False) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertEqual(cursor, values) - self.assertIsNot(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_snapshot_wrong_collection(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection) - - with self.assertRaises(ValueError): - query._cursor_helper(snapshot, False, False) - - def test__cursor_helper_w_snapshot_other_collection_all_descendants(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection, all_descendants=True) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test__cursor_helper_w_snapshot(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test_start_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields( - parent=collection, skip_fields=("orders",), all_descendants=True - ) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.start_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_start_at") - - def test_start_after(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.start_after(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_after(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_start_at") - - def test_end_before(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.end_before(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_before(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_end_at") - self._compare_queries(query4, query5, "_end_at") - - def test_end_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.end_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_end_at") - - def test__filters_pb_empty(self): - query = self._make_one(mock.sentinel.parent) - self.assertEqual(len(query._field_filters), 0) - self.assertIsNone(query._filters_pb()) - - def test__filters_pb_single(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - filter_pb = query2._filters_pb() - expected_pb = query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="x.y"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=50.5), - ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__filters_pb_multi(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - query3 = query2.where("ABC", "==", 123) - - filter_pb = query3._filters_pb() - op_class = enums.StructuredQuery.FieldFilter.Operator - expected_pb = query_pb2.StructuredQuery.Filter( - composite_filter=query_pb2.StructuredQuery.CompositeFilter( - op=enums.StructuredQuery.CompositeFilter.Operator.AND, - filters=[ - query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference( - field_path="x.y" - ), - op=op_class.GREATER_THAN, - value=document_pb2.Value(double_value=50.5), - ) - ), - query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference( - field_path="ABC" - ), - op=op_class.EQUAL, - value=document_pb2.Value(integer_value=123), - ) - ), - ], - ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__normalize_projection_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_projection(None)) - - def test__normalize_projection_empty(self): - projection = self._make_projection_for_select([]) - query = self._make_one(mock.sentinel.parent) - normalized = query._normalize_projection(projection) - field_paths = [field_ref.field_path for field_ref in normalized.fields] - self.assertEqual(field_paths, ["__name__"]) - - def test__normalize_projection_non_empty(self): - projection = self._make_projection_for_select(["a", "b"]) - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._normalize_projection(projection), projection) - - def test__normalize_orders_wo_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent) - expected = [] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent).order_by("a") - expected = [query._make_order("a", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).start_at(snapshot) - expected = [query._make_order("__name__", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_name_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .order_by("__name__", "DESCENDING") - .start_at(snapshot) - ) - expected = [query._make_order("__name__", "DESCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .where("c", "<=", 20) - .order_by("c", "DESCENDING") - .start_at(snapshot) - ) - expected = [ - query._make_order("c", "DESCENDING"), - query._make_order("__name__", "DESCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot) - expected = [ - query._make_order("c", "ASCENDING"), - query._make_order("__name__", "ASCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_cursor_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_cursor(None, query._orders)) - - def test__normalize_cursor_no_order(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent) - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_list_mismatched_order(self): - cursor = ([1, 2], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_dict_mismatched_order(self): - cursor = ({"a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD - - cursor = ([DELETE_FIELD], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - - cursor = ([SERVER_TIMESTAMP], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove - - cursor = ([ArrayRemove([1, 3, 5])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion - - cursor = ([ArrayUnion([2, 4, 8])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - - def test__normalize_cursor_as_list_hit(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_hit(self): - cursor = ({"b": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_with_dot_key_hit(self): - cursor = ({"b.a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_dict_with_inner_data_hit(self): - cursor = ({"b": {"a": 1}}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_as_snapshot_hit(self): - values = {"b": 1} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - cursor = (snapshot, True) - collection = self._make_collection("here") - query = self._make_one(collection).order_by("b", "ASCENDING") - - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - - def test__normalize_cursor_w___name___w_reference(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client"]) - parent._client = client - parent._path = ["C"] - query = self._make_one(parent).order_by("__name__", "ASCENDING") - docref = self._make_docref("here", "doc_id") - values = {"a": 7} - snapshot = self._make_snapshot(docref, values) - expected = docref - cursor = (snapshot, True) - - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) - - def test__normalize_cursor_w___name___wo_slash(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client", "document"]) - parent._client = client - parent._path = ["C"] - document = parent.document.return_value = mock.Mock(spec=[]) - query = self._make_one(parent).order_by("__name__", "ASCENDING") - cursor = (["b"], True) - expected = document - - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) - parent.document.assert_called_once_with("b") - - def test__to_protobuf_all_fields(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.select(["X", "Y", "Z"]) - query3 = query2.where("Y", ">", 2.5) - query4 = query3.order_by("X") - query5 = query4.limit(17) - query6 = query5.offset(3) - query7 = query6.start_at({"X": 10}) - query8 = query7.end_at({"X": 25}) - - structured_query_pb = query8._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in ["X", "Y", "Z"] - ] - ), - "where": query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="Y"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=2.5), - ) - ), - "order_by": [ - _make_order_pb("X", enums.StructuredQuery.Direction.ASCENDING) - ], - "start_at": query_pb2.Cursor( - values=[document_pb2.Value(integer_value=10)], before=True - ), - "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=25)]), - "offset": 3, - "limit": wrappers_pb2.Int32Value(value=17), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_select_only(self): - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - field_paths = ["a.b", "a.c", "d"] - query2 = query1.select(field_paths) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query_pb2.StructuredQuery.Projection( - fields=[ - query_pb2.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_where_only(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="dog", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.where("a", "==", u"b") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "where": query_pb2.StructuredQuery.Filter( - field_filter=query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="a"), - op=enums.StructuredQuery.FieldFilter.Operator.EQUAL, - value=document_pb2.Value(string_value=u"b"), - ) - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_order_by_only(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="fish", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.order_by("abc") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("abc", enums.StructuredQuery.Direction.ASCENDING) - ], - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_start_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="phish", spec=["id"]) - query = self._make_one(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) - - structured_query_pb = query._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("X.Y", enums.StructuredQuery.Direction.ASCENDING) - ], - "start_at": query_pb2.Cursor( - values=[document_pb2.Value(string_value=u"Z")] - ), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_end_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="ghoti", spec=["id"]) - query = self._make_one(parent).order_by("a").end_at({"a": 88}) - - structured_query_pb = query._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [ - _make_order_pb("a", enums.StructuredQuery.Direction.ASCENDING) - ], - "end_at": query_pb2.Cursor(values=[document_pb2.Value(integer_value=88)]), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_offset_only(self): - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="cartt", spec=["id"]) - query1 = self._make_one(parent) - offset = 14 - query2 = query1.offset(offset) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "offset": offset, - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_limit_only(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - parent = mock.Mock(id="donut", spec=["id"]) - query1 = self._make_one(parent) - limit = 31 - query2 = query1.limit(limit) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from": [ - query_pb2.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "limit": wrappers_pb2.Int32Value(value=limit), - } - expected_pb = query_pb2.StructuredQuery(**query_kwargs) - - self.assertEqual(structured_query_pb, expected_pb) - def test_get_simple(self): import warnings @@ -1366,381 +352,9 @@ def test_on_snapshot(self, watch): query.on_snapshot(None) watch.for_query.assert_called_once() - def test_comparator_no_ordering(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_no_ordering_same_id(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument1") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 0) - - def test_comparator_ordering(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 1) - - def test_comparator_ordering_descending(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = -1 # descending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_missing_order_by_field_in_data_raises(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = {} - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - with self.assertRaisesRegex(ValueError, "Can only compare fields "): - query._comparator(doc1, doc2) - - -class Test__enum_from_op_string(unittest.TestCase): - @staticmethod - def _call_fut(op_string): - from google.cloud.firestore_v1.query import _enum_from_op_string - - return _enum_from_op_string(op_string) - - @staticmethod - def _get_op_class(): - from google.cloud.firestore_v1.gapic import enums - - return enums.StructuredQuery.FieldFilter.Operator - - def test_lt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) - - def test_le(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) - - def test_eq(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("=="), op_class.EQUAL) - - def test_ge(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) - - def test_gt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) - - def test_array_contains(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) - - def test_in(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("in"), op_class.IN) - - def test_array_contains_any(self): - op_class = self._get_op_class() - self.assertEqual( - self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY - ) - - def test_invalid(self): - with self.assertRaises(ValueError): - self._call_fut("?") - - -class Test__isnan(unittest.TestCase): - @staticmethod - def _call_fut(value): - from google.cloud.firestore_v1.query import _isnan - - return _isnan(value) - - def test_valid(self): - self.assertTrue(self._call_fut(float("nan"))) - - def test_invalid(self): - self.assertFalse(self._call_fut(51.5)) - self.assertFalse(self._call_fut(None)) - self.assertFalse(self._call_fut("str")) - self.assertFalse(self._call_fut(int)) - self.assertFalse(self._call_fut(1.0 + 1.0j)) - - -class Test__enum_from_direction(unittest.TestCase): - @staticmethod - def _call_fut(direction): - from google.cloud.firestore_v1.query import _enum_from_direction - - return _enum_from_direction(direction) - - def test_success(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.query import Query - - dir_class = enums.StructuredQuery.Direction - self.assertEqual(self._call_fut(Query.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(Query.DESCENDING), dir_class.DESCENDING) - - # Ints pass through - self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING) - - def test_failure(self): - with self.assertRaises(ValueError): - self._call_fut("neither-ASCENDING-nor-DESCENDING") - - -class Test__filter_pb(unittest.TestCase): - @staticmethod - def _call_fut(field_or_unary): - from google.cloud.firestore_v1.query import _filter_pb - - return _filter_pb(field_or_unary) - - def test_unary(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import query_pb2 - - unary_pb = query_pb2.StructuredQuery.UnaryFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="a.b.c"), - op=enums.StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - filter_pb = self._call_fut(unary_pb) - expected_pb = query_pb2.StructuredQuery.Filter(unary_filter=unary_pb) - self.assertEqual(filter_pb, expected_pb) - - def test_field(self): - from google.cloud.firestore_v1.gapic import enums - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import query_pb2 - - field_filter_pb = query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path="XYZ"), - op=enums.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document_pb2.Value(double_value=90.75), - ) - filter_pb = self._call_fut(field_filter_pb) - expected_pb = query_pb2.StructuredQuery.Filter(field_filter=field_filter_pb) - self.assertEqual(filter_pb, expected_pb) - - def test_bad_type(self): - with self.assertRaises(ValueError): - self._call_fut(None) - - -class Test__cursor_pb(unittest.TestCase): - @staticmethod - def _call_fut(cursor_pair): - from google.cloud.firestore_v1.query import _cursor_pb - - return _cursor_pb(cursor_pair) - - def test_no_pair(self): - self.assertIsNone(self._call_fut(None)) - - def test_success(self): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1 import _helpers - - data = [1.5, 10, True] - cursor_pair = data, True - - cursor_pb = self._call_fut(cursor_pair) - - expected_pb = query_pb2.Cursor( - values=[_helpers.encode_value(value) for value in data], before=True - ) - self.assertEqual(cursor_pb, expected_pb) - - -class Test__query_response_to_snapshot(unittest.TestCase): - @staticmethod - def _call_fut(response_pb, collection, expected_prefix): - from google.cloud.firestore_v1.query import _query_response_to_snapshot - - return _query_response_to_snapshot(response_pb, collection, expected_prefix) - - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - client = _make_client() - collection = client.collection("a", "b", "c") - _, expected_prefix = collection._parent_info() - - # Create name for the protobuf. - doc_id = "gigantic" - name = "{}/{}".format(expected_prefix, doc_id) - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=name, data=data) - - snapshot = self._call_fut(response_pb, collection, expected_prefix) - self.assertIsInstance(snapshot, DocumentSnapshot) - expected_path = collection._path + (doc_id,) - self.assertEqual(snapshot.reference._path, expected_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb.read_time) - self.assertEqual(snapshot.create_time, response_pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb.document.update_time) - - -class Test__collection_group_query_response_to_snapshot(unittest.TestCase): - @staticmethod - def _call_fut(response_pb, collection): - from google.cloud.firestore_v1.query import ( - _collection_group_query_response_to_snapshot, - ) - - return _collection_group_query_response_to_snapshot(response_pb, collection) - - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) - - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) - - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - client = _make_client() - collection = client.collection("a", "b", "c") - other_collection = client.collection("a", "b", "d") - to_match = other_collection.document("gigantic") - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=to_match._document_path, data=data) - - snapshot = self._call_fut(response_pb, collection) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertEqual(snapshot.reference._document_path, to_match._document_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb.read_time) - self.assertEqual(snapshot.create_time, response_pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb.document.update_time) - - -def _make_credentials(): - import google.auth.credentials - - return mock.Mock(spec=google.auth.credentials.Credentials) - def _make_client(project="project-project"): from google.cloud.firestore_v1.client import Client credentials = _make_credentials() return Client(project=project, credentials=credentials) - - -def _make_order_pb(field_path, direction): - from google.cloud.firestore_v1.proto import query_pb2 - - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=direction, - ) - - -def _make_query_response(**kwargs): - # kwargs supported are ``skipped_results``, ``name`` and ``data`` - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1 import _helpers - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - kwargs["read_time"] = read_time - - name = kwargs.pop("name", None) - data = kwargs.pop("data", None) - if name is not None and data is not None: - document_pb = document_pb2.Document( - name=name, fields=_helpers.encode_dict(data) - ) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - document_pb.update_time.CopyFrom(update_time) - document_pb.create_time.CopyFrom(create_time) - - kwargs["document"] = document_pb - - return firestore_pb2.RunQueryResponse(**kwargs) From 3a660acc53b18762290f8f2d4f30a1f19fa302c4 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Thu, 25 Jun 2020 16:19:12 -0500 Subject: [PATCH 38/47] refactor: generalize collection tests with mocks --- tests/unit/v1/async/test_async_collection.py | 133 ------------------- tests/unit/v1/test_base_collection.py | 130 ++++++++++++++++++ tests/unit/v1/test_collection.py | 133 ------------------- 3 files changed, 130 insertions(+), 266 deletions(-) diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index 9cb97ae263..dedd12e0e4 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -202,139 +202,6 @@ async def test_add_explicit_id(self): metadata=client._rpc_metadata, ) - def test_select(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - field_paths = ["a", "b"] - query = collection.select(field_paths) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - projection_paths = [ - field_ref.field_path for field_ref in query._projection.fields - ] - self.assertEqual(projection_paths, field_paths) - - @staticmethod - def _make_field_filter_pb(field_path, op_string, value): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.base_query import _enum_from_op_string - - return query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=_enum_from_op_string(op_string), - value=_helpers.encode_value(value), - ) - - def test_where(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - field_path = "foo" - op_string = "==" - value = 45 - query = collection.where(field_path, op_string, value) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(len(query._field_filters), 1) - field_filter_pb = query._field_filters[0] - self.assertEqual( - field_filter_pb, self._make_field_filter_pb(field_path, op_string, value) - ) - - @staticmethod - def _make_order_pb(field_path, direction): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1.base_query import _enum_from_direction - - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=_enum_from_direction(direction), - ) - - def test_order_by(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - field_path = "foo" - direction = AsyncQuery.DESCENDING - query = collection.order_by(field_path, direction=direction) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(len(query._orders), 1) - order_pb = query._orders[0] - self.assertEqual(order_pb, self._make_order_pb(field_path, direction)) - - def test_limit(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - limit = 15 - query = collection.limit(limit) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._limit, limit) - - def test_offset(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - offset = 113 - query = collection.offset(offset) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._offset, offset) - - def test_start_at(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - doc_fields = {"a": "b"} - query = collection.start_at(doc_fields) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._start_at, (doc_fields, True)) - - def test_start_after(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - doc_fields = {"d": "foo", "e": 10} - query = collection.start_after(doc_fields) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._start_at, (doc_fields, False)) - - def test_end_before(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - doc_fields = {"bar": 10.5} - query = collection.end_before(doc_fields) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._end_at, (doc_fields, True)) - - def test_end_at(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - collection = self._make_one("collection") - doc_fields = {"opportunity": True, "reason": 9} - query = collection.end_at(doc_fields) - - self.assertIsInstance(query, AsyncQuery) - self.assertIs(query._parent, collection) - self.assertEqual(query._end_at, (doc_fields, False)) - @pytest.mark.asyncio async def _list_documents_helper(self, page_size=None): from google.api_core.page_iterator import Iterator diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index c73a10a818..cbdbc2898c 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -168,6 +168,136 @@ def test__parent_info_nested(self): prefix = "{}/{}".format(expected_path, collection_id2) self.assertEqual(expected_prefix, prefix) + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_select(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + field_paths = ["a", "b"] + query = collection.select(field_paths) + + mock_query.select.assert_called_once_with(field_paths) + self.assertEqual(query, mock_query.select.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_where(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + field_path = "foo" + op_string = "==" + value = 45 + query = collection.where(field_path, op_string, value) + + mock_query.where.assert_called_once_with(field_path, op_string, value) + self.assertEqual(query, mock_query.where.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_order_by(self, mock_query): + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + field_path = "foo" + direction = BaseQuery.DESCENDING + query = collection.order_by(field_path, direction=direction) + + mock_query.order_by.assert_called_once_with(field_path, direction=direction) + self.assertEqual(query, mock_query.order_by.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_limit(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + limit = 15 + query = collection.limit(limit) + + mock_query.limit.assert_called_once_with(limit) + self.assertEqual(query, mock_query.limit.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_offset(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + offset = 113 + query = collection.offset(offset) + + mock_query.offset.assert_called_once_with(offset) + self.assertEqual(query, mock_query.offset.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_start_at(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + doc_fields = {"a": "b"} + query = collection.start_at(doc_fields) + + mock_query.start_at.assert_called_once_with(doc_fields) + self.assertEqual(query, mock_query.start_at.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_start_after(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + doc_fields = {"d": "foo", "e": 10} + query = collection.start_after(doc_fields) + + mock_query.start_after.assert_called_once_with(doc_fields) + self.assertEqual(query, mock_query.start_after.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_end_before(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + doc_fields = {"bar": 10.5} + query = collection.end_before(doc_fields) + + mock_query.end_before.assert_called_once_with(doc_fields) + self.assertEqual(query, mock_query.end_before.return_value) + + @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) + def test_end_at(self, mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = self._make_one("collection") + doc_fields = {"opportunity": True, "reason": 9} + query = collection.end_at(doc_fields) + + mock_query.end_at.assert_called_once_with(doc_fields) + self.assertEqual(query, mock_query.end_at.return_value) + class Test__auto_id(unittest.TestCase): @staticmethod diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 4188d959f4..967012d36b 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -188,139 +188,6 @@ def test_add_explicit_id(self): metadata=client._rpc_metadata, ) - def test_select(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - field_paths = ["a", "b"] - query = collection.select(field_paths) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - projection_paths = [ - field_ref.field_path for field_ref in query._projection.fields - ] - self.assertEqual(projection_paths, field_paths) - - @staticmethod - def _make_field_filter_pb(field_path, op_string, value): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.base_query import _enum_from_op_string - - return query_pb2.StructuredQuery.FieldFilter( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - op=_enum_from_op_string(op_string), - value=_helpers.encode_value(value), - ) - - def test_where(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - field_path = "foo" - op_string = "==" - value = 45 - query = collection.where(field_path, op_string, value) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(len(query._field_filters), 1) - field_filter_pb = query._field_filters[0] - self.assertEqual( - field_filter_pb, self._make_field_filter_pb(field_path, op_string, value) - ) - - @staticmethod - def _make_order_pb(field_path, direction): - from google.cloud.firestore_v1.proto import query_pb2 - from google.cloud.firestore_v1.base_query import _enum_from_direction - - return query_pb2.StructuredQuery.Order( - field=query_pb2.StructuredQuery.FieldReference(field_path=field_path), - direction=_enum_from_direction(direction), - ) - - def test_order_by(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - field_path = "foo" - direction = Query.DESCENDING - query = collection.order_by(field_path, direction=direction) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(len(query._orders), 1) - order_pb = query._orders[0] - self.assertEqual(order_pb, self._make_order_pb(field_path, direction)) - - def test_limit(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - limit = 15 - query = collection.limit(limit) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._limit, limit) - - def test_offset(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - offset = 113 - query = collection.offset(offset) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._offset, offset) - - def test_start_at(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - doc_fields = {"a": "b"} - query = collection.start_at(doc_fields) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._start_at, (doc_fields, True)) - - def test_start_after(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - doc_fields = {"d": "foo", "e": 10} - query = collection.start_after(doc_fields) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._start_at, (doc_fields, False)) - - def test_end_before(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - doc_fields = {"bar": 10.5} - query = collection.end_before(doc_fields) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._end_at, (doc_fields, True)) - - def test_end_at(self): - from google.cloud.firestore_v1.query import Query - - collection = self._make_one("collection") - doc_fields = {"opportunity": True, "reason": 9} - query = collection.end_at(doc_fields) - - self.assertIsInstance(query, Query) - self.assertIs(query._parent, collection) - self.assertEqual(query._end_at, (doc_fields, False)) - def _list_documents_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page From fd654f166fdf1e6e0642f6739d5226b26dd0dca1 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Mon, 6 Jul 2020 17:54:16 -0500 Subject: [PATCH 39/47] feat: create Transaction/AsyncTransaction superclass --- .../cloud/firestore_v1/async_transaction.py | 79 +-------- google/cloud/firestore_v1/base_transaction.py | 166 ++++++++++++++++++ google/cloud/firestore_v1/transaction.py | 114 ++---------- tests/unit/v1/async/test_async_transaction.py | 85 ++------- tests/unit/v1/test_base_transaction.py | 121 +++++++++++++ tests/unit/v1/test_transaction.py | 87 ++------- 6 files changed, 337 insertions(+), 315 deletions(-) create mode 100644 google/cloud/firestore_v1/base_transaction.py create mode 100644 tests/unit/v1/test_base_transaction.py diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 069f8168c3..c2a8a14b8b 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -20,7 +20,9 @@ import six -from google.cloud.firestore_v1.transaction import ( +from google.cloud.firestore_v1.base_transaction import ( + _BaseTransactional, + BaseTransaction, MAX_ATTEMPTS, _CANT_BEGIN, _CANT_ROLLBACK, @@ -30,17 +32,15 @@ _MAX_SLEEP, _MULTIPLIER, _EXCEED_ATTEMPTS_TEMPLATE, - _CANT_RETRY_READ_ONLY, - _Transactional, ) + from google.api_core import exceptions from google.cloud.firestore_v1 import async_batch -from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_query import AsyncQuery -class AsyncTransaction(async_batch.AsyncWriteBatch): +class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. Args: @@ -56,9 +56,7 @@ class AsyncTransaction(async_batch.AsyncWriteBatch): def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): super(AsyncTransaction, self).__init__(client) - self._max_attempts = max_attempts - self._read_only = read_only - self._id = None + BaseTransaction.__init__(self, max_attempts, read_only) def _add_write_pbs(self, write_pbs): """Add `Write`` protobufs to this transaction. @@ -75,61 +73,6 @@ def _add_write_pbs(self, write_pbs): super(AsyncTransaction, self)._add_write_pbs(write_pbs) - def _options_protobuf(self, retry_id): - """Convert the current object to protobuf. - - The ``retry_id`` value is used when retrying a transaction that - failed (e.g. due to contention). It is intended to be the "first" - transaction that failed (i.e. if multiple retries are needed). - - Args: - retry_id (Union[bytes, NoneType]): Transaction ID of a transaction - to be retried. - - Returns: - Optional[google.cloud.firestore_v1.types.TransactionOptions]: - The protobuf ``TransactionOptions`` if ``read_only==True`` or if - there is a transaction ID to be retried, else :data:`None`. - - Raises: - ValueError: If ``retry_id`` is not :data:`None` but the - transaction is read-only. - """ - if retry_id is not None: - if self._read_only: - raise ValueError(_CANT_RETRY_READ_ONLY) - - return types.TransactionOptions( - read_write=types.TransactionOptions.ReadWrite( - retry_transaction=retry_id - ) - ) - elif self._read_only: - return types.TransactionOptions( - read_only=types.TransactionOptions.ReadOnly() - ) - else: - return None - - @property - def in_progress(self): - """Determine if this transaction has already begun. - - Returns: - bool: Indicates if the transaction has started. - """ - return self._id is not None - - @property - def id(self): - """Get the current transaction ID. - - Returns: - Optional[bytes]: The transaction ID (or :data:`None` if the - current transaction is not in progress). - """ - return self._id - async def _begin(self, retry_id=None): """Begin the transaction. @@ -151,14 +94,6 @@ async def _begin(self, retry_id=None): ) self._id = transaction_response.transaction - def _clean_up(self): - """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. - - This intended to occur on success or failure of the associated RPCs. - """ - self._write_pbs = [] - self._id = None - async def _rollback(self): """Roll back the transaction. @@ -232,7 +167,7 @@ async def get(self, ref_or_query): ) -class _AsyncTransactional(_Transactional): +class _AsyncTransactional(_BaseTransactional): """Provide a callable object to use as a transactional decorater. This is surfaced via diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py new file mode 100644 index 0000000000..f477fb0fef --- /dev/null +++ b/google/cloud/firestore_v1/base_transaction.py @@ -0,0 +1,166 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for applying Google Cloud Firestore changes in a transaction.""" + + +from google.cloud.firestore_v1 import types + +MAX_ATTEMPTS = 5 +"""int: Default number of transaction attempts (with retries).""" +_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." +_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." +_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") +_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") +_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." +_INITIAL_SLEEP = 1.0 +"""float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" +_MAX_SLEEP = 30.0 +"""float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" +_MULTIPLIER = 2.0 +"""float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" +_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." +_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." + + +class BaseTransaction(object): + """Accumulate read-and-write operations to be sent in a transaction. + + Args: + max_attempts (Optional[int]): The maximum number of attempts for + the transaction (i.e. allowing retries). Defaults to + :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. + read_only (Optional[bool]): Flag indicating if the transaction + should be read-only or should allow writes. Defaults to + :data:`False`. + """ + + def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False): + self._max_attempts = max_attempts + self._read_only = read_only + self._id = None + + def _add_write_pbs(self, write_pbs): + raise NotImplementedError + + def _options_protobuf(self, retry_id): + """Convert the current object to protobuf. + + The ``retry_id`` value is used when retrying a transaction that + failed (e.g. due to contention). It is intended to be the "first" + transaction that failed (i.e. if multiple retries are needed). + + Args: + retry_id (Union[bytes, NoneType]): Transaction ID of a transaction + to be retried. + + Returns: + Optional[google.cloud.firestore_v1.types.TransactionOptions]: + The protobuf ``TransactionOptions`` if ``read_only==True`` or if + there is a transaction ID to be retried, else :data:`None`. + + Raises: + ValueError: If ``retry_id`` is not :data:`None` but the + transaction is read-only. + """ + if retry_id is not None: + if self._read_only: + raise ValueError(_CANT_RETRY_READ_ONLY) + + return types.TransactionOptions( + read_write=types.TransactionOptions.ReadWrite( + retry_transaction=retry_id + ) + ) + elif self._read_only: + return types.TransactionOptions( + read_only=types.TransactionOptions.ReadOnly() + ) + else: + return None + + @property + def in_progress(self): + """Determine if this transaction has already begun. + + Returns: + bool: Indicates if the transaction has started. + """ + return self._id is not None + + @property + def id(self): + """Get the current transaction ID. + + Returns: + Optional[bytes]: The transaction ID (or :data:`None` if the + current transaction is not in progress). + """ + return self._id + + def _clean_up(self): + """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. + + This intended to occur on success or failure of the associated RPCs. + """ + self._write_pbs = [] + self._id = None + + def _begin(self, retry_id=None): + raise NotImplementedError + + def _rollback(self): + raise NotImplementedError + + def _commit(self): + raise NotImplementedError + + def get_all(self, references): + raise NotImplementedError + + def get(self, ref_or_query): + raise NotImplementedError + + +class _BaseTransactional(object): + """Provide a callable object to use as a transactional decorater. + + This is surfaced via + :func:`~google.cloud.firestore_v1.transaction.transactional`. + + Args: + to_wrap (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + A callable that should be run (and retried) in a transaction. + """ + + def __init__(self, to_wrap): + self.to_wrap = to_wrap + self.current_id = None + """Optional[bytes]: The current transaction ID.""" + self.retry_id = None + """Optional[bytes]: The ID of the first attempted transaction.""" + + def _reset(self): + """Unset the transaction IDs.""" + self.current_id = None + self.retry_id = None + + def _pre_commit(self, transaction, *args, **kwargs): + raise NotImplementedError + + def _maybe_commit(self, transaction): + raise NotImplementedError + + def __call__(self, transaction, *args, **kwargs): + raise NotImplementedError diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 04485a84c2..f69f7f61ae 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -20,31 +20,27 @@ import six +from google.cloud.firestore_v1.base_transaction import ( + _BaseTransactional, + BaseTransaction, + MAX_ATTEMPTS, + _CANT_BEGIN, + _CANT_ROLLBACK, + _CANT_COMMIT, + _WRITE_READ_ONLY, + _INITIAL_SLEEP, + _MAX_SLEEP, + _MULTIPLIER, + _EXCEED_ATTEMPTS_TEMPLATE, +) + from google.api_core import exceptions from google.cloud.firestore_v1 import batch -from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.query import Query -MAX_ATTEMPTS = 5 -"""int: Default number of transaction attempts (with retries).""" -_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." -_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." -_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") -_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") -_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." -_INITIAL_SLEEP = 1.0 -"""float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" -_MAX_SLEEP = 30.0 -"""float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" -_MULTIPLIER = 2.0 -"""float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" -_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." -_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." - - -class Transaction(batch.WriteBatch): +class Transaction(batch.WriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. Args: @@ -60,9 +56,7 @@ class Transaction(batch.WriteBatch): def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): super(Transaction, self).__init__(client) - self._max_attempts = max_attempts - self._read_only = read_only - self._id = None + BaseTransaction.__init__(self, max_attempts, read_only) def _add_write_pbs(self, write_pbs): """Add `Write`` protobufs to this transaction. @@ -79,61 +73,6 @@ def _add_write_pbs(self, write_pbs): super(Transaction, self)._add_write_pbs(write_pbs) - def _options_protobuf(self, retry_id): - """Convert the current object to protobuf. - - The ``retry_id`` value is used when retrying a transaction that - failed (e.g. due to contention). It is intended to be the "first" - transaction that failed (i.e. if multiple retries are needed). - - Args: - retry_id (Union[bytes, NoneType]): Transaction ID of a transaction - to be retried. - - Returns: - Optional[google.cloud.firestore_v1.types.TransactionOptions]: - The protobuf ``TransactionOptions`` if ``read_only==True`` or if - there is a transaction ID to be retried, else :data:`None`. - - Raises: - ValueError: If ``retry_id`` is not :data:`None` but the - transaction is read-only. - """ - if retry_id is not None: - if self._read_only: - raise ValueError(_CANT_RETRY_READ_ONLY) - - return types.TransactionOptions( - read_write=types.TransactionOptions.ReadWrite( - retry_transaction=retry_id - ) - ) - elif self._read_only: - return types.TransactionOptions( - read_only=types.TransactionOptions.ReadOnly() - ) - else: - return None - - @property - def in_progress(self): - """Determine if this transaction has already begun. - - Returns: - bool: Indicates if the transaction has started. - """ - return self._id is not None - - @property - def id(self): - """Get the current transaction ID. - - Returns: - Optional[bytes]: The transaction ID (or :data:`None` if the - current transaction is not in progress). - """ - return self._id - def _begin(self, retry_id=None): """Begin the transaction. @@ -155,14 +94,6 @@ def _begin(self, retry_id=None): ) self._id = transaction_response.transaction - def _clean_up(self): - """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. - - This intended to occur on success or failure of the associated RPCs. - """ - self._write_pbs = [] - self._id = None - def _rollback(self): """Roll back the transaction. @@ -234,7 +165,7 @@ def get(self, ref_or_query): ) -class _Transactional(object): +class _Transactional(_BaseTransactional): """Provide a callable object to use as a transactional decorater. This is surfaced via @@ -246,16 +177,7 @@ class _Transactional(object): """ def __init__(self, to_wrap): - self.to_wrap = to_wrap - self.current_id = None - """Optional[bytes]: The current transaction ID.""" - self.retry_id = None - """Optional[bytes]: The ID of the first attempted transaction.""" - - def _reset(self): - """Unset the transaction IDs.""" - self.current_id = None - self.retry_id = None + super(_Transactional, self).__init__(to_wrap) def _pre_commit(self, transaction, *args, **kwargs): """Begin transaction and call the wrapped callable. diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index 32b061c8d0..ba8fb8ff16 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -49,7 +49,7 @@ def test_constructor_explicit(self): self.assertIsNone(transaction._id) def test__add_write_pbs_failure(self): - from google.cloud.firestore_v1.async_transaction import _WRITE_READ_ONLY + from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY batch = self._make_one(mock.sentinel.client, read_only=True) self.assertEqual(batch._write_pbs, []) @@ -65,55 +65,18 @@ def test__add_write_pbs(self): batch._add_write_pbs([mock.sentinel.write]) self.assertEqual(batch._write_pbs, [mock.sentinel.write]) - def test__options_protobuf_read_only(self): - from google.cloud.firestore_v1.proto import common_pb2 - - transaction = self._make_one(mock.sentinel.client, read_only=True) - options_pb = transaction._options_protobuf(None) - expected_pb = common_pb2.TransactionOptions( - read_only=common_pb2.TransactionOptions.ReadOnly() - ) - self.assertEqual(options_pb, expected_pb) - - def test__options_protobuf_read_only_retry(self): - from google.cloud.firestore_v1.async_transaction import _CANT_RETRY_READ_ONLY - - transaction = self._make_one(mock.sentinel.client, read_only=True) - retry_id = b"illuminate" - - with self.assertRaises(ValueError) as exc_info: - transaction._options_protobuf(retry_id) - - self.assertEqual(exc_info.exception.args, (_CANT_RETRY_READ_ONLY,)) - - def test__options_protobuf_read_write(self): - transaction = self._make_one(mock.sentinel.client) - options_pb = transaction._options_protobuf(None) - self.assertIsNone(options_pb) - - def test__options_protobuf_on_retry(self): - from google.cloud.firestore_v1.proto import common_pb2 - + def test__clean_up(self): transaction = self._make_one(mock.sentinel.client) - retry_id = b"hocus-pocus" - options_pb = transaction._options_protobuf(retry_id) - expected_pb = common_pb2.TransactionOptions( - read_write=common_pb2.TransactionOptions.ReadWrite( - retry_transaction=retry_id - ) + transaction._write_pbs.extend( + [mock.sentinel.write_pb1, mock.sentinel.write_pb2] ) - self.assertEqual(options_pb, expected_pb) + transaction._id = b"not-this-time-my-friend" - def test_in_progress_property(self): - transaction = self._make_one(mock.sentinel.client) - self.assertFalse(transaction.in_progress) - transaction._id = b"not-none-bites" - self.assertTrue(transaction.in_progress) + ret_val = transaction._clean_up() + self.assertIsNone(ret_val) - def test_id_property(self): - transaction = self._make_one(mock.sentinel.client) - transaction._id = mock.sentinel.eye_dee - self.assertIs(transaction.id, mock.sentinel.eye_dee) + self.assertEqual(transaction._write_pbs, []) + self.assertIsNone(transaction._id) @pytest.mark.asyncio async def test__begin(self): @@ -147,7 +110,7 @@ async def test__begin(self): @pytest.mark.asyncio async def test__begin_failure(self): - from google.cloud.firestore_v1.async_transaction import _CANT_BEGIN + from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN client = _make_client() transaction = self._make_one(client) @@ -159,19 +122,6 @@ async def test__begin_failure(self): err_msg = _CANT_BEGIN.format(transaction._id) self.assertEqual(exc_info.exception.args, (err_msg,)) - def test__clean_up(self): - transaction = self._make_one(mock.sentinel.client) - transaction._write_pbs.extend( - [mock.sentinel.write_pb1, mock.sentinel.write_pb2] - ) - transaction._id = b"not-this-time-my-friend" - - ret_val = transaction._clean_up() - self.assertIsNone(ret_val) - - self.assertEqual(transaction._write_pbs, []) - self.assertIsNone(transaction._id) - @pytest.mark.asyncio async def test__rollback(self): from google.protobuf import empty_pb2 @@ -202,7 +152,7 @@ async def test__rollback(self): @pytest.mark.asyncio async def test__rollback_not_allowed(self): - from google.cloud.firestore_v1.async_transaction import _CANT_ROLLBACK + from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK client = _make_client() transaction = self._make_one(client) @@ -289,7 +239,7 @@ async def test__commit(self): @pytest.mark.asyncio async def test__commit_not_allowed(self): - from google.cloud.firestore_v1.async_transaction import _CANT_COMMIT + from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT transaction = self._make_one(mock.sentinel.client) self.assertIsNone(transaction._id) @@ -395,17 +345,6 @@ def test_constructor(self): self.assertIsNone(wrapped.current_id) self.assertIsNone(wrapped.retry_id) - def test__reset(self): - wrapped = self._make_one(mock.sentinel.callable_) - wrapped.current_id = b"not-none" - wrapped.retry_id = b"also-not" - - ret_val = wrapped._reset() - self.assertIsNone(ret_val) - - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) - @pytest.mark.asyncio async def test__pre_commit_success(self): to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) diff --git a/tests/unit/v1/test_base_transaction.py b/tests/unit/v1/test_base_transaction.py new file mode 100644 index 0000000000..e869f4383d --- /dev/null +++ b/tests/unit/v1/test_base_transaction.py @@ -0,0 +1,121 @@ +# Copyright 2017 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import mock + + +class TestBaseTransaction(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.base_transaction import BaseTransaction + + return BaseTransaction + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor_defaults(self): + from google.cloud.firestore_v1.transaction import MAX_ATTEMPTS + + transaction = self._make_one() + self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) + self.assertFalse(transaction._read_only) + self.assertIsNone(transaction._id) + + def test_constructor_explicit(self): + transaction = self._make_one(max_attempts=10, read_only=True) + self.assertEqual(transaction._max_attempts, 10) + self.assertTrue(transaction._read_only) + self.assertIsNone(transaction._id) + + def test__options_protobuf_read_only(self): + from google.cloud.firestore_v1.proto import common_pb2 + + transaction = self._make_one(read_only=True) + options_pb = transaction._options_protobuf(None) + expected_pb = common_pb2.TransactionOptions( + read_only=common_pb2.TransactionOptions.ReadOnly() + ) + self.assertEqual(options_pb, expected_pb) + + def test__options_protobuf_read_only_retry(self): + from google.cloud.firestore_v1.base_transaction import _CANT_RETRY_READ_ONLY + + transaction = self._make_one(read_only=True) + retry_id = b"illuminate" + + with self.assertRaises(ValueError) as exc_info: + transaction._options_protobuf(retry_id) + + self.assertEqual(exc_info.exception.args, (_CANT_RETRY_READ_ONLY,)) + + def test__options_protobuf_read_write(self): + transaction = self._make_one() + options_pb = transaction._options_protobuf(None) + self.assertIsNone(options_pb) + + def test__options_protobuf_on_retry(self): + from google.cloud.firestore_v1.proto import common_pb2 + + transaction = self._make_one() + retry_id = b"hocus-pocus" + options_pb = transaction._options_protobuf(retry_id) + expected_pb = common_pb2.TransactionOptions( + read_write=common_pb2.TransactionOptions.ReadWrite( + retry_transaction=retry_id + ) + ) + self.assertEqual(options_pb, expected_pb) + + def test_in_progress_property(self): + transaction = self._make_one() + self.assertFalse(transaction.in_progress) + transaction._id = b"not-none-bites" + self.assertTrue(transaction.in_progress) + + def test_id_property(self): + transaction = self._make_one() + transaction._id = mock.sentinel.eye_dee + self.assertIs(transaction.id, mock.sentinel.eye_dee) + + +class Test_Transactional(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.base_transaction import _BaseTransactional + + return _BaseTransactional + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + wrapped = self._make_one(mock.sentinel.callable_) + self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) + self.assertIsNone(wrapped.current_id) + self.assertIsNone(wrapped.retry_id) + + def test__reset(self): + wrapped = self._make_one(mock.sentinel.callable_) + wrapped.current_id = b"not-none" + wrapped.retry_id = b"also-not" + + ret_val = wrapped._reset() + self.assertIsNone(ret_val) + + self.assertIsNone(wrapped.current_id) + self.assertIsNone(wrapped.retry_id) diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index da3c2d0b02..72bb7eb37d 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -48,7 +48,7 @@ def test_constructor_explicit(self): self.assertIsNone(transaction._id) def test__add_write_pbs_failure(self): - from google.cloud.firestore_v1.transaction import _WRITE_READ_ONLY + from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY batch = self._make_one(mock.sentinel.client, read_only=True) self.assertEqual(batch._write_pbs, []) @@ -64,55 +64,18 @@ def test__add_write_pbs(self): batch._add_write_pbs([mock.sentinel.write]) self.assertEqual(batch._write_pbs, [mock.sentinel.write]) - def test__options_protobuf_read_only(self): - from google.cloud.firestore_v1.proto import common_pb2 - - transaction = self._make_one(mock.sentinel.client, read_only=True) - options_pb = transaction._options_protobuf(None) - expected_pb = common_pb2.TransactionOptions( - read_only=common_pb2.TransactionOptions.ReadOnly() - ) - self.assertEqual(options_pb, expected_pb) - - def test__options_protobuf_read_only_retry(self): - from google.cloud.firestore_v1.transaction import _CANT_RETRY_READ_ONLY - - transaction = self._make_one(mock.sentinel.client, read_only=True) - retry_id = b"illuminate" - - with self.assertRaises(ValueError) as exc_info: - transaction._options_protobuf(retry_id) - - self.assertEqual(exc_info.exception.args, (_CANT_RETRY_READ_ONLY,)) - - def test__options_protobuf_read_write(self): - transaction = self._make_one(mock.sentinel.client) - options_pb = transaction._options_protobuf(None) - self.assertIsNone(options_pb) - - def test__options_protobuf_on_retry(self): - from google.cloud.firestore_v1.proto import common_pb2 - + def test__clean_up(self): transaction = self._make_one(mock.sentinel.client) - retry_id = b"hocus-pocus" - options_pb = transaction._options_protobuf(retry_id) - expected_pb = common_pb2.TransactionOptions( - read_write=common_pb2.TransactionOptions.ReadWrite( - retry_transaction=retry_id - ) + transaction._write_pbs.extend( + [mock.sentinel.write_pb1, mock.sentinel.write_pb2] ) - self.assertEqual(options_pb, expected_pb) + transaction._id = b"not-this-time-my-friend" - def test_in_progress_property(self): - transaction = self._make_one(mock.sentinel.client) - self.assertFalse(transaction.in_progress) - transaction._id = b"not-none-bites" - self.assertTrue(transaction.in_progress) + ret_val = transaction._clean_up() + self.assertIsNone(ret_val) - def test_id_property(self): - transaction = self._make_one(mock.sentinel.client) - transaction._id = mock.sentinel.eye_dee - self.assertIs(transaction.id, mock.sentinel.eye_dee) + self.assertEqual(transaction._write_pbs, []) + self.assertIsNone(transaction._id) def test__begin(self): from google.cloud.firestore_v1.gapic import firestore_client @@ -144,7 +107,7 @@ def test__begin(self): ) def test__begin_failure(self): - from google.cloud.firestore_v1.transaction import _CANT_BEGIN + from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN client = _make_client() transaction = self._make_one(client) @@ -156,19 +119,6 @@ def test__begin_failure(self): err_msg = _CANT_BEGIN.format(transaction._id) self.assertEqual(exc_info.exception.args, (err_msg,)) - def test__clean_up(self): - transaction = self._make_one(mock.sentinel.client) - transaction._write_pbs.extend( - [mock.sentinel.write_pb1, mock.sentinel.write_pb2] - ) - transaction._id = b"not-this-time-my-friend" - - ret_val = transaction._clean_up() - self.assertIsNone(ret_val) - - self.assertEqual(transaction._write_pbs, []) - self.assertIsNone(transaction._id) - def test__rollback(self): from google.protobuf import empty_pb2 from google.cloud.firestore_v1.gapic import firestore_client @@ -197,7 +147,7 @@ def test__rollback(self): ) def test__rollback_not_allowed(self): - from google.cloud.firestore_v1.transaction import _CANT_ROLLBACK + from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK client = _make_client() transaction = self._make_one(client) @@ -281,7 +231,7 @@ def test__commit(self): ) def test__commit_not_allowed(self): - from google.cloud.firestore_v1.transaction import _CANT_COMMIT + from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT transaction = self._make_one(mock.sentinel.client) self.assertIsNone(transaction._id) @@ -382,17 +332,6 @@ def test_constructor(self): self.assertIsNone(wrapped.current_id) self.assertIsNone(wrapped.retry_id) - def test__reset(self): - wrapped = self._make_one(mock.sentinel.callable_) - wrapped.current_id = b"not-none" - wrapped.retry_id = b"also-not" - - ret_val = wrapped._reset() - self.assertIsNone(ret_val) - - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) - def test__pre_commit_success(self): to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) @@ -728,7 +667,7 @@ def test___call__success_second_attempt(self): def test___call__failure(self): from google.api_core import exceptions - from google.cloud.firestore_v1.transaction import _EXCEED_ATTEMPTS_TEMPLATE + from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) From ec5448e42422fdb303aed7ef85ca273e23b846a6 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Tue, 14 Jul 2020 21:16:52 -0500 Subject: [PATCH 40/47] feat: add microgen support to async interface --- google/cloud/firestore_v1/async_batch.py | 10 +- google/cloud/firestore_v1/async_client.py | 38 ++- google/cloud/firestore_v1/async_collection.py | 14 +- google/cloud/firestore_v1/async_document.py | 48 +++- google/cloud/firestore_v1/async_query.py | 14 +- .../cloud/firestore_v1/async_transaction.py | 26 +- tests/unit/v1/async/test_async_batch.py | 38 +-- tests/unit/v1/async/test_async_client.py | 65 +++-- tests/unit/v1/async/test_async_collection.py | 46 ++-- tests/unit/v1/async/test_async_document.py | 110 ++++---- tests/unit/v1/async/test_async_query.py | 64 +++-- tests/unit/v1/async/test_async_transaction.py | 245 +++++++++++------- 12 files changed, 438 insertions(+), 280 deletions(-) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 7fb18e90e2..d29c302356 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -37,15 +37,17 @@ async def commit(self): """Commit the changes accumulated in this batch. Returns: - List[:class:`google.cloud.proto.firestore.v1.write_pb2.WriteResult`, ...]: + List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: The write results corresponding to the changes committed, returned in the same order as the changes were applied to this batch. A write result contains an ``update_time`` field. """ commit_response = self._client._firestore_api.commit( - self._client._database_string, - self._write_pbs, - transaction=None, + request={ + "database": self._client._database_string, + "writes": self._write_pbs, + "transaction": None, + }, metadata=self._client._rpc_metadata, ) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 04ff127edf..14e506c417 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -31,7 +31,6 @@ _reference_info, _parse_batch_get, _get_doc_mask, - _item_to_collection_ref, _path_helper, ) @@ -215,10 +214,12 @@ async def get_all(self, references, field_paths=None, transaction=None): document_paths, reference_map = _reference_info(references) mask = _get_doc_mask(field_paths) response_iterator = self._firestore_api.batch_get_documents( - self._database_string, - document_paths, - mask, - transaction=_helpers.get_transaction_id(transaction), + request={ + "database": self._database_string, + "documents": document_paths, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + }, metadata=self._rpc_metadata, ) @@ -233,11 +234,30 @@ async def collections(self): iterator of subcollections of the current document. """ iterator = self._firestore_api.list_collection_ids( - "{}/documents".format(self._database_string), metadata=self._rpc_metadata + request={"parent": "{}/documents".format(self._database_string)}, + metadata=self._rpc_metadata, ) - iterator.client = self - iterator.item_to_value = _item_to_collection_ref - return iterator + + while True: + for i in iterator.collection_ids: + yield self.collection(i) + if iterator.next_page_token: + iterator = self._firestore_api.list_collection_ids( + request={ + "parent": "{}/documents".format(self._database_string), + "page_token": iterator.next_page_token, + }, + metadata=self._rpc_metadata, + ) + else: + return + + # TODO(microgen): currently this method is rewritten to iterate/page itself. + # https://github.com/googleapis/gapic-generator-python/issues/516 + # it seems the generator ought to be able to do this itself. + # iterator.client = self + # iterator.item_to_value = _item_to_collection_ref + # return iterator def batch(self): """Get a batch instance from this client. diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 77c43107f7..aa09e3d9a5 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -111,15 +111,15 @@ async def list_documents(self, page_size=None): parent, _ = self._parent_info() iterator = self._client._firestore_api.list_documents( - parent, - self.id, - page_size=page_size, - show_missing=True, + request={ + "parent": parent, + "collection_id": self.id, + "page_size": page_size, + "show_missing": True, + }, metadata=self._client._rpc_metadata, ) - iterator.collection = self - iterator.item_to_value = _item_to_document_ref - return iterator + return (_item_to_document_ref(self, i) for i in iterator) async def get(self, transaction=None): """Deprecated alias for :meth:`stream`.""" diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 1cd66b57d7..00672153c5 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -20,12 +20,11 @@ BaseDocumentReference, DocumentSnapshot, _first_write_result, - _item_to_collection_ref, ) from google.api_core import exceptions from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.proto import common_pb2 +from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.watch import Watch @@ -274,9 +273,11 @@ async def delete(self, option=None): """ write_pb = _helpers.pb_for_delete(self._document_path, option) commit_response = self._client._firestore_api.commit( - self._client._database_string, - [write_pb], - transaction=None, + request={ + "database": self._client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=self._client._rpc_metadata, ) @@ -313,16 +314,18 @@ async def get(self, field_paths=None, transaction=None): raise ValueError("'field_paths' must be a sequence of paths, not a string.") if field_paths is not None: - mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + mask = common.DocumentMask(field_paths=sorted(field_paths)) else: mask = None firestore_api = self._client._firestore_api try: document_pb = firestore_api.get_document( - self._document_path, - mask=mask, - transaction=_helpers.get_transaction_id(transaction), + request={ + "name": self._document_path, + "mask": mask, + "transaction": _helpers.get_transaction_id(transaction), + }, metadata=self._client._rpc_metadata, ) except exceptions.NotFound: @@ -360,13 +363,30 @@ async def collections(self, page_size=None): iterator will be empty """ iterator = self._client._firestore_api.list_collection_ids( - self._document_path, - page_size=page_size, + request={"parent": self._document_path, "page_size": page_size}, metadata=self._client._rpc_metadata, ) - iterator.document = self - iterator.item_to_value = _item_to_collection_ref - return iterator + + while True: + for i in iterator.collection_ids: + yield self.collection(i) + if iterator.next_page_token: + iterator = self._client._firestore_api.list_collection_ids( + request={ + "parent": self._document_path, + "page_size": page_size, + "page_token": iterator.next_page_token, + }, + metadata=self._client._rpc_metadata, + ) + else: + return + + # TODO(microgen): currently this method is rewritten to iterate/page itself. + # it seems the generator ought to be able to do this itself. + # iterator.document = self + # iterator.item_to_value = _item_to_collection_ref + # return iterator def on_snapshot(self, callback): """Watch this document. diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index dbfa1866f0..dea0c960b7 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -41,13 +41,13 @@ class AsyncQuery(BaseQuery): parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`): The collection that this query applies to. projection (Optional[:class:`google.cloud.proto.firestore.v1.\ - query_pb2.StructuredQuery.Projection`]): + query.StructuredQuery.Projection`]): A projection of document fields to limit the query results to. field_filters (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ - query_pb2.StructuredQuery.FieldFilter`, ...]]): + query.StructuredQuery.FieldFilter`, ...]]): The filters to be applied in the query. orders (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\ - query_pb2.StructuredQuery.Order`, ...]]): + query.StructuredQuery.Order`, ...]]): The "order by" entries to use in the query. limit (Optional[int]): The maximum number of documents the query is allowed to return. @@ -150,9 +150,11 @@ async def stream(self, transaction=None): """ parent_path, expected_prefix = self._parent._parent_info() response_iterator = self._client._firestore_api.run_query( - parent_path, - self._to_protobuf(), - transaction=_helpers.get_transaction_id(transaction), + request={ + "parent": parent_path, + "structured_query": self._to_protobuf(), + "transaction": _helpers.get_transaction_id(transaction), + }, metadata=self._client._rpc_metadata, ) diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index c2a8a14b8b..5690254656 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -63,7 +63,7 @@ def _add_write_pbs(self, write_pbs): Args: write_pbs (List[google.cloud.proto.firestore.v1.\ - write_pb2.Write]): A list of write protobufs to be added. + write.Write]): A list of write protobufs to be added. Raises: ValueError: If this transaction is read-only. @@ -88,8 +88,10 @@ async def _begin(self, retry_id=None): raise ValueError(msg) transaction_response = self._client._firestore_api.begin_transaction( - self._client._database_string, - options_=self._options_protobuf(retry_id), + request={ + "database": self._client._database_string, + "options": self._options_protobuf(retry_id), + }, metadata=self._client._rpc_metadata, ) self._id = transaction_response.transaction @@ -106,8 +108,10 @@ async def _rollback(self): try: # NOTE: The response is just ``google.protobuf.Empty``. self._client._firestore_api.rollback( - self._client._database_string, - self._id, + request={ + "database": self._client._database_string, + "transaction": self._id, + }, metadata=self._client._rpc_metadata, ) finally: @@ -117,7 +121,7 @@ async def _commit(self): """Transactionally commit the changes accumulated. Returns: - List[:class:`google.cloud.proto.firestore.v1.write_pb2.WriteResult`, ...]: + List[:class:`google.cloud.proto.firestore.v1.write.WriteResult`, ...]: The write results corresponding to the changes committed, returned in the same order as the changes were applied to this transaction. A write result contains an ``update_time`` field. @@ -312,7 +316,7 @@ async def _commit_with_retry(client, write_pbs, transaction_id): Args: client (:class:`~google.cloud.firestore_v1.client.Client`): A client with GAPIC client and configuration details. - write_pbs (List[:class:`google.cloud.proto.firestore.v1.write_pb2.Write`, ...]): + write_pbs (List[:class:`google.cloud.proto.firestore.v1.write.Write`, ...]): A ``Write`` protobuf instance to be committed. transaction_id (bytes): ID of an existing transaction that this commit will run in. @@ -329,9 +333,11 @@ async def _commit_with_retry(client, write_pbs, transaction_id): while True: try: return client._firestore_api.commit( - client._database_string, - write_pbs, - transaction=transaction_id, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": transaction_id, + }, metadata=client._rpc_metadata, ) except exceptions.ServiceUnavailable: diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index 6b6a8af774..d09e826a7b 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -39,14 +39,14 @@ def test_constructor(self): @pytest.mark.asyncio async def test_commit(self): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.Mock(spec=["commit"]) timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore_pb2.CommitResponse( - write_results=[write_pb2.WriteResult(), write_pb2.WriteResult()], + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, ) firestore_api.commit.return_value = commit_response @@ -66,28 +66,31 @@ async def test_commit(self): write_results = await batch.commit() self.assertEqual(write_results, list(commit_response.write_results)) self.assertEqual(batch.write_results, write_results) - self.assertEqual(batch.commit_time, timestamp) + # TODO(microgen): v2: commit time is already a datetime, though not with nano + # self.assertEqual(batch.commit_time, timestamp) # Make sure batch has no more "changes". self.assertEqual(batch._write_pbs, []) # Verify the mocks. firestore_api.commit.assert_called_once_with( - client._database_string, - write_pbs, - transaction=None, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, metadata=client._rpc_metadata, ) @pytest.mark.asyncio async def test_as_context_mgr_wo_error(self): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write firestore_api = mock.Mock(spec=["commit"]) timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore_pb2.CommitResponse( - write_results=[write_pb2.WriteResult(), write_pb2.WriteResult()], + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, ) firestore_api.commit.return_value = commit_response @@ -104,15 +107,18 @@ async def test_as_context_mgr_wo_error(self): write_pbs = batch._write_pbs[::] self.assertEqual(batch.write_results, list(commit_response.write_results)) - self.assertEqual(batch.commit_time, timestamp) + # TODO(microgen): v2: commit time is already a datetime, though not with nano + # self.assertEqual(batch.commit_time, timestamp) # Make sure batch has no more "changes". self.assertEqual(batch._write_pbs, []) # Verify the mocks. firestore_api.commit.assert_called_once_with( - client._database_string, - write_pbs, - transaction=None, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, metadata=client._rpc_metadata, ) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index e83fd7db08..c78b474aa4 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -133,7 +133,7 @@ def test_collection_group(self): assert query._all_descendants assert query._field_filters[0].field.field_path == "foo" assert query._field_filters[0].value.string_value == u"bar" - assert query._field_filters[0].op == query._field_filters[0].EQUAL + assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL assert query._parent.id == "collectionId" def test_collection_group_no_slashes(self): @@ -201,10 +201,13 @@ async def test_collections(self): firestore_api = mock.Mock(spec=["list_collection_ids"]) client._firestore_api_internal = firestore_api + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 class _Iterator(Iterator): def __init__(self, pages): super(_Iterator, self).__init__(client=None) self._pages = pages + self.collection_ids = pages[0] def _next_page(self): if self._pages: @@ -214,7 +217,7 @@ def _next_page(self): iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = list(await client.collections()) + collections = [c async for c in client.collections()] self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): @@ -224,7 +227,7 @@ def _next_page(self): base_path = client._database_string + "/documents" firestore_api.list_collection_ids.assert_called_once_with( - base_path, metadata=client._rpc_metadata + request={"parent": base_path}, metadata=client._rpc_metadata ) async def _get_all_helper(self, client, references, document_pbs, **kwargs): @@ -251,14 +254,14 @@ def _info_for_get_all(self, data1, data2): document_pb1, read_time = _doc_get_info(document1._document_path, data1) response1 = _make_batch_response(found=document_pb1, read_time=read_time) - document_pb2, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document_pb2, read_time=read_time) + document, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=read_time) return client, document1, document2, response1, response2 @pytest.mark.asyncio async def test_get_all(self): - from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.async_document import DocumentSnapshot data1 = {"a": u"cheese"} @@ -288,12 +291,14 @@ async def test_get_all(self): # Verify the call to the mock. doc_paths = [document1._document_path, document2._document_path] - mask = common_pb2.DocumentMask(field_paths=field_paths) + mask = common.DocumentMask(field_paths=field_paths) client._firestore_api.batch_get_documents.assert_called_once_with( - client._database_string, - doc_paths, - mask, - transaction=None, + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -322,10 +327,12 @@ async def test_get_all_with_transaction(self): # Verify the call to the mock. doc_paths = [document._document_path] client._firestore_api.batch_get_documents.assert_called_once_with( - client._database_string, - doc_paths, - None, - transaction=txn_id, + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -346,10 +353,12 @@ async def test_get_all_unknown_result(self): # Verify the call to the mock. doc_paths = [document._document_path] client._firestore_api.batch_get_documents.assert_called_once_with( - client._database_string, - doc_paths, - None, - transaction=None, + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -390,10 +399,12 @@ async def test_get_all_wrong_order(self): document3._document_path, ] client._firestore_api.batch_get_documents.assert_called_once_with( - client._database_string, - doc_paths, - None, - transaction=None, + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -425,13 +436,13 @@ def _make_credentials(): def _make_batch_response(**kwargs): - from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.types import firestore - return firestore_pb2.BatchGetDocumentsResponse(**kwargs) + return firestore.BatchGetDocumentsResponse(**kwargs) def _doc_get_info(ref_string, values): - from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.types import document from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud.firestore_v1 import _helpers @@ -441,7 +452,7 @@ def _doc_get_info(ref_string, values): update_time = _datetime_to_pb_timestamp(now - delta) create_time = _datetime_to_pb_timestamp(now - 2 * delta) - document_pb = document_pb2.Document( + document_pb = document.Document( name=ref_string, fields=_helpers.encode_dict(values), create_time=create_time, diff --git a/tests/unit/v1/async/test_async_collection.py b/tests/unit/v1/async/test_async_collection.py index dedd12e0e4..680b0eb85b 100644 --- a/tests/unit/v1/async/test_async_collection.py +++ b/tests/unit/v1/async/test_async_collection.py @@ -95,7 +95,7 @@ def test_constructor_invalid_kwarg(self): @pytest.mark.asyncio async def test_add_auto_assigned(self): - from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1 import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_create @@ -111,7 +111,7 @@ async def test_add_auto_assigned(self): commit_time=mock.sentinel.commit_time, ) firestore_api.commit.return_value = commit_response - create_doc_response = document_pb2.Document() + create_doc_response = document.Document() firestore_api.create_document.return_value = create_doc_response client = _make_client() client._firestore_api_internal = firestore_api @@ -138,9 +138,11 @@ async def test_add_auto_assigned(self): write_pbs = pbs_for_create(document_ref._document_path, document_data) firestore_api.commit.assert_called_once_with( - client._database_string, - write_pbs, - transaction=None, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, metadata=client._rpc_metadata, ) # Since we generate the ID locally, we don't call 'create_document'. @@ -148,16 +150,16 @@ async def test_add_auto_assigned(self): @staticmethod def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers - return write_pb2.Write( - update=document_pb2.Document( + return write.Write( + update=document.Document( name=document_path, fields=_helpers.encode_dict(document_data) ), - current_document=common_pb2.Precondition(exists=False), + current_document=common.Precondition(exists=False), ) @pytest.mark.asyncio @@ -196,9 +198,11 @@ async def test_add_explicit_id(self): write_pb = self._write_pb_for_create(document_ref._document_path, document_data) firestore_api.commit.assert_called_once_with( - client._database_string, - [write_pb], - transaction=None, + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -207,8 +211,8 @@ async def _list_documents_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1.gapic.firestore_client import FirestoreClient - from google.cloud.firestore_v1.proto.document_pb2 import Document + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + from google.cloud.firestore_v1.types.document import Document class _Iterator(Iterator): def __init__(self, pages): @@ -246,10 +250,12 @@ def _next_page(self): parent, _ = collection._parent_info() api_client.list_documents.assert_called_once_with( - parent, - collection.id, - page_size=page_size, - show_missing=True, + request={ + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + }, metadata=client._rpc_metadata, ) diff --git a/tests/unit/v1/async/test_async_document.py b/tests/unit/v1/async/test_async_document.py index be265d7bfd..b59c7282b9 100644 --- a/tests/unit/v1/async/test_async_document.py +++ b/tests/unit/v1/async/test_async_document.py @@ -63,31 +63,32 @@ def test_constructor_invalid_kwarg(self): @staticmethod def _make_commit_repsonse(write_results=None): - from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.types import firestore - response = mock.create_autospec(firestore_pb2.CommitResponse) + response = mock.create_autospec(firestore.CommitResponse) response.write_results = write_results or [mock.sentinel.write_result] response.commit_time = mock.sentinel.commit_time return response @staticmethod def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers - return write_pb2.Write( - update=document_pb2.Document( + return write.Write( + update=document.Document( name=document_path, fields=_helpers.encode_dict(document_data) ), - current_document=common_pb2.Precondition(exists=False), + current_document=common.Precondition(exists=False), ) @pytest.mark.asyncio async def test_create(self): # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) + firestore_api = mock.Mock() + firestore_api.commit.mock_add_spec(spec=["commit"]) firestore_api.commit.return_value = self._make_commit_repsonse() # Attach the fake GAPIC to a real client. @@ -103,9 +104,11 @@ async def test_create(self): self.assertIs(write_result, mock.sentinel.write_result) write_pb = self._write_pb_for_create(document._document_path, document_data) firestore_api.commit.assert_called_once_with( - client._database_string, - [write_pb], - transaction=None, + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -138,13 +141,13 @@ async def test_create_empty(self): @staticmethod def _write_pb_for_set(document_path, document_data, merge): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers - write_pbs = write_pb2.Write( - update=document_pb2.Document( + write_pbs = write.Write( + update=document.Document( name=document_path, fields=_helpers.encode_dict(document_data) ) ) @@ -158,8 +161,8 @@ def _write_pb_for_set(document_path, document_data, merge): field_paths = [ field_path.to_api_repr() for field_path in sorted(field_paths) ] - mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) - write_pbs.update_mask.CopyFrom(mask) + mask = common.DocumentMask(field_paths=sorted(field_paths)) + write_pbs._pb.update_mask.CopyFrom(mask._pb) return write_pbs @pytest.mark.asyncio @@ -182,9 +185,11 @@ async def _set_helper(self, merge=False, **option_kwargs): write_pb = self._write_pb_for_set(document._document_path, document_data, merge) firestore_api.commit.assert_called_once_with( - client._database_string, - [write_pb], - transaction=None, + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -198,17 +203,17 @@ async def test_set_merge(self): @staticmethod def _write_pb_for_update(document_path, update_values, field_paths): - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers - return write_pb2.Write( - update=document_pb2.Document( + return write.Write( + update=document.Document( name=document_path, fields=_helpers.encode_dict(update_values) ), - update_mask=common_pb2.DocumentMask(field_paths=field_paths), - current_document=common_pb2.Precondition(exists=True), + update_mask=common.DocumentMask(field_paths=field_paths), + current_document=common.Precondition(exists=True), ) @pytest.mark.asyncio @@ -249,9 +254,11 @@ async def _update_helper(self, **option_kwargs): if option is not None: option.modify_write(write_pb) firestore_api.commit.assert_called_once_with( - client._database_string, - [write_pb], - transaction=None, + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -290,7 +297,7 @@ async def test_empty_update(self): @pytest.mark.asyncio async def _delete_helper(self, **option_kwargs): - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import write # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["commit"]) @@ -311,13 +318,15 @@ async def _delete_helper(self, **option_kwargs): # Verify the response and the mocks. self.assertIs(delete_time, mock.sentinel.commit_time) - write_pb = write_pb2.Write(delete=document._document_path) + write_pb = write.Write(delete=document._document_path) if option is not None: option.modify_write(write_pb) firestore_api.commit.assert_called_once_with( - client._database_string, - [write_pb], - transaction=None, + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -337,15 +346,15 @@ async def _get_helper( self, field_paths=None, use_transaction=False, not_found=False ): from google.api_core.exceptions import NotFound - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import document_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.transaction import Transaction # Create a minimal fake GAPIC with a dummy response. create_time = 123 update_time = 234 firestore_api = mock.Mock(spec=["get_document"]) - response = mock.create_autospec(document_pb2.Document) + response = mock.create_autospec(document.Document) response.fields = {} response.create_time = create_time response.update_time = update_time @@ -384,7 +393,7 @@ async def _get_helper( # Verify the request made to the API if field_paths is not None: - mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + mask = common.DocumentMask(field_paths=sorted(field_paths)) else: mask = None @@ -394,9 +403,11 @@ async def _get_helper( expected_transaction_id = None firestore_api.get_document.assert_called_once_with( - document._document_path, - mask=mask, - transaction=expected_transaction_id, + request={ + "name": document._document_path, + "mask": mask, + "transaction": expected_transaction_id, + }, metadata=client._rpc_metadata, ) @@ -430,12 +441,14 @@ async def _collections_helper(self, page_size=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - from google.cloud.firestore_v1.gapic.firestore_client import FirestoreClient + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + # TODO(microgen): https://github.com/googleapis/gapic-generator-python/issues/516 class _Iterator(Iterator): def __init__(self, pages): super(_Iterator, self).__init__(client=None) self._pages = pages + self.collection_ids = pages[0] def _next_page(self): if self._pages: @@ -453,9 +466,9 @@ def _next_page(self): # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) if page_size is not None: - collections = list(await document.collections(page_size=page_size)) + collections = [c async for c in document.collections(page_size=page_size)] else: - collections = list(await document.collections()) + collections = [c async for c in document.collections()] # Verify the response and the mocks. self.assertEqual(len(collections), len(collection_ids)) @@ -465,7 +478,8 @@ def _next_page(self): self.assertEqual(collection.id, collection_id) api_client.list_collection_ids.assert_called_once_with( - document._document_path, page_size=page_size, metadata=client._rpc_metadata + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, ) @pytest.mark.asyncio diff --git a/tests/unit/v1/async/test_async_query.py b/tests/unit/v1/async/test_async_query.py index 5a9edd6d30..87305bfbc6 100644 --- a/tests/unit/v1/async/test_async_query.py +++ b/tests/unit/v1/async/test_async_query.py @@ -81,9 +81,11 @@ async def test_get_simple(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -123,9 +125,11 @@ async def test_stream_simple(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -165,9 +169,11 @@ async def test_stream_with_transaction(self): # Verify the mock call. firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=txn_id, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -194,9 +200,11 @@ async def test_stream_no_results(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -224,9 +232,11 @@ async def test_stream_second_response_in_empty_stream(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -263,9 +273,11 @@ async def test_stream_with_skipped_results(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -302,9 +314,11 @@ async def test_stream_empty_after_first_response(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) @@ -344,9 +358,11 @@ async def test_stream_w_collection_group(self): # Verify the mock call. parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( - parent_path, - query._to_protobuf(), - transaction=None, + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, metadata=client._rpc_metadata, ) diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index ba8fb8ff16..ab9a56033c 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -80,15 +80,17 @@ def test__clean_up(self): @pytest.mark.asyncio async def test__begin(self): - from google.cloud.firestore_v1.gapic import firestore_client - from google.cloud.firestore_v1.proto import firestore_pb2 + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + from google.cloud.firestore_v1.types import firestore # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( firestore_client.FirestoreClient, instance=True ) txn_id = b"to-begin" - response = firestore_pb2.BeginTransactionResponse(transaction=txn_id) + response = firestore.BeginTransactionResponse(transaction=txn_id) firestore_api.begin_transaction.return_value = response # Attach the fake GAPIC to a real client. @@ -105,7 +107,8 @@ async def test__begin(self): # Verify the called mock. firestore_api.begin_transaction.assert_called_once_with( - client._database_string, options_=None, metadata=client._rpc_metadata + request={"database": client._database_string, "options": None}, + metadata=client._rpc_metadata, ) @pytest.mark.asyncio @@ -125,7 +128,9 @@ async def test__begin_failure(self): @pytest.mark.asyncio async def test__rollback(self): from google.protobuf import empty_pb2 - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -147,7 +152,8 @@ async def test__rollback(self): # Verify the called mock. firestore_api.rollback.assert_called_once_with( - client._database_string, txn_id, metadata=client._rpc_metadata + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, ) @pytest.mark.asyncio @@ -166,7 +172,9 @@ async def test__rollback_not_allowed(self): @pytest.mark.asyncio async def test__rollback_failure(self): from google.api_core import exceptions - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy failure. firestore_api = mock.create_autospec( @@ -193,22 +201,23 @@ async def test__rollback_failure(self): # Verify the called mock. firestore_api.rollback.assert_called_once_with( - client._database_string, txn_id, metadata=client._rpc_metadata + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, ) @pytest.mark.asyncio async def test__commit(self): - from google.cloud.firestore_v1.gapic import firestore_client - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( firestore_client.FirestoreClient, instance=True ) - commit_response = firestore_pb2.CommitResponse( - write_results=[write_pb2.WriteResult()] - ) + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) firestore_api.commit.return_value = commit_response # Attach the fake GAPIC to a real client. @@ -231,9 +240,11 @@ async def test__commit(self): # Verify the mocks. firestore_api.commit.assert_called_once_with( - client._database_string, - write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -251,7 +262,9 @@ async def test__commit_not_allowed(self): @pytest.mark.asyncio async def test__commit_failure(self): from google.api_core import exceptions - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy failure. firestore_api = mock.create_autospec( @@ -281,9 +294,11 @@ async def test__commit_failure(self): # Verify the called mock. firestore_api.commit.assert_called_once_with( - client._database_string, - write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -363,8 +378,10 @@ async def test__pre_commit_success(self): to_wrap.assert_called_once_with(transaction, "pos", key="word") firestore_api = transaction._client._firestore_api firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=None, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_not_called() @@ -372,7 +389,7 @@ async def test__pre_commit_success(self): @pytest.mark.asyncio async def test__pre_commit_retry_id_already_set_success(self): - from google.cloud.firestore_v1.proto import common_pb2 + from google.cloud.firestore_v1.types import common to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) @@ -391,14 +408,14 @@ async def test__pre_commit_retry_id_already_set_success(self): # Verify mocks. to_wrap.assert_called_once_with(transaction) firestore_api = transaction._client._firestore_api - options_ = common_pb2.TransactionOptions( - read_write=common_pb2.TransactionOptions.ReadWrite( - retry_transaction=txn_id1 - ) + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) ) firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=options_, + request={ + "database": transaction._client._database_string, + "options": options_, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_not_called() @@ -424,13 +441,17 @@ async def test__pre_commit_failure(self): to_wrap.assert_called_once_with(transaction, 10, 20) firestore_api = transaction._client._firestore_api firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=None, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_called_once_with( - transaction._client._database_string, - txn_id, + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.commit.assert_not_called() @@ -462,13 +483,17 @@ async def test__pre_commit_failure_with_rollback_failure(self): # Verify mocks. to_wrap.assert_called_once_with(transaction, a="b", c="zebra") firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=None, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_called_once_with( - transaction._client._database_string, - txn_id, + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.commit.assert_not_called() @@ -491,9 +516,11 @@ async def test__maybe_commit_success(self): firestore_api.begin_transaction.assert_not_called() firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @@ -527,9 +554,11 @@ async def test__maybe_commit_failure_read_only(self): firestore_api.begin_transaction.assert_not_called() firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @@ -561,9 +590,11 @@ async def test__maybe_commit_failure_can_retry(self): firestore_api.begin_transaction.assert_not_called() firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @@ -596,9 +627,11 @@ async def test__maybe_commit_failure_cannot_retry(self): firestore_api.begin_transaction.assert_not_called() firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @@ -620,24 +653,25 @@ async def test___call__success_first_attempt(self): to_wrap.assert_called_once_with(transaction, "a", b="c") firestore_api = transaction._client._firestore_api firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=None, + request={"database": transaction._client._database_string, "options": None}, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @pytest.mark.asyncio async def test___call__success_second_attempt(self): from google.api_core import exceptions - from google.cloud.firestore_v1.proto import common_pb2 - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = self._make_one(to_wrap) @@ -650,7 +684,7 @@ async def test___call__success_second_attempt(self): firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = [ exc, - firestore_pb2.CommitResponse(write_results=[write_pb2.WriteResult()]), + firestore.CommitResponse(write_results=[write.WriteResult()]), ] # Call the __call__-able ``wrapped``. @@ -666,25 +700,26 @@ async def test___call__success_second_attempt(self): self.assertEqual(to_wrap.mock_calls, [wrapped_call, wrapped_call]) firestore_api = transaction._client._firestore_api db_str = transaction._client._database_string - options_ = common_pb2.TransactionOptions( - read_write=common_pb2.TransactionOptions.ReadWrite(retry_transaction=txn_id) + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) ) self.assertEqual( firestore_api.begin_transaction.mock_calls, [ mock.call( - db_str, options_=None, metadata=transaction._client._rpc_metadata + request={"database": db_str, "options": None}, + metadata=transaction._client._rpc_metadata, ), mock.call( - db_str, - options_=options_, + request={"database": db_str, "options": options_}, metadata=transaction._client._rpc_metadata, ), ], ) firestore_api.rollback.assert_not_called() commit_call = mock.call( - db_str, [], transaction=txn_id, metadata=transaction._client._rpc_metadata + request={"database": db_str, "writes": [], "transaction": txn_id}, + metadata=transaction._client._rpc_metadata, ) self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) @@ -720,19 +755,25 @@ async def test___call__failure(self): # Verify mocks. to_wrap.assert_called_once_with(transaction, "here", there=1.5) firestore_api.begin_transaction.assert_called_once_with( - transaction._client._database_string, - options_=None, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_called_once_with( - transaction._client._database_string, - txn_id, + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.commit.assert_called_once_with( - transaction._client._database_string, - [], - transaction=txn_id, + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) @@ -763,7 +804,9 @@ async def _call_fut(client, write_pbs, transaction_id): @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") @pytest.mark.asyncio async def test_success_first_attempt(self, _sleep): - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -782,9 +825,11 @@ async def test_success_first_attempt(self, _sleep): # Verify mocks used. _sleep.assert_not_called() firestore_api.commit.assert_called_once_with( - client._database_string, - mock.sentinel.write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -794,7 +839,9 @@ async def test_success_first_attempt(self, _sleep): @pytest.mark.asyncio async def test_success_third_attempt(self, _sleep): from google.api_core import exceptions - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -822,9 +869,11 @@ async def test_success_third_attempt(self, _sleep): _sleep.assert_any_call(2.0) # commit() called same way 3 times. commit_call = mock.call( - client._database_string, - mock.sentinel.write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) self.assertEqual( @@ -835,7 +884,9 @@ async def test_success_third_attempt(self, _sleep): @pytest.mark.asyncio async def test_failure_first_attempt(self, _sleep): from google.api_core import exceptions - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -859,9 +910,11 @@ async def test_failure_first_attempt(self, _sleep): # Verify mocks used. _sleep.assert_not_called() firestore_api.commit.assert_called_once_with( - client._database_string, - mock.sentinel.write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) @@ -869,7 +922,9 @@ async def test_failure_first_attempt(self, _sleep): @pytest.mark.asyncio async def test_failure_second_attempt(self, _sleep): from google.api_core import exceptions - from google.cloud.firestore_v1.gapic import firestore_client + from google.cloud.firestore_v1.services.firestore import ( + client as firestore_client, + ) # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -896,9 +951,11 @@ async def test_failure_second_attempt(self, _sleep): _sleep.assert_called_once_with(1.0) # commit() called same way 2 times. commit_call = mock.call( - client._database_string, - mock.sentinel.write_pbs, - transaction=txn_id, + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, metadata=client._rpc_metadata, ) self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) @@ -973,9 +1030,9 @@ def _make_client(project="feral-tom-cat"): def _make_transaction(txn_id, **txn_kwargs): from google.protobuf import empty_pb2 - from google.cloud.firestore_v1.gapic import firestore_client - from google.cloud.firestore_v1.proto import firestore_pb2 - from google.cloud.firestore_v1.proto import write_pb2 + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.async_transaction import AsyncTransaction # Create a fake GAPIC ... @@ -983,14 +1040,12 @@ def _make_transaction(txn_id, **txn_kwargs): firestore_client.FirestoreClient, instance=True ) # ... with a dummy ``BeginTransactionResponse`` result ... - begin_response = firestore_pb2.BeginTransactionResponse(transaction=txn_id) + begin_response = firestore.BeginTransactionResponse(transaction=txn_id) firestore_api.begin_transaction.return_value = begin_response # ... and a dummy ``Rollback`` result ... firestore_api.rollback.return_value = empty_pb2.Empty() # ... and a dummy ``Commit`` result. - commit_response = firestore_pb2.CommitResponse( - write_results=[write_pb2.WriteResult()] - ) + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) firestore_api.commit.return_value = commit_response # Attach the fake GAPIC to a real client. From b6c83805307af13fb3e02824ef9c2556f2b3bc11 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 11:07:45 -0500 Subject: [PATCH 41/47] fix: async client copyright date --- google/cloud/firestore_v1/async_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 14e506c417..4dd17035c8 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -1,4 +1,4 @@ -# Copyright 202 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 18bf8614d76bcace2f7ae4d95846a2218444d8bb Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 15:21:39 -0500 Subject: [PATCH 42/47] fix: standardize assert syntax --- tests/unit/v1/async/test_async_client.py | 12 +++++++----- tests/unit/v1/test_client.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/unit/v1/async/test_async_client.py b/tests/unit/v1/async/test_async_client.py index c78b474aa4..6fd9b93d28 100644 --- a/tests/unit/v1/async/test_async_client.py +++ b/tests/unit/v1/async/test_async_client.py @@ -130,11 +130,13 @@ def test_collection_group(self): client = self._make_default_one() query = client.collection_group("collectionId").where("foo", "==", u"bar") - assert query._all_descendants - assert query._field_filters[0].field.field_path == "foo" - assert query._field_filters[0].value.string_value == u"bar" - assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL - assert query._parent.id == "collectionId" + self.assertTrue(query._all_descendants) + self.assertEqual(query._field_filters[0].field.field_path, "foo") + self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual( + query._field_filters[0].op, query._field_filters[0].Operator.EQUAL + ) + self.assertEqual(query._parent.id, "collectionId") def test_collection_group_no_slashes(self): client = self._make_default_one() diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 8aa5f41d42..40c969295a 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -129,11 +129,13 @@ def test_collection_group(self): client = self._make_default_one() query = client.collection_group("collectionId").where("foo", "==", u"bar") - assert query._all_descendants - assert query._field_filters[0].field.field_path == "foo" - assert query._field_filters[0].value.string_value == u"bar" - assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL - assert query._parent.id == "collectionId" + self.assertTrue(query._all_descendants) + self.assertEqual(query._field_filters[0].field.field_path, "foo") + self.assertEqual(query._field_filters[0].value.string_value, u"bar") + self.assertEqual( + query._field_filters[0].op, query._field_filters[0].Operator.EQUAL + ) + self.assertEqual(query._parent.id, "collectionId") def test_collection_group_no_slashes(self): client = self._make_default_one() From faccb506b1186e7ef353d2013ec5e6270998edba Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 15:24:53 -0500 Subject: [PATCH 43/47] fix: incorrect copyright date --- tests/unit/v1/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 40c969295a..433fcadfaf 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC All rights reserved. +# Copyright 2020 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 1f4ba241adcfcdc0bba658e89c4904730a948d10 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 15:27:53 -0500 Subject: [PATCH 44/47] fix: incorrect copyright date --- tests/unit/v1/async/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/v1/async/__init__.py b/tests/unit/v1/async/__init__.py index ab67290952..c6334245ae 100644 --- a/tests/unit/v1/async/__init__.py +++ b/tests/unit/v1/async/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 81b09f8016a112a110488825d7b816461e5dfa0a Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 17:13:56 -0500 Subject: [PATCH 45/47] fix: clarify _sleep assertions in transaction --- tests/unit/v1/async/test_async_transaction.py | 1 + tests/unit/v1/test_transaction.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/unit/v1/async/test_async_transaction.py b/tests/unit/v1/async/test_async_transaction.py index ab9a56033c..b27f30e9cd 100644 --- a/tests/unit/v1/async/test_async_transaction.py +++ b/tests/unit/v1/async/test_async_transaction.py @@ -864,6 +864,7 @@ async def test_success_third_attempt(self, _sleep): self.assertIs(commit_response, mock.sentinel.commit_response) # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds self.assertEqual(_sleep.call_count, 2) _sleep.assert_any_call(1.0) _sleep.assert_any_call(2.0) diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index e4c8389921..a32e58c104 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -831,6 +831,7 @@ def test_success_third_attempt(self, _sleep): self.assertIs(commit_response, mock.sentinel.commit_response) # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds self.assertEqual(_sleep.call_count, 2) _sleep.assert_any_call(1.0) _sleep.assert_any_call(2.0) From a75ab7082fe12b4ad75311fd2e23b7a86a00d144 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Wed, 15 Jul 2020 17:59:33 -0500 Subject: [PATCH 46/47] fix: clarify error in context manager tests --- tests/unit/v1/async/test_async_batch.py | 2 +- tests/unit/v1/test_batch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index d09e826a7b..79daf47920 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -137,9 +137,9 @@ async def test_as_context_mgr_w_error(self): ctx_mgr.delete(document2) raise RuntimeError("testing") + # batch still has its changes, as _aexit_ is not invoked self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - # batch still has its changes self.assertEqual(len(batch._write_pbs), 2) firestore_api.commit.assert_not_called() diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index e8ab7a2670..b18e671e58 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -133,9 +133,9 @@ def test_as_context_mgr_w_error(self): ctx_mgr.delete(document2) raise RuntimeError("testing") + # batch still has its changes, as _exit_ is not invoked self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) - # batch still has its changes self.assertEqual(len(batch._write_pbs), 2) firestore_api.commit.assert_not_called() From 5914bfa6ea538ee04bff2604fbed87eb19cb07e6 Mon Sep 17 00:00:00 2001 From: Rafi Long Date: Thu, 16 Jul 2020 13:18:57 -0500 Subject: [PATCH 47/47] fix: clarify error in context manager tests --- tests/unit/v1/async/test_async_batch.py | 3 ++- tests/unit/v1/test_batch.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/v1/async/test_async_batch.py b/tests/unit/v1/async/test_async_batch.py index 79daf47920..acb977d869 100644 --- a/tests/unit/v1/async/test_async_batch.py +++ b/tests/unit/v1/async/test_async_batch.py @@ -137,7 +137,8 @@ async def test_as_context_mgr_w_error(self): ctx_mgr.delete(document2) raise RuntimeError("testing") - # batch still has its changes, as _aexit_ is not invoked + # batch still has its changes, as _aexit_ (and commit) is not invoked + # changes are preserved so commit can be retried self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) self.assertEqual(len(batch._write_pbs), 2) diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index b18e671e58..5396540c6d 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -133,7 +133,8 @@ def test_as_context_mgr_w_error(self): ctx_mgr.delete(document2) raise RuntimeError("testing") - # batch still has its changes, as _exit_ is not invoked + # batch still has its changes, as _exit_ (and commit) is not invoked + # changes are preserved so commit can be retried self.assertIsNone(batch.write_results) self.assertIsNone(batch.commit_time) self.assertEqual(len(batch._write_pbs), 2)