Skip to content

Commit cc449fe

Browse files
Merge pull request #56 from wmTJc9IK0Q/table-udf-context-2
Add context to table UDFs
2 parents 2ad3914 + b9bfa58 commit cc449fe

File tree

2 files changed

+144
-15
lines changed

2 files changed

+144
-15
lines changed

table_udf.go

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ typedef void (*table_udf_delete_callback_t)(void *);
2121
import "C"
2222

2323
import (
24+
"context"
2425
"database/sql"
2526
"runtime"
2627
"runtime/cgo"
@@ -81,32 +82,58 @@ type (
8182
// A ParallelChunkTableFunction is a type which can be bound to return a ParallelChunkTableSource.
8283
ParallelChunkTableFunction = tableFunction[ParallelChunkTableSource]
8384

84-
tableFunction[T any] struct {
85+
tableFunction[T tableSource] struct {
8586
// Config returns the table function configuration, including the function arguments.
8687
Config TableFunctionConfig
8788
// BindArguments binds the arguments and returns a TableSource.
8889
BindArguments func(named map[string]any, args ...any) (T, error)
90+
// BindArgumentsContext binds the arguments with context and returns a TableSource.
91+
BindArgumentsContext func(ctx context.Context, named map[string]any, args ...any) (T, error)
8992
}
9093
)
9194

9295
func wrapRowTF(f RowTableFunction) ParallelRowTableFunction {
93-
return ParallelRowTableFunction{
96+
tf := ParallelRowTableFunction{
9497
Config: f.Config,
95-
BindArguments: func(named map[string]any, args ...any) (ParallelRowTableSource, error) {
98+
}
99+
100+
if f.BindArguments != nil {
101+
tf.BindArguments = func(named map[string]any, args ...any) (ParallelRowTableSource, error) {
96102
rts, err := f.BindArguments(named, args...)
97103
return parallelRowTSWrapper{s: rts}, err
98-
},
104+
}
105+
}
106+
107+
if f.BindArgumentsContext != nil {
108+
tf.BindArgumentsContext = func(ctx context.Context, named map[string]any, args ...any) (ParallelRowTableSource, error) {
109+
rts, err := f.BindArgumentsContext(ctx, named, args...)
110+
return parallelRowTSWrapper{s: rts}, err
111+
}
99112
}
113+
114+
return tf
100115
}
101116

102117
func wrapChunkTF(f ChunkTableFunction) ParallelChunkTableFunction {
103-
return ParallelChunkTableFunction{
118+
tf := ParallelChunkTableFunction{
104119
Config: f.Config,
105-
BindArguments: func(named map[string]any, args ...any) (ParallelChunkTableSource, error) {
120+
}
121+
122+
if f.BindArguments != nil {
123+
tf.BindArguments = func(named map[string]any, args ...any) (ParallelChunkTableSource, error) {
106124
rts, err := f.BindArguments(named, args...)
107125
return parallelChunkTSWrapper{s: rts}, err
108-
},
126+
}
127+
}
128+
129+
if f.BindArgumentsContext != nil {
130+
tf.BindArgumentsContext = func(ctx context.Context, named map[string]any, args ...any) (ParallelChunkTableSource, error) {
131+
rts, err := f.BindArgumentsContext(ctx, named, args...)
132+
return parallelChunkTSWrapper{s: rts}, err
133+
}
109134
}
135+
136+
return tf
110137
}
111138

112139
func isRowIdColumn(i mapping.IdxT) bool {
@@ -139,7 +166,8 @@ func table_udf_bind_chunk(infoPtr unsafe.Pointer) {
139166
func udfBindTyped[T tableSource](infoPtr unsafe.Pointer) {
140167
info := mapping.BindInfo{Ptr: infoPtr}
141168

142-
f := getPinned[tableFunction[T]](mapping.BindGetExtraInfo(info))
169+
fc := getPinned[*tableFuncContext[T]](mapping.BindGetExtraInfo(info))
170+
f := fc.f
143171
config := f.Config
144172

145173
argCount := len(config.Arguments)
@@ -170,7 +198,24 @@ func udfBindTyped[T tableSource](infoPtr unsafe.Pointer) {
170198
}
171199
}
172200

173-
instance, err := f.BindArguments(namedArgs, args...)
201+
var clientCtx mapping.ClientContext
202+
mapping.TableFunctionGetClientContext(info, &clientCtx)
203+
defer mapping.DestroyClientContext(&clientCtx)
204+
connId := mapping.ClientContextGetConnectionId(clientCtx)
205+
206+
ctx := fc.ctxStore.load(uint64(connId))
207+
208+
var instance T
209+
var err error
210+
switch {
211+
case f.BindArgumentsContext != nil:
212+
instance, err = f.BindArgumentsContext(ctx, namedArgs, args...)
213+
case f.BindArguments != nil:
214+
instance, err = f.BindArguments(namedArgs, args...)
215+
default:
216+
// We should never reach here due to checks during registration.
217+
panic("unreachable: no bind function defined")
218+
}
174219
if err != nil {
175220
mapping.BindSetError(info, err.Error())
176221
return
@@ -306,16 +351,28 @@ func RegisterTableUDF[TFT TableFunction](conn *sql.Conn, name string, f TFT) err
306351
}
307352
}
308353

309-
func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name string, f TFT) error {
354+
type tableFuncContext[T tableSource] struct {
355+
f tableFunction[T]
356+
ctxStore *contextStore
357+
}
358+
359+
func registerParallelTableUDF[T tableSource](conn *sql.Conn, name string, f tableFunction[T]) error {
310360
function := mapping.CreateTableFunction()
311361
mapping.TableFunctionSetName(function, name)
312362

313363
var config TableFunctionConfig
314364

365+
// Get the context store for the connection.
366+
ctxStore, err := contextStoreFromConn(conn)
367+
if err != nil {
368+
mapping.DestroyTableFunction(&function)
369+
return err
370+
}
371+
315372
// Pin the table function f.
316-
value := pinnedValue[TFT]{
373+
value := pinnedValue[*tableFuncContext[T]]{
317374
pinner: &runtime.Pinner{},
318-
value: f,
375+
value: &tableFuncContext[T]{f: f, ctxStore: ctxStore},
319376
}
320377
h := cgo.NewHandle(value)
321378
value.pinner.Pin(&h)
@@ -342,7 +399,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
342399
mapping.TableFunctionSetBind(function, bindCallbackPtr)
343400

344401
config = tableFunc.Config
345-
if tableFunc.BindArguments == nil {
402+
if tableFunc.BindArguments == nil && tableFunc.BindArgumentsContext == nil {
346403
return getError(errAPI, errTableUDFMissingBindArgs)
347404
}
348405

@@ -351,7 +408,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
351408
mapping.TableFunctionSetBind(function, bindCallbackPtr)
352409

353410
config = tableFunc.Config
354-
if tableFunc.BindArguments == nil {
411+
if tableFunc.BindArguments == nil && tableFunc.BindArgumentsContext == nil {
355412
return getError(errAPI, errTableUDFMissingBindArgs)
356413
}
357414

@@ -380,7 +437,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
380437
}
381438

382439
// Register the function on the underlying driver connection exposed by c.Raw.
383-
err := conn.Raw(func(driverConn any) error {
440+
err = conn.Raw(func(driverConn any) error {
384441
c := driverConn.(*Conn)
385442
state := mapping.RegisterTableFunction(c.conn, function)
386443
mapping.DestroyTableFunction(&function)

table_udf_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ type (
9494
start int64
9595
end int64
9696
}
97+
98+
contextTableUDF struct {
99+
value uint64
100+
}
97101
)
98102

99103
var (
@@ -680,6 +684,50 @@ func (udf *constTableUDF[T]) Cardinality() *CardinalityInfo {
680684
return nil
681685
}
682686

687+
func (udf *contextTableUDF) GetFunction() RowTableFunction {
688+
return RowTableFunction{
689+
Config: TableFunctionConfig{
690+
// No arguments needed for this test, as context is passed directly.
691+
},
692+
BindArgumentsContext: bindContextTableUDF,
693+
}
694+
}
695+
696+
func bindContextTableUDF(ctx context.Context, namedArgs map[string]any, args ...any) (RowTableSource, error) {
697+
v, ok := ctx.Value(testCtxKey).(uint64)
698+
if !ok {
699+
return nil, fmt.Errorf("context does not contain the test value")
700+
}
701+
return &contextTableUDF{value: v}, nil
702+
}
703+
704+
func (udf *contextTableUDF) ColumnInfos() []ColumnInfo {
705+
return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}}
706+
}
707+
708+
func (udf *contextTableUDF) Init() {}
709+
710+
func (udf *contextTableUDF) FillRow(row Row) (bool, error) {
711+
if udf.value == 0 {
712+
return false, nil
713+
}
714+
err := SetRowValue(row, 0, udf.value)
715+
udf.value = 0 // Only return once
716+
return true, err
717+
}
718+
719+
func (udf *contextTableUDF) GetValue(r, c int) any {
720+
return uint64(123) // Dummy value for GetValue, actual value comes from context
721+
}
722+
723+
func (udf *contextTableUDF) GetTypes() []any {
724+
return []any{uint64(0)}
725+
}
726+
727+
func (udf *contextTableUDF) Cardinality() *CardinalityInfo {
728+
return nil
729+
}
730+
683731
func (udf *chunkIncTableUDF) GetFunction() ChunkTableFunction {
684732
return ChunkTableFunction{
685733
Config: TableFunctionConfig{
@@ -948,3 +996,27 @@ func BenchmarkChunkTableUDF(b *testing.B) {
948996
closeRowsWrapper(b, res)
949997
}
950998
}
999+
1000+
func TestContextTableUDF(t *testing.T) {
1001+
db := openDbWrapper(t, `?access_mode=READ_WRITE`)
1002+
defer closeDbWrapper(t, db)
1003+
1004+
conn := openConnWrapper(t, db, context.Background())
1005+
defer closeConnWrapper(t, conn)
1006+
1007+
var udf contextTableUDF
1008+
err := RegisterTableUDF(conn, "context_table_udf", udf.GetFunction())
1009+
require.NoError(t, err)
1010+
1011+
ctx := context.WithValue(context.Background(), testCtxKey, uint64(999))
1012+
1013+
res, err := db.QueryContext(ctx, `SELECT * FROM context_table_udf()`)
1014+
require.NoError(t, err)
1015+
defer closeRowsWrapper(t, res)
1016+
1017+
var result uint64
1018+
require.True(t, res.Next())
1019+
require.NoError(t, res.Scan(&result))
1020+
require.Equal(t, uint64(999), result)
1021+
require.False(t, res.Next())
1022+
}

0 commit comments

Comments
 (0)