From 15cab6f41716c684a7ab0d14db0bf3d857abcfdc Mon Sep 17 00:00:00 2001 From: Dhruv Shetty Date: Wed, 10 Jun 2026 15:46:35 -0400 Subject: [PATCH] feat: add async support (AsyncMockFirestore and async refs) Ported from mdowds/python-mock-firestore PR #62 (author: anna-hope), taking only the additive async layer and adapting it to this fork's sync API instead of PR #62's refactored one: - async_query uses Query._process_snapshots (extracted from Query.stream, behavior unchanged) instead of PR #62's _process_field_filters/_process_pagination - async_collection.get consumes its async stream instead of inheriting the sync list(stream) implementation; where() supports FieldFilter like this fork's sync where() - async_document.set mirrors this fork's set(merge=True) semantics (flatten_for_merge, __name__ stamping) with proper awaits - MockFirestore._ensure_path uses isinstance so async subclasses route document/collection paths correctly - consume_async_iterable helper added to _helpers --- mockfirestore/__init__.py | 6 + mockfirestore/_helpers.py | 6 +- mockfirestore/async_client.py | 45 ++ mockfirestore/async_collection.py | 76 ++++ mockfirestore/async_document.py | 35 ++ mockfirestore/async_query.py | 14 + mockfirestore/async_transaction.py | 44 ++ mockfirestore/client.py | 2 +- mockfirestore/query.py | 4 +- requirements-dev-minimal.txt | 3 +- tests/test_async_collection_reference.py | 524 +++++++++++++++++++++++ tests/test_async_document_reference.py | 337 +++++++++++++++ tests/test_async_mock_client.py | 25 ++ tests/test_async_query.py | 13 + tests/test_async_transaction.py | 71 +++ 15 files changed, 1201 insertions(+), 4 deletions(-) create mode 100644 mockfirestore/async_client.py create mode 100644 mockfirestore/async_collection.py create mode 100644 mockfirestore/async_document.py create mode 100644 mockfirestore/async_query.py create mode 100644 mockfirestore/async_transaction.py create mode 100644 tests/test_async_collection_reference.py create mode 100644 tests/test_async_document_reference.py create mode 100644 tests/test_async_mock_client.py create mode 100644 tests/test_async_query.py create mode 100644 tests/test_async_transaction.py diff --git a/mockfirestore/__init__.py b/mockfirestore/__init__.py index ac36f1f..f7d5e34 100644 --- a/mockfirestore/__init__.py +++ b/mockfirestore/__init__.py @@ -13,3 +13,9 @@ from mockfirestore.query import Query from mockfirestore._helpers import Timestamp from mockfirestore.transaction import Transaction, BatchTransaction + +from mockfirestore.async_client import AsyncMockFirestore +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_collection import AsyncCollectionReference +from mockfirestore.async_query import AsyncQuery +from mockfirestore.async_transaction import AsyncTransaction diff --git a/mockfirestore/_helpers.py b/mockfirestore/_helpers.py index 59637e1..15c3c5d 100644 --- a/mockfirestore/_helpers.py +++ b/mockfirestore/_helpers.py @@ -3,7 +3,7 @@ import string from datetime import datetime as dt from functools import reduce -from typing import (Dict, Any, Tuple, TypeVar, Sequence, Iterator) +from typing import (Dict, Any, Tuple, TypeVar, Sequence, Iterator, AsyncIterable, List) T = TypeVar('T') KeyValuePair = Tuple[str, Dict[str, Any]] @@ -97,6 +97,10 @@ def flatten_for_merge(data: Dict[str, Any], prefix: str = '') -> Dict[str, Any]: return result +async def consume_async_iterable(iterable: AsyncIterable[T]) -> List[T]: + return [item async for item in iterable] + + def get_document_iterator(document: Dict[str, Any], prefix: str = '') -> Iterator[Tuple[str, Any]]: """ :returns: (dot-delimited path, value,) diff --git a/mockfirestore/async_client.py b/mockfirestore/async_client.py new file mode 100644 index 0000000..2eb1ecd --- /dev/null +++ b/mockfirestore/async_client.py @@ -0,0 +1,45 @@ +from typing import AsyncIterable, Iterable + +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_collection import AsyncCollectionReference +from mockfirestore.async_transaction import AsyncTransaction +from mockfirestore.client import MockFirestore +from mockfirestore.document import DocumentSnapshot + + +class AsyncMockFirestore(MockFirestore): + def document(self, path: str) -> AsyncDocumentReference: + doc = super().document(path) + assert isinstance(doc, AsyncDocumentReference) + return doc + + def collection(self, path: str) -> AsyncCollectionReference: + path = path.split("/") + + if len(path) % 2 != 1: + raise Exception("Cannot create collection at path {}".format(path)) + + name = path[-1] + if len(path) > 1: + current_position = self._ensure_path(path) + return current_position.collection(name) + else: + if name not in self._data: + self._data[name] = {} + return AsyncCollectionReference(self._data, [name]) + + async def collections(self) -> AsyncIterable[AsyncCollectionReference]: + for collection_name in self._data: + yield AsyncCollectionReference(self._data, [collection_name]) + + async def get_all( + self, + references: Iterable[AsyncDocumentReference], + field_paths=None, + transaction=None, + ) -> AsyncIterable[DocumentSnapshot]: + for doc_ref in set(references): + yield await doc_ref.get() + + def transaction(self, **kwargs) -> AsyncTransaction: + return AsyncTransaction(self, **kwargs) diff --git a/mockfirestore/async_collection.py b/mockfirestore/async_collection.py new file mode 100644 index 0000000..2aaff6c --- /dev/null +++ b/mockfirestore/async_collection.py @@ -0,0 +1,76 @@ +from typing import Optional, List, Tuple, Dict, AsyncIterator, Any, Union +from google.cloud.firestore_v1.base_query import FieldFilter +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.async_query import AsyncQuery +from mockfirestore.collection import CollectionReference +from mockfirestore.document import DocumentSnapshot, DocumentReference +from mockfirestore._helpers import Timestamp, get_by_path, consume_async_iterable + + +class AsyncCollectionReference(CollectionReference): + def document(self, document_id: Optional[str] = None) -> AsyncDocumentReference: + doc_ref = super().document(document_id) + return AsyncDocumentReference( + doc_ref._data, doc_ref._path, parent=doc_ref.parent + ) + + async def get(self) -> List[DocumentSnapshot]: + return await consume_async_iterable(self.stream()) + + async def add( + self, document_data: Dict, document_id: str = None + ) -> Tuple[Timestamp, AsyncDocumentReference]: + timestamp, doc_ref = super().add(document_data, document_id=document_id) + async_doc_ref = AsyncDocumentReference( + doc_ref._data, doc_ref._path, parent=doc_ref.parent + ) + return timestamp, async_doc_ref + + async def list_documents( + self, page_size: Optional[int] = None + ) -> AsyncIterator[DocumentReference]: + docs = super().list_documents() + for doc in docs: + yield doc + + async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: + for key in sorted(get_by_path(self._data, self._path)): + doc_snapshot = await self.document(key).get() + yield doc_snapshot + + def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Optional[Any] = None, + filter: Optional[FieldFilter] = None) -> AsyncQuery: + if filter is not None: + field, op, value = filter.field_path, filter.op_string, filter.value + if field is None or op is None or value is None: + raise ValueError('field, op, and value must be provided (or a FieldFilter instance)') + query = AsyncQuery(self, field_filters=[(field, op, value)]) + return query + + def order_by(self, key: str, direction: Optional[str] = None) -> AsyncQuery: + query = AsyncQuery(self, orders=[(key, direction)]) + return query + + def limit(self, limit_amount: int) -> AsyncQuery: + query = AsyncQuery(self, limit=limit_amount) + return query + + def offset(self, offset: int) -> AsyncQuery: + query = AsyncQuery(self, offset=offset) + return query + + def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, start_at=(document_fields_or_snapshot, True)) + return query + + def start_after(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, start_at=(document_fields_or_snapshot, False)) + return query + + def end_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, end_at=(document_fields_or_snapshot, True)) + return query + + def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery: + query = AsyncQuery(self, end_at=(document_fields_or_snapshot, False)) + return query diff --git a/mockfirestore/async_document.py b/mockfirestore/async_document.py new file mode 100644 index 0000000..3d29bc4 --- /dev/null +++ b/mockfirestore/async_document.py @@ -0,0 +1,35 @@ +from copy import deepcopy +from typing import Dict, Any +from mockfirestore import NotFound +from mockfirestore._helpers import flatten_for_merge +from mockfirestore.document import DocumentReference, DocumentSnapshot + + +class AsyncDocumentReference(DocumentReference): + async def get(self) -> DocumentSnapshot: + return super().get() + + async def delete(self): + super().delete() + + async def set(self, data: Dict[str, Any], merge=False): + # Mirrors the sync set's merge branch with awaits: the sync + # implementation calls self.update, which dispatches to the async + # override here and would return an unawaited coroutine. + if merge: + data = deepcopy(data) + data['__name__'] = self.id + try: + await self.update(flatten_for_merge(data)) + except NotFound: + super().set(data, merge=False) + else: + super().set(data, merge=False) + + async def update(self, data: Dict[str, Any]): + super().update(data) + + def collection(self, name) -> 'AsyncCollectionReference': + from mockfirestore.async_collection import AsyncCollectionReference + coll_ref = super().collection(name) + return AsyncCollectionReference(coll_ref._data, coll_ref._path, self) diff --git a/mockfirestore/async_query.py b/mockfirestore/async_query.py new file mode 100644 index 0000000..cbf42d3 --- /dev/null +++ b/mockfirestore/async_query.py @@ -0,0 +1,14 @@ +from typing import List, AsyncIterator +from mockfirestore.document import DocumentSnapshot +from mockfirestore.query import Query +from mockfirestore._helpers import consume_async_iterable + + +class AsyncQuery(Query): + async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]: + doc_snapshots = await consume_async_iterable(self.parent.stream()) + for doc_snapshot in self._process_snapshots(doc_snapshots): + yield doc_snapshot + + async def get(self, transaction=None) -> List[DocumentSnapshot]: + return await consume_async_iterable(self.stream()) diff --git a/mockfirestore/async_transaction.py b/mockfirestore/async_transaction.py new file mode 100644 index 0000000..6a3b6d2 --- /dev/null +++ b/mockfirestore/async_transaction.py @@ -0,0 +1,44 @@ +from typing import AsyncIterable, Iterable + +from mockfirestore.async_document import AsyncDocumentReference +from mockfirestore.document import DocumentSnapshot +from mockfirestore.transaction import Transaction, WriteResult, _CANT_COMMIT + + +class AsyncTransaction(Transaction): + async def _begin(self, retry_id=None): + return super()._begin() + + async def _rollback(self): + super()._rollback() + + async def _commit(self) -> Iterable[WriteResult]: + if not self.in_progress: + raise ValueError(_CANT_COMMIT) + + results = [] + for write_op in self._write_ops: + await write_op() + results.append(WriteResult()) + self.write_results = results + self._clean_up() + return results + + async def get(self, ref_or_query) -> AsyncIterable[DocumentSnapshot]: + doc_snapshots = super().get(ref_or_query) + async for doc_snapshot in doc_snapshots: + yield doc_snapshot + + async def get_all( + self, references: Iterable[AsyncDocumentReference] + ) -> AsyncIterable[DocumentSnapshot]: + doc_snapshots = super().get_all(references) + async for doc_snapshot in doc_snapshots: + yield doc_snapshot + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + await self.commit() diff --git a/mockfirestore/client.py b/mockfirestore/client.py index dc9164c..c444581 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -14,7 +14,7 @@ def _ensure_path(self, path): current_position = self for el in path[:-1]: - if type(current_position) in (MockFirestore, DocumentReference): + if isinstance(current_position, (MockFirestore, DocumentReference)): current_position = current_position.collection(el) else: current_position = current_position.document(el) diff --git a/mockfirestore/query.py b/mockfirestore/query.py index ddb9894..1d67ecc 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -33,7 +33,9 @@ def __init__(self, parent: 'CollectionReference', projection=None, def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: doc_snapshots = self.parent.stream() + return iter(self._process_snapshots(doc_snapshots)) + def _process_snapshots(self, doc_snapshots) -> List[DocumentSnapshot]: for field, compare, value in self._field_filters: doc_snapshots = [doc_snapshot for doc_snapshot in doc_snapshots if compare(doc_snapshot._get_by_field_path(field), value)] @@ -57,7 +59,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: if self._limit: doc_snapshots = islice(doc_snapshots, self._limit) - return iter(list(doc_snapshots)) + return list(doc_snapshots) def get(self) -> Iterator[DocumentSnapshot]: warnings.warn('Query.get is deprecated, please use Query.stream', diff --git a/requirements-dev-minimal.txt b/requirements-dev-minimal.txt index 38604d8..903c408 100644 --- a/requirements-dev-minimal.txt +++ b/requirements-dev-minimal.txt @@ -1 +1,2 @@ -google-cloud-firestore \ No newline at end of file +google-cloud-firestore +aiounittest diff --git a/tests/test_async_collection_reference.py b/tests/test_async_collection_reference.py new file mode 100644 index 0000000..eb1145f --- /dev/null +++ b/tests/test_async_collection_reference.py @@ -0,0 +1,524 @@ +import aiounittest +from mockfirestore import ( + AsyncMockFirestore, + DocumentReference, + DocumentSnapshot, + AlreadyExists, +) +from mockfirestore._helpers import consume_async_iterable + + +class TestAsyncCollectionReference(aiounittest.AsyncTestCase): + async def test_collection_get_returnsDocuments(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").stream()) + + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + + async def test_collection_get_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + docs = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual([], docs) + + async def test_collection_get_nestedCollection(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1, "bar": {"first_nested": {"id": 1.1}}}}} + docs = await consume_async_iterable( + fs.collection("foo").document("first").collection("bar").stream() + ) + self.assertEqual({"id": 1.1}, docs[0].to_dict()) + + async def test_collection_get_nestedCollection_by_path(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1, "bar": {"first_nested": {"id": 1.1}}}}} + docs = await consume_async_iterable(fs.collection("foo/first/bar").stream()) + self.assertEqual({"id": 1.1}, docs[0].to_dict()) + + async def test_collection_get_nestedCollection_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + docs = await consume_async_iterable( + fs.collection("foo").document("first").collection("bar").stream() + ) + self.assertEqual([], docs) + + async def test_collection_get_nestedCollection_by_path_collectionDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + docs = await consume_async_iterable(fs.collection("foo/first/bar").stream()) + self.assertEqual([], docs) + + async def test_collection_get_ordersByAscendingDocumentId_byDefault(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"beta": {"id": 1}, "alpha": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual({"id": 2}, docs[0].to_dict()) + + async def test_collection_whereEquals(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"valid": True}, "second": {"gumby": False}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("valid", "==", True).stream() + ) + self.assertEqual({"valid": True}, docs[0].to_dict()) + + async def test_collection_whereNotEquals(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "!=", 1).stream() + ) + self.assertEqual({"count": 5}, docs[0].to_dict()) + + async def test_collection_whereLessThan(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "<", 5).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + + async def test_collection_whereLessThanOrEqual(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", "<=", 5).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + self.assertEqual({"count": 5}, docs[1].to_dict()) + + async def test_collection_whereGreaterThan(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", ">", 1).stream() + ) + self.assertEqual({"count": 5}, docs[0].to_dict()) + + async def test_collection_whereGreaterThanOrEqual(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("count", ">=", 1).stream() + ) + self.assertEqual({"count": 1}, docs[0].to_dict()) + self.assertEqual({"count": 5}, docs[1].to_dict()) + + async def test_collection_whereMissingField(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}, "second": {"count": 5}}} + + docs = await consume_async_iterable( + fs.collection("foo").where("no_field", "==", 1).stream() + ) + self.assertEqual(len(docs), 0) + + async def test_collection_whereNestedField(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"nested": {"a": 1}}, "second": {"nested": {"a": 2}}} + } + + docs = await consume_async_iterable( + fs.collection("foo").where("nested.a", "==", 1).stream() + ) + self.assertEqual(len(docs), 1) + self.assertEqual({"nested": {"a": 1}}, docs[0].to_dict()) + + async def test_collection_whereIn(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": "a1"}, + "second": {"field": "a2"}, + "third": {"field": "a3"}, + "fourth": {"field": "a4"}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").where("field", "in", ["a1", "a3"]).stream() + ) + self.assertEqual(len(docs), 2) + self.assertEqual({"field": "a1"}, docs[0].to_dict()) + self.assertEqual({"field": "a3"}, docs[1].to_dict()) + + async def test_collection_whereArrayContains(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": ["val4"]}, + "second": {"field": ["val3", "val2"]}, + "third": {"field": ["val3", "val2", "val1"]}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").where("field", "array_contains", "val1").stream() + ) + self.assertEqual(len(docs), 1) + self.assertEqual(docs[0].to_dict(), {"field": ["val3", "val2", "val1"]}) + + async def test_collection_whereArrayContainsAny(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"field": ["val4"]}, + "second": {"field": ["val3", "val2"]}, + "third": {"field": ["val3", "val2", "val1"]}, + } + } + + contains_any_docs = await consume_async_iterable( + fs.collection("foo") + .where("field", "array_contains_any", ["val1", "val4"]) + .stream() + ) + self.assertEqual(len(contains_any_docs), 2) + self.assertEqual({"field": ["val4"]}, contains_any_docs[0].to_dict()) + self.assertEqual( + {"field": ["val3", "val2", "val1"]}, contains_any_docs[1].to_dict() + ) + + async def test_collection_orderBy(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"order": 2}, "second": {"order": 1}}} + + docs = await consume_async_iterable( + fs.collection("foo").order_by("order").stream() + ) + self.assertEqual({"order": 1}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + + async def test_collection_orderBy_descending(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 3}, + "third": {"order": 1}, + } + } + + docs = await consume_async_iterable( + fs.collection("foo").order_by("order", direction="DESCENDING").stream() + ) + self.assertEqual({"order": 3}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + self.assertEqual({"order": 1}, docs[2].to_dict()) + + async def test_collection_limit(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + docs = await consume_async_iterable(fs.collection("foo").limit(1).stream()) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_offset(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable(fs.collection("foo").offset(1).stream()) + + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_orderby_offset(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").offset(1).stream() + ) + + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").start_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_at_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("second").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_at(doc).stream() + ) + self.assertEqual(4, len(docs)) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual({"id": 3}, docs[1].to_dict()) + self.assertEqual({"id": 4}, docs[2].to_dict()) + self.assertEqual({"id": 5}, docs[3].to_dict()) + + async def test_collection_start_after(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").start_after({"id": 1}).stream() + ) + self.assertEqual({"id": 2}, docs[0].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_start_after_similar_objects(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1, "value": 1}, + "second": {"id": 2, "value": 2}, + "third": {"id": 3, "value": 2}, + "fourth": {"id": 4, "value": 3}, + } + } + docs = await consume_async_iterable( + fs.collection("foo") + .order_by("id") + .start_after({"id": 3, "value": 2}) + .stream() + ) + self.assertEqual({"id": 4, "value": 3}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_start_after_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_after({"id": 2}).stream() + ) + self.assertEqual({"id": 3}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_start_after_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("second").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").start_after(doc).stream() + ) + self.assertEqual(3, len(docs)) + self.assertEqual({"id": 3}, docs[0].to_dict()) + self.assertEqual({"id": 4}, docs[1].to_dict()) + self.assertEqual({"id": 5}, docs[2].to_dict()) + + async def test_collection_end_before(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").end_before({"id": 2}).stream() + ) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_end_before_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_before({"id": 2}).stream() + ) + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual(1, len(docs)) + + async def test_collection_end_before_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("fourth").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_before(doc).stream() + ) + self.assertEqual(3, len(docs)) + + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual({"id": 3}, docs[2].to_dict()) + + async def test_collection_end_at(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").end_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_end_at_order_by(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"id": 1}, "second": {"id": 2}, "third": {"id": 3}} + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_at({"id": 2}).stream() + ) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual(2, len(docs)) + + async def test_collection_end_at_doc_snapshot(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"id": 1}, + "second": {"id": 2}, + "third": {"id": 3}, + "fourth": {"id": 4}, + "fifth": {"id": 5}, + } + } + + doc = await fs.collection("foo").document("fourth").get() + + docs = await consume_async_iterable( + fs.collection("foo").order_by("id").end_at(doc).stream() + ) + self.assertEqual(4, len(docs)) + + self.assertEqual({"id": 1}, docs[0].to_dict()) + self.assertEqual({"id": 2}, docs[1].to_dict()) + self.assertEqual({"id": 3}, docs[2].to_dict()) + self.assertEqual({"id": 4}, docs[3].to_dict()) + + async def test_collection_limitAndOrderBy(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + docs = await consume_async_iterable( + fs.collection("foo").order_by("order").limit(2).stream() + ) + self.assertEqual({"order": 1}, docs[0].to_dict()) + self.assertEqual({"order": 2}, docs[1].to_dict()) + + async def test_collection_listDocuments(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_refs = await consume_async_iterable(fs.collection("foo").list_documents()) + self.assertEqual(3, len(doc_refs)) + for doc_ref in doc_refs: + self.assertIsInstance(doc_ref, DocumentReference) + + async def test_collection_stream(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_snapshots = await consume_async_iterable(fs.collection("foo").stream()) + self.assertEqual(3, len(doc_snapshots)) + for doc_snapshot in doc_snapshots: + self.assertIsInstance(doc_snapshot, DocumentSnapshot) + + async def test_collection_parent(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": {"order": 2}, + "second": {"order": 1}, + "third": {"order": 3}, + } + } + doc_snapshots = await consume_async_iterable(fs.collection("foo").stream()) + for doc_snapshot in doc_snapshots: + doc_reference = doc_snapshot.reference + subcollection = doc_reference.collection("order") + self.assertIs(subcollection.parent, doc_reference) + + async def test_collection_addDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_id = "bar" + doc_content = {"id": doc_id, "xy": "z"} + timestamp, doc_ref = await fs.collection("foo").add(doc_content) + doc_snapshot = await doc_ref.get() + self.assertEqual(doc_content, doc_snapshot.to_dict()) + + doc = await fs.collection("foo").document(doc_id).get() + self.assertEqual(doc_content, doc.to_dict()) + + with self.assertRaises(AlreadyExists): + await fs.collection("foo").add(doc_content) + + async def test_collection_useDocumentIdKwarg(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.collection("foo").document(document_id="first").get() + self.assertEqual({"id": 1}, doc.to_dict()) diff --git a/tests/test_async_document_reference.py b/tests/test_async_document_reference.py new file mode 100644 index 0000000..aec8ef4 --- /dev/null +++ b/tests/test_async_document_reference.py @@ -0,0 +1,337 @@ +import aiounittest + +from google.cloud import firestore +from mockfirestore import AsyncMockFirestore, NotFound + + +class TestAsyncDocumentReference(aiounittest.AsyncTestCase): + async def test_get_document_by_path(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.document("foo/first").get() + self.assertEqual({"id": 1}, doc.to_dict()) + self.assertEqual("first", doc.id) + + async def test_set_document_by_path(self): + fs = AsyncMockFirestore() + fs._data = {} + doc_content = {"id": "bar"} + await fs.document("foo/doc1/bar/doc2").set(doc_content) + doc = await fs.document("foo/doc1/bar/doc2").get() + doc = doc.to_dict() + self.assertEqual(doc_content, doc) + + async def test_document_get_returnsDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1}, doc.to_dict()) + self.assertEqual("first", doc.id) + + async def test_document_get_documentIdEqualsKey(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + doc_ref = fs.collection("foo").document("first") + self.assertEqual("first", doc_ref.id) + + async def test_document_get_newDocumentReturnsDefaultId(self): + fs = AsyncMockFirestore() + doc_ref = fs.collection("foo").document() + doc = await doc_ref.get() + self.assertNotEqual(None, doc_ref.id) + self.assertFalse(doc.exists) + + async def test_document_get_documentDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc = await fs.collection("foo").document("bar").get() + self.assertEqual({}, doc.to_dict()) + + async def test_get_nestedDocument(self): + fs = AsyncMockFirestore() + fs._data = { + "top_collection": { + "top_document": { + "id": 1, + "nested_collection": {"nested_document": {"id": 1.1}}, + } + } + } + doc = ( + await fs.collection("top_collection") + .document("top_document") + .collection("nested_collection") + .document("nested_document") + .get() + ) + + self.assertEqual({"id": 1.1}, doc.to_dict()) + + async def test_get_nestedDocument_documentDoesNotExist(self): + fs = AsyncMockFirestore() + fs._data = { + "top_collection": {"top_document": {"id": 1, "nested_collection": {}}} + } + doc = ( + await fs.collection("top_collection") + .document("top_document") + .collection("nested_collection") + .document("nested_document") + .get() + ) + + self.assertEqual({}, doc.to_dict()) + + async def test_document_set_setsContentOfDocument(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_content = {"id": "bar"} + await fs.collection("foo").document("bar").set(doc_content) + doc = await fs.collection("foo").document("bar").get() + self.assertEqual(doc_content, doc.to_dict()) + + async def test_document_set_mergeNewValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").set({"updated": True}, merge=True) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1, "updated": True}, doc.to_dict()) + + async def test_document_set_mergeNewValueForNonExistentDoc(self): + fs = AsyncMockFirestore() + await fs.collection("foo").document("first").set({"updated": True}, merge=True) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"updated": True}, doc.to_dict()) + + async def test_document_set_overwriteValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").set({"new_id": 1}, merge=False) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"new_id": 1}, doc.to_dict()) + + async def test_document_set_isolation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {}} + doc_content = {"id": "bar"} + await fs.collection("foo").document("bar").set(doc_content) + doc_content["id"] = "new value" + doc = await fs.collection("foo").document("bar").get() + self.assertEqual({"id": "bar"}, doc.to_dict()) + + async def test_document_update_addNewValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").update({"updated": True}) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 1, "updated": True}, doc.to_dict()) + + async def test_document_update_changeExistingValue(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").update({"id": 2}) + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"id": 2}, doc.to_dict()) + + async def test_document_update_documentDoesNotExist(self): + fs = AsyncMockFirestore() + with self.assertRaises(NotFound): + await fs.collection("foo").document("nonexistent").update({"id": 2}) + docsnap = await fs.collection("foo").document("nonexistent").get() + self.assertFalse(docsnap.exists) + + async def test_document_update_isolation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"nested": {"id": 1}}}} + update_doc = {"nested": {"id": 2}} + await fs.collection("foo").document("first").update(update_doc) + update_doc["nested"]["id"] = 3 + doc = await fs.collection("foo").document("first").get() + self.assertEqual({"nested": {"id": 2}}, doc.to_dict()) + + async def test_document_update_transformerIncrementBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"count": 1}}} + await fs.collection("foo").document("first").update( + {"count": firestore.Increment(2)} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"count": 3}) + + async def test_document_update_transformerIncrementNested(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": { + "nested": {"count": 1}, + "other": {"likes": 0}, + } + } + } + await fs.collection("foo").document("first").update( + { + "nested": {"count": firestore.Increment(-1)}, + "other": {"likes": firestore.Increment(1), "smoked": "salmon"}, + } + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), + {"nested": {"count": 0}, "other": {"likes": 1, "smoked": "salmon"}}, + ) + + async def test_document_update_transformerIncrementNonExistent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"count": firestore.Increment(1)} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"count": 1, "spicy": "tuna"}) + + async def test_document_delete_documentDoesNotExistAfterDelete(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + await fs.collection("foo").document("first").delete() + doc = await fs.collection("foo").document("first").get() + self.assertEqual(False, doc.exists) + + async def test_document_parent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}}} + coll = fs.collection("foo") + document = coll.document("first") + self.assertIs(document.parent, coll) + + async def test_document_update_transformerArrayUnionBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayUnion([3, 4])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) + + async def test_document_update_transformerArrayUnionNested(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": { + "first": { + "nested": {"arr": [1]}, + "other": {"labels": ["a"]}, + } + } + } + await fs.collection("foo").document("first").update( + { + "nested": {"arr": firestore.ArrayUnion([2])}, + "other": {"labels": firestore.ArrayUnion(["b"]), "smoked": "salmon"}, + } + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), + { + "nested": {"arr": [1, 2]}, + "other": {"labels": ["a", "b"], "smoked": "salmon"}, + }, + ) + + async def test_document_update_transformerArrayUnionNonExistent(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayUnion([1])} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"arr": [1], "spicy": "tuna"}) + + async def test_document_update_nestedFieldDotNotation(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"nested": {"value": 1, "unchanged": "foo"}}}} + + await fs.collection("foo").document("first").update({"nested.value": 2}) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"nested": {"value": 2, "unchanged": "foo"}}) + + async def test_document_update_nestedFieldDotNotationNestedFieldCreation(self): + fs = AsyncMockFirestore() + fs._data = { + "foo": {"first": {"other": None}} + } # non-existent nested field is created + + await fs.collection("foo").document("first").update({"nested.value": 2}) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {"nested": {"value": 2}, "other": None}) + + async def test_document_update_nestedFieldDotNotationMultipleNested(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"other": None}}} + + await fs.collection("foo").document("first").update( + {"nested.subnested.value": 42} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), {"nested": {"subnested": {"value": 42}}, "other": None} + ) + + async def test_document_update_nestedFieldDotNotationMultipleNestedWithTransformer( + self, + ): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"other": None}}} + + await fs.collection("foo").document("first").update( + {"nested.subnested.value": firestore.ArrayUnion([1, 3])} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual( + doc.to_dict(), {"nested": {"subnested": {"value": [1, 3]}}, "other": None} + ) + + async def test_document_update_transformerSentinel(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"spicy": "tuna"}}} + await fs.collection("foo").document("first").update( + {"spicy": firestore.DELETE_FIELD} + ) + + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict(), {}) + + async def test_document_update_transformerArrayRemoveBasic(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayRemove([3, 4])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2]) + + async def test_document_update_transformerArrayRemoveNonExistentField(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"arr": firestore.ArrayRemove([5])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) + + async def test_document_update_transformerArrayRemoveNonExistentArray(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}} + await fs.collection("foo").document("first").update( + {"non_existent_array": firestore.ArrayRemove([1, 2])} + ) + doc = await fs.collection("foo").document("first").get() + self.assertEqual(doc.to_dict()["arr"], [1, 2, 3, 4]) diff --git a/tests/test_async_mock_client.py b/tests/test_async_mock_client.py new file mode 100644 index 0000000..6131e31 --- /dev/null +++ b/tests/test_async_mock_client.py @@ -0,0 +1,25 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore +from mockfirestore._helpers import consume_async_iterable + + +class TestAsyncMockFirestore(aiounittest.AsyncTestCase): + async def test_client_get_all(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + doc = fs.collection("foo").document("first") + results = await consume_async_iterable(fs.get_all([doc])) + returned_doc_snapshot = results[0].to_dict() + expected_doc_snapshot = (await doc.get()).to_dict() + self.assertEqual(returned_doc_snapshot, expected_doc_snapshot) + + async def test_client_collections(self): + fs = AsyncMockFirestore() + fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}, "bar": {}} + collections = await consume_async_iterable(fs.collections()) + expected_collections = fs._data + + self.assertEqual(len(collections), len(expected_collections)) + for collection in collections: + self.assertTrue(collection._path[0] in expected_collections) diff --git a/tests/test_async_query.py b/tests/test_async_query.py new file mode 100644 index 0000000..4931ba2 --- /dev/null +++ b/tests/test_async_query.py @@ -0,0 +1,13 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore + + +class TestAsyncMockFirestore(aiounittest.AsyncTestCase): + async def test_query_get(self): + fs = AsyncMockFirestore() + doc_in_fs = {"id": 1} + fs._data = {"foo": {"first": doc_in_fs}} + docs = await fs.collection("foo").where("id", "==", 1).get() + self.assertEqual(len(docs), 1) + self.assertEqual(docs[0].to_dict()["id"], 1) diff --git a/tests/test_async_transaction.py b/tests/test_async_transaction.py new file mode 100644 index 0000000..d365a06 --- /dev/null +++ b/tests/test_async_transaction.py @@ -0,0 +1,71 @@ +import aiounittest + +from mockfirestore import AsyncMockFirestore, AsyncTransaction +from mockfirestore._helpers import consume_async_iterable + + +class TestAsyncTransaction(aiounittest.AsyncTestCase): + def setUp(self) -> None: + self.fs = AsyncMockFirestore() + self.fs._data = {"foo": {"first": {"id": 1}, "second": {"id": 2}}} + + async def test_transaction_getAll(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + docs = [ + self.fs.collection("foo").document(doc_name) + for doc_name in self.fs._data["foo"] + ] + results = await consume_async_iterable(transaction.get_all(docs)) + returned_docs_snapshots = [result.to_dict() for result in results] + expected_doc_snapshots = [(await doc.get()).to_dict() for doc in docs] + for expected_snapshot in expected_doc_snapshots: + self.assertIn(expected_snapshot, returned_docs_snapshots) + + async def test_transaction_getDocument(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + doc = self.fs.collection("foo").document("first") + returned_doc = [doc async for doc in transaction.get(doc)][0] + self.assertEqual((await doc.get()).to_dict(), returned_doc.to_dict()) + + async def test_transaction_getQuery(self): + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + query = self.fs.collection("foo").order_by("id") + returned_docs = [doc.to_dict() async for doc in transaction.get(query)] + query = self.fs.collection("foo").order_by("id") + expected_docs = [doc.to_dict() async for doc in query.stream()] + self.assertEqual(returned_docs, expected_docs) + + async def test_transaction_set_setsContentOfDocument(self): + doc_content = {"id": "3"} + doc_ref = self.fs.collection("foo").document("third") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.set(doc_ref, doc_content) + self.assertEqual((await doc_ref.get()).to_dict(), doc_content) + + async def test_transaction_set_mergeNewValue(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.set(doc, {"updated": True}, merge=True) + updated_doc = {"id": 1, "updated": True} + self.assertEqual((await doc.get()).to_dict(), updated_doc) + + async def test_transaction_update_changeExistingValue(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.update(doc, {"updated": False}) + updated_doc = {"id": 1, "updated": False} + self.assertEqual((await doc.get()).to_dict(), updated_doc) + + async def test_transaction_delete_documentDoesNotExistAfterDelete(self): + doc = self.fs.collection("foo").document("first") + async with AsyncTransaction(self.fs) as transaction: + await transaction._begin() + transaction.delete(doc) + doc = await self.fs.collection("foo").document("first").get() + self.assertEqual(False, doc.exists)