Skip to content

Commit

Permalink
fix: rolling back on managing conn lifecycle using context mgrs: it d…
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas ESTRADA committed Jan 29, 2025
1 parent 41f8ded commit 129b18a
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions sources/pg_legacy_replication/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def get_max_lsn(credentials: ConnectionStringCredentials) -> Optional[int]:
Returns None if the replication slot is empty.
Does not consume the slot, i.e. messages are not flushed.
"""
with _get_conn(credentials) as conn:
cur = conn.cursor()
cur = _get_conn(credentials).cursor()
try:
loc_fn = (
"pg_current_xlog_location"
if get_pg_version(cur) < 100000
Expand All @@ -171,6 +171,8 @@ def get_max_lsn(credentials: ConnectionStringCredentials) -> Optional[int]:
cur.execute(f"SELECT {loc_fn}() - '0/0' as max_lsn;")
lsn: int = cur.fetchone()[0]
return lsn
finally:
cur.connection.close()


def lsn_int_to_hex(lsn: int) -> str:
Expand All @@ -192,13 +194,15 @@ def advance_slot(
the behavior of that method seems odd when used outside of `consume_stream`.
"""
assert upto_lsn > 0
with _get_conn(credentials) as conn:
cur = conn.cursor()
cur = _get_conn(credentials).cursor()
try:
# There is unfortunately no way in pg9.6 to manually advance the replication slot
if get_pg_version(cur) > 100000:
cur.execute(
f"SELECT * FROM pg_replication_slot_advance('{slot_name}', '{lsn_int_to_hex(upto_lsn)}');"
)
finally:
cur.connection.close()


def _get_conn(
Expand Down Expand Up @@ -382,19 +386,20 @@ def __iter__(self) -> Iterator[TableItems]:
Maintains LSN of last consumed commit message in object state.
Advances the slot only when all messages have been consumed.
"""
with get_rep_conn(self.credentials) as conn:
cur = conn.cursor()
cur = get_rep_conn(self.credentials).cursor()
consumer = MessageConsumer(
upto_lsn=self.upto_lsn,
table_qnames=self.table_qnames,
repl_options=self.repl_options,
target_batch_size=self.target_batch_size,
)
try:
cur.start_replication(slot_name=self.slot_name, start_lsn=self.start_lsn)
consumer = MessageConsumer(
upto_lsn=self.upto_lsn,
table_qnames=self.table_qnames,
repl_options=self.repl_options,
target_batch_size=self.target_batch_size,
)
try:
cur.consume_stream(consumer)
except StopReplication: # completed batch or reached `upto_lsn`
yield from self.flush_batch(cur, consumer)
cur.consume_stream(consumer)
except StopReplication: # completed batch or reached `upto_lsn`
yield from self.flush_batch(cur, consumer)
finally:
cur.connection.close()

def flush_batch(
self, cur: ReplicationCursor, consumer: MessageConsumer
Expand Down

0 comments on commit 129b18a

Please sign in to comment.