From 2742c83e358c01f80282ffe853c0b20d487ed732 Mon Sep 17 00:00:00 2001 From: zilto Date: Thu, 16 Jan 2025 17:23:22 -0500 Subject: [PATCH] added pymongoarrow_schema; linting --- sources/mongodb/__init__.py | 31 ++++-- sources/mongodb/helpers.py | 136 +++++++++++++++++++++++---- tests/mongodb/test_mongodb_source.py | 31 +++--- 3 files changed, 160 insertions(+), 38 deletions(-) diff --git a/sources/mongodb/__init__.py b/sources/mongodb/__init__.py index 351b850e7..cea713a3e 100644 --- a/sources/mongodb/__init__.py +++ b/sources/mongodb/__init__.py @@ -24,6 +24,8 @@ def mongodb( parallel: Optional[bool] = dlt.config.value, limit: Optional[int] = None, filter_: Optional[Dict[str, Any]] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + pymongoarrow_schema: Optional[Any] = None ) -> Iterable[DltResource]: """ A DLT source which loads data from a mongo database using PyMongo. @@ -41,6 +43,13 @@ def mongodb( The maximum number of documents to load. The limit is applied to each requested collection separately. filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection. + projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields of a collection + when loading the collection. Supported inputs: + include (list) - ["year", "title"] + include (dict) - {"year": True, "title": True} + exclude (dict) - {"released": False, "runtime": False} + Note: Can't mix include and exclude statements '{"title": True, "released": False}` + pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. @@ -73,12 +82,15 @@ def mongodb( parallel=parallel, limit=limit, filter_=filter_ or {}, - projection=None, + projection=projection, + pymongoarrow_schema=pymongoarrow_schema, ) -@dlt.common.configuration.with_config( - sections=("sources", "mongodb"), spec=MongoDbCollectionResourceConfiguration +@dlt.resource( + name=lambda args: args["collection"], + standalone=True, + spec=MongoDbCollectionResourceConfiguration, ) def mongodb_collection( connection_url: str = dlt.secrets.value, @@ -91,7 +103,8 @@ def mongodb_collection( chunk_size: Optional[int] = 10000, data_item_format: Optional[TDataItemFormat] = "object", filter_: Optional[Dict[str, Any]] = None, - projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value, + pymongoarrow_schema: Optional[Any] = None ) -> Any: """ A DLT source which loads a collection from a mongo database using PyMongo. @@ -111,12 +124,13 @@ def mongodb_collection( object - Python objects (dicts, lists). arrow - Apache Arrow tables. filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection. - projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns + projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields when loading the collection. Supported inputs: include (list) - ["year", "title"] - include (dict) - {"year": 1, "title": 1} - exclude (dict) - {"released": 0, "runtime": 0} - Note: Can't mix include and exclude statements '{"title": 1, "released": 0}` + include (dict) - {"year": True, "title": True} + exclude (dict) - {"released": False, "runtime": False} + Note: Can't mix include and exclude statements '{"title": True, "released": False}` + pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types to convert BSON to Arrow Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. @@ -145,4 +159,5 @@ def mongodb_collection( data_item_format=data_item_format, filter_=filter_ or {}, projection=projection, + pymongoarrow_schema=pymongoarrow_schema, ) diff --git a/sources/mongodb/helpers.py b/sources/mongodb/helpers.py index a9f2b5eba..817a58fdd 100644 --- a/sources/mongodb/helpers.py +++ b/sources/mongodb/helpers.py @@ -107,10 +107,10 @@ def _filter_op(self) -> Dict[str, Any]: filt[self.cursor_field]["$gt"] = self.incremental.end_value return filt - - def _projection_op(self, projection) -> Optional[Dict[str, Any]]: + + def _projection_op(self, projection:Optional[Union[Mapping[str, Any], Iterable[str]]]) -> Optional[Dict[str, Any]]: """Build a projection operator. - + A tuple of fields to include or a dict specifying fields to include or exclude. The incremental `primary_key` needs to be handle differently for inclusion and exclusion projections. @@ -123,17 +123,16 @@ def _projection_op(self, projection) -> Optional[Dict[str, Any]]: projection_dict = dict(_fields_list_to_dict(projection, "projection")) - # NOTE we can still filter on primary_key if it's excluded from projection if self.incremental: # this is an inclusion projection - if any(v == 1 for v in projection.values()): + if any(v == 1 for v in projection_dict.values()): # ensure primary_key is included - projection_dict.update({self.incremental.primary_key: 1}) + projection_dict.update(m={self.incremental.primary_key: 1}) # this is an exclusion projection else: try: # ensure primary_key isn't excluded - projection_dict.pop(self.incremental.primary_key) + projection_dict.pop(self.incremental.primary_key) # type: ignore except KeyError: pass # primary_key was properly not included in exclusion projection else: @@ -174,6 +173,7 @@ def load_documents( Args: filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. + projection: selection of fields to create Cursor Yields: Iterator[TDataItem]: An iterator of the loaded documents. @@ -279,6 +279,7 @@ def load_documents( Args: filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. + projection: selection of fields to create Cursor Yields: Iterator[TDataItem]: An iterator of the loaded documents. @@ -300,6 +301,7 @@ def load_documents( filter_: Dict[str, Any], limit: Optional[int] = None, projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + pymongoarrow_schema: Any = None, ) -> Iterator[Any]: """ Load documents from the collection in Apache Arrow format. @@ -307,12 +309,14 @@ def load_documents( Args: filter_ (Dict[str, Any]): The filter to apply to the collection. limit (Optional[int]): The number of documents to load. + projection: selection of fields to create Cursor + pymongoarrow_schema: mapping of field types to convert BSON to Arrow Yields: Iterator[Any]: An iterator of the loaded documents. """ from pymongoarrow.context import PyMongoArrowContext # type: ignore - from pymongoarrow.lib import process_bson_stream + from pymongoarrow.lib import process_bson_stream # type: ignore filter_op = self._filter_op _raise_if_intersection(filter_op, filter_) @@ -330,7 +334,8 @@ def load_documents( cursor = self._limit(cursor, limit) # type: ignore context = PyMongoArrowContext.from_schema( - None, codec_options=self.collection.codec_options + schema=pymongoarrow_schema, + codec_options=self.collection.codec_options ) for batch in cursor: process_bson_stream(batch, context) @@ -343,6 +348,58 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel): Mongo DB collection parallel loader, which uses Apache Arrow for data processing. """ + def load_documents( + self, + filter_: Dict[str, Any], + limit: Optional[int] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + pymongoarrow_schema: Any = None, + ) -> Iterator[TDataItem]: + """Load documents from the collection in parallel. + + Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. + limit (Optional[int]): The number of documents to load. + projection: selection of fields to create Cursor + pymongoarrow_schema: mapping of field types to convert BSON to Arrow + + Yields: + Iterator[TDataItem]: An iterator of the loaded documents. + """ + yield from self._get_all_batches( + limit=limit, + filter_=filter_, + projection=projection, + pymongoarrow_schema=pymongoarrow_schema + ) + + def _get_all_batches( + self, + filter_: Dict[str, Any], + limit: Optional[int] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + pymongoarrow_schema: Any = None, + ) -> Iterator[TDataItem]: + """Load all documents from the collection in parallel batches. + + Args: + filter_ (Dict[str, Any]): The filter to apply to the collection. + limit (Optional[int]): The maximum number of documents to load. + projection: selection of fields to create Cursor + pymongoarrow_schema: mapping of field types to convert BSON to Arrow + + Yields: + Iterator[TDataItem]: An iterator of the loaded documents. + """ + batches = self._create_batches(limit=limit) + cursor = self._get_cursor(filter_=filter_, projection=projection) + for batch in batches: + yield self._run_batch( + cursor=cursor, + batch=batch, + pymongoarrow_schema=pymongoarrow_schema, + ) + def _get_cursor( self, filter_: Dict[str, Any], @@ -352,6 +409,7 @@ def _get_cursor( Args: filter_ (Dict[str, Any]): The filter to apply to the collection. + projection: selection of fields to create Cursor Returns: Cursor: The cursor for the collection. @@ -371,14 +429,20 @@ def _get_cursor( return cursor @dlt.defer - def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem: + def _run_batch( + self, + cursor: TCursor, + batch: Dict[str, int], + pymongoarrow_schema: Any = None, + ) -> TDataItem: from pymongoarrow.context import PyMongoArrowContext from pymongoarrow.lib import process_bson_stream cursor = cursor.clone() context = PyMongoArrowContext.from_schema( - None, codec_options=self.collection.codec_options + schema=pymongoarrow_schema, + codec_options=self.collection.codec_options ) for chunk in cursor.skip(batch["skip"]).limit(batch["limit"]): process_bson_stream(chunk, context) @@ -390,7 +454,8 @@ def collection_documents( client: TMongoClient, collection: TCollection, filter_: Dict[str, Any], - projection: Union[Dict[str, Any], List[str]], # TODO kwargs reserved for dlt? + projection: Union[Dict[str, Any], List[str]], + pymongoarrow_schema: "pymongoarrow.schema.Schema", incremental: Optional[dlt.sources.incremental[Any]] = None, parallel: bool = False, limit: Optional[int] = None, @@ -413,12 +478,13 @@ def collection_documents( Supported formats: object - Python objects (dicts, lists). arrow - Apache Arrow tables. - projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns + projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields when loading the collection. Supported inputs: include (list) - ["year", "title"] - include (dict) - {"year": 1, "title": 1} - exclude (dict) - {"released": 0, "runtime": 0} - Note: Can't mix include and exclude statements '{"title": 1, "released": 0}` + include (dict) - {"year": True, "title": True} + exclude (dict) - {"released": False, "runtime": False} + Note: Can't mix include and exclude statements '{"title": True, "released": False}` + pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow Returns: Iterable[DltResource]: A list of DLT resources for each collection to be loaded. @@ -429,6 +495,19 @@ def collection_documents( ) data_item_format = "object" + if data_item_format != "arrow" and pymongoarrow_schema: + dlt.common.logger.warn( + "Received value for `pymongoarrow_schema`, but `data_item_format=='object'` " + "Use `data_item_format=='arrow'` to enforce schema." + ) + + if data_item_format == "arrow" and pymongoarrow_schema and projection: + dlt.common.logger.warn( + "Received values for both `pymongoarrow_schema` and `projection`. Since both " + "create a projection to select fields, `projection` will be ignored." + ) + + if parallel: if data_item_format == "arrow": LoaderClass = CollectionArrowLoaderParallel @@ -443,11 +522,24 @@ def collection_documents( loader = LoaderClass( client, collection, incremental=incremental, chunk_size=chunk_size ) - for data in loader.load_documents(limit=limit, filter_=filter_, projection=projection): - yield data + if isinstance(loader, (CollectionArrowLoader, CollectionArrowLoaderParallel)): + yield from loader.load_documents( + limit=limit, + filter_=filter_, + projection=projection, + pymongoarrow_schema=pymongoarrow_schema, + ) + else: + yield from loader.load_documents(limit=limit, filter_=filter_, projection=projection) def convert_mongo_objs(value: Any) -> Any: + """MongoDB to dlt type conversion when using Python loaders. + + Notes: + The method `ObjectId.__str__()` creates an hexstring using `binascii.hexlify(__id).decode()` + + """ if isinstance(value, (ObjectId, Decimal128)): return str(value) if isinstance(value, _datetime.datetime): @@ -464,6 +556,13 @@ def convert_mongo_objs(value: Any) -> Any: def convert_arrow_columns(table: Any) -> Any: """Convert the given table columns to Python types. + Notes: + Calling str() matches the `convert_mongo_obs()` used in non-arrow code. + Pymongoarrow converts ObjectId to `fixed_size_binary[12]`, which can't be + converted to a string as a vectorized operation because it contains ASCII characters. + + Instead, you need to loop over values using: `value.as_buffer().hex().decode()` + Args: table (pyarrow.lib.Table): The table to convert. @@ -539,6 +638,7 @@ class MongoDbCollectionResourceConfiguration(BaseConfiguration): incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] write_disposition: Optional[str] = dlt.config.value parallel: Optional[bool] = False + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value __source_name__ = "mongodb" diff --git a/tests/mongodb/test_mongodb_source.py b/tests/mongodb/test_mongodb_source.py index 88f4ba1c5..ec75a6dea 100644 --- a/tests/mongodb/test_mongodb_source.py +++ b/tests/mongodb/test_mongodb_source.py @@ -422,9 +422,7 @@ def test_projection_list_inclusion(destination_name): expected_columns = projection + ["_id", "_dlt_id", "_dlt_load_id"] movies = mongodb_collection( - collection=collection_name, - projection=projection, - limit=2 + collection=collection_name, projection=projection, limit=2 ) pipeline.run(movies) loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys() @@ -445,9 +443,7 @@ def test_projection_dict_inclusion(destination_name): expected_columns = list(projection.keys()) + ["_id", "_dlt_id", "_dlt_load_id"] movies = mongodb_collection( - collection=collection_name, - projection=projection, - limit=2 + collection=collection_name, projection=projection, limit=2 ) pipeline.run(movies) loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys() @@ -465,17 +461,28 @@ def test_projection_dict_exclusion(destination_name): ) collection_name = "movies" columns_to_exclude = [ - "runtime", "released", "year", "plot", "fullplot", "lastupdated", "type", - "directors", "imdb", "cast", "countries", "genres", "tomatoes", "num_mflix_comments", - "rated", "awards" + "runtime", + "released", + "year", + "plot", + "fullplot", + "lastupdated", + "type", + "directors", + "imdb", + "cast", + "countries", + "genres", + "tomatoes", + "num_mflix_comments", + "rated", + "awards", ] projection = {col: 0 for col in columns_to_exclude} expected_columns = ["title", "poster", "_id", "_dlt_id", "_dlt_load_id"] movies = mongodb_collection( - collection=collection_name, - projection=projection, - limit=2 + collection=collection_name, projection=projection, limit=2 ) pipeline.run(movies) loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys()