Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.auth import _helpers
from google.auth import environment_vars

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: NO COVER
import google.auth.credentials
import google.auth.transport

Expand Down Expand Up @@ -455,6 +455,61 @@ def start_refresh(self, credentials, request, rab_manager):
self._worker.start()


def _prepare_async_lookup_callable(request):
"""Unwraps a request callable, clones the transport, and returns the new callable.

Args:
request: The original request callable (e.g. functools.partial or raw Request).

Returns:
Tuple[Callable, Any, bool]: A tuple containing the new lookup callable, the
underlying request object, and a boolean indicating if it was cloned.
"""
is_partial = isinstance(request, functools.partial)
base_callable = request.func if is_partial else request

if not hasattr(base_callable, "_clone"):
return request, base_callable, False

cloned_callable = base_callable._clone()
is_cloned = cloned_callable is not base_callable

if is_partial:
new_request = functools.partial(
cloned_callable, *request.args, **request.keywords
)
else:
new_request = cloned_callable

return new_request, cloned_callable, is_cloned


async def _close_cloned_request(lookup_request, is_cloned):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It seems like _prepare_async_lookup_callable and _close_cloned_request would be a good fit for a context manager. That would let us encapsulate these three variables, and enforce automatic closing.

Gemini put this together:

@contextlib.asynccontextmanager
async def _managed_lookup_callable(request):
    """An async context manager that prepares a cloned lookup callable 
    and guarantees its transport is closed on exit.
    """
    lookup_callable, lookup_request, is_cloned = _prepare_async_lookup_callable(request)
    try:
        yield lookup_callable
    finally:
        await _close_cloned_request(lookup_request, is_cloned)


# ... Inside your class/function where _worker is defined:

async def _worker():
    try:
        async with _managed_lookup_callable(request) as lookup_callable:
            regional_access_boundary_info = (
                await credentials._lookup_regional_access_boundary(lookup_callable)
            )
    except Exception as e:
        if _helpers.is_logging_enabled(_LOGGER):
            _LOGGER.warning(
                "Failed regional access boundary lookup: %s", 
                e, 
                exc_info=True
            )
        regional_access_boundary_info = None

But this is just a suggestion that came to mind, I think it's fine to merge as-is too.

"""Safely closes the underlying cloned request transport, if applicable.

Args:
lookup_request (Any): The request object/transport to close.
is_cloned (bool): Whether the request was actually cloned.
"""
if not is_cloned or not hasattr(lookup_request, "close"):
return

is_async = False
try:
maybe_coro = lookup_request.close()
if is_async := inspect.iscoroutine(maybe_coro):
Comment on lines +499 to +500

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we are missing valid awaitable cases here as written (e.g. if Future returned by custom transports)

await maybe_coro
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
adapter_type = " asynchronous " if is_async else " "
_LOGGER.warning(
"Failed to cleanly close cloned%srequest transport: %s",
adapter_type,
e,
exc_info=True,
)


class _AsyncRegionalAccessBoundaryRefreshManager(object):
"""Manages a task for background refreshing of the Regional Access Boundary in async flows."""

Expand Down Expand Up @@ -492,10 +547,18 @@ def start_refresh(self, credentials, request, rab_manager):
return

async def _worker():
lookup_request = None
is_cloned = False
try:
# credentials._lookup_regional_access_boundary should be async in the async creds class
(
lookup_callable,
lookup_request,
is_cloned,
) = _prepare_async_lookup_callable(request)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ClientSession can still be closed and the request will still fail with the current implementation as with the request being created inside the background worker, start_refresh() will potentially return before _clone() runs.

regional_access_boundary_info = (
await credentials._lookup_regional_access_boundary(request)
await credentials._lookup_regional_access_boundary(
lookup_callable
)
)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
Expand All @@ -505,6 +568,8 @@ async def _worker():
exc_info=True,
)
regional_access_boundary_info = None
finally:
await _close_cloned_request(lookup_request, is_cloned)

