|
| 1 | +import contextlib |
| 2 | +import itertools |
| 3 | +import os |
| 4 | +import sqlite3 |
| 5 | +import tempfile |
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, cast |
| 8 | + |
| 9 | +import requests |
| 10 | +from psycopg2.sql import SQL, Identifier |
| 11 | + |
| 12 | +from splitgraph.core.types import ( |
| 13 | + Credentials, |
| 14 | + IntrospectionResult, |
| 15 | + Params, |
| 16 | + TableColumn, |
| 17 | + TableInfo, |
| 18 | + TableParams, |
| 19 | +) |
| 20 | +from splitgraph.engine.postgres.engine import _quote_ident |
| 21 | +from splitgraph.hooks.data_source.base import LoadableDataSource |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from splitgraph.engine.postgres.engine import PostgresEngine |
| 25 | + |
| 26 | +LIST_TABLES_QUERY = """ |
| 27 | +SELECT * FROM |
| 28 | +(SELECT |
| 29 | + tbl_name |
| 30 | +FROM sqlite_master |
| 31 | +WHERE type='table') t |
| 32 | +JOIN pragma_table_info(tbl_name) s |
| 33 | +ORDER BY 1,2; |
| 34 | +""" |
| 35 | + |
| 36 | + |
| 37 | +# based on https://stackoverflow.com/a/16696317 |
| 38 | +def download_file(url: str, local_fh: tempfile._TemporaryFileWrapper) -> int: |
| 39 | + total_bytes_written = 0 |
| 40 | + with requests.get(url, stream=True, verify=os.environ["SSL_CERT_FILE"]) as r: |
| 41 | + r.raise_for_status() |
| 42 | + for chunk in r.iter_content(chunk_size=8192): |
| 43 | + total_bytes_written += local_fh.write(chunk) |
| 44 | + local_fh.flush() |
| 45 | + return total_bytes_written |
| 46 | + |
| 47 | + |
| 48 | +@contextmanager |
| 49 | +def minio_file(url: str) -> Generator[str, None, None]: |
| 50 | + with tempfile.NamedTemporaryFile(mode="wb", delete=True) as local_fh: |
| 51 | + download_file(url, local_fh) |
| 52 | + yield local_fh.name |
| 53 | + |
| 54 | + |
| 55 | +def query_connection( |
| 56 | + con: sqlite3.Connection, sql: str, parameters: Optional[Dict[str, str]] = None |
| 57 | +) -> List[Any]: |
| 58 | + with contextlib.closing(con.cursor()) as cursor: |
| 59 | + cursor.execute(sql, parameters or {}) |
| 60 | + return cursor.fetchall() |
| 61 | + |
| 62 | + |
| 63 | +@contextmanager |
| 64 | +def db_from_minio(url: str) -> Generator[sqlite3.Connection, None, None]: |
| 65 | + with minio_file(url) as f: |
| 66 | + with contextlib.closing(sqlite3.connect(f)) as con: |
| 67 | + yield con |
| 68 | + |
| 69 | + |
| 70 | +# partly based on https://stackoverflow.com/questions/1942586/comparison-of-database-column-types-in-mysql-postgresql-and-sqlite-cross-map |
| 71 | +def sqlite_to_postgres_type(sqlite_type: str) -> str: |
| 72 | + if sqlite_type == "DATETIME": |
| 73 | + return "TIMESTAMP WITHOUT TIME ZONE" |
| 74 | + # from: https://www.sqlite.org/datatype3.html#determination_of_column_affinity |
| 75 | + # If the declared type contains the string "INT" then it is assigned INTEGER affinity. |
| 76 | + if "INT" in sqlite_type: |
| 77 | + return "INTEGER" |
| 78 | + # If the declared type of the column contains any of the strings "CHAR", "CLOB", or "TEXT" then that column has TEXT affinity. Notice that the type VARCHAR contains the string "CHAR" and is thus assigned TEXT affinity. |
| 79 | + if "CHAR" in sqlite_type or "CLOB" in sqlite_type or "TEXT" in sqlite_type: |
| 80 | + return "TEXT" |
| 81 | + # If the declared type for a column contains the string "BLOB" or if no type is specified then the column has affinity BLOB. |
| 82 | + if "BLOB" in sqlite_type: |
| 83 | + return "BLOB" |
| 84 | + # If the declared type for a column contains any of the strings "REAL", "FLOA", or "DOUB" then the column has REAL affinity. |
| 85 | + if "REAL" in sqlite_type or "FLOA" in sqlite_type or "DOUB" in sqlite_type: |
| 86 | + return "REAL" |
| 87 | + # Otherwise, the affinity is NUMERIC. TODO: Precision and scale |
| 88 | + return "NUMERIC" |
| 89 | + |
| 90 | + |
| 91 | +def sqlite_connection_to_introspection_result(con: sqlite3.Connection) -> IntrospectionResult: |
| 92 | + schema = IntrospectionResult({}) |
| 93 | + for ( |
| 94 | + table_name, |
| 95 | + column_id, |
| 96 | + column_name, |
| 97 | + column_type, |
| 98 | + _notnull, |
| 99 | + _default_value, |
| 100 | + pk, |
| 101 | + ) in query_connection(con, LIST_TABLES_QUERY, {}): |
| 102 | + table = schema.get(table_name, ([], TableParams({}))) |
| 103 | + assert isinstance(table, tuple) |
| 104 | + table[0].append( |
| 105 | + TableColumn(column_id + 1, column_name, sqlite_to_postgres_type(column_type), pk != 0) |
| 106 | + ) |
| 107 | + schema[table_name] = table |
| 108 | + return schema |
| 109 | + |
| 110 | + |
| 111 | +class SQLiteDataSource(LoadableDataSource): |
| 112 | + |
| 113 | + table_params_schema: Dict[str, Any] = {"type": "object", "properties": {}} |
| 114 | + |
| 115 | + params_schema: Dict[str, Any] = { |
| 116 | + "type": "object", |
| 117 | + "properties": { |
| 118 | + "url": { |
| 119 | + "type": "string", |
| 120 | + "description": "HTTP URL to the SQLite file", |
| 121 | + "title": "URL", |
| 122 | + } |
| 123 | + }, |
| 124 | + } |
| 125 | + |
| 126 | + supports_mount = False |
| 127 | + supports_load = True |
| 128 | + supports_sync = False |
| 129 | + |
| 130 | + _icon_file = "sqlite.svg" # TODO |
| 131 | + |
| 132 | + def _load(self, schema: str, tables: Optional[TableInfo] = None): |
| 133 | + with db_from_minio(str(self.params.get("url"))) as con: |
| 134 | + introspection_result = sqlite_connection_to_introspection_result(con) |
| 135 | + for table_name, table_definition in introspection_result.items(): |
| 136 | + assert isinstance(table_definition, tuple) |
| 137 | + schema_spec = table_definition[0] |
| 138 | + self.engine.create_table( |
| 139 | + schema=schema, |
| 140 | + table=table_name, |
| 141 | + schema_spec=schema_spec, |
| 142 | + ) |
| 143 | + table_contents = query_connection( |
| 144 | + con, "SELECT * FROM {}".format(_quote_ident(table_name)) # nosec |
| 145 | + ) |
| 146 | + self.engine.run_sql_batch( |
| 147 | + SQL("INSERT INTO {0}.{1} ").format(Identifier(schema), Identifier(table_name)) |
| 148 | + + SQL(" VALUES (" + ",".join(itertools.repeat("%s", len(schema_spec))) + ")"), |
| 149 | + # TODO: break this up into multiple batches for larger sqlite files |
| 150 | + table_contents, |
| 151 | + ) # nosec |
| 152 | + |
| 153 | + def introspect(self) -> IntrospectionResult: |
| 154 | + with db_from_minio(str(self.params.get("url"))) as con: |
| 155 | + return sqlite_connection_to_introspection_result(con) |
| 156 | + |
| 157 | + def __init__( |
| 158 | + self, |
| 159 | + engine: "PostgresEngine", |
| 160 | + credentials: Credentials, |
| 161 | + params: Params, |
| 162 | + tables: Optional[TableInfo] = None, |
| 163 | + ): |
| 164 | + super().__init__(engine, credentials, params, tables) |
| 165 | + |
| 166 | + @classmethod |
| 167 | + def get_name(cls) -> str: |
| 168 | + return "SQLite files" |
| 169 | + |
| 170 | + @classmethod |
| 171 | + def get_description(cls) -> str: |
| 172 | + return "SQLite files" |
| 173 | + |
| 174 | + def get_remote_schema_name(self) -> str: |
| 175 | + # We ignore the schema name and use the bucket/prefix passed in the params instead. |
| 176 | + return "data" |
0 commit comments