Skip to content

Commit a221e9d

Browse files
committed
fix: fix context propagation in distributed tracing (#402)
1 parent ddb2e15 commit a221e9d

File tree

7 files changed

+38
-140
lines changed

7 files changed

+38
-140
lines changed

cmd/optimizely/main.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import (
3434

3535
"github.com/optimizely/agent/config"
3636
"github.com/optimizely/agent/pkg/metrics"
37-
"github.com/optimizely/agent/pkg/middleware"
3837
"github.com/optimizely/agent/pkg/optimizely"
3938
"github.com/optimizely/agent/pkg/routers"
4039
"github.com/optimizely/agent/pkg/server"
@@ -50,6 +49,7 @@ import (
5049
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
5150
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
5251
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
52+
"go.opentelemetry.io/otel/propagation"
5353
"go.opentelemetry.io/otel/sdk/resource"
5454
sdktrace "go.opentelemetry.io/otel/sdk/trace"
5555
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
@@ -151,7 +151,6 @@ func getStdOutTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv
151151
return sdktrace.NewTracerProvider(
152152
sdktrace.WithBatcher(exp),
153153
sdktrace.WithResource(res),
154-
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)),
155154
), nil
156155
}
157156

@@ -199,7 +198,6 @@ func getRemoteTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv
199198
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(conf.Services.Remote.SampleRate))),
200199
sdktrace.WithResource(res),
201200
sdktrace.WithSpanProcessor(bsp),
202-
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)),
203201
), nil
204202
}
205203

