@@ -156,7 +156,7 @@ func TestWithCommandFilter(t *testing.T) {
156156		hook  :=  newTracingHook (
157157			"" ,
158158			WithTracerProvider (provider ),
159- 			WithCommandFilter (BasicCommandFilter ),
159+ 			WithCommandFilter (DefaultCommandFilter ),
160160		)
161161		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
162162		cmd  :=  redis .NewCmd (ctx , "auth" , "test-password" )
@@ -181,7 +181,7 @@ func TestWithCommandFilter(t *testing.T) {
181181		hook  :=  newTracingHook (
182182			"" ,
183183			WithTracerProvider (provider ),
184- 			WithCommandFilter (BasicCommandFilter ),
184+ 			WithCommandFilter (DefaultCommandFilter ),
185185		)
186186		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
187187		cmd  :=  redis .NewCmd (ctx , "hello" , 3 , "AUTH" , "test-user" , "test-password" )
@@ -206,7 +206,7 @@ func TestWithCommandFilter(t *testing.T) {
206206		hook  :=  newTracingHook (
207207			"" ,
208208			WithTracerProvider (provider ),
209- 			WithCommandFilter (BasicCommandFilter ),
209+ 			WithCommandFilter (DefaultCommandFilter ),
210210		)
211211		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
212212		cmd  :=  redis .NewCmd (ctx , "hello" , 3 )
@@ -227,6 +227,120 @@ func TestWithCommandFilter(t *testing.T) {
227227	})
228228}
229229
230+ func  TestWithCommandsFilter (t  * testing.T ) {
231+ 	t .Run ("filter out ping and info commands" , func (t  * testing.T ) {
232+ 		provider  :=  sdktrace .NewTracerProvider ()
233+ 		hook  :=  newTracingHook (
234+ 			"" ,
235+ 			WithTracerProvider (provider ),
236+ 			WithCommandsFilter (func (cmds  []redis.Cmder ) bool  {
237+ 				for  _ , cmd  :=  range  cmds  {
238+ 					if  cmd .Name () ==  "ping"  ||  cmd .Name () ==  "info"  {
239+ 						return  true 
240+ 					}
241+ 				}
242+ 				return  false 
243+ 			}),
244+ 		)
245+ 
246+ 		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
247+ 		cmds  :=  []redis.Cmder {
248+ 			redis .NewCmd (ctx , "ping" ),
249+ 			redis .NewCmd (ctx , "info" ),
250+ 		}
251+ 		defer  span .End ()
252+ 
253+ 		processPipelineHook  :=  hook .ProcessPipelineHook (func (ctx  context.Context , cmds  []redis.Cmder ) error  {
254+ 			innerSpan  :=  trace .SpanFromContext (ctx ).(sdktrace.ReadOnlySpan )
255+ 			if  innerSpan .Name () !=  "redis-test"  ||  innerSpan .Name () ==  "redis.pipeline ping\n info"  {
256+ 				t .Fatalf ("ping and info commands should not be traced" )
257+ 			}
258+ 			return  nil 
259+ 		})
260+ 		err  :=  processPipelineHook (ctx , cmds )
261+ 		if  err  !=  nil  {
262+ 			t .Fatal (err )
263+ 		}
264+ 	})
265+ 
266+ 	t .Run ("do not filter ping and info commands" , func (t  * testing.T ) {
267+ 		provider  :=  sdktrace .NewTracerProvider ()
268+ 		hook  :=  newTracingHook (
269+ 			"" ,
270+ 			WithTracerProvider (provider ),
271+ 			WithCommandsFilter (func (cmds  []redis.Cmder ) bool  {
272+ 				return  false  // never filter 
273+ 			}),
274+ 		)
275+ 		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
276+ 		cmds  :=  []redis.Cmder {
277+ 			redis .NewCmd (ctx , "ping" ),
278+ 			redis .NewCmd (ctx , "info" ),
279+ 		}
280+ 		defer  span .End ()
281+ 		processPipelineHook  :=  hook .ProcessPipelineHook (func (ctx  context.Context , cmds  []redis.Cmder ) error  {
282+ 			innerSpan  :=  trace .SpanFromContext (ctx ).(sdktrace.ReadOnlySpan )
283+ 			if  innerSpan .Name () !=  "redis.pipeline ping info"  {
284+ 				t .Fatalf ("ping and info commands should be traced" )
285+ 			}
286+ 
287+ 			return  nil 
288+ 		})
289+ 
290+ 		err  :=  processPipelineHook (ctx , cmds )
291+ 		if  err  !=  nil  {
292+ 			t .Fatal (err )
293+ 		}
294+ 	})
295+ }
296+ 
297+ func  TestWithDialFilter (t  * testing.T ) {
298+ 	t .Run ("filter out dial" , func (t  * testing.T ) {
299+ 		provider  :=  sdktrace .NewTracerProvider ()
300+ 		hook  :=  newTracingHook (
301+ 			"" ,
302+ 			WithTracerProvider (provider ),
303+ 			WithDialFilter (true ),
304+ 		)
305+ 		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
306+ 		defer  span .End ()
307+ 		dialHook  :=  hook .DialHook (func (ctx  context.Context , network , addr  string ) (conn  net.Conn , err  error ) {
308+ 			innerSpan  :=  trace .SpanFromContext (ctx ).(sdktrace.ReadOnlySpan )
309+ 			if  innerSpan .Name () ==  "redis.dial"  {
310+ 				t .Fatalf ("dial should not be traced" )
311+ 			}
312+ 			return  nil , nil 
313+ 		})
314+ 
315+ 		_ , err  :=  dialHook (ctx , "tcp" , "localhost:6379" )
316+ 		if  err  !=  nil  {
317+ 			t .Fatal (err )
318+ 		}
319+ 	})
320+ 
321+ 	t .Run ("do not filter dial" , func (t  * testing.T ) {
322+ 		provider  :=  sdktrace .NewTracerProvider ()
323+ 		hook  :=  newTracingHook (
324+ 			"" ,
325+ 			WithTracerProvider (provider ),
326+ 			WithDialFilter (false ),
327+ 		)
328+ 		ctx , span  :=  provider .Tracer ("redis-test" ).Start (context .TODO (), "redis-test" )
329+ 		defer  span .End ()
330+ 		dialHook  :=  hook .DialHook (func (ctx  context.Context , network , addr  string ) (conn  net.Conn , err  error ) {
331+ 			innerSpan  :=  trace .SpanFromContext (ctx ).(sdktrace.ReadOnlySpan )
332+ 			if  innerSpan .Name () !=  "redis.dial"  {
333+ 				t .Fatalf ("dial should be traced" )
334+ 			}
335+ 			return  nil , nil 
336+ 		})
337+ 		_ , err  :=  dialHook (ctx , "tcp" , "localhost:6379" )
338+ 		if  err  !=  nil  {
339+ 			t .Fatal (err )
340+ 		}
341+ 	})
342+ }
343+ 
230344func  TestTracingHook_DialHook (t  * testing.T ) {
231345	imsb  :=  tracetest .NewInMemoryExporter ()
232346	provider  :=  sdktrace .NewTracerProvider (sdktrace .WithSyncer (imsb ))
0 commit comments