diff --git a/templates/keynote-2/src/rpc-servers/postgres-rpc-server.ts b/templates/keynote-2/src/rpc-servers/postgres-rpc-server.ts index 72441607c6c..272b4deff4d 100644 --- a/templates/keynote-2/src/rpc-servers/postgres-rpc-server.ts +++ b/templates/keynote-2/src/rpc-servers/postgres-rpc-server.ts @@ -3,7 +3,7 @@ import http from 'node:http'; import { Pool } from 'pg'; import { drizzle } from 'drizzle-orm/node-postgres'; import { pgTable, integer, bigint as pgBigint } from 'drizzle-orm/pg-core'; -import { sql } from 'drizzle-orm'; +import { eq, inArray, sql } from 'drizzle-orm'; import { RpcRequest, RpcResponse } from '../connectors/rpc/rpc_common.ts'; import { getSharedRuntimeDefaults } from '../config.ts'; @@ -27,36 +27,6 @@ const pool = new Pool({ const db = drizzle(pool, { schema: { accounts } }); -const PREPARED = { - getAccountById: { - name: 'get_account', - text: ` - SELECT id, balance - FROM accounts - WHERE id = $1 - LIMIT 1 - `, - }, - transferSelectForUpdate: { - name: 'transfer_select', - text: ` - SELECT id, balance - FROM accounts - WHERE id IN ($1, $2) - ORDER BY id - FOR UPDATE - `, - }, - transferUpdateBalance: { - name: 'transfer_update', - text: ` - UPDATE accounts - SET balance = $1::bigint - WHERE id = $2 - `, - }, -} as const; - async function rpcTransfer(args: Record) { const fromId = Number(args.from_id ?? args.from); const toId = Number(args.to_id ?? args.to); @@ -72,17 +42,14 @@ async function rpcTransfer(args: Record) { if (fromId === toId || amount <= 0) return; const delta = BigInt(amount); - const client = await pool.connect(); - try { - await client.query('BEGIN'); - - const rowsResult = await client.query<{ id: number; balance: string }>({ - name: PREPARED.transferSelectForUpdate.name, - text: PREPARED.transferSelectForUpdate.text, - values: [fromId, toId], - }); - const rows = rowsResult.rows; + await db.transaction(async (tx) => { + const rows = await tx + .select() + .from(accounts) + .where(inArray(accounts.id, [fromId, toId])) + .for('update') + .orderBy(accounts.id); if (rows.length !== 2) { throw new Error('account_missing'); @@ -91,43 +58,32 @@ async function rpcTransfer(args: Record) { const [first, second] = rows; const fromRow = first.id === fromId ? first : second; const toRow = first.id === fromId ? second : first; - const fromBalance = BigInt(fromRow.balance); - - if (fromBalance >= delta) { - const toBalance = BigInt(toRow.balance); - - await client.query({ - name: PREPARED.transferUpdateBalance.name, - text: PREPARED.transferUpdateBalance.text, - values: [(fromBalance - delta).toString(), fromId], - }); - - await client.query({ - name: PREPARED.transferUpdateBalance.name, - text: PREPARED.transferUpdateBalance.text, - values: [(toBalance + delta).toString(), toId], - }); + + if (fromRow.balance < delta) { + return; } - await client.query('COMMIT'); - } catch (err) { - await client.query('ROLLBACK').catch(() => {}); - throw err; - } finally { - client.release(); - } + await tx + .update(accounts) + .set({ balance: fromRow.balance - delta }) + .where(eq(accounts.id, fromId)); + + await tx + .update(accounts) + .set({ balance: toRow.balance + delta }) + .where(eq(accounts.id, toId)); + }); } async function rpcGetAccount(args: Record) { const id = Number(args.id); if (!Number.isInteger(id)) throw new Error('invalid id'); - const rowsResult = await pool.query<{ id: number; balance: string }>({ - name: PREPARED.getAccountById.name, - text: PREPARED.getAccountById.text, - values: [id], - }); - const rows = rowsResult.rows; + const rows = await db + .select() + .from(accounts) + .where(eq(accounts.id, id)) + .limit(1); if (rows.length === 0) return null; const row = rows[0]!;