Skip to content

Commit

Permalink
added pymongoarrow_schema; linting
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto committed Jan 16, 2025
1 parent 239a536 commit 2742c83
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 38 deletions.
31 changes: 23 additions & 8 deletions sources/mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def mongodb(
parallel: Optional[bool] = dlt.config.value,
limit: Optional[int] = None,
filter_: Optional[Dict[str, Any]] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
pymongoarrow_schema: Optional[Any] = None
) -> Iterable[DltResource]:
"""
A DLT source which loads data from a mongo database using PyMongo.
Expand All @@ -41,6 +43,13 @@ def mongodb(
The maximum number of documents to load. The limit is
applied to each requested collection separately.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields of a collection
when loading the collection. Supported inputs:
include (list) - ["year", "title"]
include (dict) - {"year": True, "title": True}
exclude (dict) - {"released": False, "runtime": False}
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand Down Expand Up @@ -73,12 +82,15 @@ def mongodb(
parallel=parallel,
limit=limit,
filter_=filter_ or {},
projection=None,
projection=projection,
pymongoarrow_schema=pymongoarrow_schema,
)


@dlt.common.configuration.with_config(
sections=("sources", "mongodb"), spec=MongoDbCollectionResourceConfiguration
@dlt.resource(
name=lambda args: args["collection"],
standalone=True,
spec=MongoDbCollectionResourceConfiguration,
)
def mongodb_collection(
connection_url: str = dlt.secrets.value,
Expand All @@ -91,7 +103,8 @@ def mongodb_collection(
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
filter_: Optional[Dict[str, Any]] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value,
pymongoarrow_schema: Optional[Any] = None
) -> Any:
"""
A DLT source which loads a collection from a mongo database using PyMongo.
Expand All @@ -111,12 +124,13 @@ def mongodb_collection(
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields
when loading the collection. Supported inputs:
include (list) - ["year", "title"]
include (dict) - {"year": 1, "title": 1}
exclude (dict) - {"released": 0, "runtime": 0}
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
include (dict) - {"year": True, "title": True}
exclude (dict) - {"released": False, "runtime": False}
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types to convert BSON to Arrow
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand Down Expand Up @@ -145,4 +159,5 @@ def mongodb_collection(
data_item_format=data_item_format,
filter_=filter_ or {},
projection=projection,
pymongoarrow_schema=pymongoarrow_schema,
)
136 changes: 118 additions & 18 deletions sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def _filter_op(self) -> Dict[str, Any]:
filt[self.cursor_field]["$gt"] = self.incremental.end_value

return filt
def _projection_op(self, projection) -> Optional[Dict[str, Any]]:

def _projection_op(self, projection:Optional[Union[Mapping[str, Any], Iterable[str]]]) -> Optional[Dict[str, Any]]:
"""Build a projection operator.
A tuple of fields to include or a dict specifying fields to include or exclude.
The incremental `primary_key` needs to be handle differently for inclusion
and exclusion projections.
Expand All @@ -123,17 +123,16 @@ def _projection_op(self, projection) -> Optional[Dict[str, Any]]:

projection_dict = dict(_fields_list_to_dict(projection, "projection"))

# NOTE we can still filter on primary_key if it's excluded from projection
if self.incremental:
# this is an inclusion projection
if any(v == 1 for v in projection.values()):
if any(v == 1 for v in projection_dict.values()):
# ensure primary_key is included
projection_dict.update({self.incremental.primary_key: 1})
projection_dict.update(m={self.incremental.primary_key: 1})
# this is an exclusion projection
else:
try:
# ensure primary_key isn't excluded
projection_dict.pop(self.incremental.primary_key)
projection_dict.pop(self.incremental.primary_key) # type: ignore
except KeyError:
pass # primary_key was properly not included in exclusion projection
else:
Expand Down Expand Up @@ -174,6 +173,7 @@ def load_documents(
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
projection: selection of fields to create Cursor
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
Expand Down Expand Up @@ -279,6 +279,7 @@ def load_documents(
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
projection: selection of fields to create Cursor
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
Expand All @@ -300,19 +301,22 @@ def load_documents(
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
pymongoarrow_schema: Any = None,
) -> Iterator[Any]:
"""
Load documents from the collection in Apache Arrow format.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
projection: selection of fields to create Cursor
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
Yields:
Iterator[Any]: An iterator of the loaded documents.
"""
from pymongoarrow.context import PyMongoArrowContext # type: ignore
from pymongoarrow.lib import process_bson_stream
from pymongoarrow.lib import process_bson_stream # type: ignore

filter_op = self._filter_op
_raise_if_intersection(filter_op, filter_)
Expand All @@ -330,7 +334,8 @@ def load_documents(
cursor = self._limit(cursor, limit) # type: ignore

context = PyMongoArrowContext.from_schema(
None, codec_options=self.collection.codec_options
schema=pymongoarrow_schema,
codec_options=self.collection.codec_options
)
for batch in cursor:
process_bson_stream(batch, context)
Expand All @@ -343,6 +348,58 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
Mongo DB collection parallel loader, which uses
Apache Arrow for data processing.
"""
def load_documents(
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
pymongoarrow_schema: Any = None,
) -> Iterator[TDataItem]:
"""Load documents from the collection in parallel.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The number of documents to load.
projection: selection of fields to create Cursor
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
yield from self._get_all_batches(
limit=limit,
filter_=filter_,
projection=projection,
pymongoarrow_schema=pymongoarrow_schema
)

def _get_all_batches(
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
pymongoarrow_schema: Any = None,
) -> Iterator[TDataItem]:
"""Load all documents from the collection in parallel batches.
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
limit (Optional[int]): The maximum number of documents to load.
projection: selection of fields to create Cursor
pymongoarrow_schema: mapping of field types to convert BSON to Arrow
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
batches = self._create_batches(limit=limit)
cursor = self._get_cursor(filter_=filter_, projection=projection)
for batch in batches:
yield self._run_batch(
cursor=cursor,
batch=batch,
pymongoarrow_schema=pymongoarrow_schema,
)

def _get_cursor(
self,
filter_: Dict[str, Any],
Expand All @@ -352,6 +409,7 @@ def _get_cursor(
Args:
filter_ (Dict[str, Any]): The filter to apply to the collection.
projection: selection of fields to create Cursor
Returns:
Cursor: The cursor for the collection.
Expand All @@ -371,14 +429,20 @@ def _get_cursor(
return cursor

@dlt.defer
def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
def _run_batch(
self,
cursor: TCursor,
batch: Dict[str, int],
pymongoarrow_schema: Any = None,
) -> TDataItem:
from pymongoarrow.context import PyMongoArrowContext
from pymongoarrow.lib import process_bson_stream

cursor = cursor.clone()

context = PyMongoArrowContext.from_schema(
None, codec_options=self.collection.codec_options
schema=pymongoarrow_schema,
codec_options=self.collection.codec_options
)
for chunk in cursor.skip(batch["skip"]).limit(batch["limit"]):
process_bson_stream(chunk, context)
Expand All @@ -390,7 +454,8 @@ def collection_documents(
client: TMongoClient,
collection: TCollection,
filter_: Dict[str, Any],
projection: Union[Dict[str, Any], List[str]], # TODO kwargs reserved for dlt?
projection: Union[Dict[str, Any], List[str]],
pymongoarrow_schema: "pymongoarrow.schema.Schema",
incremental: Optional[dlt.sources.incremental[Any]] = None,
parallel: bool = False,
limit: Optional[int] = None,
Expand All @@ -413,12 +478,13 @@ def collection_documents(
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select fields
when loading the collection. Supported inputs:
include (list) - ["year", "title"]
include (dict) - {"year": 1, "title": 1}
exclude (dict) - {"released": 0, "runtime": 0}
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
include (dict) - {"year": True, "title": True}
exclude (dict) - {"released": False, "runtime": False}
Note: Can't mix include and exclude statements '{"title": True, "released": False}`
pymongoarrow_schema (pymongoarrow.schema.Schema): Mapping of expected field types of a collection to convert BSON to Arrow
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -429,6 +495,19 @@ def collection_documents(
)
data_item_format = "object"

if data_item_format != "arrow" and pymongoarrow_schema:
dlt.common.logger.warn(
"Received value for `pymongoarrow_schema`, but `data_item_format=='object'` "
"Use `data_item_format=='arrow'` to enforce schema."
)

if data_item_format == "arrow" and pymongoarrow_schema and projection:
dlt.common.logger.warn(
"Received values for both `pymongoarrow_schema` and `projection`. Since both "
"create a projection to select fields, `projection` will be ignored."
)


if parallel:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoaderParallel
Expand All @@ -443,11 +522,24 @@ def collection_documents(
loader = LoaderClass(
client, collection, incremental=incremental, chunk_size=chunk_size
)
for data in loader.load_documents(limit=limit, filter_=filter_, projection=projection):
yield data
if isinstance(loader, (CollectionArrowLoader, CollectionArrowLoaderParallel)):
yield from loader.load_documents(
limit=limit,
filter_=filter_,
projection=projection,
pymongoarrow_schema=pymongoarrow_schema,
)
else:
yield from loader.load_documents(limit=limit, filter_=filter_, projection=projection)


def convert_mongo_objs(value: Any) -> Any:
"""MongoDB to dlt type conversion when using Python loaders.
Notes:
The method `ObjectId.__str__()` creates an hexstring using `binascii.hexlify(__id).decode()`
"""
if isinstance(value, (ObjectId, Decimal128)):
return str(value)
if isinstance(value, _datetime.datetime):
Expand All @@ -464,6 +556,13 @@ def convert_mongo_objs(value: Any) -> Any:
def convert_arrow_columns(table: Any) -> Any:
"""Convert the given table columns to Python types.
Notes:
Calling str() matches the `convert_mongo_obs()` used in non-arrow code.
Pymongoarrow converts ObjectId to `fixed_size_binary[12]`, which can't be
converted to a string as a vectorized operation because it contains ASCII characters.
Instead, you need to loop over values using: `value.as_buffer().hex().decode()`
Args:
table (pyarrow.lib.Table): The table to convert.
Expand Down Expand Up @@ -539,6 +638,7 @@ class MongoDbCollectionResourceConfiguration(BaseConfiguration):
incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
write_disposition: Optional[str] = dlt.config.value
parallel: Optional[bool] = False
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = dlt.config.value


__source_name__ = "mongodb"
Loading

0 comments on commit 2742c83

Please sign in to comment.