Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,41 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) {
}
}

func BenchmarkMinimalPgConnPreparedStatementDescriptionSelect(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)

pgConn := conn.PgConn()

psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil)
if err != nil {
b.Fatal(err)
}

encodedBytes := make([]byte, 8)

b.ResetTimer()
for i := 0; i < b.N; i++ {

rr := pgConn.ExecPreparedStatementDescription(context.Background(), psd, [][]byte{encodedBytes}, []int16{1}, []int16{1})
if err != nil {
b.Fatal(err)
}

for rr.NextRow() {
for i := range rr.Values() {
if !bytes.Equal(rr.Values()[0], encodedBytes) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes)
}
}
}
_, err = rr.Close()
if err != nil {
b.Fatal(err)
}
}
}

func BenchmarkPointerPointerWithNullValues(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)
Expand Down Expand Up @@ -1282,6 +1317,51 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
}
}

func BenchmarkSelectRowsPgConnExecPreparedStatementDescription(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)

rowCounts := getSelectRowsCounts(b)

psd, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
if err != nil {
b.Fatal(err)
}

for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
formats := []struct {
name string
code int16
}{
{"text", pgx.TextFormatCode},
{"binary - mostly", pgx.BinaryFormatCode},
}
for _, format := range formats {
b.Run(format.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
rr := conn.PgConn().ExecPreparedStatementDescription(
context.Background(),
psd,
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
)
for rr.NextRow() {
rr.Values()
}

_, err := rr.Close()
if err != nil {
b.Fatal(err)
}
}
})
}
})
}
}

type queryRecorder struct {
conn net.Conn
writeBuf []byte
Expand Down
68 changes: 62 additions & 6 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
)

const (
Expand Down Expand Up @@ -1165,7 +1166,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})

pgConn.execExtendedSuffix(result)
pgConn.execExtendedSuffix(result, nil, nil)

return result
}
Expand All @@ -1190,7 +1191,37 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa

pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})

pgConn.execExtendedSuffix(result)
pgConn.execExtendedSuffix(result, nil, nil)

return result
}

// ExecPreparedStatementDescription enqueues the execution of a prepared statement via the PostgreSQL extended query
// protocol.
//
// This differs from ExecPrepared in that it takes a *StatementDescription instead of just the prepared statement name.
// Because it has the *StatementDescription it can avoid the Describe Portal message that ExecPrepared must send to get
// the result column descriptions.
//
// paramValues are the parameter values. It must be encoded in the format given by paramFormats.
//
// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or
// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if len(paramFormats) is not
// 0, 1, or len(paramValues).
//
// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or binary
// format. If resultFormats is nil all results will be in text format.
//
// ResultReader must be closed before PgConn can be used again.
func (pgConn *PgConn) ExecPreparedStatementDescription(ctx context.Context, statementDescription *StatementDescription, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader {
result := pgConn.execExtendedPrefix(ctx, paramValues)
if result.closed {
return result
}

pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: statementDescription.Name, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})

pgConn.execExtendedSuffix(result, statementDescription, resultFormats)

return result
}
Expand Down Expand Up @@ -1230,8 +1261,10 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
}

func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
func (pgConn *PgConn) execExtendedSuffix(result *ResultReader, statementDescription *StatementDescription, resultFormats []int16) {
if statementDescription == nil {
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
}
pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{})

Expand All @@ -1245,7 +1278,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
return
}

result.readUntilRowDescription()
result.readUntilRowDescription(statementDescription, resultFormats)
}

