-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat(google-auth): grpc cert rotation handling #16597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
54a513a
b0eafcf
271c279
1d6a2fe
cbd868e
bcb6e41
f6c49cb
fe8deed
7d39f3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| from __future__ import absolute_import | ||
|
|
||
| import logging | ||
|
|
||
| import threading | ||
| from google.auth import exceptions | ||
| from google.auth.transport import _mtls_helper | ||
| from google.oauth2 import service_account | ||
|
|
@@ -208,7 +208,7 @@ def my_client_cert_callback(): | |
|
|
||
| channel = google.auth.transport.grpc.secure_authorized_channel( | ||
| credentials, request, mtls_endpoint) | ||
|
|
||
| Args: | ||
| credentials (google.auth.credentials.Credentials): The credentials to | ||
| add to requests. | ||
|
|
@@ -253,6 +253,7 @@ def my_client_cert_callback(): | |
| ) | ||
|
|
||
| # If SSL credentials are not explicitly set, try client_cert_callback and ADC. | ||
| cached_cert = None | ||
| if not ssl_credentials: | ||
| use_client_cert = _mtls_helper.check_use_client_cert() | ||
| if use_client_cert and client_cert_callback: | ||
|
|
@@ -261,20 +262,40 @@ def my_client_cert_callback(): | |
| ssl_credentials = grpc.ssl_channel_credentials( | ||
| certificate_chain=cert, private_key=key | ||
| ) | ||
| cached_cert = cert | ||
| elif use_client_cert: | ||
| # Use application default SSL credentials. | ||
| adc_ssl_credentils = SslCredentials() | ||
| ssl_credentials = adc_ssl_credentils.ssl_credentials | ||
| cached_cert = adc_ssl_credentils._cached_cert | ||
|
agrawalradhika-cell marked this conversation as resolved.
Outdated
|
||
| else: | ||
| ssl_credentials = grpc.ssl_channel_credentials() | ||
|
|
||
| # Combine the ssl credentials and the authorization credentials. | ||
| composite_credentials = grpc.composite_channel_credentials( | ||
| ssl_credentials, google_auth_credentials | ||
| ) | ||
|
|
||
| return grpc.secure_channel(target, composite_credentials, **kwargs) | ||
|
|
||
| is_retry = kwargs.pop("_is_retry", False) | ||
| channel = grpc.secure_channel(target, composite_credentials, **kwargs) | ||
| # Check if we are already inside a retry to avoid infinite recursion | ||
| if cached_cert and not is_retry: | ||
| # Package arguments to recreate the channel if rotation occurs | ||
| factory_args = { | ||
| "credentials": credentials, | ||
| "request": request, | ||
| "target": target, | ||
| "ssl_credentials": None, | ||
| "client_cert_callback": client_cert_callback, | ||
| "_is_retry": True, # Hidden flag to stop recursion | ||
| **kwargs | ||
| } | ||
| interceptor = _MTLSCallInterceptor(cached_cert) | ||
|
agrawalradhika-cell marked this conversation as resolved.
Outdated
|
||
|
|
||
| wrapper = _MTLSRefreshingChannel(target, factory_args, channel, cached_cert) | ||
|
|
||
| interceptor._wrapper = wrapper | ||
| return grpc.intercept_channel(wrapper, interceptor) | ||
| return channel | ||
|
|
||
| class SslCredentials: | ||
| """Class for application default SSL credentials. | ||
|
|
@@ -292,6 +313,7 @@ class SslCredentials: | |
|
|
||
| def __init__(self): | ||
| use_client_cert = _mtls_helper.check_use_client_cert() | ||
| self._cached_cert = None | ||
| if not use_client_cert: | ||
| self._is_mtls = False | ||
| else: | ||
|
|
@@ -323,6 +345,7 @@ def ssl_credentials(self): | |
| self._ssl_credentials = grpc.ssl_channel_credentials( | ||
| certificate_chain=cert, private_key=key | ||
| ) | ||
| self._cached_cert = cert | ||
| except exceptions.ClientCertError as caught_exc: | ||
| new_exc = exceptions.MutualTLSChannelError(caught_exc) | ||
| raise new_exc from caught_exc | ||
|
|
@@ -335,3 +358,78 @@ def ssl_credentials(self): | |
| def is_mtls(self): | ||
| """Indicates if the created SSL channel credentials is mutual TLS.""" | ||
| return self._is_mtls | ||
|
|
||
| class _MTLSCallInterceptor(grpc.UnaryUnaryClientInterceptor): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| def __init__(self, cached_cert): | ||
| self._cached_cert = cached_cert | ||
|
agrawalradhika-cell marked this conversation as resolved.
Outdated
|
||
| self._wrapper = None | ||
| self._max_retries = 2 # Set your desired limit here | ||
|
|
||
| def _should_retry(self, code, retry_count): | ||
| if code != grpc.StatusCode.UNAUTHENTICATED or not self._wrapper: | ||
| return False | ||
|
|
||
| if retry_count >= self._max_retries: | ||
| _LOGGER.debug("Max retries reached (%d/%d).", retry_count, self._max_retries) | ||
| return False | ||
|
|
||
| # Fingerprint check logic | ||
| _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._wrapper._cached_cert) | ||
| return cached_fp != current_fp | ||
|
|
||
| def intercept_unary_unary(self, continuation, client_call_details, request): | ||
| retry_count = 0 | ||
|
|
||
| while True: | ||
| try: | ||
| # Every time we call continuation(), our Wrapper (which is the channel | ||
| # being intercepted) will point to its CURRENT active raw channel. | ||
| response = continuation(client_call_details, request) | ||
| status_code = response.code() | ||
| except grpc.RpcError as e: | ||
| status_code = e.code() | ||
| if not self._should_retry(status_code, retry_count): | ||
| raise e | ||
| # If we should retry, we fall through to the refresh logic below | ||
|
|
||
| if self._should_retry(status_code, retry_count): | ||
| retry_count += 1 | ||
| # Tell the wrapper to swap the channel. | ||
| # We don't need the wrapper to execute the retry; the loop does it! | ||
| self._wrapper.refresh_logic(retry_count) | ||
| continue # Jump back to the start of the while loop | ||
|
|
||
| return response | ||
|
|
||
| class _MTLSRefreshingChannel(grpc.Channel): | ||
| def __init__(self, target, factory_args, initial_channel, initial_cert): | ||
| self._target = target | ||
| self._factory_args = factory_args | ||
| self._channel = initial_channel | ||
| self._cached_cert = initial_cert | ||
| self._lock = threading.Lock() | ||
|
Comment on lines
+406
to
+411
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def __init__(self, target, factory_args, initial_channel, initial_cert):
self._target = target
self._factory_args = factory_args
self._channel = initial_channel
self._cached_cert = initial_cert
self._lock = threading.Lock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close() |
||
|
|
||
| def refresh_logic(self, count): | ||
| with self._lock: | ||
| # Re-check inside lock to prevent race conditions | ||
| _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._cached_cert) | ||
| if cached_fp != current_fp: | ||
| print(f"Wrapper: Refreshing mTLS channel. Retry count: {count}") | ||
|
agrawalradhika-cell marked this conversation as resolved.
Outdated
|
||
| old_channel = self._channel | ||
| self._channel = secure_authorized_channel(**self._factory_args) | ||
|
|
||
| creds = _mtls_helper.get_client_ssl_credentials() | ||
| self._cached_cert = creds[1] | ||
| old_channel.close() | ||
|
|
||
| def unary_unary(self, method, *args, **kwargs): | ||
| # Always return a callable from the CURRENT channel | ||
| return self._channel.unary_unary(method, *args, **kwargs) | ||
|
|
||
| # Mandatory passthroughs | ||
| def unary_stream(self, method, *args, **kwargs): return self._channel.unary_stream(method, *args, **kwargs) | ||
| def stream_unary(self, method, *args, **kwargs): return self._channel.stream_unary(method, *args, **kwargs) | ||
| def stream_stream(self, method, *args, **kwargs): return self._channel.stream_stream(method, *args, **kwargs) | ||
| def subscribe(self, *args, **kwargs): return self._channel.subscribe(*args, **kwargs) | ||
| def unsubscribe(self, *args, **kwargs): return self._channel.unsubscribe(*args, **kwargs) | ||
| def close(self): self._channel.close() | ||
Uh oh!
There was an error while loading. Please reload this page.