Skip to content

Commit

Permalink
optimize the grpc plugin method call. (apache#1470)
Browse files Browse the repository at this point in the history
* optimize the Grpc plugin method call,support generic invoke with unary.

* optimize the Grpc plugin method call,support custom ServerBuilder.

* optimize the Grpc plugin method call,rewrite the code by review.
  • Loading branch information
midnight2104 authored May 21, 2021
1 parent 4cf09da commit 55ecf11
Show file tree
Hide file tree
Showing 23 changed files with 828 additions and 543 deletions.
11 changes: 5 additions & 6 deletions shenyu-client/shenyu-client-grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
<artifactId>shenyu-client-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.shenyu</groupId>
<artifactId>shenyu-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-beans</artifactId>
Expand All @@ -52,11 +57,5 @@
<artifactId>spring-core</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-all</artifactId>
<version>${grpc.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.shenyu.client.grpc;

import com.google.common.collect.Lists;
import io.grpc.BindableService;
import io.grpc.ServerServiceDefinition;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.shenyu.client.core.disruptor.ShenyuClientRegisterEventPublisher;
import org.apache.shenyu.client.grpc.common.annotation.ShenyuGrpcClient;
import org.apache.shenyu.client.grpc.common.dto.GrpcExt;
import org.apache.shenyu.client.grpc.json.JsonServerServiceInterceptor;
import org.apache.shenyu.common.utils.GsonUtils;
import org.apache.shenyu.common.utils.IpUtils;
import org.apache.shenyu.register.client.api.ShenyuClientRegisterRepository;
Expand All @@ -36,6 +40,7 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.ExecutorService;
Expand All @@ -61,7 +66,10 @@ public class GrpcClientBeanPostProcessor implements BeanPostProcessor {
private final String host;

private final int port;


@Getter
private List<ServerServiceDefinition> serviceDefinitions = Lists.newArrayList();

/**
* Instantiates a new Shenyu client bean post processor.
*
Expand All @@ -74,7 +82,7 @@ public GrpcClientBeanPostProcessor(final ShenyuRegisterCenterConfig config, fina
String ipAndPort = props.getProperty("ipAndPort");
String port = props.getProperty("port");
if (StringUtils.isEmpty(contextPath) || StringUtils.isEmpty(ipAndPort) || StringUtils.isEmpty(port)) {
throw new RuntimeException("tars client must config the contextPath, ipAndPort");
throw new RuntimeException("grpc client must config the contextPath, ipAndPort");
}
this.ipAndPort = ipAndPort;
this.contextPath = contextPath;
Expand All @@ -87,6 +95,7 @@ public GrpcClientBeanPostProcessor(final ShenyuRegisterCenterConfig config, fina
@Override
public Object postProcessAfterInitialization(@NonNull final Object bean, @NonNull final String beanName) throws BeansException {
if (bean instanceof BindableService) {
exportJsonGenericService(bean);
executorService.execute(() -> handler(bean));
}
return bean;
Expand Down Expand Up @@ -157,6 +166,17 @@ private String buildRpcExt(final ShenyuGrpcClient shenyuGrpcClient) {
GrpcExt build = GrpcExt.builder().timeout(shenyuGrpcClient.timeout()).build();
return GsonUtils.getInstance().toJson(build);
}
}

private void exportJsonGenericService(final Object bean) {
BindableService bindableService = (BindableService) bean;
ServerServiceDefinition serviceDefinition = bindableService.bindService();

try {
ServerServiceDefinition jsonDefinition = JsonServerServiceInterceptor.useJsonMessages(serviceDefinition);
serviceDefinitions.add(serviceDefinition);
serviceDefinitions.add(jsonDefinition);
} catch (Exception e) {
log.error("export json generic service is fail", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shenyu.client.grpc.json;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;
import io.grpc.Attributes;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.ServerCall;
import io.grpc.Metadata;

import lombok.extern.slf4j.Slf4j;
import org.apache.shenyu.common.message.JsonResponse;

/**
* Handle response of json generic service.
*
* @param <R> request message
* @param <P> response message
*/
@Slf4j
public class JsonForwardingServerCall<R, P> extends ServerCall<R, P> {

private final ServerCall<P, P> call;

public JsonForwardingServerCall(final ServerCall<P, P> call) {
this.call = call;
}

protected ServerCall<P, P> delegate() {
return call;
}

@SuppressWarnings("unchecked")
@Override
public void sendMessage(final P message) {
String jsonFormat;
JsonResponse rep;
try {
if (message == null) {
delegate().sendMessage(null);
return;
}

jsonFormat = JsonFormat.printer().includingDefaultValueFields().preservingProtoFieldNames()
.print((MessageOrBuilder) message);

rep = JsonResponse.newBuilder().setMessage(jsonFormat).build();
log.debug("begin send json response: {}", jsonFormat);
delegate().sendMessage((P) rep);
} catch (InvalidProtocolBufferException e) {
log.error("handle json message is error", e);
throw Status.INTERNAL.withDescription(e.getMessage()).asRuntimeException();
}
}

@Override
public void request(final int numMessages) {
delegate().request(numMessages);
}

@Override
public void sendHeaders(final Metadata headers) {
delegate().sendHeaders(headers);
}

@Override
public boolean isReady() {
return delegate().isReady();
}

@Override
public void close(final Status status, final Metadata trailers) {
delegate().close(status, trailers);
}

@Override
public boolean isCancelled() {
return delegate().isCancelled();
}

@Override
public void setMessageCompression(final boolean enabled) {
delegate().setMessageCompression(enabled);
}

@Override
public void setCompression(final String compressor) {
delegate().setCompression(compressor);
}

@Override
public Attributes getAttributes() {
return delegate().getAttributes();
}

@SuppressWarnings("unchecked")
@Override
public MethodDescriptor<R, P> getMethodDescriptor() {
return (MethodDescriptor<R, P>) delegate().getMethodDescriptor();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shenyu.client.grpc.json;

import com.google.common.collect.Maps;
import io.grpc.MethodDescriptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCallHandler;
import io.grpc.ServiceDescriptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.shenyu.common.constant.GrpcConstants;
import org.apache.shenyu.common.exception.ShenyuException;
import org.apache.shenyu.common.message.JsonRequest;
import org.apache.shenyu.common.utils.GrpcUtils;
import org.apache.shenyu.common.utils.ReflectUtils;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Support json generic invoke.
*/
@Slf4j
public class JsonServerServiceInterceptor {

private static Map<String, Class<?>> requestClazzMap = Maps.newConcurrentMap();

/**
* wrap ServerServiceDefinition to get json ServerServiceDefinition.
* @param serviceDef ServerServiceDefinition
* @return json ServerServiceDefinition
* @throws IllegalArgumentException IllegalArgumentException
* @throws IllegalAccessException IllegalAccessException
*/
public static ServerServiceDefinition useJsonMessages(final ServerServiceDefinition serviceDef)
throws IllegalArgumentException, IllegalAccessException {
return useMarshalledMessages(serviceDef,
io.grpc.protobuf.ProtoUtils.marshaller(GrpcUtils.createDefaultInstance(JsonRequest.class)));
}

/**
* wrap method.
* @param serviceDef ServerServiceDefinition
* @param marshaller message
* @param <T> message type
* @return wrap ServerServiceDefinition
* @throws IllegalArgumentException IllegalArgumentException
* @throws IllegalAccessException IllegalAccessException
*/
public static <T> ServerServiceDefinition useMarshalledMessages(final ServerServiceDefinition serviceDef,
final MethodDescriptor.Marshaller<T> marshaller)
throws IllegalArgumentException, IllegalAccessException {
List<ServerMethodDefinition<?, ?>> wrappedMethods = new ArrayList<>();
List<MethodDescriptor<?, ?>> wrappedDescriptors = new ArrayList<>();

// Wrap the descriptors
for (final ServerMethodDefinition<?, ?> definition : serviceDef.getMethods()) {
MethodDescriptor.Marshaller<?> requestMarshaller = definition.getMethodDescriptor().getRequestMarshaller();
Field defaultInstanceField = ReflectUtils.getField(requestMarshaller.getClass(), "defaultInstance");
if (Objects.isNull(defaultInstanceField)) {
throw new ShenyuException(String.format("can not get defaultInstance Field of %s", requestMarshaller.getClass()));
}

defaultInstanceField.setAccessible(true);

String fullMethodName = definition.getMethodDescriptor().getFullMethodName();
String[] splitMethodName = fullMethodName.split("/");
fullMethodName = splitMethodName[0] + GrpcConstants.GRPC_JSON_GENERIC_SERVICE + "/" + splitMethodName[1];
requestClazzMap.put(fullMethodName, defaultInstanceField.get(requestMarshaller).getClass());

final MethodDescriptor<?, ?> originalMethodDescriptor = definition.getMethodDescriptor();
final MethodDescriptor<T, T> wrappedMethodDescriptor = originalMethodDescriptor
.toBuilder(marshaller, marshaller).build();
wrappedDescriptors.add(wrappedMethodDescriptor);
wrappedMethods.add(wrapMethod(definition, wrappedMethodDescriptor));
}

// Build the new service descriptor
ServiceDescriptor.Builder build = ServiceDescriptor.newBuilder(serviceDef.getServiceDescriptor().getName() + GrpcConstants.GRPC_JSON_GENERIC_SERVICE);
for (MethodDescriptor<?, ?> md : wrappedDescriptors) {
Field fullMethodNameField = ReflectUtils.getField(md.getClass(), "fullMethodName");
if (Objects.isNull(fullMethodNameField)) {
throw new ShenyuException(String.format("can not get fullMethodName Field of %s", md.getClass()));
}
fullMethodNameField.setAccessible(true);
String fullMethodName = (String) fullMethodNameField.get(md);
String[] splitMethodName = fullMethodName.split("/");
fullMethodName = splitMethodName[0] + GrpcConstants.GRPC_JSON_GENERIC_SERVICE + "/" + splitMethodName[1];
fullMethodNameField.set(md, fullMethodName);

Field serviceNameField = ReflectUtils.getField(md.getClass(), "serviceName");
if (Objects.isNull(serviceNameField)) {
throw new ShenyuException(String.format("can not get serviceName Field Field of %s", md.getClass()));
}
serviceNameField.setAccessible(true);
String serviceName = (String) serviceNameField.get(md);
serviceName = serviceName + GrpcConstants.GRPC_JSON_GENERIC_SERVICE;
serviceNameField.set(md, serviceName);

build.addMethod(md);
}
final ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition
.builder(build.build());

// Create the new service definition
for (ServerMethodDefinition<?, ?> definition : wrappedMethods) {
serviceBuilder.addMethod(definition);
}
return serviceBuilder.build();
}

/**
* wrap Method.
* @param definition ServerMethodDefinition
* @param wrappedMethod MethodDescriptor
* @param <R> origin request message
* @param <P> origin response message
* @param <W> wrap request message
* @param <M> wrap response message
* @return wrap method
*/
private static <R, P, W, M> ServerMethodDefinition<W, M> wrapMethod(
final ServerMethodDefinition<R, P> definition,
final MethodDescriptor<W, M> wrappedMethod) {
final ServerCallHandler<W, M> wrappedHandler = wrapHandler(definition.getServerCallHandler()
);
return ServerMethodDefinition.create(wrappedMethod, wrappedHandler);
}

/**
* wrap handler.
* @param originalHandler original handler
* @param <R> origin request message
* @param <P> origin response message
* @param <W> wrap request message
* @param <M> wrap response message
* @return wrap handler
*/
@SuppressWarnings("unchecked")
private static <R, P, W, M> ServerCallHandler<W, M> wrapHandler(
final ServerCallHandler<R, P> originalHandler) {
return new ServerCallHandler<W, M>() {
@SuppressWarnings("rawtypes")
@Override
public ServerCall.Listener<W> startCall(final ServerCall<W, M> call, final Metadata headers) {
final ServerCall<R, P> unwrappedCall = new JsonForwardingServerCall<>((ServerCall<P, P>) call);
final ServerCall.Listener<R> originalListener = originalHandler.startCall(unwrappedCall, headers);
return new ServerJsonListener(originalListener, unwrappedCall);
}
};
}

/**
* get RequestClazzMap.
* @return requestClazzMap
*/
public static Map<String, Class<?>> getRequestClazzMap() {
return requestClazzMap;
}
}
Loading

0 comments on commit 55ecf11

Please sign in to comment.