Skip to content

Commit 29fd42c

Browse files
authored
Make optional OAuth2 request parameters configurable (#486)
1 parent 14c021b commit 29fd42c

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

mkdocs/docs/configuration.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ catalog:
170170
| credential | t-1234:secret | Credential to use for OAuth2 credential flow when initializing the catalog |
171171
| token | FEW23.DFSDF.FSDF | Bearer token value to use for `Authorization` header |
172172
| scope | openid offline corpds:ds:profile | Desired scope of the requested security token (default : catalog) |
173+
| resource | rest_catalog.iceberg.com | URI for the target resource or service |
174+
| audience | rest_catalog | Logical name of target resource or service |
173175
| rest.sigv4-enabled | true | Sign requests to the REST Server using AWS SigV4 protocol |
174176
| rest.signing-region | us-east-1 | The region to use when SigV4 signing a request |
175177
| rest.signing-name | execute-api | The service signing name to use when SigV4 signing a request |

pyiceberg/catalog/rest.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class Endpoints:
105105
CREDENTIAL = "credential"
106106
GRANT_TYPE = "grant_type"
107107
SCOPE = "scope"
108+
AUDIENCE = "audience"
109+
RESOURCE = "resource"
108110
TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
109111
SEMICOLON = ":"
110112
KEY = "key"
@@ -289,16 +291,26 @@ def auth_url(self) -> str:
289291
else:
290292
return self.url(Endpoints.get_token, prefixed=False)
291293

294+
def _extract_optional_oauth_params(self) -> Dict[str, str]:
295+
optional_oauth_param = {SCOPE: self.properties.get(SCOPE) or CATALOG_SCOPE}
296+
set_of_optional_params = {AUDIENCE, RESOURCE}
297+
for param in set_of_optional_params:
298+
if param_value := self.properties.get(param):
299+
optional_oauth_param[param] = param_value
300+
301+
return optional_oauth_param
302+
292303
def _fetch_access_token(self, session: Session, credential: str) -> str:
293304
if SEMICOLON in credential:
294305
client_id, client_secret = credential.split(SEMICOLON)
295306
else:
296307
client_id, client_secret = None, credential
297308

298-
# take scope from properties or use default CATALOG_SCOPE
299-
scope = self.properties.get(SCOPE) or CATALOG_SCOPE
309+
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}
310+
311+
optional_oauth_params = self._extract_optional_oauth_params()
312+
data.update(optional_oauth_params)
300313

301-
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: scope}
302314
response = session.post(
303315
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
304316
)

tests/catalog/test_rest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
TEST_AUTH_URL = "https://auth-endpoint/"
4848
TEST_TOKEN = "some_jwt_token"
4949
TEST_SCOPE = "openid_offline_corpds_ds_profile"
50+
TEST_AUDIENCE = "test_audience"
51+
TEST_RESOURCE = "test_resource"
52+
5053
TEST_HEADERS = {
5154
"Content-type": "application/json",
5255
"X-Client-Version": "0.14.1",
@@ -137,6 +140,48 @@ def test_token_200_without_optional_fields(rest_mock: Mocker) -> None:
137140
)
138141

139142

143+
def test_token_with_optional_oauth_params(rest_mock: Mocker) -> None:
144+
mock_request = rest_mock.post(
145+
f"{TEST_URI}v1/oauth/tokens",
146+
json={
147+
"access_token": TEST_TOKEN,
148+
"token_type": "Bearer",
149+
"expires_in": 86400,
150+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
151+
},
152+
status_code=200,
153+
request_headers=OAUTH_TEST_HEADERS,
154+
)
155+
assert (
156+
RestCatalog(
157+
"rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience=TEST_AUDIENCE, resource=TEST_RESOURCE
158+
)._session.headers["Authorization"]
159+
== f"Bearer {TEST_TOKEN}"
160+
)
161+
assert TEST_AUDIENCE in mock_request.last_request.text
162+
assert TEST_RESOURCE in mock_request.last_request.text
163+
164+
165+
def test_token_with_optional_oauth_params_as_empty(rest_mock: Mocker) -> None:
166+
mock_request = rest_mock.post(
167+
f"{TEST_URI}v1/oauth/tokens",
168+
json={
169+
"access_token": TEST_TOKEN,
170+
"token_type": "Bearer",
171+
"expires_in": 86400,
172+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
173+
},
174+
status_code=200,
175+
request_headers=OAUTH_TEST_HEADERS,
176+
)
177+
assert (
178+
RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience="", resource="")._session.headers["Authorization"]
179+
== f"Bearer {TEST_TOKEN}"
180+
)
181+
assert TEST_AUDIENCE not in mock_request.last_request.text
182+
assert TEST_RESOURCE not in mock_request.last_request.text
183+
184+
140185
def test_token_with_default_scope(rest_mock: Mocker) -> None:
141186
mock_request = rest_mock.post(
142187
f"{TEST_URI}v1/oauth/tokens",

0 commit comments

Comments
 (0)