Skip to content

Commit e6fe71a

Browse files
speedstorm1copybara-github
authored andcommitted
chore: process proxy and base url settings for file uploads
PiperOrigin-RevId: 882260641
1 parent 861ab18 commit e6fe71a

2 files changed

Lines changed: 116 additions & 10 deletions

File tree

google/genai/_api_client.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,21 @@ def _upload_fd(
15381538
The HttpResponse object from the finalize request.
15391539
"""
15401540
offset = 0
1541+
http_options = http_options if http_options else self._http_options
1542+
base_url = (
1543+
http_options.get('base_url')
1544+
if isinstance(http_options, dict)
1545+
else getattr(http_options, 'base_url', None)
1546+
)
1547+
if base_url:
1548+
parsed_base = urlparse(base_url)
1549+
parsed_upload = urlparse(upload_url)
1550+
upload_url = urlunparse(
1551+
parsed_upload._replace(
1552+
scheme=parsed_base.scheme, netloc=parsed_base.netloc
1553+
)
1554+
)
1555+
15411556
# Upload the file in chunks
15421557
while True:
15431558
file_chunk = file.read(CHUNK_SIZE)
@@ -1548,7 +1563,6 @@ def _upload_fd(
15481563
# If last chunk, finalize the upload.
15491564
if chunk_size + offset >= upload_size:
15501565
upload_command += ', finalize'
1551-
http_options = http_options if http_options else self._http_options
15521566
timeout = (
15531567
http_options.get('timeout')
15541568
if isinstance(http_options, dict)
@@ -1562,11 +1576,17 @@ def _upload_fd(
15621576
else self._http_options.timeout
15631577
)
15641578
timeout_in_seconds = get_timeout_in_seconds(timeout)
1565-
upload_headers = {
1579+
user_headers = (
1580+
http_options.get('headers', {})
1581+
if isinstance(http_options, dict)
1582+
else (getattr(http_options, 'headers', {}) or {})
1583+
)
1584+
upload_headers = dict(user_headers) if user_headers else {}
1585+
upload_headers.update({
15661586
'X-Goog-Upload-Command': upload_command,
15671587
'X-Goog-Upload-Offset': str(offset),
15681588
'Content-Length': str(chunk_size),
1569-
}
1589+
})
15701590
populate_server_timeout_header(upload_headers, timeout_in_seconds)
15711591
retry_count = 0
15721592
while retry_count < MAX_RETRY_COUNT:
@@ -1689,6 +1709,21 @@ async def _async_upload_fd(
16891709
The HttpResponse object from the finalized request.
16901710
"""
16911711
offset = 0
1712+
http_options = http_options if http_options else self._http_options
1713+
base_url = (
1714+
http_options.get('base_url')
1715+
if isinstance(http_options, dict)
1716+
else getattr(http_options, 'base_url', None)
1717+
)
1718+
if base_url:
1719+
parsed_base = urlparse(base_url)
1720+
parsed_upload = urlparse(upload_url)
1721+
upload_url = urlunparse(
1722+
parsed_upload._replace(
1723+
scheme=parsed_base.scheme, netloc=parsed_base.netloc
1724+
)
1725+
)
1726+
16921727
# Upload the file in chunks
16931728
if self._use_aiohttp(): # pylint: disable=g-import-not-at-top
16941729
self._aiohttp_session = await self._get_aiohttp_session()
@@ -1704,7 +1739,6 @@ async def _async_upload_fd(
17041739
# If last chunk, finalize the upload.
17051740
if chunk_size + offset >= upload_size:
17061741
upload_command += ', finalize'
1707-
http_options = http_options if http_options else self._http_options
17081742
timeout = (
17091743
http_options.get('timeout')
17101744
if isinstance(http_options, dict)
@@ -1718,11 +1752,17 @@ async def _async_upload_fd(
17181752
else self._http_options.timeout
17191753
)
17201754
timeout_in_seconds = get_timeout_in_seconds(timeout)
1721-
upload_headers = {
1755+
user_headers = (
1756+
http_options.get('headers', {})
1757+
if isinstance(http_options, dict)
1758+
else (getattr(http_options, 'headers', {}) or {})
1759+
)
1760+
upload_headers = dict(user_headers) if user_headers else {}
1761+
upload_headers.update({
17221762
'X-Goog-Upload-Command': upload_command,
1723-
'X-Goog-Upload-Offset': str(offset),
1763+
'X-Goog-Upload-Offset': str(offset),
17241764
'Content-Length': str(chunk_size),
1725-
}
1765+
})
17261766
populate_server_timeout_header(upload_headers, timeout_in_seconds)
17271767

17281768
retry_count = 0
@@ -1780,7 +1820,6 @@ async def _async_upload_fd(
17801820
# If last chunk, finalize the upload.
17811821
if chunk_size + offset >= upload_size:
17821822
upload_command += ', finalize'
1783-
http_options = http_options if http_options else self._http_options
17841823
timeout = (
17851824
http_options.get('timeout')
17861825
if isinstance(http_options, dict)
@@ -1794,11 +1833,17 @@ async def _async_upload_fd(
17941833
else self._http_options.timeout
17951834
)
17961835
timeout_in_seconds = get_timeout_in_seconds(timeout)
1797-
upload_headers = {
1836+
user_headers = (
1837+
http_options.get('headers', {})
1838+
if isinstance(http_options, dict)
1839+
else (getattr(http_options, 'headers', {}) or {})
1840+
)
1841+
upload_headers = dict(user_headers) if user_headers else {}
1842+
upload_headers.update({
17981843
'X-Goog-Upload-Command': upload_command,
17991844
'X-Goog-Upload-Offset': str(offset),
18001845
'Content-Length': str(chunk_size),
1801-
}
1846+
})
18021847
populate_server_timeout_header(upload_headers, timeout_in_seconds)
18031848

18041849
retry_count = 0

google/genai/tests/client/test_upload_errors.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,34 @@ def _httpx_response(code: int, headers: dict = None, content: bytes = b''):
4545
def client():
4646
return api_client.BaseApiClient(vertexai=False, api_key='test_api_key')
4747

48+
49+
def test_upload_url_rewrite(client: api_client.BaseApiClient):
50+
mock_httpx_client = mock.MagicMock(spec=httpx.Client)
51+
mock_httpx_client.request.side_effect = [
52+
_httpx_response(
53+
200, headers={"X-Goog-Upload-Status": "final"}
54+
), # Upload request succeeding
55+
]
56+
client._httpx_client = mock_httpx_client
57+
58+
http_options = types.HttpOptions(base_url="https://my-proxy.company.com")
59+
60+
with io.BytesIO(b"test") as f:
61+
client._upload_fd(
62+
f,
63+
"https://generativelanguage.googleapis.com/upload/v1beta/files?uploadType=resumable",
64+
4,
65+
http_options=http_options,
66+
)
67+
68+
assert mock_httpx_client.request.call_count == 1
69+
call_args = mock_httpx_client.request.call_args[1]
70+
assert (
71+
call_args["url"]
72+
== "https://my-proxy.company.com/upload/v1beta/files?uploadType=resumable"
73+
)
74+
75+
4876
def test_upload_fd_error(client: api_client.BaseApiClient):
4977
error_content = json.dumps({
5078
"error": {
@@ -71,6 +99,39 @@ def test_upload_fd_error(client: api_client.BaseApiClient):
7199

72100
assert mock_httpx_client.request.call_count == 2
73101

102+
103+
@pytest.mark.asyncio
104+
async def test_async_upload_url_rewrite_httpx(client: api_client.BaseApiClient):
105+
mock_async_httpx_client = mock.MagicMock(spec=httpx.AsyncClient)
106+
mock_async_httpx_client.request = mock.AsyncMock(
107+
side_effect=[
108+
_httpx_response(
109+
200, headers={"X-Goog-Upload-Status": "final"}
110+
), # Upload request
111+
]
112+
)
113+
client._async_httpx_client = mock_async_httpx_client
114+
115+
http_options = types.HttpOptions(base_url="https://my-proxy.company.com")
116+
117+
with mock.patch.object(
118+
client, "_use_aiohttp", return_value=False
119+
), io.BytesIO(b"test") as f:
120+
await client._async_upload_fd(
121+
f,
122+
"https://generativelanguage.googleapis.com/upload/v1beta/files?uploadType=resumable",
123+
4,
124+
http_options=http_options,
125+
)
126+
127+
assert mock_async_httpx_client.request.call_count == 1
128+
call_args = mock_async_httpx_client.request.call_args[1]
129+
assert (
130+
call_args["url"]
131+
== "https://my-proxy.company.com/upload/v1beta/files?uploadType=resumable"
132+
)
133+
134+
74135
@pytest.mark.asyncio
75136
async def test_async_upload_fd_error_httpx(client: api_client.BaseApiClient):
76137
error_content = json.dumps({

0 commit comments

Comments
 (0)