Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mockfirestore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@
from mockfirestore.query import Query
from mockfirestore._helpers import Timestamp
from mockfirestore.transaction import Transaction, BatchTransaction

from mockfirestore.async_client import AsyncMockFirestore
from mockfirestore.async_document import AsyncDocumentReference
from mockfirestore.async_collection import AsyncCollectionReference
from mockfirestore.async_query import AsyncQuery
from mockfirestore.async_transaction import AsyncTransaction
6 changes: 5 additions & 1 deletion mockfirestore/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import string
from datetime import datetime as dt
from functools import reduce
from typing import (Dict, Any, Tuple, TypeVar, Sequence, Iterator)
from typing import (Dict, Any, Tuple, TypeVar, Sequence, Iterator, AsyncIterable, List)

T = TypeVar('T')
KeyValuePair = Tuple[str, Dict[str, Any]]
Expand Down Expand Up @@ -97,6 +97,10 @@ def flatten_for_merge(data: Dict[str, Any], prefix: str = '') -> Dict[str, Any]:
return result


async def consume_async_iterable(iterable: AsyncIterable[T]) -> List[T]:
return [item async for item in iterable]


def get_document_iterator(document: Dict[str, Any], prefix: str = '') -> Iterator[Tuple[str, Any]]:
"""
:returns: (dot-delimited path, value,)
Expand Down
45 changes: 45 additions & 0 deletions mockfirestore/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import AsyncIterable, Iterable

from mockfirestore.async_document import AsyncDocumentReference
from mockfirestore.async_collection import AsyncCollectionReference
from mockfirestore.async_transaction import AsyncTransaction
from mockfirestore.client import MockFirestore
from mockfirestore.document import DocumentSnapshot


class AsyncMockFirestore(MockFirestore):
def document(self, path: str) -> AsyncDocumentReference:
doc = super().document(path)
assert isinstance(doc, AsyncDocumentReference)
return doc
Comment on lines +11 to +14

def collection(self, path: str) -> AsyncCollectionReference:
path = path.split("/")

if len(path) % 2 != 1:
raise Exception("Cannot create collection at path {}".format(path))

name = path[-1]
if len(path) > 1:
current_position = self._ensure_path(path)
return current_position.collection(name)
else:
if name not in self._data:
self._data[name] = {}
return AsyncCollectionReference(self._data, [name])

async def collections(self) -> AsyncIterable[AsyncCollectionReference]:
for collection_name in self._data:
yield AsyncCollectionReference(self._data, [collection_name])

async def get_all(
self,
references: Iterable[AsyncDocumentReference],
field_paths=None,
transaction=None,
) -> AsyncIterable[DocumentSnapshot]:
for doc_ref in set(references):
yield await doc_ref.get()

def transaction(self, **kwargs) -> AsyncTransaction:
return AsyncTransaction(self, **kwargs)
76 changes: 76 additions & 0 deletions mockfirestore/async_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional, List, Tuple, Dict, AsyncIterator, Any, Union
from google.cloud.firestore_v1.base_query import FieldFilter
from mockfirestore.async_document import AsyncDocumentReference
from mockfirestore.async_query import AsyncQuery
from mockfirestore.collection import CollectionReference
from mockfirestore.document import DocumentSnapshot, DocumentReference
from mockfirestore._helpers import Timestamp, get_by_path, consume_async_iterable


class AsyncCollectionReference(CollectionReference):
def document(self, document_id: Optional[str] = None) -> AsyncDocumentReference:
doc_ref = super().document(document_id)
return AsyncDocumentReference(
doc_ref._data, doc_ref._path, parent=doc_ref.parent
)

async def get(self) -> List[DocumentSnapshot]:
return await consume_async_iterable(self.stream())

async def add(
self, document_data: Dict, document_id: str = None
) -> Tuple[Timestamp, AsyncDocumentReference]:
timestamp, doc_ref = super().add(document_data, document_id=document_id)
async_doc_ref = AsyncDocumentReference(
doc_ref._data, doc_ref._path, parent=doc_ref.parent
)
return timestamp, async_doc_ref

async def list_documents(
self, page_size: Optional[int] = None
) -> AsyncIterator[DocumentReference]:
docs = super().list_documents()
for doc in docs:
yield doc

async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]:
for key in sorted(get_by_path(self._data, self._path)):
doc_snapshot = await self.document(key).get()
yield doc_snapshot

def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Optional[Any] = None,
filter: Optional[FieldFilter] = None) -> AsyncQuery:
if filter is not None:
field, op, value = filter.field_path, filter.op_string, filter.value
if field is None or op is None or value is None:
raise ValueError('field, op, and value must be provided (or a FieldFilter instance)')
query = AsyncQuery(self, field_filters=[(field, op, value)])
return query

def order_by(self, key: str, direction: Optional[str] = None) -> AsyncQuery:
query = AsyncQuery(self, orders=[(key, direction)])
return query

def limit(self, limit_amount: int) -> AsyncQuery:
query = AsyncQuery(self, limit=limit_amount)
return query
Comment on lines +54 to +56

def offset(self, offset: int) -> AsyncQuery:
query = AsyncQuery(self, offset=offset)
return query

def start_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery:
query = AsyncQuery(self, start_at=(document_fields_or_snapshot, True))
return query

def start_after(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery:
query = AsyncQuery(self, start_at=(document_fields_or_snapshot, False))
return query

def end_at(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery:
query = AsyncQuery(self, end_at=(document_fields_or_snapshot, True))
return query

def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) -> AsyncQuery:
query = AsyncQuery(self, end_at=(document_fields_or_snapshot, False))
return query
35 changes: 35 additions & 0 deletions mockfirestore/async_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from copy import deepcopy
from typing import Dict, Any
from mockfirestore import NotFound
from mockfirestore._helpers import flatten_for_merge
from mockfirestore.document import DocumentReference, DocumentSnapshot


class AsyncDocumentReference(DocumentReference):
async def get(self) -> DocumentSnapshot:
return super().get()

async def delete(self):
super().delete()

async def set(self, data: Dict[str, Any], merge=False):
# Mirrors the sync set's merge branch with awaits: the sync
# implementation calls self.update, which dispatches to the async
# override here and would return an unawaited coroutine.
if merge:
data = deepcopy(data)
data['__name__'] = self.id
try:
await self.update(flatten_for_merge(data))
except NotFound:
super().set(data, merge=False)
else:
super().set(data, merge=False)

async def update(self, data: Dict[str, Any]):
super().update(data)

def collection(self, name) -> 'AsyncCollectionReference':
from mockfirestore.async_collection import AsyncCollectionReference
coll_ref = super().collection(name)
return AsyncCollectionReference(coll_ref._data, coll_ref._path, self)
14 changes: 14 additions & 0 deletions mockfirestore/async_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List, AsyncIterator
from mockfirestore.document import DocumentSnapshot
from mockfirestore.query import Query
from mockfirestore._helpers import consume_async_iterable


class AsyncQuery(Query):
async def stream(self, transaction=None) -> AsyncIterator[DocumentSnapshot]:
doc_snapshots = await consume_async_iterable(self.parent.stream())
for doc_snapshot in self._process_snapshots(doc_snapshots):
yield doc_snapshot

async def get(self, transaction=None) -> List[DocumentSnapshot]:
return await consume_async_iterable(self.stream())
44 changes: 44 additions & 0 deletions mockfirestore/async_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import AsyncIterable, Iterable

from mockfirestore.async_document import AsyncDocumentReference
from mockfirestore.document import DocumentSnapshot
from mockfirestore.transaction import Transaction, WriteResult, _CANT_COMMIT


class AsyncTransaction(Transaction):
async def _begin(self, retry_id=None):
return super()._begin()

async def _rollback(self):
super()._rollback()

async def _commit(self) -> Iterable[WriteResult]:
if not self.in_progress:
raise ValueError(_CANT_COMMIT)

results = []
for write_op in self._write_ops:
await write_op()
results.append(WriteResult())
self.write_results = results
self._clean_up()
return results

async def get(self, ref_or_query) -> AsyncIterable[DocumentSnapshot]:
doc_snapshots = super().get(ref_or_query)
async for doc_snapshot in doc_snapshots:
yield doc_snapshot

async def get_all(
self, references: Iterable[AsyncDocumentReference]
) -> AsyncIterable[DocumentSnapshot]:
doc_snapshots = super().get_all(references)
async for doc_snapshot in doc_snapshots:
yield doc_snapshot

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.commit()
2 changes: 1 addition & 1 deletion mockfirestore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _ensure_path(self, path):
current_position = self

for el in path[:-1]:
if type(current_position) in (MockFirestore, DocumentReference):
if isinstance(current_position, (MockFirestore, DocumentReference)):
current_position = current_position.collection(el)
else:
current_position = current_position.document(el)
Expand Down
4 changes: 3 additions & 1 deletion mockfirestore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def __init__(self, parent: 'CollectionReference', projection=None,

def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
doc_snapshots = self.parent.stream()
return iter(self._process_snapshots(doc_snapshots))

def _process_snapshots(self, doc_snapshots) -> List[DocumentSnapshot]:
for field, compare, value in self._field_filters:
doc_snapshots = [doc_snapshot for doc_snapshot in doc_snapshots
if compare(doc_snapshot._get_by_field_path(field), value)]
Expand All @@ -57,7 +59,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
if self._limit:
doc_snapshots = islice(doc_snapshots, self._limit)

return iter(list(doc_snapshots))
return list(doc_snapshots)

def get(self) -> Iterator[DocumentSnapshot]:
warnings.warn('Query.get is deprecated, please use Query.stream',
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev-minimal.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
google-cloud-firestore
google-cloud-firestore
aiounittest
Loading