rab_manager.process_regional_access_boundary_info(
regional_access_boundary_info
Expand Down
10 changes: 10 additions & 0 deletions packages/google-auth/google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,13 @@ async def close(self) -> None:
Close the underlying session.
"""
raise NotImplementedError("close must be implemented.")

def _clone(self) -> "Request":
"""Creates a copy of this request adapter.

The base implementation returns `self` (an identical shared instance).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I still think a name other than clone should be considered. Gemini suggests _isolate() or _branch(). But this doesn't matter too much if it's internal

Transport adapters that maintain internal connection pools or stateful
sessions must override this method to return an independent, detached
adapter instance.
"""
return self
Comment thread
nbayati marked this conversation as resolved.
73 changes: 72 additions & 1 deletion packages/google-auth/google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
else:
try:
from aiohttp import ClientTimeout
except (ImportError, AttributeError):
except (ImportError, AttributeError): # pragma: NO COVER
ClientTimeout = None

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -203,3 +203,74 @@ async def close(self) -> None:
if not self._closed and self._session:
await self._session.close()
self._closed = True

def _clone(self) -> "Request":
"""Creates an independent copy of this request adapter.

Returns:
google.auth.aio.transport.aiohttp.Request: A request adapter copy
running a new aiohttp.ClientSession with identical connection,
proxy, and session configurations.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring can be reworded. I think the current implementation of clone does not actually clone everything. Also doesn't mention unsupported cases.

"""
if self._closed:
raise exceptions.TransportError("Cannot clone a closed transport.")

if not self._session:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)
return Request(session=new_session)

session_kwargs: dict = {
"auto_decompress": False,
"trust_env": getattr(self._session, "_trust_env", True),
}

# Copy underlying connection pool settings (SSL context, IP bindings, limits).
orig_connector = getattr(self._session, "_connector", None)
if orig_connector and not orig_connector.closed:
if isinstance(orig_connector, aiohttp.TCPConnector):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _clone() implementation explicitly checks only for standard TCPConnector and UnixConnector instances. If the original session is configured with a custom, proxy, or subclassed connector (such as corporate SOCKS or tunneling proxies), the check falls through and the cloned session is created with a default, direct-connection TCPConnector.

This silently drops the proxy/custom configuration and routes traffic directly over the public internet, which will fail or violate security constraints in enterprise/isolated cloud environments. We should either explicitly support proxy preservation or raise a clear transport exception if an unsupported custom connector is detected.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point.

We can't really support proxy preservation because third-party aiohttp connectors have arbitrary, unknown constructor signatures (meaning we have no way to instantiate a fresh detached copy of them dynamically), and simply shallow-copying the existing connector is unsafe due to shared socket pools. This leaves us two options: fallback to re-using the customer transport and hope that we don't encounter the bug this PR is trying to fix, or raise the exception as you suggested and accept this as a limitation of RAB.

I've decided not to fallback to re-using the customer's transport if we can't clone it, because it's not just that the RAB call would fail, but also there's another risk: if the foreground task closes the session while the background worker is actively reading from it, the forceful socket truncation mid-flight can leave complex corporate proxy connections in a hung or corrupted state, which means that the affects won't be limited to our RAB calls. So I've added the else: raise exceptions.TransportError(...) block, as raising the error here is the safest path. The exception will trigger the 15-minute cooldown and allow the user's main request to proceed safely.

I thought about disabling RAB permanently if we can't clone the transport (thinking what's the point of entering cooldown if we're going to keep trying to clone it and fail), but decided against it. I realized that because credentials objects are frequently instantiated globally and shared across multiple different clients and API surfaces, there's a chance that the next call would be executed over entirely different transports, making the RAB call possible.

# We explicitly do not copy the resolver. The connector
# owns the resolver, and closing the cloned session would
# close the shared resolver, breaking the original session.
session_kwargs["connector"] = aiohttp.TCPConnector(
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
limit=getattr(orig_connector, "_limit", 100),
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
force_close=getattr(orig_connector, "_force_close", False),
local_addr=getattr(orig_connector, "_local_addr", None),
)
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
orig_connector, getattr(aiohttp, "UnixConnector")
):
path = getattr(orig_connector, "_path", None)
if path:
session_kwargs["connector"] = aiohttp.UnixConnector(
path=path,
limit=getattr(orig_connector, "_limit", 100),
force_close=getattr(orig_connector, "_force_close", False),
)
else:
raise exceptions.TransportError(
f"Unsupported connector type for cloning: {type(orig_connector)}"
)

# Preserve distributed tracing configurations.
trace_configs = getattr(self._session, "_trace_configs", None)
if trace_configs:
session_kwargs["trace_configs"] = list(trace_configs)

# Copy session-level defaults (headers, cookies, auth, timeout).
for attr_name, kwarg_name in [
("_default_headers", "headers"),
("_cookie_jar", "cookie_jar"),
("_default_auth", "auth"),
("_timeout", "timeout"),
("_json_serialize", "json_serialize"),
]:
val = getattr(self._session, attr_name, None)
if val is not None:
session_kwargs[kwarg_name] = val

return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore
77 changes: 77 additions & 0 deletions packages/google-auth/google/auth/transport/_aiohttp_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,83 @@ async def __call__(
new_exc = exceptions.TransportError(caught_exc)
raise new_exc from caught_exc

def _clone(self):
"""Create an independent detached copy of this request adapter.

Returns:
google.auth.transport._aiohttp_requests.Request: An independent request adapter
running an isolated aiohttp.ClientSession with identical environment proxy and
observability configurations.
"""
if getattr(self, "_closed", False):
raise exceptions.TransportError("Cannot clone a closed transport.")

if not self.session:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)
return Request(session=new_session)

session_kwargs: dict = {
"auto_decompress": False,
"trust_env": getattr(self.session, "_trust_env", True),
}

# Copy underlying connection pool settings (SSL context, IP bindings, limits).
orig_connector = getattr(self.session, "_connector", None)
if orig_connector and not getattr(orig_connector, "closed", True):
if isinstance(orig_connector, aiohttp.TCPConnector):
# We explicitly do not copy the resolver. The connector
# owns the resolver, and closing the cloned session would
# close the shared resolver, breaking the original session.
session_kwargs["connector"] = aiohttp.TCPConnector(
ssl=getattr(orig_connector, "_ssl", None), # type: ignore
limit=getattr(orig_connector, "_limit", 100),
limit_per_host=getattr(orig_connector, "_limit_per_host", 0),
force_close=getattr(orig_connector, "_force_close", False),
local_addr=getattr(orig_connector, "_local_addr", None),
)
elif getattr(aiohttp, "UnixConnector", None) and isinstance(
orig_connector, getattr(aiohttp, "UnixConnector")
):
path = getattr(orig_connector, "_path", None)
if path:
session_kwargs["connector"] = aiohttp.UnixConnector(
path=path,
limit=getattr(orig_connector, "_limit", 100),
force_close=getattr(orig_connector, "_force_close", False),
)
else:
raise exceptions.TransportError(
f"Unsupported connector type for cloning: {type(orig_connector)}"
)

# Preserve distributed tracing configurations.
trace_configs = getattr(self.session, "_trace_configs", None)
if trace_configs:
session_kwargs["trace_configs"] = list(trace_configs)

# Copy session-level defaults (headers, cookies, auth, timeout).
for attr_name, kwarg_name in [
("_default_headers", "headers"),
("_cookie_jar", "cookie_jar"),
("_default_auth", "auth"),
("_timeout", "timeout"),
("_json_serialize", "json_serialize"),
]:
val = getattr(self.session, attr_name, None)
if val is not None:
session_kwargs[kwarg_name] = val

return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore

async def close(self):
"""Cleanly release the underlying aiohttp ClientSession resources."""
if not getattr(self, "_closed", False) and self.session:
await self.session.close()
self._closed = True


class AuthorizedSession(aiohttp.ClientSession):
"""This is an async implementation of the Authorized Session class. We utilize an
Expand Down
Loading
Loading