diff --git a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java index 451d8868a..fdf5a2bab 100644 --- a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java +++ b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/GrpcServiceStubs.java @@ -16,8 +16,6 @@ package com.uber.cadence.internal.compatibility.proto.serviceclient; import com.google.common.base.Strings; -import com.google.protobuf.ByteString; -import com.uber.cadence.api.v1.*; import com.uber.cadence.api.v1.DomainAPIGrpc; import com.uber.cadence.api.v1.MetaAPIGrpc; import com.uber.cadence.api.v1.MetaAPIGrpc.MetaAPIBlockingStub; @@ -32,7 +30,6 @@ import com.uber.cadence.api.v1.WorkflowAPIGrpc.WorkflowAPIBlockingStub; import com.uber.cadence.api.v1.WorkflowAPIGrpc.WorkflowAPIFutureStub; import com.uber.cadence.internal.Version; -import com.uber.cadence.internal.tracing.TracingPropagator; import com.uber.cadence.serviceclient.ClientOptions; import com.uber.cadence.serviceclient.auth.IAuthorizationProvider; import io.grpc.*; @@ -41,13 +38,9 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.propagation.TextMapPropagator; import io.opentelemetry.context.propagation.TextMapSetter; -import io.opentracing.Scope; -import io.opentracing.Span; import io.opentracing.Tracer; import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; -import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; @@ -116,6 +109,7 @@ final class GrpcServiceStubs implements IGrpcServiceStubs { if (!Strings.isNullOrEmpty(options.getIsolationGroup())) { headers.put(ISOLATION_GROUP_HEADER_KEY, options.getIsolationGroup()); } + mergeHeaders(headers, options.getHeaders()); Channel interceptedChannel = ClientInterceptors.intercept( @@ -205,117 +199,7 @@ public void start(Listener responseListener, Metadata headers) { } private ClientInterceptor newOpenTracingInterceptor(Tracer tracer) { - return new ClientInterceptor() { - private final TracingPropagator tracingPropagator = new TracingPropagator(tracer); - private final String OPERATIONFORMAT = "cadence-%s"; - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return new ForwardingClientCall.SimpleForwardingClientCall( - next.newCall(method, callOptions)) { - - @Override - public void start(Listener responseListener, Metadata headers) { - Span span = - tracingPropagator.spanByServiceMethod( - String.format(OPERATIONFORMAT, method.getBareMethodName())); - Scope scope = tracer.activateSpan(span); - super.start( - new ForwardingClientCallListener.SimpleForwardingClientCallListener( - responseListener) { - @Override - public void onClose(Status status, Metadata trailers) { - try { - super.onClose(status, trailers); - } finally { - span.finish(); - scope.close(); - } - } - }, - headers); - } - - @SuppressWarnings("unchecked") - @Override - public void sendMessage(ReqT message) { - if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecution") - && message instanceof StartWorkflowExecutionRequest) { - StartWorkflowExecutionRequest request = (StartWorkflowExecutionRequest) message; - Header newHeader = addTracingHeaders(request.getHeader()); - - // cast should not throw error as we are using the builder - message = (ReqT) request.toBuilder().setHeader(newHeader).build(); - } else if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecutionAsync") - && message instanceof StartWorkflowExecutionAsyncRequest) { - StartWorkflowExecutionAsyncRequest request = - (StartWorkflowExecutionAsyncRequest) message; - Header newHeader = addTracingHeaders(request.getRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setRequest(request.getRequest().toBuilder().setHeader(newHeader)) - .build(); - } else if (Objects.equals( - method.getBareMethodName(), "SignalWithStartWorkflowExecution") - && message instanceof SignalWithStartWorkflowExecutionRequest) { - SignalWithStartWorkflowExecutionRequest request = - (SignalWithStartWorkflowExecutionRequest) message; - Header newHeader = addTracingHeaders(request.getStartRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setStartRequest( - request.getStartRequest().toBuilder().setHeader(newHeader)) - .build(); - } else if (Objects.equals( - method.getBareMethodName(), "SignalWithStartWorkflowExecutionAsync") - && message instanceof SignalWithStartWorkflowExecutionAsyncRequest) { - SignalWithStartWorkflowExecutionAsyncRequest request = - (SignalWithStartWorkflowExecutionAsyncRequest) message; - Header newHeader = - addTracingHeaders(request.getRequest().getStartRequest().getHeader()); - - // cast should not throw error as we are using the builder - message = - (ReqT) - request - .toBuilder() - .setRequest( - request - .getRequest() - .toBuilder() - .setStartRequest( - request - .getRequest() - .getStartRequest() - .toBuilder() - .setHeader(newHeader))) - .build(); - } - super.sendMessage(message); - } - - private Header addTracingHeaders(Header header) { - Map headers = new HashMap<>(); - tracingPropagator.inject(headers); - Header.Builder headerBuilder = header.toBuilder(); - headers.forEach( - (k, v) -> - headerBuilder.putFields( - k, Payload.newBuilder().setData(ByteString.copyFrom(v)).build())); - return headerBuilder.build(); - } - }; - } - }; + return new OpenTracingInterceptor(tracer); } private ClientInterceptor newTracingInterceptor() { @@ -488,4 +372,22 @@ public ClientCall interceptCall( return next.newCall(method, callOptions.withDeadlineAfter(duration, TimeUnit.MILLISECONDS)); } } + + private static void mergeHeaders(Metadata metadata, Map headers) { + if (headers == null) { + return; + } + for (Map.Entry entry : headers.entrySet()) { + Metadata.Key key = Metadata.Key.of(entry.getKey(), Metadata.ASCII_STRING_MARSHALLER); + // Allow headers to overwrite any defaults + if (metadata.containsKey(key)) { + metadata.removeAll(key); + } + // Only replace it if they specify a value. + // This allows for removing headers + if (!Strings.isNullOrEmpty(entry.getValue())) { + metadata.put(key, entry.getValue()); + } + } + } } diff --git a/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java new file mode 100644 index 000000000..553e49b03 --- /dev/null +++ b/src/main/java/com/uber/cadence/internal/compatibility/proto/serviceclient/OpenTracingInterceptor.java @@ -0,0 +1,231 @@ +/* + * Modifications Copyright (c) 2017-2021 Uber Technologies Inc. + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package com.uber.cadence.internal.compatibility.proto.serviceclient; + +import com.google.protobuf.ByteString; +import com.uber.cadence.api.v1.Header; +import com.uber.cadence.api.v1.Payload; +import com.uber.cadence.api.v1.SignalWithStartWorkflowExecutionAsyncRequest; +import com.uber.cadence.api.v1.SignalWithStartWorkflowExecutionRequest; +import com.uber.cadence.api.v1.StartWorkflowExecutionAsyncRequest; +import com.uber.cadence.api.v1.StartWorkflowExecutionRequest; +import com.uber.cadence.internal.tracing.TracingPropagator; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.Tracer; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.Nullable; + +final class OpenTracingInterceptor implements ClientInterceptor { + private static final String OPERATION_FORMAT = "cadence-%s"; + private final Tracer tracer; + private final TracingPropagator tracingPropagator; + + OpenTracingInterceptor(Tracer tracer) { + this.tracer = tracer; + this.tracingPropagator = new TracingPropagator(tracer); + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + + Span span = + tracingPropagator.spanByServiceMethod( + String.format(OPERATION_FORMAT, method.getBareMethodName())); + try (Scope ignored = tracer.activateSpan(span)) { + return new OpenTracingClientCall<>(next, method, callOptions, span); + } + } + + private class OpenTracingClientCall + extends ForwardingClientCall.SimpleForwardingClientCall { + + private final AtomicBoolean finished = new AtomicBoolean(); + private final MethodDescriptor method; + private final Span span; + + public OpenTracingClientCall( + Channel next, MethodDescriptor method, CallOptions callOptions, Span span) { + super(next.newCall(method, callOptions)); + this.method = method; + this.span = span; + } + + @Override + public void start(Listener responseListener, Metadata headers) { + try (Scope ignored = tracer.activateSpan(span)) { + super.start( + new ForwardingClientCallListener.SimpleForwardingClientCallListener( + responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + try { + super.onClose(status, trailers); + } finally { + finishSpan(); + } + } + }, + headers); + } + } + + @Override + public void request(int numMessages) { + try (Scope ignored = tracer.activateSpan(span)) { + super.request(numMessages); + } + } + + @Override + public void setMessageCompression(boolean enabled) { + try (Scope ignored = tracer.activateSpan(span)) { + super.setMessageCompression(enabled); + } + } + + @Override + public boolean isReady() { + try (Scope ignored = tracer.activateSpan(span)) { + return super.isReady(); + } + } + + @Override + public Attributes getAttributes() { + try (Scope ignored = tracer.activateSpan(span)) { + return super.getAttributes(); + } + } + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + try (Scope ignored = tracer.activateSpan(span)) { + super.cancel(message, cause); + } finally { + finishSpan(); + } + } + + @Override + public void halfClose() { + try (Scope ignored = tracer.activateSpan(span)) { + super.halfClose(); + } + } + + @Override + public void sendMessage(ReqT message) { + try (Scope ignored = tracer.activateSpan(span)) { + message = replaceMessage(message); + super.sendMessage(message); + } + } + + private void finishSpan() { + // Some combination of cancel and onClose can be called so ensure we only finish once + if (finished.compareAndSet(false, true)) { + span.finish(); + } + } + + @SuppressWarnings("unchecked") + private ReqT replaceMessage(ReqT message) { + if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecution") + && message instanceof StartWorkflowExecutionRequest) { + StartWorkflowExecutionRequest request = (StartWorkflowExecutionRequest) message; + Header newHeader = addTracingHeaders(request.getHeader()); + + // cast should not throw error as we are using the builder + message = (ReqT) request.toBuilder().setHeader(newHeader).build(); + } else if (Objects.equals(method.getBareMethodName(), "StartWorkflowExecutionAsync") + && message instanceof StartWorkflowExecutionAsyncRequest) { + StartWorkflowExecutionAsyncRequest request = (StartWorkflowExecutionAsyncRequest) message; + Header newHeader = addTracingHeaders(request.getRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setRequest(request.getRequest().toBuilder().setHeader(newHeader)) + .build(); + } else if (Objects.equals(method.getBareMethodName(), "SignalWithStartWorkflowExecution") + && message instanceof SignalWithStartWorkflowExecutionRequest) { + SignalWithStartWorkflowExecutionRequest request = + (SignalWithStartWorkflowExecutionRequest) message; + Header newHeader = addTracingHeaders(request.getStartRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setStartRequest(request.getStartRequest().toBuilder().setHeader(newHeader)) + .build(); + } else if (Objects.equals(method.getBareMethodName(), "SignalWithStartWorkflowExecutionAsync") + && message instanceof SignalWithStartWorkflowExecutionAsyncRequest) { + SignalWithStartWorkflowExecutionAsyncRequest request = + (SignalWithStartWorkflowExecutionAsyncRequest) message; + Header newHeader = addTracingHeaders(request.getRequest().getStartRequest().getHeader()); + + // cast should not throw error as we are using the builder + message = + (ReqT) + request + .toBuilder() + .setRequest( + request + .getRequest() + .toBuilder() + .setStartRequest( + request + .getRequest() + .getStartRequest() + .toBuilder() + .setHeader(newHeader))) + .build(); + } + + return message; + } + + private Header addTracingHeaders(Header header) { + Map headers = new HashMap<>(); + tracingPropagator.inject(headers); + Header.Builder headerBuilder = header.toBuilder(); + headers.forEach( + (k, v) -> + headerBuilder.putFields( + k, Payload.newBuilder().setData(ByteString.copyFrom(v)).build())); + return headerBuilder.build(); + } + }; +} diff --git a/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java b/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java index d28b69804..5098145f8 100644 --- a/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java +++ b/src/test/java/com/uber/cadence/internal/compatibility/Thrift2ProtoAdapterTest.java @@ -25,6 +25,7 @@ import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; +import com.google.common.collect.ImmutableMap; import com.uber.cadence.AccessDeniedError; import com.uber.cadence.RefreshWorkflowTasksRequest; import com.uber.cadence.SignalWithStartWorkflowExecutionAsyncRequest; @@ -84,6 +85,13 @@ public class Thrift2ProtoAdapterTest { private static final Metadata.Key AUTHORIZATION_HEADER_KEY = Metadata.Key.of("cadence-authorization", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key EXPECTED_HEADER_KEY = + Metadata.Key.of("rpc-service", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVED_HEADER_KEY = + Metadata.Key.of("rpc-caller", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADDED_HEADER_KEY = + Metadata.Key.of("from-options", Metadata.ASCII_STRING_MARSHALLER); + private static final String ADDED_HEADER_VALUE = "added-value"; private static final StatusRuntimeException GRPC_ACCESS_DENIED = new StatusRuntimeException(Status.PERMISSION_DENIED); @@ -107,11 +115,14 @@ public void setup() { (Logger) LoggerFactory.getLogger( "com.uber.cadence.internal.compatibility.proto.serviceclient.GrpcServiceStubs"); + Map headers = + ImmutableMap.of(REMOVED_HEADER_KEY.name(), "", ADDED_HEADER_KEY.name(), ADDED_HEADER_VALUE); logger.setLevel(Level.TRACE); client = new Thrift2ProtoAdapter( IGrpcServiceStubs.newInstance( ClientOptions.newBuilder() + .setHeaders(headers) .setAuthorizationProvider("foo"::getBytes) .setGRPCChannel(clientChannel) .build())); @@ -119,6 +130,7 @@ public void setup() { new Thrift2ProtoAdapter( IGrpcServiceStubs.newInstance( ClientOptions.newBuilder() + .setHeaders(headers) .setAuthorizationProvider("foo"::getBytes) .setTracer(tracer) .setGRPCChannel(clientChannel) @@ -1020,7 +1032,9 @@ public Server createServer(ServiceDescriptor... descriptors) { } serverBuilder.addService( ServerInterceptors.intercept( - serviceDefinition.build(), new AuthHeaderValidatingInterceptor())); + serviceDefinition.build(), + new AuthHeaderValidatingInterceptor(), + new HeaderValidatingInterceptor())); } return serverBuilder.build().start(); } catch (IOException e) { @@ -1058,7 +1072,41 @@ private static class AuthHeaderValidatingInterceptor implements ServerIntercepto public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { if (!headers.containsKey(AUTHORIZATION_HEADER_KEY)) { - call.close(Status.INVALID_ARGUMENT, new Metadata()); + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Missing auth header: " + AUTHORIZATION_HEADER_KEY.name()), + new Metadata()); + } + return next.startCall(call, headers); + } + } + + private static class HeaderValidatingInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + if (!headers.containsKey(EXPECTED_HEADER_KEY)) { + call.close( + Status.INVALID_ARGUMENT.withDescription("Missing " + EXPECTED_HEADER_KEY.name()), + new Metadata()); + } + String addedHeaderValue = headers.get(ADDED_HEADER_KEY); + if (!ADDED_HEADER_VALUE.equals(addedHeaderValue)) { + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Incorrect value for " + + ADDED_HEADER_KEY.name() + + "; got " + + addedHeaderValue + + " instead of " + + ADDED_HEADER_VALUE), + new Metadata()); + } + if (headers.containsKey(REMOVED_HEADER_KEY)) { + call.close( + Status.INVALID_ARGUMENT.withDescription( + "Unexpected header " + REMOVED_HEADER_KEY.name()), + new Metadata()); } return next.startCall(call, headers); }