diff --git a/contrib/internal/telemetrytest/telemetry_test.go b/contrib/internal/telemetrytest/telemetry_test.go index 1c457a22b6..d82146d018 100644 --- a/contrib/internal/telemetrytest/telemetry_test.go +++ b/contrib/internal/telemetrytest/telemetry_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" "gopkg.in/DataDog/dd-trace-go.v1/internal/telemetry" "gopkg.in/DataDog/dd-trace-go.v1/internal/telemetry/telemetrytest" @@ -23,14 +25,12 @@ import ( // sends the correct data to the telemetry client. func TestIntegrationInfo(t *testing.T) { // mux.NewRouter() uses the net/http and gorilla/mux integration - client := new(telemetrytest.MockClient) - client.On("AppStart").Return() - client.On("MarkIntegrationAsLoaded", telemetry.Integration{Name: "net/http", Version: "", Error: ""}).Return() - client.On("MarkIntegrationAsLoaded", telemetry.Integration{Name: "gorilla/mux", Version: "", Error: ""}).Return() + client := new(telemetrytest.RecordClient) telemetry.StartApp(client) _ = mux.NewRouter() - client.AssertExpectations(t) + assert.Contains(t, client.Integrations, telemetry.Integration{Name: "net/http", Version: "", Error: ""}) + assert.Contains(t, client.Integrations, telemetry.Integration{Name: "gorilla/mux", Version: "", Error: ""}) } type contribPkg struct { diff --git a/ddtrace/opentelemetry/telemetry_test.go b/ddtrace/opentelemetry/telemetry_test.go index d3697a285e..05f6e62e20 100644 --- a/ddtrace/opentelemetry/telemetry_test.go +++ b/ddtrace/opentelemetry/telemetry_test.go @@ -74,9 +74,7 @@ func TestTelemetry(t *testing.T) { t.Setenv(k, v) } telemetryClient := new(telemetrytest.RecordClient) - original := telemetry.GlobalClient() - telemetry.SwapClient(telemetryClient) - defer telemetry.SwapClient(original) + defer telemetry.MockClient(telemetryClient)() p := NewTracerProvider() p.Tracer("") diff --git a/ddtrace/opentelemetry/tracer_test.go b/ddtrace/opentelemetry/tracer_test.go index c7736c4c9d..95982efc4f 100644 --- a/ddtrace/opentelemetry/tracer_test.go +++ b/ddtrace/opentelemetry/tracer_test.go @@ -196,9 +196,7 @@ func TestShutdownOnce(t *testing.T) { func TestSpanTelemetry(t *testing.T) { telemetryClient := new(telemetrytest.RecordClient) - original := telemetry.GlobalClient() - telemetry.SwapClient(telemetryClient) - defer telemetry.SwapClient(original) + defer telemetry.MockClient(telemetryClient)() tp := NewTracerProvider() otel.SetTracerProvider(tp) tr := otel.Tracer("") diff --git a/ddtrace/opentracer/tracer_test.go b/ddtrace/opentracer/tracer_test.go index e0e9c5f14e..4cb0ee4269 100644 --- a/ddtrace/opentracer/tracer_test.go +++ b/ddtrace/opentracer/tracer_test.go @@ -116,9 +116,7 @@ func TestExtractError(t *testing.T) { func TestSpanTelemetry(t *testing.T) { telemetryClient := new(telemetrytest.RecordClient) - original := telemetry.GlobalClient() - telemetry.SwapClient(telemetryClient) - defer telemetry.SwapClient(original) + defer telemetry.MockClient(telemetryClient)() opentracing.SetGlobalTracer(New()) _ = opentracing.StartSpan("opentracing.span") assert.NotZero(t, telemetryClient.Count(telemetry.NamespaceTracers, "spans_created", telemetryTags).Get()) diff --git a/ddtrace/tracer/otel_dd_mappings_test.go b/ddtrace/tracer/otel_dd_mappings_test.go index db5630d460..4427a02a3b 100644 --- a/ddtrace/tracer/otel_dd_mappings_test.go +++ b/ddtrace/tracer/otel_dd_mappings_test.go @@ -31,9 +31,7 @@ func TestAssessSource(t *testing.T) { }) t.Run("both", func(t *testing.T) { telemetryClient := new(telemetrytest.RecordClient) - original := telemetry.GlobalClient() - telemetry.SwapClient(telemetryClient) - defer telemetry.SwapClient(original) + defer telemetry.MockClient(telemetryClient)() // DD_SERVICE prevails t.Setenv("DD_SERVICE", "abc") t.Setenv("OTEL_SERVICE_NAME", "123") diff --git a/profiler/telemetry_test.go b/profiler/telemetry_test.go index f134a53b31..81a4fca36e 100644 --- a/profiler/telemetry_test.go +++ b/profiler/telemetry_test.go @@ -15,19 +15,11 @@ import ( "github.com/stretchr/testify/assert" ) -func mockGlobalClient(client telemetry.Client) func() { - orig := telemetry.GlobalClient() - telemetry.SwapClient(client) - return func() { - telemetry.SwapClient(orig) - } -} - // Test that the profiler sends the correct telemetry information func TestTelemetryEnabled(t *testing.T) { t.Run("tracer start, profiler start", func(t *testing.T) { telemetryClient := new(telemetrytest.RecordClient) - defer mockGlobalClient(telemetryClient)() + defer telemetry.MockClient(telemetryClient)() tracer.Start() defer tracer.Stop() @@ -44,7 +36,7 @@ func TestTelemetryEnabled(t *testing.T) { }) t.Run("only profiler start", func(t *testing.T) { telemetryClient := new(telemetrytest.RecordClient) - defer mockGlobalClient(telemetryClient)() + defer telemetry.MockClient(telemetryClient)() Start( WithProfileTypes( HeapProfile,