@@ -246,6 +244,7 @@ func main() {
246244
}
247245
}()
248246
otel.SetTracerProvider(tp)
247+
otel.SetTextMapPropagator(propagation.TraceContext{})
249248
log.Info().Msg(fmt.Sprintf("Tracing enabled with service %q", conf.Tracing.OpenTelemetry.Default))
250249
} else {
251250
log.Info().Msg("Tracing disabled")
@@ -275,7 +274,7 @@ func main() {
275274
cancel()
276275
}()
277276

278-
apiRouter := routers.NewDefaultAPIRouter(optlyCache, *conf, agentMetricsRegistry)
277+
apiRouter := routers.NewDefaultAPIRouter(optlyCache, conf.API, agentMetricsRegistry)
279278
adminRouter := routers.NewAdminRouter(*conf)
280279

281280
log.Info().Str("version", conf.Version).Msg("Starting services.")

config.yaml

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ log:
2929
##
3030
## tracing: tracing configuration
3131
##
32+
## For distributed tracing, trace context should be sent on "traceparent" header
33+
## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification.
34+
## See more at https://www.w3.org/TR/trace-context/#trace-id
3235
tracing:
3336
## bydefault tracing is disabled
3437
## to enable tracing set enabled to true
@@ -43,10 +46,6 @@ tracing:
4346
## tracing environment name
4447
## example: for production environment env can be set as "prod"
4548
env: "dev"
46-
## HTTP Header Key for TraceID in Distributed Tracing
47-
## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification.
48-
## See more at https://www.w3.org/TR/trace-context/#trace-id
49-
traceIDHeaderKey: "X-Optimizely-Trace-ID"
5049
## tracing service configuration
5150
services:
5251
## stdout exporter configuration

config/config.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,10 @@ const (
197197
)
198198

199199
type OTELTracingConfig struct {
200-
Default TracingServiceType `json:"default"`
201-
ServiceName string `json:"serviceName"`
202-
Env string `json:"env"`
203-
TraceIDHeaderKey string `json:"traceIDHeaderKey"`
204-
Services TracingServiceConfig `json:"services"`
200+
Default TracingServiceType `json:"default"`
201+
ServiceName string `json:"serviceName"`
202+
Env string `json:"env"`
203+
Services TracingServiceConfig `json:"services"`
205204
}
206205

207206
type TracingServiceConfig struct {

pkg/middleware/trace.go

+5-58
Original file line numberDiff line numberDiff line change
@@ -18,75 +18,22 @@
1818
package middleware
1919

2020
import (
21-
"context"
22-
crand "crypto/rand"
23-
"encoding/binary"
24-
"math/rand"
2521
"net/http"
26-
"sync"
2722

2823
"github.com/go-chi/chi/v5/middleware"
29-
"github.com/optimizely/agent/config"
30-
"github.com/rs/zerolog/log"
3124
"go.opentelemetry.io/otel"
3225
"go.opentelemetry.io/otel/attribute"
26+
"go.opentelemetry.io/otel/propagation"
3327
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
34-
"go.opentelemetry.io/otel/trace"
3528
)
3629

37-
type traceIDGenerator struct {
38-
sync.Mutex
39-
randSource *rand.Rand
40-
traceIDHeaderKey string
41-
}
42-
43-
func NewTraceIDGenerator(traceIDHeaderKey string) *traceIDGenerator {
44-
var rngSeed int64
45-
_ = binary.Read(crand.Reader, binary.LittleEndian, &rngSeed)
46-
return &traceIDGenerator{
47-
randSource: rand.New(rand.NewSource(rngSeed)),
48-
traceIDHeaderKey: traceIDHeaderKey,
49-
}
50-
}
51-
52-
func (gen *traceIDGenerator) NewSpanID(ctx context.Context, traceID trace.TraceID) trace.SpanID {
53-
gen.Lock()
54-
defer gen.Unlock()
55-
sid := trace.SpanID{}
56-
_, _ = gen.randSource.Read(sid[:])
57-
return sid
58-
}
59-
60-
func (gen *traceIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.SpanID) {
61-
gen.Lock()
62-
defer gen.Unlock()
63-
tid := trace.TraceID{}
64-
_, _ = gen.randSource.Read(tid[:])
65-
sid := trace.SpanID{}
66-
_, _ = gen.randSource.Read(sid[:])
67-
68-
// read trace id from header if provided
69-
traceIDHeader := ctx.Value(gen.traceIDHeaderKey)
70-
if val, ok := traceIDHeader.(string); ok {
71-
if val != "" {
72-
headerTraceId, err := trace.TraceIDFromHex(val)
73-
if err == nil {
74-
tid = headerTraceId
75-
} else {
76-
log.Error().Err(err).Msg("failed to parse trace id from header, invalid trace id")
77-
}
78-
}
79-
}
80-
81-
return tid, sid
82-
}
83-
84-
func AddTracing(conf config.TracingConfig, tracerName, spanName string) func(http.Handler) http.Handler {
30+
func AddTracing(tracerName, spanName string) func(http.Handler) http.Handler {
8531
return func(next http.Handler) http.Handler {
8632
fn := func(w http.ResponseWriter, r *http.Request) {
87-
pctx := context.WithValue(r.Context(), conf.OpenTelemetry.TraceIDHeaderKey, r.Header.Get(conf.OpenTelemetry.TraceIDHeaderKey))
33+
prop := otel.GetTextMapPropagator()
34+
propCtx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
8835

89-
ctx, span := otel.Tracer(tracerName).Start(pctx, spanName)
36+
ctx, span := otel.Tracer(tracerName).Start(propCtx, spanName)
9037
defer span.End()
9138

9239
span.SetAttributes(

pkg/middleware/trace_test.go

+1-46
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
package middleware
1919

2020
import (
21-
"context"
2221
"net/http"
2322
"net/http/httptest"
2423
"testing"
25-
26-
"github.com/optimizely/agent/config"
27-
"github.com/stretchr/testify/assert"
28-
"go.opentelemetry.io/otel/trace"
2924
)
3025

3126
func TestAddTracing(t *testing.T) {
@@ -37,7 +32,7 @@ func TestAddTracing(t *testing.T) {
3732

3833
req := httptest.NewRequest("GET", "/", nil)
3934
rr := httptest.NewRecorder()
40-
middleware := http.Handler(AddTracing(config.TracingConfig{}, "test-tracer", "test-span")(handler))
35+
middleware := http.Handler(AddTracing("test-tracer", "test-span")(handler))
4136

4237
// Serve the request through the middleware
4338
middleware.ServeHTTP(rr, req)
@@ -54,43 +49,3 @@ func TestAddTracing(t *testing.T) {
5449
t.Errorf("Expected Content-Type header %v, but got %v", "application/text", typeHeader)
5550
}
5651
}
57-
58-
func TestNewIDs(t *testing.T) {
59-
gen := NewTraceIDGenerator("")
60-
n := 1000
61-
62-
for i := 0; i < n; i++ {
63-
traceID, spanID := gen.NewIDs(context.Background())
64-
assert.Truef(t, traceID.IsValid(), "trace id: %s", traceID.String())
65-
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
66-
}
67-
}
68-
69-
func TestNewSpanID(t *testing.T) {
70-
gen := NewTraceIDGenerator("")
71-
testTraceID := [16]byte{123, 123}
72-
n := 1000
73-
74-
for i := 0; i < n; i++ {
75-
spanID := gen.NewSpanID(context.Background(), testTraceID)
76-
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
77-
}
78-
}
79-
80-
func TestNewSpanIDWithInvalidTraceID(t *testing.T) {
81-
gen := NewTraceIDGenerator("")
82-
spanID := gen.NewSpanID(context.Background(), trace.TraceID{})
83-
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
84-
}
85-
86-
func TestTraceIDWithGivenHeaderValue(t *testing.T) {
87-
traceHeader := "X-Trace-ID"
88-
traceID := "9b8eac67e332c6f8baf1e013de6891bb"
89-
90-
gen := NewTraceIDGenerator(traceHeader)
91-
92-
ctx := context.WithValue(context.Background(), traceHeader, traceID)
93-
genTraceID, _ := gen.NewIDs(ctx)
94-
assert.Truef(t, genTraceID.IsValid(), "trace id: %s", genTraceID.String())
95-
assert.Equal(t, traceID, genTraceID.String())
96-
}

pkg/routers/api.go

+16-17
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ func forbiddenHandler(message string) http.HandlerFunc {
6363
}
6464

6565
// NewDefaultAPIRouter creates a new router with the default backing optimizely.Cache
66-
func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConfig, metricsRegistry *metrics.Registry) http.Handler {
67-
conf := agentConf.API
66+
func NewDefaultAPIRouter(optlyCache optimizely.Cache, conf config.APIConfig, metricsRegistry *metrics.Registry) http.Handler {
6867
authProvider := middleware.NewAuth(&conf.Auth)
6968
if authProvider == nil {
7069
log.Error().Msg("unable to initialize api auth middleware.")
@@ -109,19 +108,19 @@ func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConf
109108
corsHandler: corsHandler,
110109
}
111110

112-
return NewAPIRouter(spec, agentConf.Tracing)
111+
return NewAPIRouter(spec)
113112
}
114113

115114
// NewAPIRouter returns HTTP API router backed by an optimizely.Cache implementation
116-
func NewAPIRouter(opt *APIOptions, traceConf config.TracingConfig) *chi.Mux {
115+
func NewAPIRouter(opt *APIOptions) *chi.Mux {
117116
r := chi.NewRouter()
118-
WithAPIRouter(opt, r, traceConf)
117+
WithAPIRouter(opt, r)
119118
return r
120119
}
121120

122121
// WithAPIRouter appends routes and middleware to the given router.
123122
// See https://godoc.org/github.com/go-chi/chi/v5#Mux.Group for usage
124-
func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig) {
123+
func WithAPIRouter(opt *APIOptions, r chi.Router) {
125124
getConfigTimer := middleware.Metricize("get-config", opt.metricsRegistry)
126125
getDatafileTimer := middleware.Metricize("get-datafile", opt.metricsRegistry)
127126
activateTimer := middleware.Metricize("activate", opt.metricsRegistry)
@@ -134,17 +133,17 @@ func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig
134133
createAccesstokenTimer := middleware.Metricize("create-api-access-token", opt.metricsRegistry)
135134
contentTypeMiddleware := chimw.AllowContentType("application/json")
136135

137-
configTracer := middleware.AddTracing(traceConf, "configHandler", "OptimizelyConfig")
138-
datafileTracer := middleware.AddTracing(traceConf, "datafileHandler", "OptimizelyDatafile")
139-
activateTracer := middleware.AddTracing(traceConf, "activateHandler", "Activate")
140-
decideTracer := middleware.AddTracing(traceConf, "decideHandler", "Decide")
141-
trackTracer := middleware.AddTracing(traceConf, "trackHandler", "Track")
142-
overrideTracer := middleware.AddTracing(traceConf, "overrideHandler", "Override")
143-
lookupTracer := middleware.AddTracing(traceConf, "lookupHandler", "Lookup")
144-
saveTracer := middleware.AddTracing(traceConf, "saveHandler", "Save")
145-
sendOdpEventTracer := middleware.AddTracing(traceConf, "sendOdpEventHandler", "SendOdpEvent")
146-
nStreamTracer := middleware.AddTracing(traceConf, "notificationHandler", "SendNotificationEvent")
147-
authTracer := middleware.AddTracing(traceConf, "authHandler", "AuthToken")
136+
configTracer := middleware.AddTracing("configHandler", "OptimizelyConfig")
137+
datafileTracer := middleware.AddTracing("datafileHandler", "OptimizelyDatafile")
138+
activateTracer := middleware.AddTracing("activateHandler", "Activate")
139+
decideTracer := middleware.AddTracing("decideHandler", "Decide")
140+
trackTracer := middleware.AddTracing("trackHandler", "Track")
141+
overrideTracer := middleware.AddTracing("overrideHandler", "Override")
142+
lookupTracer := middleware.AddTracing("lookupHandler", "Lookup")
143+
saveTracer := middleware.AddTracing("saveHandler", "Save")
144+
sendOdpEventTracer := middleware.AddTracing("sendOdpEventHandler", "SendOdpEvent")
145+
nStreamTracer := middleware.AddTracing("notificationHandler", "SendNotificationEvent")
146+
authTracer := middleware.AddTracing("authHandler", "AuthToken")
148147

149148
if opt.maxConns > 0 {
150149
// Note this is NOT a rate limiter, but a concurrency threshold

pkg/routers/api_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ func (suite *APIV1TestSuite) SetupTest() {
126126
corsHandler: testCorsHandler,
127127
}
128128

129-
suite.mux = NewAPIRouter(opts, config.TracingConfig{})
129+
suite.mux = NewAPIRouter(opts)
130130
}
131131

132132
func (suite *APIV1TestSuite) TestValidRoutes() {
@@ -138,7 +138,7 @@ func (suite *APIV1TestSuite) TestValidRoutes() {
138138
}
139139
return http.HandlerFunc(fn)
140140
}
141-
suite.mux = NewAPIRouter(opts, config.TracingConfig{})
141+
suite.mux = NewAPIRouter(opts)
142142

143143
routes := []struct {
144144
method string
@@ -328,7 +328,7 @@ func TestAPIV1TestSuite(t *testing.T) {
328328
}
329329

330330
func TestNewDefaultAPIV1Router(t *testing.T) {
331-
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry)
331+
client := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry)
332332
assert.NotNil(t, client)
333333
}
334334

@@ -353,7 +353,7 @@ func TestNewDefaultAPIV1RouterInvalidHandlerConfig(t *testing.T) {
353353
EnableNotifications: false,
354354
EnableOverrides: false,
355355
}
356-
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry)
356+
client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry)
357357
assert.Nil(t, client)
358358
}
359359

@@ -368,12 +368,12 @@ func TestNewDefaultClientRouterInvalidMiddlewareConfig(t *testing.T) {
368368
EnableNotifications: false,
369369
EnableOverrides: false,
370370
}
371-
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry)
371+
client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry)
372372
assert.Nil(t, client)
373373
}
374374

375375
func TestForbiddenRoutes(t *testing.T) {
376-
mux := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry)
376+
mux := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry)
377377

378378
routes := []struct {
379379
method string

0 commit comments

Comments
 (0)