@@ -21,6 +21,7 @@ typedef void (*table_udf_delete_callback_t)(void *);
2121import "C"
2222
2323import (
24+ "context"
2425 "database/sql"
2526 "runtime"
2627 "runtime/cgo"
@@ -81,32 +82,58 @@ type (
8182 // A ParallelChunkTableFunction is a type which can be bound to return a ParallelChunkTableSource.
8283 ParallelChunkTableFunction = tableFunction [ParallelChunkTableSource ]
8384
84- tableFunction [T any ] struct {
85+ tableFunction [T tableSource ] struct {
8586 // Config returns the table function configuration, including the function arguments.
8687 Config TableFunctionConfig
8788 // BindArguments binds the arguments and returns a TableSource.
8889 BindArguments func (named map [string ]any , args ... any ) (T , error )
90+ // BindArgumentsContext binds the arguments with context and returns a TableSource.
91+ BindArgumentsContext func (ctx context.Context , named map [string ]any , args ... any ) (T , error )
8992 }
9093)
9194
9295func wrapRowTF (f RowTableFunction ) ParallelRowTableFunction {
93- return ParallelRowTableFunction {
96+ tf := ParallelRowTableFunction {
9497 Config : f .Config ,
95- BindArguments : func (named map [string ]any , args ... any ) (ParallelRowTableSource , error ) {
98+ }
99+
100+ if f .BindArguments != nil {
101+ tf .BindArguments = func (named map [string ]any , args ... any ) (ParallelRowTableSource , error ) {
96102 rts , err := f .BindArguments (named , args ... )
97103 return parallelRowTSWrapper {s : rts }, err
98- },
104+ }
105+ }
106+
107+ if f .BindArgumentsContext != nil {
108+ tf .BindArgumentsContext = func (ctx context.Context , named map [string ]any , args ... any ) (ParallelRowTableSource , error ) {
109+ rts , err := f .BindArgumentsContext (ctx , named , args ... )
110+ return parallelRowTSWrapper {s : rts }, err
111+ }
99112 }
113+
114+ return tf
100115}
101116
102117func wrapChunkTF (f ChunkTableFunction ) ParallelChunkTableFunction {
103- return ParallelChunkTableFunction {
118+ tf := ParallelChunkTableFunction {
104119 Config : f .Config ,
105- BindArguments : func (named map [string ]any , args ... any ) (ParallelChunkTableSource , error ) {
120+ }
121+
122+ if f .BindArguments != nil {
123+ tf .BindArguments = func (named map [string ]any , args ... any ) (ParallelChunkTableSource , error ) {
106124 rts , err := f .BindArguments (named , args ... )
107125 return parallelChunkTSWrapper {s : rts }, err
108- },
126+ }
127+ }
128+
129+ if f .BindArgumentsContext != nil {
130+ tf .BindArgumentsContext = func (ctx context.Context , named map [string ]any , args ... any ) (ParallelChunkTableSource , error ) {
131+ rts , err := f .BindArgumentsContext (ctx , named , args ... )
132+ return parallelChunkTSWrapper {s : rts }, err
133+ }
109134 }
135+
136+ return tf
110137}
111138
112139func isRowIdColumn (i mapping.IdxT ) bool {
@@ -139,7 +166,8 @@ func table_udf_bind_chunk(infoPtr unsafe.Pointer) {
139166func udfBindTyped [T tableSource ](infoPtr unsafe.Pointer ) {
140167 info := mapping.BindInfo {Ptr : infoPtr }
141168
142- f := getPinned [tableFunction [T ]](mapping .BindGetExtraInfo (info ))
169+ fc := getPinned [* tableFuncContext [T ]](mapping .BindGetExtraInfo (info ))
170+ f := fc .f
143171 config := f .Config
144172
145173 argCount := len (config .Arguments )
@@ -170,7 +198,24 @@ func udfBindTyped[T tableSource](infoPtr unsafe.Pointer) {
170198 }
171199 }
172200
173- instance , err := f .BindArguments (namedArgs , args ... )
201+ var clientCtx mapping.ClientContext
202+ mapping .TableFunctionGetClientContext (info , & clientCtx )
203+ defer mapping .DestroyClientContext (& clientCtx )
204+ connId := mapping .ClientContextGetConnectionId (clientCtx )
205+
206+ ctx := fc .ctxStore .load (uint64 (connId ))
207+
208+ var instance T
209+ var err error
210+ switch {
211+ case f .BindArgumentsContext != nil :
212+ instance , err = f .BindArgumentsContext (ctx , namedArgs , args ... )
213+ case f .BindArguments != nil :
214+ instance , err = f .BindArguments (namedArgs , args ... )
215+ default :
216+ // We should never reach here due to checks during registration.
217+ panic ("unreachable: no bind function defined" )
218+ }
174219 if err != nil {
175220 mapping .BindSetError (info , err .Error ())
176221 return
@@ -306,16 +351,28 @@ func RegisterTableUDF[TFT TableFunction](conn *sql.Conn, name string, f TFT) err
306351 }
307352}
308353
309- func registerParallelTableUDF [TFT parallelTableFunction ](conn * sql.Conn , name string , f TFT ) error {
354+ type tableFuncContext [T tableSource ] struct {
355+ f tableFunction [T ]
356+ ctxStore * contextStore
357+ }
358+
359+ func registerParallelTableUDF [T tableSource ](conn * sql.Conn , name string , f tableFunction [T ]) error {
310360 function := mapping .CreateTableFunction ()
311361 mapping .TableFunctionSetName (function , name )
312362
313363 var config TableFunctionConfig
314364
365+ // Get the context store for the connection.
366+ ctxStore , err := contextStoreFromConn (conn )
367+ if err != nil {
368+ mapping .DestroyTableFunction (& function )
369+ return err
370+ }
371+
315372 // Pin the table function f.
316- value := pinnedValue [TFT ]{
373+ value := pinnedValue [* tableFuncContext [ T ] ]{
317374 pinner : & runtime.Pinner {},
318- value : f ,
375+ value : & tableFuncContext [ T ]{ f : f , ctxStore : ctxStore } ,
319376 }
320377 h := cgo .NewHandle (value )
321378 value .pinner .Pin (& h )
@@ -342,7 +399,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
342399 mapping .TableFunctionSetBind (function , bindCallbackPtr )
343400
344401 config = tableFunc .Config
345- if tableFunc .BindArguments == nil {
402+ if tableFunc .BindArguments == nil && tableFunc . BindArgumentsContext == nil {
346403 return getError (errAPI , errTableUDFMissingBindArgs )
347404 }
348405
@@ -351,7 +408,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
351408 mapping .TableFunctionSetBind (function , bindCallbackPtr )
352409
353410 config = tableFunc .Config
354- if tableFunc .BindArguments == nil {
411+ if tableFunc .BindArguments == nil && tableFunc . BindArgumentsContext == nil {
355412 return getError (errAPI , errTableUDFMissingBindArgs )
356413 }
357414
@@ -380,7 +437,7 @@ func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name st
380437 }
381438
382439 // Register the function on the underlying driver connection exposed by c.Raw.
383- err : = conn .Raw (func (driverConn any ) error {
440+ err = conn .Raw (func (driverConn any ) error {
384441 c := driverConn .(* Conn )
385442 state := mapping .RegisterTableFunction (c .conn , function )
386443 mapping .DestroyTableFunction (& function )
0 commit comments