diff --git a/mockfirestore/_helpers.py b/mockfirestore/_helpers.py index 74222db..59637e1 100644 --- a/mockfirestore/_helpers.py +++ b/mockfirestore/_helpers.py @@ -77,6 +77,26 @@ def nanos(self): return str(self._timestamp).split('.')[1] +def flatten_for_merge(data: Dict[str, Any], prefix: str = '') -> Dict[str, Any]: + """Flatten nested dicts into dot-notation keys for set(merge=True). + + Firestore's set(merge=True) deep-merges at every nesting level. + The mock's update() already handles dot-notation keys correctly, + so flattening first lets the existing logic work for nested dicts. + + Only plain dicts are recursed into; Firestore transform objects + (Sentinel, ArrayUnion, etc.) are treated as leaf values. + """ + result: Dict[str, Any] = {} + for key, value in data.items(): + full_key = '{}.{}'.format(prefix, key) if prefix else key + if isinstance(value, dict) and value: + result.update(flatten_for_merge(value, prefix=full_key)) + else: + result[full_key] = value + return result + + def get_document_iterator(document: Dict[str, Any], prefix: str = '') -> Iterator[Tuple[str, Any]]: """ :returns: (dot-delimited path, value,) diff --git a/mockfirestore/_transformations.py b/mockfirestore/_transformations.py index 3ea6193..3697419 100644 --- a/mockfirestore/_transformations.py +++ b/mockfirestore/_transformations.py @@ -68,7 +68,10 @@ def _apply_updates(document: Dict[str, Any], data: Dict[str, Any]): def _apply_deletes(document: Dict[str, Any], data: List[str]): for key in data: path = key.split(".") - delete_by_path(document, path) + try: + delete_by_path(document, path) + except KeyError: + continue def _apply_arr_deletes(document: Dict[str, Any], data: Dict[str, Any]): diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index d2af224..4e6ae90 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -26,9 +26,8 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: return DocumentReference(self._data, new_path, parent=self) def get(self) -> Iterable[DocumentSnapshot]: - warnings.warn('Collection.get is deprecated, please use Collection.stream', - category=DeprecationWarning) - return self.stream() + # Stream uses a generator, so we need to convert it to a list for compatibility for .get() method with firestore library + return list(self.stream()) @property def path(self): @@ -127,9 +126,8 @@ def document(self, document_id: Optional[str] = None, path: List[str] = None) -> return ret def get(self) -> Iterable[DocumentSnapshot]: - warnings.warn('Collection.get is deprecated, please use Collection.stream', - category=DeprecationWarning) - return self.stream() + # Stream uses a generator, so we need to convert it to a list for compatibility for .get() method with firestore library + return list(self.stream()) def stream(self, transaction=None) -> Iterable[DocumentSnapshot]: for path in self._path: diff --git a/mockfirestore/document.py b/mockfirestore/document.py index 4cc10b1..96ebc4a 100644 --- a/mockfirestore/document.py +++ b/mockfirestore/document.py @@ -4,7 +4,8 @@ from typing import List, Dict, Any from mockfirestore import NotFound from mockfirestore._helpers import ( - Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path + Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path, + flatten_for_merge ) from mockfirestore._transformations import apply_transformations @@ -85,7 +86,7 @@ def set(self, data: Dict, merge=False): data['__name__'] = self.id if merge: try: - self.update(data) + self.update(flatten_for_merge(data)) except NotFound: self.set(data) else: diff --git a/mockfirestore/query.py b/mockfirestore/query.py index d14d1ef..e5a8681 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -57,12 +57,11 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: if self._limit: doc_snapshots = islice(doc_snapshots, self._limit) - return iter(doc_snapshots) + return iter(list(doc_snapshots)) def get(self) -> Iterator[DocumentSnapshot]: - warnings.warn('Query.get is deprecated, please use Query.stream', - category=DeprecationWarning) - return self.stream() + # Stream uses a generator, so we need to convert it to a list for compatibility for .get() method with firestore library + return list(self.stream()) def _add_field_filter(self, field: str, op: str, value: Any): compare = self._compare_func(op) diff --git a/tests/test_document_reference.py b/tests/test_document_reference.py index ae25341..e6cf58f 100644 --- a/tests/test_document_reference.py +++ b/tests/test_document_reference.py @@ -317,6 +317,39 @@ def test_document_update_transformerSentinel(self): doc = fs.collection("foo").document("first").get().to_dict() self.assertEqual(doc, {}) + def test_document_update_transformerSentinelNonExistentField(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'spicy': 'tuna'} + }} + fs.collection('foo').document('first').update({"nonexistent": firestore.DELETE_FIELD}) + + doc = fs.collection("foo").document("first").get().to_dict() + self.assertEqual(doc, {'spicy': 'tuna'}) + + def test_document_update_transformerSentinelNonExistentNestedField(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'spicy': 'tuna'} + }} + fs.collection('foo').document('first').update({"stats.student123.field": firestore.DELETE_FIELD}) + + doc = fs.collection("foo").document("first").get().to_dict() + self.assertEqual(doc, {'spicy': 'tuna'}) + + def test_document_update_transformerSentinelMixedExistingAndNonExistent(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'spicy': 'tuna', 'remove_me': 'gone'} + }} + fs.collection('foo').document('first').update({ + "remove_me": firestore.DELETE_FIELD, + "nonexistent": firestore.DELETE_FIELD, + }) + + doc = fs.collection("foo").document("first").get().to_dict() + self.assertEqual(doc, {'spicy': 'tuna'}) + def test_document_update_transformerArrayRemoveBasic(self): fs = MockFirestore() fs._data = {"foo": {"first": {"arr": [1, 2, 3, 4]}}}