Skip to content

Commit

Permalink
Merge branch 'master' into feat/1495-rename-JSONResponsePaginator
Browse files Browse the repository at this point in the history
  • Loading branch information
willi-mueller authored Jul 18, 2024
2 parents 55a098e + cb96e69 commit 7dfe1de
Show file tree
Hide file tree
Showing 20 changed files with 527 additions and 200 deletions.
132 changes: 89 additions & 43 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ tiktoken = "^0.4.0"

[tool.poetry.group.mongodb.dependencies]
pymongo = "^4.3.3"
pymongoarrow = ">=1.3.0"

[tool.poetry.group.airtable.dependencies]
pyairtable = "^2.1.0.post1"
Expand Down
7 changes: 7 additions & 0 deletions sources/mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Iterable, List, Optional

import dlt
from dlt.common.data_writers import TDataItemFormat
from dlt.sources import DltResource

from .helpers import (
Expand Down Expand Up @@ -78,6 +79,7 @@ def mongodb_collection(
parallel: Optional[bool] = False,
limit: Optional[int] = None,
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
) -> Any:
"""
A DLT source which loads a collection from a mongo database using PyMongo.
Expand All @@ -92,6 +94,10 @@ def mongodb_collection(
parallel (Optional[bool]): Option to enable parallel loading for the collection. Default is False.
limit (Optional[int]): The number of documents load.
chunk_size (Optional[int]): The number of documents load in each batch.
data_item_format (Optional[TDataItemFormat]): The data format to use for loading.
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -117,4 +123,5 @@ def mongodb_collection(
parallel=parallel,
limit=limit,
chunk_size=chunk_size,
data_item_format=data_item_format,
)
134 changes: 133 additions & 1 deletion sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import dlt
from bson.decimal128 import Decimal128
from bson.objectid import ObjectId
from bson.regex import Regex
from bson.timestamp import Timestamp
from dlt.common import logger
from dlt.common.configuration.specs import BaseConfiguration, configspec
from dlt.common.data_writers import TDataItemFormat
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.typing import TDataItem
from dlt.common.utils import map_nested_in_place
Expand All @@ -16,6 +19,7 @@
from pymongo.collection import Collection
from pymongo.cursor import Cursor


if TYPE_CHECKING:
TMongoClient = MongoClient[Any]
TCollection = Collection[Any] # type: ignore
Expand Down Expand Up @@ -198,13 +202,85 @@ def load_documents(self, limit: Optional[int] = None) -> Iterator[TDataItem]:
yield document


class CollectionArrowLoader(CollectionLoader):
"""
Mongo DB collection loader, which uses
Apache Arrow for data processing.
"""

def load_documents(self, limit: Optional[int] = None) -> Iterator[Any]:
"""
Load documents from the collection in Apache Arrow format.
Args:
limit (Optional[int]): The number of documents to load.
Yields:
Iterator[Any]: An iterator of the loaded documents.
"""
from pymongoarrow.context import PyMongoArrowContext # type: ignore
from pymongoarrow.lib import process_bson_stream # type: ignore

context = PyMongoArrowContext.from_schema(
None, codec_options=self.collection.codec_options
)

cursor = self.collection.find_raw_batches(
self._filter_op, batch_size=self.chunk_size
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore

cursor = self._limit(cursor, limit) # type: ignore

for batch in cursor:
process_bson_stream(batch, context)

table = context.finish()
yield convert_arrow_columns(table)


class CollectionArrowLoaderParallel(CollectionLoaderParallel):
"""
Mongo DB collection parallel loader, which uses
Apache Arrow for data processing.
"""

def _get_cursor(self) -> TCursor:
cursor = self.collection.find_raw_batches(
filter=self._filter_op, batch_size=self.chunk_size
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore

return cursor

@dlt.defer
def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
from pymongoarrow.context import PyMongoArrowContext
from pymongoarrow.lib import process_bson_stream

cursor = cursor.clone()

context = PyMongoArrowContext.from_schema(
None, codec_options=self.collection.codec_options
)

for chunk in cursor.skip(batch["skip"]).limit(batch["limit"]):
process_bson_stream(chunk, context)

table = context.finish()
yield convert_arrow_columns(table)


def collection_documents(
client: TMongoClient,
collection: TCollection,
incremental: Optional[dlt.sources.incremental[Any]] = None,
parallel: bool = False,
limit: Optional[int] = None,
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
) -> Iterator[TDataItem]:
"""
A DLT source which loads data from a Mongo database using PyMongo.
Expand All @@ -217,11 +293,24 @@ def collection_documents(
parallel (bool): Option to enable parallel loading for the collection. Default is False.
limit (Optional[int]): The maximum number of documents to load.
chunk_size (Optional[int]): The number of documents to load in each batch.
data_item_format (Optional[TDataItemFormat]): The data format to use for loading.
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
"""
LoaderClass = CollectionLoaderParallel if parallel else CollectionLoader
if parallel:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoaderParallel
elif data_item_format == "object":
LoaderClass = CollectionLoaderParallel # type: ignore
else:
if data_item_format == "arrow":
LoaderClass = CollectionArrowLoader # type: ignore
elif data_item_format == "object":
LoaderClass = CollectionLoader # type: ignore

loader = LoaderClass(
client, collection, incremental=incremental, chunk_size=chunk_size
Expand All @@ -235,9 +324,52 @@ def convert_mongo_objs(value: Any) -> Any:
return str(value)
if isinstance(value, _datetime.datetime):
return ensure_pendulum_datetime(value)
if isinstance(value, Regex):
return value.try_compile().pattern
if isinstance(value, Timestamp):
date = value.as_datetime()
return ensure_pendulum_datetime(date)

return value


def convert_arrow_columns(table: Any) -> Any:
"""Convert the given table columns to Python types.
Args:
table (pyarrow.lib.Table): The table to convert.
Returns:
pyarrow.lib.Table: The table with the columns converted.
"""
from pymongoarrow.types import _is_binary, _is_code, _is_decimal128, _is_objectid # type: ignore
from dlt.common.libs.pyarrow import pyarrow

for i, field in enumerate(table.schema):
if _is_objectid(field.type) or _is_decimal128(field.type):
col_values = [str(value) for value in table[field.name]]
table = table.set_column(
i,
pyarrow.field(field.name, pyarrow.string()),
pyarrow.array(col_values, type=pyarrow.string()),
)
else:
type_ = None
if _is_binary(field.type):
type_ = pyarrow.binary()
elif _is_code(field.type):
type_ = pyarrow.string()

if type_:
col_values = [value.as_py() for value in table[field.name]]
table = table.set_column(
i,
pyarrow.field(field.name, type_),
pyarrow.array(col_values, type=type_),
)
return table


def client_from_credentials(connection_url: str) -> TMongoClient:
client: TMongoClient = MongoClient(
connection_url, uuidRepresentation="standard", tz_aware=True
Expand Down
3 changes: 2 additions & 1 deletion sources/mongodb/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pymongo>=4.3.3
dlt>=0.3.5
pymongoarrow>=1.3.0
dlt>=0.5.1
49 changes: 47 additions & 2 deletions sources/mongodb_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import dlt
from dlt.common import pendulum
from dlt.common.data_writers import TDataItemFormat
from dlt.common.pipeline import LoadInfo
from dlt.common.typing import TDataItems
from dlt.pipeline.pipeline import Pipeline
Expand Down Expand Up @@ -45,6 +44,18 @@ def load_select_collection_db_items(parallel: bool = False) -> TDataItems:
return list(comments)


def load_select_collection_db_items_parallel(
data_item_format: TDataItemFormat, parallel: bool = False
) -> TDataItems:
comments = mongodb_collection(
incremental=dlt.sources.incremental("date"),
parallel=parallel,
data_item_format=data_item_format,
collection="comments",
)
return list(comments)


def load_select_collection_db_filtered(pipeline: Pipeline = None) -> LoadInfo:
"""Use the mongodb source to reflect an entire database schema and load select tables from it.
Expand Down Expand Up @@ -116,6 +127,37 @@ def load_entire_database(pipeline: Pipeline = None) -> LoadInfo:
return info


def load_collection_with_arrow(pipeline: Pipeline = None) -> LoadInfo:
"""
Load a MongoDB collection, using Apache
Error as the data processor.
"""
if pipeline is None:
# Create a pipeline
pipeline = dlt.pipeline(
pipeline_name="local_mongo",
destination="postgres",
dataset_name="mongo_select_incremental",
full_refresh=True,
)

# Configure the source to load data with Arrow
comments = mongodb_collection(
collection="comments",
incremental=dlt.sources.incremental(
"date",
initial_value=pendulum.DateTime(
2005, 1, 1, tzinfo=pendulum.timezone("UTC")
),
end_value=pendulum.DateTime(2005, 6, 1, tzinfo=pendulum.timezone("UTC")),
),
data_item_format="arrow",
)

info = pipeline.run(comments)
return info


if __name__ == "__main__":
# Credentials for the sample database.
# Load selected tables with different settings
Expand All @@ -125,3 +167,6 @@ def load_entire_database(pipeline: Pipeline = None) -> LoadInfo:
# Load all tables from the database.
# Warning: The sample database is large
# print(load_entire_database())

# Load data with Apache Arrow.
# print(load_collection_with_arrow())
6 changes: 5 additions & 1 deletion sources/notion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ def notion_pages(
"""
client = NotionClient(api_key)
pages = client.search(filter_criteria={"value": "page", "property": "object"})

for page in pages:
blocks = client.fetch_resource("blocks", page["id"], "children")["results"]
if page_ids and page["id"] not in page_ids:
continue
yield page

if blocks:
yield blocks


@dlt.source
Expand Down
28 changes: 16 additions & 12 deletions sources/pg_replication/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ def init_replication(
table_name=table_name,
schema_name=schema_name,
cur=cur_snap,
include_columns=None
if include_columns is None
else include_columns.get(table_name),
include_columns=(
None
if include_columns is None
else include_columns.get(table_name)
),
)
for table_name in table_names
]
Expand Down Expand Up @@ -610,13 +612,13 @@ def __init__(

self.consumed_all: bool = False
# data_items attribute maintains all data items
self.data_items: Dict[
int, List[Union[TDataItem, DataItemWithMeta]]
] = dict() # maps relation_id to list of data items
self.data_items: Dict[int, List[Union[TDataItem, DataItemWithMeta]]] = (
dict()
) # maps relation_id to list of data items
# other attributes only maintain last-seen values
self.last_table_schema: Dict[
int, TTableSchema
] = dict() # maps relation_id to table schema
self.last_table_schema: Dict[int, TTableSchema] = (
dict()
) # maps relation_id to table schema
self.last_commit_ts: pendulum.DateTime
self.last_commit_lsn = None

Expand Down Expand Up @@ -751,9 +753,11 @@ def process_change(
lsn=msg_start_lsn,
commit_ts=self.last_commit_ts,
for_delete=isinstance(decoded_msg, Delete),
include_columns=None
if self.include_columns is None
else self.include_columns.get(table_name),
include_columns=(
None
if self.include_columns is None
else self.include_columns.get(table_name)
),
)
self.data_items[decoded_msg.relation_id].append(data_item)

Expand Down
9 changes: 4 additions & 5 deletions sources/pg_replication_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,14 @@ def get_postgres_pipeline() -> dlt.Pipeline:
Uses workaround to fix destination to `postgres`, so it does not get replaced
during `dlt init`.
"""
pipe = dlt.pipeline(
# this trick prevents dlt init command from replacing "destination" argument to "pipeline"
p_call = dlt.pipeline
pipe = p_call(
pipeline_name="source_pipeline",
# destination="postgres", # don't use this, `dlt init` replaces this
destination=Destination.from_reference("postgres", credentials=PG_CREDS),
dataset_name="source_dataset",
full_refresh=True,
)
pipe.destination = Destination.from_reference(
"postgres", credentials=PG_CREDS
) # use this instead
return pipe


Expand Down
Loading

0 comments on commit 7dfe1de

Please sign in to comment.