Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
53 changes: 45 additions & 8 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import saml2.response
from saml2.client import Saml2Client

from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self, hs: "synapse.server.HomeServer"):
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template

# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
Expand All @@ -84,6 +86,25 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)

def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.

This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.

Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)

def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
Expand Down Expand Up @@ -146,12 +167,23 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
self._render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
self._render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
)
return

if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed.")
self._render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return

logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
Expand All @@ -171,14 +203,19 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
)

for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)

# Call the mapper to register/login the user
user_id = await self._map_saml_response_to_user(
resp_bytes, relay_state, user_agent, ip_address
)
Expand Down Expand Up @@ -324,19 +361,19 @@ def expire_sessions(self):
del self._outstanding_requests_dict[reqid]


def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
return True

logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
return False


DOT_REPLACE_PATTERN = re.compile(
Expand Down