Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}

# Compile a registry of transports.
Expand All @@ -29,6 +30,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I imagined we'd be able to pass a list of interceptors that could be chained together, but understand that requires additional architecture here. We can handle the chaining within the body of the pre_/post_ functions.

Copy link
Copy Markdown
Contributor Author

@software-dov software-dov Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, that's easily done by something like this:

class InterceptorChainer:
    def __init__(self, chain):
        assert all(isinstance(i, RestInterceptor) for i in chain)
        # Make our own copy to prevent external modification
        self.chain = list(chain)
        
    def __getattr__(self, name):
        if name.startswith("pre_"):
            def pre(request, metadata):
                for i in self.chain:
                    request, metadata = getattr(i, name)(request, metadata)
                return request, metadata
                
            return pre
                
        elif name.startswith("post_"):
            def post(response):
                for i in self.chain:
                    response = getattr(i, name)(response)
                return response
            
            return post
            
        else:
            raise AttributeError(f"No such attribute: {name}")

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from google.api_core import grpc_helpers
from google.api_core import path_template
{% if service.has_lro %}
from google.api_core import future
from google.api_core import operation
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% if "rest" in opts.transport %}
Expand Down Expand Up @@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields():

{% endif %}{# required_fields #}

{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("null_interceptor", [True, False])
def test_{{ method_name }}_rest_interceptors(null_interceptor):
transport = transports.{{ service.name }}RestTransport(
credentials=ga_credentials.AnonymousCredentials(),
interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(),
)
client = {{ service.client_name }}(transport=transport)
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
{% if method.lro %}
mock.patch.object(operation.Operation, "_set_result_from_operation"), \
{% endif %}
{% if not method.void %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
{% endif %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
pre.assert_not_called()
{% if not method.void %}
post.assert_not_called()
{% endif %}

transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}

req.return_value = Response()
req.return_value.status_code = 200
req.return_value.request = PreparedRequest()
{% if not method.void %}
req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %}
{% endif %}

request = {{ method.input.ident }}()
metadata =[
("key", "val"),
("cephalopod", "squid"),
]
pre.return_value = request, metadata
{% if not method.void %}
post.return_value = {{ method.output.ident }}
{% endif %}

client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
{% if not method.void %}
post.assert_called_once()
{% endif %}
{% endif %}{# streaming #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}


Expand All @@ -34,6 +35,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Loading