From d533ea0db7e3ea891747e5b5fff2c52c718d6507 Mon Sep 17 00:00:00 2001 From: roblabla Date: Thu, 21 Jan 2021 19:53:40 +0100 Subject: [PATCH 1/4] Allow disabling generation of grpclib stubs --- src/betterproto/plugin/models.py | 7 ++++++- src/betterproto/plugin/parser.py | 15 +++++++++++++-- src/betterproto/templates/template.py.j2 | 8 +++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 840140043..7f25efa73 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -203,9 +203,14 @@ def comment(self) -> str: @dataclass -class PluginRequestCompiler: +class Options: + grpc_kind: str = "grpclib" + +@dataclass +class PluginRequestCompiler: plugin_request_obj: CodeGeneratorRequest + options: Options output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) @property diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 21a2caf14..527448420 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -21,6 +21,7 @@ FieldCompiler, MapEntryCompiler, MessageCompiler, + Options, OneOfFieldCompiler, OutputTemplate, PluginRequestCompiler, @@ -33,7 +34,6 @@ if TYPE_CHECKING: from google.protobuf.descriptor import Descriptor - def traverse( proto_file: FieldDescriptorProto, ) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": @@ -60,6 +60,12 @@ def _traverse( _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) ) +def parse_options(plugin_options: List[str]) -> Options: + options = Options() + for option in plugin_options: + if option.startswith("grpc="): + options.grpc_kind = option.split("=", 1)[1] + return options def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: response = CodeGeneratorResponse() @@ -67,7 +73,12 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: plugin_options = request.parameter.split(",") if request.parameter else [] response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL - request_data = PluginRequestCompiler(plugin_request_obj=request) + options = parse_options(plugin_options) + + request_data = PluginRequestCompiler( + plugin_request_obj=request, + options=options + ) # Gather output packages for proto_file in request.proto_file: if ( diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index d27cff610..b7e9dcd48 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -15,8 +15,10 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} import betterproto +{% if output_file.parent_request.options.grpc_kind == "grpclib" %} from betterproto.grpc.grpclib_server import ServiceBase -{% if output_file.services %} +{% endif %} +{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpclib" %} import grpclib {% endif %} @@ -68,6 +70,9 @@ class {{ message.py_name }}(betterproto.Message): {% endfor %} + +{% if output_file.parent_request.options.grpc_kind == "grpclib" %} + {% for service in output_file.services %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} @@ -239,6 +244,7 @@ class {{ service.py_name }}Base(ServiceBase): } {% endfor %} +{% endif %} {% for i in output_file.imports|sort %} {{ i }} From a5b89ee00a1ad414d52e24fc56bbb7b61763532e Mon Sep 17 00:00:00 2001 From: roblabla Date: Thu, 21 Jan 2021 19:53:49 +0100 Subject: [PATCH 2/4] Add grpcio base generation support --- src/betterproto/plugin/models.py | 4 ++ src/betterproto/templates/template.py.j2 | 57 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 7f25efa73..a1e10ebfb 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -662,6 +662,10 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name + @property + def proto_path(self) -> str: + return self.parent.package + "." + self.proto_name + @property def py_name(self) -> str: return pythonize_class_name(self.proto_name) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index b7e9dcd48..359a94b9d 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -21,6 +21,9 @@ from betterproto.grpc.grpclib_server import ServiceBase {% if output_file.services and output_file.parent_request.options.grpc_kind == "grpclib" %} import grpclib {% endif %} +{% if output_file.services and output_file.parent_request.options.grpc_kind == "grpcio" %} +import grpc +{% endif %} {% if output_file.enums %}{% for enum in output_file.enums %} @@ -246,6 +249,60 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} {% endif %} +{% if output_file.parent_request.options.grpc_kind == "grpcio" %} +{% for service in output_file.services %} +class {{ service.py_name }}Base: + {% if service.comment %} +{{ service.comment }} + + {% endif %} + + {% for method in service.methods %} + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , request: "{{ method.py_input_message_type }}" + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] + {%- endif -%} + , context: grpc.aio.ServicerContext + ) -> {% if method.server_streaming %}AsyncGenerator["{{ method.py_output_message_type }}", None]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + {% endfor %} + + def register_to_server(self, server: grpc.aio.Server): + rpc_method_handlers = { + {% for method in service.methods %} + "{{ method.proto_name }}": + {% if not method.client_streaming and not method.server_streaming %} + grpc.unary_unary_rpc_method_handler( + {% elif method.client_streaming and method.server_streaming %} + grpc.stream_stream_rpc_method_handler( + {% elif method.client_streaming %} + grpc.stream_unary_rpc_method_handler( + {% else %} + grpc.unary_stream_rpc_method_handler( + {% endif %} + self.{{ method.py_name }}, + request_deserializer={{ method.py_input_message_type }}.FromString, + response_serializer={{ method.py_input_message_type }}.SerializeToString, + ), + {% endfor %} + } + generic_handler = grpc.method_handlers_generic_handler( + "{{ service.proto_path }}", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + +{% endfor %} +{% endif %} + {% for i in output_file.imports|sort %} {{ i }} {% endfor %} From 2882a175b241e07906b914b110ae93620c0dd27c Mon Sep 17 00:00:00 2001 From: roblabla Date: Mon, 25 Jan 2021 10:51:03 +0100 Subject: [PATCH 3/4] Centralize parsing of options INCLUDE_GOOGLE option is now parsed in parse_options. --- src/betterproto/plugin/models.py | 1 + src/betterproto/plugin/parser.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index a1e10ebfb..21d0ad89d 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -205,6 +205,7 @@ def comment(self) -> str: @dataclass class Options: grpc_kind: str = "grpclib" + include_google: bool = False @dataclass diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 527448420..7e4701a39 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from google.protobuf.descriptor import Descriptor + def traverse( proto_file: FieldDescriptorProto, ) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": @@ -60,11 +61,14 @@ def _traverse( _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) ) + def parse_options(plugin_options: List[str]) -> Options: options = Options() for option in plugin_options: if option.startswith("grpc="): options.grpc_kind = option.split("=", 1)[1] + if option == "INCLUDE_GOOGLE": + options.include_google = True return options def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: @@ -76,15 +80,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: options = parse_options(plugin_options) request_data = PluginRequestCompiler( - plugin_request_obj=request, - options=options + plugin_request_obj=request, options=options ) # Gather output packages for proto_file in request.proto_file: - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): + if proto_file.package == "google.protobuf" and options.include_google: # If not INCLUDE_GOOGLE, # skip re-compiling Google's well-known types continue From 883cf852a139ccb431666d1fe6c5a2b4a6da6217 Mon Sep 17 00:00:00 2001 From: roblabla Date: Mon, 25 Jan 2021 15:04:43 +0100 Subject: [PATCH 4/4] Refactor how registering grpcio servicer works We now have a grpcio module in betterproto that defines the ServicerBase Abstract Base Class and a register_servicers function. Register_servicers takes multiple ServicerBase instances, and registers them to the grpc AIO server. The generated servicers now inherit from ServicerBase. --- pyproject.toml | 4 ++- src/betterproto/grpc/grpcio_server.py | 34 ++++++++++++++++++++++++ src/betterproto/templates/template.py.j2 | 13 ++++----- 3 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 src/betterproto/grpc/grpcio_server.py diff --git a/pyproject.toml b/pyproject.toml index e948e9384..422a39c60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,10 +15,12 @@ packages = [ python = ">=3.6.2,<4.0" black = { version = ">=19.3b0", optional = true } dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } +grpcio = { version = "^1.43.0", optional = true } grpclib = "^0.4.1" jinja2 = { version = "^2.11.2", optional = true } python-dateutil = "^2.8" + [tool.poetry.dev-dependencies] asv = "^0.4.2" black = "^21.11b0" @@ -43,7 +45,7 @@ protoc-gen-python_betterproto = "betterproto.plugin:main" [tool.poetry.extras] compiler = ["black", "jinja2"] - +grpcio = ["grpcio"] # Dev workflow tasks diff --git a/src/betterproto/grpc/grpcio_server.py b/src/betterproto/grpc/grpcio_server.py new file mode 100644 index 000000000..70b55289d --- /dev/null +++ b/src/betterproto/grpc/grpcio_server.py @@ -0,0 +1,34 @@ +from typing import Dict, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + import grpc + + +class ServicerBase(ABC): + """ + Base class for async grpcio servers. + """ + + @property + @abstractmethod + def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: + ... + + @property + @abstractmethod + def __proto_path__(self) -> str: + ... + + +def register_servicers(server: "grpc.aio.Server", *servicers: ServicerBase): + from grpc import method_handlers_generic_handler + + server.add_generic_rpc_handlers( + tuple( + method_handlers_generic_handler( + servicer.__proto_path__, servicer.__rpc_handlers__ + ) + for servicer in servicers + ) + ) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 359a94b9d..26f35d76a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -23,6 +23,7 @@ import grpclib {% endif %} {% if output_file.services and output_file.parent_request.options.grpc_kind == "grpcio" %} import grpc +from betterproto.grpc.grpcio_server import ServicerBase {% endif %} @@ -251,7 +252,7 @@ class {{ service.py_name }}Base(ServiceBase): {% if output_file.parent_request.options.grpc_kind == "grpcio" %} {% for service in output_file.services %} -class {{ service.py_name }}Base: +class {{ service.py_name }}Base(ServicerBase): {% if service.comment %} {{ service.comment }} @@ -277,8 +278,11 @@ class {{ service.py_name }}Base: {% endfor %} - def register_to_server(self, server: grpc.aio.Server): - rpc_method_handlers = { + __proto_path__ = "{{ service.proto_path }}" + + @property + def __rpc_methods__(self): + return { {% for method in service.methods %} "{{ method.proto_name }}": {% if not method.client_streaming and not method.server_streaming %} @@ -296,9 +300,6 @@ class {{ service.py_name }}Base: ), {% endfor %} } - generic_handler = grpc.method_handlers_generic_handler( - "{{ service.proto_path }}", rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) {% endfor %} {% endif %}