diff --git a/packages/google-cloud-firestore/noxfile.py b/packages/google-cloud-firestore/noxfile.py index 588dd7c0058d..5a7c0a1b8536 100644 --- a/packages/google-cloud-firestore/noxfile.py +++ b/packages/google-cloud-firestore/noxfile.py @@ -71,7 +71,7 @@ SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ALL_PYTHON SYSTEM_TEST_STANDARD_DEPENDENCIES = [ "mock", - "pytest", + "pytest>9.0", "google-cloud-testutils", ] SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ diff --git a/packages/google-cloud-firestore/tests/system/test__helpers.py b/packages/google-cloud-firestore/tests/system/test__helpers.py index 8032ae11918f..603be70705d0 100644 --- a/packages/google-cloud-firestore/tests/system/test__helpers.py +++ b/packages/google-cloud-firestore/tests/system/test__helpers.py @@ -1,9 +1,14 @@ import os import re +import time +import datetime +import contextlib from test_utils.system import EmulatorCreds, unique_resource_id from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST +from google.cloud.firestore import SERVER_TIMESTAMP +from google.api_core.exceptions import AlreadyExists FIRESTORE_CREDS = os.environ.get("FIRESTORE_APPLICATION_CREDENTIALS") FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") @@ -20,3 +25,41 @@ # run all tests against default database, and a named database TEST_DATABASES = [None, FIRESTORE_OTHER_DB] TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] + + +@contextlib.contextmanager +def system_test_lock(client, lock_name="system_test_lock", max_wait_minutes=65): + """ + Acquires a distributed lock using a Firestore document to prevent concurrent system tests. + """ + lock_ref = client.collection("system_tests").document(lock_name) + start_time = time.time() + max_wait_time = max_wait_minutes * 60 + + while time.time() - start_time < max_wait_time: + try: + lock_ref.create({"created_at": SERVER_TIMESTAMP}) + break # Lock acquired + except AlreadyExists: + lock_doc = lock_ref.get() + if lock_doc.exists: + created_at = lock_doc.to_dict().get("created_at") + if created_at: + now = datetime.datetime.now(datetime.timezone.utc) + age = (now - created_at).total_seconds() + if age > 3600: + print(f"Lock is expired (age: {age}s). Stealing lock.") + lock_ref.delete() + continue + else: + print( + f"Waiting for {lock_name}. Lock is {age:.0f}s old. Sleeping for 15s..." + ) + time.sleep(15) + else: + raise TimeoutError(f"Timed out waiting for {lock_name}") + + try: + yield lock_ref + finally: + lock_ref.delete() diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py index afff43ac6950..b1f284531102 100644 --- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py +++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py @@ -26,7 +26,7 @@ import yaml from google.api_core.exceptions import GoogleAPIError from google.protobuf.json_format import MessageToDict -from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB +from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock from google.cloud.firestore import AsyncClient, Client from google.cloud.firestore_v1 import pipeline_expressions @@ -364,21 +364,23 @@ def client(): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) data = yaml_loader("data", attach_file_name=False) to_delete = [] - try: - # setup data - batch = client.batch() - for collection_name, documents in data.items(): - collection_ref = client.collection(collection_name) - for document_id, document_data in documents.items(): - document_ref = collection_ref.document(document_id) - to_delete.append(document_ref) - batch.set(document_ref, _parse_yaml_types(document_data)) - batch.commit() - yield client - finally: - # clear data - for document_ref in to_delete: - document_ref.delete() + + with system_test_lock(client, lock_name="pipeline_e2e_lock"): + try: + # setup data + batch = client.batch() + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id, document_data in documents.items(): + document_ref = collection_ref.document(document_id) + to_delete.append(document_ref) + batch.set(document_ref, _parse_yaml_types(document_data)) + batch.commit() + yield client + finally: + # clear data + for document_ref in to_delete: + document_ref.delete() @pytest.fixture(scope="module") diff --git a/packages/google-cloud-firestore/tests/system/test_system.py b/packages/google-cloud-firestore/tests/system/test_system.py index f3cdf13b09a1..350daa4a5bc5 100644 --- a/packages/google-cloud-firestore/tests/system/test_system.py +++ b/packages/google-cloud-firestore/tests/system/test_system.py @@ -84,62 +84,77 @@ def cleanup(): operation() -def verify_pipeline(query): +@pytest.fixture +def verify_pipeline(subtests): """ - This function ensures a pipeline produces the same - results as the query it is derived from + This fixture provides a subtest function which + ensures a pipeline produces the same results as the query it is derived + from. It can be attached to existing query tests to check both modalities at the same time - """ - from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - if FIRESTORE_EMULATOR: - pytest.skip("skip pipeline verification on emulator") - - def _clean_results(results): - if isinstance(results, dict): - return {k: _clean_results(v) for k, v in results.items()} - elif isinstance(results, list): - return [_clean_results(r) for r in results] - elif isinstance(results, float) and math.isnan(results): - return "__NAN_VALUE__" - else: - return results + Pipelines are only supported on enterprise dbs. Skip other environments + """ - query_exception = None - query_results = None - try: - try: - if isinstance(query, BaseAggregationQuery): - # aggregation queries return a list of lists of aggregation results - query_results = _clean_results( - list( - itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in query.get()] + def _verifier(query): + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + with subtests.test(msg="verify_pipeline"): + client = query._client + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + if client._database != FIRESTORE_ENTERPRISE_DB: + pytest.skip("pipelines only supports enterprise db") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results( + [s.to_dict() for s in query.get()] ) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results( + [s.data() for s in pipeline.execute()] ) - ) - else: - # other qureies return a simple list of results - query_results = _clean_results([s.to_dict() for s in query.get()]) - except Exception as e: - # if we expect the query to fail, capture the exception - query_exception = e - client = query._client - pipeline = client.pipeline().create_from(query) - if query_exception: - # ensure that the pipeline uses same error as query - with pytest.raises(query_exception.__class__): - pipeline.execute() - else: - # ensure results match query - pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) - assert query_results == pipeline_results - except FailedPrecondition as e: - # if testing against a non-enterprise db, skip this check - if ENTERPRISE_MODE_ERROR not in e.message: - raise e + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + return _verifier @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1294,7 +1309,7 @@ def query(collection): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_legacy_where(query_docs, database): +def test_query_stream_legacy_where(query_docs, database, verify_pipeline): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1311,7 +1326,7 @@ def test_query_stream_legacy_where(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_eq_op(query_docs, database): +def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1323,7 +1338,9 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1335,7 +1352,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_in_op(query_docs, database): +def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1348,7 +1365,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_not_eq_op(query_docs, database): +def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1371,7 +1388,7 @@ def test_query_stream_w_not_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_not_in_op(query_docs, database): +def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1384,7 +1401,9 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_any_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1399,7 +1418,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_order_by(query_docs, database): +def test_query_stream_w_order_by(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] @@ -1414,7 +1433,7 @@ def test_query_stream_w_order_by(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_field_path(query_docs, database): +def test_query_stream_w_field_path(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1453,7 +1472,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_wo_results(query_docs, database): +def test_query_stream_wo_results(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) @@ -1480,7 +1499,7 @@ def test_query_stream_w_projection(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_multiple_filters(query_docs, database): +def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( filter=FieldFilter("stats.product", "<", 10) @@ -1501,7 +1520,7 @@ def test_query_stream_w_multiple_filters(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_w_offset(query_docs, database): +def test_query_stream_w_offset(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1522,7 +1541,9 @@ def test_query_stream_w_offset(query_docs, database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): +def test_query_stream_or_get_w_no_explain_options( + query_docs, database, method, verify_pipeline +): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs @@ -1886,7 +1907,7 @@ def test_query_with_order_dot_key(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -def test_query_unary(client, cleanup, database): +def test_query_unary(client, cleanup, database, verify_pipeline): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1943,7 +1964,7 @@ def test_query_unary(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_collection_group_queries(client, cleanup, database): +def test_collection_group_queries(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -2020,7 +2041,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_collection_group_queries_filters(client, cleanup, database): +def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -2416,7 +2437,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_array_union(client, cleanup, database): - doc_ref = client.document("gcp-7523", "test-document") + doc_ref = client.document(f"gcp-7523-{UNIQUE_RESOURCE_ID}", "test-document") cleanup(doc_ref.delete) doc_ref.delete() tree_1 = {"forest": {"tree-1": "oak"}} @@ -2811,7 +2832,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_repro_429(client, cleanup, database): +def test_repro_429(client, cleanup, database, verify_pipeline): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID) @@ -3406,7 +3427,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_and_composite_filter(collection, database): +def test_query_with_and_composite_filter(collection, database, verify_pipeline): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -3422,7 +3443,7 @@ def test_query_with_and_composite_filter(collection, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_or_composite_filter(collection, database): +def test_query_with_or_composite_filter(collection, database, verify_pipeline): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -3456,6 +3477,7 @@ def test_aggregation_queries_with_read_time( database, aggregation_type, expected_value, + verify_pipeline, ): """ Ensure that all aggregation queries work when read_time is passed into @@ -3494,7 +3516,7 @@ def test_aggregation_queries_with_read_time( @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_query_with_complex_composite_filter(collection, database): +def test_query_with_complex_composite_filter(collection, database, verify_pipeline): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -3552,6 +3574,7 @@ def test_aggregation_query_in_transaction( aggregation_type, aggregation_args, expected, + verify_pipeline, ): """ Test creating an aggregation query inside a transaction @@ -3593,7 +3616,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -def test_or_query_in_transaction(client, cleanup, database): +def test_or_query_in_transaction(client, cleanup, database, verify_pipeline): """ Test running or query inside a transaction. Should pass transaction id along with request """ diff --git a/packages/google-cloud-firestore/tests/system/test_system_async.py b/packages/google-cloud-firestore/tests/system/test_system_async.py index b2806a0dc68e..3a7959830425 100644 --- a/packages/google-cloud-firestore/tests/system/test_system_async.py +++ b/packages/google-cloud-firestore/tests/system/test_system_async.py @@ -164,64 +164,80 @@ async def cleanup(): await operation() -async def verify_pipeline(query): +@pytest.fixture +def verify_pipeline(subtests): """ - This function ensures a pipeline produces the same - results as the query it is derived from + This fixture provide a subtest function which + ensures a pipeline produces the same results as the query it is derived + from It can be attached to existing query tests to check both modalities at the same time + + Pipelines are only supported on enterprise dbs. Skip other environments """ - from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - if FIRESTORE_EMULATOR: - pytest.skip("skip pipeline verification on emulator") - - def _clean_results(results): - if isinstance(results, dict): - return {k: _clean_results(v) for k, v in results.items()} - elif isinstance(results, list): - return [_clean_results(r) for r in results] - elif isinstance(results, float) and math.isnan(results): - return "__NAN_VALUE__" - else: - return results - - query_exception = None - query_results = None - try: - try: - if isinstance(query, BaseAggregationQuery): - # aggregation queries return a list of lists of aggregation results - query_results = _clean_results( - list( - itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in await query.get()] + async def _verifier(query): + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + with subtests.test(msg="verify_pipeline"): + client = query._client + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + if client._database != FIRESTORE_ENTERPRISE_DB: + pytest.skip("pipelines only supports enterprise db") + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [ + [a._to_dict() for a in s] + for s in await query.get() + ] + ) + ) ) + else: + # other qureies return a simple list of results + query_results = _clean_results( + [s.to_dict() for s in await query.get()] + ) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = client.pipeline().create_from(query) + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + await pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] ) - ) - else: - # other qureies return a simple list of results - query_results = _clean_results([s.to_dict() for s in await query.get()]) - except Exception as e: - # if we expect the query to fail, capture the exception - query_exception = e - client = query._client - pipeline = client.pipeline().create_from(query) - if query_exception: - # ensure that the pipeline uses same error as query - with pytest.raises(query_exception.__class__): - await pipeline.execute() - else: - # ensure results match query - pipeline_results = _clean_results( - [s.data() async for s in pipeline.stream()] - ) - assert query_results == pipeline_results - except FailedPrecondition as e: - # if testing against a non-enterprise db, skip this check - if ENTERPRISE_MODE_ERROR not in e.message: - raise e + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + return _verifier @pytest.fixture(scope="module") @@ -1268,7 +1284,7 @@ async def async_query(collection): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_legacy_where(query_docs, database): +async def test_query_stream_legacy_where(query_docs, database, verify_pipeline): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1285,7 +1301,7 @@ async def test_query_stream_legacy_where(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_eq_op(query_docs, database): +async def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1297,7 +1313,9 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): +async def test_query_stream_w_simple_field_array_contains_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1309,7 +1327,7 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_in_op(query_docs, database): +async def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1322,7 +1340,9 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): +async def test_query_stream_w_simple_field_array_contains_any_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1337,7 +1357,7 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_order_by(query_docs, database): +async def test_query_stream_w_order_by(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) async for snapshot in query.stream()] @@ -1352,7 +1372,7 @@ async def test_query_stream_w_order_by(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_field_path(query_docs, database): +async def test_query_stream_w_field_path(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} @@ -1391,7 +1411,7 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_wo_results(query_docs, database): +async def test_query_stream_wo_results(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) @@ -1418,7 +1438,7 @@ async def test_query_stream_w_projection(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_multiple_filters(query_docs, database): +async def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( "stats.product", "<", 10 @@ -1439,7 +1459,7 @@ async def test_query_stream_w_multiple_filters(query_docs, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_w_offset(query_docs, database): +async def test_query_stream_w_offset(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1460,7 +1480,9 @@ async def test_query_stream_w_offset(query_docs, database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): +async def test_query_stream_or_get_w_no_explain_options( + query_docs, database, method, verify_pipeline +): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, allowed_vals = query_docs @@ -1812,7 +1834,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -async def test_query_unary(client, cleanup, database): +async def test_query_unary(client, cleanup, database, verify_pipeline): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1869,7 +1891,7 @@ async def test_query_unary(client, cleanup, database): @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_collection_group_queries(client, cleanup, database): +async def test_collection_group_queries(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1946,7 +1968,9 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) -async def test_collection_group_queries_filters(client, cleanup, database): +async def test_collection_group_queries_filters( + client, cleanup, database, verify_pipeline +): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [