Skip to content

Commit 74ab329

Browse files
authored
Pass module API to OIDC mapping provider (#16974)
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
1 parent 05489d8 commit 74ab329

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

changelog.d/16974.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.

docs/sso_mapping_providers.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.
5050

5151
A custom mapping provider must specify the following methods:
5252

53-
* `def __init__(self, parsed_config)`
53+
* `def __init__(self, parsed_config, module_api)`
5454
- Arguments:
5555
- `parsed_config` - A configuration object that is the return value of the
5656
`parse_config` method. You should set any configuration options needed by
5757
the module here.
58+
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
59+
stable API available for extension modules.
5860
* `def parse_config(config)`
5961
- This method should have the `@staticmethod` decoration.
6062
- Arguments:

synapse/handlers/oidc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from synapse.http.servlet import parse_string
6666
from synapse.http.site import SynapseRequest
6767
from synapse.logging.context import make_deferred_yieldable
68+
from synapse.module_api import ModuleApi
6869
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
6970
from synapse.util import Clock, json_decoder
7071
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
@@ -421,9 +422,19 @@ def __init__(
421422
# from the IdP's jwks_uri, if required.
422423
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
423424

424-
self._user_mapping_provider = provider.user_mapping_provider_class(
425-
provider.user_mapping_provider_config
425+
user_mapping_provider_init_method = (
426+
provider.user_mapping_provider_class.__init__
426427
)
428+
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
429+
self._user_mapping_provider = provider.user_mapping_provider_class(
430+
provider.user_mapping_provider_config,
431+
ModuleApi(hs, hs.get_auth_handler()),
432+
)
433+
else:
434+
self._user_mapping_provider = provider.user_mapping_provider_class(
435+
provider.user_mapping_provider_config,
436+
)
437+
427438
self._skip_verification = provider.skip_verification
428439
self._allow_existing_users = provider.allow_existing_users
429440

@@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
15831594
This is the default mapping provider.
15841595
"""
15851596

1586-
def __init__(self, config: JinjaOidcMappingConfig):
1597+
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
15871598
self._config = config
15881599

15891600
@staticmethod

0 commit comments

Comments
 (0)