diff --git a/dirac.cfg b/dirac.cfg index 9af77571bde..7f3aaaad8ae 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -57,6 +57,9 @@ Registry # Real VOMS VO name, if this VO is associated with VOMS VO VOMSName = lhcb + # Registered identity provider associated with VO + IdP = CheckIn + # Section to describe all the VOMS servers that can be used with the given VOMS VO VOMSServers { @@ -99,6 +102,9 @@ Registry # Role of the users in the VO VOMSRole = /lhcb + + # Scope associated with a role of the user in the VO + IdPRole = some_special_scope # Virtual organization associated with the group VOMSVO = lhcb @@ -418,6 +424,20 @@ Systems } Resources { + IdProviders + { + CheckIn + { + # What supported type of provider does it belong to + ProviderType = OAuth2 + # Description of the client parameters registered on the identity provider side. + # Look here for information about client parameters description https://tools.ietf.org/html/rfc8414#section-2 + issuer = https://aai-dev.egi.eu/oidc + client_id = type_client_id_here_receved_after_client_registration + client_secret = type_client_secret_here_receved_after_client_registration + scope = openid, profile, offline_access, eduperson_entitlement, cert_entitlement + } + } # Section for proxy providers, subsections is the names of the proxy providers # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Resources/proxyprovider.html diff --git a/docs/source/AdministratorGuide/CommandReference/index.rst b/docs/source/AdministratorGuide/CommandReference/index.rst index 1ca4c9fe267..ec640b98cba 100644 --- a/docs/source/AdministratorGuide/CommandReference/index.rst +++ b/docs/source/AdministratorGuide/CommandReference/index.rst @@ -155,6 +155,9 @@ Other commands: .. toctree:: :maxdepth: 2 + dirac-login + dirac-logout + dirac-admin-accounting-cli dirac-admin-sysadmin-cli diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index 4725a3903f5..0cb12e4fbef 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -65,3 +65,12 @@ DIRAC_X509_HOST_KEY X509_VOMSES Must be set to point to a folder containing VOMSES information. See :ref:`multi_vo_dirac` + +BEARER_TOKEN + If the environment variable is set, then the value is taken to be the token contents (https://doi.org/10.5281/zenodo.3937438). + +BEARER_TOKEN_FILE + If the environment variable is set, then its value is interpreted as a filename. The content of the specified file is used as token string (https://doi.org/10.5281/zenodo.3937438). + +DIRAC_USE_ACCESS_TOKEN + If this environment is set to ``true`` or ``yes``, the concurrent.futures.ThreadPoolExecutor will be used (default=false) diff --git a/environment-py3.yml b/environment-py3.yml index 169d35228c7..ebe05a23909 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,6 +89,10 @@ dependencies: - typing >=3.6.6 - pyyaml - pip: + # Prerelease of the required package for integration of OAuth2 + - Authlib>=1.0.0.a2 + - dominate + - pyjwt # This is a fork of tornado with a patch to allow for configurable iostream # It should eventually be part of DIRACGrid - git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable diff --git a/environment.yml b/environment.yml index 3413b74c980..fcb11ffd10e 100644 --- a/environment.yml +++ b/environment.yml @@ -74,6 +74,10 @@ dependencies: - selectors2 - pip: - diraccfg + # OAuth2 + - dominate + - authlib + - pyjwt # This is a fork of tornado with a patch to allow for configurable iostream - git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable # This is an extension of Tornado to use M2Crypto diff --git a/setup.cfg b/setup.cfg index 93674cb8269..b31ca03e79d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,9 @@ install_requires = six sqlalchemy subprocess32 + Authlib >=1.0.0.a2 + pyjwt + dominate zip_safe = False include_package_data = True @@ -92,6 +95,7 @@ testing = parameterized pytest pytest-cov + pytest-mock pycodestyle [options.entry_points] @@ -158,6 +162,8 @@ console_scripts = dirac-dms-user-lfns = DIRAC.DataManagementSystem.scripts.dirac_dms_user_lfns:main dirac-dms-user-quota = DIRAC.DataManagementSystem.scripts.dirac_dms_user_quota:main # FrameworkSystem + dirac-login = DIRAC.FrameworkSystem.scripts.dirac_login:main + dirac-logout = DIRAC.FrameworkSystem.scripts.dirac_logout:main dirac-admin-get-CAs = DIRAC.FrameworkSystem.scripts.dirac_admin_get_CAs:main [server] dirac-admin-get-proxy = DIRAC.FrameworkSystem.scripts.dirac_admin_get_proxy:main [admin] dirac-admin-proxy-upload = DIRAC.FrameworkSystem.scripts.dirac_admin_proxy_upload:main [admin] diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index e9377b1915a..750b7287628 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -13,6 +13,8 @@ __RCSID__ = "$Id$" +ID_DN_PREFIX = "/O=DIRAC/CN=" + # pylint: disable=missing-docstring gBaseRegistrySection = "/Registry" @@ -428,6 +430,16 @@ def getVOForGroup(group): return getVO() or gConfig.getValue("%s/Groups/%s/VO" % (gBaseRegistrySection, group), "") +def getIdPForGroup(group): + """ Get identity provider for group VO + + :param str group: group name + + :return: str + """ + return getVOOption(getVOForGroup(group), 'IdP') + + def getDefaultVOMSAttribute(): """ Get default VOMS attribute @@ -697,3 +709,25 @@ def getEmailsForGroup(groupName): email = getUserOption(username, 'Email', []) emails.append(email) return emails + + +def wrapIDAsDN(userID): + """ Wrap user ID as user DN + + :param str userID: user ID + + :return: str + """ + return '/O=DIRAC/CN=' + userID + + +def getIDFromDN(userDN): + """ Parse user ID from user DN + + :param str userDN: user DN + + :return: S_OK(str)/S_ERROR() + """ + if not userDN.startswith(ID_DN_PREFIX): + return S_ERROR("%s DN does not contain user ID." % userDN) + return S_OK(userDN[len(ID_DN_PREFIX):]) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index 20e63aac9f7..097f8b28eb4 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -95,6 +95,17 @@ def getComponentSection(system, component=False, setup=False, componentCategory= ) +def getAPISection(system, endpointName=False, setup=False): + """ Get API section in a system + + :param str system: system name + :param str endpointName: endpoint name + + :return: str + """ + return getComponentSection(system, component=endpointName, setup=setup, componentCategory="APIs") + + def getServiceSection(system, serviceName=False, setup=False): """ Get service section in a system diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index d4b9cef6bab..f1138f6e386 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -20,6 +20,7 @@ from DIRAC.ConfigurationSystem.Client.PathFinder import getDatabaseSection from DIRAC.Core.Utilities.Glue2 import getGlue2CEInfo from DIRAC.Core.Utilities.SiteSEMapping import getSEHosts +from DIRAC.ConfigurationSystem.Client.PathFinder import getSystemInstance from DIRAC.DataManagementSystem.Utilities.DMSHelpers import DMSHelpers @@ -536,16 +537,6 @@ def getElasticDBParameters(fullname): return S_OK(parameters) -def getOAuthAPI(instance='Production'): - """ Get OAuth API url - - :param str instance: instance - - :return: str - """ - return gConfig.getValue("/Systems/Framework/%s/URLs/OAuthAPI" % instance) - - def getDIRACGOCDictionary(): """ Create a dictionary containing DIRAC site names and GOCDB site names @@ -578,3 +569,51 @@ def getDIRACGOCDictionary(): log.debug('End function.') return S_OK(dictionary) + + +def getAuthAPI(): + """ Get Auth REST API url + + :return: str + """ + return gConfig.getValue("/Systems/Framework/%s/URLs/AuthAPI" % getSystemInstance("Framework")) + + +def getAuthorizationServerMetadata(issuer=None, ignoreErrors=False): + """ Get authorization server metadata + + :param str issuer: issuer + :param bool ignoreErrors: igrnore configuration errors + + :return: S_OK(dict)/S_ERROR() + """ + data = {} + try: + result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization') + if not result['OK']: + return S_OK({'issuer': issuer}) if issuer else result + data = result['Value'] + except Exception as e: + if ignoreErrors: + gLogger.warn(repr(e)) + else: + raise e + + # Search DIRAC Authorization Server issuer + data['issuer'] = data.get('issuer', issuer) + if not data['issuer']: + try: + data['issuer'] = getAuthAPI() + except Exception as e: + return S_ERROR('No issuer found in DIRAC authorization server: %s' % repr(e)) + + return S_OK(data) if data['issuer'] else S_ERROR('Cannot find DIRAC Authorization Server issuer.') + + +def isDownloadablePersonalProxy(): + """ Get downloadablePersonalProxy flag + + :return: S_OK(bool)/S_ERROR() + """ + cs_path = '/Systems/Framework/%s/APIs/Auth' % getSystemInstance("Framework") + return gConfig.getValue(cs_path + '/downloadablePersonalProxy', "false").lower() in ("y", "yes", "true") diff --git a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py index 7443a13ba43..841f2c23186 100644 --- a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py +++ b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py @@ -4,6 +4,7 @@ __RCSID__ = "$Id$" +from six import PY3 import time from tornado import gen @@ -76,9 +77,13 @@ def __refreshLoop(self): yield gen.sleep(gConfigurationData.getPropagationTime()) # Publish step is blocking so we have to run it in executor # If we are not doing it, when master try to ping we block the IOLoop - yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh) - @gen.coroutine + # When switching from python 2 to python 3, the following error occurs: + # RuntimeError: There is no current event loop in thread.. + # The reason seems to be that asyncio.get_event_loop() is called in some thread other than the main thread, + # asyncio only generates an event loop for the main thread. + yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh if PY3 else gen.coroutine(self.__AutoRefresh)) + def __AutoRefresh(self): """ Auto refresh the configuration diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 5f2fb1b7a95..a501726214c 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -34,6 +34,7 @@ import errno import requests import six +import os from six.moves import http_client @@ -46,9 +47,13 @@ from DIRAC.Core.DISET.ThreadConfig import ThreadConfig from DIRAC.Core.Security import Locations -from DIRAC.Core.Utilities import Network +from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode +if six.PY3: + # DIRACOS not contain required packages + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile # TODO CHRIS: refactor all the messy `discover` methods # I do not do it now because I want first to decide @@ -62,6 +67,7 @@ class TornadoBaseClient(object): __threadConfig = ThreadConfig() VAL_EXTRA_CREDENTIALS_HOST = "hosts" + KW_USE_ACCESS_TOKEN = "useAccessToken" KW_USE_CERTIFICATES = "useCertificates" KW_EXTRA_CREDENTIALS = "extraCredentials" KW_TIMEOUT = "timeout" @@ -103,6 +109,8 @@ def __init__(self, serviceName, **kwargs): self.__ca_location = False self.kwargs = kwargs + self.__idp = None + self.__useAccessToken = None self.__useCertificates = None # The CS useServerCertificate option can be overridden by explicit argument self.__forceUseCertificates = self.kwargs.get(self.KW_USE_CERTIFICATES) @@ -202,6 +210,12 @@ def __discoverCredentialsToUse(self): -> if KW_SKIP_CA_CHECK is not in kwargs and we are using the certificates, set KW_SKIP_CA_CHECK to false in kwargs -> if KW_SKIP_CA_CHECK is not in kwargs and we are not using the certificate, check the skipCACheck + * Bearer token: + -> If KW_USE_ACCESS_TOKEN in kwargs, sets it in self.__useAccessToken + -> If not, check "/DIRAC/Security/UseTokens", and sets it in self.__useAccessToken + and kwargs[KW_USE_ACCESS_TOKEN] + -> If not, check 'DIRAC_USE_ACCESS_TOKEN' in os.environ, sets it in self.__useAccessToken + and kwargs[KW_USE_ACCESS_TOKEN] * Proxy Chain WARNING: MOSTLY COPY/PASTE FROM Core/Diset/private/BaseClient @@ -219,6 +233,21 @@ def __discoverCredentialsToUse(self): else: self.kwargs[self.KW_SKIP_CA_CHECK] = skipCACheck() + # Use tokens? + if self.KW_USE_ACCESS_TOKEN in self.kwargs: + self.__useAccessToken = self.kwargs[self.KW_USE_ACCESS_TOKEN] + elif 'DIRAC_USE_ACCESS_TOKEN' in os.environ: + self.__useAccessToken = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") + else: + self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") + self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken + + if self.__useAccessToken and six.PY3: + result = IdProviderFactory().getIdProvider('DIRACCLI') + if not result['OK']: + return result + self.__idp = result['Value'] + # Rewrite a little bit from here: don't need the proxy string, we use the file if self.KW_PROXY_CHAIN in self.kwargs: try: @@ -490,12 +519,38 @@ def _request(self, retry=0, outputFile=None, **kwargs): # getting certificate # Do we use the server certificate ? if self.kwargs[self.KW_USE_CERTIFICATES]: - cert = Locations.getHostCertificateAndKeyLocation() + auth = {'cert': Locations.getHostCertificateAndKeyLocation()} + + # Use access token? + elif self.__useAccessToken and six.PY3: + # Read token from token environ variable or from token file + result = getLocalTokenDict() + if not result['OK']: + return result + token = result['Value'] + + # Check if access token expired + if token.is_expired(): + if not token.get('refresh_token'): + return S_ERROR('Access token expired.') + + # Try to refresh token + self.__idp.scope = None + result = self.__idp.refreshToken(token['refresh_token']) + if result['OK']: + token = result['Value'] + result = writeTokenDictToTokenFile(token) + if not result['OK']: + return result + gLogger.notice('Token is saved in %s.' % result['Value']) + + auth = {'headers': {"Authorization": "Bearer %s" % token['access_token']}} + # CHRIS 04.02.21 # TODO: add proxyLocation check ? else: - cert = Locations.getProxyLocation() - if not cert: + auth = {'cert': Locations.getProxyLocation()} + if not auth['cert']: gLogger.error("No proxy found") return S_ERROR("No proxy found") @@ -510,9 +565,8 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Default case, just return the result if not outputFile: - call = requests.post(url, data=kwargs, - timeout=self.timeout, verify=verify, - cert=cert) + call = requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, + **auth) # raising the exception for status here # means essentialy that we are losing here the information of what is returned by the server # as error message, since it is not passed to the exception @@ -532,7 +586,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Stream download # https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow with requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, - cert=cert, stream=True) as r: + stream=True, **auth) as r: rawText = r.text r.raise_for_status() @@ -550,6 +604,8 @@ def _request(self, retry=0, outputFile=None, **kwargs): return S_ERROR(errno.ENOSYS, "%s is not implemented" % kwargs.get('method')) elif status_code in (http_client.FORBIDDEN, http_client.UNAUTHORIZED): return S_ERROR(errno.EACCES, "No access to %s" % url) + elif status_code == http_client.NOT_FOUND: + rawText = "%s is not found" % url # if it is something else, retry raise diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index 1f4397f63e9..3a1fba3cd30 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -9,6 +9,8 @@ __RCSID__ = "$Id$" +import inspect +from six import string_types from tornado.web import url as TornadoURL, RequestHandler from DIRAC import gConfig, gLogger, S_ERROR, S_OK @@ -50,109 +52,221 @@ class HandlerManager(object): ``System/Component`` (e.g. ``DataManagement/FileCatalog``) """ - def __init__(self, autoDiscovery=True): + def __init__(self, services, endpoints): """ - Initialization function, you can set autoDiscovery=False to prevent automatic - discovery of handler. If disabled you can use loadHandlersByServiceName() to - load your handlers or loadHandlerInHandlerManager() - - :param autoDiscovery: (default True) Disable the automatic discovery, - can be used to choose service we want to load. + Initialization function, you can set False for both arguments to prevent automatic + discovery of handlers and use `loadServicesHandlers()` to + load your handlers or `loadEndpointsHandlers()` + + :param services: List of service handlers to load. + If ``True``, loads all services from CS + :type services: bool or list + :param endpoints: List of endpoint handlers to load. + If ``True``, loads all endpoints from CS + :type endpoints: bool or list """ + self.loader = None self.__handlers = {} + self.__services = services + self.__endpoints = endpoints self.__objectLoader = ObjectLoader() - self.__autoDiscovery = autoDiscovery - self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") - def __addHandler(self, handlerTuple, url=None): + def __addHandler(self, handlerPath, handler, urls=None, port=None): """ Function which add handler to list of known handlers + :param str handlerPath: module name, e.g.: `Framework/Auth` + :param object handler: handler class + :param list urls: request path + :param int port: port - :param handlerTuple: (path, class) + :return: S_OK()/S_ERROR() """ - # Check if handler not already loaded - if not url or url not in self.__handlers: - gLogger.debug("Find new handler %s" % (handlerTuple[0])) - - # If url is not given, try to discover it - if url is None: - # FIRST TRY: Url is hardcoded - try: - url = handlerTuple[1].LOCATION - # SECOND TRY: URL can be deduced from path - except AttributeError: - gLogger.debug("No location defined for %s try to get it from path" % handlerTuple[0]) - url = urlFinder(handlerTuple[0]) - + # First of all check if we can find route + # If urls is not given, try to discover it + if urls is None: + # FIRST TRY: Url is hardcoded + try: + urls = handler.LOCATION + # SECOND TRY: URL can be deduced from path + except AttributeError: + gLogger.debug("No location defined for %s try to get it from path" % handlerPath) + urls = urlFinder(handlerPath) + + if not urls: + gLogger.warn("URL not found for %s" % (handlerPath)) + return S_ERROR("URL not found for %s" % (handlerPath)) + + for url in urls if isinstance(urls, (list, tuple)) else [urls]: # We add "/" if missing at begin, e.g. we found "Framework/Service" # URL can't be relative in Tornado if url and not url.startswith('/'): url = "/%s" % url - elif not url: - gLogger.warn("URL not found for %s" % (handlerTuple[0])) - return S_ERROR("URL not found for %s" % (handlerTuple[0])) + + # Some new handler + if handlerPath not in self.__handlers: + gLogger.debug("Add new handler %s with port %s" % (handlerPath, port)) + self.__handlers[handlerPath] = {'URLs': [], 'Port': port} + + # Check if URL already loaded + if (url, handler) in self.__handlers[handlerPath]['URLs']: + gLogger.debug("URL: %s already loaded for %s " % (url, handlerPath)) + continue # Finally add the URL to handlers - if url not in self.__handlers: - self.__handlers[url] = handlerTuple[1] - gLogger.info("New handler: %s with URL %s" % (handlerTuple[0], url)) - else: - gLogger.debug("Handler already loaded %s" % (handlerTuple[0])) + gLogger.info("Add new URL %s to %s handler" % (url, handlerPath)) + self.__handlers[handlerPath]['URLs'].append((url, handler)) + return S_OK() - def discoverHandlers(self): + def discoverHandlers(self, handlerInstance): """ Force the discovery of URL, automatic call when we try to get handlers for the first time. You can disable the automatic call with autoDiscovery=False at initialization + + :param str handlerInstance: handler instance, the name of the section in some system section e.g.:: Services, APIs + + :return: list """ - gLogger.debug("Trying to auto-discover the handlers for Tornado") + urls = [] + gLogger.debug("Trying to auto-discover the %s handlers for Tornado" % handlerInstance) # Look in config diracSystems = gConfig.getSections('/Systems') - serviceList = [] if diracSystems['OK']: for system in diracSystems['Value']: try: - instance = PathFinder.getSystemInstance(system) - services = gConfig.getSections('/Systems/%s/%s/Services' % (system, instance)) - if services['OK']: - for service in services['Value']: - newservice = ("%s/%s" % (system, service)) - - # We search in the CS all handlers which used HTTPS as protocol - isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, instance, service)) - if isHTTPS and isHTTPS.lower() == 'https': - serviceList.append(newservice) + sysInstance = PathFinder.getSystemInstance(system) + result = gConfig.getSections('/Systems/%s/%s/%s' % (system, sysInstance, handlerInstance)) + if result['OK']: + for instName in result['Value']: + newInst = ("%s/%s" % (system, instName)) + port = gConfig.getValue('/Systems/%s/%s/%s/%s/Port' % (system, sysInstance, + handlerInstance, instName)) + if port: + newInst += ':%s' % port + + if handlerInstance == 'Services': + # We search in the CS all handlers which used HTTPS as protocol + isHTTPS = gConfig.getValue('/Systems/%s/%s/%s/%s/Protocol' % (system, sysInstance, + handlerInstance, instName)) + if isHTTPS and isHTTPS.lower() == 'https': + urls.append(newInst) + else: + urls.append(newInst) # On systems sometime you have things not related to services... except RuntimeError: pass - return self.loadHandlersByServiceName(serviceList) + return urls - def loadHandlersByServiceName(self, servicesNames): + def loadServicesHandlers(self, services=None): """ Load a list of handler from list of service using DIRAC moduleLoader Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - :param servicesNames: list of service, e.g. ['Framework/Hello', 'Configuration/Server'] + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() + """ + # list of services, e.g. ['Framework/Hello', 'Configuration/Server'] + if isinstance(services, string_types): + services = [services] + # list of services + self.__services = self.__services if services is None else services if services else [] + + if self.__services is True: + self.__services = self.discoverHandlers('Services') + + if self.__services: + # Extract ports + ports, self.__services = self.__extractPorts(self.__services) + + self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) + load = self.loader.loadModules(self.__services) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + url = module['loadName'] + + # URL can be like https://domain:port/service/name or just service/name + # Here we just want the service name, for tornado + serviceTuple = url.replace('https://', '').split('/')[-2:] + url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) + self.__addHandler(module['loadName'], module['classObj'], url, ports.get(module['modName'])) + return S_OK() + + def __extractPorts(self, serviceURIs): + """ Extract ports from serviceURIs + + :param list serviceURIs: list of uri that can contain port, .e.g:: System/Service:port + + :return: (dict, list) + """ + portMapping = {} + newURLs = [] + for _url in serviceURIs: + if ':' in _url: + urlTuple = _url.split(':') + if urlTuple[0] not in portMapping: + portMapping[urlTuple[0]] = urlTuple[1] + newURLs.append(urlTuple[0]) + else: + newURLs.append(_url) + return (portMapping, newURLs) + + def loadEndpointsHandlers(self, endpoints=None): """ + Load a list of handler from list of endpoints using DIRAC moduleLoader + Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - # Use DIRAC system to load: search in CS if path is given and if not defined - # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) - if not isinstance(servicesNames, list): - servicesNames = [servicesNames] - - load = self.loader.loadModules(servicesNames) - if not load['OK']: - return load - for module in self.loader.getModules().values(): - url = module['loadName'] - - # URL can be like https://domain:port/service/name or just service/name - # Here we just want the service name, for tornado - serviceTuple = url.replace('https://', '').split('/')[-2:] - url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) - self.__addHandler((module['loadName'], module['classObj']), url) + :param endpoints: List of endpoint handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + # list of endpoints, e.g. ['Framework/Auth', ...] + if isinstance(endpoints, string_types): + endpoints = [endpoints] + # list of endpoints. If __endpoints is ``True`` then list of endpoints will dicover from CS + self.__endpoints = self.__endpoints if endpoints is None else endpoints if endpoints else [] + + if self.__endpoints is True: + self.__endpoints = self.discoverHandlers('APIs') + + if self.__endpoints: + # Extract ports + ports, self.__endpoints = self.__extractPorts(self.__endpoints) + + self.loader = ModuleLoader("API", PathFinder.getAPISection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/API) + load = self.loader.loadModules(self.__endpoints) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + handler = module['classObj'] + if not handler.LOCATION: + handler.LOCATION = urlFinder(module['loadName']) + urls = [] + # Look for methods that are exported + for mName, mObj in inspect.getmembers(handler): + if inspect.isroutine(mObj) and mName.find(handler.METHOD_PREFIX) == 0: + methodName = mName[len(handler.METHOD_PREFIX):] + args = getattr(handler, 'path_%s' % methodName, []) + gLogger.debug(" - Route %s/%s -> %s %s" % (handler.LOCATION, methodName, module['loadName'], mName)) + url = "%s%s" % (handler.LOCATION, '' if methodName == 'index' else ('/%s' % methodName)) + if args: + url += r'[\/]?%s' % '/'.join(args) + urls.append(url) + gLogger.debug(" * %s" % url) + self.__addHandler(module['loadName'], handler, urls, ports.get(module['modName'])) return S_OK() def getHandlersURLs(self): @@ -163,24 +277,18 @@ def getHandlersURLs(self): :returns: a list of URL (not the string with "https://..." but the tornado object) see http://www.tornadoweb.org/en/stable/web.html#tornado.web.URLSpec """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - self.discoverHandlers() urls = [] - for key in self.__handlers: - urls.append(TornadoURL(key, self.__handlers[key])) + for handlerData in self.__handlers.values(): + for url in handlerData['URLs']: + urls.append(TornadoURL(*url)) return urls def getHandlersDict(self): """ Return all handler dictionary - :returns: dictionary with absolute url as key ("/System/Service") - and tornado.web.url object as value + :returns: dictionary with service name as key ("System/Service") + and tornado.web.url objects as value for 'URLs' key + and port as value for 'Port' key """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - res = self.discoverHandlers() - if not res['OK']: - gLogger.error("Could not load handlers", res) return self.__handlers diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py new file mode 100644 index 00000000000..5857d8281d8 --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -0,0 +1,149 @@ +""" +TornadoREST is the base class for your RESTful API handlers. +It directly inherits from :py:class:`tornado.web.RequestHandler` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from functools import partial + +import tornado.ioloop +from tornado import gen +from tornado.web import HTTPError +from tornado.ioloop import IOLoop + +import DIRAC + +from DIRAC import gLogger, S_OK +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import BaseRequestHandler + +sLog = gLogger.getSubLogger(__name__) + + +class TornadoREST(BaseRequestHandler): # pylint: disable=abstract-method + """ Base class for all the endpoints handlers. + It directly inherits from :py:class:`DIRAC.Core.Tornado.Server.BaseRequestHandler.BaseRequestHandler` + + Each HTTP request is served by a new instance of this class. + + In order to create a handler for your service, it has to + follow a certain skeleton:: + + from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST + class yourEndpointHandler(TornadoREST): + + @classmethod + def initializeHandler(cls, infosDict): + ''' Called only once when the first request for this handler arrives Useful for initializing DB or so. + ''' + pass + + def initializeRequest(self): + ''' Called at the beginning of each request + ''' + pass + + # Specify the path arguments + path_someMethod = ['([A-z0-9-_]*)'] + + # Specify the default permission for the method + # See :py:class:`DIRAC.Core.DISET.AuthManager.AuthManager` + auth_someMethod = ['authenticated'] + + def web_someMethod(self, provider=None): + ''' Your method + ''' + return S_OK(provider) + + Note that because we inherit from :py:class:`tornado.web.RequestHandler` + and we are running using executors, the methods you export cannot write + back directly to the client. Please see inline comments for more details. + + In order to pass information around and keep some states, we use instance attributes. + These are initialized in the :py:meth:`.initialize` method. + + The handler define the ``post`` and ``get`` verbs. Please refer to :py:meth:`.post` for the details. + """ + + USE_AUTHZ_GRANTS = ['SSL', 'JWT', 'VISITOR'] + METHOD_PREFIX = 'web_' + LOCATION = '/' + + @classmethod + def _getServiceName(cls, request): + """ Define endpoint full name + + :param object request: tornado Request + + :return: str + """ + if not cls.SYSTEM: + raise Exception("System name must be defined.") + return "/".join([cls.SYSTEM, cls.__name__]) + + @classmethod + def _getServiceAuthSection(cls, endpointName): + """ Search endpoint auth section. + + :param str endpointName: endpoint name + + :return: str + """ + return "%s/Authorization" % PathFinder.getAPISection(endpointName) + + def _getMethodName(self): + """ Parse method name. By default we read the first section in the path + following the coincidence with the value of `LOCATION`. + If such a method is not defined, then try to use the `index` method. + + :return: str + """ + method = self.request.path.replace(self.LOCATION, '', 1).strip('/').split('/')[0] + if method and hasattr(self, ''.join([self.METHOD_PREFIX, method])): + return method + elif hasattr(self, '%sindex' % self.METHOD_PREFIX): + gLogger.warn('%s method not implemented. Use the index method to handle this.' % method) + return 'index' + else: + raise NotImplementedError('%s method not implemented. You can use the index method to handle this.' % method) + + @gen.coroutine + def get(self, *args, **kwargs): # pylint: disable=arguments-differ + """ Method to handle incoming ``GET`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` method. + """ + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) + self._finishFuture(retVal) + + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + """ Method to handle incoming ``POST`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` method. + """ + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) + self._finishFuture(retVal) + + auth_echo = ['all'] + + @staticmethod + def web_echo(data): + """ + This method used for testing the performance of a service + """ + return S_OK(data) + + auth_whoami = ['authenticated'] + + def web_whoami(self): + """ A simple whoami, returns all credential dictionary, except certificate chain object. + """ + credDict = self.srv_getRemoteCredentials() + if 'x509Chain' in credDict: + # Not serializable + del credDict['x509Chain'] + return S_OK(credDict) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 63aeaf133dc..0d7675ad921 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -21,21 +21,30 @@ 'tornado_m2crypto.m2iostream.M2IOStream') # pylint: disable=wrong-import-position from tornado.httpserver import HTTPServer -from tornado.web import Application, url +from tornado.web import Application, url, RequestHandler from tornado.ioloop import IOLoop import tornado.ioloop import DIRAC -from DIRAC import gConfig, gLogger -from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC import gConfig, gLogger, S_OK from DIRAC.Core.Security import Locations -from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager from DIRAC.Core.Utilities import MemStat +from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager +from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient sLog = gLogger.getSubLogger(__name__) +class NotFoundHandler(RequestHandler): + """ Handle 404 errors """ + def prepare(self): + self.set_status(404) + if six.PY3: + from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML + self.finish(getHTML('Not found.', state=404, info='Nothing matches the given URI.')) + + class TornadoServer(object): """ Tornado webserver @@ -61,35 +70,36 @@ class TornadoServer(object): Example 2:We want to debug service1 and service2 only, and use another port for that :: - services = ['component/service1', 'component/service2'] - serverToLaunch = TornadoServer(services=services, port=1234) + services = ['component/service1:port1', 'component/service2'] + endpoints = ['component/endpoint1', 'component/endpoint2'] + serverToLaunch = TornadoServer(services=services, endpoints=endpoints, port=1234) serverToLaunch.startTornado() """ - def __init__(self, services=None, port=None): - """ - - :param list services: (default None) List of service handlers to load. If ``None``, loads all - :param int port: Port to listen to. If None, the port is resolved following the logic - described in the class documentation + def __init__(self, services=True, endpoints=False, port=None): + """ C'r + + :param list services: (default True) List of service handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load services + :param list endpoints: (default False) List of endpoint handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load endpoints + :param int port: Port to listen to. + If ``None``, the port is resolved following the logic described in the class documentation """ - + # Application metadata, routes and settings mapping on the ports + self.__appsSettings = {} + # Default port, if enother is not discover if port is None: port = gConfig.getValue("/Systems/Tornado/%s/Port" % PathFinder.getSystemInstance('Tornado'), 8443) - - if services and not isinstance(services, list): - services = [services] - - # URLs for services. - # Contains Tornado :py:class:`tornado.web.url` object - self.urls = [] - # Other infos self.port = port - self.handlerManager = HandlerManager() - # Monitoring attributes + # Handler manager initialization with default settings + self.handlerManager = HandlerManager(services, endpoints) + # Monitoring attributes self._monitor = MonitoringClient() # temp value for computation, used by the monitoring self.__report = None @@ -98,20 +108,69 @@ def __init__(self, services=None, port=None): self.__monitoringLoopDelay = 60 # In secs # If services are defined, load only these ones (useful for debug purpose or specific services) - if services: - retVal = self.handlerManager.loadHandlersByServiceName(services) - if not retVal['OK']: - sLog.error(retVal['Message']) - raise ImportError("Some services can't be loaded, check the service names and configuration.") - + retVal = self.handlerManager.loadServicesHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some services can't be loaded, check the service names and configuration.") + + retVal = self.handlerManager.loadEndpointsHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some endpoints can't be loaded, check the endpoint names and configuration.") + + def __calculateAppSettings(self): + """ Calculate application information mapping on the ports + """ # if no service list is given, load services from configuration handlerDict = self.handlerManager.getHandlersDict() - for item in handlerDict.items(): - # handlerDict[key].initializeService(key) - self.urls.append(url(item[0], item[1])) - # If there is no services loaded: - if not self.urls: - raise ImportError("There is no services loaded, please check your configuration") + for data in handlerDict.values(): + port = data.get('Port') or self.port + for hURL in data['URLs']: + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if hURL not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(hURL) + return bool(self.__appsSettings) + + def loadServices(self, services): + """ Load a services + + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadServicesHandlers(services) + + def loadEndpoints(self, endpoints): + """ Load a endpoints + + :param endpoints: List of service handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadEndpointsHandlers(endpoints) + + def addHandlers(self, routes, settings=None, port=None): + """ Add new routes + + :param list routes: routes + :param dict settings: application settings + :param int port: port + """ + port = port or self.port + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if settings: + self.__appsSettings[port]['settings'].update(settings) + for route in routes: + if route not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(route) + + return S_OK() def startTornado(self): """ @@ -119,13 +178,13 @@ def startTornado(self): This method never returns. """ - sLog.debug("Starting Tornado") - self._initMonitoring() + # If there is no services loaded: + if not self.__calculateAppSettings(): + raise Exception("There is no services loaded, please check your configuration") - router = Application(self.urls, - debug=False, - compress_response=True) + sLog.debug("Starting Tornado") + # Prepare SSL settings certs = Locations.getHostCertificateAndKeyLocation() if certs is False: sLog.fatal("Host certificates not found ! Can't start the Server") @@ -139,29 +198,40 @@ def startTornado(self): 'sslDebug': False, # Set to true if you want to see the TLS debug messages } + # Init monitoring + self._initMonitoring() self.__monitorLastStatsUpdate = time.time() self.__report = self.__startReportToMonitoringLoop() # Starting monitoring, IOLoop waiting time in ms, __monitoringLoopDelay is defined in seconds tornado.ioloop.PeriodicCallback(self.__reportToMonitoring, self.__monitoringLoopDelay * 1000).start() - # If we are running with python3, Tornado will use asyncio, - # and we have to convince it to let us run in a different thread - # Doing this ensures a consistent behavior between py2 and py3 if six.PY3: + # If we are running with python3, Tornado will use asyncio, + # and we have to convince it to let us run in a different thread + # Doing this ensures a consistent behavior between py2 and py3 import asyncio # pylint: disable=import-error asyncio.set_event_loop_policy(tornado.platform.asyncio.AnyThreadEventLoopPolicy()) - # Start server - server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) - try: - server.listen(self.port) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception starting HTTPServer", e) - raise - sLog.always("Listening on port %s" % self.port) - for service in self.urls: - sLog.debug("Available service: %s" % service) + for port, app in self.__appsSettings.items(): + sLog.debug(" - %s" % "\n - ".join(["%s = %s" % (k, ssl_options[k]) for k in ssl_options])) + + # Default server configuration + settings = dict(compress_response=True, cookie_secret='secret') + + # Merge appllication settings + settings.update(app['settings']) + # Start server + router = Application(app['routes'], default_handler_class=NotFoundHandler, **settings) + server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) + try: + server.listen(int(port)) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception starting HTTPServer", e) + raise + sLog.always("Listening on port %s" % port) + for service in app['routes']: + sLog.debug("Available service: %s" % service if isinstance(service, url) else service[0]) IOLoop.current().start() diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index da0976ec12b..18ad47262b3 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -12,46 +12,29 @@ from io import open import os -import time -import threading from datetime import datetime -from six.moves import http_client -from tornado.web import RequestHandler, HTTPError -from tornado import gen + import tornado.ioloop +from tornado import gen from tornado.ioloop import IOLoop import DIRAC -from DIRAC import gConfig, gLogger, S_OK -from DIRAC.ConfigurationSystem.Client import PathFinder -from DIRAC.Core.DISET.AuthManager import AuthManager -from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC import gLogger, S_OK from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import BaseRequestHandler +from DIRAC.ConfigurationSystem.Client import PathFinder sLog = gLogger.getSubLogger(__name__) -class TornadoService(RequestHandler): # pylint: disable=abstract-method +class TornadoService(BaseRequestHandler): # pylint: disable=abstract-method """ - Base class for all the Handlers. - It directly inherits from :py:class:`tornado.web.RequestHandler` + Base class for all the sevices handlers. + It directly inherits from :py:class:`DIRAC.Core.Tornado.Server.BaseRequestHandler.BaseRequestHandler` Each HTTP request is served by a new instance of this class. - For the sequence of method called, please refer to - the `tornado documentation `_. - - For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, - the handler can define a method ``export_streamToClient``. This is the method that will be called - whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET - ``transfer_toClient``. - Note that this is here only for compatibility, and we discourage using it for new purposes, as it is - bound to disappear. - - The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. - In order to create a handler for your service, it has to follow a certain skeleton:: @@ -110,203 +93,67 @@ def export_streamToClient(self, myDataToSend, token): In order to pass information around and keep some states, we use instance attributes. These are initialized in the :py:meth:`.initialize` method. - """ - - # Because we initialize at first request, we use a flag to know if it's already done - __init_done = False - # Lock to make sure that two threads are not initializing at the same time - __init_lock = threading.RLock() + The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. - # MonitoringClient, we don't use gMonitor which is not thread-safe - # We also need to add specific attributes for each service - _monitor = None + """ + # Prefix of methods names + METHOD_PREFIX = "export_" @classmethod - def _initMonitoring(cls, serviceName, fullUrl): - """ - Initialize the monitoring specific to this handler - This has to be called only by :py:meth:`.__initializeService` - to ensure thread safety and unicity of the call. - - :param serviceName: relative URL ``//`` - :param fullUrl: full URl like ``https://://`` - """ - - # Init extra bits of monitoring - - cls._monitor = MonitoringClient() - cls._monitor.setComponentType(MonitoringClient.COMPONENT_WEB) - - cls._monitor.initialize() - - if tornado.process.task_id() is None: # Single process mode - cls._monitor.setComponentName('Tornado/%s' % serviceName) - else: - cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) - - cls._monitor.setComponentLocation(fullUrl) - - cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) - - cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) - cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) - cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) + def _getServiceName(cls, request): + """ Search service name in request. - cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} + :param object request: tornado Request - return S_OK() + :return: str + """ + # Expected path: ``//`` + return request.path[1:] @classmethod - def __initializeService(cls, relativeUrl, absoluteUrl): - """ - Initialize a service. - The work is only perform once at the first request. + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. - :param relativeUrl: relative URL, e.g. ``//`` - :param absoluteUrl: full URL e.g. ``https://://`` + :param str serviceName: service name + :param object request: tornado Request - :returns: S_OK + :return: dict """ - # If the initialization was already done successfuly, - # we can just return - if cls.__init_done: - return S_OK() - - # Otherwise, do the work but with a lock - with cls.__init_lock: - - # Check again that the initialization was not done by another thread - # while we were waiting for the lock - if cls.__init_done: - return S_OK() - - # Url starts with a "/", we just remove it - serviceName = relativeUrl[1:] - - cls._startTime = datetime.utcnow() - sLog.info("First use, initializing service...", "%s" % relativeUrl) - cls._authManager = AuthManager("%s/Authorization" % PathFinder.getServiceSection(serviceName)) - - cls._initMonitoring(serviceName, absoluteUrl) - - cls._serviceName = serviceName - cls._validNames = [serviceName] - serviceInfo = {'serviceName': serviceName, - 'serviceSectionPath': PathFinder.getServiceSection(serviceName), - 'csPaths': [PathFinder.getServiceSection(serviceName)], - 'URL': absoluteUrl - } - cls._serviceInfoDict = serviceInfo - - cls.__monitorLastStatsUpdate = time.time() - - cls.initializeHandler(serviceInfo) - - cls.__init_done = True - - return S_OK() + return {'serviceName': serviceName, + 'serviceSectionPath': PathFinder.getServiceSection(serviceName), + 'csPaths': [PathFinder.getServiceSection(serviceName)], + 'URL': request.full_url()} @classmethod - def initializeHandler(cls, serviceInfoDict): - """ - This may be overwritten when you write a DIRAC service handler - And it must be a class method. This method is called only one time, - at the first request + def _getServiceAuthSection(cls, serviceName): + """ Search service auth section. - :param dict ServiceInfoDict: infos about services, it contains - 'serviceName', 'serviceSectionPath', - 'csPaths' and 'URL' - """ - pass + :param str serviceName: service name - def initializeRequest(self): - """ - Called at every request, may be overwritten in your handler. + :return: str """ - pass + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) - # This is a Tornado magic method - def initialize(self): # pylint: disable=arguments-differ - """ - Initialize the handler, called at every request. - - It just calls :py:meth:`.__initializeService` + def _getMethodName(self): + """ Parse method name. - If anything goes wrong, the client will get ``Connection aborted`` - error. See details inside the method. - - ..warning:: - DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER - ==> initialize in DISET became initializeRequest in HTTPS ! + :return: str """ + return self.get_argument("method") - # Only initialized once - if not self.__init_done: - # Ideally, if something goes wrong, we would like to return a Server Error 500 - # but this method cannot write back to the client as per the - # `tornado doc `_. - # So the client will get a ``Connection aborted``` - try: - res = self.__initializeService(self.srv_getURL(), self.request.full_url()) - if not res['OK']: - raise Exception(res['Message']) - except Exception as e: - sLog.error("Error in initialization", repr(e)) - raise - - def prepare(self): - """ - Prepare the request. It reads certificates and check authorizations. - We make the assumption that there is always going to be a ``method`` argument - regardless of the HTTP method used + def _getMethodArgs(self, args): + """ Decode args. + :return: tuple """ - - # "method" argument of the POST call. - # This resolves into the ``export_`` method - # on the handler side - # If the argument is not available, the method exists - # and an error 400 ``Bad Request`` is returned to the client - self.method = self.get_argument("method") - - self._stats['requests'] += 1 - self._monitor.setComponentExtraParam('queries', self._stats['requests']) - self._monitor.addMark("Queries") - - try: - self.credDict = self._gatherPeerCredentials() - except Exception: # pylint: disable=broad-except - # If an error occur when reading certificates we close connection - # It can be strange but the RFC, for HTTP, say's that when error happend - # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN - sLog.error( - "Error gathering credentials", "%s; path %s" % - (self.getRemoteAddress(), self.request.path)) - raise HTTPError(status_code=http_client.UNAUTHORIZED) - - # Resolves the hard coded authorization requirements - try: - hardcodedAuth = getattr(self, 'auth_' + self.method) - except AttributeError: - hardcodedAuth = None - - # Check whether we are authorized to perform the query - # Note that performing the authQuery modifies the credDict... - authorized = self._authManager.authQuery(self.method, self.credDict, hardcodedAuth) - if not authorized: - sLog.error( - "Unauthorized access", "Identity %s; path %s; DN %s" % - (self.srv_getFormattedRemoteCredentials, - self.request.path, - self.credDict['DN'], - )) - raise HTTPError(status_code=http_client.UNAUTHORIZED) + args_encoded = self.get_body_argument('args', default=encode([])) + return (decode(args_encoded)[0], {}) # Make post a coroutine. # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines # for details @gen.coroutine - def post(self): # pylint: disable=arguments-differ + def post(self, *args, **kwargs): # pylint: disable=arguments-differ """ Method to handle incoming ``POST`` requests. Note that all the arguments are already prepared in the :py:meth:`.prepare` @@ -338,224 +185,30 @@ def post(self): # pylint: disable=arguments-differ ...: print r.json() ...: {u'OK': True, - u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', + u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', u'group': u'dirac_user', - u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', + u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', u'isLimitedProxy': False, u'isProxy': True, - u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', + u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', u'properties': [u'NormalUser'], u'secondsLeft': 85441, - u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/CN=2409820262', + u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch/CN=2409820262', u'username': u'adminusername', u'validDN': False, u'validGroup': False}} """ - - sLog.notice( - "Incoming request", "%s /%s: %s" % - (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - self.method)) - # Execute the method in an executor (basically a separate thread) # Because of that, we cannot calls certain methods like `self.write` - # in __executeMethod. This is because these methods are not threadsafe + # in _executeMethod. This is because these methods are not threadsafe # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes # However, we can still rely on instance attributes to store what should # be sent back (reminder: there is an instance # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self.__executeMethod) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) # retVal is :py:class:`tornado.concurrent.Future` - self.result = retVal.result() - - # Here it is safe to write back to the client, because we are not - # in a thread anymore - - # If set to true, do not JEncode the return of the RPC call - # This is basically only used for file download through - # the 'streamToClient' method. - rawContent = self.get_argument('rawContent', default=False) - - if rawContent: - # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - self.set_header("Content-Type", "application/octet-stream") - result = self.result - else: - self.set_header("Content-Type", "application/json") - result = encode(self.result) - - self.write(result) - self.finish() - - # This nice idea of streaming to the client cannot work because we are ran in an executor - # and we should not write back to the client in a different thread. - # See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # def export_streamToClient(self, filename): - # # https://bhch.github.io/posts/2017/12/serving-large-files-with-tornado-safely-without-blocking/ - # #import ipdb; ipdb.set_trace() - # # chunk size to read - # chunk_size = 1024 * 1024 * 1 # 1 MiB - - # with open(filename, 'rb') as f: - # while True: - # chunk = f.read(chunk_size) - # if not chunk: - # break - # try: - # self.write(chunk) # write the chunk to response - # self.flush() # send the chunk to client - # except StreamClosedError: - # # this means the client has closed the connection - # # so break the loop - # break - # finally: - # # deleting the chunk is very important because - # # if many clients are downloading files at the - # # same time, the chunks in memory will keep - # # increasing and will eat up the RAM - # del chunk - # # pause the coroutine so other handlers can run - # yield gen.sleep(0.000000001) # 1 nanosecond - - # return S_OK() - - @gen.coroutine - def __executeMethod(self): - """ - Execute the method called, this method is ran in an executor - We have several try except to catch the different problem which can occur - - - First, the method does not exist => Attribute error, return an error to client - - second, anything happend during execution => General Exception, send error to client - - .. warning:: - This method is called in an executor, and so cannot use methods like self.write - See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - """ - # getting method - try: - # For compatibility reasons with DISET, the methods are still called ``export_*`` - method = getattr(self, 'export_%s' % self.method) - except AttributeError: - sLog.error("Invalid method", self.method) - raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) - - # Decode args - args_encoded = self.get_body_argument('args', default=encode([])) - - args = decode(args_encoded)[0] - # Execute - try: - self.initializeRequest() - retVal = method(*args) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) - raise HTTPError(http_client.INTERNAL_SERVER_ERROR) - - return retVal - - # def __write_return(self, retVal): - # """ - # Write back to the client and return. - # It sets some headers (status code, ``Content-Type``). - # If raw content was requested by the client, the ``Content-Type`` - # is ``application/octet-stream``, otherwise we set it to ``application/json`` - # and JEncode retVal. - - # If ``retVal`` is a dictionary that contains a ``Callstack`` item, - # it is removed, not to leak internal information. - - # :param retVal: anything that can be serialized in json. - # """ - - # # In case of error in server side we hide server CallStack to client - # try: - # if 'CallStack' in retVal: - # del retVal['CallStack'] - # except TypeError: - # pass - - # # Set the status - # self.set_status(self._httpStatus) - - # # This is basically only used for file download through - # # the 'streamToClient' method. - # if self.rawContent: - # # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - # self.set_header("Content-Type", "application/octet-stream") - # returnedData = retVal - # else: - # self.set_header("Content-Type", "application/json") - # returnedData = encode(retVal) - - # self.write(returnedData) - - def on_finish(self): - """ - Called after the end of HTTP request. - Log the request duration - """ - elapsedTime = 1000.0 * self.request.request_time() - - try: - if self.result['OK']: - argsString = "OK" - else: - argsString = "ERROR: %s" % self.result['Message'] - except (AttributeError, KeyError): # In case it is not a DIRAC structure - if self._reason == 'OK': - argsString = 'OK' - else: - argsString = 'ERROR %s' % self._reason - - argsString = "ERROR: %s" % self._reason - sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - elapsedTime, argsString)) - - def _gatherPeerCredentials(self): - """ - Load client certchain in DIRAC and extract informations. - - The dictionary returned is designed to work with the AuthManager, - already written for DISET and re-used for HTTPS. - - :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` - (not a DIRAC structure !) - """ - - chainAsText = self.request.get_ssl_certificate().as_pem() - peerChain = X509Chain() - - # Here we read all certificate chain - cert_chain = self.request.get_ssl_certificate_chain() - for cert in cert_chain: - chainAsText += cert.as_pem() - - peerChain.loadChainFromString(chainAsText) - - # Retrieve the credentials - res = peerChain.getCredentials(withRegistryInfo=False) - if not res['OK']: - raise Exception(res['Message']) - - credDict = res['Value'] - - # We check if client sends extra credentials... - if "extraCredentials" in self.request.arguments: - extraCred = self.get_argument("extraCredentials") - if extraCred: - credDict['extraCredentials'] = decode(extraCred)[0] - return credDict - - -#### -# -# Default method -# -#### + self._finishFuture(retVal) auth_ping = ['all'] @@ -617,118 +270,3 @@ def export_whoami(self): # Not serializable del credDict['x509Chain'] return S_OK(credDict) - -#### -# -# Utilities methods, some getters. -# From DIRAC.Core.DISET.requestHandler to get same interface in the handlers. -# Adapted for Tornado. -# These method are copied from DISET RequestHandler, they are not all used when i'm writing -# these lines. I rewrite them for Tornado to get them ready when a new HTTPS service need them -# -#### - - @classmethod - def srv_getCSOption(cls, optionName, defaultValue=False): - """ - Get an option from the CS section of the services - - :return: Value for serviceSection/optionName in the CS being defaultValue the default - """ - if optionName[0] == "/": - return gConfig.getValue(optionName, defaultValue) - for csPath in cls._serviceInfoDict['csPaths']: - result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) - if result['OK']: - return result['Value'] - return defaultValue - - def getCSOption(self, optionName, defaultValue=False): - """ - Just for keeping same public interface - """ - return self.srv_getCSOption(optionName, defaultValue) - - def srv_getRemoteAddress(self): - """ - Get the address of the remote peer. - - :return: Address of remote peer. - """ - - remote_ip = self.request.remote_ip - # Although it would be trivial to add this attribute in _HTTPRequestContext, - # Tornado won't release anymore 5.1 series, so go the hacky way - try: - remote_port = self.request.connection.stream.socket.getpeername()[1] - except Exception: # pylint: disable=broad-except - remote_port = 0 - - return (remote_ip, remote_port) - - def getRemoteAddress(self): - """ - Just for keeping same public interface - """ - return self.srv_getRemoteAddress() - - def srv_getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def srv_getFormattedRemoteCredentials(self): - """ - Return the DN of user - - Mostly copy paste from - :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` - - Note that the information will be complete only once the AuthManager was called - """ - address = self.getRemoteAddress() - peerId = "" - # Depending on where this is call, it may be that credDict is not yet filled. - # (reminder: AuthQuery fills part of it..) - try: - peerId = "[%s:%s]" % (self.credDict['group'], self.credDict['username']) - except AttributeError: - pass - - if address[0].find(":") > -1: - return "([%s]:%s)%s" % (address[0], address[1], peerId) - return "(%s:%s)%s" % (address[0], address[1], peerId) - -# def getFormattedCredentials(self): -# peerCreds = self.getConnectingCredentials() -# address = self.getRemoteAddress() -# if 'username' in peerCreds: -# peerId = "[%s:%s]" % (peerCreds['group'], peerCreds['username']) -# else: -# peerId = "" -# if address[0].find(":") > -1: -# return "([%s]:%s)%s" % (address[0], address[1], peerId) -# return "(%s:%s)%s" % (address[0], address[1], peerId) - - def srv_getServiceName(self): - """ - Return the service name - """ - return self._serviceInfoDict['serviceName'] - - def srv_getURL(self): - """ - Return the URL - """ - return self.request.path diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py new file mode 100644 index 00000000000..9a0d2aa6710 --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -0,0 +1,818 @@ +""" BaseRequestHandler is the base class for tornados services and etc handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from io import open + +import os +import six +import time +import inspect +import threading +from datetime import datetime +from six import string_types +from six.moves import http_client +from six.moves.urllib.parse import unquote +from functools import partial + +import tornado +from tornado import gen +from tornado.web import RequestHandler, HTTPError +from tornado.ioloop import IOLoop +from tornado.concurrent import Future + +import DIRAC + +from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.DISET.AuthManager import AuthManager +from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance + +if six.PY3: + # DIRACOS not contain required packages + import jwt + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +sLog = gLogger.getSubLogger(__name__.split('.')[-1]) + + +class TornadoResponse(object): + """ This class registers tornadoes with arguments in the order they are called + from TornadoResponse to call them later. + + Use:: + + def web_myEndpotin(self): + resp = TornadoResponse('data') + resp.set_status(400) + return resp + """ + __attrs = inspect.getmembers(RequestHandler) + + def __init__(self, payload=None, status_code=None): + """ C'or + + :param payload: response body + :param int status_code: response status code + """ + self.payload = payload + self.status_code = status_code + self.actions = [] + for mName, mObj in self.__attrs: + if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startswith('get'): + setattr(self, mName, partial(self.__setAction, mName)) + + def __setAction(self, mName, *args, **kwargs): + """ Register new action + + :param str mName: RequestHandler method name + + :return: TornadoResponse instance + """ + self.actions.append((mName, args, kwargs)) + return self + + def _runActions(self, reqObj): + """ Calling methods in the order of their registration + + :param reqObj: RequestHandler instance + """ + if self.status_code: + reqObj.set_status(self.status_code) + for mName, args, kwargs in self.actions: + getattr(reqObj, mName)(*args, **kwargs) + if not reqObj._finished: + reqObj.finish(self.payload) + + +class BaseRequestHandler(RequestHandler): + """ Base class for all the Handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` + + Each HTTP request is served by a new instance of this class. + + For the sequence of method called, please refer to + the `tornado documentation `_. + + This class is basic for :py:class:`DIRAC.Core.Tornado.Server.TornadoService.TornadoService` + and :py:class:`DIRAC.Core.Tornado.Server.TornadoREST.TornadoREST`. + + In order to create a class that inherits from `BaseRequestHandler`, it has to + follow a certain skeleton:: + + class TornadoInstance(BaseRequestHandler): + + # Prefix of methods names + METHOD_PREFIX = "export_" + + @classmethod + def _getServiceName(cls, request): + ''' Search service name in request + ''' + return request.path[1:] + + @classmethod + def _getServiceInfo(cls, serviceName, request): + ''' Fill service information. + ''' + return {'serviceName': serviceName, + 'serviceSectionPath': PathFinder.getServiceSection(serviceName), + 'csPaths': [PathFinder.getServiceSection(serviceName)], + 'URL': request.full_url()} + + @classmethod + def _getServiceAuthSection(cls, serviceName): + ''' Search service "Authorization" configuration section. + ''' + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + + def _getMethodName(self): + ''' Parse method name. + ''' + return self.get_argument("method") + + def _getMethodArgs(self, args): + ''' Decode args. + ''' + args_encoded = self.get_body_argument('args', default=encode([])) + return decode(args_encoded)[0] + + # Make post a coroutine. + # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines + # for details + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + ''' Describe HTTP method to use + ''' + # Execute the method in an executor (basically a separate thread) + # Because of that, we cannot calls certain methods like `self.write` + # in __executeMethod. This is because these methods are not threadsafe + # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + # However, we can still rely on instance attributes to store what should + # be sent back (reminder: there is an instance of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) + + For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, + the handler can define a method ``export_streamToClient``. This is the method that will be called + whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET + ``transfer_toClient``. + Note that this is here only for compatibility, and we discourage using it for new purposes, as it is + bound to disappear. + """ + # Because we initialize at first request, we use a flag to know if it's already done + __init_done = False + # Lock to make sure that two threads are not initializing at the same time + __init_lock = threading.RLock() + + # MonitoringClient, we don't use gMonitor which is not thread-safe + # We also need to add specific attributes for each service + _monitor = None + + # System name with which this component is associated + SYSTEM = None + + # Auth requirements + AUTH_PROPS = None + + # Type of component + MONITORING_COMPONENT = MonitoringClient.COMPONENT_WEB + + # Prefix of methods names + METHOD_PREFIX = "export_" + + # Definition of identity providers + _idps = IdProviderFactory() if six.PY3 else None + _idp = {} + + # Which grant type to use + USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + + @classmethod + def _initMonitoring(cls, serviceName, fullUrl): + """ + Initialize the monitoring specific to this handler + This has to be called only by :py:meth:`.__initializeService` + to ensure thread safety and unicity of the call. + + :param serviceName: relative URL ``//`` + :param fullUrl: full URl like ``https://://`` + """ + + # Init extra bits of monitoring + + cls._monitor = MonitoringClient() + cls._monitor.setComponentType(cls.MONITORING_COMPONENT) + + cls._monitor.initialize() + + if tornado.process.task_id() is None: # Single process mode + cls._monitor.setComponentName('Tornado/%s' % serviceName) + else: + cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) + + cls._monitor.setComponentLocation(fullUrl) + + cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) + + cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) + cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) + cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) + + cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} + + return S_OK() + + @classmethod + def _getServiceName(cls, request): + """ Search service name in request. + + :param object request: tornado Request + + :return: str + """ + raise NotImplementedError('Please, create the _getServiceName class method') + + @classmethod + def _getServiceAuthSection(cls, serviceName): + """ Search service auth section. + + :param str serviceName: service name + + :return: str + """ + raise NotImplementedError('Please, create the _getServiceAuthSection class method') + + @classmethod + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. + + :param str serviceName: service name + :param object request: tornado Request + + :return: dict + """ + gLogger.warn('Service information will not be collected because the _getServiceInfo method is not defined.') + return {} + + @classmethod + def __loadIdPs(cls): + """ Load identity providers that will be used to verify tokens + """ + gLogger.info('Load identity providers..') + # Research Identity Providers + result = getProvidersForInstance('Id') + if result['OK']: + for providerName in result['Value']: + result = cls._idps.getIdProvider(providerName) + if result['OK']: + cls._idp[result['Value'].issuer.strip('/')] = result['Value'] + else: + gLogger.error(result['Message']) + + @classmethod + def __initializeService(cls, request): + """ + Initialize a service. + The work is only performed once at the first request. + + :param object request: tornado Request + + :returns: S_OK + """ + # If the initialization was already done successfuly, + # we can just return + if cls.__init_done: + return S_OK() + + # Otherwise, do the work but with a lock + with cls.__init_lock: + + # Check again that the initialization was not done by another thread + # while we were waiting for the lock + if cls.__init_done: + return S_OK() + + # absoluteUrl: full URL e.g. ``https://://`` + absoluteUrl = request.path + serviceName = cls._getServiceName(request) + + cls._startTime = datetime.utcnow() + sLog.info("First use of %s, initializing service..." % serviceName) + cls._authManager = AuthManager(cls._getServiceAuthSection(serviceName)) + + cls._initMonitoring(serviceName, absoluteUrl) + + cls._serviceName = serviceName + cls._validNames = [serviceName] + serviceInfo = cls._getServiceInfo(serviceName, request) + + cls._serviceInfoDict = serviceInfo + + cls.__monitorLastStatsUpdate = time.time() + + # Some pre-initialization + cls._initializeHandler() + + cls.initializeHandler(serviceInfo) + + # Load all registred identity providers + if six.PY3: + cls.__loadIdPs() + + cls.__init_done = True + + return S_OK() + + @classmethod + def _initializeHandler(cls): + """ + If you are writing your own framework that follows this class + and you need to add something before initializing the service, + such as initializing the OAuth client, then you need to change this method. + """ + pass + + @classmethod + def initializeHandler(cls, serviceInfo): + """ + This may be overwritten when you write a DIRAC service handler + And it must be a class method. This method is called only one time, + at the first request + + :param dict ServiceInfoDict: infos about services, it contains + 'serviceName', 'serviceSectionPath', + 'csPaths' and 'URL' + """ + pass + + def initializeRequest(self): + """ + Called at every request, may be overwritten in your handler. + """ + pass + + # This is a Tornado magic method + def initialize(self): # pylint: disable=arguments-differ + """ + Initialize the handler, called at every request. + + It just calls :py:meth:`.__initializeService` + + If anything goes wrong, the client will get ``Connection aborted`` + error. See details inside the method. + + ..warning:: + DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER + ==> initialize in DISET became initializeRequest in HTTPS ! + """ + # Only initialized once + if not self.__init_done: + # Ideally, if something goes wrong, we would like to return a Server Error 500 + # but this method cannot write back to the client as per the + # `tornado doc `_. + # So the client will get a ``Connection aborted``` + try: + res = self.__initializeService(self.request) + if not res['OK']: + raise Exception(res['Message']) + except Exception as e: + sLog.error("Error in initialization", repr(e)) + raise + + def _monitorRequest(self): + """ Monitor action for each request + """ + self._stats['requests'] += 1 + self._monitor.setComponentExtraParam('queries', self._stats['requests']) + self._monitor.addMark("Queries") + + def _getMethodName(self): + """ Parse method name. + + :return: str + """ + raise NotImplementedError('Please, create the _getMethodName method') + + def _getMethodArgs(self, args): + """ Decode args. + + :return: tuple -- contain args and kwargs + """ + # Read information about method + fArgs, fvArgs, fvKwargs, fDef = inspect.getargspec(self.methodObj) + # Remove self argument from method arguments + fArgs = [a for a in fArgs if a != "self"] + # Pass all arguments + kwargs = self.request.arguments.copy() if fvKwargs else {} + # Get all defaults from method + defaults = {a: fDef[fArgs[-len(fDef):].index(a)] for a in fArgs[-len(fDef):]} if fDef else {} + # Calcule kwargs + for arg in fArgs[len(args):]: + if arg not in defaults: + kwargs[arg] = self.get_argument(arg) + elif isinstance(defaults[arg], list): + kwargs[arg] = self.get_arguments(arg) or defaults[arg] + else: + kwargs[arg] = self.get_argument(arg, defaults[arg]) + + return ([unquote(a) for a in args], kwargs) + + def _getMethodAuthProps(self): + """ Resolves the hard coded authorization requirements for method. + + :return: list + """ + if self.AUTH_PROPS and not isinstance(self.AUTH_PROPS, (list, tuple)): + self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] + return getattr(self, 'auth_' + self.mehtodName, self.AUTH_PROPS) + + def _getMethod(self): + """ Get method function to call. + + :return: function + """ + methodObj = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.mehtodName), None) + if not callable(methodObj): + sLog.error("Invalid method", self.mehtodName) + raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + return methodObj + + def prepare(self): + """ + Tornados prepare method that called before request + """ + + # "method" argument of the POST call. + # This resolves into the ``export_`` method + # on the handler side + # If the argument is not available, the method exists + # and an error 400 ``Bad Request`` is returned to the client + self.mehtodName = self._getMethodName() + self.methodObj = self._getMethod() + + self._monitorRequest() + + self._prepare() + + def _prepare(self): + """ + Prepare the request. It reads certificates and check authorizations. + We make the assumption that there is always going to be a ``method`` argument + regardless of the HTTP method used + + """ + try: + self.credDict = self._gatherPeerCredentials() + except Exception as e: # pylint: disable=broad-except + # If an error occur when reading certificates we close connection + # It can be strange but the RFC, for HTTP, say's that when error happend + # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN + sLog.debug(str(e)) + sLog.error( + "Error gathering credentials ", "%s; path %s" % + (self.getRemoteAddress(), self.request.path)) + raise HTTPError(http_client.UNAUTHORIZED, str(e)) + + # Check whether we are authorized to perform the query + # Note that performing the authQuery modifies the credDict... + authorized = self._authManager.authQuery(self.mehtodName, self.credDict, + self._getMethodAuthProps()) + if not authorized: + extraInfo = '' + if self.credDict.get('ID'): + extraInfo += 'ID: %s' % self.credDict['ID'] + elif self.credDict.get('DN'): + extraInfo += 'DN: %s' % self.credDict['DN'] + sLog.error( + "Unauthorized access", "Identity %s; path %s; %s" % + (self.srv_getFormattedRemoteCredentials(), + self.request.path, extraInfo)) + raise HTTPError(http_client.UNAUTHORIZED) + + def __executeMethod(self, targetMethod, args, kwargs): + """ + Execute the method called, this method is ran in an executor + We have several try except to catch the different problem which can occur + + - First, the method does not exist => Attribute error, return an error to client + - second, anything happend during execution => General Exception, send error to client + + .. warning:: + This method is called in an executor, and so cannot use methods like self.write + See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + + :param str targetMethod: name of the method to call + :param list args: target method arguments + :param dict kwargs: target method arguments + + :return: Future + """ + + sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, self.mehtodName)) + # Execute + try: + self.initializeRequest() + return targetMethod(*args, **kwargs) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) + raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) + + def _prepareExecutor(self, args): + """ Preparation of necessary arguments for the `__executeMethod` method + + :param list args: arguments passed to the `post`, `get`, etc. tornado methods + + :return: executor, target method with arguments + """ + args, kwargs = self._getMethodArgs(args) + return None, partial(gen.coroutine(self.__executeMethod) if six.PY2 else self.__executeMethod, + self.methodObj, args, kwargs) + + def _finishFuture(self, retVal): + """ Handler Future result + + :param object retVal: tornado.concurrent.Future + """ + # Wait result only if it's a Future object + self.result = retVal.result() if isinstance(retVal, Future) else retVal + + # Here it is safe to write back to the client, because we are not in a thread anymore + + # If you need to end the method using tornado methods, outside the thread, + # you need to define the finish_ method. + # This method will be started after __executeMethod is completed. + finishFunc = getattr(self, 'finish_%s' % self.mehtodName, None) + + if isinstance(self.result, TornadoResponse): + self.result._runActions(self) + + elif callable(finishFunc): + finishFunc() + + # In case nothing is returned + elif self.result is None: + self.finish() + + # If set to true, do not JEncode the return of the RPC call + # This is basically only used for file download through + # the 'streamToClient' method. + elif self.get_argument('rawContent', default=False): + # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt + self.set_header("Content-Type", "application/octet-stream") + self.finish(self.result) + + # Return simple text or html + elif isinstance(self.result, string_types): + self.finish(self.result) + + # JSON + else: + self.set_header("Content-Type", "application/json") + self.finish(encode(self.result)) + + def on_finish(self): + """ + Called after the end of HTTP request. + Log the request duration + """ + elapsedTime = 1000.0 * self.request.request_time() + + argsString = "OK" + try: + if not self.result['OK']: + argsString = "ERROR: %s" % self.result['Message'] + except (AttributeError, KeyError, TypeError): # In case it is not a DIRAC structure + if self._reason != 'OK': + argsString = 'ERROR %s' % self._reason + + sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + elapsedTime, argsString)) + + def _gatherPeerCredentials(self, grants=None): + """ Returne a dictionary designed to work with the AuthManager, + already written for DISET and re-used for HTTPS. + + :param list grants: grants to use + + :returns: a dict containing user credentials + """ + err = [] + + # At least some authorization method must be defined, if nothing is defined, + # the authorization will go through the `_authzVISITOR` method and + # everyone will have access as anonymous@visitor + for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): + if six.PY2 and grant == 'JWT': + # Skip token authorization for python 2 + continue + grant = grant.upper() + grantFunc = getattr(self, '_authz%s' % grant, None) + # pylint: disable=not-callable + result = grantFunc() if callable(grantFunc) else S_ERROR('%s authentication type is not supported.' % grant) + if result['OK']: + for e in err: + sLog.debug(e) + sLog.debug('%s authentication success.' % grant) + return result['Value'] + err.append('%s authentication: %s' % (grant, result['Message'])) + + # Report on failed authentication attempts + raise Exception('; '.join(err)) + + def _authzSSL(self): + """ Load client certchain in DIRAC and extract informations. + + :return: S_OK(dict)/S_ERROR() + """ + try: + derCert = self.request.get_ssl_certificate() + except Exception: + # If 'IOStream' object has no attribute 'get_ssl_certificate' + derCert = None + + # Get client certificate as pem + if derCert: + chainAsText = derCert.as_pem().decode() + # Read all certificate chain + chainAsText += ''.join([cert.as_pem().decode() for cert in self.request.get_ssl_certificate_chain()]) + elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS' and self.request.headers.get('X-SSL-CERT'): + chainAsText = unquote(self.request.headers.get('X-SSL-CERT')) + else: + return S_ERROR(DErrno.ECERTFIND, 'Valid certificate not found.') + + # Load full certificate chain + peerChain = X509Chain() + peerChain.loadChainFromString(chainAsText) + + # Retrieve the credentials + res = peerChain.getCredentials(withRegistryInfo=False) + if not res['OK']: + return res + + credDict = res['Value'] + + # We check if client sends extra credentials... + if "extraCredentials" in self.request.arguments: + extraCred = self.get_argument("extraCredentials") + if extraCred: + credDict['extraCredentials'] = decode(extraCred)[0] + return S_OK(credDict) + + def _authzJWT(self, accessToken=None): + """ Load token claims in DIRAC and extract informations. + + :param str accessToken: access_token + + :return: S_OK(dict)/S_ERROR() + """ + if not accessToken: + # Export token from headers + token = self.request.headers.get('Authorization') + if not token or len(token.split()) != 2: + return S_ERROR(DErrno.EATOKENFIND, 'Not found a bearer access token.') + tokenType, accessToken = token.split() + if tokenType.lower() != 'bearer': + return S_ERROR(DErrno.ETOKENTYPE, 'Found a not bearer access token.') + + # Read token without verification to get issuer + self.log.debug('Read issuer from access token', accessToken) + issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, + verify_aud=False))['iss'].strip('/') + # Verify token + self.log.debug('Verify access token') + result = self._idp[issuer].verifyToken(accessToken) + self.log.debug('Search user group') + return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result + + def _authzVISITOR(self): + """ Visitor access + + :return: S_OK(dict) + """ + return S_OK({}) + + @property + def log(self): + return sLog + + def getUserDN(self): + return self.credDict.get('DN', '') + + def getUserName(self): + return self.credDict.get('username', '') + + def getUserGroup(self): + return self.credDict.get('group', '') + + def getProperties(self): + return self.credDict.get('properties', []) + + def isRegisteredUser(self): + return self.credDict.get('username', 'anonymous') != 'anonymous' and self.credDict.get('group') + + @classmethod + def srv_getCSOption(cls, optionName, defaultValue=False): + """ + Get an option from the CS section of the services + + :return: Value for serviceSection/optionName in the CS being defaultValue the default + """ + if optionName[0] == "/": + return gConfig.getValue(optionName, defaultValue) + for csPath in cls._serviceInfoDict['csPaths']: + result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) + if result['OK']: + return result['Value'] + return defaultValue + + def getCSOption(self, optionName, defaultValue=False): + """ + Just for keeping same public interface + """ + return self.srv_getCSOption(optionName, defaultValue) + + def srv_getRemoteAddress(self): + """ + Get the address of the remote peer. + + :return: Address of remote peer. + """ + + remote_ip = self.request.remote_ip + # Although it would be trivial to add this attribute in _HTTPRequestContext, + # Tornado won't release anymore 5.1 series, so go the hacky way + try: + remote_port = self.request.connection.stream.socket.getpeername()[1] + except Exception: # pylint: disable=broad-except + remote_port = 0 + + return (remote_ip, remote_port) + + def getRemoteAddress(self): + """ + Just for keeping same public interface + """ + return self.srv_getRemoteAddress() + + def srv_getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def srv_getFormattedRemoteCredentials(self): + """ + Return the DN of user + + Mostly copy paste from + :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` + + Note that the information will be complete only once the AuthManager was called + """ + address = self.getRemoteAddress() + peerId = "" + # Depending on where this is call, it may be that credDict is not yet filled. + # (reminder: AuthQuery fills part of it..) + try: + peerId = "[%s:%s]" % (self.credDict.get('group', 'visitor'), self.credDict.get('username', 'anonymous')) + except AttributeError: + pass + + if address[0].find(":") > -1: + return "([%s]:%s)%s" % (address[0], address[1], peerId) + return "(%s:%s)%s" % (address[0], address[1], peerId) + + def srv_getServiceName(self): + """ + Return the service name + """ + return self._serviceInfoDict['serviceName'] + + def srv_getURL(self): + """ + Return the URL + """ + return self.request.path diff --git a/src/DIRAC/Core/Tornado/Server/private/__init__.py b/src/DIRAC/Core/Tornado/Server/private/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py index 61462857ce7..27d30fe591a 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py @@ -58,7 +58,7 @@ def main(): except TypeError: csPort = None - serverToLaunch = TornadoServer(services='Configuration/Server', port=csPort) + serverToLaunch = TornadoServer(services=['Configuration/Server'], port=csPort) serverToLaunch.startTornado() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py index 4a28a5b4704..c0b729a7890 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py @@ -60,7 +60,7 @@ def main(): gLogger.initialize('Tornado', "/") - serverToLaunch = TornadoServer() + serverToLaunch = TornadoServer(endpoints=True) serverToLaunch.startTornado() diff --git a/src/DIRAC/Core/Utilities/DErrno.py b/src/DIRAC/Core/Utilities/DErrno.py index 54b692d2fd1..68ed494f359 100644 --- a/src/DIRAC/Core/Utilities/DErrno.py +++ b/src/DIRAC/Core/Utilities/DErrno.py @@ -105,6 +105,10 @@ EMQCONN = 1142 # Elasticsearch EELNOFOUND = 1146 +# Tokens +EATOKENFIND = 1150 +EATOKENREAD = 1151 +ETOKENTYPE = 1152 # config ESECTION = 1400 @@ -185,6 +189,12 @@ 1142: 'EMQCONN', # Elasticsearch 1146: 'EELNOFOUND', + + # 115X: Tokens + 1150: 'EATOKENFIND', + 1151: 'EATOKENREAD', + 1152: 'ETOKENTYPE', + # Config 1400: "ESECTION", # Processes @@ -260,6 +270,12 @@ EMQCONN: "MQ connection failure", # 114X Elasticsearch EELNOFOUND: "Index not found", + + # 115X: Tokens + EATOKENFIND: "Can't find a bearer access token.", + EATOKENREAD: "Can't read a bearer access token.", + ETOKENTYPE: "Unsupported access token type.", + # Config ESECTION: "Section is not found", # processes diff --git a/src/DIRAC/Core/Utilities/JEncode.py b/src/DIRAC/Core/Utilities/JEncode.py index c2283a41061..0369f401def 100644 --- a/src/DIRAC/Core/Utilities/JEncode.py +++ b/src/DIRAC/Core/Utilities/JEncode.py @@ -105,6 +105,9 @@ def default(self, obj): # pylint: disable=method-hidden # if the object inherits from JSJerializable, try to serialize it elif isinstance(obj, JSerializable): return obj._toJSON() # pylint: disable=protected-access + # if the object a bytes, decode it + elif isinstance(obj, bytes): + return obj.decode() # otherwise, let the parent do return super(DJSONEncoder, self).default(obj) diff --git a/src/DIRAC/Core/scripts/dirac_configure.py b/src/DIRAC/Core/scripts/dirac_configure.py index 48bf58f5c57..5c6db36cbab 100755 --- a/src/DIRAC/Core/scripts/dirac_configure.py +++ b/src/DIRAC/Core/scripts/dirac_configure.py @@ -63,6 +63,7 @@ class Params(object): def __init__(self): + self.issuer = None self.logLevel = None self.setup = None self.configurationServer = None @@ -172,6 +173,11 @@ def setExtensions(self, optionValue): DIRAC.gConfig.setOptionValue(cfgInstallPath('Extensions'), self.extensions) return DIRAC.S_OK() + def setIssuer(self, optionValue): + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'True' + self.issuer = optionValue + DIRAC.gConfig.setOptionValue('/DIRAC/Security/Authorization/issuer', self.issuer) + return DIRAC.S_OK() def _runConfigurationWizard(setups, defaultSetup): """The implementation of the configuration wizard""" @@ -289,7 +295,75 @@ def main(): runDiracConfigure(params) +def login(params): + from prompt_toolkit import prompt, print_formatted_text, HTML + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, + getTokenFileLocation, + readTokenFromFile) + # Init authorization client + result = IdProviderFactory().getIdProvider('DIRACCLI', issuer=params.issuer, scope=' ') + if not result['OK']: + return result + idpObj = result['Value'] + + # Get token file path + tokenFile = getTokenFileLocation() + + # Submit Device authorisation flow + result = idpObj.deviceAuthorization() + if not result['OK']: + return result + + # Revoke old tokens from token file + if os.path.isfile(tokenFile): + result = readTokenFromFile(tokenFile) + if not result['OK']: + DIRAC.gLogger.warn(result['Message']) + elif result['Value']: + oldToken = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(oldToken[tokenType], tokenType) + if result['OK']: + DIRAC.gLogger.debug('%s is revoked from' % tokenType, tokenFile) + else: + DIRAC.gLogger.warn(result['Message']) + + # Save new tokens to token file + result = writeTokenDictToTokenFile(idpObj.token, tokenFile) + if not result['OK']: + return result + DIRAC.gLogger.debug('New token is saved to %s.' % result['Value']) + + # Get server setups and master CS server URL + csURL = idpObj.get_metadata("configuration_server") + setups = idpObj.get_metadata("setups") + + if len(setups) == 1 and csURL: + # If setup only one dont ask user + params.setSetup(setups[0]) + params.setServer(csURL) + elif setups and csURL: + # Ask the user for the appropriate configuration settings + while True: + result = _runConfigurationWizard({setup: csURL for setup in setups}, setups[0]) + if result: + break + print_formatted_text(HTML( + "Wizard failed, retrying... (press Control + C to exit)\n" + )) + # Apply the arguments to the params object + setup, csURL = result + params.setSetup(setup) + params.setServer(csURL) + + confirm = prompt(HTML("Do you want to use tokens instead of certificates by default? "), default="no") + return DIRAC.S_OK('yes') if confirm.lower() in ["y", "yes"] else DIRAC.S_OK(None) + + def runDiracConfigure(params): + if six.PY3: + Script.registerSwitch("", "login=", "Set DIRAC authorization endpoint", params.setIssuer) Script.registerSwitch("S:", "Setup=", "Set as DIRAC setup", params.setSetup) Script.registerSwitch("e:", "Extensions=", "Set as DIRAC extensions", params.setExtensions) Script.registerSwitch("C:", "ConfigurationServer=", "Set as DIRAC configuration server", params.setServer) @@ -320,6 +394,18 @@ def runDiracConfigure(params): Script.parseCommandLine(ignoreErrors=True) + # Use token auth + if params.issuer: + result = login(params) + if not result['OK']: + DIRAC.gLogger.error('Authorization failed: %s' % result['Message']) + DIRAC.exit(1) + useTokens = result['Value'] + if useTokens: + DIRAC.gConfig.setOptionValue('/DIRAC/Security/UseTokens', useTokens) + else: + DIRAC.gLogger.notice('To use tokens, please, set "/DIRAC/Security/UseTokens=yes".') + if not params.logLevel: params.logLevel = DIRAC.gConfig.getValue(cfgInstallPath('LogLevel'), '') if params.logLevel: diff --git a/src/DIRAC/Core/scripts/dirac_info.py b/src/DIRAC/Core/scripts/dirac_info.py index ebb8bd8834e..f9b38c330a9 100755 --- a/src/DIRAC/Core/scripts/dirac_info.py +++ b/src/DIRAC/Core/scripts/dirac_info.py @@ -60,6 +60,8 @@ def platform(arg): records = [] records.append(('Setup', gConfig.getValue('/DIRAC/Setup', 'Unknown'))) + records.append(('AuthorizationServer', gConfig.getValue('/DIRAC/Security/Authorization/issuer', + '/DIRAC/Security/Authorization/issuer option is absent'))) records.append(('ConfigurationServer', gConfig.getValue('/DIRAC/Configuration/Servers', []))) records.append(('Installation path', DIRAC.rootPath)) @@ -88,6 +90,10 @@ def platform(arg): records.append(('Use Server Certificate', 'Yes')) else: records.append(('Use Server Certificate', 'No')) + if gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true"): + records.append(('Use tokens', 'Yes')) + else: + records.append(('Use tokens', 'No')) if gConfig.getValue('/DIRAC/Security/SkipCAChecks', False): records.append(('Skip CA Checks', 'Yes')) else: diff --git a/src/DIRAC/Core/scripts/install_full.cfg b/src/DIRAC/Core/scripts/install_full.cfg index 248a69b59d9..f5b823b3e74 100755 --- a/src/DIRAC/Core/scripts/install_full.cfg +++ b/src/DIRAC/Core/scripts/install_full.cfg @@ -101,6 +101,8 @@ LocalInstallation Databases += FTSDB Databases += ComponentMonitoringDB Databases += ProxyDB + Databases += TokenDB + Databases += AuthDB Databases += PilotAgentsDB Databases += AccountingDB Databases += TransformationDB diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py new file mode 100644 index 00000000000..e297e9073a3 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -0,0 +1,452 @@ +""" This handler basically provides a REST interface to interact with the OAuth 2 authentication server + + .. literalinclude:: ../ConfigTemplate.cfg + :start-after: ##BEGIN Auth: + :end-before: ##END + :dedent: 2 + :caption: Auth options +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import pprint + +from dominate import tags as dom +from tornado.template import Template +from tornado.concurrent import Future + +from authlib.oauth2.base import OAuth2Error + +from DIRAC import S_ERROR, gConfig +from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getIdPForGroup, getGroupsForUser +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML + +__RCSID__ = "$Id$" + + +class AuthHandler(TornadoREST): + # Authorization access to all methods handled by AuthServer instance + USE_AUTHZ_GRANTS = ['JWT', 'VISITOR'] + SYSTEM = 'Framework' + AUTH_PROPS = 'all' + LOCATION = "/auth" + + @classmethod + def initializeHandler(cls, serviceInfo): + """ This method is called only one time, at the first request + + :param dict ServiceInfoDict: infos about services + """ + cls.server = AuthServer() + cls.server.LOCATION = cls.LOCATION + + def initializeRequest(self): + """ Called at every request """ + self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path + + path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] + + def web_index(self, well_known, instance=None): + """ Well known endpoint, specified by + `RFC8414 `_ + + Request examples:: + + GET: LOCATION/.well-known/openid-configuration + GET: LOCATION/.well-known/oauth-authorization-server + + Responce:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "registration_endpoint": "https://domain.com/auth/register", + "userinfo_endpoint": "https://domain.com/auth/userinfo", + "jwks_uri": "https://domain.com/auth/jwk", + "code_challenge_methods_supported": [ + "S256" + ], + "grant_types_supported": [ + "authorization_code", + "code", + "refresh_token" + ], + "token_endpoint": "https://domain.com/auth/token", + "response_types_supported": [ + "code", + "device", + "id_token token", + "id_token", + "token" + ], + "authorization_endpoint": "https://domain.com/auth/authorization", + "issuer": "https://domain.com/auth" + } + """ + if self.request.method == "GET": + resDict = dict(setups=gConfig.getSections('DIRAC/Setups').get('Value', []), + configuration_server=gConfig.getValue("/DIRAC/Configuration/MasterServer", "")) + resDict.update(self.server.metadata) + resDict.pop('Clients', None) + return resDict + + def web_jwk(self): + """ JWKs endpoint + + Request example:: + + GET LOCATION/jwk + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "keys": [ + { + "e": "AQAB", + "kty": "RSA", + "n": "3Vv5h5...X3Y7k" + } + ] + } + """ + result = self.server.db.getKeySet() + return result['Value'].as_dict() if result['OK'] else {} + + def web_revoke(self): + """ Revocation endpoint + + Request example:: + + GET LOCATION/revoke + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + """ + if self.request.method == 'POST': + self.log.verbose('Initialize a Device authentication flow.') + return self.server.create_endpoint_response(RevocationEndpoint.ENDPOINT_NAME, self.request) + + def web_userinfo(self): + """ The UserInfo endpoint can be used to retrieve identity information about a user, + see `spec `_ + + GET LOCATION/userinfo + + Parameters: + +---------------+--------+---------------------------------+--------------------------------------------------+ + | **name** | **in** | **description** | **example** | + +---------------+--------+---------------------------------+--------------------------------------------------+ + | Authorization | header | Provide access token | Bearer jkagfbfd3r4ubf887gqduyqwogasd87 | + +---------------+--------+---------------------------------+--------------------------------------------------+ + + Request example:: + + GET LOCATION/userinfo + Authorization: Bearer + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "248289761001", + "name": "Bob Smith", + "given_name": "Bob", + "family_name": "Smith", + "group": [ + "dirac_user", + "dirac_admin" + ] + } + """ + return self.getRemoteCredentials() + + path_device = ['([A-z%0-9-_]*)'] + + def web_device(self, provider=None, user_code=None): + """ The device authorization endpoint can be used to request device and user codes. + This endpoint is used to start the device flow authorization process and user code verification. + + POST LOCATION/device/? + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | user code | query | in the last step to confirm recived user | WE8R-WEN9 | + | | | code put it as query parameter (optional) | | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The public client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + + User code confirmation:: + + GET LOCATION/device/?user_code= + + Request example, to initialize a Device authentication flow:: + + POST LOCATION/device/CheckIn_dev?client_id=3f1DAj8z6eNw0E6JGq1Vu6efZwyV&scope=g:dirac_admin + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_code": "TglwLiow0HUwowjB9aHH5HqH3bZKP9d420LkNhCEuR", + "verification_uri": "https://marosvn32.in2p3.fr/auth/device", + "interval": 5, + "expires_in": 1800, + "verification_uri_complete": "https://marosvn32.in2p3.fr/auth/device/WSRL-HJMR", + "user_code": "WSRL-HJMR" + } + + Request example, to confirm the user code:: + + POST LOCATION/device/CheckIn_dev/WSRL-HJMR + + Response:: + + HTTP/1.1 200 OK + """ + if self.request.method == 'POST': + self.log.verbose('Initialize a Device authentication flow.') + return self.server.create_endpoint_response(DeviceAuthorizationEndpoint.ENDPOINT_NAME, self.request) + + elif self.request.method == 'GET': + if user_code: + # If received a request with a user code, then prepare a request to authorization endpoint. + self.log.verbose('User code verification.') + result = self.server.db.getSessionByUserCode(user_code) + if not result['OK'] or not result['Value']: + return getHTML('session is expired.', theme='warning', body=result.get('Message'), + info='Seems device code flow authorization session %s expired.' % user_code) + session = result['Value'] + # Get original request from session + req = createOAuth2Request(dict(method='GET', uri=session['uri'])) + req.setQueryArguments(id=session['id'], user_code=user_code) + + # Save session to cookie and redirect to authorization endpoint + authURL = '%s?%s' % (req.path.replace('device', 'authorization'), req.query) + return self.server.handle_response(302, {}, [("Location", authURL)], session) + + # If received a request without a user code, then send a form to enter the user code + with dom.div(cls='row mt-5 justify-content-md-center') as tag: + with dom.div(cls="col-auto"): + dom.div(dom.form(dom._input(type="text", name="user_code"), + dom.button('Submit', type="submit", cls="btn btn-submit"), + action=self.currentPath, method="GET"), cls='card') + return getHTML('user code verification..', body=tag, icon='ticket-alt', + info='Device flow required user code. You will need to type user code to continue.') + + path_authorization = ['([A-z%0-9-_]*)'] + + def web_authorization(self, provider=None): + """ Authorization endpoint + + GET: LOCATION/authorization/ + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | response_type | query | informs of the desired grant type | code | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + General options: + provider -- identity provider to autorize + + Device flow: + &user_code=.. (required) + + Authentication code flow: + &scope=.. (optional) + &redirect_uri=.. (optional) + &state=.. (main session id, optional) + &code_challenge=.. (PKCE, optional) + &code_challenge_method=(pain|S256) ('pain' by default, optional) + """ + return self.server.validate_consent_request(self.request, provider) + + def web_redirect(self, state, error=None, error_description='', chooseScope=[]): + """ Redirect endpoint. + After a user successfully authorizes an application, the authorization server will redirect + the user back to the application with either an authorization code or access token in the URL. + The full URL of this endpoint must be registered in the identity provider. + + Read more in `oauth.com `_. + Specified by `RFC6749 `_. + + GET LOCATION/redirect + + :param str state: Current IdP session state + :param str error: IdP error response + :param str error_description: error description + :param list chooseScope: to specify new scope(group in our case) (optional) + + :return: S_OK()/S_ERROR() + """ + # Check current auth session that was initiated for the selected external identity provider + session = self.get_secure_cookie('auth_session') + if not session: + return self.server.handle_response( + payload=getHTML("session is expired.", theme="warning", state=400, + info="Seems %s session is expired, please, try again." % state), delSession=True) + + sessionWithExtIdP = json.loads(session) + if state and not sessionWithExtIdP.get('state') == state: + return self.server.handle_response( + payload=getHTML("session is expired.", theme="warning", state=400, + info="Seems %s session is expired, please, try again." % state), delSession=True) + + # Try to catch errors if the authorization on the selected identity provider was unsuccessful + if error: + provider = sessionWithExtIdP.get('Provider') + return self.server.handle_response( + payload=getHTML(error, theme="error", body=error_description, + info="Seems %s session is failed on the %s's' side." % (state, provider)), delSession=True) + + if not sessionWithExtIdP.get('authed'): + # Parse result of the second authentication flow + self.log.info('%s session, parsing authorization response:\n' % state, self.request.uri) + + result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) + if not result['OK']: + if result['Message'].startswith(""): + return self.server.handle_response(payload=result['Message'], delSession=True) + return self.server.handle_response( + payload=getHTML("server error", state=500, info=result['Message']), delSession=True) + # Return main session flow + sessionWithExtIdP['authed'] = result['Value'] + + # Research group + grant_user, response = self.__researchDIRACGroup(sessionWithExtIdP, chooseScope, state) + if not grant_user: + return response + + # RESPONSE to basic DIRAC client request + resp = self.server.create_authorization_response(response, grant_user) + if not resp.payload.startswith(""): + resp.payload = getHTML('authorization response', state=resp.status_code, body=resp.payload) + return resp + + def web_token(self): + """ The token endpoint, the description of the parameters will differ depending on the selected grant_type + + POST LOCATION/token + + Parameters: + +----------------+--------+-------------------------------+---------------------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | grant_type | query | grant type to use | urn:ietf:params:oauth:grant-type:device_code | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | + +----------------+--------+-------------------------------+---------------------------------------------------+ + + :mod:`Supported grant types ` + + Request example:: + + POST LOCATION/token?client_id=L86..yV&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code=uW5 + + Response:: + + HTTP/1.1 400 OK + Content-Type: application/json + + { + "error": "authorization_pending" + } + """ + return self.server.create_token_response(self.request) + + def __researchDIRACGroup(self, extSession, chooseScope, state): + """ Research DIRAC groups for authorized user + + :param dict extSession: ended authorized external IdP session + + :return: -- will return (None, response) to provide error or group selector + will return (grant_user, request) to contionue authorization with choosed group + """ + # Base DIRAC client auth session + firstRequest = createOAuth2Request(extSession['firstRequest']) + # Read requested groups by DIRAC client or user + firstRequest.addScopes(chooseScope) + # Read already authed user + username = extSession['authed']['username'] + # Requested arguments in first request + provider = firstRequest.provider + self.log.debug('Next groups has been found for %s:' % username, ', '.join(firstRequest.groups)) + + # Researche Group + result = getGroupsForUser(username) + if not result['OK']: + return None, self.server.handle_response( + getHTML("server error", theme="error", info=result['Message']), delSession=True) + groups = result['Value'] + + validGroups = [group for group in groups if (getIdPForGroup(group) == provider) or ('proxy' in firstRequest.scope)] + if not validGroups: + return None, self.server.handle_response(getHTML( + "groups not found.", theme="error", + info='No groups found for %s and for %s Identity Provider.' % (username, provider)), delSession=True) + + self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) + + # If group already defined in first request, just return it + if firstRequest.groups: + return extSession['authed'], firstRequest + + # If not and we found only one valid group, apply this group + if len(validGroups) == 1: + firstRequest.addScopes(['g:%s' % validGroups[0]]) + return extSession['authed'], firstRequest + + # Else give user chanse to choose group in browser + with dom.div(cls='row mt-5 justify-content-md-center align-items-center') as tag: + for group in sorted(validGroups): + vo, gr = group.split('_') + with dom.div(cls="col-auto p-2").add(dom.div(cls="card shadow-lg border-0 text-center p-2")): + dom.h4(vo.upper() + ' ' + gr, cls="p-2") + dom.a(href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group), cls="stretched-link") + + html = getHTML('group selection..', body=tag, icon='users', + info='Dirac use groups to describe permissions. ' + 'You will need to select one of the groups to continue.') + + return None, self.server.handle_response(payload=html, newSession=extSession) diff --git a/src/DIRAC/FrameworkSystem/API/__init__.py b/src/DIRAC/FrameworkSystem/API/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py index 239e39d27c7..a20d4b0ac60 100644 --- a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py +++ b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py @@ -7,11 +7,14 @@ import os import getpass import tarfile +import six from six import BytesIO +from base64 import b64decode from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Base.Client import Client, createClient -from DIRAC.Core.DISET.TransferClient import TransferClient +from DIRAC.Core.Tornado.Client.TornadoClient import TornadoClient +from DIRAC.Core.Tornado.Client.ClientSelector import TransferClientSelector as TransferClient from DIRAC.Core.Security import Locations, Utilities from DIRAC.Core.Utilities.File import mkDir from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import skipCACheck @@ -20,6 +23,16 @@ __RCSID__ = "$Id$" +class BundleDeliveryJSONClient(TornadoClient): + + def receiveFile(self, buff, fileId): + retVal = self.executeRPC('streamToClient', fileId) + if retVal['OK']: + retVal['Value'] = b64decode(retVal['Value'].encode()) + buff.write(retVal['Value']) + return retVal + + @createClient('Framework/BundleDelivery') class BundleDeliveryClient(Client): @@ -37,7 +50,8 @@ def __getTransferClient(self): if self.transferClient: return self.transferClient return TransferClient("Framework/BundleDelivery", - skipCACheck=skipCACheck()) + skipCACheck=skipCACheck(), + httpsClient=BundleDeliveryJSONClient) def __getHash(self, bundleID, dirToSyncTo): """ Get hash for bundle in directory @@ -48,9 +62,9 @@ def __getHash(self, bundleID, dirToSyncTo): :return: str """ try: - with open(os.path.join(dirToSyncTo, ".dab.%s" % bundleID), "r") as fd: + with open(os.path.join(dirToSyncTo, ".dab.%s" % bundleID), "rb") as fd: bdHash = fd.read().strip() - return bdHash + return bdHash.decode() if six.PY3 else bdHash except Exception: return "" @@ -63,8 +77,8 @@ def __setHash(self, bundleID, dirToSyncTo, bdHash): """ try: fileName = os.path.join(dirToSyncTo, ".dab.%s" % bundleID) - with open(fileName, "wt") as fd: - fd.write(bdHash) + with open(fileName, "wb") as fd: + fd.write(bdHash.encode() if six.PY3 and not isinstance(bdHash, bytes) else bdHash) except Exception as e: self.log.error("Could not save hash after synchronization", "%s: %s" % (fileName, str(e))) diff --git a/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py new file mode 100644 index 00000000000..c49abfef639 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py @@ -0,0 +1,18 @@ +""" The TokenManagerClient is a class representing the client of the DIRAC TokenManager service. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from DIRAC.Core.Base.Client import Client, createClient + + +@createClient('Framework/TokenManager') +class TokenManagerClient(Client): + """Client exposing the TokenManager Service.""" + + def __init__(self, **kwargs): + super(TokenManagerClient, self).__init__(**kwargs) + self.setServer('Framework/TokenManager') diff --git a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg index 8d4289ecb9d..d239cd6c5d7 100644 --- a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg +++ b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg @@ -1,3 +1,15 @@ +APIs +{ + ##BEGIN Auth: + # Section to describe RESTful API for DIRAC Authorization Server(AS) + Auth + { + Port = 8000 + # Allow download personal proxy + downloadablePersonalProxy = True + } + ##END +} Services { Gateway @@ -13,6 +25,17 @@ Services storeHostInfo = Operator } } + TornadoBundleDelivery + { + Protocol = https + } + ##BEGIN TokenManager: + # Section to describe TokenManager system + TokenManager + { + Protocol = https + } + ##END ##BEGIN ProxyManager: # Section to describe ProxyManager system # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Systems/Framework/ProxyManager/index.html diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py new file mode 100644 index 00000000000..0b0d4063e5a --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -0,0 +1,356 @@ +""" Auth class is a front-end to the Auth Database +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import time +import pprint + +from sqlalchemy import Column, Integer, Text, String +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.ext.declarative import declarative_base + +import authlib +from authlib.jose import KeySet, JsonWebKey +from authlib.common.security import generate_token + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + + +Model = declarative_base() + + +class RefreshToken(Model): + __tablename__ = 'RefreshToken' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + jti = Column(String(255), nullable=False, primary_key=True) + issued_at = Column(Integer, nullable=False, default=0) + access_token = Column(Text, nullable=False) + refresh_token = Column(Text) + + +class JWK(Model): + __tablename__ = 'JWK' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + kid = Column(String(255), unique=True, primary_key=True, nullable=False) + key = Column(Text, nullable=False) + expires_at = Column(Integer, nullable=False, default=0) + + +class AuthSession(Model): + __tablename__ = 'AuthSession' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + id = Column(String(255), unique=True, primary_key=True, nullable=False) + uri = Column(String(255)) + state = Column(String(255)) + scope = Column(String(255)) + user_id = Column(String(255)) + username = Column(String(255)) + client_id = Column(String(255)) + user_code = Column(String(255)) + device_code = Column(String(255)) + interval = Column(Integer, nullable=False, default=5) + expires_at = Column(Integer, nullable=False, default=0) + expires_in = Column(Integer, nullable=False, default=0) + verification_uri = Column(String(255)) + verification_uri_complete = Column(String(255)) + + +class AuthDB(SQLAlchemyDB): + """ AuthDB class is a front-end to the OAuth Database + """ + def __init__(self): + """ Constructor + """ + super(AuthDB, self).__init__() + self._initializeConnection('Framework/AuthDB') + result = self.__initializeDB() + if not result['OK']: + raise Exception("Can't create tables: %s" % result['Message']) + self.session = scoped_session(self.sessionMaker_o) + + def __initializeDB(self): + """ Create the tables + """ + tablesInDB = self.inspector.get_table_names() + + # RefreshToken + if 'RefreshToken' not in tablesInDB: + try: + RefreshToken.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + # JWK + if 'JWK' not in tablesInDB: + try: + JWK.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + # AuthSession + if 'AuthSession' not in tablesInDB: + try: + AuthSession.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + return S_OK() + + def storeRefreshToken(self, token, tokenID=None): + """ Store refresh token + + :param dict token: tokens as dict + :param str tokenID: token ID + + :return: S_OK(dict)/S_ERROR() + """ + iat = int(time.time()) + jti = tokenID or generate_token(10) + self.log.debug('Store %s token:\n' % jti, pprint.pformat(token)) + + session = self.session() + try: + session.add(RefreshToken(jti=jti, + issued_at=iat, + access_token=token['access_token'], + refresh_token=token.get('refresh_token'))) + except Exception as e: + return self.__result(session, S_ERROR('Could not add refresh token: %s' % repr(e))) + + self.log.info('Token with %s ID successfully added:\n' % jti, pprint.pformat(token)) + return S_OK(dict(jti=jti, iat=iat)) + + def revokeRefreshToken(self, tokenID): + """ Revoke refresh token + + :param str tokenID: refresh token ID + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + session.query(RefreshToken).filter(RefreshToken.jti == tokenID).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return S_OK() + + def getCredentialByRefreshToken(self, tokenID): + """ Get refresh token credential + + :param str tokenID: refresh token ID + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + token = session.query(RefreshToken).filter(RefreshToken.jti == tokenID).first() + session.query(RefreshToken).filter(RefreshToken.jti == tokenID).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) + + def generateRSAKeys(self): + """ Generate an RSA keypair with an exponent of 65537 in PEM format + + :return: S_OK/S_ERROR + """ + key = JsonWebKey.generate_key('RSA', 1024, is_private=True) + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0 + keyDict = dict(key=json.dumps(key.as_dict(True)), + kid=key.thumbprint(), expires_at=time.time() + (30 * 24 * 3600)) + session = self.session() + try: + session.add(JWK(**keyDict)) + except Exception as e: + return self.__result(session, S_ERROR('Could not generate keys: %s' % e)) + return self.__result(session, S_OK(keyDict)) + + def getKeySet(self): + """ Get key set + + :return: S_OK(obj)/S_ERROR() + """ + result = self.getActiveKeys() + if result['OK'] and not result['Value']: + result = self.generateRSAKeys() + if result['OK']: + result['Value'] = [result['Value']] + if not result['OK']: + return result + return S_OK(KeySet([JsonWebKey.import_key(json.loads(key['key'])) for key in result['Value']])) + + def getPrivateKey(self, kid=None): + """ Get private key + + :param str kid: key ID + + :return: S_OK(obj)/S_ERROR() + """ + result = self.getActiveKeys(kid) + if not result['OK']: + return result + jwks = result['Value'] + if kid: + strkey = jwks[0]['key'] + return S_OK(JsonWebKey.import_key(json.loads(jwks[0]['key']))) + newer = {} + for jwk in jwks: + if int(jwk['expires_at']) > int(newer.get('expires_at', time.time() + (24 * 3600))): + newer = jwk + if not newer.get('key'): + result = self.generateRSAKeys() + if not result['OK']: + return result + newer = result['Value'] + return S_OK(JsonWebKey.import_key(json.loads(newer['key']))) + + def getActiveKeys(self, kid=None): + """ Get active keys + + :param str kid: key ID + + :return: S_OK(list)/S_ERROR() + """ + session = self.session() + try: + # Remove all expired jwks + session.query(JWK).filter(JWK.expires_at < time.time()).delete() + jwks = session.query(JWK).filter(JWK.expires_at > time.time()).all() + if kid: + jwks = [jwk for jwk in jwks if jwk.kid == kid] + except NoResultFound: + return self.__result(session, S_OK([])) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK([self.__rowToDict(jwk) for jwk in jwks])) + + def removeKeys(self): + """ Get active keys + + :return: S_OK(list)/S_ERROR() + """ + session = self.session() + try: + session.query(JWK).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def addSession(self, data): + """ Add new session + + :param dict data: session metadata + + :return: S_OK(dict)/S_ERROR() + """ + attrts = {} + if not data.get('expires_at'): + data['expires_at'] = data['expires_in'] + time.time() + self.log.debug('Add authorization session:', data) + for k, v in data.items(): + if k not in AuthSession.__dict__.keys(): + self.log.warn('%s is not expected as authentication session attribute.' % k) + else: + attrts[k] = v + session = self.session() + try: + session.add(AuthSession(**attrts)) + except Exception as e: + return self.__result(session, S_ERROR('Could not add Token: %s' % e)) + return self.__result(session, S_OK('Token successfully added')) + + def updateSession(self, data, sessionID): + """ Update session data + + :param dict data: data info + :param str sessionID: sessionID + + :return: S_OK(object)/S_ERROR() + """ + self.removeSession(sessionID=sessionID) + return self.addSession(data) + + def removeSession(self, sessionID): + """ Remove session + + :param str sessionID: session id + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + # Remove all expired sessions + session.query(AuthSession).filter(AuthSession.expires_at < time.time()).delete() + session.query(AuthSession).filter(AuthSession.id == sessionID).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def getSession(self, sessionID): + """ Get client + + :param str sessionID: session id + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.id == sessionID).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + except NoResultFound: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def getSessionByUserCode(self, userCode): + """ Get client + + :param str userCode: user code + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % userCode)) + except NoResultFound: + return self.__result(session, S_ERROR("Session for %s user code is expired." % userCode)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def __result(self, session, result=None): + try: + if not result['OK']: + session.rollback() + else: + session.commit() + except Exception as e: + session.rollback() + result = S_ERROR('Could not commit: %s' % (e)) + session.close() + return result + + def __rowToDict(self, row): + """ Convert sqlalchemy row to dictionary + + :param object row: sqlalchemy row + + :return: dict + """ + return {c.name: str(getattr(row, c.name)) for c in row.__table__.columns} if row else {} diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.sql b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql new file mode 100644 index 00000000000..bfd422307e6 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql @@ -0,0 +1,2 @@ +# Everything is created by the DB object upon instantiation if it does not exists. +use AuthDB; \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/DB/TokenDB.py b/src/DIRAC/FrameworkSystem/DB/TokenDB.py new file mode 100644 index 00000000000..6c6a86633c9 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/TokenDB.py @@ -0,0 +1,176 @@ +""" Auth class is a front-end to the Auth Database +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import jwt +import time +import pprint + +from sqlalchemy import Column, Integer, Text, String +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.ext.declarative import declarative_base + +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + + +Model = declarative_base() + + +class Token(Model, OAuth2TokenMixin): + __tablename__ = 'Token' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + # access_token too large for varchar(255) + # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 + # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length + id = Column(Integer, autoincrement=True, primary_key=True) + kid = Column(String(255)) + user_id = Column(String(255)) + provider = Column(String(255)) + expires_at = Column(Integer, nullable=False, default=0) + access_token = Column(Text, nullable=False) + refresh_token = Column(Text, nullable=False) + rt_expires_at = Column(Integer, nullable=False, default=0) + + +class TokenDB(SQLAlchemyDB): + """ TokenDB class is a front-end to the OAuth Database + """ + def __init__(self): + """ Constructor + """ + super(TokenDB, self).__init__() + self._initializeConnection('Framework/TokenDB') + result = self.__initializeDB() + if not result['OK']: + raise Exception("Can't create tables: %s" % result['Message']) + self.session = scoped_session(self.sessionMaker_o) + + def __initializeDB(self): + """ Create the tables + """ + tablesInDB = self.inspector.get_table_names() + + # Token + if 'Token' not in tablesInDB: + try: + Token.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + return S_OK() + + def getTokenForUserProvider(self, userID, provider): + """ Get token for user ID and provider name + + :param str userID: user ID + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + token = session.query(Token).filter(Token.rt_expires_at > time.time()).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).first() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) + + def updateToken(self, token, userID, provider, rt_expired_in): + """ Update tokens + + :param dict token: token info + :param str userID: user ID + :param str provider: provider + :param int rt_expired_in: refresh token lifetime + + :return: S_OK(list)/S_ERROR() + """ + token['user_id'] = userID + token['provider'] = provider + if not token.get('rt_expires_at'): + try: + token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, + verify_aud=False))['exp']) + except Exception as e: + self.log.debug('Cannot get refresh token expires time: %s' % repr(e)) + + token['rt_expires_at'] = int(token.get('rt_expires_at', rt_expired_in + int(time.time()))) + if token['rt_expires_at'] < time.time(): + return S_ERROR('Cannot store expired refresh token.') + + attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) + self.log.debug('Store token:', pprint.pformat(attrts)) + session = self.session() + try: + session.query(Token).filter(Token.expires_at < time.time()).delete() + oldTokens = session.query(Token).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).all() + session.add(Token(**attrts)) + session.query(Token).filter(Token.user_id == userID).filter(Token.provider == provider)\ + .filter(Token.access_token != token['access_token']).delete() + except Exception as e: + self.log.exception(e) + return self.__result(session, S_ERROR('Could not add Token: %s' % repr(e))) + self.log.info('Token successfully added for %s user, %s provider' % (token['user_id'], token['provider'])) + return self.__result(session, S_OK([self.__rowToDict(t) for t in oldTokens] if oldTokens else [])) + + def removeToken(self, access_token=None, refresh_token=None, user_id=None): + """ Remove token + + :param str access_token: access token + :param str refresh_token: refresh token + + :return: S_OK(object)/S_ERROR() + """ + session = self.session() + try: + if access_token: + session.query(Token).filter(Token.access_token == access_token).delete() + elif refresh_token: + session.query(Token).filter(Token.refresh_token == refresh_token).delete() + elif user_id: + session.query(Token).filter(Token.user_id == user_id).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK('Token successfully removed')) + + def getTokensByUserID(self, userID): + session = self.session() + try: + tokens = session.query(Token).filter(Token.user_id == userID).all() + except NoResultFound: + return self.__result(session, S_OK([])) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK([OAuth2Token(self.__rowToDict(t)) for t in tokens])) + + def __result(self, session, result=None): + try: + if not result['OK']: + session.rollback() + else: + session.commit() + except Exception as e: + session.rollback() + result = S_ERROR('Could not commit: %s' % (e)) + session.close() + return result + + def __rowToDict(self, row): + """ Convert sqlalchemy row to dictionary + + :param object row: sqlalchemy row + + :return: dict + """ + return {c.name: str(getattr(row, c.name)) for c in row.__table__.columns} if row else {} diff --git a/src/DIRAC/FrameworkSystem/DB/TokenDB.sql b/src/DIRAC/FrameworkSystem/DB/TokenDB.sql new file mode 100644 index 00000000000..15d9a8185fb --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/TokenDB.sql @@ -0,0 +1,2 @@ +# Everything is created by the DB object upon instantiation if it does not exists. +use TokenDB; \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py new file mode 100644 index 00000000000..8ffa1f20aae --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -0,0 +1,221 @@ +""" TokenManager service + + .. literalinclude:: ../ConfigTemplate.cfg + :start-after: ##BEGIN TokenManager: + :end-before: ##END + :dedent: 2 + :caption: TokenManager options +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import pprint + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Security import Properties +from DIRAC.Core.Tornado.Server.TornadoService import TornadoService +from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + + +class TokenManagerHandler(TornadoService): + + __maxExtraLifeFactor = 1.5 + __tokenDB = None + + @classmethod + def initializeHandler(cls, serviceInfoDict): + try: + cls.__tokenDB = TokenDB() + except Exception as e: + gLogger.exception(e) + return S_ERROR('Could not connect to the database %s' % repr(e)) + + cls.idps = IdProviderFactory() + return S_OK() + + def __generateUsersTokensInfo(self, users): + """ Generate information dict about user tokens + + :return: dict + """ + tokensInfo = [] + credDict = self.getRemoteCredentials() + result = Registry.getDNForUsername(credDict['username']) + if not result['OK']: + return result + for dn in result['Value']: + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokensByUserID(result['Value']) + if not result['OK']: + gLogger.error(result['Message']) + tokensInfo += result['Value'] + return tokensInfo + + def __generateUserTokensInfo(self): + """ Generate information dict about user tokens + + :return: dict + """ + tokensInfo = [] + credDict = self.getRemoteCredentials() + result = Registry.getDNForUsername(credDict['username']) + if not result['OK']: + return result + for dn in result['Value']: + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokensByUserID(result['Value']) + if not result['OK']: + gLogger.error(result['Message']) + tokensInfo += result['Value'] + return tokensInfo + + def __addKnownUserTokensInfo(self, retDict): + """ Given a S_OK/S_ERR add a tokens entry with info of all the tokens a user has uploaded + + :return: S_OK(dict)/S_ERROR() + """ + retDict['tokens'] = self.__generateUserTokensInfo() + return retDict + + auth_getUserTokensInfo = ['authenticated'] + + def export_getUserTokensInfo(self): + """ Get the info about the user tokens in the system + + :return: S_OK(dict) + """ + return S_OK(self.__generateUserTokensInfo()) + + auth_getUsersTokensInfo = [Properties.PROXY_MANAGEMENT] + + def export_getUsersTokensInfo(self, users): + """ Get the info about the user tokens in the system + + :param list users: user names + + :return: S_OK(dict) + """ + tokensInfo = [] + for user in users: + result = Registry.getDNForUsername(user) + if not result['OK']: + return result + for dn in result['Value']: + uid = Registry.getIDFromDN(dn).get('Value') + if uid: + result = self.__tokenDB.getTokensByUserID(uid) + if not result['OK']: + gLogger.error(result['Message']) + else: + for tokenDict in result['Value']: + if tokenDict not in tokensInfo: + tokenDict['username'] = user + tokensInfo.append(tokenDict) + return S_OK(tokensInfo) + + auth_uploadToken = ['authenticated'] + + def export_updateToken(self, token, userID, provider, rt_expired_in=24 * 3600): + """ Request to delegate tokens to DIRAC + + :param dict token: token + :param str userID: user ID + :param str provider: provider name + :param int rt_expired_in: refresh token expires time + + :return: S_OK(list)/S_ERROR() -- list contain uploaded tokens info as dictionaries + """ + self.log.verbose('Update %s user token for %s:\n' % (userID, provider), pprint.pformat(token)) + result = self.idps.getIdProvider(provider) + if not result['OK']: + return result + idPObj = result['Value'] + result = self.__tokenDB.updateToken(token, userID, provider, rt_expired_in) + if not result['OK']: + return result + for oldToken in result['Value']: + if 'refresh_token' in oldToken and oldToken['refresh_token'] != token['refresh_token']: + self.log.verbose('Revoke old refresh token:\n', pprint.pformat(oldToken)) + idPObj.revokeToken(oldToken['refresh_token']) + return self.__tokenDB.getTokensByUserID(userID) + + def __checkProperties(self, requestedUserDN, requestedUserGroup): + """ Check the properties and return if they can only download limited tokens if authorized + + :param str requestedUserDN: user DN + :param str requestedUserGroup: DIRAC group + + :return: S_OK(bool)/S_ERROR() + """ + credDict = self.getRemoteCredentials() + if Properties.FULL_DELEGATION in credDict['properties']: + return S_OK(False) + if Properties.LIMITED_DELEGATION in credDict['properties']: + return S_OK(True) + if Properties.PRIVATE_LIMITED_DELEGATION in credDict['properties']: + if credDict['DN'] != requestedUserDN: + return S_ERROR("You are not allowed to download any token") + if Properties.PRIVATE_LIMITED_DELEGATION not in Registry.getPropertiesForGroup(requestedUserGroup): + return S_ERROR("You can't download tokens for that group") + return S_OK(True) + # Not authorized! + return S_ERROR("You can't get tokens!") + + def export_getToken(self, username, userGroup): + """ Get a access token for a user/group + + * Properties: + * FullDelegation <- permits full delegation of tokens + * LimitedDelegation <- permits downloading only limited tokens + * PrivateLimitedDelegation <- permits downloading only limited tokens for one self + """ + userID = [] + provider = Registry.getIdPForGroup(userGroup) + if not provider: + return S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.' % userGroup) + + result = self.idps.getIdProvider(provider) + if not result['OK']: + return result + idpObj = result['Value'] + + result = Registry.getDNForUsername(username) + if not result['OK']: + return result + + err = [] + for dn in result['Value']: + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokenForUserProvider(result['Value'], provider) + if result['OK'] and result['Value']: + idpObj.token = result['Value'] + result = self.__checkProperties(dn, userGroup) + if result['OK']: + result = idpObj.exchangeGroup(userGroup) + if result['OK']: + return result + err.append(result.get('Message', 'No token found for %s.' % dn)) + return S_ERROR('; '.join(err or ['No user ID found for %s' % username])) + + def export_deleteToken(self, userDN): + """ Delete a token from the DB + + :param str userDN: user DN + + :return: S_OK()/S_ERROR() + """ + credDict = self.getRemoteCredentials() + if Properties.PROXY_MANAGEMENT not in credDict['properties']: + if userDN != credDict['DN']: + return S_ERROR("You aren't allowed!") + result = Registry.getIDFromDN(userDN) + return self.__tokenDB.removeToken(user_id=result['Value']) if result['OK'] else result diff --git a/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py new file mode 100644 index 00000000000..6d7a8ad40fc --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py @@ -0,0 +1,206 @@ +""" Handler for CAs + CRLs bundles +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import tarfile +import os +from base64 import b64encode, b64decode + +from DIRAC import gLogger, S_OK, S_ERROR, gConfig +from DIRAC.Core.Tornado.Server.TornadoService import TornadoService +from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler +from DIRAC.Core.Utilities import File, List +from DIRAC.Core.Security import Locations, Utilities + + +class BundleManager(object): + + def __init__(self, baseCSPath): + self.__csPath = baseCSPath + self.__bundles = {} + self.updateBundles() + + def __getDirsToBundle(self): + dirsToBundle = {} + result = gConfig.getOptionsDict("%s/DirsToBundle" % self.__csPath) + if result['OK']: + dB = result['Value'] + for bId in dB: + dirsToBundle[bId] = List.fromChar(dB[bId]) + if gConfig.getValue("%s/BundleCAs" % self.__csPath, True): + dirsToBundle['CAs'] = [ + "%s/*.0" % + Locations.getCAsLocation(), + "%s/*.signing_policy" % + Locations.getCAsLocation(), + "%s/*.pem" % + Locations.getCAsLocation()] + if gConfig.getValue("%s/BundleCRLs" % self.__csPath, True): + dirsToBundle['CRLs'] = ["%s/*.r0" % Locations.getCAsLocation()] + return dirsToBundle + + def getBundles(self): + return dict([(bId, self.__bundles[bId]) for bId in self.__bundles]) + + def bundleExists(self, bId): + return bId in self.__bundles + + def getBundleVersion(self, bId): + try: + return self.__bundles[bId][0] + except Exception: + return "" + + def getBundleData(self, bId): + try: + return self.__bundles[bId][1] + except Exception: + return "" + + def updateBundles(self): + dirsToBundle = self.__getDirsToBundle() + # Delete bundles that don't have to be updated + for bId in self.__bundles: + if bId not in dirsToBundle: + gLogger.info("Deleting old bundle %s" % bId) + del(self.__bundles[bId]) + for bId in dirsToBundle: + bundlePaths = dirsToBundle[bId] + gLogger.info("Updating %s bundle %s" % (bId, bundlePaths)) + buffer_ = six.BytesIO() + filesToBundle = sorted(File.getGlobbedFiles(bundlePaths)) + if filesToBundle: + commonPath = os.path.commonprefix(filesToBundle) + commonEnd = len(commonPath) + gLogger.info("Bundle will have %s files with common path %s" % (len(filesToBundle), commonPath)) + with tarfile.open('dummy', "w:gz", buffer_) as tarBuffer: + for filePath in filesToBundle: + tarBuffer.add(filePath, filePath[commonEnd:]) + zippedData = buffer_.getvalue() + buffer_.close() + hash_ = File.getMD5ForFiles(filesToBundle) + gLogger.info("Bundled %s : %s bytes (%s)" % (bId, len(zippedData), hash_)) + self.__bundles[bId] = (hash_, zippedData) + else: + self.__bundles[bId] = (None, None) + + +class TornadoBundleDeliveryHandler(TornadoService): + + @classmethod + def initializeHandler(cls, serviceInfoDict): + csPath = serviceInfoDict['serviceSectionPath'] + cls.bundleManager = BundleManager(csPath) + updateBundleTime = gConfig.getValue("%s/BundlesLifeTime" % csPath, 3600 * 6) + gLogger.info("Bundles will be updated each %s secs" % updateBundleTime) + gThreadScheduler.addPeriodicTask(updateBundleTime, cls.bundleManager.updateBundles) + return S_OK() + + types_streamToClient = [] + + def export_streamToClient(self, fileId): + version = "" + if isinstance(fileId, six.string_types): + if fileId in ['CAs', 'CRLs']: + retVal = Utilities.generateCAFile() if fileId == 'CAs' else Utilities.generateRevokedCertsFile() + if not retVal['OK']: + return retVal + with open(retVal['Value'], 'r') as fd: + return S_OK(b64encode(fd.read()).decode()) + bId = fileId + + elif isinstance(fileId, (list, tuple)): + if len(fileId) == 0: + return S_ERROR("No bundle specified!") + bId = fileId[0] + if len(fileId) != 1: + version = fileId[1] + + if not self.bundleManager.bundleExists(bId): + return S_ERROR("Unknown bundle %s" % bId) + + bundleVersion = self.bundleManager.getBundleVersion(bId) + if bundleVersion is None: + return S_ERROR("Empty bundle %s" % bId) + + if version == bundleVersion: + return S_OK(bundleVersion) + + return S_OK(b64encode(self.bundleManager.getBundleData(bId)).decode()) + + types_getListOfBundles = [] + + @classmethod + def export_getListOfBundles(cls): + return S_OK(cls.bundleManager.getBundles()) + + def transfer_toClient(self, fileId, token, fileHelper): + version = "" + if isinstance(fileId, six.string_types): + if fileId in ['CAs', 'CRLs']: + return self.__transferFile(fileId, fileHelper) + else: + bId = fileId + elif isinstance(fileId, (list, tuple)): + if len(fileId) == 0: + fileHelper.markAsTransferred() + return S_ERROR("No bundle specified!") + elif len(fileId) == 1: + bId = fileId[0] + else: + bId = fileId[0] + version = fileId[1] + if not self.bundleManager.bundleExists(bId): + fileHelper.markAsTransferred() + return S_ERROR("Unknown bundle %s" % bId) + + bundleVersion = self.bundleManager.getBundleVersion(bId) + if bundleVersion is None: + fileHelper.markAsTransferred() + return S_ERROR("Empty bundle %s" % bId) + + if version == bundleVersion: + fileHelper.markAsTransferred() + return S_OK(bundleVersion) + + buffer_ = six.BytesIO(self.bundleManager.getBundleData(bId)) + result = fileHelper.DataSourceToNetwork(buffer_) + buffer_.close() + if not result['OK']: + return result + return S_OK(bundleVersion) + + def __transferFile(self, filetype, fileHelper): + """ + This file is creates and transfers the CAs or CRLs file to the client. + :param str filetype: we can define which file will be transfered to the client + :param object fileHelper: + :return: S_OK or S_ERROR + """ + if filetype == 'CAs': + retVal = Utilities.generateCAFile() + elif filetype == 'CRLs': + retVal = Utilities.generateRevokedCertsFile() + else: + return S_ERROR("Not supported file type %s" % filetype) + + if not retVal['OK']: + return retVal + else: + result = fileHelper.getFileDescriptor(retVal['Value'], 'r') + if not result['OK']: + result = fileHelper.sendEOF() + # better to check again the existence of the file + if not os.path.exists(retVal['Value']): + return S_ERROR('File %s does not exist' % os.path.basename(retVal['Value'])) + else: + return S_ERROR('Failed to get file descriptor') + fileDescriptor = result['Value'] + result = fileHelper.FDToNetwork(fileDescriptor) + fileHelper.oFile.close() # close the file and return + return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py new file mode 100644 index 00000000000..3ae5784695d --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -0,0 +1,481 @@ +""" This class provides authorization server activity. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import six +import sys +import time +import json +import pprint +import logging +from dominate import tags as dom + +import authlib +from authlib.jose import JsonWebKey, jwt +from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance, getProviderInfo +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.ConfigurationSystem.Client.Utilities import isDownloadablePersonalProxy +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, wrapIDAsDN, + getDNForUsername, getIdPForGroup) +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient +from DIRAC.FrameworkSystem.Client.TokenManagerClient import TokenManagerClient +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import TornadoResponse +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients, Client +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata, getHTML +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, + DeviceCodeGrant) +from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant + +log = logging.getLogger('authlib') +log.addHandler(logging.StreamHandler(sys.stdout)) +log.setLevel(logging.DEBUG) +log = gLogger.getSubLogger(__name__) + + +class AuthServer(_AuthorizationServer): + """ Implementation of the :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + + Initialize:: + + server = AuthServer() + """ + LOCATION = None + REFRESH_TOKEN_EXPIRES_IN = 24 * 3600 + + def __init__(self): + self.db = AuthDB() + self.log = log + self.idps = IdProviderFactory() + self.proxyCli = ProxyManagerClient() + self.tokenCli = TokenManagerClient() + self.metadata = collectMetadata() + self.metadata.validate() + # args for authlib < 1.0.0: (query_client=self.query_client, save_token=None, metadata=self.metadata) + # for authlib >= 1.0.0: + _AuthorizationServer.__init__(self, scopes_supported=self.metadata['scopes_supported']) + # Skip authlib method save_token and send_signal + self.save_token = lambda x, y: None + self.send_signal = lambda *x, **y: None + self.generate_token = self.generateProxyOrToken + # Register configured grants + self.register_grant(RefreshTokenGrant) + self.register_grant(DeviceCodeGrant) + self.register_endpoint(DeviceAuthorizationEndpoint) + self.register_endpoint(RevocationEndpoint) + self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) + + # pylint: disable=method-hidden + def query_client(self, client_id): + """ Search authorization client. + + :param str clientID: client ID + + :return: client as object or None + """ + gLogger.debug('Try to query %s client' % client_id) + clients = getDIRACClients() + for cli in clients: + if client_id == clients[cli]['client_id']: + gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) + return Client(clients[cli]) + return None + + def addSession(self, session): + self.db.addSession(session) + + def getSession(self, session): + self.db.getSession(session) + + def _getScope(self, scope, param): + """ Get parameter scope + + :param str scope: scope + :param str param: parameter scope + + :return: str or None + """ + try: + return [s.split(':')[1] for s in scope_to_list(scope) if s.startswith('%s:' % param) and s.split(':')[1]][0] + except Exception: + return None + + def generateProxyOrToken(self, client, grant_type, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """ Generate proxy or tokens after authorization + """ + group = self._getScope(scope, 'g') + lifetime = self._getScope(scope, 'lifetime') + provider = getIdPForGroup(group) + + # Search DIRAC username + result = getUsernameForDN(wrapIDAsDN(user)) + if not result['OK']: + raise OAuth2Error(result['Message']) + userName = result['Value'] + + if 'proxy' in scope_to_list(scope): + # Try to return user proxy if proxy scope present in the authorization request + if not isDownloadablePersonalProxy(): + raise OAuth2Error("You can't get proxy, configuration(downloadablePersonalProxy) not allow to do that.") + self.log.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) + result = getDNForUsername(userName) + if not result['OK']: + raise OAuth2Error(result['Message']) + userDNs = result['Value'] + err = [] + for dn in userDNs: + self.log.debug('Try to get proxy for %s' % dn) + if lifetime: + result = self.proxyCli.downloadProxy(dn, group, requiredTimeLeft=int(lifetime)) + else: + result = self.proxyCli.downloadProxy(dn, group) + if not result['OK']: + err.append(result['Message']) + else: + self.log.info('Proxy was created.') + result = result['Value'].dumpAllToString() + if not result['OK']: + raise OAuth2Error(result['Message']) + return {'proxy': result['Value'].decode() if isinstance(result['Value'], bytes) else result['Value']} + raise OAuth2Error('; '.join(err)) + + else: + # Ask TokenManager to generate new tokens for user + result = self.tokenCli.getToken(userName, group) + if not result['OK']: + raise OAuth2Error(result['Message']) + token = result['Value'] + + # Wrap the refresh token and register it to protect against reuse + result = self.registerRefreshToken(dict(sub=user, scope=scope, provider=provider, + azp=client.get_client_id()), token) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] + + def __signToken(self, payload): + """ Sign token + + :param dict payload: payload + + :return: S_OK(str)/S_ERROR() + """ + result = self.db.getPrivateKey() + if not result['OK']: + return result + key = result['Value'] + try: + return S_OK(jwt.encode(dict(alg='RS256', kid=key.thumbprint()), payload, key).decode('utf-8')) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + def readToken(self, token): + """ Decode self token + + :param str token: token to decode + + :return: S_OK(dict)/S_ERROR() + """ + result = self.db.getKeySet() + if not result['OK']: + return result + try: + return S_OK(jwt.decode(token, JsonWebKey.import_key_set(result['Value'].as_dict()))) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + def registerRefreshToken(self, payload, token): + """ Register refresh token to protect it from reuse + + :param dict payload: payload + :param dict token: token as a dictionary + + :return: S_OK(dict)S_ERROR() + """ + result = self.db.storeRefreshToken(token, payload.get('jti')) + if result['OK']: + payload.update(result['Value']) + result = self.__signToken(payload) + if not result['OK']: + if token.get('refresh_token'): + prov = self.idps.getIdProvider(payload['provider']) + if prov['OK']: + prov['Value'].revokeToken(token['refresh_token']) + prov['Value'].revokeToken(token['access_token'], 'access_token') + return result + token['refresh_token'] = result['Value'] + return S_OK(token) + + def getIdPAuthorization(self, provider, request): + """ Submit subsession and return dict with authorization url and session number + + :param str provider: provider name + :param object request: main session request + + :return: S_OK(response)/S_ERROR() -- dictionary contain response generated by `handle_response` + """ + result = self.idps.getIdProvider(provider) + if not result['OK']: + raise Exception(result['Message']) + idpObj = result['Value'] + authURL, state, session = idpObj.submitNewSession() + session['state'] = state + session['Provider'] = provider + session['firstRequest'] = request if isinstance(request, dict) else request.toDict() + + self.log.verbose('Redirect to', authURL) + return self.handle_response(302, {}, [("Location", authURL)], session) + + def parseIdPAuthorizationResponse(self, response, session): + """ Fill session by user profile, tokens, comment, OIDC authorize status, etc. + Prepare dict with user parameters, if DN is absent there try to get it. + Create new or modify existing DIRAC user and store the session + + :param dict response: authorization response + :param str session: session + + :return: S_OK(dict)/S_ERROR() + """ + providerName = session.pop('Provider') + self.log.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) + # Parse response + result = self.idps.getIdProvider(providerName) + if not result['OK']: + return result + idpObj = result['Value'] + result = idpObj.parseAuthResponse(response, session) + if not result['OK']: + return result + + # FINISHING with IdP + # As a result of authentication we will receive user credential dictionary + credDict, payload = result['Value'] + + self.log.debug("Read profile:", pprint.pformat(credDict)) + # Is ID registred? + result = getUsernameForDN(credDict['DN']) + if not result['OK']: + comment = 'Your ID is not registred in the DIRAC: %s.' % credDict['ID'] + payload.update(idpObj.getUserProfile().get('Value', {})) + result = self.__registerNewUser(providerName, payload) + + if result['OK']: + comment += ' Administrators have been notified about you.' + else: + comment += ' Please, contact the DIRAC administrators.' + + # Notify user about problem + html = getHTML("unregistered user!", info=comment, theme='warning') + return S_ERROR(html) + + credDict['username'] = result['Value'] + + # Update token for user. This token will be stored separately in the database and + # updated from time to time. This token will never be transmitted, + # it will be used to make exchange token requests. + result = self.tokenCli.updateToken(idpObj.token, credDict['ID'], idpObj.name) + return S_OK(credDict) if result['OK'] else result + + def create_oauth2_request(self, request, method_cls=OAuth2Request, use_json=False): + self.log.debug('Create OAuth2 request', 'with json' if use_json else '') + return createOAuth2Request(request, method_cls, use_json) + + def create_json_request(self, request): + return self.create_oauth2_request(request, HttpRequest, True) + + def validate_requested_scope(self, scope, state=None): + """ See :func:`authlib.oauth2.rfc6749.authorization_server.validate_requested_scope` """ + # We also consider parametric scope containing ":" charter + extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) + super(AuthServer, self).validate_requested_scope(extended_scope, state) + + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, delSession=None): + """ Handle response + + :param int status_code: http status code + :param payload: response payload + :param list headers: headers + :param dict newSession: session data to store + + :return: TornadoResponse() + """ + self.log.debug('Handle authorization response with %s status code:' % status_code, payload) + resp = TornadoResponse(payload, status_code) + if headers: + self.log.debug('Headers:', headers) + for key, value in headers: + resp.set_header(key, value) # pylint: disable=no-member + if newSession: + self.log.debug('Initialize new session:', newSession) + # pylint: disable=no-member + resp.set_secure_cookie('auth_session', json.dumps(newSession), secure=True, httponly=True) + if delSession or isinstance(payload, dict) and 'error' in payload: + resp.clear_cookie('auth_session') # pylint: disable=no-member + return resp + + def create_authorization_response(self, response, username): + """ Rewrite original Authlib method + `authlib.authlib.oauth2.rfc6749.authorization_server.create_authorization_response` + to catch errors and remove authorization session. + + :return: TornadoResponse object + """ + try: + response = super(AuthServer, self).create_authorization_response(response, username) + response.clear_cookie('auth_session') + return response + except Exception as e: + self.log.exception(e) + return self.handle_response( + payload=getHTML('server error', theme='error', body='traceback', info=repr(e)), delSession=True) + + def validate_consent_request(self, request, provider=None): + """ Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization:: + + :param object request: tornado request + :param provider: provider + + :return: response generated by `handle_response` or S_ERROR or html + """ + if request.method != 'GET': + return self.handle_response( + payload=getHTML("use GET method", theme="error", info='Use GET method to access this endpoint.'), + delSession=True) + try: + request = self.create_oauth2_request(request) + # Check Identity Provider + req = self.validateIdentityProvider(request, provider) + + # If return IdP selector + if isinstance(req, six.string_types): + return req + + self.log.info('Validate consent request for', req.state) + grant = self.get_authorization_grant(req) + self.log.debug('Use grant:', grant) + grant.validate_consent_request() + if not hasattr(grant, 'prompt'): + grant.prompt = None + + # Submit second auth flow through IdP + return self.getIdPAuthorization(req.provider, req) + + except OAuth2Error as error: + self.db.removeSession(request.sessionID) + code, body, _ = error(None) + return self.handle_response( + payload=getHTML(repr(e), state=code, body=body, info='OAuth2 error.'), delSession=True) + except Exception as e: + self.db.removeSession(request.sessionID) + self.log.exception(e) + return self.handle_response( + payload=getHTML('server error', theme='error', body='traceback', info=repr(e)), delSession=True) + + def validateIdentityProvider(self, request, provider): + """ Check if identity provider registred in DIRAC + + :param object request: request + :param str provider: provider name + + :return: OAuth2Request object or HTML -- new request with provider name or provider selector + """ + if provider: + request.provider = provider + + # Find identity provider for group + groupProvider = getIdPForGroup(request.group) if request.groups else None + + # If requested access token for group that is not registred in any identity provider + # or the requested provider does not match the group return error + if request.group and not groupProvider and 'proxy' not in request.scope: + raise Exception('The %s group belongs to the VO that is not tied to any Identity Provider.' % request.group) + + self.log.debug("Check if %s identity provider registred in DIRAC.." % request.provider) + # Research supported IdPs + result = getProvidersForInstance('Id') + if not result['OK']: + raise Exception(result['Message']) + + idPs = result['Value'] + if not idPs: + raise Exception('No identity providers found.') + + if request.provider: + if request.provider not in idPs: + raise Exception('%s identity provider is not registered.' % request.provider) + elif groupProvider and request.provider != groupProvider: + raise Exception('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, + request.provider)) + return request + + # If no identity provider is specified, it must be assigned + if groupProvider: + request.provider = groupProvider + return request + + # If only one identity provider is registered, then choose it + if len(idPs) == 1: + request.provider = idPs[0] + return request + + # Choose IdP HTML interface + with dom.div(cls="row m-5 justify-content-md-center") as tag: + for idP in idPs: + result = getProviderInfo(idP) + if result['OK']: + logo = result['Value'].get('logoURL') + with dom.div(cls="col-md-6 p-2").add(dom.div(cls="card shadow-lg h-100 border-0")): + with dom.div(cls="row m-2 justify-content-md-center align-items-center h-100"): + with dom.div(cls="col-auto"): + dom.h2(idP) + dom.a(href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query), cls="stretched-link") + if logo: + dom.div(dom.img(src=logo, cls="card-img"), cls="col-auto") + + # Render into HTML + return getHTML("Identity Provider selection..", body=tag, icon='fingerprint', + info="Dirac itself is not an Identity Provider. " + "You will need to select one to continue.") + + def __registerNewUser(self, provider, payload): + """ Register new user + + :param str provider: provider + :param dict payload: user information dictionary + + :return: S_OK()/S_ERROR() + """ + from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient + + username = payload['sub'] + + mail = {} + mail['subject'] = "[DIRAC AS] User %s to be added." % username + mail['body'] = 'User %s was authenticated by %s.' % (username, provider) + mail['body'] += "\n\nNew user to be added with the following information:\n%s" % pprint.pformat(payload) + mail['body'] += "\n\nPlease, add '%s' to /Register/Users//DN option.\n" % wrapIDAsDN(username) + mail['body'] += "\n\n------" + mail['body'] += "\n This is a notification from the DIRAC authorization service, please do not reply.\n" + result = S_OK() + for addresses in getEmailsForGroup('dirac_admin'): + result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) + if not result['OK']: + self.log.error(result['Message']) + if result['OK']: + self.log.info(result['Value'], "administrators have been notified about a new user.") + return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py new file mode 100644 index 00000000000..1dc58a8ad42 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -0,0 +1,118 @@ +""" This class describe Authorization Code grant type +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +from authlib.jose import JsonWebSignature +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant +from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + + +class OAuth2Code(dict): + """ This class describe Authorization Code object """ + + def __init__(self, params): + """ C'or """ + params['auth_time'] = params.get('auth_time', int(time())) + super(OAuth2Code, self).__init__(params) + + @property + def user(self): + return self.get('user_id') + + @property + def code_challenge(self): + return self.get('code_challenge') + + @property + def code_challenge_method(self): + return self.get('code_challenge_method', 'pain') + + def is_expired(self): + return self.get('auth_time') + 300 < time() + + def get_redirect_uri(self): + return self.get('redirect_uri') + + def get_scope(self): + return self.get('scope', '') + + def get_auth_time(self): + return self.get('auth_time') + + def get_nonce(self): + return self.get('nonce') + + +class AuthorizationCodeGrant(_AuthorizationCodeGrant): + """ See :class:`authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant` """ + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + pass + + def delete_authorization_code(self, authorization_code): + pass + + def query_authorization_code(self, code, client): + """ Parse authorization code + + :param code: authorization code as JWS + :param client: client + + :return: OAuth2Code or None + """ + self.server.log.debug('Query authorization code:', code) + jws = JsonWebSignature(algorithms=['RS256']) + result = self.server.db.getKeySet() + if not result['OK']: + raise Exception(result['Message']) + err = None + data = None + for key in result['Value'].keys: + try: + data = jws.deserialize_compact(code, key) + except Exception as e: + err = e + if err: + self.server.log.error('Cannot get authorization code:', repr(err)) + return None + try: + item = OAuth2Code(json_loads(urlsafe_b64decode(data['payload']))) + self.server.log.debug('Authorization code scope:', item.get_scope()) + except Exception as e: + self.server.log.error('Cannot read authorization code:', repr(e)) + return None + if not item.is_expired(): + return item + + def authenticate_user(self, authorization_code): + """ Authenticate the user related to this authorization_code. + + :param authorization_code: authorization code + """ + return authorization_code.user + + def generate_authorization_code(self): + """ The method to generate "code" value for authorization code data. + + :return: str + """ + self.server.log.debug('Generate authorization code for credentials:', self.request.user) + jws = JsonWebSignature(algorithms=['RS256']) + protected = {'alg': 'RS256'} + code = OAuth2Code({'user_id': self.request.user['ID'], + # These scope already contain DIRAC groups + 'scope': self.request.data['scope'], + 'redirect_uri': self.request.args['redirect_uri'], + 'client_id': self.request.args['client_id'], + 'code_challenge': self.request.args.get('code_challenge'), + 'code_challenge_method': self.request.args.get('code_challenge_method')}) + self.server.log.debug('Authorization code generated:', dict(code)) + result = self.server.db.getPrivateKey() + if not result['OK']: + raise OAuth2Error('Cannot generate authorization code: %s' % result['Message']) + return jws.serialize_compact(protected, json_b64encode(dict(code)), result['Value']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py new file mode 100644 index 00000000000..34a4b144e7b --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from dominate import document, tags as dom +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationEndpointMixin +from authlib.oauth2.rfc6749.errors import InvalidClientError, UnauthorizedClientError +from authlib.oauth2.rfc8628 import (DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, + DeviceCodeGrant as _DeviceCodeGrant, + DeviceCredentialDict) +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML + + +class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): + """ See :class:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint` """ + + def create_endpoint_response(self, req): + """ See :func:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint.create_endpoint_response` """ + # Share original request object to endpoint class before create_endpoint_response + self.req = req + return super(DeviceAuthorizationEndpoint, self).create_endpoint_response(req) + + def get_verification_uri(self): + """ Create verification uri when `DeviceCode` flow initialized + + :return: str + """ + return self.server.metadata['device_authorization_endpoint'] + + def save_device_credential(self, client_id, scope, data): + """ Save device credentials + + :param str client_id: client id + :param str scope: request scopes + :param dict data: device credentials + """ + data.update(dict(uri='{api}?{query}&response_type=device&client_id={client_id}&scope={scope}'.format( + api=data['verification_uri'], query=self.req.query, client_id=client_id, scope=scope, + ), id=data['device_code'], client_id=client_id, scope=scope)) + result = self.server.db.addSession(data) + if not result['OK']: + raise OAuth2Error('Cannot save device credentials', result['Message']) + + +class DeviceCodeGrant(_DeviceCodeGrant, AuthorizationEndpointMixin): + """ See :class:`authlib.oauth2.rfc8628.DeviceCodeGrant` """ + RESPONSE_TYPES = {'device'} + + def validate_authorization_request(self): + """ Validate authorization request + + :return: None + """ + # Validate client for this request + client_id = self.request.client_id + self.server.log.debug('Validate authorization request of', client_id) + if client_id is None: + raise InvalidClientError(state=self.request.state) + client = self.server.query_client(client_id) + if not client: + raise InvalidClientError(state=self.request.state) + response_type = self.request.response_type + if not client.check_response_type(response_type): + raise UnauthorizedClientError('The client is not authorized to use "response_type={}"'.format(response_type)) + self.request.client = client + self.validate_requested_scope() + + # Check user_code, when user go to authorization endpoint + userCode = self.request.args.get('user_code') + if not userCode: + raise OAuth2Error('user_code is absent.') + + # Get session from cookie + if not self.server.db.getSessionByUserCode(userCode): + raise OAuth2Error('Session with %s user code is expired.' % userCode) + return None + + def create_authorization_response(self, redirect_uri, user): + """ Mark session as authed with received user + + :param str redirect_uri: redirect uri + :param dict user: dictionary with username and userID + + :return: result of `handle_response` + """ + result = self.server.db.getSessionByUserCode(self.request.data['user_code']) + if not result['OK']: + return 500, getHTML("server error", theme='error', body=result['Message'], + info='Failed to read %s authorization session.' % self.request.data['user_code']) + data = result['Value'] + data.update(dict(user_id=user['ID'], uri=self.request.uri, + username=user['username'], scope=self.request.scope)) + # Save session with user + result = self.server.db.updateSession(data, data['id']) + if not result['OK']: + return 500, getHTML("server error", theme='error', body=result['Message'], + info='Failed to save %s authorization session status.' % self.request.data['user_code']) + + # Notify user that authorization complite. + return 200, getHTML( + "authorization complite!", theme='success', + info='Authorization has been completed, now you can close this window and return to the terminal.') + + def query_device_credential(self, device_code): + """ Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. + + :param str device_code: device code + + :return: dict + """ + result = self.server.db.getSession(device_code) + if not result['OK']: + raise OAuth2Error(result['Message']) + data = result['Value'] + if not data: + return None + data['verification_uri'] = self.server.metadata['device_authorization_endpoint'] + data['expires_at'] = int(data['expires_in']) + int(time.time()) + data['interval'] = DeviceAuthorizationEndpoint.INTERVAL + return DeviceCredentialDict(data) + + def query_user_grant(self, user_code): + """ Check if user alredy authed and return it to token generator + + :param str user_code: user code + + :return: (str, bool) or None -- user dict and user auth status + """ + result = self.server.db.getSessionByUserCode(user_code) + if not result['OK']: + raise OAuth2Error('Cannot found authorization session', result['Message']) + return (result['Value']['user_id'], True) if result['Value'].get('username', "None") != "None" else None + + def should_slow_down(self, *args): + """ The authorization request is still pending and polling should continue, + but the interval MUST be increased by 5 seconds for this and all subsequent requests. + """ + return False diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py new file mode 100644 index 00000000000..ae2743406b5 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -0,0 +1,79 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant + +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, wrapIDAsDN + + +class RefreshTokenGrant(_RefreshTokenGrant): + """ See :class:`authlib.oauth2.rfc6749.grants.RefreshTokenGrant` """ + + DEFAULT_EXPIRES_AT = 12 * 3600 + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def authenticate_refresh_token(self, refresh_token): + """ Get credential for token + + :param str refresh_token: refresh token + + :return: dict or None + """ + result = self.server.readToken(refresh_token) + if not result['OK']: + raise OAuth2Error(result['Message']) + rtDict = result['Value'] + result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) + if not result['OK']: + raise OAuth2Error(result['Message']) + credential = result['Value'] + + if int(rtDict['iat']) != int(credential['issued_at']): + # An attempt to reuse the refresh token was detected + prov = self.server.idps.getIdProvider(rtDict['provider']) + if prov['OK']: + prov['Value'].revokeToken(credential['refresh_token']) + prov['Value'].revokeToken(credential['access_token'], 'access_token') + return None + + credential.update(rtDict) + return credential + + def authenticate_user(self, credential): + """ Authorize user + + :param dict credential: credential (token payload) + + :return: str or bool + """ + result = getUsernameForDN(wrapIDAsDN(credential['sub'])) + if not result['OK']: + self.server.log.error(result['Message']) + return result.get('Value') + + def issue_token(self, user, credential): + """ Refresh tokens + + :param user: unuse + :param dict credential: token credential + + :return: dict + """ + if credential['refresh_token']: + result = self.server.idps.getIdProvider(credential['provider']) + if result['OK']: + result = result['Value'].refreshToken(credential['refresh_token']) + else: + result = self.server.tokenCli.getToken(user, self.server._getScope(credential['scope'], 'g')) + if result['OK']: + token = result['Value'] + result = self.server.registerRefreshToken(credential, token) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] + + def revoke_old_credential(self, credential): + """ Remove old credential """ + pass diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py new file mode 100644 index 00000000000..729a4c7f718 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint + + +class RevocationEndpoint(_RevocationEndpoint): + """ See :class:`authlib.oauth2.rfc7009.RevocationEndpoint` """ + + def query_token(self, token, token_type_hint, client): + """ Query requested token from database. + + :param str token: token + :param str token_type_hint: token type + :param client: client + + :return: dict + """ + if token_type_hint == 'refresh_token': + result = self.server.readToken(token) + if not result['OK']: + raise OAuth2Error(result['Message']) + rtDict = result['Value'] + result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] + return {token_type_hint: token} + + def revoke_token(self, token): + """ Mark the give token as revoked. + + :param dict token: token dict + """ + result = self.server.idps.getIdProviderForToken(token['access_token']) + if not result['OK']: + raise OAuth2Error(result['Message']) + if result['OK']: + for tokenType in token: + result = result['Value'].revokeToken(token[tokenType], tokenType) + if not result['OK']: + self.server.log.error(result['Message']) + if not result['OK']: + raise OAuth2Error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py new file mode 100644 index 00000000000..d4daf0165a7 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -0,0 +1,63 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import time +import pprint + +from DIRAC import gConfig, gLogger +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata + +__RCSID__ = "$Id$" + +DEFAULT_CLIENTS = { + 'DIRACCLI': dict(client_id='DIRAC_CLI', scope='proxy g: lifetime:', response_types=['device'], + grant_types=['urn:ietf:params:oauth:grant-type:device_code', 'refresh_token'], + token_endpoint_auth_method='none', verify=False, + ProviderType='OAuth2'), + 'DIRACWeb': dict(client_id='DIRAC_Web', scope='g:', response_types=['code'], + grant_types=['authorization_code', 'refresh_token'], + ProviderType='OAuth2') +} + + +def getDIRACClients(): + """ Get DIRAC authorization clients + + :return: S_OK(dict)/S_ERROR() + """ + clients = DEFAULT_CLIENTS.copy() + result = getAuthorizationServerMetadata(ignoreErrors=True) + confClients = result.get('Value', {}).get('Clients', {}) + for cli in confClients: + if cli not in clients: + clients[cli] = confClients[cli] + else: + clients[cli].update(confClients[cli]) + return clients + + +class Client(OAuth2ClientMixin): + + def __init__(self, params): + if params.get('redirect_uri') and not params.get('redirect_uris'): + params['redirect_uris'] = [params['redirect_uri']] + self.set_client_metadata(params) + self.client_id = params['client_id'] + self.client_secret = params.get('client_secret', '') + self.client_id_issued_at = params.get('client_id_issued_at', int(time.time())) + self.client_secret_expires_at = params.get('client_secret_expires_at', 0) + + def get_allowed_scope(self, scope): + if not isinstance(scope, six.string_types): + scope = list_to_scope(scope) + allowed = scope_to_list(super(Client, self).get_allowed_scope(scope)) + for s in scope_to_list(scope): + for def_scope in scope_to_list(self.scope): + if s.startswith(def_scope) and s not in allowed: + allowed.append(s) + gLogger.debug('Try to allow "%s" scope:' % scope, allowed) + return list_to_scope(list(set(allowed))) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py new file mode 100644 index 00000000000..ab4c7953008 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -0,0 +1,115 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tornado.escape import json_decode +from authlib.common.encoding import to_unicode +from authlib.oauth2 import OAuth2Request as _OAuth2Request +from authlib.oauth2.rfc6749.util import scope_to_list +from six.moves.urllib.parse import quote + +__RCSID__ = "$Id$" + + +class OAuth2Request(_OAuth2Request): + """ OAuth request object """ + + def addScopes(self, scopes): + """ Add new scopes to query + + :param list scopes: scopes + """ + self.setQueryArguments(scope=list(set(scope_to_list(self.scope or '') + scopes))) + + def setQueryArguments(self, **kwargs): + """ Set query arguments """ + for k in kwargs: + # Quote value before add it to request query + value = '+'.join([quote(str(v)) for v in kwargs[k]]) if isinstance(kwargs[k], list) else quote(str(kwargs[k])) + # Remove argument from uri + query = re.sub(r"&{argument}(=[^&]*)?|^{argument}(=[^&]*)?&?".format(argument=k), "", self.query) + # Add new one + if query: + query += '&' + query += "%s=%s" % (k, value) + # Re-init class + self.__init__(self.method, to_unicode(self.path + '?' + query)) + + @property + def path(self): + """ URL path + + :return: str + """ + return self.uri.replace('?%s' % (self.query or ''), '') + + @property + def groups(self): + """ Search DIRAC groups in scopes + + :return: list + """ + return [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] + + @property + def group(self): + """ Search DIRAC group in scopes + + :return: str + """ + groups = [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] + return groups[0] if groups else None + + @property + def provider(self): + """ Search IdP in scopes + + :return: str + """ + return self.data.get('provider') + + @provider.setter + def provider(self, provider): + self.setQueryArguments(provider=provider) + + @property + def sessionID(self): + """ Search IdP in scopes + + :return: str + """ + return self.data.get('id') + + @sessionID.setter + def sessionID(self, sessionID): + self.setQueryArguments(id=sessionID) + + def toDict(self): + """ Convert class to dictionary + + :return: dict + """ + return {'method': self.method, 'uri': self.uri} + + +def createOAuth2Request(request, method_cls=OAuth2Request, use_json=False): + """ Create request object + + :param request: request + :type request: object, dict + :param object method_cls: returned class + :param str use_json: if data is json + + :return: object -- `OAuth2Request` + """ + if isinstance(request, method_cls): + return request + if isinstance(request, dict): + return method_cls(request['method'], request['uri'], request.get('body'), request.get('headers')) + if use_json: + return method_cls(request.method, request.full_url(), json_decode(request.body), request.headers) + body = {k: request.body_arguments[k][-1].decode("utf-8") + for k in request.body_arguments if request.body_arguments[k]} + return method_cls(request.method, request.full_url(), body, request.headers) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py new file mode 100644 index 00000000000..68c0c5feed6 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -0,0 +1,269 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import jwt +import six +import stat +import time +import json +import datetime + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token + +BEARER_TOKEN_ENV = 'BEARER_TOKEN' +BEARER_TOKEN_FILE_ENV = 'BEARER_TOKEN_FILE' + + +def getTokenFileLocation(fileName=None): + """ Research token file location. Use the bearer token discovery protocol + defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. + + :param str fileName: file name to dump to + + :return: str + """ + if fileName: + return fileName + if os.environ.get(BEARER_TOKEN_FILE_ENV): + return os.environ[BEARER_TOKEN_FILE_ENV] + elif os.environ.get('XDG_RUNTIME_DIR'): + return "%s/bt_u%s" % (os.environ['XDG_RUNTIME_DIR'], os.getuid()) + else: + return "/tmp/bt_u%s" % os.getuid() + + +def getLocalTokenDict(location=None): + """ Search local token. Use the bearer token discovery protocol + defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. + + :param str location: token file path + + :return: S_OK(dict)/S_ERROR() + """ + result = readTokenFromEnv() + return result if result['OK'] and result['Value'] else readTokenFromFile(location) + + +def readTokenFromEnv(): + """ Read token from an environ variable + + :return: S_OK(dict or None) + """ + token = os.environ.get(BEARER_TOKEN_ENV, "").strip() + return S_OK(OAuth2Token(token) if token else None) + + +def readTokenFromFile(fileName=None): + """ Read token from a file + + :param str fileName: filename to read + + :return: S_OK(dict or None)/S_ERROR() + """ + location = getTokenFileLocation(fileName) + try: + with open(location, 'rt') as f: + token = f.read().strip() + except IOError as e: + return S_ERROR(DErrno.EOF, "Can't open %s token file.\n%s" % (location, repr(e))) + return S_OK(OAuth2Token(token) if token else None) + + +def writeToTokenFile(tokenContents, fileName): + """ Write a token string to file + + :param str tokenContents: token as string + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + location = getTokenFileLocation(fileName) + try: + with open(location, 'wt') as fd: + fd.write(tokenContents) + except Exception as e: + return S_ERROR(DErrno.EWF, " %s: %s" % (location, repr(e))) + try: + os.chmod(location, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + return S_ERROR(DErrno.ESPF, "%s: %s" % (location, repr(e))) + return S_OK(location) + + +def writeTokenDictToTokenFile(tokenDict, fileName=None): + """ Write a token dict to file + + :param dict tokenDict: dict object to dump to file + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + fileName = getTokenFileLocation(fileName) + if not isinstance(tokenDict, dict): + return S_ERROR('Token is not a dictionary') + return writeToTokenFile(json.dumps(tokenDict), fileName) + + +class OAuth2Token(_OAuth2Token): + """ Implementation of a Token object """ + + def __init__(self, params=None, **kwargs): + """ Constructor + """ + if six.PY3 and isinstance(params, bytes): + params = params.decode() + if isinstance(params, six.string_types): + # Is params a JWT? + params = params.strip() + if re.match(r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$", params): + params = dict(access_token=params) + else: + params = json.loads(params) + + kwargs.update(params or {}) + kwargs['issued_at'] = kwargs.get('issued_at', kwargs.get('iat')) + kwargs['expires_at'] = kwargs.get('expires_at', kwargs.get('exp')) + if not kwargs.get('expires_at') and kwargs.get('access_token'): + # Get access token expires_at claim + kwargs['expires_at'] = self.get_claim('exp') + super(OAuth2Token, self).__init__(kwargs) + + def check_client(self, client): + """ A method to check if this token is issued to the given client. + + :param client: client object + + :return: bool + """ + return self.get('client_id', self.get('azp')) == client.client_id + + def get_scope(self): + """ A method to get scope of the authorization code. + + :return: str + """ + return self.get('scope') + + def get_expires_in(self): + """ A method to get the ``expires_in`` value of the token. + + :return: int + """ + return int(self.get('expires_in')) + + def is_expired(self): + """ A method to define if this token is expired. + + :return: bool + """ + if self.get('expires_at'): + return int(self.get('expires_at')) < time.time() + elif self.get('issued_at') and self.get('expires_in'): + return int(self.get('issued_at')) + int(self.get('expires_in')) < time.time() + else: + exp = self.get_payload().get('exp') + return int(exp) < time.time() if exp else True + + @property + def scopes(self): + """ Get tokens scopes + + :return: list + """ + return scope_to_list(self.get('scope', '')) + + @property + def groups(self): + """ Get tokens groups + + :return: list + """ + return [s.split(':')[1] for s in self.scopes if s.startswith('g:') and s.split(':')[1]] + + def get_payload(self, token_type='access_token'): + """ Decode token + + :param str token_type: token type + + :return: dict + """ + if not self.get(token_type): + return {} + return jwt.decode(self.get(token_type), options=dict(verify_signature=False, + verify_exp=False, + verify_aud=False, + verify_nbf=False)) + + def get_claim(self, claim, token_type='access_token'): + """ Get token claim without verification + + :param str attr: attribute + :param str token_type: token type + + :return: str + """ + return self.get_payload(token_type).get(claim) + + def dump_to_string(self): + """ Dump token dictionary to sting + + :return: str + """ + return json.dumps(dict(self)) + + def getInfoAsString(self): + """ Return information about token as string + + :return: str + """ + result = IdProviderFactory().getIdProviderForToken(self.get('access_token')) + if not result['OK']: + return "Cannot load provider: %s" % result['Message'] + cli = result['Value'] + cli.token = self.copy() + result = cli.verifyToken() + if not result['OK']: + return result['Message'] + payload = result['Value'] + result = cli.researchGroup(payload) + if not result['OK']: + return result['Message'] + credDict = result['Value'] + result = Registry.getUsernameForDN(credDict['DN']) + if not result['OK']: + return result['Message'] + credDict['username'] = result['Value'] + if credDict.get('group'): + credDict['properties'] = Registry.getPropertiesForGroup(credDict['group']) + payload.update(credDict) + return self.__formatTokenInfoAsString(payload) + + def __formatTokenInfoAsString(self, infoDict): + """ Convert a token infoDict into a string + + :param dict infoDict: info + + :return: str + """ + secsLeft = int(infoDict['exp']) - time.time() + strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") + leftAlign = 13 + contentList = [] + contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) + contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) + contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) + contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) + if infoDict.get('group'): + contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) + if infoDict.get('properties'): + contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) + return "\n".join(contentList) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py new file mode 100644 index 00000000000..674e0472987 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import traceback + +from dominate import document, tags as dom +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata + +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata + + +def collectMetadata(issuer=None, ignoreErrors=False): + """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defined by the IETF specification: + https://datatracker.ietf.org/doc/html/rfc8414#section-2 + + :param str issuer: issuer to set + + :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time + """ + result = getAuthorizationServerMetadata(issuer, ignoreErrors=ignoreErrors) + if not result['OK']: + raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) + metadata = result['Value'] + for name, endpoint in [('jwks_uri', 'jwk'), + ('token_endpoint', 'token'), + ('userinfo_endpoint', 'userinfo'), + ('revocation_endpoint', 'revoke'), + ('redirect_uri', 'redirect'), + ('authorization_endpoint', 'authorization'), + ('device_authorization_endpoint', 'device')]: + metadata[name] = metadata['issuer'].strip('/') + '/' + endpoint + metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] + metadata['grant_types_supported'] = ['code', 'authorization_code', 'refresh_token', + 'urn:ietf:params:oauth:grant-type:device_code'] + metadata['response_types_supported'] = ['code', 'device', 'token'] + metadata['code_challenge_methods_supported'] = ['S256'] + return AuthorizationServerMetadata(metadata) + + +def getHTML(title, info=None, body=None, style=None, state=None, theme=None, icon=None): + """ Provide HTML object + + :param str title: short name of the notification, e.g.: server error + :param str info: some short description if needed, e.g.: It looks like the server is not responding + :param body: it can be string or dominate tag object, e.g.: + from dominate import tags as dom + return getHTML('server error', body=dom.pre(dom.code(result['Message'])) + :param str style: additional css style if needed, e.g.: '.card{color:red;}' + :param int state: response state code, if needed, e.g.: 404 + :param str theme: message color theme, the same that in bootstrap 5, e.g.: 'warning' + :param str icon: awesome icon name, e.g.: 'users' + + :return: str -- HTML document + """ + html = document("DIRAC - %s" % title) + + # select the color to the state code + if state in [400, 401, 403, 404]: + theme = theme or 'warning' + elif state in [500]: + theme = theme or 'danger' + elif state in [200]: + theme = theme or 'success' + + # select the icon to the theme + if theme in ['warning', 'warn']: + theme = 'warning' + icon = icon or 'exclamation-triangle' + elif theme == 'info': + icon = icon or 'info' + elif theme == 'success': + icon = icon or 'check' + elif theme in ['error', 'danger']: + theme = 'danger' + icon = icon or 'times' + else: + theme = theme or 'secondary' + icon = icon or 'flask' + + # If body is text wrap it with tags + if body and isinstance(body, six.string_types): + body = dom.pre(dom.code(traceback.format_exc() if body == 'traceback' else body), cls="mt-5") + + try: + diracLogo = collectMetadata(ignoreErrors=True).get('logoURL', '') + except Exception: + diracLogo = '' + + # Create head + with html.head: + # Meta tags + dom.meta(charset="utf-8") + dom.meta(name="viewport", content="width=device-width, initial-scale=1") + # Favicon + dom.link(rel="shortcut icon", href="/static/core/img/icons/system/favicon.ico", type="image/x-icon") + # Provide awesome icons + # https://fontawesome.com/v4.7/license/ + dom.script(src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/js/all.min.js") + # Enable bootstrap 5 + # https://getbootstrap.com/docs/5.0/getting-started/introduction/ + # https://getbootstrap.com/docs/5.0/about/license/ + dom.link(rel='stylesheet', integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC", + href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css", crossorigin="anonymous") + dom.script(src='https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/js/bootstrap.bundle.min.js', + integrity="sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM", + crossorigin="anonymous") + # Provide additional css + style = ".card{transition:.3s;}.card:hover{transform:scale(1.03);}" + (style or '') + dom.style(style) + + # Create body + with html: + # Background image + dom.i(cls='position-absolute bottom-0 start-0 translate-middle-x m-5 fa ' + 'fa-%s text-%s' % (icon, theme), + style="font-size:40vw;z-index:-1;") + + # A4 page with align center + with dom.div(cls="row vh-100 vw-100 justify-content-md-center align-items-center m-0"): + with dom.div(cls='container', style="max-width:600px;") as page: + # Main panel + with dom.div(cls="row align-items-center"): + # Logo + dom.div(dom.img(src=diracLogo, cls="card-img px-2"), cls="col-md-6 my-3") + # Information card + with dom.div(cls="col-md-6 my-3"): + + # Show response state number + if state and state != 200: + dom.div(dom.h1(state, cls='text-center badge bg-%s text-wrap' % theme), cls="row py-2") + + # Message title + with dom.div(cls="row"): + dom.div(dom.i(cls="fa fa-%s text-%s" % (icon, theme)), cls="col-auto") + dom.div(title, cls='col-auto ps-0 pb-2 fw-bold') + + # Description + if info: + dom.small(dom.i(cls="fa fa-info text-info")) + dom.small(info, cls="ps-1") + + # Add content + if body: + page.add(body) + + return html.render() diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py new file mode 100644 index 00000000000..89caf8fa57a --- /dev/null +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +######################################################################## +# File : dirac-login.py +# Author : Andrii Lytovchenko +######################################################################## +""" +Login to DIRAC. + +Example: + $ dirac-login -g dirac_user +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys + +import DIRAC +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Utilities.DIRACScript import DIRACScript as Script +from DIRAC.Core.Security.ProxyFile import writeToProxyFile +from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, readTokenFromFile, + getTokenFileLocation) + +__RCSID__ = "$Id$" + + +class Params(object): + + def __init__(self): + self.info = False + self.proxy = False + self.group = None + self.lifetime = None + self.issuer = None + self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() + self.tokenLoc = None + + def getInfo(self, _arg): + """ To return user info + + :return: S_OK() + """ + self.info = True + return S_OK() + + def returnProxy(self, _arg): + """ To return proxy + + :return: S_OK() + """ + self.proxy = True + return S_OK() + + def setGroup(self, arg): + """ Set group + + :param str arg: group + + :return: S_OK() + """ + self.group = arg + return S_OK() + + def setIssuer(self, arg): + """ Set issuer + + :param str arg: issuer + + :return: S_OK() + """ + self.issuer = arg + return S_OK() + + def setTokenFile(self, arg): + """ Set token file + + :param str arg: token file + + :return: S_OK() + """ + self.tokenLoc = arg + return S_OK() + + def setLivetime(self, arg): + """ Set email + + :param str arg: lifetime + + :return: S_OK() + """ + self.lifetime = arg + return S_OK() + + def registerCLISwitches(self): + """ Register CLI switches """ + Script.registerSwitch("", "info", "output current user authorization status", self.getInfo) + Script.registerSwitch("P", "proxy", "request a proxy certificate with DIRAC group extension", self.returnProxy) + Script.registerSwitch("g:", "group=", "set DIRAC group", self.setGroup) + Script.registerSwitch("I:", "issuer=", "set issuer", self.setIssuer) + Script.registerSwitch("T:", "lifetime=", "set proxy lifetime in a hours", self.setLivetime) + Script.registerSwitch("F:", "file=", "set token file location", self.setTokenFile) + + def doOAuthMagic(self): + """ Magic method with tokens + + :return: S_OK()/S_ERROR() + """ + tokenFile = getTokenFileLocation(self.tokenLoc) + + if self.info: + # Try to get user information + Script.enableCS() + + useTokens = DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true") + if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: + useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") + if useTokens: + gLogger.notice('You are currently using access token to access new HTTP DIRAC services.' + ' To use a proxy instead, do the following:\n', + 'export DIRAC_USE_ACCESS_TOKEN=False\n') + result = readTokenFromFile(tokenFile) + if result['OK']: + gLogger.notice(result['Value'].getInfoAsString()) + else: + gLogger.notice('You are currently using proxy to access new HTTP DIRAC services.' + ' To use a access token instead, do the following:\n', + 'export DIRAC_USE_ACCESS_TOKEN=True\n') + result = getProxyInfo(self.proxyLoc) + if result['OK']: + gLogger.notice(formatProxyInfoAsString(result['Value'])) + + return result + + params = {} + if self.issuer: + params['issuer'] = self.issuer + result = IdProviderFactory().getIdProvider('DIRACCLI', **params) + if not result['OK']: + return result + idpObj = result['Value'] + scope = [] + if self.group: + scope.append('g:%s' % self.group) + if self.proxy: + scope.append('proxy') + if self.lifetime: + scope.append('lifetime:%s' % (int(self.lifetime) * 3600)) + idpObj.scope = '+'.join(scope) if scope else '' + + # Submit Device authorisation flow + result = idpObj.deviceAuthorization() + if not result['OK']: + return result + + if self.proxy: + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'False' + # Save new proxy certificate + result = writeToProxyFile(idpObj.token['proxy'].encode("UTF-8"), self.proxyLoc) + if not result['OK']: + return result + gLogger.notice('Proxy is saved to %s.' % self.proxyLoc) + else: + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'True' + # Revoke old tokens from token file + if os.path.isfile(tokenFile): + result = readTokenFromFile(tokenFile) + if not result['OK']: + gLogger.error(result['Message']) + elif result['Value']: + oldToken = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(oldToken[tokenType], tokenType) + if result['OK']: + gLogger.notice('%s is revoked from' % tokenType, tokenFile) + else: + gLogger.error(result['Message']) + + # Save new tokens to token file + result = writeTokenDictToTokenFile(idpObj.token, tokenFile) + if not result['OK']: + return result + tokenFile = result['Value'] + gLogger.notice('New token is saved to %s.' % tokenFile) + + if not DIRAC.gConfig.getValue('/DIRAC/Security/Authorization/issuer'): + gLogger.notice('To continue use token you need to add /DIRAC/Security/Authorization/issuer option.') + if not self.issuer: + DIRAC.exit(1) + DIRAC.gConfig.setOptionValue('/DIRAC/Security/Authorization/issuer', self.issuer) + + # Try to get user information + result = Script.enableCS() + if not result['OK']: + return S_ERROR("Cannot contact CS.") + DIRAC.gConfig.forceRefresh() + + if self.proxy: + result = getProxyInfo(self.proxyLoc) + if not result['OK']: + return result['Message'] + gLogger.notice(formatProxyInfoAsString(result['Value'])) + else: + result = readTokenFromFile(tokenFile) + if not result['OK']: + return result + gLogger.notice(result['Value'].getInfoAsString()) + + return S_OK() + + +@Script() +def main(): + piParams = Params() + piParams.registerCLISwitches() + + Script.disableCS() + Script.parseCommandLine(ignoreErrors=True) + DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") + + resultDoMagic = piParams.doOAuthMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py new file mode 100644 index 00000000000..620954535a4 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +######################################################################## +# File : dirac-logout.py +# Author : Andrii Lytovchenko +######################################################################## +""" +Logout + +Example: + $ dirac-logout +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys + +import DIRAC +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Base import Script +from DIRAC.Core.Utilities.DIRACScript import DIRACScript +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (readTokenFromFile, readTokenFromEnv, + getTokenFileLocation, BEARER_TOKEN_ENV) + +__RCSID__ = "$Id$" + + +class Params(object): + + def __init__(self): + self.issuer = None + self.tokenFileLoc = None + + def setIssuer(self, arg): + """ Set issuer + + :param str arg: issuer + + :return: S_OK() + """ + self.issuer = arg + return S_OK() + + def setTokenFile(self, arg): + """ Set token file + + :param str arg: token file + + :return: S_OK() + """ + self.tokenFileLoc = arg + return S_OK() + + def registerCLISwitches(self): + """ Register CLI switches """ + Script.registerSwitch( + "I:", + "issuer=", + "set issuer", + self.setIssuer) + Script.registerSwitch( + "F:", + "file=", + "set token file location", + self.setTokenFile) + + def doOAuthMagic(self): + """ Magic method with tokens + + :return: S_OK()/S_ERROR() + """ + tokens = [] + params = {} + if self.issuer: + params['issuer'] = self.issuer + result = IdProviderFactory().getIdProvider('DIRACCLI', **params) + if not result['OK']: + return result + idpObj = result['Value'] + tokenFile = getTokenFileLocation(self.tokenFileLoc) + + # Try to find token in environ and in a token file and revoke it + for result, location in [(readTokenFromEnv(), BEARER_TOKEN_ENV), + (readTokenFromFile(tokenFile), tokenFile)]: + if not result['OK']: + gLogger.error(result['Message']) + elif result['Value']: + token = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(token[tokenType], tokenType) + if result['OK']: + gLogger.notice('%s is revoked from' % tokenType, location) + else: + gLogger.error(result['Message']) + + # After remove token file + if os.path.isfile(tokenFile): + os.unlink(tokenFile) + gLogger.notice('%s token file is removed.' % tokenFile) + + return S_OK() + + +@DIRACScript() +def main(): + piParams = Params() + piParams.registerCLISwitches() + + Script.disableCS() + Script.parseCommandLine(ignoreErrors=True) + DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") + + resultDoMagic = piParams.doOAuthMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py index 8868db2fe1a..9885fa97dbe 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py @@ -254,6 +254,8 @@ def main(): global piParams, pI piParams = Params() piParams.registerCLISwitches() + # Take off tokens + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'False' Script.disableCS() Script.parseCommandLine(ignoreErrors=True) diff --git a/src/DIRAC/Interfaces/API/Dirac.py b/src/DIRAC/Interfaces/API/Dirac.py index e1cb29400c1..8c9b344c874 100755 --- a/src/DIRAC/Interfaces/API/Dirac.py +++ b/src/DIRAC/Interfaces/API/Dirac.py @@ -1772,7 +1772,7 @@ def getJobStatus(self, jobID): if self.jobRepo: self.jobRepo.updateJobs(repoDict) for job, vals in siteDict.items(): # can be an iterator - result[job].update(vals) + result[str(job)].update(vals) return S_OK(result) diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py new file mode 100644 index 00000000000..e16c5cb124a --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -0,0 +1,30 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC import S_OK +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup, getGroupOption + + +__RCSID__ = "$Id$" + + +class CheckInIdProvider(OAuth2IdProvider): + + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPRole') + if not idPScope: + idPScope = 'eduperson_entitlement?value=urn:mace:egi.eu:group:%s:role=%s#aai.egi.eu' % (getVOForGroup(group), + group.split('_')[1]) + return S_OK(scope_to_list(idPScope)) diff --git a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py new file mode 100644 index 00000000000..66a4f201abd --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py @@ -0,0 +1,28 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC import S_OK +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup, getGroupOption + +__RCSID__ = "$Id$" + + +class IAMIdProvider(OAuth2IdProvider): + + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPRole') + if not idPScope: + idPScope = 'wlcg.groups:/%s/%s' % (getVOForGroup(group), group.split('_')[1]) + return S_OK(scope_to_list(idPScope)) diff --git a/src/DIRAC/Resources/IdProvider/IdProvider.py b/src/DIRAC/Resources/IdProvider/IdProvider.py index a3ea4c52383..d1f8e6def02 100644 --- a/src/DIRAC/Resources/IdProvider/IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IdProvider.py @@ -11,9 +11,25 @@ class IdProvider(object): - def __init__(self, parameters=None): + DEFAULT_METADATA = {} + + def __init__(self, **kwargs): + """ C'or + """ self.log = gLogger.getSubLogger(self.__class__.__name__) - self.parameters = parameters + meta = self.DEFAULT_METADATA + meta.update(kwargs) + self.setParameters(meta) + self._initialization(**meta) + + def _initialization(self, **kwargs): + """ Initialization """ + pass def setParameters(self, parameters): + """ Set parameters + + :param dict parameters: parameters of the identity Provider + """ self.parameters = parameters + self.name = parameters.get('ProviderName') diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 1134654f218..f1ced5f37ae 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -10,52 +10,109 @@ from __future__ import division from __future__ import print_function +import jwt + from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getInfoAboutProviders +from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.Resources.IdProvider.Utilities import getProviderInfo, getSettingsNamesForIdPIssuer +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata __RCSID__ = "$Id$" +gCacheMetadata = ThreadSafe.Synchronizer() + class IdProviderFactory(object): - ############################################################################# def __init__(self): """ Standard constructor """ self.log = gLogger.getSubLogger('IdProviderFactory') + self.cacheMetadata = DictCache() + + @gCacheMetadata + def getMetadata(self, idP): + return self.cacheMetadata.get(idP) or {} + + @gCacheMetadata + def addMetadata(self, idP, data, time=24 * 3600): + if data: + self.cacheMetadata.add(idP, time, data) - ############################################################################# - def getIdProvider(self, idProvider): - """ This method returns a IdProvider instance corresponding to the supplied name. + def getIdProviderForToken(self, token): + """ This method returns a IdProvider instance corresponding to the supplied + issuer in a token. - :param str idProvider: the name of the Identity Provider + :param token: access token or dict with access_token key :return: S_OK(IdProvider)/S_ERROR() """ - result = getInfoAboutProviders(of='Id', providerName=idProvider, option="all", section="all") + if isinstance(token, dict): + token = token['access_token'] + + data = {} + + # Read token without verification to get issuer + issuer = jwt.decode(token, leeway=300, + options=dict(verify_signature=False, verify_aud=False))['iss'].strip('/') + + result = getSettingsNamesForIdPIssuer(issuer) if not result['OK']: return result - pDict = result['Value'] - pDict['ProviderName'] = idProvider - pType = pDict['ProviderType'] + return self.getIdProvider(result['Value']) + + def getIdProvider(self, name, **kwargs): + """ This method returns a IdProvider instance corresponding to the supplied + name. + + :param str name: the name of the Identity Provider client + + :return: S_OK(IdProvider)/S_ERROR() + """ + if not name: + return S_ERROR('Identity Provider client name must be not None.') + # Get Authorization Server metadata + try: + asMetaDict = collectMetadata(kwargs.get('issuer'), ignoreErrors=True) + except Exception as e: + return S_ERROR(str(e)) + self.log.debug('Search configuration for', name) + clients = getDIRACClients() + if name in clients: + # If it is a DIRAC default pre-registred client + pDict = asMetaDict + pDict.update(clients[name]) + else: + # if it is external identity provider client + result = getProviderInfo(name) + if not result['OK']: + self.log.error('Failed to read configuration', '%s: %s' % (name, result['Message'])) + return result + pDict = result['Value'] + # Set default redirect_uri + pDict['redirect_uri'] = pDict.get('redirect_uri', asMetaDict['redirect_uri']) + + pDict.update(kwargs) + pDict['ProviderName'] = name - self.log.verbose('Creating IdProvider', 'of %s type with the name %s' % (pType, idProvider)) - subClassName = "%sIdProvider" % (pType) + self.log.verbose('Creating IdProvider of %s type with the name %s' % (pDict['ProviderType'], name)) + subClassName = "%sIdProvider" % pDict['ProviderType'] - result = ObjectLoader().loadObject('Resources.IdProvider.%s' % subClassName) + objectLoader = ObjectLoader.ObjectLoader() + result = objectLoader.loadObject('Resources.IdProvider.%s' % subClassName, subClassName) if not result['OK']: self.log.error('Failed to load object', '%s: %s' % (subClassName, result['Message'])) return result pClass = result['Value'] try: - provider = pClass() - provider.setParameters(pDict) + provider = pClass(**pDict) except Exception as x: msg = 'IdProviderFactory could not instantiate %s object: %s' % (subClassName, str(x)) self.log.exception() - self.log.error(msg) + self.log.warn(msg) return S_ERROR(msg) return S_OK(provider) diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py new file mode 100644 index 00000000000..08c065f4ed8 --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -0,0 +1,546 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import six +import time +import pprint +import requests +from authlib.jose import JsonWebKey, jwt +from authlib.common.urls import url_decode +from authlib.common.security import generate_token +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oauth2.rfc6749.parameters import prepare_token_request +from authlib.oauth2.rfc8628 import DEVICE_CODE_GRANT_TYPE +from authlib.integrations.requests_client import OAuth2Session +from authlib.oidc.discovery.well_known import get_well_known_url +from authlib.oauth2.rfc7636 import create_s256_code_challenge + +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import ThreadSafe +from DIRAC.Resources.IdProvider.IdProvider import IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getVOMSRoleGroupMapping, getGroupOption, + getAllGroups, wrapIDAsDN) +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + +DEFAULT_HEADERS = { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' +} + +gJWKs = ThreadSafe.Synchronizer() +gMetadata = ThreadSafe.Synchronizer() +gRefreshToken = ThreadSafe.Synchronizer() + + +def claimParser(claimDict, attributes): + """ Parse claims to dictionary with certain keys + + :param dict claimDict: claims + :param dict attributes: contain claim and regex to parse it + + :return: dict + """ + profile = {} + result = None + for claim, reg in attributes.items(): + if claim not in claimDict: + continue + profile[claim] = {} + if isinstance(claimDict[claim], dict): + result = claimParser(claimDict[claim], reg) + if result: + profile[claim] = result + elif isinstance(claimDict[claim], six.string_types): + result = re.compile(reg).match(claimDict[claim]) + if result: + for k, v in result.groupdict().items(): + profile[claim][k] = v + else: + profile[claim] = [] + for claimItem in claimDict[claim]: + if isinstance(reg, dict): + result = claimParser(claimItem, reg) + if result: + profile[claim].append(result) + else: + result = re.compile(reg).match(claimItem) + if result: + profile[claim].append(result.groupdict()) + + return profile + + +class OAuth2IdProvider(IdProvider, OAuth2Session): + """ Base class to describe the configuration of the OAuth2 client of the corresponding provider. + """ + + JWKS_REFRESH_RATE = 24 * 3600 + METADATA_REFRESH_RATE = 24 * 3600 + + def __init__(self, **kwargs): + """ Initialization """ + IdProvider.__init__(self, **kwargs) + OAuth2Session.__init__(self, **kwargs) + self.metadata_fetch_last = 0 + self.issuer = self.metadata['issuer'] + self.scope = self.scope or '' + self.jwks = kwargs.get('jwks') + self.verify = kwargs.get('verify', False) + self.token_placement = kwargs.get('token_placement', 'header') + self.code_challenge_method = 'S256' + # self.token_endpoint_auth_method = kwargs.get('token_endpoint_auth_method') #, 'client_secret_post') + self.server_metadata_url = kwargs.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) + self.jwks_fetch_last = time.time() - self.JWKS_REFRESH_RATE + self.metadata_fetch_last = time.time() - self.METADATA_REFRESH_RATE + self.log.debug('"%s" OAuth2 IdP initialization done:' % self.name, + '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, self.client_secret, + pprint.pformat(self.metadata))) + + def get_metadata(self, option=None): + """ Get metadata + + :param str option: option + + :return: option value + """ + if not self.metadata.get(option): + self.fetch_metadata() + return self.metadata.get(option) + + @gMetadata + def fetch_metadata(self): + """ Fetch metada + """ + if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): + data = self.get(self.server_metadata_url, withhold_token=True).json() + self.metadata.update(data) + self.metadata_fetch_last = time.time() + + @gJWKs + def updateJWKs(self): + """ Update JWKs + """ + if self.jwks_fetch_last < (time.time() - self.JWKS_REFRESH_RATE): + try: + self.jwks = self.get(self.get_metadata('jwks_uri'), withhold_token=True).json() + self.jwks_fetch_last = time.time() + return S_OK(self.jwks) + except Exception as e: + self.log.exception(e) + return S_ERROR("Error %s" % repr(e)) + return S_OK() + + def verifyToken(self, accessToken=None, jwks=None): + """ Verify access token + + :param str accessToken: access token + :param dict jwks: JWKs + + :return: dict + """ + # Define an access token + if not accessToken: + accessToken = self.token['access_token'] + # Renew a JWKs of an identity provider if needed + if not jwks: + result = self.updateJWKs() + if not result['OK']: + return result + jwks = self.jwks + if not jwks: + return S_ERROR("JWKs not found.") + # Try to decode and verify an access token + self.log.debug("Try to decode token %s with JWKs:\n" % accessToken, pprint.pformat(jwks)) + try: + return S_OK(jwt.decode(accessToken, JsonWebKey.import_key_set(jwks))) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + @gRefreshToken + def refreshToken(self, refresh_token=None, group=None, **kwargs): + """ Refresh token + + :param str token: refresh_token + :param str group: DIRAC group + + :return: dict + """ + if group: + # If group set add group scopes to request + result = self.getGroupScopes(group) + if not result['OK']: + return result + kwargs.update(dict(scope=list_to_scope(result['Value']))) + + if not refresh_token: + refresh_token = self.token.get('refresh_token') + try: + token = self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token, **kwargs) + return S_OK(OAuth2Token(dict(token))) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + @gRefreshToken + def fetchToken(self, **kwargs): + """ Fetch token + + :return: dict + """ + try: + self.fetch_access_token(self.get_metadata('token_endpoint'), **kwargs) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + self.token['client_id'] = self.client_id + self.token['provider'] = self.name + return S_OK(OAuth2Token(dict(self.token))) + + def revokeToken(self, token=None, token_type_hint='refresh_token'): + """ Revoke token + + :param str token: token + :param str token_type_hint: token type + + :return: S_OK()/S_ERROR() + """ + if not token: + tokn = self.token.get(token_type_hint) + try: + self.revoke_token(self.get_metadata('revocation_endpoint'), token=token, token_type_hint=token_type_hint) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + return S_OK() + + def exchangeGroup(self, group): + """ Get new tokens for group scope + + :param str group: requested group + + :return: dict -- token + """ + result = self.getGroupScopes(group) + if not result['OK']: + return result + groupScopes = result['Value'] + try: + token = self.exchange_token(self.get_metadata('token_endpoint'), subject_token=self.token['access_token'], + subject_token_type='urn:ietf:params:oauth:token-type:access_token', + scope=list_to_scope(scope_to_list(self.scope) + groupScopes)) + if not token: + return S_ERROR('Cannot exchange token with %s group.' % group) + return S_OK(OAuth2Token(dict(token))) + except Exception as e: + self.log.exception(e) + return S_ERROR('Cannot exchange token with %s group: %s' % (group, repr(e))) + + def researchGroup(self, payload=None, token=None): + """ Research group + + :param str payload: token payload + :param str token: access token + + :return: S_OK(dict)/S_ERROR() + """ + if not token: + token = self.token + + if not payload and token: + payload = OAuth2Token(token).get_payload() + + credDict = self.parseBasic(payload) + if not credDict.get('DIRACGroups'): + credDict.update(self.parseEduperson(payload)) + if credDict.get('DIRACGroups'): + self.log.debug('Found next groups:', ', '.join(credDict['DIRACGroups'])) + credDict['group'] = credDict['DIRACGroups'][0] + return S_OK(credDict) + + def parseBasic(self, claimDict): + """ Parse basic claims + + :param dict claimDict: claims + + :return: S_OK(dict)/S_ERROR() + """ + self.log.debug('Token payload:', pprint.pformat(claimDict)) + credDict = {} + credDict['ID'] = claimDict['sub'] + credDict['DN'] = wrapIDAsDN(credDict['ID']) + if claimDict.get('scope'): + self.log.debug('Search groups for %s scope.' % claimDict['scope']) + credDict['DIRACGroups'] = self.getScopeGroups(claimDict['scope']) + return credDict + + def parseEduperson(self, claimDict): + """ Parse eduperson claims + + :return: dict + """ + vos = {} + credDict = {} + attributes = { + 'eduperson_unique_id': '^(?P.*)', + 'eduperson_entitlement': '%s:%s' % ('^(?P[A-z,.,_,-,:]+):(group:registry|group)', + '(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*') + } + self.log.debug('Try to parse eduperson claims..') + # Parse eduperson claims + resDict = claimParser(claimDict, attributes) + if resDict.get('eduperson_unique_id'): + self.log.debug('Found eduperson_unique_id claim:', pprint.pformat(resDict['eduperson_unique_id'])) + credDict['ID'] = resDict['eduperson_unique_id']['ID'] + if resDict.get('eduperson_entitlement'): + self.log.debug('Found eduperson_entitlement claim:', pprint.pformat(resDict['eduperson_entitlement'])) + for voDict in resDict['eduperson_entitlement']: + if voDict['VO'] not in vos: + vos[voDict['VO']] = {'VORoles': []} + if voDict['VORole'] not in vos[voDict['VO']]['VORoles']: + vos[voDict['VO']]['VORoles'].append(voDict['VORole']) + # Search DIRAC groups + for vo in vos: + result = getVOMSRoleGroupMapping(vo) + if not result['OK']: + # Skip VO if it absent in Registry + self.log.debug(result['Message']) + continue + for role in vos[vo]['VORoles']: + groups = result['Value']['VOMSDIRAC'].get('/%s/%s' % (vo, role)) + if groups: + credDict['DIRACGroups'] = list(set(credDict.get('DIRACGroups', []) + groups)) + return credDict + + def deviceAuthorization(self, group=None): + """ Authorizaion through DeviceCode flow + """ + result = self.submitDeviceCodeAuthorizationFlow(group) + if not result['OK']: + return result + response = result['Value'] + + # Notify user to go to authorization endpoint + if response.get('verification_uri_complete'): + showURL = 'Use next link to continue\n%s' % response['verification_uri_complete'] + else: + showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], + response['verification_uri']) + self.log.notice(showURL) + try: + return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) + except KeyboardInterrupt: + return S_ERROR('User canceled the operation..') + + def submitNewSession(self, pkce=True): + """ Submit new authorization session + + :param bool pkce: use PKCE + + :return: S_OK(str)/S_ERROR() + """ + session = {} + params = dict(state=generate_token(10)) + # Create PKCE verifier + if pkce: + session['code_verifier'] = generate_token(48) + params['code_challenge_method'] = 'S256' + params['code_challenge'] = create_s256_code_challenge(session['code_verifier']) + url, state = self.create_authorization_url(self.get_metadata('authorization_endpoint'), **params) + return url, state, session + + def parseAuthResponse(self, response, session=None): + """ Make user info dict: + + :param dict response: response on request to get user profile + :param object session: session + + :return: S_OK((dict, dict))/S_ERROR() + """ + response = createOAuth2Request(response) + + self.log.debug('Try to parse authentication response:', pprint.pformat(response.data)) + + if not session: + session = {} + + self.log.debug('Current session is:\n', pprint.pformat(session)) + + self.fetchToken(authorization_response=response.uri, code_verifier=session.get('code_verifier')) + + result = self.verifyToken(self.token['access_token']) + if not result['OK']: + return result + payload = result['Value'] + result = self.researchGroup(payload) + if not result['OK']: + return result + credDict = result['Value'] + self.log.debug('Got response dictionary:\n', pprint.pformat(credDict)) + + # Store token + self.token['user_id'] = credDict['ID'] + + return S_OK((credDict, payload)) + + def submitDeviceCodeAuthorizationFlow(self, group=None): + """ Submit authorization flow + + :return: S_OK(dict)/S_ERROR() -- dictionary with device code flow response + """ + groupScopes = [] + if group: + result = self.getGroupScopes(group) + if not result['OK']: + return result + groupScopes = result['Value'] + + try: + r = requests.post(self.get_metadata('device_authorization_endpoint'), data=dict( + client_id=self.client_id, scope=list_to_scope(scope_to_list(self.scope) + groupScopes) + ), verify=self.verify) + r.raise_for_status() + deviceResponse = r.json() + if 'error' in deviceResponse: + return S_ERROR('%s: %s' % (deviceResponse['error'], deviceResponse.get('description', ''))) + + # Check if all main keys are present here + for k in ['user_code', 'device_code', 'verification_uri']: + if not deviceResponse.get(k): + return S_ERROR('Mandatory %s key is absent in authentication response.' % k) + + return S_OK(deviceResponse) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer, timeout.') + except requests.exceptions.RequestException as ex: + return S_ERROR(repr(ex)) + except Exception as ex: + return S_ERROR('Cannot read authentication response: %s' % repr(ex)) + + def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, timeout=300): + """ Submit waiting loop process, that will monitor current authorization session status + + :param str deviceCode: received device code + :param int interval: waiting interval + :param int timeout: max time of waiting + + :return: S_OK(dict)/S_ERROR() - dictionary contain access/refresh token and some metadata + """ + __start = time.time() + + self.log.notice('Authorization pending.. (use CNTL + C to stop)') + while True: + time.sleep(int(interval)) + if time.time() - __start > timeout: + return S_ERROR('Time out.') + r = requests.post(self.get_metadata('token_endpoint'), data=dict(client_id=self.client_id, + grant_type=DEVICE_CODE_GRANT_TYPE, device_code=deviceCode), verify=self.verify) + token = r.json() + if not token: + return S_ERROR('Resived token is empty!') + if 'error' not in token: + self.token = token + return S_OK(token) + if token['error'] != 'authorization_pending': + return S_ERROR((token.get('error') or 'unknown') + ' : ' + (token.get('error_description') or '')) + + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPRole') + if not idPScope: + return S_ERROR('Cannot find role for %s' % group) + return S_OK(scope_to_list(idPScope)) + + def getScopeGroups(self, scope): + """ Get scope groups + + :param str scope: scope + + :return: list + """ + groups = [] + for group in getAllGroups(): + result = self.getGroupScopes(group) + if not result['OK']: + # Skip DIRAAC group without scope parameter + self.log.debug(result['Message']) + continue + if set(result['Value']).issubset(scope_to_list(scope)): + groups.append(group) + return groups + + def getUserProfile(self): + """ Get user profile + + :return: S_OK()/S_ERROR() + """ + try: + return S_OK(self.get(self.get_metadata('userinfo_endpoint')).json()) + except Exception as e: + self.log.exception(e) + return S_ERROR('Cannot get user profile: %s' % repr(e)) + + def exchange_token(self, url, subject_token=None, subject_token_type=None, body='', + refresh_token=None, access_token=None, auth=None, headers=None, **kwargs): + """ Exchange a new access token + + :param url: Exchange Token endpoint, must be HTTPS. + :param str subject_token: subject_token + :param str subject_token_type: token type https://tools.ietf.org/html/rfc8693#section-3 + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param str refresh_token: refresh token + :param str access_token: access token + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: A :class:`OAuth2Token` object (a dict too). + """ + session_kwargs = self._extract_session_request_params(kwargs) + refresh_token = refresh_token or self.token.get('refresh_token') + access_token = access_token or self.token.get('access_token') + subject_token = subject_token or refresh_token + subject_token_type = subject_token_type or 'urn:ietf:params:oauth:token-type:refresh_token' + if 'scope' not in kwargs and self.scope: + kwargs['scope'] = self.scope + body = prepare_token_request('urn:ietf:params:oauth:grant-type:token-exchange', body, + subject_token=subject_token, subject_token_type=subject_token_type, **kwargs) + + if headers is None: + headers = DEFAULT_HEADERS + + for hook in self.compliance_hook.get('exchange_token_request', []): + url, headers, body = hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.token_endpoint_auth_method) + + return self._exchange_token(url, refresh_token=refresh_token, body=body, headers=headers, + auth=auth, **session_kwargs) + + def _exchange_token(self, url, body='', refresh_token=None, headers=None, auth=None, **kwargs): + resp = self.session.post(url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook.get('exchange_token_response', []): + resp = hook(resp) + + token = self.parse_response_token(resp.json()) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if callable(self.update_token): + self.update_token(self.token, refresh_token=refresh_token) + + return self.token diff --git a/src/DIRAC/Resources/IdProvider/Utilities.py b/src/DIRAC/Resources/IdProvider/Utilities.py new file mode 100644 index 00000000000..05822524ead --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/Utilities.py @@ -0,0 +1,86 @@ +""" Utilities for the IdProvider package +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from DIRAC import S_OK, S_ERROR, gConfig + + +def getSettingsNamesForIdPIssuer(issuer): + """ Get identity provider for issuer + + :param str issuer: issuer + + :return: S_OK(str)/S_ERROR() + """ + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for name in result['Value']: + nameIssuer = gConfig.getValue('/Resources/IdProviders/%s/issuer' % name) + if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): + return S_OK(name) + return S_ERROR('Not found provider with %s issuer.' % issuer) + + +def getSettingsNamesForClientID(clientID): + """ Get identity providers for clientID + + :param str clientID: clientID + + :return: S_OK(list)/S_ERROR() + """ + names = [] + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for name in result['Value']: + res = gConfig.getValue('/Resources/IdProviders/%s/client_id' % name) + if res and clientID == res: + names.append(name) + return S_OK(names) if names else S_ERROR('Not found provider with %s clientID.' % clientID) + + +def getProvidersForInstance(instance, providerType=None): + """ Get providers for instance + + :param str instance: instance of what this providers + :param str providerType: provider type + + :return: S_OK(list)/S_ERROR() + """ + providers = [] + instance = "%sProviders" % instance + result = gConfig.getSections('/Resources/%s' % instance) + + # Return an empty list if the section does not exist + if not result['OK'] or not result['Value'] or not providerType: + return result + + for prov in result['Value']: + if providerType == gConfig.getValue('/Resources/%s/%s/ProviderType' % (instance, prov)): + providers.append(prov) + return S_OK(providers) + + +def getProviderInfo(provider): + """ Get provider info + + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() + """ + result = gConfig.getSections('/Resources') + if not result['OK']: + return result + for section in result['Value']: + if section.endswith('Providers'): + result = getProvidersForInstance(section[:-9]) + if not result['OK']: + return result + if provider in result['Value']: + return gConfig.getOptionsDictRecursively("/Resources/%s/%s/" % (section, provider)) + return S_ERROR('%s provider not found.' % provider) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py new file mode 100644 index 00000000000..a52cd2094b0 --- /dev/null +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -0,0 +1,225 @@ +""" This is a test of the AuthDB. Requires authlib + It supposes that the DB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import time +import pytest + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +if six.PY3: + # DIRACOS not contain required packages + from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey + from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + db = AuthDB() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_RefreshToken(): + """ Try to revoke/save/get refresh tokens + """ + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_in=int(time.time()) + 3600) + + preset_jti = '123' + + # Remove refresh token + result = db.revokeRefreshToken(preset_jti) + assert result['OK'], result['Message'] + + # Store tokens + result = db.storeRefreshToken(DToken.copy(), preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['iat'] <= int(time.time()) + + result = db.storeRefreshToken(New_DToken.copy()) + assert result['OK'], result['Message'] + assert result['Value']['jti'] + assert result['Value']['iat'] <= int(time.time()) + + token_id = result['Value']['jti'] + issued_at = result['Value']['iat'] + + # Check token + result = db.getCredentialByRefreshToken(preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == token_id + assert int(result['Value']['issued_at']) == issued_at + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] + + # Check token after request + for jti in [token_id, preset_jti]: + result = db.getCredentialByRefreshToken(jti) + assert result['OK'], result['Message'] + assert not result['Value'] + + # Renew tokens + result = db.storeRefreshToken(New_DToken.copy(), token_id) + assert result['OK'], result['Message'] + + # Revoke token + result = db.revokeRefreshToken(token_id) + assert result['OK'], result['Message'] + + # Check token + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert not result['Value'] + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_keys(): + """ Try to store/get/remove keys + """ + # JWS + jws = JsonWebSignature(algorithms=['RS256']) + code_payload = {'user_id': 'user', + 'scope': 'scope', + 'client_id': 'client', + 'redirect_uri': 'redirect_uri', + 'code_challenge': 'code_challenge'} + + # Token metadata + header = {'alg': 'RS256'} + payload = {'sub': 'user', + 'iss': 'issuer', + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + # Remove all keys + result = db.removeKeys() + assert result['OK'], result['Message'] + + # Check active keys + result = db.getActiveKeys() + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Create new one + result = db.getPrivateKey() + assert result['OK'], result['Message'] + + private_key = result['Value'] + assert isinstance(private_key, RSAKey) + + # Sign token + header['kid'] = private_key.thumbprint() + + # Find key by KID + result = db.getPrivateKey(header['kid']) + assert result['OK'], result['Message'] + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0: + assert result['Value'].as_dict(True) == private_key.as_dict(True) + + # Sign token + token = jwt.encode(header, payload, private_key) + # Sign auth code + code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) + + # Get public key set + result = db.getKeySet() + keyset = result['Value'] + assert result['OK'], result['Message'] + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0: + assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) + + # Read token + _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) + assert _payload == payload + # Read auth code + data = jws.deserialize_compact(code, keyset.keys[0]) + _code_payload = json_loads(urlsafe_b64decode(data['payload'])) + assert _code_payload == code_payload + + +# DIRACOS not contain required packages +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_Sessions(): + """ Try to store/get/remove Sessions + """ + # Example of the new session metadata + sData1 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/auth/device', + 'verification_uri_complete': u'https://domain.com/auth/device?user_code=MDKP-MXMF'} + + # Example of the updated session + sData2 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/auth/device', + 'verification_uri_complete': u'https://domain.com/auth/device?user_code=MDKP-MXMF', + 'user_id': 'username'} + + # Remove old session + db.removeSession(sData1['id']) + + # Add session + result = db.addSession(sData1) + assert result['OK'], result['Message'] + + # Get session + result = db.getSessionByUserCode(sData1['user_code']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData1['device_code'] + assert result['Value'].get('user_id') != sData2['user_id'] + + # Update session + result = db.updateSession(sData2, sData1['id']) + assert result['OK'], result['Message'] + + # Get session + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData2['device_code'] + assert result['Value']['user_id'] == sData2['user_id'] + + # Remove session + result = db.removeSession(sData2['id']) + assert result['OK'], result['Message'] + + # Make sure that the session is absent + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert not result['Value'] diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py new file mode 100644 index 00000000000..7358ffe0ee4 --- /dev/null +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -0,0 +1,118 @@ +""" This is a test of the AuthServer. Requires authlib, pyjwt, dominate + It supposes that the AuthDB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import time +import pytest +from mock import MagicMock + +from diraccfg import CFG + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +import DIRAC +from DIRAC import S_OK, S_ERROR, gConfig + +if six.PY3: + # DIRACOS not contain required packages + from DIRAC.FrameworkSystem.private.authorization import AuthServer + + +class Proxy(object): + def dumpAllToString(self): + return S_OK('proxy') + + +class ProxyManagerClient(object): + def downloadProxy(self, *args, **kwargs): + return S_OK(Proxy()) + + +class TokenManagerClient(object): + def getToken(self, *args, **kwargs): + return S_OK({'access_token': 'token', 'refresh_token': 'token'}) + + +mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) +mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) +mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) +mockisDownloadablePersonalProxy = MagicMock(return_value=True) + + +@pytest.fixture +def auth_server(monkeypatch): + cfg = CFG() + cfg.loadFromBuffer(""" + DIRAC + { + Security + { + Authorization + { + issuer = https://issuer.url/ + } + } + } + """) + gConfig.loadCFG(cfg) + if six.PY3: + monkeypatch.setattr(AuthServer, "getIdPForGroup", mockgetIdPForGroup) + monkeypatch.setattr(AuthServer, "getDNForUsername", mockgetDNForUsername) + monkeypatch.setattr(AuthServer, "getUsernameForDN", mockgetUsernameForDN) + monkeypatch.setattr(AuthServer, "ProxyManagerClient", ProxyManagerClient) + monkeypatch.setattr(AuthServer, "TokenManagerClient", TokenManagerClient) + monkeypatch.setattr(AuthServer, "isDownloadablePersonalProxy", mockisDownloadablePersonalProxy) + return AuthServer.AuthServer() + return None + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_metadata(auth_server): + """ Check metadata + """ + assert auth_server.metadata.get('issuer') + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_queryClient(auth_server): + """ Try to search some default client + """ + assert not auth_server.query_client('not_exist_client') + assert auth_server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +@pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ + ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), + ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), +]) +def test_generateToken(auth_server, client, grant, user, scope, expires_in, refresh_token, instance, result): + """ Generate tokens + """ + from authlib.oauth2.base import OAuth2Error + cli = auth_server.query_client(client) + try: + assert auth_server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result + except OAuth2Error as e: + assert False, str(e) + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_writeReadRefreshToken(auth_server): + """ Try to search some default client + """ + result = auth_server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) + assert result['OK'], result['Message'] + token = result['Value'] + assert token.get('access_token') == 'token' + assert token.get('refresh_token') != 'token' + + result = auth_server.readToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value'].get('jti') + assert result['Value'].get('iat') diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py new file mode 100644 index 00000000000..245edd33363 --- /dev/null +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -0,0 +1,85 @@ +""" This is a test of the AuthDB. Requires authlib, pyjwt + It supposes that the DB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import time +import pytest + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + +exp_payload = payload.copy() +exp_payload['iat'] = int(time.time()) - 10 +exp_payload['exp'] = int(time.time()) - 10 + + +if six.PY3: + # DIRACOS not contain required packages + from authlib.jose import jwt + from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB + db = TokenDB() + + +# DIRACOS not contain required packages +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_Token(): + """ Try to revoke/save/get tokens + """ + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + + Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + expires_at=int(time.time()) - 10, + rt_expires_at=int(time.time()) - 10) + + # Remove all tokens + result = db.removeToken(user_id=123) + assert result['OK'], result['Message'] + + # Store tokens + result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Expired token + result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert not result['OK'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + # Store new tokens + result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + # Must return old tokens + assert len(result['Value']) == 1 + assert result['Value'][0]['access_token'] == DToken['access_token'] + assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py new file mode 100644 index 00000000000..005c383925a --- /dev/null +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -0,0 +1,106 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import time +import pytest +import unittest + +from diraccfg import CFG + +import DIRAC +from DIRAC import gConfig + +cfg = CFG() +cfg.loadFromBuffer(""" +DIRAC +{ + Security + { + Authorization + { + issuer = https://issuer.url/ + Clients + { + DIRACWeb + { + client_id = client_identificator + client_secret = client_secret_key + redirect_uri = https://redirect.url/ + } + } + } + } +} +Resources +{ + IdProviders + { + SomeIdP + { + ProviderType = OAuth2 + issuer = https://idp.url/ + client_id = IdP_client_id + client_secret = IdP_client_secret + redirect_uri = https://dirac/redirect + jwks_uri = https://idp.url/jwk + scope = openid+profile+offline_access+eduperson_entitlement + } + } +} +""") +gConfig.loadCFG(cfg) + +if six.PY3: + # DIRACOS not contain required packages + from authlib.jose import jwt + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS + idps = IdProviderFactory() + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_getDIRACClients(): + """ Try to load default DIRAC authorization client + """ + # Try to get DIRAC client authorization settings + result = idps.getIdProvider('DIRACCLI') + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + + # Try to get DIRAC client authorization settings for Web portal + result = idps.getIdProvider('DIRACWeb') + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == 'client_identificator' + assert result['Value'].client_secret == 'client_secret_key' + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_getIdPClients(): + """ Try to load external identity provider settings + """ + # Try to get identity provider by name + result = idps.getIdProvider('SomeIdP', jwks='my_jwks') + assert result['OK'], result['Message'] + assert result['Value'].jwks == 'my_jwks' + assert result['Value'].issuer == 'https://idp.url/' + assert result['Value'].client_id == 'IdP_client_id' + assert result['Value'].client_secret == 'IdP_client_secret' + assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' + + # Try to get identity provider for token issued by it + result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( + sub='user', + iss=result['Value'].issuer, + iat=int(time.time()), + exp=int(time.time()) + (12 * 3600), + ), "secret").decode('utf-8')) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://idp.url/' diff --git a/tests/Integration/Resources/IdProvider/__init__.py b/tests/Integration/Resources/IdProvider/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/Integration/all_integration_server_tests.sh b/tests/Integration/all_integration_server_tests.sh index d98f978a193..aef91e6b90e 100644 --- a/tests/Integration/all_integration_server_tests.sh +++ b/tests/Integration/all_integration_server_tests.sh @@ -26,7 +26,10 @@ pytest "${THIS_DIR}/Core/Test_MySQLDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( #-------------------------------------------------------------------------------# echo -e "*** $(date -u) **** FRAMEWORK TESTS (partially skipped) ****\n" pytest "${THIS_DIR}/Framework/Test_InstalledComponentsDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) -python "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_TokenDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_AuthServer.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #pytest ${THIS_DIR}/Framework/Test_LoggingDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #-------------------------------------------------------------------------------# diff --git a/tests/Jenkins/utilities.sh b/tests/Jenkins/utilities.sh index 0fee47ce6d5..08bd24da217 100644 --- a/tests/Jenkins/utilities.sh +++ b/tests/Jenkins/utilities.sh @@ -801,7 +801,7 @@ diracServices(){ echo '==> [diracServices]' # Ignore tornado services - local services=$(cut -d '.' -f 1 < services | grep -v Tornado | grep -v PilotsLogging | grep -v StorageElementHandler | grep -v ^ConfigurationSystem | grep -v Plotting | grep -v RAWIntegrity | grep -v RunDBInterface | grep -v ComponentMonitoring | sed 's/System / /g' | sed 's/Handler//g' | sed 's/ /\//g') + local services=$(cut -d '.' -f 1 < services | grep -v Tornado | grep -v TokenManager | grep -v PilotsLogging | grep -v StorageElementHandler | grep -v ^ConfigurationSystem | grep -v Plotting | grep -v RAWIntegrity | grep -v RunDBInterface | grep -v ComponentMonitoring | sed 's/System / /g' | sed 's/Handler//g' | sed 's/ /\//g') # group proxy, will be uploaded explicitly # echo '==> getting/uploading proxy for prod' @@ -852,7 +852,7 @@ diracUninstallServices(){ findServices # Ignore tornado services - local services=$(cut -d '.' -f 1 getting/uploading proxy for prod'