Skip to content

Commit 6544223

Browse files
committed
Cast arrays using Python instead of Postgres db
1 parent 4fdaac8 commit 6544223

File tree

3 files changed

+156
-92
lines changed

3 files changed

+156
-92
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import setup
44

55
setup(name='tap-postgres',
6-
version='0.0.65',
6+
version='0.0.66',
77
description='Singer.io tap for extracting data from PostgreSQL',
88
author='Stitch',
99
url='https://singer.io',

tap_postgres/sync_strategies/logical_replication.py

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
#!/usr/bin/env python3
22
# pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,invalid-name,too-many-return-statements,too-many-branches,len-as-condition,too-many-nested-blocks,wrong-import-order,duplicate-code, anomalous-backslash-in-string, too-many-statements, singleton-comparison, consider-using-in
33

4-
import singer
4+
from functools import reduce
5+
from select import select
6+
import copy
7+
import csv
58
import datetime
69
import decimal
10+
import json
11+
import re
12+
13+
from dateutil.parser import parse
14+
import psycopg2
15+
import singer
716
from singer import utils, get_bookmark
817
import singer.metadata as metadata
918
import tap_postgres.db as post_db
1019
import tap_postgres.sync_strategies.common as sync_common
11-
from dateutil.parser import parse
12-
import psycopg2
13-
from psycopg2 import sql
14-
import copy
15-
from select import select
16-
from functools import reduce
17-
import json
18-
import re
20+
1921

2022
LOGGER = singer.get_logger()
2123

@@ -65,81 +67,29 @@ def get_stream_version(tap_stream_id, state):
6567

6668
return stream_version
6769

68-
def tuples_to_map(accum, t):
69-
accum[t[0]] = t[1]
70-
return accum
71-
72-
def create_hstore_elem_query(elem):
73-
return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem))
74-
75-
def create_hstore_elem(conn_info, elem):
76-
with post_db.open_connection(conn_info) as conn:
77-
with conn.cursor() as cur:
78-
query = create_hstore_elem_query(elem)
79-
cur.execute(query)
80-
res = cur.fetchone()[0]
81-
hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {})
82-
return hstore_elem
83-
84-
def create_array_elem(elem, sql_datatype, conn_info):
70+
def create_hstore_elem(elem):
71+
array = [(item.replace('"', '').split('=>')) for item in elem]
72+
hstore = {}
73+
for item in array:
74+
if len(item) == 2:
75+
key, value = item
76+
if key in hstore:
77+
raise KeyError('Duplicate key {} found when creating hstore'.format(key))
78+
if value.lower() == 'null':
79+
value = None
80+
d[key] = value
81+
82+
return hstore
83+
84+
def create_array_elem(elem):
8585
if elem is None:
8686
return None
8787

88-
with post_db.open_connection(conn_info) as conn:
89-
with conn.cursor() as cur:
90-
if sql_datatype == 'bit[]':
91-
cast_datatype = 'boolean[]'
92-
elif sql_datatype == 'boolean[]':
93-
cast_datatype = 'boolean[]'
94-
elif sql_datatype == 'character varying[]':
95-
cast_datatype = 'character varying[]'
96-
elif sql_datatype == 'cidr[]':
97-
cast_datatype = 'cidr[]'
98-
elif sql_datatype == 'citext[]':
99-
cast_datatype = 'text[]'
100-
elif sql_datatype == 'date[]':
101-
cast_datatype = 'text[]'
102-
elif sql_datatype == 'double precision[]':
103-
cast_datatype = 'double precision[]'
104-
elif sql_datatype == 'hstore[]':
105-
cast_datatype = 'text[]'
106-
elif sql_datatype == 'integer[]':
107-
cast_datatype = 'integer[]'
108-
elif sql_datatype == 'bigint[]':
109-
cast_datatype = 'bigint[]'
110-
elif sql_datatype == 'inet[]':
111-
cast_datatype = 'inet[]'
112-
elif sql_datatype == 'json[]':
113-
cast_datatype = 'text[]'
114-
elif sql_datatype == 'jsonb[]':
115-
cast_datatype = 'text[]'
116-
elif sql_datatype == 'macaddr[]':
117-
cast_datatype = 'macaddr[]'
118-
elif sql_datatype == 'money[]':
119-
cast_datatype = 'text[]'
120-
elif sql_datatype == 'numeric[]':
121-
cast_datatype = 'text[]'
122-
elif sql_datatype == 'real[]':
123-
cast_datatype = 'real[]'
124-
elif sql_datatype == 'smallint[]':
125-
cast_datatype = 'smallint[]'
126-
elif sql_datatype == 'text[]':
127-
cast_datatype = 'text[]'
128-
elif sql_datatype in ('time without time zone[]', 'time with time zone[]'):
129-
cast_datatype = 'text[]'
130-
elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'):
131-
cast_datatype = 'text[]'
132-
elif sql_datatype == 'uuid[]':
133-
cast_datatype = 'text[]'
134-
135-
else:
136-
#custom datatypes like enums
137-
cast_datatype = 'text[]'
138-
139-
sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype)
140-
cur.execute(sql_stmt)
141-
res = cur.fetchone()[0]
142-
return res
88+
elem = [elem[1:-1]]
89+
reader = csv.reader(elem, delimiter=',', escapechar='\\' , quotechar='"')
90+
array = next(reader)
91+
array = [None if element.lower() == 'null' else element for element in array]
92+
return array
14393

14494
#pylint: disable=too-many-branches,too-many-nested-blocks
14595
def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
@@ -166,17 +116,21 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
166116
#for ordinary bits, elem will == '1'
167117
return elem == '1' or elem == True
168118
if sql_datatype == 'boolean':
169-
return elem
119+
return bool(elem)
170120
if sql_datatype == 'hstore':
171-
return create_hstore_elem(conn_info, elem)
121+
return create_hstore_elem(elem)
172122
if 'numeric' in sql_datatype:
173-
return decimal.Decimal(str(elem))
174-
if isinstance(elem, int):
175-
return elem
176-
if isinstance(elem, float):
177-
return elem
178-
if isinstance(elem, str):
179-
return elem
123+
return decimal.Decimal(elem)
124+
if sql_datatype == 'money':
125+
return decimal.Decimal(elem[1:])
126+
if sql_datatype in ('integer', 'smallint', 'bigint'):
127+
return int(elem)
128+
if sql_datatype in ('double precision', 'real', 'float'):
129+
return float(elem)
130+
if sql_datatype in ('text', 'character varying'):
131+
return elem # return as string
132+
if sql_datatype in ('cidr', 'citext', 'json', 'jsonb', 'inet', 'macaddr', 'uuid'):
133+
return elem # return as string
180134

181135
raise Exception("do not know how to marshall value of type {}".format(elem.__class__))
182136

@@ -189,7 +143,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info):
189143
def selected_value_to_singer_value(elem, sql_datatype, conn_info):
190144
#are we dealing with an array?
191145
if sql_datatype.find('[]') > 0:
192-
cleaned_elem = create_array_elem(elem, sql_datatype, conn_info)
146+
cleaned_elem = create_array_elem(elem)
193147
return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or [])))
194148

195149
return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info)

tests/test_logical_replication.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from decimal import Decimal
2+
import unittest
3+
from unittest.mock import patch
4+
5+
from utils import get_test_connection_config
6+
from tap_postgres.sync_strategies import logical_replication
7+
8+
9+
class TestHandlingArrays(unittest.TestCase):
10+
def setUp(self):
11+
self.env = patch.dict(
12+
'os.environ', {
13+
'TAP_POSTGRES_HOST':'test',
14+
'TAP_POSTGRES_USER':'test',
15+
'TAP_POSTGRES_PASSWORD':'test',
16+
'TAP_POSTGRES_PORT':'5432'
17+
},
18+
)
19+
20+
self.arrays = [
21+
'{10,01,NULL}',
22+
'{t,f,NULL}',
23+
'{127.0.0.1/32,10.0.0.0/32,NULL}',
24+
'{CASE_INSENSITIVE,case_insensitive,NULL,"CASE,,INSENSITIVE"}',
25+
'{2000-12-31,2001-01-01,NULL}',
26+
'{3.14159265359,3.1415926,NULL}',
27+
'{"\\"foo\\"=>\\"bar\\"","\\"baz\\"=>NULL",NULL}',
28+
'{1,2,NULL}',
29+
'{9223372036854775807,NULL}',
30+
'{198.24.10.0/24,NULL}',
31+
'{"{\\"foo\\":\\"bar\\"}",NULL}',
32+
'{"{\\"foo\\": \\"bar\\"}",NULL}',
33+
'{08:00:2b:01:02:03,NULL}',
34+
'{$19.99,NULL}',
35+
'{19.9999999,NULL}',
36+
'{3.14159,NULL}',
37+
'{0,1,NULL}',
38+
'{foo,bar,NULL,"foo,bar","diederik\'s motel "}',
39+
'{16:38:47,NULL}',
40+
'{"2019-11-19 11:38:47-05",NULL}',
41+
'{123e4567-e89b-12d3-a456-426655440000,NULL}'
42+
]
43+
44+
self.sql_datatypes = {
45+
'bit[]': bool,
46+
'boolean[]': bool,
47+
'cidr[]': str,
48+
'citext[]': str,
49+
'date[]': str,
50+
'double precision[]': float,
51+
'hstore[]': dict,
52+
'integer[]': int,
53+
'bigint[]': int,
54+
'inet[]': str,
55+
'json[]': str,
56+
'jsonb[]': str,
57+
'macaddr[]': str,
58+
'money[]': Decimal,
59+
'numeric[]': Decimal,
60+
'real[]': float,
61+
'smallint[]': int,
62+
'text[]': str,
63+
'time with time zone[]': str,
64+
'timestamp with time zone[]': str,
65+
'uuid[]': str,
66+
}
67+
68+
def test_create_array_elem(self):
69+
expected_arrays = [
70+
['10', '01' ,None],
71+
['t', 'f', None],
72+
['127.0.0.1/32', '10.0.0.0/32', None],
73+
['CASE_INSENSITIVE', 'case_insensitive', None,"CASE,,INSENSITIVE"],
74+
['2000-12-31', '2001-01-01', None],
75+
['3.14159265359','3.1415926', None],
76+
['"foo"=>"bar"', '"baz"=>NULL', None],
77+
['1','2',None],
78+
['9223372036854775807', None],
79+
['198.24.10.0/24', None],
80+
["{\"foo\":\"bar\"}", None],
81+
["{\"foo\": \"bar\"}", None],
82+
['08:00:2b:01:02:03', None],
83+
['$19.99', None],
84+
['19.9999999', None],
85+
['3.14159', None],
86+
['0','1', None],
87+
['foo','bar',None,"foo,bar","diederik\'s motel "],
88+
['16:38:47',None],
89+
["2019-11-19 11:38:47-05",None],
90+
['123e4567-e89b-12d3-a456-426655440000', None],
91+
]
92+
for elem, expected_array in zip(self.arrays, expected_arrays):
93+
array = logical_replication.create_array_elem(elem)
94+
self.assertEqual(array, expected_array)
95+
96+
def test_selected_value_to_singer_value_impl(self):
97+
with self.env:
98+
conn_info = get_test_connection_config()
99+
for elem, sql_datatype in zip(self.arrays, self.sql_datatypes.keys()):
100+
array = logical_replication.selected_value_to_singer_value(elem, sql_datatype, conn_info)
101+
102+
for element in array:
103+
python_datatype = self.sql_datatypes[sql_datatype]
104+
if element:
105+
self.assertIsInstance(element, python_datatype)
106+
107+
if __name__== "__main__":
108+
test1 = TestHandlingArrays()
109+
test1.setUp()
110+
test1.test_selected_value_to_singer_value_impl()

0 commit comments

Comments
 (0)