// CopyTo executes the copy command sql and copies the results to w.
Expand Down Expand Up @@ -1662,13 +1695,36 @@ func (rr *ResultReader) Close() (CommandTag, error) {

// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any
// error will be stored in the ResultReader.
func (rr *ResultReader) readUntilRowDescription() {
func (rr *ResultReader) readUntilRowDescription(statementDescription *StatementDescription, resultFormats []int16) {
for !rr.commandConcluded {
// Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method.
// This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are
// manually used to construct a query that does not issue a describe statement.
msg, _ := rr.pgConn.peekMessage()
if _, ok := msg.(*pgproto3.DataRow); ok {
if statementDescription != nil {
rr.fieldDescriptions = statementDescription.Fields
// Adjust field descriptions for resultFormats
if len(resultFormats) == 0 {
// No format codes provided, default to text format
for i := range rr.fieldDescriptions {
rr.fieldDescriptions[i].Format = pgtype.TextFormatCode
}
} else if len(resultFormats) == 1 {
// Single format code applies to all columns
for i := range rr.fieldDescriptions {
rr.fieldDescriptions[i].Format = resultFormats[0]
}
} else if len(resultFormats) == len(rr.fieldDescriptions) {
// One format code per column
for i := range rr.fieldDescriptions {
rr.fieldDescriptions[i].Format = resultFormats[i]
}
} else {
// This should be impossible to reach as the mismatch would have been caught earlier.
rr.concludeCommand(CommandTag{}, fmt.Errorf("mismatched result format codes length"))
}
}
return
}

Expand Down
140 changes: 140 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,146 @@ func TestConnExecPreparedEmptySQL(t *testing.T) {
ensureConnValid(t, pgConn)
}

func TestConnExecPreparedStatementDescription(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil)
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, 1)
assert.Len(t, psd.Fields, 1)

result := pgConn.ExecPreparedStatementDescription(ctx, psd, [][]byte{[]byte("Hello, world")}, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)

rowCount := 0
for result.NextRow() {
rowCount += 1
assert.Equal(t, "Hello, world", string(result.Values()[0]))
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)

ensureConnValid(t, pgConn)
}

type byteCounterConn struct {
conn net.Conn
bytesRead int
bytesWritten int
}

func (cbn *byteCounterConn) Read(b []byte) (n int, err error) {
n, err = cbn.conn.Read(b)
cbn.bytesRead += n
return n, err
}

func (cbn *byteCounterConn) Write(b []byte) (n int, err error) {
n, err = cbn.conn.Write(b)
cbn.bytesWritten += n
return n, err
}

func (cbn *byteCounterConn) Close() error {
return cbn.conn.Close()
}

func (cbn *byteCounterConn) LocalAddr() net.Addr {
return cbn.conn.LocalAddr()
}

func (cbn *byteCounterConn) RemoteAddr() net.Addr {
return cbn.conn.RemoteAddr()
}

func (cbn *byteCounterConn) SetDeadline(t time.Time) error {
return cbn.conn.SetDeadline(t)
}

func (cbn *byteCounterConn) SetReadDeadline(t time.Time) error {
return cbn.conn.SetReadDeadline(t)
}

func (cbn *byteCounterConn) SetWriteDeadline(t time.Time) error {
return cbn.conn.SetWriteDeadline(t)
}

func TestConnExecPreparedStatementDescriptionNetworkUsage(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)

var counterConn *byteCounterConn
config.AfterNetConnect = func(ctx context.Context, config *pgconn.Config, conn net.Conn) (net.Conn, error) {
counterConn = &byteCounterConn{conn: conn}
return counterConn, nil
}

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)
require.NotNil(t, counterConn)

if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server uses different number of bytes for same operations")
}

psd, err := pgConn.Prepare(ctx, "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, 1)
assert.Len(t, psd.Fields, 9)

counterConn.bytesWritten = 0
counterConn.bytesRead = 0

result := pgConn.ExecPrepared(ctx,
psd.Name,
[][]byte{[]byte("1")},
nil,
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode},
).Read()
require.NoError(t, result.Err)
withDescribeBytesWritten := counterConn.bytesWritten
withDescribeBytesRead := counterConn.bytesRead

counterConn.bytesWritten = 0
counterConn.bytesRead = 0

result = pgConn.ExecPreparedStatementDescription(
ctx,
psd,
[][]byte{[]byte("1")},
nil,
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode, pgx.BinaryFormatCode},
).Read()
require.NoError(t, result.Err)
noDescribeBytesWritten := counterConn.bytesWritten
noDescribeBytesRead := counterConn.bytesRead

assert.Equal(t, 61, withDescribeBytesWritten)
assert.Equal(t, 54, noDescribeBytesWritten)
assert.Equal(t, 391, withDescribeBytesRead)
assert.Equal(t, 153, noDescribeBytesRead)

ensureConnValid(t, pgConn)
}

func TestConnExecBatch(t *testing.T) {
t.Parallel()

Expand Down