From c4f1aad8fc874a4164b6cb9e10b99537bdf5430f Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Fri, 6 Nov 2020 03:27:03 +0100 Subject: [PATCH 01/13] added servicer stub generation --- src/betterproto/__init__.py | 1 + src/betterproto/grpc/grpclib_server.py | 29 +++++ src/betterproto/plugin/models.py | 8 ++ src/betterproto/templates/template.py.j2 | 91 +++++++++++++++ .../example_service/example_service.proto | 20 ++++ .../example_service/test_example_service.py | 107 ++++++++++++++++++ 6 files changed, 256 insertions(+) create mode 100644 src/betterproto/grpc/grpclib_server.py create mode 100644 tests/inputs/example_service/example_service.proto create mode 100644 tests/inputs/example_service/test_example_service.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index e6daeb3a9..d906166fa 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -25,6 +25,7 @@ from ._types import T from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub +from .grpc.grpclib_server import ServiceImplementation if sys.version_info[:2] < (3, 7): # Apply backport of datetime.fromisoformat from 3.7 diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py new file mode 100644 index 000000000..080db0060 --- /dev/null +++ b/src/betterproto/grpc/grpclib_server.py @@ -0,0 +1,29 @@ +from abc import ABC + +import grpclib +import grpclib.server + + +class ServiceImplementation(ABC): + """ + Base class for async gRPC servers. + """ + + __service_name__: str + + def __rpc_methods__(self): + pass + + def __mapping__(self): + mapping = {} + for ( + method, + proto_name, + cardinality, + request_type, + response_type, + ) in self.__rpc_methods__(): + mapping[f"/{self.__service_name__}/{proto_name}"] = grpclib.const.Handler( + method, cardinality, request_type, response_type + ) + return mapping diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index bf314051e..a82b73ddb 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -279,6 +279,10 @@ def proto_name(self) -> str: def py_name(self) -> str: return pythonize_class_name(self.proto_name) + @property + def py_name_as_field(self) -> str: + return pythonize_field_name(self.proto_name) + @property def annotation(self) -> str: if self.repeated: @@ -559,6 +563,10 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name + @property + def full_proto_name(self) -> str: + return f"{self.parent.package_proto_obj.package}.{self.proto_obj.name}" + @property def py_name(self) -> str: return pythonize_class_name(self.proto_name) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 753d340c7..46681c142 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -154,6 +154,97 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% endfor %} +{% for service in output_file.services %} +class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): + {% if service.comment %} +{{ service.comment }} + + {% endif %} + + __service_name__ = "{{ service.full_proto_name }}" + + {% for method in service.methods %} + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + {%- if method.py_input_message and method.py_input_message.fields -%}, + {%- for field in method.py_input_message.fields -%} + {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} + Optional[{{ field.annotation }}] + {%- else -%} + {{ field.annotation }} + {%- endif -%} + {%- if not loop.last %}, {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , {{ method.py_input_message.py_name_as_field }}_iterator: AsyncIterable["{{ method.py_input_message_type }}"] + {%- endif -%} + ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} + raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + + {% endfor %} + + {% for method in service.methods %} + async def __rpc_{{ method.py_name }}(self, stream): + {% if not method.client_streaming %} + request = await stream.recv_message() + + request_kwargs = { + {% for field in method.py_input_message.fields %} + "{{ field.py_name }}": request.{{ field.py_name }}, + {% endfor %} + } + + {% else %} + request_kwargs = {"{{ method.py_input_message.py_name_as_field }}_iterator": stream.__aiter__()} + {% endif %} + + {% if not method.server_streaming %} + response = await self.{{ method.py_name }}(**request_kwargs) + await stream.send_message(response) + {% else %} + response_iter = self.{{ method.py_name }}(**request_kwargs) + {# check if response is actually an AsyncIterator #} + {# this might be false if the method just returns without #} + {# yielding at least once #} + {# in that case, we just interpret it as an empty iterator #} + if isinstance(response_iter, AsyncIterable): + async for response_message in response_iter: + await stream.send_message(response_message) + else: + response_iter.close() + {% endif %} + + {% endfor %} + + def __rpc_methods__(self): + return [ + {% for method in service.methods %} + ( + self.__rpc_{{ method.py_name }}, + "{{ method.proto_name }}", + {% if not method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.UNARY_UNARY, + {% elif not method.client_streaming and method.server_streaming %} + grpclib.const.Cardinality.UNARY_STREAM, + {% elif method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.STREAM_UNARY, + {% else %} + grpclib.const.Cardinality.STREAM_STREAM, + {% endif %} + {{ method.py_input_message_type }}, + {{ method.py_output_message_type }}, + ), + {% endfor %} + ] + +{% endfor %} + {% for i in output_file.imports|sort %} {{ i }} {% endfor %} diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto new file mode 100644 index 000000000..b5cd0bb54 --- /dev/null +++ b/tests/inputs/example_service/example_service.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package example_service; + +service ExampleService { + rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); + rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); + rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); + rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); +} + +message ExampleRequest { + string example_string = 1; + int64 example_integer = 2; +} + +message ExampleResponse { + string example_string = 1; + int64 example_integer = 2; +} diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py new file mode 100644 index 000000000..5ec8fe085 --- /dev/null +++ b/tests/inputs/example_service/test_example_service.py @@ -0,0 +1,107 @@ +import asyncio +from typing import AsyncIterator, AsyncIterable + +from grpclib.client import Channel +from grpclib.server import Server + +from tests.output_betterproto.example_service.example_service import ( + ExampleServiceImplementation, + ExampleServiceStub, + ExampleRequest, + ExampleResponse, +) + + +class ExampleService(ExampleServiceImplementation): + async def example_unary_unary( + self, example_string: str, example_integer: int + ) -> "ExampleResponse": + return ExampleResponse( + example_string=example_string, + example_integer=example_integer, + ) + + async def example_unary_stream( + self, example_string: str, example_integer: int + ) -> AsyncIterator["ExampleResponse"]: + response = ExampleResponse( + example_string=example_string, + example_integer=example_integer, + ) + yield response + yield response + yield response + + async def example_stream_unary( + self, example_request_iterator: AsyncIterable["ExampleRequest"] + ) -> "ExampleResponse": + async for example_request in example_request_iterator: + return ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + async def example_stream_stream( + self, example_request_iterator: AsyncIterable["ExampleRequest"] + ) -> AsyncIterator["ExampleResponse"]: + async for example_request in example_request_iterator: + yield ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + +async def async_test_server_start(): + host = "localhost" + port = 133337 + + test_string = "test string" + test_int = 42 + + # start server + server = Server([ExampleService()]) + await server.start(host, port) + + # start client + channel = Channel(host=host, port=port) + stub = ExampleServiceStub(channel) + + # unary unary + response = await stub.example_unary_unary( + example_string="test string", + example_integer=42, + ) + assert response.example_string == test_string + assert response.example_integer == test_int + + # unary stream + async for response in stub.example_unary_stream( + example_string="test string", + example_integer=42, + ): + assert response.example_string == test_string + assert response.example_integer == test_int + + # stream unary + request = ExampleRequest( + example_string=test_string, + example_integer=42, + ) + + async def request_iterator(): + yield request + yield request + yield request + + response = await stub.example_stream_unary(request_iterator()) + assert response.example_string == test_string + assert response.example_integer == test_int + + # stream stream + async for response in stub.example_stream_stream(request_iterator()): + assert response.example_string == test_string + assert response.example_integer == test_int + + +def test_server_start(): + asyncio.run(async_test_server_start()) From 694be73d0b16e7cc4340be3fcde6284311f9c025 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Fri, 6 Nov 2020 04:14:28 +0100 Subject: [PATCH 02/13] hacky workaround to make test compatible with test framework --- tests/inputs/example_service/example_service.proto | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto index b5cd0bb54..dc9aaad17 100644 --- a/tests/inputs/example_service/example_service.proto +++ b/tests/inputs/example_service/example_service.proto @@ -18,3 +18,6 @@ message ExampleResponse { string example_string = 1; int64 example_integer = 2; } + +// Suppress test framework error when it's looking for a "Test" message or service +message Test {} From 2030b8712a1555dc5abadbdca0457e1a0fa9c8cc Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Fri, 6 Nov 2020 04:49:27 +0100 Subject: [PATCH 03/13] make test compatible with macOS and python 3.6 --- tests/inputs/example_service/test_example_service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py index 5ec8fe085..9786ad923 100644 --- a/tests/inputs/example_service/test_example_service.py +++ b/tests/inputs/example_service/test_example_service.py @@ -52,8 +52,8 @@ async def example_stream_stream( async def async_test_server_start(): - host = "localhost" - port = 133337 + host = "127.0.0.1" + port = 13337 test_string = "test string" test_int = 42 @@ -104,4 +104,5 @@ async def request_iterator(): def test_server_start(): - asyncio.run(async_test_server_start()) + loop = asyncio.get_event_loop() + loop.run_until_complete(async_test_server_start()) From 207233b2e4f548d04f1c09edcb4ba6a8e10c1969 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Sat, 21 Nov 2020 09:17:22 +0100 Subject: [PATCH 04/13] cut down on duplicate code - merged __service_name__ and __rpc_methods__ into __mapping__ - __mapping__ is now part of generated class instead of ABC - duplicate code for managing the grpclib stream moved from generated class into ABC --- src/betterproto/grpc/grpclib_server.py | 43 ++++++++++------- src/betterproto/plugin/models.py | 1 + src/betterproto/templates/template.py.j2 | 61 +++++++++++------------- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index 080db0060..a94b491fd 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -1,4 +1,6 @@ from abc import ABC +from collections import AsyncIterable +from typing import Callable, Any, Dict import grpclib import grpclib.server @@ -9,21 +11,30 @@ class ServiceImplementation(ABC): Base class for async gRPC servers. """ - __service_name__: str + async def __call_rpc_handler_server_unary( + self, + handler: Callable, + stream: grpclib.server.Stream, + request_kwargs: Dict[str, Any], + ) -> None: - def __rpc_methods__(self): - pass + response = await handler(**request_kwargs) + await stream.send_message(response) - def __mapping__(self): - mapping = {} - for ( - method, - proto_name, - cardinality, - request_type, - response_type, - ) in self.__rpc_methods__(): - mapping[f"/{self.__service_name__}/{proto_name}"] = grpclib.const.Handler( - method, cardinality, request_type, response_type - ) - return mapping + async def __call_rpc_handler_server_stream( + self, + handler: Callable, + stream: grpclib.server.Stream, + request_kwargs: Dict[str, Any], + ) -> None: + + response_iter = handler(**request_kwargs) + # check if response is actually an AsyncIterator + # this might be false if the method just returns without + # yielding at least once + # in that case, we just interpret it as an empty iterator + if isinstance(response_iter, AsyncIterable): + async for response_message in response_iter: + await stream.send_message(response_message) + else: + response_iter.close() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index a82b73ddb..e19190d42 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -557,6 +557,7 @@ class ServiceCompiler(ProtoContentBase): def __post_init__(self) -> None: # Add service to output file self.output_file.services.append(self) + self.output_file.typing_imports.add("Dict") super().__post_init__() # check for unset fields @property diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 46681c142..79a66190a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -82,7 +82,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): Optional[{{ field.annotation }}] {%- else -%} {{ field.annotation }} - {%- endif -%} = + {%- endif -%} = {%- if field.py_name not in method.mutable_default_args -%} {{ field.default_value_string }} {%- else -%} @@ -161,8 +161,6 @@ class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): {% endif %} - __service_name__ = "{{ service.full_proto_name }}" - {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} @@ -205,43 +203,40 @@ class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): {% endif %} {% if not method.server_streaming %} - response = await self.{{ method.py_name }}(**request_kwargs) - await stream.send_message(response) + await self.__call_rpc_handler_server_unary( + self.{{ method.py_name }}, + stream, + request_kwargs, + ) {% else %} - response_iter = self.{{ method.py_name }}(**request_kwargs) - {# check if response is actually an AsyncIterator #} - {# this might be false if the method just returns without #} - {# yielding at least once #} - {# in that case, we just interpret it as an empty iterator #} - if isinstance(response_iter, AsyncIterable): - async for response_message in response_iter: - await stream.send_message(response_message) - else: - response_iter.close() + await self.__call_rpc_handler_server_stream( + self.{{ method.py_name }}, + stream, + request_kwargs, + ) {% endif %} {% endfor %} - def __rpc_methods__(self): - return [ + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { {% for method in service.methods %} - ( - self.__rpc_{{ method.py_name }}, - "{{ method.proto_name }}", - {% if not method.client_streaming and not method.server_streaming %} - grpclib.const.Cardinality.UNARY_UNARY, - {% elif not method.client_streaming and method.server_streaming %} - grpclib.const.Cardinality.UNARY_STREAM, - {% elif method.client_streaming and not method.server_streaming %} - grpclib.const.Cardinality.STREAM_UNARY, - {% else %} - grpclib.const.Cardinality.STREAM_STREAM, - {% endif %} - {{ method.py_input_message_type }}, - {{ method.py_output_message_type }}, - ), + "/{{ service.full_proto_name }}/{{ method.proto_name }}": grpclib.const.Handler( + self.__rpc_{{ method.py_name }}, + {% if not method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.UNARY_UNARY, + {% elif not method.client_streaming and method.server_streaming %} + grpclib.const.Cardinality.UNARY_STREAM, + {% elif method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.STREAM_UNARY, + {% else %} + grpclib.const.Cardinality.STREAM_STREAM, + {% endif %} + {{ method.py_input_message_type }}, + {{ method.py_output_message_type }}, + ), {% endfor %} - ] + } {% endfor %} From 61e17995d55c071dfe72723bcf1389b263ccb786 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Sat, 21 Nov 2020 09:30:06 +0100 Subject: [PATCH 05/13] fix: remove name mangling from ServiceImplementation ABC also removed left-over from trying to work around the test framework --- src/betterproto/grpc/grpclib_server.py | 4 ++-- src/betterproto/templates/template.py.j2 | 4 ++-- tests/inputs/example_service/example_service.proto | 3 --- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index a94b491fd..5d4e21b05 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -11,7 +11,7 @@ class ServiceImplementation(ABC): Base class for async gRPC servers. """ - async def __call_rpc_handler_server_unary( + async def _call_rpc_handler_server_unary( self, handler: Callable, stream: grpclib.server.Stream, @@ -21,7 +21,7 @@ async def __call_rpc_handler_server_unary( response = await handler(**request_kwargs) await stream.send_message(response) - async def __call_rpc_handler_server_stream( + async def _call_rpc_handler_server_stream( self, handler: Callable, stream: grpclib.server.Stream, diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 79a66190a..1ad1995b9 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -203,13 +203,13 @@ class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): {% endif %} {% if not method.server_streaming %} - await self.__call_rpc_handler_server_unary( + await self._call_rpc_handler_server_unary( self.{{ method.py_name }}, stream, request_kwargs, ) {% else %} - await self.__call_rpc_handler_server_stream( + await self._call_rpc_handler_server_stream( self.{{ method.py_name }}, stream, request_kwargs, diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto index dc9aaad17..b5cd0bb54 100644 --- a/tests/inputs/example_service/example_service.proto +++ b/tests/inputs/example_service/example_service.proto @@ -18,6 +18,3 @@ message ExampleResponse { string example_string = 1; int64 example_integer = 2; } - -// Suppress test framework error when it's looking for a "Test" message or service -message Test {} From a54d0ce8fa0f8d957cd0ad4c6a60825c46a14aac Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Sat, 21 Nov 2020 09:48:18 +0100 Subject: [PATCH 06/13] return type hint for wrapper; use method.route --- src/betterproto/templates/template.py.j2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 1ad1995b9..447ac3c02 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -188,7 +188,7 @@ class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): {% endfor %} {% for method in service.methods %} - async def __rpc_{{ method.py_name }}(self, stream): + async def __rpc_{{ method.py_name }}(self, stream) -> None: {% if not method.client_streaming %} request = await stream.recv_message() @@ -221,7 +221,7 @@ class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { {% for method in service.methods %} - "/{{ service.full_proto_name }}/{{ method.proto_name }}": grpclib.const.Handler( + "{{ method.route }}": grpclib.const.Handler( self.__rpc_{{ method.py_name }}, {% if not method.client_streaming and not method.server_streaming %} grpclib.const.Cardinality.UNARY_UNARY, From 54af07440b8b51eb7d655691570c25ed07dcf4f7 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Sun, 22 Nov 2020 04:21:21 +0100 Subject: [PATCH 07/13] renamed ServiceImplementation; refactored test added test framework workaround back in --- src/betterproto/__init__.py | 1 - src/betterproto/grpc/__init__.py | 1 + src/betterproto/grpc/grpclib_server.py | 2 +- src/betterproto/templates/template.py.j2 | 2 +- .../example_service/example_service.proto | 3 + .../example_service/test_example_service.py | 85 ++++++++----------- 6 files changed, 42 insertions(+), 52 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index d906166fa..e6daeb3a9 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -25,7 +25,6 @@ from ._types import T from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub -from .grpc.grpclib_server import ServiceImplementation if sys.version_info[:2] < (3, 7): # Apply backport of datetime.fromisoformat from 3.7 diff --git a/src/betterproto/grpc/__init__.py b/src/betterproto/grpc/__init__.py index e69de29bb..d851cac0c 100644 --- a/src/betterproto/grpc/__init__.py +++ b/src/betterproto/grpc/__init__.py @@ -0,0 +1 @@ +from .grpclib_server import ServiceBase diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index 5d4e21b05..53c081a8a 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -6,7 +6,7 @@ import grpclib.server -class ServiceImplementation(ABC): +class ServiceBase(ABC): """ Base class for async gRPC servers. """ diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 447ac3c02..724526bc1 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -155,7 +155,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% for service in output_file.services %} -class {{ service.py_name }}Implementation(betterproto.ServiceImplementation): +class {{ service.py_name }}Base(betterproto.grpc.ServiceBase): {% if service.comment %} {{ service.comment }} diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto index b5cd0bb54..dc9aaad17 100644 --- a/tests/inputs/example_service/example_service.proto +++ b/tests/inputs/example_service/example_service.proto @@ -18,3 +18,6 @@ message ExampleResponse { string example_string = 1; int64 example_integer = 2; } + +// Suppress test framework error when it's looking for a "Test" message or service +message Test {} diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py index 9786ad923..0ea343b06 100644 --- a/tests/inputs/example_service/test_example_service.py +++ b/tests/inputs/example_service/test_example_service.py @@ -1,18 +1,17 @@ -import asyncio from typing import AsyncIterator, AsyncIterable -from grpclib.client import Channel -from grpclib.server import Server +import pytest +from grpclib.testing import ChannelFor from tests.output_betterproto.example_service.example_service import ( - ExampleServiceImplementation, + ExampleServiceBase, ExampleServiceStub, ExampleRequest, ExampleResponse, ) -class ExampleService(ExampleServiceImplementation): +class ExampleService(ExampleServiceBase): async def example_unary_unary( self, example_string: str, example_integer: int ) -> "ExampleResponse": @@ -51,58 +50,46 @@ async def example_stream_stream( ) -async def async_test_server_start(): - host = "127.0.0.1" - port = 13337 - +@pytest.mark.asyncio +async def test_calls_with_different_cardinalities(): test_string = "test string" test_int = 42 - # start server - server = Server([ExampleService()]) - await server.start(host, port) - - # start client - channel = Channel(host=host, port=port) - stub = ExampleServiceStub(channel) - - # unary unary - response = await stub.example_unary_unary( - example_string="test string", - example_integer=42, - ) - assert response.example_string == test_string - assert response.example_integer == test_int - - # unary stream - async for response in stub.example_unary_stream( - example_string="test string", - example_integer=42, - ): + async with ChannelFor([ExampleService()]) as channel: + stub = ExampleServiceStub(channel) + + # unary unary + response = await stub.example_unary_unary( + example_string="test string", + example_integer=42, + ) assert response.example_string == test_string assert response.example_integer == test_int - # stream unary - request = ExampleRequest( - example_string=test_string, - example_integer=42, - ) - - async def request_iterator(): - yield request - yield request - yield request + # unary stream + async for response in stub.example_unary_stream( + example_string="test string", + example_integer=42, + ): + assert response.example_string == test_string + assert response.example_integer == test_int + + # stream unary + request = ExampleRequest( + example_string=test_string, + example_integer=42, + ) - response = await stub.example_stream_unary(request_iterator()) - assert response.example_string == test_string - assert response.example_integer == test_int + async def request_iterator(): + yield request + yield request + yield request - # stream stream - async for response in stub.example_stream_stream(request_iterator()): + response = await stub.example_stream_unary(request_iterator()) assert response.example_string == test_string assert response.example_integer == test_int - -def test_server_start(): - loop = asyncio.get_event_loop() - loop.run_until_complete(async_test_server_start()) + # stream stream + async for response in stub.example_stream_stream(request_iterator()): + assert response.example_string == test_string + assert response.example_integer == test_int From 2a8754279fac6cf774b05830ac14627ebb8c3544 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Mon, 23 Nov 2020 05:58:07 +0100 Subject: [PATCH 08/13] moved ServiceBase import into generated file --- src/betterproto/grpc/__init__.py | 1 - src/betterproto/templates/template.py.j2 | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/betterproto/grpc/__init__.py b/src/betterproto/grpc/__init__.py index d851cac0c..e69de29bb 100644 --- a/src/betterproto/grpc/__init__.py +++ b/src/betterproto/grpc/__init__.py @@ -1 +0,0 @@ -from .grpclib_server import ServiceBase diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 724526bc1..62d32dce4 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} import betterproto +from betterproto.grpc.grpclib_server import ServiceBase {% if output_file.services %} import grpclib {% endif %} @@ -155,7 +156,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% for service in output_file.services %} -class {{ service.py_name }}Base(betterproto.grpc.ServiceBase): +class {{ service.py_name }}Base(ServiceBase): {% if service.comment %} {{ service.comment }} From 5bbe19ab601140d4f636bd558ff71e23a84548e3 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Thu, 26 Nov 2020 08:08:01 +0100 Subject: [PATCH 09/13] renames to improve consistency; inlined unary server response translate --- src/betterproto/grpc/grpclib_server.py | 10 ---------- src/betterproto/templates/template.py.j2 | 11 ++++------- tests/inputs/example_service/example_service.proto | 5 +---- tests/inputs/example_service/test_example_service.py | 8 ++++---- 4 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index 53c081a8a..59bc7d435 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -11,16 +11,6 @@ class ServiceBase(ABC): Base class for async gRPC servers. """ - async def _call_rpc_handler_server_unary( - self, - handler: Callable, - stream: grpclib.server.Stream, - request_kwargs: Dict[str, Any], - ) -> None: - - response = await handler(**request_kwargs) - await stream.send_message(response) - async def _call_rpc_handler_server_stream( self, handler: Callable, diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 62d32dce4..11e46f2a3 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -177,7 +177,7 @@ class {{ service.py_name }}Base(ServiceBase): {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message.py_name_as_field }}_iterator: AsyncIterable["{{ method.py_input_message_type }}"] + , request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] {%- endif -%} ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} @@ -189,7 +189,7 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} {% for method in service.methods %} - async def __rpc_{{ method.py_name }}(self, stream) -> None: + async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: {% if not method.client_streaming %} request = await stream.recv_message() @@ -204,11 +204,8 @@ class {{ service.py_name }}Base(ServiceBase): {% endif %} {% if not method.server_streaming %} - await self._call_rpc_handler_server_unary( - self.{{ method.py_name }}, - stream, - request_kwargs, - ) + response = await self.{{ method.py_name }}(**request_kwargs) + await stream.send_message(response) {% else %} await self._call_rpc_handler_server_stream( self.{{ method.py_name }}, diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto index dc9aaad17..96455cc30 100644 --- a/tests/inputs/example_service/example_service.proto +++ b/tests/inputs/example_service/example_service.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package example_service; -service ExampleService { +service Test { rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); @@ -18,6 +18,3 @@ message ExampleResponse { string example_string = 1; int64 example_integer = 2; } - -// Suppress test framework error when it's looking for a "Test" message or service -message Test {} diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py index 0ea343b06..530f520d5 100644 --- a/tests/inputs/example_service/test_example_service.py +++ b/tests/inputs/example_service/test_example_service.py @@ -4,14 +4,14 @@ from grpclib.testing import ChannelFor from tests.output_betterproto.example_service.example_service import ( - ExampleServiceBase, - ExampleServiceStub, + TestBase, + TestStub, ExampleRequest, ExampleResponse, ) -class ExampleService(ExampleServiceBase): +class ExampleService(TestBase): async def example_unary_unary( self, example_string: str, example_integer: int ) -> "ExampleResponse": @@ -56,7 +56,7 @@ async def test_calls_with_different_cardinalities(): test_int = 42 async with ChannelFor([ExampleService()]) as channel: - stub = ExampleServiceStub(channel) + stub = TestStub(channel) # unary unary response = await stub.example_unary_unary( From b73bd0bb8d427c7b45d07f5ab9efc85c5a1e3f7b Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Tue, 1 Dec 2020 16:14:21 +0100 Subject: [PATCH 10/13] fixed test --- tests/inputs/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 7d1466736..9b7b288d3 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -17,4 +17,5 @@ "import_service_input_message", "googletypes_service_returns_empty", "googletypes_service_returns_googletype", + "example_service", } From 99e63798abecac213ca52a85e63d732f6853e3c0 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Thu, 3 Dec 2020 00:09:19 +0100 Subject: [PATCH 11/13] removed left-over method; fixed outdated parameter names --- src/betterproto/plugin/models.py | 4 ---- tests/inputs/example_service/test_example_service.py | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index e19190d42..dda40991d 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -564,10 +564,6 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name - @property - def full_proto_name(self) -> str: - return f"{self.parent.package_proto_obj.package}.{self.proto_obj.name}" - @property def py_name(self) -> str: return pythonize_class_name(self.proto_name) diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py index 530f520d5..12d646b14 100644 --- a/tests/inputs/example_service/test_example_service.py +++ b/tests/inputs/example_service/test_example_service.py @@ -32,18 +32,18 @@ async def example_unary_stream( yield response async def example_stream_unary( - self, example_request_iterator: AsyncIterable["ExampleRequest"] + self, request_iterator: AsyncIterator["ExampleRequest"] ) -> "ExampleResponse": - async for example_request in example_request_iterator: + async for example_request in request_iterator: return ExampleResponse( example_string=example_request.example_string, example_integer=example_request.example_integer, ) async def example_stream_stream( - self, example_request_iterator: AsyncIterable["ExampleRequest"] + self, request_iterator: AsyncIterator["ExampleRequest"] ) -> AsyncIterator["ExampleResponse"]: - async for example_request in example_request_iterator: + async for example_request in request_iterator: yield ExampleResponse( example_string=example_request.example_string, example_integer=example_request.example_integer, From 2ff45e78d0a839b2d79ddc6d63f16779fefd14aa Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Thu, 3 Dec 2020 00:12:21 +0100 Subject: [PATCH 12/13] fixed outdated parameter name in template --- src/betterproto/templates/template.py.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 11e46f2a3..de53963ed 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -200,7 +200,7 @@ class {{ service.py_name }}Base(ServiceBase): } {% else %} - request_kwargs = {"{{ method.py_input_message.py_name_as_field }}_iterator": stream.__aiter__()} + request_kwargs = {"request_iterator": stream.__aiter__()} {% endif %} {% if not method.server_streaming %} From eac05e9b5ba82a501e2e225cf8a27057f46bdbb0 Mon Sep 17 00:00:00 2001 From: Tim Schmidt Date: Fri, 4 Dec 2020 03:57:08 +0100 Subject: [PATCH 13/13] removed unused / obsolete property function --- src/betterproto/plugin/models.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index dda40991d..1fe55d04d 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -279,10 +279,6 @@ def proto_name(self) -> str: def py_name(self) -> str: return pythonize_class_name(self.proto_name) - @property - def py_name_as_field(self) -> str: - return pythonize_field_name(self.proto_name) - @property def annotation(self) -> str: if self.repeated: