From ee7b62e89ea2227c8c4f1440df3c381407640ee6 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 19 May 2021 20:45:49 +0200 Subject: [PATCH 001/178] add identity providers --- dirac.cfg | 28 + .../Client/Helpers/Registry.py | 10 + .../Client/Helpers/Resources.py | 141 +--- .../ConfigurationSystem/Client/PathFinder.py | 8 +- .../ConfigurationSystem/Client/Utilities.py | 53 +- src/DIRAC/Core/Security/Locations.py | 17 + src/DIRAC/Core/Security/TokenFile.py | 99 +++ src/DIRAC/Core/Security/TokenInfo.py | 76 ++ .../Client/private/TornadoBaseClient.py | 36 +- .../Core/Tornado/Server/BaseRequestHandler.py | 790 ++++++++++++++++++ .../Core/Tornado/Server/HandlerManager.py | 257 ++++-- src/DIRAC/Core/Tornado/Server/TornadoREST.py | 80 ++ .../Core/Tornado/Server/TornadoServer.py | 166 ++-- .../Core/Tornado/Server/TornadoService.py | 481 +---------- .../Core/Tornado/scripts/tornado_start_AS.py | 55 ++ .../Core/Tornado/scripts/tornado_start_CS.py | 2 +- .../Core/Tornado/scripts/tornado_start_web.py | 66 ++ src/DIRAC/FrameworkSystem/API/AuthHandler.py | 527 ++++++++++++ src/DIRAC/FrameworkSystem/API/__init__.py | 0 src/DIRAC/FrameworkSystem/ConfigTemplate.cfg | 9 + src/DIRAC/FrameworkSystem/DB/AuthDB.py | 385 +++++++++ src/DIRAC/FrameworkSystem/DB/AuthDB.sql | 2 + .../FrameworkSystem/scripts/dirac_login.py | 198 +++++ .../Resources/IdProvider/CheckInIdProvider.py | 30 + .../Resources/IdProvider/DIRACIdProvider.py | 21 + .../Resources/IdProvider/IAMIdProvider.py | 17 + src/DIRAC/Resources/IdProvider/IdProvider.py | 17 +- .../Resources/IdProvider/IdProviderFactory.py | 92 +- .../Resources/IdProvider/OAuth2IdProvider.py | 442 ++++++++++ tests/Integration/Framework/Test_AuthDB.py | 200 +++++ 30 files changed, 3579 insertions(+), 726 deletions(-) create mode 100644 src/DIRAC/Core/Security/TokenFile.py create mode 100644 src/DIRAC/Core/Security/TokenInfo.py create mode 100644 src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py create mode 100644 src/DIRAC/Core/Tornado/Server/TornadoREST.py create mode 100644 src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py create mode 100644 src/DIRAC/Core/Tornado/scripts/tornado_start_web.py create mode 100644 src/DIRAC/FrameworkSystem/API/AuthHandler.py create mode 100644 src/DIRAC/FrameworkSystem/API/__init__.py create mode 100644 src/DIRAC/FrameworkSystem/DB/AuthDB.py create mode 100644 src/DIRAC/FrameworkSystem/DB/AuthDB.sql create mode 100644 src/DIRAC/FrameworkSystem/scripts/dirac_login.py create mode 100644 src/DIRAC/Resources/IdProvider/CheckInIdProvider.py create mode 100644 src/DIRAC/Resources/IdProvider/DIRACIdProvider.py create mode 100644 src/DIRAC/Resources/IdProvider/IAMIdProvider.py create mode 100644 src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py create mode 100644 tests/Integration/Framework/Test_AuthDB.py diff --git a/dirac.cfg b/dirac.cfg index 9af77571bde..502a64166e0 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -57,6 +57,9 @@ Registry # Real VOMS VO name, if this VO is associated with VOMS VO VOMSName = lhcb + # Registered identity provider associated with VO + IdP = ChechIn + # Section to describe all the VOMS servers that can be used with the given VOMS VO VOMSServers { @@ -99,6 +102,9 @@ Registry # Role of the users in the VO VOMSRole = /lhcb + + # Scope associated with a role of the user in the VO + IdPScope = some_special_scope # Virtual organization associated with the group VOMSVO = lhcb @@ -418,6 +424,28 @@ Systems } Resources { + IdProviders + { + # WebAppDIRAC + # { + # # This type describe DIRAC authorization server client + # ProviderType = DIRAC + # issuer = https://dirac.egi.eu/DIRAC/auth + # client_id = type_client_id_here_receved_after_client_registration + # client_secret = type_client_secret_here_receved_after_client_registration + # } + CheckIn + { + # What supported type of provider does it belong to + ProviderType = OAuth2 + # Description of the client parameters registered on the identity provider side. + # Look here for information about client parameters description https://tools.ietf.org/html/rfc8414#section-2 + issuer = https://aai-dev.egi.eu/oidc + client_id = type_client_id_here_receved_after_client_registration + client_secret = type_client_secret_here_receved_after_client_registration + scope = openid, profile, offline_access, eduperson_entitlement, cert_entitlement + } + } # Section for proxy providers, subsections is the names of the proxy providers # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Resources/proxyprovider.html diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index e9377b1915a..1d6faeb7374 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -428,6 +428,16 @@ def getVOForGroup(group): return getVO() or gConfig.getValue("%s/Groups/%s/VO" % (gBaseRegistrySection, group), "") +def getIdPForGroup(group): + """ Get identity provider for group VO + + :param str group: group name + + :return: str + """ + return getVOOption(getVOForGroup(group), 'IdP') + + def getDefaultVOMSAttribute(): """ Get default VOMS attribute diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index 631c9b6f49e..bb1e85de049 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -430,6 +430,24 @@ def getFilterConfig(filterID): return gConfig.getOptionsDict('Resources/LogFilters/%s' % filterID) +def getSettingsNamesForIdPIssuer(issuer): + """ Get identity providers for issuer + + :param str issuer: issuer + + :return: S_OK(list)/S_ERROR() + """ + names = [] + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for name in result['Value']: + nameIssuer = gConfig.getValue('%s/IdProviders/%s/issuer' % (gBaseResourcesSection, name)) + if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): + names.append(name) + return S_OK(names) if names else S_ERROR('Not found provider wwith %s issuer.' % issuer) + + def getInfoAboutProviders(of=None, providerName=None, option='', section=''): """ Get the information about providers @@ -474,117 +492,40 @@ def getInfoAboutProviders(of=None, providerName=None, option='', section=''): section, option))) -def findGenericCloudCredentials(vo=False, group=False): - """ Get the cloud credentials to use for a specific VO and/or group. """ - if not group and not vo: - return S_ERROR("Need a group or a VO to determine the Generic cloud credentials") - if not vo: - vo = Registry.getVOForGroup(group) - if not vo: - return S_ERROR("Group %s does not have a VO associated" % group) - opsHelper = Operations.Operations(vo=vo) - cloudGroup = opsHelper.getValue("Cloud/GenericCloudGroup", "") - cloudDN = opsHelper.getValue("Cloud/GenericCloudDN", "") - if not cloudDN: - cloudUser = opsHelper.getValue("Cloud/GenericCloudUser", "") - if cloudUser: - result = Registry.getDNForUsername(cloudUser) - if result['OK']: - cloudDN = result['Value'][0] - else: - return S_ERROR("Failed to find suitable CloudDN") - if cloudDN and cloudGroup: - gLogger.verbose("Cloud credentials from CS: %s@%s" % (cloudDN, cloudGroup)) - result = gProxyManager.userHasProxy(cloudDN, cloudGroup, 86400) - if not result['OK']: - return result - return S_OK((cloudDN, cloudGroup)) - return S_ERROR("Cloud credentials not found") +def getProvidersForInstance(instance, providerType=None): + """ Get providers for instance + :param str instance: instance of what this providers + :param str providerType: provider type -def getVMTypes(siteList=None, ceList=None, vmTypeList=None, vo=None): - """ Get CE/vmType options filtered by the provided parameters. + :return: S_OK(list)/S_ERROR() """ + data = [] + instance = "%sProviders" % instance + result = gConfig.getSections(gBaseResourcesSection) + if result['OK']: + if instance not in result['Value']: + return S_OK(data) + result = gConfig.getSections('%s/%s' % (gBaseResourcesSection, instance)) - result = gConfig.getSections('/Resources/Sites') - if not result['OK']: + # Return an empty list if the section does not exist + if not result['OK'] or not result['Value'] or not providerType: return result - resultDict = {} + for prov in result['Value']: + if providerType == gConfig.getValue('%s/%s/%s/ProviderType' % (gBaseResourcesSection, instance, prov)): + data.append(prov) + return S_OK(data) - grids = result['Value'] - for grid in grids: - result = gConfig.getSections('/Resources/Sites/%s' % grid) - if not result['OK']: - continue - sites = result['Value'] - for site in sites: - if siteList is not None and site not in siteList: - continue - if vo: - voList = gConfig.getValue('/Resources/Sites/%s/%s/VO' % (grid, site), []) - if voList and vo not in voList: - continue - result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud' % (grid, site)) - if not result['OK']: - continue - ces = result['Value'] - for ce in ces: - if ceList is not None and ce not in ceList: - continue - if vo: - voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/VO' % (grid, site, ce), []) - if voList and vo not in voList: - continue - result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s' % (grid, site, ce)) - if not result['OK']: - continue - ceOptionsDict = result['Value'] - result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud/%s/VMTypes' % (grid, site, ce)) - if not result['OK']: - result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud/%s/Images' % (grid, site, ce)) - if not result['OK']: - return result - vmTypes = result['Value'] - for vmType in vmTypes: - if vmTypeList is not None and vmType not in vmTypeList: - continue - if vo: - voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/VMTypes/%s/VO' % (grid, site, ce, vmType), []) - if not voList: - voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/Images/%s/VO' % (grid, site, ce, vmType), []) - if voList and vo not in voList: - continue - resultDict.setdefault(site, {}) - resultDict[site].setdefault(ce, ceOptionsDict) - resultDict[site][ce].setdefault('VMTypes', {}) - result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s/VMTypes/%s' % (grid, site, ce, vmType)) - if not result['OK']: - result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s/Images/%s' % (grid, site, ce, vmType)) - if not result['OK']: - continue - vmTypeOptionsDict = result['Value'] - resultDict[site][ce]['VMTypes'][vmType] = vmTypeOptionsDict - return S_OK(resultDict) +def getProviderInfo(provider): + """ Get provider info + :param str provider: provider -def getVMTypeConfig(site, ce='', vmtype=''): - """ Get the VM image type parameters of the specified queue + :return: S_OK(dict)/S_ERROR() """ - tags = [] - grid = site.split('.')[0] - if not ce: - result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud' % (grid, site)) - if not result['OK']: - return result - ceList = result['Value'] - if len(ceList) == 1: - ce = ceList[0] - else: - return S_ERROR('No cloud endpoint specified') - - result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s' % (grid, site, ce)) + result = gConfig.getSections(gBaseResourcesSection) if not result['OK']: return result resultDict = result['Value'] diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index 20e63aac9f7..9a38fe9c745 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -131,8 +131,12 @@ def getExecutorSection(system, executorName=None, component=False, setup=False): return getComponentSection(system, component=executorName, setup=setup, componentCategory="Executors") -def getDatabaseSection(system, dbName=False, setup=False): - """ Get DB section in a system +def getAPISection(APIName, APITuple=False, setup=False): + return getComponentSection(APIName, APITuple, setup, "APIs") + + +def getServiceSection(serviceName, serviceTuple=False, setup=False): + return getComponentSection(serviceName, serviceTuple, setup, "Services") :param str system: system name :param str dbName: DB name diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index d4b9cef6bab..7e285c77cb6 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -536,16 +536,6 @@ def getElasticDBParameters(fullname): return S_OK(parameters) -def getOAuthAPI(instance='Production'): - """ Get OAuth API url - - :param str instance: instance - - :return: str - """ - return gConfig.getValue("/Systems/Framework/%s/URLs/OAuthAPI" % instance) - - def getDIRACGOCDictionary(): """ Create a dictionary containing DIRAC site names and GOCDB site names @@ -578,3 +568,46 @@ def getDIRACGOCDictionary(): log.debug('End function.') return S_OK(dictionary) + + +def getAuthAPI(): + """ Get Auth REST API url + + :return: str + """ + return gConfig.getValue("/Systems/Framework/%s/URLs/AuthAPI" % getSystemInstance("Framework")) + + +def getAuthorisationServerMetadata(issuer=None): + """ Get authoraisation server metadata + + :return: S_OK(dict)/S_ERROR() + """ + data = {'issuer': issuer} + + result = gConfig.getSections('/DIRAC') + if result['OK']: + if 'Authorization' in result['Value']: + result = gConfig.getOptionsDictRecursively('/DIRAC/Authorization') + if result['OK']: + data.update(result['Value']) + if not result['OK']: + return result + if not data['issuer']: + data['issuer'] = getAuthAPI() + if not data['issuer']: + return S_ERROR('No issuer found in DIRAC authorization server configuration.') + + # Search values with type list + for key, v in data.items(): + data[key] = [e for e in v.replace(', ', ',').split(',') if e] if ',' in v else v + return S_OK(data) + + +def isDownloadablePersonalProxy(): + """ Get downloadablePersonalProxy flag + + :return: S_OK(bool)/S_ERROR() + """ + cs_path = '/Systems/Framework/%s/APIs/Auth' % getSystemInstance("Framework") + return gConfig.getOption(cs_path + '/downloadablePersonalProxy') diff --git a/src/DIRAC/Core/Security/Locations.py b/src/DIRAC/Core/Security/Locations.py index fe9e575aca6..8c876caaf88 100644 --- a/src/DIRAC/Core/Security/Locations.py +++ b/src/DIRAC/Core/Security/Locations.py @@ -12,6 +12,23 @@ g_SecurityConfPath = "/DIRAC/Security" +def getTokenLocation(): + """ Get the path of the currently active access token file + """ + envVar = 'DIRAC_TOKEN_FILE' + if envVar in os.environ: + tokenPath = os.path.realpath(os.environ[envVar]) + if os.path.isfile(tokenPath): + return tokenPath + # /tmp/JWTup_u + tokenName = "JWTup_u%d" % os.getuid() + if os.path.isfile("/tmp/%s" % tokenName): + return "/tmp/%s" % tokenName + + # No access token found + return False + + def getProxyLocation(): """ Get the path of the currently active grid proxy file """ diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py new file mode 100644 index 00000000000..2001ea1a5a7 --- /dev/null +++ b/src/DIRAC/Core/Security/TokenFile.py @@ -0,0 +1,99 @@ +""" Collection of utilities for dealing with security files (i.e. token files) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import os +import json +import stat +import tempfile + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Security.Locations import getTokenLocation + + +def readTokenFromFile(fileName=None): + """ Read token from a file + + :param str fileName: filename to read + + :return: S_OK(dict)/S_ERROR() + """ + if not fileName: + fileName = getTokenLocation() or os.environ.get('DIRAC_TOKEN_FILE', "/tmp/JWTup_u%d" % os.getuid()) + try: + with open(fileName, 'r') as f: + data = f.read() + return S_OK(json.loads(data)) + except Exception as e: + return S_ERROR('Cannot read token.') + + +def writeToTokenFile(tokenContents, fileName=False): + """ Write a token string to file + + :param str tokenContents: token as string + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + if not fileName: + try: + fd, tokenLocation = tempfile.mkstemp() + os.close(fd) + except IOError: + return S_ERROR(DErrno.ECTMPF) + fileName = tokenLocation + try: + with open(fileName, 'wb') as fd: + fd.write(tokenContents) + except Exception as e: + return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e).replace(',)', ')'))) + try: + os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e).replace(',)', ')'))) + return S_OK(fileName) + + +def writeTokenDictToTokenFile(tokenDict, fileName=None): + """ Write a token dict to file + + :param dict tokenDict: dict object to dump to file + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + if not fileName: + fileName = getTokenLocation() or os.environ.get('DIRAC_TOKEN_FILE', "/tmp/JWTup_u%d" % os.getuid()) + try: + retVal = json.dumps(tokenDict) + except Exception as e: + return S_ERROR('Cannot read token.') + return writeToTokenFile(retVal, fileName) + + +def writeTokenDictToTemporaryFile(tokenDict): + """ Write a token dict to a temporary file + + :param dict tokenDict: dict object to dump to file + + :return: S_OK(str)/S_ERROR() -- contain file name + """ + try: + fd, tokenLocation = tempfile.mkstemp() + os.close(fd) + except IOError: + return S_ERROR(DErrno.ECTMPF) + retVal = writeTokenDictToTokenFile(tokenDict, tokenLocation) + if not retVal['OK']: + try: + os.unlink(tokenLocation) + except Exception: + pass + return retVal + return S_OK(tokenLocation) diff --git a/src/DIRAC/Core/Security/TokenInfo.py b/src/DIRAC/Core/Security/TokenInfo.py new file mode 100644 index 00000000000..39efe2e5f23 --- /dev/null +++ b/src/DIRAC/Core/Security/TokenInfo.py @@ -0,0 +1,76 @@ +""" + Set of utilities to retrieve Information from proxy +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import jwt as _jwt +import six +import time + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Security import Locations + +from DIRAC.Core.Security.TokenFile import readTokenFromFile +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + + +def getTokenInfo(token=False): + """ Return token info + + :param token: token location or token as dict + + :return: S_OK(dict)/S_ERROR() + """ + # Discover token location + if isinstance(token, dict): + token = OAuth2Token(token) + else: + tokenLocation = token if isinstance(token, six.string_types) else Locations.getTokenLocation() + if not tokenLocation: + return S_ERROR("Cannot find token location.") + result = readTokenFromFile() + if not result['OK']: + return result + token = OAuth2Token(result['Value']) + + payload = _jwt.decode(token['access_token'], options=dict(verify_signature=False)) + result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) + if not result['OK']: + return result + payload['username'] = result['Value'] + if payload.get('group'): + payload['properties'] = Registry.getPropertiesForGroup(payload['group']) + return S_OK(payload) + + +def formatTokenInfoAsString(infoDict): + """ Convert a token infoDict into a string + + :param dict infoDict: info + + :return: str + """ + secs = int(infoDict['exp']) - time.time() + hours = int(secs / 3600) + secs -= hours * 3600 + mins = int(secs / 60) + secs -= mins * 60 + exp = "%02d:%02d:%02d" % (hours, mins, secs) + + leftAlign = 13 + contentList = [] + contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) + contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) + contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), exp)) + contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) + if infoDict.get('group'): + contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) + if infoDict.get('properties'): + contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) + return "\n".join(contentList) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 5f2fb1b7a95..f84d0d05fd0 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -34,6 +34,7 @@ import errno import requests import six +import os from six.moves import http_client @@ -46,7 +47,8 @@ from DIRAC.Core.DISET.ThreadConfig import ThreadConfig from DIRAC.Core.Security import Locations -from DIRAC.Core.Utilities import Network +from DIRAC.Core.Security.TokenFile import readTokenFromFile +from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode @@ -62,6 +64,7 @@ class TornadoBaseClient(object): __threadConfig = ThreadConfig() VAL_EXTRA_CREDENTIALS_HOST = "hosts" + KW_USE_ACCESS_TOKEN = "useAccessToken" KW_USE_CERTIFICATES = "useCertificates" KW_EXTRA_CREDENTIALS = "extraCredentials" KW_TIMEOUT = "timeout" @@ -103,6 +106,7 @@ def __init__(self, serviceName, **kwargs): self.__ca_location = False self.kwargs = kwargs + self.__useAccessToken = None self.__useCertificates = None # The CS useServerCertificate option can be overridden by explicit argument self.__forceUseCertificates = self.kwargs.get(self.KW_USE_CERTIFICATES) @@ -219,6 +223,15 @@ def __discoverCredentialsToUse(self): else: self.kwargs[self.KW_SKIP_CA_CHECK] = skipCACheck() + # Use tokens? + if self.KW_USE_ACCESS_TOKEN in self.kwargs: + self.__useAccessToken = self.kwargs[self.KW_USE_ACCESS_TOKEN] + else: + if not gConfig.useServerCertificate(): + self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") + if os.environ.get('DIRAC_USE_ACCESS_TOKEN'): + self.__useAccessToken = os.environ['DIRAC_USE_ACCESS_TOKEN'] + # Rewrite a little bit from here: don't need the proxy string, we use the file if self.KW_PROXY_CHAIN in self.kwargs: try: @@ -490,12 +503,20 @@ def _request(self, retry=0, outputFile=None, **kwargs): # getting certificate # Do we use the server certificate ? if self.kwargs[self.KW_USE_CERTIFICATES]: - cert = Locations.getHostCertificateAndKeyLocation() + auth = {'cert': Locations.getHostCertificateAndKeyLocation()} + + # Use access token? + elif self.__useAccessToken: + result = readTokenFromFile() + if not result['OK']: + return result + auth = {'headers': {"Authorization": "Bearer %s" % result['Value']['access_token']}} + # CHRIS 04.02.21 # TODO: add proxyLocation check ? else: - cert = Locations.getProxyLocation() - if not cert: + auth = {'cert': Locations.getProxyLocation()} + if not auth['cert']: gLogger.error("No proxy found") return S_ERROR("No proxy found") @@ -510,9 +531,8 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Default case, just return the result if not outputFile: - call = requests.post(url, data=kwargs, - timeout=self.timeout, verify=verify, - cert=cert) + call = requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, + **auth) # raising the exception for status here # means essentialy that we are losing here the information of what is returned by the server # as error message, since it is not passed to the exception @@ -532,7 +552,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Stream download # https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow with requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, - cert=cert, stream=True) as r: + stream=True, **auth) as r: rawText = r.text r.raise_for_status() diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py new file mode 100644 index 00000000000..4ed160480d6 --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -0,0 +1,790 @@ +""" BaseRequestHandler is the base class for tornados services and etc handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from io import open + +import jwt as _jwt + +import os +import time +import pprint +import requests +import threading +from datetime import datetime +from six import string_types +from six.moves import http_client +from six.moves.urllib.parse import unquote +from authlib.jose import JsonWebKey, jwt +from authlib.oauth2.rfc6749.util import scope_to_list + + +import tornado +from tornado import gen +from tornado.web import RequestHandler, HTTPError +from tornado.ioloop import IOLoop +from tornado.concurrent import Future + +import DIRAC + +from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.DISET.AuthManager import AuthManager +from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +sLog = gLogger.getSubLogger(__name__.split('.')[-1]) + + +class BaseRequestHandler(RequestHandler): + """ + Base class for all the Handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` + + Each HTTP request is served by a new instance of this class. + + For the sequence of method called, please refer to + the `tornado documentation `_. + + For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, + the handler can define a method ``export_streamToClient``. This is the method that will be called + whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET + ``transfer_toClient``. + Note that this is here only for compatibility, and we discourage using it for new purposes, as it is + bound to disappear. + + The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. + + """ + # Because we initialize at first request, we use a flag to know if it's already done + __init_done = False + # Lock to make sure that two threads are not initializing at the same time + __init_lock = threading.RLock() + + # MonitoringClient, we don't use gMonitor which is not thread-safe + # We also need to add specific attributes for each service + _monitor = None + + # System name with which this component is associated + SYSTEM = None + + # Auth requirements + AUTH_PROPS = None + + # Type of component + MONITORING_COMPONENT = MonitoringClient.COMPONENT_WEB + + # Prefix of methods names + METHOD_PREFIX = "export_" + + # Which grant type to use + USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + + @classmethod + def _initMonitoring(cls, serviceName, fullUrl): + """ + Initialize the monitoring specific to this handler + This has to be called only by :py:meth:`.__initializeService` + to ensure thread safety and unicity of the call. + + :param serviceName: relative URL ``//`` + :param fullUrl: full URl like ``https://://`` + """ + + # Init extra bits of monitoring + + cls._monitor = MonitoringClient() + cls._monitor.setComponentType(cls.MONITORING_COMPONENT) + + cls._monitor.initialize() + + if tornado.process.task_id() is None: # Single process mode + cls._monitor.setComponentName('Tornado/%s' % serviceName) + else: + cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) + + cls._monitor.setComponentLocation(fullUrl) + + cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) + + cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) + cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) + cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) + + cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} + + return S_OK() + + @classmethod + def _getServiceName(cls, request): + """ Search service name in request. + + :param object request: tornado Request + + :return: str + """ + raise NotImplementedError('Please, create the _getServiceName class method') + + @classmethod + def _getServiceAuthSection(cls, serviceName): + """ Search service auth section. + + :param str serviceName: service name + + :return: str + """ + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + + @classmethod + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. + + :param str serviceName: service name + :param object request: tornado Request + + :return: dict + """ + return {} + + @classmethod + def __initializeService(cls, request): + """ + Initialize a service. + The work is only perform once at the first request. + + :param object request: tornado Request + + :returns: S_OK + """ + # If the initialization was already done successfuly, + # we can just return + if cls.__init_done: + return S_OK() + + # Otherwise, do the work but with a lock + with cls.__init_lock: + + # Check again that the initialization was not done by another thread + # while we were waiting for the lock + if cls.__init_done: + return S_OK() + + cls._idps = IdProviderFactory() + + # absoluteUrl: full URL e.g. ``https://://`` + absoluteUrl = request.path + serviceName = cls._getServiceName(request) + + cls._startTime = datetime.utcnow() + sLog.info("First use of %s, initializing service..." % serviceName) + cls._authManager = AuthManager(cls._getServiceAuthSection(serviceName)) + + cls._initMonitoring(serviceName, absoluteUrl) + + cls._serviceName = serviceName + cls._validNames = [serviceName] + serviceInfo = cls._getServiceInfo(serviceName, request) + + cls._serviceInfoDict = serviceInfo + + cls.__monitorLastStatsUpdate = time.time() + + # Some pre-initialization + cls._initializeHandler() + + cls.initializeHandler(serviceInfo) + + cls.__init_done = True + + return S_OK() + + @classmethod + def _initializeHandler(cls): + """ + If you are writing your own framework that follows this class + and you need to add something before initializing the service, + such as initializing the OAuth client, then you need to change this method. + """ + pass + + @classmethod + def initializeHandler(cls, serviceInfo): + """ + This may be overwritten when you write a DIRAC service handler + And it must be a class method. This method is called only one time, + at the first request + + :param dict ServiceInfoDict: infos about services, it contains + 'serviceName', 'serviceSectionPath', + 'csPaths' and 'URL' + """ + pass + + def initializeRequest(self): + """ + Called at every request, may be overwritten in your handler. + """ + pass + + # This is a Tornado magic method + def initialize(self): # pylint: disable=arguments-differ + """ + Initialize the handler, called at every request. + + It just calls :py:meth:`.__initializeService` + + If anything goes wrong, the client will get ``Connection aborted`` + error. See details inside the method. + + ..warning:: + DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER + ==> initialize in DISET became initializeRequest in HTTPS ! + """ + # Only initialized once + if not self.__init_done: + # Ideally, if something goes wrong, we would like to return a Server Error 500 + # but this method cannot write back to the client as per the + # `tornado doc `_. + # So the client will get a ``Connection aborted``` + try: + res = self.__initializeService(self.request) + if not res['OK']: + raise Exception(res['Message']) + except Exception as e: + sLog.error("Error in initialization", repr(e)) + raise + + def _monitorRequest(self): + """ Monitor action for each request + """ + self._stats['requests'] += 1 + self._monitor.setComponentExtraParam('queries', self._stats['requests']) + self._monitor.addMark("Queries") + + def _getMethodName(self): + """ Parse method name. + + :return: str + """ + raise NotImplementedError('Please, create the _getMethodName method') + + def _getMethodArgs(self, args): + """ Decode args. + + :return: list + """ + return args + + def _getMethodAuthProps(self): + """ Resolves the hard coded authorization requirements for method. + + :return: object + """ + try: + return getattr(self, 'auth_' + self.method) + except AttributeError: + if self.AUTH_PROPS and not isinstance(self.AUTH_PROPS, (list, tuple)): + self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] + return self.AUTH_PROPS + + def _getMethod(self): + """ Get method object. + + :return: object + """ + try: + return getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method)) + except AttributeError as e: + sLog.error("Invalid method", self.method) + raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + + def prepare(self): + """ + Tornados prepare method that called before request + """ + + # "method" argument of the POST call. + # This resolves into the ``export_`` method + # on the handler side + # If the argument is not available, the method exists + # and an error 400 ``Bad Request`` is returned to the client + self.method = self._getMethodName() + + self._monitorRequest() + self._prepare() + + def _prepare(self): + """ + Prepare the request. It reads certificates and check authorizations. + We make the assumption that there is always going to be a ``method`` argument + regardless of the HTTP method used + + """ + try: + self.credDict = self._gatherPeerCredentials() + except Exception as e: # pylint: disable=broad-except + # If an error occur when reading certificates we close connection + # It can be strange but the RFC, for HTTP, say's that when error happend + # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN + sLog.debug(str(e)) + sLog.error( + "Error gathering credentials ", "%s; path %s" % + (self.getRemoteAddress(), self.request.path)) + raise HTTPError(status_code=http_client.UNAUTHORIZED) + + # Check whether we are authorized to perform the query + # Note that performing the authQuery modifies the credDict... + authorized = self._authManager.authQuery(self.method, self.credDict, + self._getMethodAuthProps()) + if not authorized: + extraInfo = '' + if self.credDict.get('DN'): + extraInfo += 'DN: %s' % self.credDict['DN'] + if self.credDict.get('ID'): + extraInfo += 'ID: %s' % self.credDict['ID'] + sLog.error( + "Unauthorized access", "Identity %s; path %s; %s" % + (self.srv_getFormattedRemoteCredentials(), + self.request.path, extraInfo)) + raise HTTPError(status_code=http_client.UNAUTHORIZED) + + # Make post a coroutine. + # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines + # for details + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + """ + Method to handle incoming ``POST`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` + method. + + The ``POST`` arguments expected are: + + * ``method``: name of the method to call + * ``args``: JSON encoded arguments for the method + * ``extraCredentials``: (optional) Extra informations to authenticate client + * ``rawContent``: (optionnal, default False) If set to True, return the raw output + of the method called. + + If ``rawContent`` was requested by the client, the ``Content-Type`` + is ``application/octet-stream``, otherwise we set it to ``application/json`` + and JEncode retVal. + + If ``retVal`` is a dictionary that contains a ``Callstack`` item, + it is removed, not to leak internal information. + + + Example of call using ``requests``:: + + In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' + ...: cert = '/tmp/x509up_u1000' + ...: kwargs = {'method':'whoami'} + ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' + ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: + ...: print r.json() + ...: + {u'OK': True, + u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'group': u'dirac_user', + u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'isLimitedProxy': False, + u'isProxy': True, + u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'properties': [u'NormalUser'], + u'secondsLeft': 85441, + u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch/CN=2409820262', + u'username': u'adminusername', + u'validDN': False, + u'validGroup': False}} + """ + # Execute the method in an executor (basically a separate thread) + # Because of that, we cannot calls certain methods like `self.write` + # in _executeMethod. This is because these methods are not threadsafe + # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + # However, we can still rely on instance attributes to store what should + # be sent back (reminder: there is an instance + # of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) + + @gen.coroutine + def _executeMethod(self, args): + """ + Execute the method called, this method is ran in an executor + We have several try except to catch the different problem which can occur + + - First, the method does not exist => Attribute error, return an error to client + - second, anything happend during execution => General Exception, send error to client + + .. warning:: + This method is called in an executor, and so cannot use methods like self.write + See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + """ + + sLog.notice( + "Incoming request %s /%s: %s" % + (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + self.method)) + + # getting method + method = self._getMethod() + methodArgs = self._getMethodArgs(args) + + # Execute + try: + self.initializeRequest() + retVal = method(*methodArgs) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) + raise HTTPError(http_client.INTERNAL_SERVER_ERROR) + + return retVal + + def _finishFuture(self, retVal): + """ Handler Future result + + :param object retVal: tornado.concurrent.Future + """ + + # Wait result only if it's a Future object + self.result = retVal.result() if isinstance(retVal, Future) else retVal + + print('FUTURE RESULT >>>') + print(self.result) + print(self.get_status()) + + # Here it is safe to write back to the client, because we are not + # in a thread anymore + + # Is it S_OK or S_ERROR + if isinstance(self.result, dict) and isinstance(self.result.get('OK'), bool) and ('Value' if self.result['OK'] else 'Message') in self.result: + self._parseDIRACResult(self.result) + + # If set to true, do not JEncode the return of the RPC call + # This is basically only used for file download through + # the 'streamToClient' method. + elif self.get_argument('rawContent', default=False): + # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt + self.set_header("Content-Type", "application/octet-stream") + self.write(self.result) + + # Return simple text or html + elif isinstance(self.result, string_types): + self.write(self.result) + + # JSON + elif isinstance(self.result, dict): + self.set_header("Content-Type", "application/json") + self.write(encode(self.result)) + + self.finish() + + def _parseDIRACResult(self, result): + """ Processing of a standard DIRAC result, + but in a separate method so that it can be modified for another class if necessary + """ + self.set_header("Content-Type", "application/json") + self.write(encode(result)) + + def on_finish(self): + """ + Called after the end of HTTP request. + Log the request duration + """ + elapsedTime = 1000.0 * self.request.request_time() + + argsString = "OK" + try: + if not self.result['OK']: + argsString = "ERROR: %s" % self.result['Message'] + except (AttributeError, KeyError, TypeError): # In case it is not a DIRAC structure + if self._reason != 'OK': + argsString = 'ERROR %s' % self._reason + + sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + elapsedTime, argsString)) + + def _gatherPeerCredentials(self, grants=None): + """ Returne a dictionary designed to work with the AuthManager, + already written for DISET and re-used for HTTPS. + + :param list grants: grants to use + + :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` + (not a DIRAC structure !) + """ + err = [] + result = None + + grants = grants or self.USE_AUTHZ_GRANTS + + if not grants: + raise Exception('USE_AUTHZ_GRANTS is not defined.') + + for a in grants: + grant = a.upper() + grantFunc = eval('self._authz%s' % grant) + if not callable(grantFunc): + raise Exception('%s authentication type is not supported.' % grant) + result = grantFunc() + if result['OK']: + for e in err: + sLog.debug(e) + sLog.debug('%s authentication success.' % grant) + return result['Value'] + err.append('%s authentication: %s' % (grant, result['Message'])) + + # Report on failed authentication attempts + raise Exception('; '.join(err)) + + def _authzSSL(self): + """ Load client certchain in DIRAC and extract informations. + + :return: S_OK(dict)/S_ERROR() + """ + peerChain = X509Chain() + derCert = self.request.get_ssl_certificate() + + # Get client certificate pem + if derCert: + chainAsText = derCert.as_pem() + # Here we read all certificate chain + cert_chain = self.request.get_ssl_certificate_chain() + for cert in cert_chain: + chainAsText = cert.as_pem() + elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS': + chainAsTextEncoded = self.request.headers.get('X-SSL-CERT') + chainAsText = unquote(chainAsTextEncoded) + else: + return S_ERROR('Not found a valide client certificate.') + + peerChain.loadChainFromString(chainAsText) + + # Retrieve the credentials + res = peerChain.getCredentials(withRegistryInfo=False) + if not res['OK']: + return res + + credDict = res['Value'] + + # We check if client sends extra credentials... + if "extraCredentials" in self.request.arguments: + extraCred = self.get_argument("extraCredentials") + if extraCred: + credDict['extraCredentials'] = decode(extraCred)[0] + return S_OK(credDict) + + def _authzJWT(self): + """ Load token claims in DIRAC and extract informations. + + :return: S_OK(dict)/S_ERROR() + """ + # Export token from headers + token = self.request.headers.get('Authorization') + if not token or len(token.split()) != 2: + return S_ERROR('Not found a bearer access token.') + tokenType, accessToken = token.split() + if tokenType.lower() != 'bearer': + return S_ERROR('Found a not bearer access token.') + + result = self._idps.getIdProviderForToken(accessToken) + if not result['OK']: + return result + cli = result['Value'] + payload = cli.verifyToken(accessToken) + credDict = cli.researchGroup(payload, accessToken) + + return S_OK(credDict) + + def _authzVISITOR(self): + """ Visitor access + + :return: S_OK(dict) + """ + return S_OK({}) + + @property + def log(self): + return sLog + + def getDN(self): + return self.credDict.get('DN', '') + + def getUserName(self): + return self.credDict.get('username', '') + + def getUserGroup(self): + return self.credDict.get('group', '') + + def getProperties(self): + return self.credDict.get('properties', []) + + def isRegisteredUser(self): + return self.credDict.get('username', 'anonymous') != 'anonymous' and self.credDict.get('group') + + auth_ping = ['all'] + + def export_ping(self): + """ + Default ping method, returns some info about server. + + It returns the exact same information as DISET, for transparency purpose. + """ + # COPY FROM DIRAC.Core.DISET.RequestHandler + dInfo = {} + dInfo['version'] = DIRAC.version + dInfo['time'] = datetime.utcnow() + # Uptime + try: + with open("/proc/uptime", 'rt') as oFD: + iUptime = int(float(oFD.readline().split()[0].strip())) + dInfo['host uptime'] = iUptime + except Exception: # pylint: disable=broad-except + pass + startTime = self._startTime + dInfo['service start time'] = self._startTime + serviceUptime = datetime.utcnow() - startTime + dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds + # Load average + try: + with open("/proc/loadavg", 'rt') as oFD: + dInfo['load'] = " ".join(oFD.read().split()[:3]) + except Exception: # pylint: disable=broad-except + pass + dInfo['name'] = self._serviceInfoDict['serviceName'] + stTimes = os.times() + dInfo['cpu times'] = {'user time': stTimes[0], + 'system time': stTimes[1], + 'children user time': stTimes[2], + 'children system time': stTimes[3], + 'elapsed real time': stTimes[4] + } + + return S_OK(dInfo) + + auth_echo = ['all'] + + @staticmethod + def export_echo(data): + """ + This method used for testing the performance of a service + """ + return S_OK(data) + + auth_whoami = ['authenticated'] + + def export_whoami(self): + """ + A simple whoami, returns all credential dictionary, except certificate chain object. + """ + credDict = self.srv_getRemoteCredentials() + if 'x509Chain' in credDict: + # Not serializable + del credDict['x509Chain'] + return S_OK(credDict) + + @classmethod + def srv_getCSOption(cls, optionName, defaultValue=False): + """ + Get an option from the CS section of the services + + :return: Value for serviceSection/optionName in the CS being defaultValue the default + """ + if optionName[0] == "/": + return gConfig.getValue(optionName, defaultValue) + for csPath in cls._serviceInfoDict['csPaths']: + result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) + if result['OK']: + return result['Value'] + return defaultValue + + def getCSOption(self, optionName, defaultValue=False): + """ + Just for keeping same public interface + """ + return self.srv_getCSOption(optionName, defaultValue) + + def srv_getRemoteAddress(self): + """ + Get the address of the remote peer. + + :return: Address of remote peer. + """ + + remote_ip = self.request.remote_ip + # Although it would be trivial to add this attribute in _HTTPRequestContext, + # Tornado won't release anymore 5.1 series, so go the hacky way + try: + remote_port = self.request.connection.stream.socket.getpeername()[1] + except Exception: # pylint: disable=broad-except + remote_port = 0 + + return (remote_ip, remote_port) + + def getRemoteAddress(self): + """ + Just for keeping same public interface + """ + return self.srv_getRemoteAddress() + + def srv_getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def srv_getFormattedRemoteCredentials(self): + """ + Return the DN of user + + Mostly copy paste from + :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` + + Note that the information will be complete only once the AuthManager was called + """ + address = self.getRemoteAddress() + peerId = "" + # Depending on where this is call, it may be that credDict is not yet filled. + # (reminder: AuthQuery fills part of it..) + try: + peerId = "[%s:%s]" % (self.credDict.get('group', 'visitor'), self.credDict.get('username', 'anonymous')) + except AttributeError: + pass + + if address[0].find(":") > -1: + return "([%s]:%s)%s" % (address[0], address[1], peerId) + return "(%s:%s)%s" % (address[0], address[1], peerId) + + def srv_getServiceName(self): + """ + Return the service name + """ + return self._serviceInfoDict['serviceName'] + + def srv_getURL(self): + """ + Return the URL + """ + return self.request.path diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index 1f4397f63e9..f51dadc9ddc 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -9,6 +9,8 @@ __RCSID__ = "$Id$" +import inspect +from six import string_types from tornado.web import url as TornadoURL, RequestHandler from DIRAC import gConfig, gLogger, S_ERROR, S_OK @@ -50,109 +52,216 @@ class HandlerManager(object): ``System/Component`` (e.g. ``DataManagement/FileCatalog``) """ - def __init__(self, autoDiscovery=True): + def __init__(self, services, endpoints): """ - Initialization function, you can set autoDiscovery=False to prevent automatic - discovery of handler. If disabled you can use loadHandlersByServiceName() to - load your handlers or loadHandlerInHandlerManager() - - :param autoDiscovery: (default True) Disable the automatic discovery, - can be used to choose service we want to load. + Initialization function, you can set False for both arguments to prevent automatic + discovery of handlers and use `loadServicesHandlers()` to + load your handlers or `loadEndpointsHandlers()` + + :param services: List of service handlers to load. + If ``True``, loads all services from CS + :type services: bool or list + :param endpoints: List of endpoint handlers to load. + If ``True``, loads all endpoints from CS + :type endpoints: bool or list """ + self.loader = None self.__handlers = {} + self.__services = services + self.__endpoints = endpoints self.__objectLoader = ObjectLoader() - self.__autoDiscovery = autoDiscovery - self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") - def __addHandler(self, handlerTuple, url=None): + def __addHandler(self, handlerPath, handler, urls=None, port=None): """ Function which add handler to list of known handlers + :param str handlerPath: module name, e.g.: `Framework/Auth` + :param object handler: handler class + :param list urls: request path + :param int port: port - :param handlerTuple: (path, class) + :return: S_OK()/S_ERROR() """ - # Check if handler not already loaded - if not url or url not in self.__handlers: - gLogger.debug("Find new handler %s" % (handlerTuple[0])) - - # If url is not given, try to discover it - if url is None: - # FIRST TRY: Url is hardcoded - try: - url = handlerTuple[1].LOCATION - # SECOND TRY: URL can be deduced from path - except AttributeError: - gLogger.debug("No location defined for %s try to get it from path" % handlerTuple[0]) - url = urlFinder(handlerTuple[0]) - + # First of all check if we can find route + # If urls is not given, try to discover it + if urls is None: + # FIRST TRY: Url is hardcoded + try: + urls = handler.LOCATION + # SECOND TRY: URL can be deduced from path + except AttributeError: + gLogger.debug("No location defined for %s try to get it from path" % handlerPath) + urls = urlFinder(handlerPath) + + if not urls: + gLogger.warn("URL not found for %s" % (handlerPath)) + return S_ERROR("URL not found for %s" % (handlerPath)) + + for url in urls if isinstance(urls, (list, tuple)) else [urls]: # We add "/" if missing at begin, e.g. we found "Framework/Service" # URL can't be relative in Tornado if url and not url.startswith('/'): url = "/%s" % url - elif not url: - gLogger.warn("URL not found for %s" % (handlerTuple[0])) - return S_ERROR("URL not found for %s" % (handlerTuple[0])) + + # Some new handler + if handlerPath not in self.__handlers: + gLogger.debug("Add new handler %s with port %s" % (handlerPath, port)) + self.__handlers[handlerPath] = {'URLs': [], 'Port': port} + + # Check if URL already loaded + if (url, handler) in self.__handlers[handlerPath]['URLs']: + gLogger.debug("URL: %s already loaded for %s " % (url, handlerPath)) + continue # Finally add the URL to handlers - if url not in self.__handlers: - self.__handlers[url] = handlerTuple[1] - gLogger.info("New handler: %s with URL %s" % (handlerTuple[0], url)) - else: - gLogger.debug("Handler already loaded %s" % (handlerTuple[0])) + gLogger.info("Add new URL %s to %s handler" % (url, handlerPath)) + self.__handlers[handlerPath]['URLs'].append((url, handler)) + return S_OK() - def discoverHandlers(self): + def discoverHandlers(self, handlerInstance): """ Force the discovery of URL, automatic call when we try to get handlers for the first time. You can disable the automatic call with autoDiscovery=False at initialization + + :param str handlerInstance: handler instance, the name of the section in some system section e.g.:: Services, APIs + + :return: list """ - gLogger.debug("Trying to auto-discover the handlers for Tornado") + urls = [] + gLogger.debug("Trying to auto-discover the %s handlers for Tornado" % handlerInstance) # Look in config diracSystems = gConfig.getSections('/Systems') - serviceList = [] if diracSystems['OK']: for system in diracSystems['Value']: try: - instance = PathFinder.getSystemInstance(system) - services = gConfig.getSections('/Systems/%s/%s/Services' % (system, instance)) - if services['OK']: - for service in services['Value']: - newservice = ("%s/%s" % (system, service)) - - # We search in the CS all handlers which used HTTPS as protocol - isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, instance, service)) - if isHTTPS and isHTTPS.lower() == 'https': - serviceList.append(newservice) + sysInstance = PathFinder.getSystemInstance(system) + result = gConfig.getSections('/Systems/%s/%s/%s' % (system, sysInstance, handlerInstance)) + if result['OK']: + for inst in result['Value']: + newInst = ("%s/%s" % (system, inst)) + + if handlerInstance == 'Services': + # We search in the CS all handlers which used HTTPS as protocol + isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, sysInstance, inst)) + if isHTTPS and isHTTPS.lower() == 'https': + urls.append(newInst) + else: + port = gConfig.getValue('/Systems/%s/%s/Services/%s/Port' % (system, sysInstance, inst)) + if port: + newInst += ':%s' % port + urls.append(newInst) # On systems sometime you have things not related to services... except RuntimeError: pass - return self.loadHandlersByServiceName(serviceList) + return urls - def loadHandlersByServiceName(self, servicesNames): + def loadServicesHandlers(self, services=None): """ Load a list of handler from list of service using DIRAC moduleLoader Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - :param servicesNames: list of service, e.g. ['Framework/Hello', 'Configuration/Server'] + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() + """ + # list of services, e.g. ['Framework/Hello', 'Configuration/Server'] + if isinstance(services, string_types): + services = [services] + # list of services + self.__services = self.__services if services is None else services if services else [] + + if self.__services is True: + self.__services = self.discoverHandlers('Services') + + if self.__services: + self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) + load = self.loader.loadModules(self.__services) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + url = module['loadName'] + + # URL can be like https://domain:port/service/name or just service/name + # Here we just want the service name, for tornado + serviceTuple = url.replace('https://', '').split('/')[-2:] + url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) + self.__addHandler(module['loadName'], module['classObj'], url) + return S_OK() + + def __extractPorts(self, urls): + """ Extract ports from urls + + :param list urls: urls that can contain port, .e.g:: System/Service:port + + :return: (dict, list) + """ + portMapping = {} + newURLs = [] + for _url in urls: + if ':' in _url: + urlTuple = _url.split(':') + if urlTuple[0] not in portMapping: + portMapping[urlTuple[0]] = urlTuple[1] + newURLs.append(urlTuple[0]) + else: + newURLs.append(_url) + return (portMapping, newURLs) + + def loadEndpointsHandlers(self, endpoints=None): """ + Load a list of handler from list of endpoints using DIRAC moduleLoader + Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - # Use DIRAC system to load: search in CS if path is given and if not defined - # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) - if not isinstance(servicesNames, list): - servicesNames = [servicesNames] - - load = self.loader.loadModules(servicesNames) - if not load['OK']: - return load - for module in self.loader.getModules().values(): - url = module['loadName'] - - # URL can be like https://domain:port/service/name or just service/name - # Here we just want the service name, for tornado - serviceTuple = url.replace('https://', '').split('/')[-2:] - url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) - self.__addHandler((module['loadName'], module['classObj']), url) + :param endpoints: List of endpoint handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + # list of endpoints, e.g. ['Framework/Proxy', ...] + if isinstance(endpoints, string_types): + endpoints = [endpoints] + # list of endpoints. If __endpoints is ``True`` then list of endpoints will dicover from CS + self.__endpoints = self.__endpoints if endpoints is None else endpoints if endpoints else [] + + if self.__endpoints is True: + self.__endpoints = self.discoverHandlers('APIs') + + if self.__endpoints: + # Extract ports + ports, self.__endpoints = self.__extractPorts(self.__endpoints) + + self.loader = ModuleLoader("API", PathFinder.getAPISection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/API) + load = self.loader.loadModules(self.__endpoints) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + handler = module['classObj'] + if not handler.LOCATION: + handler.LOCATION = urlFinder(module['loadName']) + urls = [] + # Look for methods that are exported + for mName, mObj in inspect.getmembers(handler): + if inspect.ismethod(mObj) and mName.find(handler.METHOD_PREFIX) == 0: + methodName = mName[len(handler.METHOD_PREFIX):] + args = getattr(handler, 'path_%s' % methodName, []) + gLogger.debug(" - Route %s/%s -> %s %s" % (handler.LOCATION, methodName, module['loadName'], mName)) + url = "%s%s" % (handler.LOCATION, '' if methodName == 'index' else ('/%s' % methodName)) + if args: + url += r'[\/]?%s' % '/'.join(args) + urls.append(url) + gLogger.debug(" * %s" % url) + self.__addHandler(module['loadName'], handler, urls, ports.get(module['modName'])) return S_OK() def getHandlersURLs(self): @@ -163,24 +272,18 @@ def getHandlersURLs(self): :returns: a list of URL (not the string with "https://..." but the tornado object) see http://www.tornadoweb.org/en/stable/web.html#tornado.web.URLSpec """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - self.discoverHandlers() urls = [] - for key in self.__handlers: - urls.append(TornadoURL(key, self.__handlers[key])) + for handlerData in self.__handlers.values(): + for url in handlerData['URLs']: + urls.append(TornadoURL(*url)) return urls def getHandlersDict(self): """ Return all handler dictionary - :returns: dictionary with absolute url as key ("/System/Service") - and tornado.web.url object as value + :returns: dictionary with service name as key ("System/Service") + and tornado.web.url objects as value for 'URLs' key + and port as value for 'Port' key """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - res = self.discoverHandlers() - if not res['OK']: - gLogger.error("Could not load handlers", res) return self.__handlers diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py new file mode 100644 index 00000000000..6d0523fc517 --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -0,0 +1,80 @@ +""" +TornadoREST is the base class for your RESTful API handlers. +It directly inherits from :py:class:`tornado.web.RequestHandler` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import tornado.ioloop +from tornado import gen +from tornado.web import HTTPError +from tornado.ioloop import IOLoop +from six.moves import http_client + +import DIRAC + +from DIRAC import gLogger +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler + +sLog = gLogger.getSubLogger(__name__) + + +class TornadoREST(BaseRequestHandler): # pylint: disable=abstract-method + USE_AUTHZ_GRANTS = ['SSL', 'JWT', 'VISITOR'] + METHOD_PREFIX = 'web_' + LOCATION = '/' + + @classmethod + def _getServiceName(cls, request): + """ Define endpoint full name + + :param object request: tornado Request + + :return: str + """ + if not cls.SYSTEM: + raise Exception("System name must be defined.") + return "/".join([cls.SYSTEM, cls.__name__]) + + @classmethod + def _getServiceAuthSection(cls, endpointName): + """ Search endpoint auth section. + + :param str endpointName: endpoint name + + :return: str + """ + return "%s/Authorization" % PathFinder.getAPISection(endpointName) + + def _getMethodName(self): + """ Parse method name. + + :return: str + """ + print(self.request.path) + method = self.request.path.replace(self.LOCATION, '', 1).strip('/').split('/')[0] + print(method) + if method and hasattr(self, ''.join([self.METHOD_PREFIX, method])): + return method + elif hasattr(self, '%sindex' % self.METHOD_PREFIX): + gLogger.warn('%s method not implemented. Use the index method to handle this.' % method) + return 'index' + else: + raise NotImplementedError('%s method not implemented. \ + You can use the index method to handle this.' % method) + + @gen.coroutine + def get(self, *args, **kwargs): # pylint: disable=arguments-differ + """ Method to handle incoming ``GET`` requests. + Logic copied from :py:func:`~DIRAC.Core.Tornado.Server.BaseRequestHandler.post`. + """ + # Execute the method in an executor (basically a separate thread) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 63aeaf133dc..50b310498dd 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -26,11 +26,11 @@ import tornado.ioloop import DIRAC -from DIRAC import gConfig, gLogger -from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC import gConfig, gLogger, S_OK from DIRAC.Core.Security import Locations -from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager from DIRAC.Core.Utilities import MemStat +from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager +from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient sLog = gLogger.getSubLogger(__name__) @@ -61,35 +61,36 @@ class TornadoServer(object): Example 2:We want to debug service1 and service2 only, and use another port for that :: - services = ['component/service1', 'component/service2'] - serverToLaunch = TornadoServer(services=services, port=1234) + services = ['component/service1:port1', 'component/service2'] + endpoints = ['component/endpoint1:port1', 'component/endpoint2'] + serverToLaunch = TornadoServer(services=services, endpoints=endpoints, port=1234) serverToLaunch.startTornado() """ - def __init__(self, services=None, port=None): + def __init__(self, services=True, endpoints=False, port=None): + """ C'r + + :param list services: (default True) List of service handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load services + :param list endpoints: (default False) List of endpoint handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load endpoints + :param int port: Port to listen to. + If ``None``, the port is resolved following the logic described in the class documentation """ - - :param list services: (default None) List of service handlers to load. If ``None``, loads all - :param int port: Port to listen to. If None, the port is resolved following the logic - described in the class documentation - """ - + # Application metadata, routes and settings mapping on the ports + self.__appsSettings = {} + # Default port, if enother is not discover if port is None: port = gConfig.getValue("/Systems/Tornado/%s/Port" % PathFinder.getSystemInstance('Tornado'), 8443) - - if services and not isinstance(services, list): - services = [services] - - # URLs for services. - # Contains Tornado :py:class:`tornado.web.url` object - self.urls = [] - # Other infos self.port = port - self.handlerManager = HandlerManager() - # Monitoring attributes + # Handler manager initialization with default settings + self.handlerManager = HandlerManager(services, endpoints) + # Monitoring attributes self._monitor = MonitoringClient() # temp value for computation, used by the monitoring self.__report = None @@ -98,20 +99,69 @@ def __init__(self, services=None, port=None): self.__monitoringLoopDelay = 60 # In secs # If services are defined, load only these ones (useful for debug purpose or specific services) - if services: - retVal = self.handlerManager.loadHandlersByServiceName(services) - if not retVal['OK']: - sLog.error(retVal['Message']) - raise ImportError("Some services can't be loaded, check the service names and configuration.") - + retVal = self.handlerManager.loadServicesHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some services can't be loaded, check the service names and configuration.") + + retVal = self.handlerManager.loadEndpointsHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some endpoints can't be loaded, check the endpoint names and configuration.") + + def __calculateAppSettings(self): + """ Calculate application information mapping on the ports + """ # if no service list is given, load services from configuration handlerDict = self.handlerManager.getHandlersDict() - for item in handlerDict.items(): - # handlerDict[key].initializeService(key) - self.urls.append(url(item[0], item[1])) - # If there is no services loaded: - if not self.urls: - raise ImportError("There is no services loaded, please check your configuration") + for data in handlerDict.values(): + port = data.get('Port') or self.port + for hURL in data['URLs']: + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if hURL not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(hURL) + return bool(self.__appsSettings) + + def loadServices(self, services): + """ Load a services + + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadServicesHandlers(services) + + def loadEndpoints(self, endpoints): + """ Load a endpoints + + :param endpoints: List of service handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadEndpointsHandlers(endpoints) + + def addHandlers(self, routes, settings=None, port=None): + """ Add new routes + + :param list routes: routes + :param dict settings: application settings + :param int port: port + """ + port = port or self.port + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if settings: + self.__appsSettings[port]['settings'].update(settings) + for route in routes: + if route not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(route) + + return S_OK() def startTornado(self): """ @@ -119,13 +169,13 @@ def startTornado(self): This method never returns. """ - sLog.debug("Starting Tornado") - self._initMonitoring() + # If there is no services loaded: + if not self.__calculateAppSettings(): + raise Exception("There is no services loaded, please check your configuration") - router = Application(self.urls, - debug=False, - compress_response=True) + sLog.debug("Starting Tornado") + # Prepare SSL settings certs = Locations.getHostCertificateAndKeyLocation() if certs is False: sLog.fatal("Host certificates not found ! Can't start the Server") @@ -139,29 +189,33 @@ def startTornado(self): 'sslDebug': False, # Set to true if you want to see the TLS debug messages } + # Init monitoring + self._initMonitoring() self.__monitorLastStatsUpdate = time.time() self.__report = self.__startReportToMonitoringLoop() # Starting monitoring, IOLoop waiting time in ms, __monitoringLoopDelay is defined in seconds tornado.ioloop.PeriodicCallback(self.__reportToMonitoring, self.__monitoringLoopDelay * 1000).start() - # If we are running with python3, Tornado will use asyncio, - # and we have to convince it to let us run in a different thread - # Doing this ensures a consistent behavior between py2 and py3 - if six.PY3: - import asyncio # pylint: disable=import-error - asyncio.set_event_loop_policy(tornado.platform.asyncio.AnyThreadEventLoopPolicy()) - - # Start server - server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) - try: - server.listen(self.port) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception starting HTTPServer", e) - raise - sLog.always("Listening on port %s" % self.port) - for service in self.urls: - sLog.debug("Available service: %s" % service) + for port, app in self.__appsSettings.items(): + sLog.debug(" - %s" % "\n - ".join(["%s = %s" % (k, ssl_options[k]) for k in ssl_options])) + + # Default server configuration + settings = dict(compress_response=True, cookie_secret='secret') + + # Merge appllication settings + settings.update(app['settings']) + # Start server + router = Application(app['routes'], **settings) + server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) + try: + server.listen(port) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception starting HTTPServer", e) + raise + sLog.always("Listening on port %s" % port) + for service in app['routes']: + sLog.debug("Available service: %s" % service if isinstance(service, url) else service[0]) IOLoop.current().start() diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index da0976ec12b..cd23bee05ac 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -9,18 +9,6 @@ __RCSID__ = "$Id$" -from io import open - -import os -import time -import threading -from datetime import datetime -from six.moves import http_client -from tornado.web import RequestHandler, HTTPError -from tornado import gen -import tornado.ioloop -from tornado.ioloop import IOLoop - import DIRAC from DIRAC import gConfig, gLogger, S_OK @@ -28,30 +16,14 @@ from DIRAC.Core.DISET.AuthManager import AuthManager from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler +from DIRAC.ConfigurationSystem.Client import PathFinder sLog = gLogger.getSubLogger(__name__) -class TornadoService(RequestHandler): # pylint: disable=abstract-method +class TornadoService(BaseRequestHandler): # pylint: disable=abstract-method """ - Base class for all the Handlers. - It directly inherits from :py:class:`tornado.web.RequestHandler` - - Each HTTP request is served by a new instance of this class. - - For the sequence of method called, please refer to - the `tornado documentation `_. - - For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, - the handler can define a method ``export_streamToClient``. This is the method that will be called - whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET - ``transfer_toClient``. - Note that this is here only for compatibility, and we discourage using it for new purposes, as it is - bound to disappear. - - The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. - In order to create a handler for your service, it has to follow a certain skeleton:: @@ -111,155 +83,45 @@ def export_streamToClient(self, myDataToSend, token): These are initialized in the :py:meth:`.initialize` method. """ - - # Because we initialize at first request, we use a flag to know if it's already done - __init_done = False - # Lock to make sure that two threads are not initializing at the same time - __init_lock = threading.RLock() - - # MonitoringClient, we don't use gMonitor which is not thread-safe - # We also need to add specific attributes for each service - _monitor = None + # Prefix of methods names + METHOD_PREFIX = "export_" @classmethod - def _initMonitoring(cls, serviceName, fullUrl): - """ - Initialize the monitoring specific to this handler - This has to be called only by :py:meth:`.__initializeService` - to ensure thread safety and unicity of the call. - - :param serviceName: relative URL ``//`` - :param fullUrl: full URl like ``https://://`` - """ + def _getServiceName(cls, request): + """ Search service name in request. - # Init extra bits of monitoring + :param object request: tornado Request - cls._monitor = MonitoringClient() - cls._monitor.setComponentType(MonitoringClient.COMPONENT_WEB) - - cls._monitor.initialize() - - if tornado.process.task_id() is None: # Single process mode - cls._monitor.setComponentName('Tornado/%s' % serviceName) - else: - cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) - - cls._monitor.setComponentLocation(fullUrl) - - cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) - - cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) - cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) - cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) - - cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} - - return S_OK() - - @classmethod - def __initializeService(cls, relativeUrl, absoluteUrl): + :return: str """ - Initialize a service. - The work is only perform once at the first request. - - :param relativeUrl: relative URL, e.g. ``//`` - :param absoluteUrl: full URL e.g. ``https://://`` - - :returns: S_OK - """ - # If the initialization was already done successfuly, - # we can just return - if cls.__init_done: - return S_OK() - - # Otherwise, do the work but with a lock - with cls.__init_lock: - - # Check again that the initialization was not done by another thread - # while we were waiting for the lock - if cls.__init_done: - return S_OK() - - # Url starts with a "/", we just remove it - serviceName = relativeUrl[1:] - - cls._startTime = datetime.utcnow() - sLog.info("First use, initializing service...", "%s" % relativeUrl) - cls._authManager = AuthManager("%s/Authorization" % PathFinder.getServiceSection(serviceName)) - - cls._initMonitoring(serviceName, absoluteUrl) - - cls._serviceName = serviceName - cls._validNames = [serviceName] - serviceInfo = {'serviceName': serviceName, - 'serviceSectionPath': PathFinder.getServiceSection(serviceName), - 'csPaths': [PathFinder.getServiceSection(serviceName)], - 'URL': absoluteUrl - } - cls._serviceInfoDict = serviceInfo - - cls.__monitorLastStatsUpdate = time.time() - - cls.initializeHandler(serviceInfo) - - cls.__init_done = True - - return S_OK() + # Expected path: ``//`` + return request.path[1:] @classmethod - def initializeHandler(cls, serviceInfoDict): - """ - This may be overwritten when you write a DIRAC service handler - And it must be a class method. This method is called only one time, - at the first request + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. - :param dict ServiceInfoDict: infos about services, it contains - 'serviceName', 'serviceSectionPath', - 'csPaths' and 'URL' - """ - pass - - def initializeRequest(self): - """ - Called at every request, may be overwritten in your handler. - """ - pass + :param str serviceName: service name + :param object request: tornado Request - # This is a Tornado magic method - def initialize(self): # pylint: disable=arguments-differ + :return: dict """ - Initialize the handler, called at every request. + return {'serviceName': serviceName, + 'serviceSectionPath': PathFinder.getServiceSection(serviceName), + 'csPaths': [PathFinder.getServiceSection(serviceName)], + 'URL': request.full_url()} - It just calls :py:meth:`.__initializeService` + def _getMethodName(self): + """ Parse method name. - If anything goes wrong, the client will get ``Connection aborted`` - error. See details inside the method. - - ..warning:: - DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER - ==> initialize in DISET became initializeRequest in HTTPS ! + :return: str """ + return self.get_argument("method") - # Only initialized once - if not self.__init_done: - # Ideally, if something goes wrong, we would like to return a Server Error 500 - # but this method cannot write back to the client as per the - # `tornado doc `_. - # So the client will get a ``Connection aborted``` - try: - res = self.__initializeService(self.srv_getURL(), self.request.full_url()) - if not res['OK']: - raise Exception(res['Message']) - except Exception as e: - sLog.error("Error in initialization", repr(e)) - raise - - def prepare(self): - """ - Prepare the request. It reads certificates and check authorizations. - We make the assumption that there is always going to be a ``method`` argument - regardless of the HTTP method used + def _getMethodArgs(self, args): + """ Decode args. + :return: list """ # "method" argument of the POST call. @@ -444,291 +306,4 @@ def __executeMethod(self): # Decode args args_encoded = self.get_body_argument('args', default=encode([])) - - args = decode(args_encoded)[0] - # Execute - try: - self.initializeRequest() - retVal = method(*args) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) - raise HTTPError(http_client.INTERNAL_SERVER_ERROR) - - return retVal - - # def __write_return(self, retVal): - # """ - # Write back to the client and return. - # It sets some headers (status code, ``Content-Type``). - # If raw content was requested by the client, the ``Content-Type`` - # is ``application/octet-stream``, otherwise we set it to ``application/json`` - # and JEncode retVal. - - # If ``retVal`` is a dictionary that contains a ``Callstack`` item, - # it is removed, not to leak internal information. - - # :param retVal: anything that can be serialized in json. - # """ - - # # In case of error in server side we hide server CallStack to client - # try: - # if 'CallStack' in retVal: - # del retVal['CallStack'] - # except TypeError: - # pass - - # # Set the status - # self.set_status(self._httpStatus) - - # # This is basically only used for file download through - # # the 'streamToClient' method. - # if self.rawContent: - # # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - # self.set_header("Content-Type", "application/octet-stream") - # returnedData = retVal - # else: - # self.set_header("Content-Type", "application/json") - # returnedData = encode(retVal) - - # self.write(returnedData) - - def on_finish(self): - """ - Called after the end of HTTP request. - Log the request duration - """ - elapsedTime = 1000.0 * self.request.request_time() - - try: - if self.result['OK']: - argsString = "OK" - else: - argsString = "ERROR: %s" % self.result['Message'] - except (AttributeError, KeyError): # In case it is not a DIRAC structure - if self._reason == 'OK': - argsString = 'OK' - else: - argsString = 'ERROR %s' % self._reason - - argsString = "ERROR: %s" % self._reason - sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - elapsedTime, argsString)) - - def _gatherPeerCredentials(self): - """ - Load client certchain in DIRAC and extract informations. - - The dictionary returned is designed to work with the AuthManager, - already written for DISET and re-used for HTTPS. - - :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` - (not a DIRAC structure !) - """ - - chainAsText = self.request.get_ssl_certificate().as_pem() - peerChain = X509Chain() - - # Here we read all certificate chain - cert_chain = self.request.get_ssl_certificate_chain() - for cert in cert_chain: - chainAsText += cert.as_pem() - - peerChain.loadChainFromString(chainAsText) - - # Retrieve the credentials - res = peerChain.getCredentials(withRegistryInfo=False) - if not res['OK']: - raise Exception(res['Message']) - - credDict = res['Value'] - - # We check if client sends extra credentials... - if "extraCredentials" in self.request.arguments: - extraCred = self.get_argument("extraCredentials") - if extraCred: - credDict['extraCredentials'] = decode(extraCred)[0] - return credDict - - -#### -# -# Default method -# -#### - - auth_ping = ['all'] - - def export_ping(self): - """ - Default ping method, returns some info about server. - - It returns the exact same information as DISET, for transparency purpose. - """ - # COPY FROM DIRAC.Core.DISET.RequestHandler - dInfo = {} - dInfo['version'] = DIRAC.version - dInfo['time'] = datetime.utcnow() - # Uptime - try: - with open("/proc/uptime", 'rt') as oFD: - iUptime = int(float(oFD.readline().split()[0].strip())) - dInfo['host uptime'] = iUptime - except Exception: # pylint: disable=broad-except - pass - startTime = self._startTime - dInfo['service start time'] = self._startTime - serviceUptime = datetime.utcnow() - startTime - dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds - # Load average - try: - with open("/proc/loadavg", 'rt') as oFD: - dInfo['load'] = " ".join(oFD.read().split()[:3]) - except Exception: # pylint: disable=broad-except - pass - dInfo['name'] = self._serviceInfoDict['serviceName'] - stTimes = os.times() - dInfo['cpu times'] = {'user time': stTimes[0], - 'system time': stTimes[1], - 'children user time': stTimes[2], - 'children system time': stTimes[3], - 'elapsed real time': stTimes[4] - } - - return S_OK(dInfo) - - auth_echo = ['all'] - - @staticmethod - def export_echo(data): - """ - This method used for testing the performance of a service - """ - return S_OK(data) - - auth_whoami = ['authenticated'] - - def export_whoami(self): - """ - A simple whoami, returns all credential dictionary, except certificate chain object. - """ - credDict = self.srv_getRemoteCredentials() - if 'x509Chain' in credDict: - # Not serializable - del credDict['x509Chain'] - return S_OK(credDict) - -#### -# -# Utilities methods, some getters. -# From DIRAC.Core.DISET.requestHandler to get same interface in the handlers. -# Adapted for Tornado. -# These method are copied from DISET RequestHandler, they are not all used when i'm writing -# these lines. I rewrite them for Tornado to get them ready when a new HTTPS service need them -# -#### - - @classmethod - def srv_getCSOption(cls, optionName, defaultValue=False): - """ - Get an option from the CS section of the services - - :return: Value for serviceSection/optionName in the CS being defaultValue the default - """ - if optionName[0] == "/": - return gConfig.getValue(optionName, defaultValue) - for csPath in cls._serviceInfoDict['csPaths']: - result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) - if result['OK']: - return result['Value'] - return defaultValue - - def getCSOption(self, optionName, defaultValue=False): - """ - Just for keeping same public interface - """ - return self.srv_getCSOption(optionName, defaultValue) - - def srv_getRemoteAddress(self): - """ - Get the address of the remote peer. - - :return: Address of remote peer. - """ - - remote_ip = self.request.remote_ip - # Although it would be trivial to add this attribute in _HTTPRequestContext, - # Tornado won't release anymore 5.1 series, so go the hacky way - try: - remote_port = self.request.connection.stream.socket.getpeername()[1] - except Exception: # pylint: disable=broad-except - remote_port = 0 - - return (remote_ip, remote_port) - - def getRemoteAddress(self): - """ - Just for keeping same public interface - """ - return self.srv_getRemoteAddress() - - def srv_getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def srv_getFormattedRemoteCredentials(self): - """ - Return the DN of user - - Mostly copy paste from - :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` - - Note that the information will be complete only once the AuthManager was called - """ - address = self.getRemoteAddress() - peerId = "" - # Depending on where this is call, it may be that credDict is not yet filled. - # (reminder: AuthQuery fills part of it..) - try: - peerId = "[%s:%s]" % (self.credDict['group'], self.credDict['username']) - except AttributeError: - pass - - if address[0].find(":") > -1: - return "([%s]:%s)%s" % (address[0], address[1], peerId) - return "(%s:%s)%s" % (address[0], address[1], peerId) - -# def getFormattedCredentials(self): -# peerCreds = self.getConnectingCredentials() -# address = self.getRemoteAddress() -# if 'username' in peerCreds: -# peerId = "[%s:%s]" % (peerCreds['group'], peerCreds['username']) -# else: -# peerId = "" -# if address[0].find(":") > -1: -# return "([%s]:%s)%s" % (address[0], address[1], peerId) -# return "(%s:%s)%s" % (address[0], address[1], peerId) - - def srv_getServiceName(self): - """ - Return the service name - """ - return self._serviceInfoDict['serviceName'] - - def srv_getURL(self): - """ - Return the URL - """ - return self.request.path + return decode(args_encoded)[0] diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py new file mode 100644 index 00000000000..b71254634cf --- /dev/null +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import os +import sys +import tornado + +from DIRAC.Core.Utilities.DIRACScript import DIRACScript + + +@DIRACScript() +def main(): + # Must be define BEFORE any dirac import + os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" + + from DIRAC.ConfigurationSystem.Client.PathFinder import getAPISection + from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData + from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration + from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer + from DIRAC.Core.Utilities.DErrno import includeExtensionErrors + from DIRAC.FrameworkSystem.Client.Logger import gLogger + + localCfg = LocalConfiguration() + localCfg.addMandatoryEntry("/DIRAC/Setup") + localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") + localCfg.addDefaultEntry("LogLevel", "INFO") + localCfg.addDefaultEntry("LogColor", True) + resultDict = localCfg.loadUserData() + if not resultDict['OK']: + gLogger.initialize("Tornado", "/") + gLogger.error("There were errors when loading configuration", resultDict['Message']) + sys.exit(1) + + includeExtensionErrors() + + gLogger.initialize('Tornado', "/") + + endpoints = ['Framework/Auth'] + try: + asPort = int(gConfigurationData.extractOptionFromCFG('%s/Port' % getAPISection('Framework/Auth'))) + except TypeError: + asPort = None + + serverToLaunch = TornadoServer(False, endpoints, port=asPort) + + serverToLaunch.startTornado() + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py index 61462857ce7..27d30fe591a 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_CS.py @@ -58,7 +58,7 @@ def main(): except TypeError: csPort = None - serverToLaunch = TornadoServer(services='Configuration/Server', port=csPort) + serverToLaunch = TornadoServer(services=['Configuration/Server'], port=csPort) serverToLaunch.startTornado() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py new file mode 100644 index 00000000000..3a8d37212ab --- /dev/null +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import os +import sys +import tornado +import pprint + +from DIRAC.Core.Utilities.DIRACScript import DIRACScript + + +@DIRACScript() +def main(): + # Must be define BEFORE any dirac import + os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" + + from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration + from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer + from DIRAC.Core.Utilities.DErrno import includeExtensionErrors + from DIRAC.FrameworkSystem.Client.Logger import gLogger + + localCfg = LocalConfiguration() + localCfg.addMandatoryEntry("/DIRAC/Setup") + localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") + localCfg.addDefaultEntry("LogLevel", "INFO") + localCfg.addDefaultEntry("LogColor", True) + resultDict = localCfg.loadUserData() + if not resultDict['OK']: + gLogger.initialize("Tornado", "/") + gLogger.error("There were errors when loading configuration", resultDict['Message']) + sys.exit(1) + + includeExtensionErrors() + + gLogger.initialize('Tornado', "/") + + services = ['DataManagement/TornadoFileCatalog'] + endpoints = False + + serverToLaunch = TornadoServer(services, endpoints, port=8000) + + try: + from WebAppDIRAC.Core.App import App + except ImportError as e: + gLogger.fatal('Web portal is not installed. %s' % repr(e)) + sys.exit(1) + + # Get routes and settings for a portal + result = App().getAppToDict(8000) + if not result['OK']: + gLogger.fatal(result['Message']) + sys.exit(1) + app = result['Value'] + + serverToLaunch.addHandlers(app['routes'], app['settings']) + + serverToLaunch.startTornado() + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py new file mode 100644 index 00000000000..ff73f0a84be --- /dev/null +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -0,0 +1,527 @@ +""" This handler basically provides a REST interface to interact with the OAuth 2 authentication server +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import pprint +from io import open + +from dominate import document, tags as dom +from tornado.template import Template + +from authlib.jose import jwk +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +__RCSID__ = "$Id$" + + +class AuthHandler(TornadoREST): + # Authorization access to all methods handled by AuthServer instance + USE_AUTHZ_GRANTS = ['JWT', 'VISITOR'] + SYSTEM = 'Framework' + AUTH_PROPS = 'all' + LOCATION = "/DIRAC/auth" + css_align_center = 'display:block;justify-content:center;align-items:center;' + css_center_div = 'height:700px;width:100%;position:absolute;top:50%;left:0;margin-top:-350px;' + css_big_text = 'font-size:28px;' + css_main = ' '.join([css_align_center, css_center_div, css_big_text]) + CSS = """ +.button { + border-radius: 4px; + background-color: #ffffff00; + border: none; + color: black; + text-align: center; + font-size: 14px; + padding: 12px; + width: 100%; + transition: all 0.5s; + cursor: pointer; + margin: 5px; + display: block; /* Make the links appear below each other */ +} +.button a { + color: black; + cursor: pointer; + display: inline-block; + position: relative; + transition: 0.5s; + text-decoration: none; /* Remove underline from links */ +} +.button a:after { + content: '\\00bb'; + position: absolute; + opacity: 0; + top: 0; + right: -20px; + transition: 0.5s; +} +.button:hover a { + padding-right: 25px; +} +.button:hover a:after { + opacity: 1; + right: 0; +}""" + + @classmethod + def initializeHandler(cls, serviceInfo): + """ This method is called only one time, at the first request + + :param dict ServiceInfoDict: infos about services + """ + cls.server = AuthServer() + cls.server.css = dict(CSS=cls.CSS, css_align_center=cls.css_align_center, css_main=cls.css_main) + cls.server.LOCATION = cls.LOCATION + + def initializeRequest(self): + """ Called at every request """ + self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path + # Template for a html UI + self.doc = document('DIRAC authentication') + with self.doc.head: + dom.link(rel='stylesheet', + href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") + dom.style(self.CSS) + + def _parseDIRACResult(self, result): + """ Here the result which returns handle_response is processed + """ + if not result['OK']: + # If response error is DIRAC server error, not OAuth2 flow error + self.removeSession() + self.set_status = 400 + self.write({'error': 'server_error', + 'description': '%s:\n%s' % (result['Message'], '\n'.join(result['CallStack']))}) + else: + # Successful responses and OAuth2 errors are processed here + status_code, headers, payload, new_session, error = result['Value'][0] + if status_code: + self.set_status(status_code) + if headers: + for key, value in headers: + self.set_header(key, value) + if payload: + self.write(payload) + if new_session: + self.saveSession(new_session) + if error: + self.removeSession() + for method, args_kwargs in result['Value'][1].items(): + eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) + + def saveSession(self, session): + """ Save session to cookie + + :param dict session: session + """ + self.set_secure_cookie('auth_session', json.dumps(session), secure=True, httponly=True) + + def removeSession(self): + """ Remove session from cookie """ + self.clear_cookie('auth_session') + + def getSession(self, state=None, **kw): + """ Get session from cookie + + :param str state: state + + :return: dict + """ + try: + session = json.loads(self.get_secure_cookie('auth_session')) + checkState = (session['state'] == state) if state else None + checkOption = (session[kw.items()[0][0]] == kw.items()[0][0]) if kw else None + except Exception as e: + return None + return session if (checkState or checkOption) else None + + path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] + + def web_index(self, instance): + """ Well known endpoint, specified by + `RFC8414 `_ + + Request examples:: + + GET: LOCATION/.well-known/openid-configuration + GET: LOCATION/.well-known/oauth-authorization-server + + Responce:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "registration_endpoint": "https://domain.com/DIRAC/auth/register", + "userinfo_endpoint": "https://domain.com/DIRAC/auth/userinfo", + "jwks_uri": "https://domain.com/DIRAC/auth/jwk", + "code_challenge_methods_supported": [ + "S256" + ], + "grant_types_supported": [ + "authorization_code", + "code", + "urn:ietf:params:oauth:grant-type:device_code", + "implicit", + "refresh_token" + ], + "token_endpoint": "https://domain.com/DIRAC/auth/token", + "response_types_supported": [ + "code", + "device", + "id_token token", + "id_token", + "token" + ], + "authorization_endpoint": "https://domain.com/DIRAC/auth/authorization", + "issuer": "https://domain.com/DIRAC/auth" + } + """ + if self.request.method == "GET": + return dict(self.server.metadata) + + def web_jwk(self): + """ JWKs endpoint + + Request example:: + + GET LOCATION/jwk + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "keys": [ + { + "e": "AQAB", + "kty": "RSA", + "n": "3Vv5h5...X3Y7k" + } + ] + } + """ + return self.server.db.getJWKs().get('Value', {}) + + def web_revoke(self): + """ Revocation endpoint + + Request example:: + + GET LOCATION/revoke + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + """ + if self.request.method == 'POST': + self.log.verbose('Initialize a Device authentication flow.') + return self.server.create_endpoint_response(RevocationEndpoint.ENDPOINT_NAME, self.request) + + def web_userinfo(self): + """ The UserInfo endpoint can be used to retrieve identity information about a user, + see `spec `_ + + GET LOCATION/userinfo + + Parameters: + +---------------+--------+---------------------------------+--------------------------------------------------+ + | **name** | **in** | **description** | **example** | + +---------------+--------+---------------------------------+--------------------------------------------------+ + | Authorization | header | Provide access token | Bearer jkagfbfd3r4ubf887gqduyqwogasd87 | + +---------------+--------+---------------------------------+--------------------------------------------------+ + + Request example:: + + GET LOCATION/userinfo + Authorization: Bearer + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "248289761001", + "name": "Bob Smith", + "given_name": "Bob", + "family_name": "Smith", + "group": [ + "dirac_user", + "dirac_admin" + ] + } + """ + # Token verification + # token = ResourceProtector().acquire_token(self.request, '') + # return {'sub': token.sub, 'issuer': token.issuer, 'group': token.groups[0]} + userinfo = self.getRemoteCredentials() + return userinfo + + path_device = ['([A-z0-9-_]*)'] + + def web_device(self, provider=None): + """ The device authorization endpoint can be used to request device and user codes. + This endpoint is used to start the device flow authorization process and user code verification. + + POST LOCATION/device/? + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | user code | query | in the last step to confirm recived user | WE8R-WEN9 | + | | | code put it as query parameter (optional) | | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The public client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + + User code confirmation:: + + GET LOCATION/device/?user_code= + + Request example, to initialize a Device authentication flow:: + + POST LOCATION/device/CheckIn_dev?client_id=3f1DAj8z6eNw0E6JGq1Vu6efZwyV&scope=g:dirac_admin + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_code": "TglwLiow0HUwowjB9aHH5HqH3bZKP9d420LkNhCEuR", + "verification_uri": "https://marosvn32.in2p3.fr/DIRAC/auth/device", + "interval": 5, + "expires_in": 1800, + "verification_uri_complete": "https://marosvn32.in2p3.fr/DIRAC/auth/device/WSRL-HJMR", + "user_code": "WSRL-HJMR" + } + + Request example, to confirm the user code:: + + POST LOCATION/device/CheckIn_dev/WSRL-HJMR + + Response:: + + HTTP/1.1 200 OK + """ + if self.request.method == 'POST': + self.log.verbose('Initialize a Device authentication flow.') + return self.server.create_endpoint_response(DeviceAuthorizationEndpoint.ENDPOINT_NAME, self.request) + + elif self.request.method == 'GET': + userCode = self.get_argument('user_code', None) + if userCode: + # If received a request with a user code, then prepare a request to authorization endpoint + self.log.verbose('User code verification.') + result = self.server.db.getSessionByUserCode(userCode) + if not result['OK']: + return 'Device code flow authorization session %s expired.' % userCode + session = result['Value'] + # Get original request from session + req = createOAuth2Request(dict(method='GET', uri=session['uri'])) + + groups = [s.split(':')[1] for s in scope_to_list(req.scope) if s.startswith('g:')] + group = groups[0] if groups else None + + if group and not provider: + print(group) + provider = Registry.getIdPForGroup(group) + + print('Use provider:', provider) + + authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) + # Save session to cookie + return self.server.handle_response(302, {}, [("Location", authURL)], session) + + # If received a request without a user code, then send a form to enter the user code + with self.doc: + dom.div(dom.form(dom._input(type="text", name="user_code", style=self.css_big_text), + dom.button('Submit', type="submit", style=self.css_big_text), + action=self.currentPath, method="GET"), style=self.css_main) + return Template(self.doc.render()).generate() + + path_authorization = ['([A-z0-9-_]*)'] + + def web_authorization(self, provider=None): + """ Authorization endpoint + + GET: LOCATION/authorization/ + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | response_type | query | informs of the desired grant type | code | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + General options: + provider -- identity provider to autorize + + Device flow: + &user_code=.. (required) + + Authentication code flow: + &scope=.. (optional) + &redirect_uri=.. (optional) + &state=.. (main session id, optional) + &code_challenge=.. (PKCE, optional) + &code_challenge_method=(pain|S256) ('pain' by default, optional) + """ + return self.server.validate_consent_request(self.request, provider) + + def web_redirect(self): + """ Redirect endpoint. + After a user successfully authorizes an application, the authorization server will redirect + the user back to the application with either an authorization code or access token in the URL. + The full URL of this endpoint must be registered in the identity provider. + + Read more in `oauth.com `_. + Specified by `RFC6749 `_. + + GET LOCATION/redirect + + Parameters:: + + &chooseScope=.. to specify new scope(group in our case) (optional) + """ + # Current IdP session state + state = self.get_argument('state') + + # Try to catch errors + if self.get_argument('error', None): + error = OAuth2Error(error=self.get_argument('error'), description=self.get_argument('error_description', '')) + return self.server.handle_error_response(state, error) + + # Check current auth session that was initiated for the selected external identity provider + sessionWithExtIdP = self.getSession(state) + if not sessionWithExtIdP: + return S_ERROR("%s session is expired." % state) + + if not sessionWithExtIdP.get('authed'): + # Parse result of the second authentication flow + self.log.info('%s session, parsing authorization response:\n' % state, + '\n'.join([self.request.uri, self.request.query, self.request.body, str(self.request.headers)])) + + result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) + if not result['OK']: + return result + # Return main session flow + sessionWithExtIdP['authed'] = result['Value'] + + # Research group + grant_user, response = self.__researchDIRACGroup(sessionWithExtIdP) + if not grant_user: + return response + + # RESPONSE to basic DIRAC client request + return self.server.create_authorization_response(response, grant_user) + + def web_token(self): + """ The token endpoint, the description of the parameters will differ depending on the selected grant_type + + POST LOCATION/token + + Parameters: + +----------------+--------+-------------------------------------+---------------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | grant_type | query | what grant type to use, more | urn:ietf:params:oauth:grant-type:device_code| + | | | supported grant types in *grants | | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | + +----------------+--------+-------------------------------------+---------------------------------------------+ + + *:mod:`grants ` + + Request example:: + + POST LOCATION/token?client_id=L86..yV&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code=uW5 + + Response:: + + HTTP/1.1 400 OK + Content-Type: application/json + + { + "error": "authorization_pending" + } + """ + return self.server.create_token_response(self.request) + + def __researchDIRACGroup(self, extSession): + """ Research DIRAC groups for authorized user + + :param dict extSession: ended authorized external IdP session + + :return: response + """ + # Base DIRAC client auth session + firstRequest = createOAuth2Request(extSession['mainSession']) + # Read requested groups by DIRAC client or user + firstRequest.addScopes(self.get_arguments('chooseScope', [])) + # Read already authed user + username = extSession['authed']['username'] + self.log.debug('Next groups has been found for %s:' % username, ', '.join(firstRequest.groups)) + + # Researche Group + result = Registry.getGroupsForUser(username) + if not result['OK']: + return None, result + validGroups = result['Value'] + if not validGroups: + return None, S_ERROR('No groups found for %s.' % username) + + self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) + + if not firstRequest.groups: + if len(validGroups) == 1: + firstRequest.addScopes(['g:%s' % validGroups[0]]) + else: + # Choose group interface + with self.doc: + with dom.div(style=self.css_main): + with dom.div('Choose group', style=self.css_align_center): + for group in validGroups: + dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, + self.get_argument('state'), group)), + cls='button') + return None, self.server.handle_response(payload=Template(self.doc.render()).generate(), newSession=extSession) + + # Return grant user + return extSession['authed'], firstRequest diff --git a/src/DIRAC/FrameworkSystem/API/__init__.py b/src/DIRAC/FrameworkSystem/API/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg index 8d4289ecb9d..d023e7bf226 100644 --- a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg +++ b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg @@ -1,3 +1,12 @@ +APIs +{ + Auth + { + Port = 8000 + # Allow download personal proxy + downloadablePersonalProxy = True + } +} Services { Gateway diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py new file mode 100644 index 00000000000..911172aeb22 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -0,0 +1,385 @@ +""" Auth class is a front-end to the Auth Database +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import jwt as _jwt + +from time import time +from pprint import pprint +from M2Crypto import RSA, BIO +from sqlalchemy import Column, Integer, Text, String +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.ext.declarative import declarative_base + +from authlib.jose import KeySet, RSAKey, jwk +from authlib.common.security import generate_token +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin + +from DIRAC import S_OK, S_ERROR, gLogger, gConfig +from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + + +Model = declarative_base() + + +class Token(Model, OAuth2TokenMixin): + __tablename__ = 'Token' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + # access_token too large for varchar(255) + # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 + # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length + id = Column(Integer, autoincrement=True, primary_key=True) + access_token = Column(Text, nullable=False) + refresh_token = Column(Text, nullable=False) + expires_at = Column(Integer, nullable=False, default=0) + + +class JWK(Model): + __tablename__ = 'JWK' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + kid = Column(String(255), unique=True, primary_key=True, nullable=False) + key = Column(Text, nullable=False) + expires_at = Column(Integer, nullable=False, default=0) + + +class AuthSession(Model): + __tablename__ = 'AuthSession' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + id = Column(String(255), unique=True, primary_key=True, nullable=False) + state = Column(String(255)) + uri = Column(String(255)) + client_id = Column(String(255)) + user_id = Column(String(255)) + username = Column(String(255)) + expires_at = Column(Integer, nullable=False, default=0) + expires_in = Column(Integer, nullable=False, default=0) + interval = Column(Integer, nullable=False, default=5) + verification_uri = Column(String(255)) + verification_uri_complete = Column(String(255)) + user_code = Column(String(255)) + device_code = Column(String(255)) + scope = Column(String(255)) + + +class AuthDB(SQLAlchemyDB): + """ AuthDB class is a front-end to the OAuth Database + """ + def __init__(self): + """ Constructor + """ + super(AuthDB, self).__init__() + self._initializeConnection('Framework/AuthDB') + result = self.__initializeDB() + if not result['OK']: + raise Exception("Can't create tables: %s" % result['Message']) + self.session = scoped_session(self.sessionMaker_o) + + def __initializeDB(self): + """ Create the tables + """ + tablesInDB = self.inspector.get_table_names() + + # Token + if 'Token' not in tablesInDB: + try: + Token.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + # JWK + if 'JWK' not in tablesInDB: + try: + JWK.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + # AuthSession + if 'AuthSession' not in tablesInDB: + try: + AuthSession.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + return S_OK() + + def getToken(self, token, token_type_hint='refresh_token'): + """ Find Token for refresh token + + :param str token: token + :param str token_type_hint: token type + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + session.query(Token).filter(Token.expires_at < time()).delete() + if token_type_hint == 'access_token': + token = session.query(Token).filter(Token.access_token == token).first() + else: + token = session.query(Token).filter(Token.refresh_token == token).first() + if not token: + return self.__result(session, S_ERROR("Token not found.")) + except NoResultFound: + return self.__result(session, S_ERROR("Token not found.")) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)))) + + def revokeToken(self, token): + """ Revoke token + + :param dict token: token to revoke + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + token = session.query(Token).filter(Token.access_token == token['access_token']).first() + token.revoked = True + except NoResultFound: + return self.__result(session, S_OK()) + except Exception as e: + return self.__result(session, S_ERROR('Could not revoke token: %s' % e)) + return self.__result(session, S_OK()) + + def storeToken(self, token): + """ Save token + + :param dict token: token info + + :return: S_OK(str)/S_ERROR() + """ + token['expires_at'] = int(_jwt.decode(token['refresh_token'], options=dict(verify_signature=False))['exp']) + gLogger.debug('Store token:', dict(token)) + attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) + session = self.session() + try: + session.query(Token).filter(Token.access_token == token['access_token']).delete() + session.add(Token(**attrts)) + except Exception as e: + return self.__result(session, S_ERROR('Could not add Token: %s' % e)) + return self.__result(session, S_OK('Token successfully added')) + + def removeTokens(self): + """ Get active keys + + :return: S_OK(list)/S_ERROR() + """ + session = self.session() + try: + session.query(Token).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def generateRSAKeys(self): + """ Generate an RSA keypair with an exponent of 65537 in PEM format + + :return: S_OK/S_ERROR + """ + key = RSAKey.generate_key(key_size=1024, is_private=True) + dictKey = dict(key=json.dumps(key.as_dict()), + expires_at=time() + (30 * 24 *3600), + kid=KeySet([key]).as_dict()['keys'][0]['kid']) + + session = self.session() + try: + session.add(JWK(**dictKey)) + except Exception as e: + return self.__result(session, S_ERROR('Could not generate keys: %s' % e)) + return self.__result(session, S_OK(dictKey)) + + def getKeySet(self): + """ Get key set + + :return: S_OK(obj)/S_ERROR() + """ + keys = [] + result = self.getActiveKeys() + if result['OK'] and not result['Value']: + result = self.generateRSAKeys() + if result['OK']: + result = self.getActiveKeys() + if not result['OK']: + return result + for keyDict in result['Value']: + key = RSAKey.import_key(json.loads(keyDict['key'])) + keys.append(key) + return S_OK(KeySet(keys)) + + def getJWKs(self): + """ Get JWKs list + + :return: S_OK(dict)/S_ERROR() + """ + keys = [] + result = self.getKeySet() + if not result['OK']: + return result + for k in result['Value'].as_dict()['keys']: + keys.append({'n': k['n'], "kty": k['kty'], "e": k['e'], "kid": k['kid']}) + return S_OK({'keys': keys}) + + def getPrivateKey(self): + """ Get private key + + :return: S_OK(obj)/S_ERROR() + """ + result = self.getActiveKeys() + if not result['OK']: + return result + newer = {} + for d in result['Value']: + if d['expires_at'] > newer.get('expires_at', time() + (24 * 3600)): + newer = d + if not newer.get('key'): + result = self.generateRSAKeys() + if not result['OK']: + return result + newer = result['Value'] + return S_OK({'key': RSAKey.import_key(json.loads(newer['key'])), 'kid': newer['kid']}) + + def getActiveKeys(self): + """ Get active keys + + :return: S_OK(list)/S_ERROR() + """ + session = self.session() + try: + # Remove all expired jwks + session.query(JWK).filter(JWK.expires_at < time()).delete() + jwks = session.query(JWK).filter(JWK.expires_at > time()).all() + except NoResultFound: + return self.__result(session, S_OK([])) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK([self.__rowToDict(jwk) for jwk in jwks])) + + def removeKeys(self): + """ Get active keys + + :return: S_OK(list)/S_ERROR() + """ + session = self.session() + try: + session.query(JWK).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def addSession(self, data): + """ Add new session + + :param dict data: session metadata + + :return: S_OK(dict)/S_ERROR() + """ + attrts = {} + if not data.get('expires_at'): + data['expires_at'] = data['expires_in'] + time() + gLogger.debug('Add authorization session:', data) + for k, v in data.items(): + if k not in AuthSession.__dict__.keys(): + self.log.warn('%s is not expected as authentication session attribute.' % k) + else: + attrts[k] = v + session = self.session() + try: + session.add(AuthSession(**attrts)) + except Exception as e: + return self.__result(session, S_ERROR('Could not add Token: %s' % e)) + return self.__result(session, S_OK('Token successfully added')) + + def updateSession(self, data, sessionID): + """ Update session data + + :param dict data: data info + :param str sessionID: sessionID + + :return: S_OK(object)/S_ERROR() + """ + self.removeSession(sessionID=sessionID) + return self.addSession(data) + + def removeSession(self, sessionID): + """ Remove session + + :param str sessionID: session id + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + # Remove all expired sessions + session.query(AuthSession).filter(AuthSession.expires_at < time()).delete() + session.query(AuthSession).filter(AuthSession.id == sessionID).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def getSession(self, sessionID): + """ Get client + + :param str sessionID: session id + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.id == sessionID).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + except NoResultFound: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def getSessionByUserCode(self, userCode): + """ Get client + + :param str userCode: user code + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + except NoResultFound: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def __result(self, session, result=None): + try: + if not result['OK']: + session.rollback() + else: + session.commit() + except Exception as e: + session.rollback() + result = S_ERROR('Could not commit: %s' % (e)) + session.close() + return result + + def __rowToDict(self, row): + """ Convert sqlalchemy row to dictionary + + :param object row: sqlalchemy row + + :return: dict + """ + return {c.name: str(getattr(row, c.name)) for c in row.__table__.columns} if row else {} diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.sql b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql new file mode 100644 index 00000000000..bfd422307e6 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql @@ -0,0 +1,2 @@ +# Everything is created by the DB object upon instantiation if it does not exists. +use AuthDB; \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py new file mode 100644 index 00000000000..ab2851d2eb0 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python +######################################################################## +# File : dirac-login.py +# Author : Adrian Casajus +######################################################################## +""" +Login to DIRAC. + +Example: + $ dirac-login -g dirac_user +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys +import urllib3 +import requests +import threading + +import DIRAC +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Base import Script +from DIRAC.Core.Utilities.DIRACScript import DIRACScript +from DIRAC.Core.Security.TokenFile import writeTokenDictToTokenFile +from DIRAC.Core.Security.ProxyFile import writeToProxyFile +from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString +from DIRAC.Core.Security.TokenInfo import getTokenInfo, formatTokenInfoAsString +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +__RCSID__ = "$Id$" + + +class Params(object): + + def __init__(self): + self.proxy = False + self.group = None + self.lifetime = None + self.provider = 'DIRACCLI' + self.issuer = None + self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() + + def returnProxy(self, _arg): + """ Set email + + :return: S_OK() + """ + self.proxy = True + return S_OK() + + def setGroup(self, arg): + """ Set email + + :param str arg: group + + :return: S_OK() + """ + self.group = arg + return S_OK() + + def setProvider(self, arg): + """ Set email + + :param str arg: provider + + :return: S_OK() + """ + self.provider = arg + return S_OK() + + def setIssuer(self, arg): + """ Set email + + :param str arg: issuer + + :return: S_OK() + """ + self.issuer = arg + return S_OK() + + def setLivetime(self, arg): + """ Set email + + :param str arg: lifetime + + :return: S_OK() + """ + self.lifetime = arg + return S_OK() + + def registerCLISwitches(self): + """ Register CLI switches """ + Script.registerSwitch( + "P", + "proxy", + "return with an access token also a proxy certificate with DIRAC group extension", + self.returnProxy) + Script.registerSwitch( + "g:", + "group=", + "set DIRAC group", + self.setGroup) + Script.registerSwitch( + "O:", + "provider=", + "set identity provider", + self.setProvider) + Script.registerSwitch( + "I:", + "issuer=", + "set issuer", + self.setIssuer) + Script.registerSwitch( + "T:", + "lifetime=", + "set proxy lifetime in a hours", + self.setLivetime) + + def doOAuthMagic(self): + """ Magic method with tokens + + :return: S_OK()/S_ERROR() + """ + params = {} + if self.issuer: + params['issuer'] = self.issuer + result = IdProviderFactory().getIdProvider(self.provider, **params) + if not result['OK']: + return result + idpObj = result['Value'] + if self.group: + idpObj.scope += '+g:%s' % self.group + if self.proxy: + idpObj.scope += '+proxy' + if self.lifetime: + idpObj.scope += '+lifetime:%s' % (int(self.lifetime) * 3600) + + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + # Submit Device authorisation flow + try: + result = idpObj.authorization() + except KeyboardInterrupt as e: + return S_ERROR(repr(e)) + if not result['OK']: + return result + + if self.proxy: + result = writeToProxyFile(idpObj.token['proxy'].encode("UTF-8"), self.proxyLoc) + if not result['OK']: + return result + gLogger.notice('Proxy is saved to %s.' % self.proxyLoc) + else: + result = writeTokenDictToTokenFile(idpObj.token) + if not result['OK']: + return result + gLogger.notice('Token is saved in %s.' % result['Value']) + + result = Script.enableCS() + if not result['OK']: + return S_ERROR("Cannot contact CS to get user list") + DIRAC.gConfig.forceRefresh() + + if self.proxy: + result = getProxyInfo(self.proxyLoc) + if not result['OK']: + return result['Message'] + gLogger.notice(formatProxyInfoAsString(result['Value'])) + else: + result = getTokenInfo(self.proxyLoc) + if not result['OK']: + return result['Message'] + gLogger.notice(formatTokenInfoAsString(result['Value'])) + + return S_OK(self.proxyLoc) + + +@DIRACScript() +def main(): + piParams = Params() + piParams.registerCLISwitches() + + Script.disableCS() + Script.parseCommandLine(ignoreErrors=True) + DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") + + resultDoMagic = piParams.doOAuthMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py new file mode 100644 index 00000000000..a4976f3a9f2 --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -0,0 +1,30 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider + +__RCSID__ = "$Id$" + + +class CheckInIdProvider(OAuth2IdProvider): + + # urn:mace:egi.eu:group:registry:training.egi.eu:role=member#aai.egi.eu' + NAMESPACE = 'urn:mace:egi.eu:group:registry' + SIGN = '#aai.egi.eu' + PARAM_SCOPE = 'eduperson_entitlement?value=' + + def researchGroup(self, payload, token=None): + """ Research group + """ + if token: + self.token = token + claims = self.getUserProfile() + credDict = self.parseBasic(claims) + credDict.update(self.parseEduperson(claims)) + cerdDict = self.userDiscover(credDict) + credDict['provider'] = self.name + + return credDict diff --git a/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py new file mode 100644 index 00000000000..68d6c22495b --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py @@ -0,0 +1,21 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata + +__RCSID__ = "$Id$" + + +class DIRACIdProvider(OAuth2IdProvider): + + def fetch_metadata(self, url=None): + """ Fetch metada + """ + self.metadata.update(collectMetadata(self.metadata['issuer'])) + if url: + return self.get(url, withhold_token=True).json() + diff --git a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py new file mode 100644 index 00000000000..a8187e61d82 --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py @@ -0,0 +1,17 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider + +__RCSID__ = "$Id$" + + +class IAMIdProvider(OAuth2IdProvider): + + def researchGroup(self, payload, token): + """ Research group + """ + pass diff --git a/src/DIRAC/Resources/IdProvider/IdProvider.py b/src/DIRAC/Resources/IdProvider/IdProvider.py index a3ea4c52383..4e209f6b14a 100644 --- a/src/DIRAC/Resources/IdProvider/IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IdProvider.py @@ -4,16 +4,27 @@ from __future__ import division from __future__ import print_function -from DIRAC import gLogger +from DIRAC import gLogger, S_OK __RCSID__ = "$Id$" class IdProvider(object): - def __init__(self, parameters=None): + def __init__(self, *args, **kwargs): + """ C'or + """ self.log = gLogger.getSubLogger(self.__class__.__name__) - self.parameters = parameters + self.parameters = kwargs.get('parameters', {}) + self._initialization() + + def _initialization(self): + """ Initialization """ + pass def setParameters(self, parameters): + """ Set parameters + + :param dict parameters: parameters of the identity Provider + """ self.parameters = parameters diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 1134654f218..65907c5e433 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -10,52 +10,112 @@ from __future__ import division from __future__ import print_function +import jwt as _jwt + from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getInfoAboutProviders +from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo, getSettingsNamesForIdPIssuer +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS __RCSID__ = "$Id$" +gCacheMetadata = ThreadSafe.Synchronizer() + class IdProviderFactory(object): - ############################################################################# def __init__(self): """ Standard constructor """ self.log = gLogger.getSubLogger('IdProviderFactory') + self.cacheMetadata = DictCache() + + @gCacheMetadata + def getMetadata(self, idP): + return self.cacheMetadata.get(idP) + + @gCacheMetadata + def addMetadata(self, idP, data, time=24 * 3600): + if data: + self.cacheMetadata.add(idP, time, data) + + def getIdProviderForToken(self, token): + """ This method returns a IdProvider instance corresponding to the supplied + issuer in a token. + + :param str token: token + + :return: S_OK(IdProvider)/S_ERROR() + """ + data = {} + + # Read token without verification to get issuer + issuer = _jwt.decode(token, options=dict(verify_signature=False))['iss'].strip('/') + + result = getSettingsNamesForIdPIssuer(issuer) + if result['OK']: + return self.getIdProvider(result['Value'][0]) + + _result = getAuthorisationServerMetadata() + if not _result['OK']: + return _result + if issuer == _result['Value'].get('issuer', '').strip('/'): + return self.getIdProvider(DEFAULT_CLIENTS.keys()[0]) - ############################################################################# - def getIdProvider(self, idProvider): - """ This method returns a IdProvider instance corresponding to the supplied name. + return result - :param str idProvider: the name of the Identity Provider + def getIdProvider(self, name, **kwargs): + """ This method returns a IdProvider instance corresponding to the supplied + name. + + :param str name: the name of the Identity Provider :return: S_OK(IdProvider)/S_ERROR() """ - result = getInfoAboutProviders(of='Id', providerName=idProvider, option="all", section="all") + self.log.debug('Search %s configuration..' % name) + pDict = DEFAULT_CLIENTS.get(name, {}) + if pDict: + result = getAuthorisationServerMetadata() + if not result['OK']: + return result + pDict.update(result['Value']) + pDict.update(kwargs) + + result = getProviderInfo(name) if not result['OK']: - return result - pDict = result['Value'] - pDict['ProviderName'] = idProvider + if not pDict: + self.log.error('Failed to read configuration', '%s: %s' % (name, result['Message'])) + return result + gLogger.debug(result['Message']) + else: + pDict.update(result['Value']) + pDict['ProviderName'] = name + pType = pDict['ProviderType'] - self.log.verbose('Creating IdProvider', 'of %s type with the name %s' % (pType, idProvider)) + self.log.verbose('Creating IdProvider of %s type with the name %s' % (pType, name)) subClassName = "%sIdProvider" % (pType) - result = ObjectLoader().loadObject('Resources.IdProvider.%s' % subClassName) + objectLoader = ObjectLoader.ObjectLoader() + result = objectLoader.loadObject('Resources.IdProvider.%s' % subClassName, subClassName) if not result['OK']: self.log.error('Failed to load object', '%s: %s' % (subClassName, result['Message'])) return result pClass = result['Value'] try: - provider = pClass() - provider.setParameters(pDict) + meta = self.getMetadata(name) + if meta: + pDict.update(meta) + provider = pClass(**pDict) + if not meta and hasattr(provider, 'metadata'): + self.addMetadata(name, provider.metadata) except Exception as x: msg = 'IdProviderFactory could not instantiate %s object: %s' % (subClassName, str(x)) self.log.exception() - self.log.error(msg) + self.log.warn(msg) return S_ERROR(msg) return S_OK(provider) diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py new file mode 100644 index 00000000000..8bfca998fbd --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -0,0 +1,442 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import six +import jwt as _jwt +import time +import pprint +import requests +from requests import exceptions +from authlib.jose import JsonWebKey, jwt +from authlib.common.urls import url_decode +from authlib.common.security import generate_token +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oauth2.rfc6749.parameters import prepare_token_request +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.oauth2.rfc8628 import DEVICE_CODE_GRANT_TYPE +from authlib.integrations.requests_client import OAuth2Session +from authlib.oidc.discovery.well_known import get_well_known_url +from authlib.oauth2.rfc7636 import create_s256_code_challenge + +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC.Resources.IdProvider.IdProvider import IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getVOForGroup, getGroupOption + +__RCSID__ = "$Id$" + +DEFAULT_HEADERS = { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' +} + + +def claimParser(claimDict, attributes): + """ Parse claims to write it as DIRAC profile + + :param dict claimDict: claims + :param dict attributes: contain claim and regex to parse it + :param dict profile: to fill parsed data + + :return: dict + """ + profile = {} + result = None + for claim, reg in attributes.items(): + if claim not in claimDict: + continue + profile[claim] = {} + if isinstance(claimDict[claim], dict): + result = claimParser(claimDict[claim], reg) + if result: + profile[claim] = result + elif isinstance(claimDict[claim], six.string_types): + result = re.compile(reg).match(claimDict[claim]) + if result: + for k, v in result.groupdict().items(): + profile[claim][k] = v + else: + profile[claim] = [] + for claimItem in claimDict[claim]: + if isinstance(reg, dict): + result = claimParser(claimItem, reg) + if result: + profile[claim].append(result) + else: + result = re.compile(reg).match(claimItem) + if result: + profile[claim].append(result.groupdict()) + + return profile + + +class OAuth2IdProvider(IdProvider, OAuth2Session): + + def __init__(self, name=None, token_endpoint_auth_method='client_secret_post', revocation_endpoint_auth_method=None, + scope='', token=None, token_placement='header', update_token=None, **parameters): + """ OIDCClient constructor + """ + if 'ProviderName' not in parameters: + parameters['ProviderName'] = name + IdProvider.__init__(self, **parameters) + OAuth2Session.__init__(self, token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, token=token, token_placement=token_placement, + update_token=update_token, **parameters) + self.jwks = parameters.get('jwks') + # Convert scope to list + scope = scope or '' + self.scope = list_to_scope([s.strip() for s in scope.strip().replace('+', ' ').split(',' if ',' in scope else ' ')]) + self.parameters = parameters + self.name = parameters['ProviderName'] + self.verify = False + + self.server_metadata_url = parameters.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) + + self.log.debug('"%s" OAuth2 IdP initialization done:' % self.name, + '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, + self.client_secret, + pprint.pformat(self.metadata))) + + def verifyToken(self, accessToken): + """ Verify access token + + :param str accessToken: access token + """ + pprint.pprint(self.jwks) + try: + # Try to decode token + gLogger.debug("Try to decode token:", accessToken) + return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) + except Exception: + # If we have outdated keys, we try to update them from identity provider + gLogger.debug("Try to update %s jwks.." % self.metadata['issuer']) + self.jwks = self.fetch_metadata(self.get_metadata('jwks_uri')) + pprint.pprint(self.jwks) + return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) + + def update_token(self, token, refresh_token): + pass + + def refreshToken(self, refresh_token): + """ Refresh token + + :param str token: refresh_token + + :return: dict + """ + return self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token) + + def revokeToken(self, token=None, token_type_hint='refresh_token'): + """ Revoke token + + :param str token: token + :param str token_type_hint: token type + """ + self.revoke_token(self.get_metadata('revocation_endpoint'), token=token, token_type_hint=token_type_hint) + + def get_metadata(self, option=None): + """ Get metadata + """ + if not self.metadata.get(option): + self.fetch_metadata() + return self.metadata.get(option) + + def fetch_metadata(self, url=None): + """ Fetch metada + """ + data = self.get(url or self.server_metadata_url, withhold_token=True).json() + self.metadata.update(data) + + def researchGroup(self, payload, token): + """ Research group + """ + credDict = self.parseBasic(payload) + if not credDict.get('group'): + cerdDict = self.userDiscover(credDict) + credDict['provider'] = self.name + return credDict + + def authorization(self, group=None): + """ Authorizaion through DeviceCode flow + """ + result = self.submitDeviceCodeAuthorizationFlow(group) + if not result['OK']: + return result + response = result['Value'] + + # Notify user to go to authorization endpoint + showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], + response['verification_uri']) + gLogger.notice(showURL) + + return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) + + def submitNewSession(self, pkce=True): + """ Submit new authorization session + + :param bool pkce: use PKCE + + :return: S_OK(str)/S_ERROR() + """ + session = {} + params = dict(state=generate_token(10)) + # Create PKCE verifier + if pkce: + session['code_verifier'] = generate_token(48) + params['code_challenge_method'] = 'S256' + params['code_challenge'] = create_s256_code_challenge(session['code_verifier']) + url, state = self.create_authorization_url(self.get_metadata('authorization_endpoint'), **params) + return url, state, session + + def parseAuthResponse(self, response, session=None): + """ Make user info dict: + + :param dict response: response on request to get user profile + :param object session: session + + :return: S_OK(dict)/S_ERROR() + """ + response = createOAuth2Request(response) + + self.log.debug('Try to parse authentication response:', pprint.pformat(response.data)) + + if not session: + session = {} + + self.log.debug('Current session is:\n', pprint.pformat(session)) + + self.fetchToken(authorization_response=response.uri, code_verifier=session.get('code_verifier')) + # Get user info + claims = self.getUserProfile() + credDict = self.parseBasic(claims) + credDict.update(self.parseEduperson(claims)) + cerdDict = self.userDiscover(credDict) + + self.log.debug('Got response dictionary:\n', pprint.pformat(cerdDict)) + + # Store token + self.token['user_id'] = credDict['ID'] + self.log.debug('Store token to the database:\n', pprint.pformat(dict(self.token))) + + return S_OK(credDict) + + def fetchToken(self, **kwargs): + """ Fetch token + """ + self.fetch_access_token(self.get_metadata('token_endpoint'), **kwargs) + self.token['client_id'] = self.client_id + self.token['provider'] = self.name + return OAuth2Token(self.token) + + def parseBasic(self, claimDict): + """ Parse basic claims + + :param dict claimDict: claims + + :return: S_OK(dict)/S_ERROR() + """ + credDict = {} + credDict['ID'] = claimDict['sub'] + credDict['DN'] = '/O=DIRAC/CN=%s' % credDict['ID'] + credDict['group'] = claimDict.get('group') + return credDict + + def parseEduperson(self, claimDict): + """ Parse eduperson claims + + :return: dict + """ + credDict = {} + attributes = { + 'eduperson_unique_id': '^(?P.*)', + 'eduperson_entitlement': '^(?P[A-z,.,_,-,:]+):(group:registry|group):(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*' + } + if 'eduperson_entitlement' not in claimDict: + claimDict = self.getUserProfile() + resDict = claimParser(claimDict, attributes) + if not resDict: + return credDict + credDict['ID'] = resDict['eduperson_unique_id']['ID'] + credDict['VOs'] = {} + for voDict in resDict['eduperson_entitlement']: + if voDict['VO'] not in credDict['VOs']: + credDict['VOs'][voDict['VO']] = {'VORoles': []} + if voDict['VORole'] not in credDict['VOs'][voDict['VO']]['VORoles']: + credDict['VOs'][voDict['VO']]['VORoles'].append(voDict['VORole']) + return credDict + + def userDiscover(self, credDict): + credDict['DIRACGroups'] = [] + for vo, voData in credDict.get('VOs', {}).items(): + result = getVOMSRoleGroupMapping(vo) + pprint.pprint(result) + if result['OK']: + for role in voData['VORoles']: + groups = result['Value']['VOMSDIRAC'].get('/%s' % role) + if groups: + credDict['DIRACGroups'] = list(set(credDict['DIRACGroups'] + groups)) + if credDict['DIRACGroups']: + credDict['group'] = credDict['DIRACGroups'][0] + return credDict + + def submitDeviceCodeAuthorizationFlow(self, group=None): + """ Submit authorization flow + + :return: S_OK(dict)/S_ERROR() -- dictionary with device code flow response + """ + groupScopes = [] + if group: + result = self.getGroupScopes(group) + if not result['OK']: + return result + groupScopes = result['Value'] + + try: + r = requests.post(self.get_metadata('device_authorization_endpoint'), data=dict( + client_id=self.client_id, scope=list_to_scope(scope_to_list(self.scope) + groupScopes) + ), verify=self.verify) + r.raise_for_status() + deviceResponse = r.json() + if 'error' in deviceResponse: + return S_ERROR('%s: %s' % (deviceResponse['error'], deviceResponse.get('description', ''))) + + # Check if all main keys are present here + for k in ['user_code', 'device_code', 'verification_uri']: + if not deviceResponse.get(k): + return S_ERROR('Mandatory %s key is absent in authentication response.' % k) + + return S_OK(deviceResponse) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer, timeout.') + except requests.exceptions.RequestException as ex: + return S_ERROR(repr(ex)) + except Exception as ex: + return S_ERROR('Cannot read authentication response: %s' % repr(ex)) + + def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, timeout=300): + """ Submit waiting loop process, that will monitor current authorization session status + + :param str deviceCode: received device code + :param int interval: waiting interval + :param int timeout: max time of waiting + + :return: S_OK(dict)/S_ERROR() - dictionary contain access/refresh token and some metadata + """ + __start = time.time() + + gLogger.notice('Authorization pending.. (use CNTL + C to stop)') + while True: + time.sleep(int(interval)) + if time.time() - __start > timeout: + return S_ERROR('Time out.') + r = requests.post(self.get_metadata('token_endpoint'), data=dict(client_id=self.client_id, + grant_type=DEVICE_CODE_GRANT_TYPE, + device_code=deviceCode), verify=self.verify) + token = r.json() + if not token: + return S_ERROR('Resived token is empty!') + if 'error' not in token: + self.token = token + return S_OK(token) + if token['error'] != 'authorization_pending': + return S_ERROR(token['error'] + ' : ' + token.get('description', '')) + + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPScope') + if not idPScope: + return S_ERROR('Cannot find role for %s' % group) + return S_OK(scope_to_list(idPScope)) + + def exchangeGroup(self, group): + """ Get new tokens for group scope + + :param str group: requested group + + :return: dict -- token + """ + result = self.getGroupScopes(group) + if not result['OK']: + return result + groupScopes = result['Value'] + try: + token = self.exchange_token(self.get_metadata('token_endpoint'), subject_token=self.token['access_token'], + subject_token_type='urn:ietf:params:oauth:token-type:access_token', + scope=list_to_scope(scope_to_list(self.scope) + groupScopes)) + if not token: + return S_ERROR('Cannot exchange token with %s group.' % group) + self.token = token + return S_OK(token) + + except Exception as e: + return S_ERROR(repr(e)) + + def getUserProfile(self): + return self.get(self.get_metadata('userinfo_endpoint')).json() + + def exchange_token(self, url, subject_token=None, subject_token_type=None, body='', + refresh_token=None, access_token=None, auth=None, headers=None, **kwargs): + """ Exchange a new access token + + :param url: Exchange Token endpoint, must be HTTPS. + :param str subject_token: subject_token + :param str subject_token_type: token type https://tools.ietf.org/html/rfc8693#section-3 + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param str refresh_token: refresh token + :param str access_token: access token + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: A :class:`OAuth2Token` object (a dict too). + """ + session_kwargs = self._extract_session_request_params(kwargs) + refresh_token = refresh_token or self.token.get('refresh_token') + access_token = access_token or self.token.get('access_token') + subject_token = subject_token or refresh_token + subject_token_type = subject_token_type or 'urn:ietf:params:oauth:token-type:refresh_token' + if 'scope' not in kwargs and self.scope: + kwargs['scope'] = self.scope + body = prepare_token_request('urn:ietf:params:oauth:grant-type:token-exchange', body, + subject_token=subject_token, subject_token_type=subject_token_type, **kwargs) + + if headers is None: + headers = DEFAULT_HEADERS + + for hook in self.compliance_hook.get('exchange_token_request', []): + url, headers, body = hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.token_endpoint_auth_method) + + return self._exchange_token(url, refresh_token=refresh_token, body=body, headers=headers, + auth=auth, **session_kwargs) + + def _exchange_token(self, url, body='', refresh_token=None, headers=None, auth=None, **kwargs): + resp = self.session.post(url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook.get('exchange_token_response', []): + resp = hook(resp) + + token = self.parse_response_token(resp.json()) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if callable(self.update_token): + self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def generateState(self, session=None): + return session or generate_token(10) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py new file mode 100644 index 00000000000..412689ec9dc --- /dev/null +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -0,0 +1,200 @@ +""" This is a test of the AuthDB + It supposes that the DB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from authlib.jose import JsonWebKey, JsonWebSignature, jwt +from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + +db = AuthDB() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + +exp_payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()) - 10, + 'exp': int(time.time()) - 10, + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + +def test_Token(): + """ Try to revoke/save/get tokens + """ + # Remove all tokens + result = db.removeTokens() + assert result['OK'], result['Message'] + + # Get key + result = db.getPrivateKey() + assert result['OK'], result['Message'] + privat_key = result['Value']['key'] + + # Sign token + token = dict(access_token=jwt.encode({'alg': 'RS256'}, payload, privat_key), + expires_in=864000, + token_type='Bearer', + client_id='1hlUgttap3P9oTSXUwpIT50TVHxCflN3O98uHP217Y', + scope='g:checkin-integration_user', + refresh_token=jwt.encode({'alg': 'RS256'}, payload, privat_key)) + # Expired token + exp_token = dict(access_token=jwt.encode({'alg': 'RS256'}, exp_payload, privat_key), + expires_in=864000, + token_type='Bearer', + client_id='1hlUgttap3P9oTSXUwpIT50TVHxCflN3O98uHP217Y', + scope='g:checkin-integration_user', + refresh_token=jwt.encode({'alg': 'RS256'}, exp_payload, privat_key)) + + # Store tokens + result = db.storeToken(token) + assert result['OK'], result['Message'] + result = db.storeToken(token) + assert result['OK'], result['Message'] + + # Check token + result = db.getToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == token['access_token'] + assert result['Value']['refresh_token'] == token['refresh_token'] + assert result['Value']['revoked'] == False + + # Check expired token + result = db.getToken(exp_token['refresh_token']) + assert not result['OK'] + + # Revoke token + result = db.revokeToken(token) + assert result['OK'], result['Message'] + + # Check if token revoked + result = db.getToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value']['revoked'] == True + + +def test_keys(): + """ Try to store/get/remove keys + """ + # JWS + jws = JsonWebSignature(algorithms=['RS256']) + code_payload = {'user_id': 'user', + 'scope': 'scope', + 'redirect_uri': 'redirect_uri', + 'client_id': 'client', + 'code_challenge': 'code_challenge'} + + # Token metadata + header = {'alg': 'RS256'} + payload = {'sub': 'user', + 'iss': 'issuer', + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + # Remove all keys + result = db.removeKeys() + assert result['OK'], result['Message'] + + # Check active keys + result = db.getActiveKeys() + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Create new one + result = db.getPrivateKey() + assert result['OK'], result['Message'] + + # Sign token + header['kid'] = result['Value']['kid'] + private_key = result['Value']['key'] + token = jwt.encode(header, payload, private_key) + # Sign auth code + code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) + + # Get public key set + result = db.getKeySet() + assert result['OK'], result['Message'] + _payload = jwt.decode(token, JsonWebKey.import_key_set(result['Value'].as_dict())) + assert _payload == payload + data = jws.deserialize_compact(code, result['Value'].keys[0]) + _code_payload = json_loads(urlsafe_b64decode(data['payload'])) + assert _code_payload == code_payload + + # Get JWK + result = db.getJWKs() + assert result['OK'], result['Message'] + _payload = jwt.decode(token, JsonWebKey.import_key_set(result['Value'])) + assert _payload == payload, result['Value'] + + +def test_Sessions(): + """ Try to store/get/remove Sessions + """ + # Example of the new session metadata + sData1 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} + + # Example of the updated session + sData2 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', + 'user_id': 'username'} + + # Remove old session + db.removeSession(sData1['id']) + + # Add session + result = db.addSession(sData1) + assert result['OK'], result['Message'] + + # Get session + result = db.getSessionByUserCode(sData1['user_code']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData1['device_code'] + assert result['Value'].get('user_id') != sData2['user_id'] + + # Update session + result = db.updateSession(sData2, sData1['id']) + assert result['OK'], result['Message'] + + # Get session + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData2['device_code'] + assert result['Value']['user_id'] == sData2['user_id'] + + # Remove session + result = db.removeSession(sData2['id']) + assert result['OK'], result['Message'] + + # Make sure that the session is absent + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert not result['Value'] From 8eb3455d00c6b0e5598e2b833d372725789d36af Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 19 May 2021 20:48:00 +0200 Subject: [PATCH 002/178] add authorization server --- .../private/authorization/AuthServer.py | 445 ++++++++++++++++++ .../private/authorization/__init__.py | 0 .../authorization/grants/AuthorizationCode.py | 121 +++++ .../authorization/grants/DeviceFlow.py | 144 ++++++ .../authorization/grants/RefreshToken.py | 44 ++ .../authorization/grants/RevokeToken.py | 39 ++ .../private/authorization/grants/__init__.py | 3 + .../private/authorization/utils/Clients.py | 59 +++ .../private/authorization/utils/Requests.py | 66 +++ .../private/authorization/utils/Tokens.py | 68 +++ .../private/authorization/utils/__init__.py | 3 + 11 files changed, 992 insertions(+) create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/__init__.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py new file mode 100644 index 00000000000..513c5a91178 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -0,0 +1,445 @@ +""" This class provides authorization server activity. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +from time import time +import pprint +from dominate import document, tags as dom +from tornado.template import Template + +from authlib.jose import jwt +from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, + DeviceCodeGrant) +from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant #, OpenIDCode +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import Client, DEFAULT_CLIENTS +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB +from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata, isDownloadablePersonalProxy +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo +from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getSetup +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient + +import logging +import sys +log = logging.getLogger('authlib') +log.addHandler(logging.StreamHandler(sys.stdout)) +log.setLevel(logging.DEBUG) +log = gLogger.getSubLogger(__name__) + + +def collectMetadata(issuer=None): + """ Collect metadata """ + result = getAuthorisationServerMetadata() + if not result['OK']: + raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) + metadata = result['Value'] + metadata['jwks_uri'] = metadata['issuer'] + '/jwk' + metadata['token_endpoint'] = metadata['issuer'] + '/token' + metadata['userinfo_endpoint'] = metadata['issuer'] + '/userinfo' + metadata['revocation_endpoint'] = metadata['issuer'] + '/revoke' + metadata['authorization_endpoint'] = metadata['issuer'] + '/authorization' + metadata['device_authorization_endpoint'] = metadata['issuer'] + '/device' + metadata['grant_types_supported'] = ['code', 'authorization_code', 'refresh_token', + 'urn:ietf:params:oauth:grant-type:device_code'] + metadata['response_types_supported'] = ['code', 'device', 'token'] + metadata['code_challenge_methods_supported'] = ['S256'] + return AuthorizationServerMetadata(metadata) + + +class AuthServer(_AuthorizationServer): + """ Implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + + Initialize:: + + server = AuthServer() + """ + css = {} + LOCATION = None + + def __init__(self): + self.db = AuthDB() + self.__tokenDB = TokenDB() + self.proxyCli = ProxyManagerClient() + self.idps = IdProviderFactory() + # Privide two authlib methods query_client and save_token + _AuthorizationServer.__init__(self, query_client=self.getClient, save_token=self.saveToken) + self.generate_token = self.generateProxyOrToken + self.bearerToken = BearerToken(self.access_token_generator, self.refresh_token_generator) + self.config = {} + self.metadata = collectMetadata() + self.metadata.validate() + # Register configured grants + self.register_grant(RefreshTokenGrant) + self.register_grant(DeviceCodeGrant) + self.register_endpoint(DeviceAuthorizationEndpoint) + self.register_endpoint(RevocationEndpoint) + self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])#, OpenIDCode(require_nonce=False)]) + + def addSession(self, session): + self.db.addSession(session) + + def getSession(self, session): + self.db.getSession(session) + + def saveToken(self, token, request): + """ Store tokens + + :param dict token: tokens + :param object request: Request object + """ + if token.get('refresh_token'): + token['client_id'] = request.client.client_id + result = self.db.storeToken(token) + if not result['OK']: + gLogger.error(result['Message']) + + def getClient(self, clientID): + """ Search authorization client + + :param str clientID: client ID + + :return: object + """ + data = {} + gLogger.debug('Try to query %s client' % clientID) + result = getProvidersForInstance('Id', 'DIRAC') + if not result['OK']: + gLogger.error(result['Message']) + return None + + clients = list(set(result['Value'] + list(DEFAULT_CLIENTS.keys()))) + for client in clients: + data = DEFAULT_CLIENTS.get(client, {}) + result = getProviderInfo(client) + if not result['OK']: + gLogger.debug(result['Message']) + else: + data.update(result['Value']) + if data.get('client_id') and data['client_id'] == clientID: + gLogger.debug('Found client:\n', pprint.pformat(data)) + return Client(data) + + return None + + def __getScope(self, scope, param): + """ Get parameter scope + + :param str scope: scope + :param str param: parameter scope + + :return: str or None + """ + try: + return [s.split(':')[1] for s in scope_to_list(scope) if s.startswith('%s:' % param)][0] + except: + return None + + def generateProxyOrToken(self, client, grant_type, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """ Generate proxy or tokens after authorization + """ + if 'proxy' in scope_to_list(scope): + # Try to return user proxy if proxy scope present in the authorization request + if not isDownloadablePersonalProxy(): + raise Exception("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") + + group = self.__getScope(scope, 'g') + lifetime = self.__getScope(scope, 'lifetime') + gLogger.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) + result = getUsernameForDN('/O=DIRAC/CN=%s' % user) + if result['OK']: + result = getDNForUsername(result['Value']) + if not result['OK']: + raise Exception(result['Message']) + userDNs = result['Value'] + err = [] + for dn in userDNs: + gLogger.debug('Try to get proxy for %s' % dn) + if lifetime: + result = self.proxyCli.downloadProxy(dn, group, requiredTimeLeft=int(lifetime)) + else: + result = self.proxyCli.downloadProxy(dn, group) + if not result['OK']: + err.append(result['Message']) + else: + gLogger.info('Proxy was created.') + result = result['Value'].dumpAllToString() + if not result['OK']: + raise Exception(result['Message']) + return {'proxy': result['Value']} + raise Exception('; '.join(err)) + + return self.bearerToken(client, grant_type, user=user, scope=scope, expires_in=expires_in, + include_refresh_token=include_refresh_token) + + def getIdPAuthorization(self, providerName, request): + """ Submit subsession and return dict with authorization url and session number + + :param str providerName: provider name + :param object request: main session request + + :return: S_OK(response)/S_ERROR() -- dictionary contain response generated by `handle_response` + """ + result = self.idps.getIdProvider(providerName) + if not result['OK']: + return result + idpObj = result['Value'] + authURL, state, session = idpObj.submitNewSession() + session['state'] = state + session['Provider'] = providerName + session['mainSession'] = request if isinstance(request, dict) else request.toDict() + + gLogger.verbose('Redirect to', authURL) + return self.handle_response(302, {}, [("Location", authURL)], session) + + def parseIdPAuthorizationResponse(self, response, session): + """ Fill session by user profile, tokens, comment, OIDC authorize status, etc. + Prepare dict with user parameters, if DN is absent there try to get it. + Create new or modify existing DIRAC user and store the session + + :param dict response: authorization response + :param str session: session + + :return: S_OK(dict)/S_ERROR() + """ + providerName = session.pop('Provider') + gLogger.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) + # Parse response + result = self.idps.getIdProvider(providerName) + if not result['OK']: + return result + provObj = result['Value'] + result = provObj.parseAuthResponse(response, session) + if not result['OK']: + return result + + # FINISHING with IdP auth result + credDict = result['Value'] + + result = self.__tokenDB.updateToken(provObj.token, user_id=provObj.token['user_id']) + if not result['OK']: + return result + + gLogger.debug("Read profile:", pprint.pformat(credDict)) + # Is ID registred? + result = getUsernameForDN(credDict['DN']) + if not result['OK']: + comment = '%s ID is not registred in the DIRAC.' % credDict['ID'] + result = self.__registerNewUser(providerName, credDict) + if result['OK']: + comment += ' Administrators have been notified about you.' + else: + comment += ' Please, contact the DIRAC administrators.' + return S_ERROR(comment) + credDict['username'] = result['Value'] + return S_OK(credDict) + + def access_token_generator(self, client, grant_type, user, scope): + """ A function to generate ``access_token`` + + :param object client: Client object + :param str grant_type: grant type + :param str user: user unique id + :param str scope: scope + + :return: str + """ + gLogger.debug('GENERATE DIRAC ACCESS TOKEN for "%s" with "%s" scopes.' % (user, scope)) + return self.signToken({'sub': user, + 'iss': self.metadata['issuer'], + 'iat': int(time()), + 'exp': int(time()) + (self.__getScope(scope, 'lifetime') or (12 * 3600)), + 'scope': scope, + 'setup': getSetup(), + 'group': self.__getScope(scope, 'g')}) + + def refresh_token_generator(self, client, grant_type, user, scope): + """ A function to generate ``refresh_token`` + + :param object client: Client object + :param str grant_type: grant type + :param str user: user unique id + :param str scope: scope + + :return: str + """ + gLogger.debug('GENERATE DIRAC REFRESH TOKEN for "%s" with "%s" scopes.' % (user, scope)) + return self.signToken({'sub': user, + 'iss': self.metadata['issuer'], + # 'iat': int(time()), + 'exp': int(time()) + (24 * 3600)}) + + def signToken(self, payload): + """ Sign token + + :param dict payload: token payload + + ;return: str + """ + result = self.db.getPrivateKey() + if not result['OK']: + raise Exception(result['Message']) + + # Sign token + key = result['Value']['key'] + kid = result['Value']['kid'] + header = {'alg': 'RS256', 'kid': kid} + # Need to use enum==0.3.1 for python 2.7 + return jwt.encode(header, payload, key) + + def get_error_uris(self, request): + error_uris = self.config.get('error_uris') + if error_uris: + return dict(error_uris) + + def create_oauth2_request(self, request, method_cls=OAuth2Request, use_json=False): + gLogger.debug('Create OAuth2 request', 'with json' if use_json else '') + return createOAuth2Request(request, method_cls, use_json) + + def create_json_request(self, request): + return self.create_oauth2_request(request, HttpRequest, True) + + def handle_error_response(self, request, error): + return self.handle_response(*error(translations=self.get_translations(request), + error_uris=self.get_error_uris(request)), error=True) + + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, error=None, **actions): + gLogger.debug('Handle authorization response with %s status code:' % status_code, payload) + gLogger.debug('Headers:', headers) + if newSession: + gLogger.debug('newSession:', newSession) + return S_OK([[status_code, headers, payload, newSession, error], actions]) + + def create_authorization_response(self, response, username): + result = super(AuthServer, self).create_authorization_response(response, username) + if result['OK']: + # Remove auth session + result['Value'][0][4] = True + return result + + def validate_consent_request(self, request, provider=None): + """ Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization:: + + :param object request: tornado request + :param provider: provider + + :return: response generated by `handle_response` or S_ERROR or html + """ + if request.method != 'GET': + return 'Use GET method to access this endpoint.' + try: + req = self.create_oauth2_request(request) + gLogger.info('Validate consent request for', req.state) + grant = self.get_authorization_grant(req) + gLogger.debug('Use grant:', grant) + grant.validate_consent_request() + if not hasattr(grant, 'prompt'): + grant.prompt = None + + # Check Identity Provider + provider, providerChooser = self.validateIdentityProvider(req, provider) + if not provider: + return providerChooser + + # Submit second auth flow through IdP + return self.getIdPAuthorization(provider, req) + except OAuth2Error as error: + return self.handle_error_response(None, error) + + def validateIdentityProvider(self, request, provider): + """ Check if identity provider registred in DIRAC + + :param object request: request + :param str provider: provider name + + :return: str, S_OK()/S_ERROR() -- provider name and html page to choose it + """ + # Research supported IdPs + result = getProvidersForInstance('Id') + if not result['OK']: + return None, result + idPs = result['Value'] + + # Remove settings of the DIRAC AS + result = getProvidersForInstance('Id', 'DIRAC') + if not result['OK']: + return None, result + for dCli in result['Value']: + if dCli in idPs: + idPs.remove(dCli) + + if not idPs: + return None, S_ERROR('No identity providers found.') + + if not provider: + if len(idPs) == 1: + return idPs[0], None + # Choose IdP interface + doc = document('DIRAC authentication') + with doc.head: + dom.link(rel='stylesheet', + href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") + dom.style(self.css['CSS']) + with doc: + with dom.div(style=self.css['css_main']): + with dom.div('Choose identity provider', style=self.css['css_align_center']): + for idP in idPs: + # data: Status, Comment, Action + dom.button(dom.a(idP, href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query)), + cls='button') + return None, self.handle_response(payload=Template(doc.render()).generate()) + + # Check IdP + if provider not in idPs: + return None, S_ERROR('%s is not registered in DIRAC.' % provider) + + return provider, None + + def __registerNewUser(self, provider, userProfile): + """ Register new user + + :param str provider: provider + :param dict userProfile: user information dictionary + + :return: S_OK()/S_ERROR() + """ + from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient + + username = userProfile['DN'] + + mail = {} + mail['subject'] = "[SessionManager] User %s to be added." % username + mail['body'] = 'User %s was authenticated by ' % userProfile['FullName'] + mail['body'] += provider + mail['body'] += "\n\nAuto updating of the user database is not allowed." + mail['body'] += " New user %s to be added," % username + mail['body'] += "with the following information:\n" + mail['body'] += "\nUser name: %s\n" % username + mail['body'] += "\nUser profile:\n%s" % pprint.pformat(userProfile) + mail['body'] += "\n\n------" + mail['body'] += "\n This is a notification from the DIRAC AuthManager service, please do not reply.\n" + result = S_OK() + for addresses in getEmailsForGroup('dirac_admin'): + result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) + if not result['OK']: + self.log.error(result['Message']) + if result['OK']: + self.log.info(result['Value'], "administrators have been notified about a new user.") + return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py new file mode 100644 index 00000000000..95d76fdc8d0 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -0,0 +1,121 @@ +""" This class describe Authorization Code grant type +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +from pprint import pprint +from authlib.jose import JsonWebSignature +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant +from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.common.encoding import to_unicode, json_dumps, json_b64encode, urlsafe_b64decode, json_loads + +from DIRAC import gLogger, S_OK, S_ERROR + + +class OAuth2Code(dict): + def __init__(self, params): + params['auth_time'] = params.get('auth_time', int(time())) + super(OAuth2Code, self).__init__(params) + + @property + def user(self): + return self.get('user_id') + + @property + def code_challenge(self): + return self.get('code_challenge') + + @property + def code_challenge_method(self): + return self.get('code_challenge_method', 'pain') + + def is_expired(self): + return self.get('auth_time') + 300 < time() + + def get_redirect_uri(self): + return self.get('redirect_uri') + + def get_scope(self): + return self.get('scope', '') + + def get_auth_time(self): + return self.get('auth_time') + + def get_nonce(self): + return self.get('nonce') + + +class AuthorizationCodeGrant(_AuthorizationCodeGrant): + """ See :class:`authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant` """ + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + pass + + def delete_authorization_code(self, authorization_code): + pass + + def query_authorization_code(self, code, client): + """ Parse authorization code + + :param code: authorization code as JWS + :param client: client + + :return: OAuth2Code or None + """ + gLogger.debug('Query authorization code:', code) + jws = JsonWebSignature(algorithms=['RS256']) + result = self.server.db.getKeySet() + if not result['OK']: + raise Exception(result['Message']) + err = None + data = None + for key in result['Value'].keys: + try: + data = jws.deserialize_compact(code, key) + except Exception as e: + err = e + if err: + gLogger.error('Cannot get authorization code:', repr(err)) + return None + try: + item = OAuth2Code(json_loads(urlsafe_b64decode(data['payload']))) + gLogger.debug('Authorization code scope:', item.get_scope()) + except Exception as e: + gLogger.error('Cannot read authorization code:', repr(e)) + return None + if not item.is_expired(): + return item + + def authenticate_user(self, authorization_code): + """ Authenticate the user related to this authorization_code. + + :param authorization_code: authorization code + """ + return authorization_code.user + + def generate_authorization_code(self): + """ The method to generate "code" value for authorization code data. + + :return: str + """ + gLogger.debug('Generate authorization code for credentials:', self.request.user) + pprint(self.request.data) + jws = JsonWebSignature(algorithms=['RS256']) + protected = {'alg': 'RS256'} + code = OAuth2Code({'user_id': self.request.user['ID'], + # These scope already contain DIRAC groups + 'scope': self.request.data['scope'], + 'redirect_uri': self.request.args['redirect_uri'], + 'client_id': self.request.args['client_id'], + 'code_challenge': self.request.args.get('code_challenge'), + 'code_challenge_method': self.request.args.get('code_challenge_method')}) + gLogger.debug('Authorization code generated:', dict(code)) + result = self.server.db.getPrivateKey() + if not result['OK']: + raise OAuth2Error('Cannot check authorization code: %s' % result['Message']) + key = result['Value']['key'] + return jws.serialize_compact(protected, json_b64encode(dict(code)), key) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py new file mode 100644 index 00000000000..e403de89853 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -0,0 +1,144 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationEndpointMixin +from authlib.oauth2.rfc6749.errors import InvalidClientError, UnauthorizedClientError +from authlib.oauth2.rfc8628 import (DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, + DeviceCodeGrant as _DeviceCodeGrant, + DeviceCredentialDict) + +from DIRAC import gLogger +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata + +log = gLogger.getSubLogger(__name__) + + +class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): + """ See :class:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint` """ + + def create_endpoint_response(self, req): + """ See :func:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint.create_endpoint_response` """ + # Share original request object to endpoint class before create_endpoint_response + self.req = req + return super(DeviceAuthorizationEndpoint, self).create_endpoint_response(req) + + def get_verification_uri(self): + """ Create verification uri when `DeviceCode` flow initialized + + :return: str + """ + result = getAuthorisationServerMetadata() + if not result['OK']: + raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) + return result['Value']['issuer'] + '/device' + + def save_device_credential(self, client_id, scope, data): + """ Save device credentials + + :param str client_id: client id + :param str scope: request scopes + :param dict data: device credentials + """ + data.update(dict(uri='{api}?{query}&response_type=device&client_id={client_id}&scope={scope}'.format( + api=data['verification_uri'], query=self.req.query, client_id=client_id, scope=scope, + ), id=data['device_code'], client_id=client_id, scope=scope)) + result = self.server.db.addSession(data) + if not result['OK']: + raise OAuth2Error('Cannot save device credentials', result['Message']) + + +class DeviceCodeGrant(_DeviceCodeGrant, AuthorizationEndpointMixin): + """ See :class:`authlib.oauth2.rfc8628.DeviceCodeGrant` """ + RESPONSE_TYPES = {'device'} + + def validate_authorization_request(self): + """ Validate authorization request + + :return: None + """ + # Validate client for this request + client_id = self.request.client_id + log.debug('Validate authorization request of', client_id) + if client_id is None: + raise InvalidClientError(state=self.request.state) + client = self.server.query_client(client_id) + if not client: + raise InvalidClientError(state=self.request.state) + response_type = self.request.response_type + if not client.check_response_type(response_type): + raise UnauthorizedClientError('The client is not authorized to use "response_type={}"'.format(response_type)) + self.request.client = client + self.validate_requested_scope() + + # Check user_code, when user go to authorization endpoint + userCode = self.request.args.get('user_code') + if not userCode: + raise OAuth2Error('user_code is absent.') + + # Get session from cookie + if not self.server.db.getSessionByUserCode(userCode): + raise OAuth2Error('Session with %s user code is expired.' % userCode) + return None + + def create_authorization_response(self, redirect_uri, user): + """ Mark session as authed with received user + + :param str redirect_uri: redirect uri + :param dict user: dictionary with username and userID + + :return: result of `handle_response` + """ + result = self.server.db.getSessionByUserCode(self.request.data['user_code']) + if not result['OK']: + raise OAuth2Error(result['Message']) + data = result['Value'] + data.update(dict(user_id=user['ID'], uri=self.request.uri, + username=user['username'], scope=self.request.scope)) + # Save session with user + result = self.server.db.updateSession(data, data['id']) + if not result['OK']: + raise OAuth2Error('Cannot save authorization result', result['Message']) + return 200, 'Authorization complite.' + + def query_device_credential(self, device_code): + """ Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. + + :param str device_code: device code + + :return: dict + """ + result = self.server.db.getSession(device_code) + if not result['OK']: + raise OAuth2Error(result['Message']) + data = result['Value'] + if not data: + return None + result = getAuthorisationServerMetadata() + if not result['OK']: + raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) + data['verification_uri'] = result['Value']['issuer'] + '/device' + data['expires_at'] = int(data['expires_in']) + int(time.time()) + data['interval'] = DeviceAuthorizationEndpoint.INTERVAL + return DeviceCredentialDict(data) + + def query_user_grant(self, user_code): + """ Check if user alredy authed and return it to token generator + + :param str user_code: user code + + :return: str, bool -- user dict and user auth status + """ + result = self.server.db.getSessionByUserCode(user_code) + if not result['OK']: + raise OAuth2Error('Cannot found authorization session', result['Message']) + data = result['Value'] + return (data['user_id'], True) if data.get('username') != "None" else None + + def should_slow_down(self, credential, now): + """ The authorization request is still pending and polling should continue, + but the interval MUST be increased by 5 seconds for this and all subsequent requests. + """ + return False diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py new file mode 100644 index 00000000000..40283cd4745 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant + +from DIRAC import gLogger + + +class RefreshTokenGrant(_RefreshTokenGrant): + """ See :class:`authlib.oauth2.rfc6749.grants.RefreshTokenGrant` """ + + def authenticate_refresh_token(self, refresh_token): + """ Get credential for token + + :param str refresh_token: refresh token + + :return: dict or None + """ + # Check auth session + result = self.server.db.getToken(refresh_token) + if not result['OK']: + raise OAuth2Error('Cannot get token', result['Message']) + token = result['Value'] + return None if token.revoked else token + + def authenticate_user(self, credential): + """ Authorize user + + :param object credential: credential + + :return: str + """ + return credential.sub + + def revoke_old_credential(self, credential): + """ Remove old credential + + :param object credential: credential + """ + result = self.server.db.revokeToken(credential) + if not result['OK']: + gLogger.error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py new file mode 100644 index 00000000000..ed1afc8a81d --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint + +from DIRAC import gLogger + + +class RevocationEndpoint(_RevocationEndpoint): + """ See :class:`authlib.oauth2.rfc7009.RevocationEndpoint` """ + + def query_token(self, token, token_type_hint, client): + """ Query requested token from database. + + :param str token: token + :param str token_type_hint: token type + :param client: client + + :return: str + """ + result = self.server.db.getToken(token, token_type_hint) + if not result['OK']: + gLogger.error(result['Message']) + return None + rv = result['Value'] + client_id = client.get_client_id() + if rv and rv.client_id == client_id: + return rv + return None + + def revoke_token(self, token): + """ Mark the give token as revoked. + + :param dict token: token dict + """ + result = self.server.db.revokeToken(token) + if not result['OK']: + gLogger.error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py new file mode 100644 index 00000000000..878841c1840 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py new file mode 100644 index 00000000000..0e067614eb8 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import json +import time + +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + +from DIRAC import gLogger + +__RCSID__ = "$Id$" + +DEFAULT_SCOPE = 'proxy g: lifetime:' + +DEFAULT_CLIENTS = { + 'DIRACCLI': dict( + ProviderType = 'DIRAC', + client_id='DIRAC_CLI', + response_types=['device'], + grant_types=['urn:ietf:params:oauth:grant-type:device_code'] + ), + 'WebAppDIRAC': dict( + ProviderType = 'DIRAC', + token_endpoint_auth_method='client_secret_basic', + response_types=['code'], + grant_types=['authorization_code', 'refresh_token'] + ) +} + +class Client(OAuth2ClientMixin): + def __init__(self, params): + + super(Client, self).__init__() + client_metadata = params.get('client_metadata', params) + client_metadata['scope'] = ' '.join([client_metadata.get('scope', ''), DEFAULT_SCOPE]) + if params.get('redirect_uri') and not client_metadata.get('redirect_uris'): + client_metadata['redirect_uris'] = [params['redirect_uri']] + self.client_id = params['client_id'] + self.client_secret = params.get('client_secret', '') + self.client_id_issued_at = params.get('client_id_issued_at', int(time.time())) + self.client_secret_expires_at = params.get('client_secret_expires_at', 0) + if isinstance(client_metadata, dict): + self._client_metadata = json.dumps(client_metadata) + else: + self._client_metadata = client_metadata + + def get_allowed_scope(self, scope): + if not isinstance(scope, six.string_types): + scope = list_to_scope(scope) + allowed = scope_to_list(super(Client, self).get_allowed_scope(scope)) + for s in scope_to_list(scope): + for def_scope in scope_to_list(DEFAULT_SCOPE): + if s.startswith(def_scope) and s not in allowed: + allowed.append(s) + gLogger.debug('Try to allow "%s" scope:' % scope, allowed) + return list_to_scope(list(set(allowed))) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py new file mode 100644 index 00000000000..e5f38473fa8 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tornado.escape import json_decode +from authlib.common.encoding import to_unicode +from authlib.oauth2 import OAuth2Request as _OAuth2Request +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + +__RCSID__ = "$Id$" + + +class OAuth2Request(_OAuth2Request): + """ OAuth request object """ + + def addScopes(self, scopes): + """ Add new scopes to query + + :param list scopes: scopes + """ + # Remove "scope" argument from uri + self.uri = re.sub(r"&scope(=[^&]*)?|^scope(=[^&]*)?&?", "", self.uri) + # Add "scope" argument to uri with new scopes + self.uri += "&scope=%s" % '+'.join(list(set(scope_to_list(self.scope or '') + scopes))) or '' + # Reinit all attributes with new uri + self.__init__(self.method, to_unicode(self.uri)) + + @property + def groups(self): + """ Serarch DIRAC groups in scopes + + :return: list + """ + return [s.split(':')[1] for s in scope_to_list(self.scope) if s.startswith('g:')] + + def toDict(self): + """ Convert class to dictionary + + :return: dict + """ + return {'method': self.method, 'uri': self.uri} + + +def createOAuth2Request(request, method_cls=OAuth2Request, use_json=False): + """ Create request object + + :param request: request + :type request: object, dict + :param object method_cls: returned class + :param str use_json: if data is json + + :return: object -- `OAuth2Request` + """ + if isinstance(request, method_cls): + return request + if isinstance(request, dict): + return method_cls(request['method'], request['uri'], request.get('body'), request.get('headers')) + if use_json: + body = json_decode(request.body) + else: + body = {} + for k, v in request.body_arguments.items(): + body[k] = ' '.join(v) + return method_cls(request.method, request.full_url(), body, request.headers) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py new file mode 100644 index 00000000000..6be7777e937 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +import functools +from contextlib import contextmanager + +from authlib.jose import jwt +from authlib.oauth2 import OAuth2Error, ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import MissingAuthorizationError, HttpRequest +from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator +from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + + +class OAuth2Token(_OAuth2Token, OAuth2TokenMixin): + """ Implementation a Token object """ + + def __init__(self, params=None, **kwargs): + kwargs.update(params or {}) + kwargs['revoked'] = False if kwargs.get('revoked', 'False') == 'False' else True + self.sub = kwargs.get('sub') + self.issuer = kwargs.get('iss') + self.client_id = kwargs.get('client_id', kwargs.get('aud')) + self.token_type = kwargs.get('token_type') + self.access_token = kwargs.get('access_token') + self.refresh_token = kwargs.get('refresh_token') + self.scope = kwargs.get('scope') + self.revoked = kwargs.get('revoked') + self.issued_at = int(kwargs.get('issued_at', kwargs.get('iat', time()))) + self.expires_in = int(kwargs.get('expires_in', 0)) + self.expires_at = int(kwargs.get('expires_at', kwargs.get('exp', 0))) + if not self.issued_at: + raise Exception('Missing "iat" in token.') + if not self.expires_at: + if not self.expires_in: + raise Exception('Cannot calculate token "expires_at".') + self.expires_at = self.issued_at + self.expires_in + if not self.expires_in: + self.expires_in = self.expires_at - self.issued_at + kwargs.update({'client_id': self.client_id, + 'token_type': self.token_type, + 'access_token': self.access_token, + 'refresh_token': self.refresh_token, + 'scope': self.scope, + 'revoked': self.revoked, + 'issued_at': self.issued_at, + 'expires_in': self.expires_in, + 'expires_at': self.expires_at}) + super(OAuth2Token, self).__init__(kwargs) + + @property + def scopes(self): + """ Get tokens scopes + + :return: list + """ + return scope_to_list(self.scope) or [] + + @property + def groups(self): + """ Get tokens groups + + :return: list + """ + return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py new file mode 100644 index 00000000000..878841c1840 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function From a0946bb4f9a8514785c5df09b1baead70b7d4a96 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 19 May 2021 21:43:55 +0200 Subject: [PATCH 003/178] fix tests --- environment-py3.yml | 4 +++ environment.yml | 4 +++ .../ConfigurationSystem/Client/Utilities.py | 1 + .../Core/Tornado/Server/BaseRequestHandler.py | 13 ++++---- .../Core/Tornado/scripts/tornado_start_web.py | 2 +- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 14 ++++----- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 30 +++++++++---------- .../private/authorization/AuthServer.py | 29 +++++++++--------- .../authorization/grants/AuthorizationCode.py | 4 +-- .../authorization/grants/DeviceFlow.py | 2 +- .../authorization/grants/RevokeToken.py | 2 +- .../private/authorization/utils/Clients.py | 5 ++-- .../FrameworkSystem/scripts/dirac_login.py | 12 ++++---- .../Resources/IdProvider/CheckInIdProvider.py | 2 +- .../Resources/IdProvider/DIRACIdProvider.py | 1 - .../Resources/IdProvider/OAuth2IdProvider.py | 15 +++++----- tests/Integration/Framework/Test_AuthDB.py | 14 ++++----- 17 files changed, 83 insertions(+), 71 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index 169d35228c7..6ee6ea34490 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -88,6 +88,10 @@ dependencies: #- tornado >=5.0.0,<6.0.0 - typing >=3.6.6 - pyyaml + # OAuth2 + - authlib == 1.0.0 + - pyjwt + - dominate - pip: # This is a fork of tornado with a patch to allow for configurable iostream # It should eventually be part of DIRACGrid diff --git a/environment.yml b/environment.yml index 3413b74c980..03decc4470b 100644 --- a/environment.yml +++ b/environment.yml @@ -72,6 +72,10 @@ dependencies: # Pin OpenSSL to avoid: https://github.com/DIRACGrid/DIRAC/issues/4489 - openssl <1.1 - selectors2 + # OAuth2 + - authlib == 1.0.0 + - pyjwt + - dominate - pip: - diraccfg # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index 7e285c77cb6..fe3b4e51d38 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -20,6 +20,7 @@ from DIRAC.ConfigurationSystem.Client.PathFinder import getDatabaseSection from DIRAC.Core.Utilities.Glue2 import getGlue2CEInfo from DIRAC.Core.Utilities.SiteSEMapping import getSEHosts +from DIRAC.ConfigurationSystem.Client.PathFinder import getSystemInstance from DIRAC.DataManagementSystem.Utilities.DMSHelpers import DMSHelpers diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 4ed160480d6..611f31a0529 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -196,10 +196,10 @@ def __initializeService(cls, request): cls._serviceInfoDict = serviceInfo cls.__monitorLastStatsUpdate = time.time() - + # Some pre-initialization cls._initializeHandler() - + cls.initializeHandler(serviceInfo) cls.__init_done = True @@ -468,8 +468,9 @@ def _finishFuture(self, retVal): # in a thread anymore # Is it S_OK or S_ERROR - if isinstance(self.result, dict) and isinstance(self.result.get('OK'), bool) and ('Value' if self.result['OK'] else 'Message') in self.result: - self._parseDIRACResult(self.result) + if isinstance(self.result, dict): + if isinstance(self.result.get('OK'), bool) and ('Value' if self.result['OK'] else 'Message') in self.result: + self._parseDIRACResult(self.result) # If set to true, do not JEncode the return of the RPC call # This is basically only used for file download through @@ -489,7 +490,7 @@ def _finishFuture(self, retVal): self.write(encode(self.result)) self.finish() - + def _parseDIRACResult(self, result): """ Processing of a standard DIRAC result, but in a separate method so that it can be modified for another class if necessary @@ -614,7 +615,7 @@ def _authzVISITOR(self): :return: S_OK(dict) """ return S_OK({}) - + @property def log(self): return sLog diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py index 3a8d37212ab..d9023de29d6 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py @@ -49,7 +49,7 @@ def main(): except ImportError as e: gLogger.fatal('Web portal is not installed. %s' % repr(e)) sys.exit(1) - + # Get routes and settings for a portal result = App().getAppToDict(8000) if not result['OK']: diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index ff73f0a84be..ea98acf9ba1 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -85,7 +85,7 @@ def initializeHandler(cls, serviceInfo): cls.server = AuthServer() cls.server.css = dict(CSS=cls.CSS, css_align_center=cls.css_align_center, css_main=cls.css_main) cls.server.LOCATION = cls.LOCATION - + def initializeRequest(self): """ Called at every request """ self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path @@ -102,7 +102,7 @@ def _parseDIRACResult(self, result): if not result['OK']: # If response error is DIRAC server error, not OAuth2 flow error self.removeSession() - self.set_status = 400 + self.set_status(400) self.write({'error': 'server_error', 'description': '%s:\n%s' % (result['Message'], '\n'.join(result['CallStack']))}) else: @@ -121,7 +121,7 @@ def _parseDIRACResult(self, result): self.removeSession() for method, args_kwargs in result['Value'][1].items(): eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) - + def saveSession(self, session): """ Save session to cookie @@ -132,7 +132,7 @@ def saveSession(self, session): def removeSession(self): """ Remove session from cookie """ self.clear_cookie('auth_session') - + def getSession(self, state=None, **kw): """ Get session from cookie @@ -346,7 +346,7 @@ def web_device(self, provider=None): # Get original request from session req = createOAuth2Request(dict(method='GET', uri=session['uri'])) - groups = [s.split(':')[1] for s in scope_to_list(req.scope) if s.startswith('g:')] + groups = [s.split(':')[1] for s in scope_to_list(req.scope) if s.startswith('g:')] # pylint: disable=no-member group = groups[0] if groups else None if group and not provider: @@ -355,7 +355,7 @@ def web_device(self, provider=None): print('Use provider:', provider) - authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) + authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) # pylint: disable=no-member # Save session to cookie return self.server.handle_response(302, {}, [("Location", authURL)], session) @@ -435,7 +435,7 @@ def web_redirect(self): # Parse result of the second authentication flow self.log.info('%s session, parsing authorization response:\n' % state, '\n'.join([self.request.uri, self.request.query, self.request.body, str(self.request.headers)])) - + result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) if not result['OK']: return result diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 911172aeb22..3dde68ea1fa 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -172,7 +172,7 @@ def storeToken(self, token): def removeTokens(self): """ Get active keys - + :return: S_OK(list)/S_ERROR() """ session = self.session() @@ -184,14 +184,14 @@ def removeTokens(self): def generateRSAKeys(self): """ Generate an RSA keypair with an exponent of 65537 in PEM format - + :return: S_OK/S_ERROR """ key = RSAKey.generate_key(key_size=1024, is_private=True) dictKey = dict(key=json.dumps(key.as_dict()), - expires_at=time() + (30 * 24 *3600), + expires_at=time() + (30 * 24 * 3600), kid=KeySet([key]).as_dict()['keys'][0]['kid']) - + session = self.session() try: session.add(JWK(**dictKey)) @@ -201,7 +201,7 @@ def generateRSAKeys(self): def getKeySet(self): """ Get key set - + :return: S_OK(obj)/S_ERROR() """ keys = [] @@ -216,10 +216,10 @@ def getKeySet(self): key = RSAKey.import_key(json.loads(keyDict['key'])) keys.append(key) return S_OK(KeySet(keys)) - + def getJWKs(self): """ Get JWKs list - + :return: S_OK(dict)/S_ERROR() """ keys = [] @@ -229,10 +229,10 @@ def getJWKs(self): for k in result['Value'].as_dict()['keys']: keys.append({'n': k['n'], "kty": k['kty'], "e": k['e'], "kid": k['kid']}) return S_OK({'keys': keys}) - + def getPrivateKey(self): """ Get private key - + :return: S_OK(obj)/S_ERROR() """ result = self.getActiveKeys() @@ -251,7 +251,7 @@ def getPrivateKey(self): def getActiveKeys(self): """ Get active keys - + :return: S_OK(list)/S_ERROR() """ session = self.session() @@ -264,10 +264,10 @@ def getActiveKeys(self): except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result(session, S_OK([self.__rowToDict(jwk) for jwk in jwks])) - + def removeKeys(self): """ Get active keys - + :return: S_OK(list)/S_ERROR() """ session = self.session() @@ -344,7 +344,7 @@ def getSession(self, sessionID): except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result(session, S_OK(self.__rowToDict(resData))) - + def getSessionByUserCode(self, userCode): """ Get client @@ -356,9 +356,9 @@ def getSessionByUserCode(self, userCode): try: resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() except MultipleResultsFound: - return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + return self.__result(session, S_ERROR("%s is not unique ID." % userCode)) except NoResultFound: - return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + return self.__result(session, S_ERROR("Session for %s user code is expired." % userCode)) except Exception as e: return self.__result(session, S_ERROR(str(e))) return self.__result(session, S_OK(self.__rowToDict(resData))) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 513c5a91178..caff2b1f701 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -21,13 +21,12 @@ from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, DeviceCodeGrant) -from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant #, OpenIDCode +from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant from DIRAC.FrameworkSystem.private.authorization.utils.Clients import Client, DEFAULT_CLIENTS from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB -from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata, isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername @@ -74,7 +73,7 @@ class AuthServer(_AuthorizationServer): def __init__(self): self.db = AuthDB() - self.__tokenDB = TokenDB() + # self.__tokenDB = TokenDB() self.proxyCli = ProxyManagerClient() self.idps = IdProviderFactory() # Privide two authlib methods query_client and save_token @@ -89,11 +88,11 @@ def __init__(self): self.register_grant(DeviceCodeGrant) self.register_endpoint(DeviceAuthorizationEndpoint) self.register_endpoint(RevocationEndpoint) - self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])#, OpenIDCode(require_nonce=False)]) + self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) def addSession(self, session): self.db.addSession(session) - + def getSession(self, session): self.db.getSession(session) @@ -108,7 +107,7 @@ def saveToken(self, token, request): result = self.db.storeToken(token) if not result['OK']: gLogger.error(result['Message']) - + def getClient(self, clientID): """ Search authorization client @@ -147,7 +146,7 @@ def __getScope(self, scope, param): """ try: return [s.split(':')[1] for s in scope_to_list(scope) if s.startswith('%s:' % param)][0] - except: + except Exception: return None def generateProxyOrToken(self, client, grant_type, user=None, scope=None, @@ -228,13 +227,15 @@ def parseIdPAuthorizationResponse(self, response, session): result = provObj.parseAuthResponse(response, session) if not result['OK']: return result - + # FINISHING with IdP auth result credDict = result['Value'] - result = self.__tokenDB.updateToken(provObj.token, user_id=provObj.token['user_id']) - if not result['OK']: - return result + + # ########### TODO: This place will store the original tokens ############ # + # result = self.__tokenDB.updateToken(provObj.token, user_id=provObj.token['user_id']) + # if not result['OK']: + # return result gLogger.debug("Read profile:", pprint.pformat(credDict)) # Is ID registred? @@ -352,7 +353,7 @@ def validate_consent_request(self, request, provider=None): grant.validate_consent_request() if not hasattr(grant, 'prompt'): grant.prompt = None - + # Check Identity Provider provider, providerChooser = self.validateIdentityProvider(req, provider) if not provider: @@ -403,7 +404,7 @@ def validateIdentityProvider(self, request, provider): for idP in idPs: # data: Status, Comment, Action dom.button(dom.a(idP, href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query)), - cls='button') + cls='button') return None, self.handle_response(payload=Template(doc.render()).generate()) # Check IdP @@ -411,7 +412,7 @@ def validateIdentityProvider(self, request, provider): return None, S_ERROR('%s is not registered in DIRAC.' % provider) return provider, None - + def __registerNewUser(self, provider, userProfile): """ Register new user diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index 95d76fdc8d0..df1fe9d0703 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -92,14 +92,14 @@ def query_authorization_code(self, code, client): def authenticate_user(self, authorization_code): """ Authenticate the user related to this authorization_code. - + :param authorization_code: authorization code """ return authorization_code.user def generate_authorization_code(self): """ The method to generate "code" value for authorization code data. - + :return: str """ gLogger.debug('Generate authorization code for credentials:', self.request.user) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index e403de89853..16651f2e513 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -56,7 +56,7 @@ class DeviceCodeGrant(_DeviceCodeGrant, AuthorizationEndpointMixin): def validate_authorization_request(self): """ Validate authorization request - + :return: None """ # Validate client for this request diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py index ed1afc8a81d..3e8980acaff 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -12,7 +12,7 @@ class RevocationEndpoint(_RevocationEndpoint): def query_token(self, token, token_type_hint, client): """ Query requested token from database. - + :param str token: token :param str token_type_hint: token type :param client: client diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index 0e067614eb8..2a5c2f661a9 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -17,19 +17,20 @@ DEFAULT_CLIENTS = { 'DIRACCLI': dict( - ProviderType = 'DIRAC', + ProviderType='DIRAC', client_id='DIRAC_CLI', response_types=['device'], grant_types=['urn:ietf:params:oauth:grant-type:device_code'] ), 'WebAppDIRAC': dict( - ProviderType = 'DIRAC', + ProviderType='DIRAC', token_endpoint_auth_method='client_secret_basic', response_types=['code'], grant_types=['authorization_code', 'refresh_token'] ) } + class Client(OAuth2ClientMixin): def __init__(self, params): diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index ab2851d2eb0..359459ad120 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -49,7 +49,7 @@ def returnProxy(self, _arg): """ self.proxy = True return S_OK() - + def setGroup(self, arg): """ Set email @@ -59,7 +59,7 @@ def setGroup(self, arg): """ self.group = arg return S_OK() - + def setProvider(self, arg): """ Set email @@ -69,7 +69,7 @@ def setProvider(self, arg): """ self.provider = arg return S_OK() - + def setIssuer(self, arg): """ Set email @@ -79,7 +79,7 @@ def setIssuer(self, arg): """ self.issuer = arg return S_OK() - + def setLivetime(self, arg): """ Set email @@ -136,7 +136,7 @@ def doOAuthMagic(self): idpObj.scope += '+proxy' if self.lifetime: idpObj.scope += '+lifetime:%s' % (int(self.lifetime) * 3600) - + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # Submit Device authorisation flow @@ -146,7 +146,7 @@ def doOAuthMagic(self): return S_ERROR(repr(e)) if not result['OK']: return result - + if self.proxy: result = writeToProxyFile(idpObj.token['proxy'].encode("UTF-8"), self.proxyLoc) if not result['OK']: diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py index a4976f3a9f2..2eb0230a7e8 100644 --- a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -26,5 +26,5 @@ def researchGroup(self, payload, token=None): credDict.update(self.parseEduperson(claims)) cerdDict = self.userDiscover(credDict) credDict['provider'] = self.name - + return credDict diff --git a/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py index 68d6c22495b..135f273fb13 100644 --- a/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py @@ -18,4 +18,3 @@ def fetch_metadata(self, url=None): self.metadata.update(collectMetadata(self.metadata['issuer'])) if url: return self.get(url, withhold_token=True).json() - diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 8bfca998fbd..17edfc25a25 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -120,7 +120,7 @@ def verifyToken(self, accessToken): self.jwks = self.fetch_metadata(self.get_metadata('jwks_uri')) pprint.pprint(self.jwks) return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) - + def update_token(self, token, refresh_token): pass @@ -152,6 +152,8 @@ def fetch_metadata(self, url=None): """ Fetch metada """ data = self.get(url or self.server_metadata_url, withhold_token=True).json() + if url: + return data self.metadata.update(data) def researchGroup(self, payload, token): @@ -211,7 +213,7 @@ def parseAuthResponse(self, response, session=None): session = {} self.log.debug('Current session is:\n', pprint.pformat(session)) - + self.fetchToken(authorization_response=response.uri, code_verifier=session.get('code_verifier')) # Get user info claims = self.getUserProfile() @@ -256,7 +258,8 @@ def parseEduperson(self, claimDict): credDict = {} attributes = { 'eduperson_unique_id': '^(?P.*)', - 'eduperson_entitlement': '^(?P[A-z,.,_,-,:]+):(group:registry|group):(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*' + 'eduperson_entitlement': '%s:%s' % ('^(?P[A-z,.,_,-,:]+):(group:registry|group)', + '(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*') } if 'eduperson_entitlement' not in claimDict: claimDict = self.getUserProfile() @@ -300,7 +303,7 @@ def submitDeviceCodeAuthorizationFlow(self, group=None): try: r = requests.post(self.get_metadata('device_authorization_endpoint'), data=dict( - client_id=self.client_id, scope=list_to_scope(scope_to_list(self.scope) + groupScopes) + client_id=self.client_id, scope=list_to_scope(scope_to_list(self.scope) + groupScopes) ), verify=self.verify) r.raise_for_status() deviceResponse = r.json() @@ -337,8 +340,7 @@ def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, t if time.time() - __start > timeout: return S_ERROR('Time out.') r = requests.post(self.get_metadata('token_endpoint'), data=dict(client_id=self.client_id, - grant_type=DEVICE_CODE_GRANT_TYPE, - device_code=deviceCode), verify=self.verify) + grant_type=DEVICE_CODE_GRANT_TYPE, device_code=deviceCode), verify=self.verify) token = r.json() if not token: return S_ERROR('Resived token is empty!') @@ -379,7 +381,6 @@ def exchangeGroup(self, group): return S_ERROR('Cannot exchange token with %s group.' % group) self.token = token return S_OK(token) - except Exception as e: return S_ERROR(repr(e)) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 412689ec9dc..82affc82470 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -68,7 +68,7 @@ def test_Token(): assert result['OK'], result['Message'] assert result['Value']['access_token'] == token['access_token'] assert result['Value']['refresh_token'] == token['refresh_token'] - assert result['Value']['revoked'] == False + assert not result['Value']['revoked'] # Check expired token result = db.getToken(exp_token['refresh_token']) @@ -81,7 +81,7 @@ def test_Token(): # Check if token revoked result = db.getToken(token['refresh_token']) assert result['OK'], result['Message'] - assert result['Value']['revoked'] == True + assert result['Value']['revoked'] def test_keys(): @@ -90,10 +90,10 @@ def test_keys(): # JWS jws = JsonWebSignature(algorithms=['RS256']) code_payload = {'user_id': 'user', - 'scope': 'scope', - 'redirect_uri': 'redirect_uri', - 'client_id': 'client', - 'code_challenge': 'code_challenge'} + 'scope': 'scope', + 'redirect_uri': 'redirect_uri', + 'client_id': 'client', + 'code_challenge': 'code_challenge'} # Token metadata header = {'alg': 'RS256'} @@ -153,7 +153,7 @@ def test_Sessions(): 'user_code': 'MDKP-MXMF', 'verification_uri': 'https://domain.com/DIRAC/auth/device', 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} - + # Example of the updated session sData2 = {'client_id': 'DIRAC_CLI', 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', From c6e5bb803fa9bffb92e6f773dae67728af547014 Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Thu, 20 May 2021 18:28:48 +0200 Subject: [PATCH 004/178] Update src/DIRAC/ConfigurationSystem/Client/Utilities.py Co-authored-by: fstagni --- src/DIRAC/ConfigurationSystem/Client/Utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index fe3b4e51d38..37c9a17a768 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -580,7 +580,7 @@ def getAuthAPI(): def getAuthorisationServerMetadata(issuer=None): - """ Get authoraisation server metadata + """ Get authorization server metadata :return: S_OK(dict)/S_ERROR() """ From 0d10464bdb540a9cfd15e5ef819657b5874ee444 Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Thu, 20 May 2021 18:57:54 +0200 Subject: [PATCH 005/178] Update src/DIRAC/Core/Security/TokenFile.py Co-authored-by: Chris Burr --- src/DIRAC/Core/Security/TokenFile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py index 2001ea1a5a7..ad569c22ad0 100644 --- a/src/DIRAC/Core/Security/TokenFile.py +++ b/src/DIRAC/Core/Security/TokenFile.py @@ -49,7 +49,7 @@ def writeToTokenFile(tokenContents, fileName=False): return S_ERROR(DErrno.ECTMPF) fileName = tokenLocation try: - with open(fileName, 'wb') as fd: + with open(fileName, 'wt') as fd: fd.write(tokenContents) except Exception as e: return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e).replace(',)', ')'))) From 3d07ee5c4b6d41cd0ff216b1cc114df31c0dc9e9 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 20 May 2021 20:29:32 +0200 Subject: [PATCH 006/178] add new packages to setup.cfg --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 93674cb8269..449b78b6456 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,9 @@ install_requires = six sqlalchemy subprocess32 + authlib + pyjwt + dominate zip_safe = False include_package_data = True From 547d9492b0563666ba1862c481eb7bb9c12edc5b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 20 May 2021 20:54:53 +0200 Subject: [PATCH 007/178] fix new packages, add scripts to setup --- environment-py3.yml | 2 +- environment.yml | 2 +- requirements.txt | 73 +++++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 5 +++- 4 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 requirements.txt diff --git a/environment-py3.yml b/environment-py3.yml index 6ee6ea34490..f96ea36e775 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,7 +89,7 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - authlib == 1.0.0 + - authlib >=1.0.0 - pyjwt - dominate - pip: diff --git a/environment.yml b/environment.yml index 03decc4470b..d0a455611eb 100644 --- a/environment.yml +++ b/environment.yml @@ -73,7 +73,7 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - authlib == 1.0.0 + - authlib >=1.0.0 - pyjwt - dominate - pip: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000000..ff3483ecd60 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,73 @@ +# From repo +fts3-rest + +#Patch for tornado +git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable +git+https://github.com/DIRACGrid/tornado_m2crypto.git + +# From pypi +apache-libcloud +boto3 +#asn1 +M2Crypto>=0.36 +autopep8==1.3.3 +cachetools<4 +certifi +coverage +docutils +diraccfg +elasticsearch-dsl~=6.3.1 +CMRESHandler>=1.0.0b4 +funcsigs +future +futures>=3.0.5 +GitPython>=2.1.0 +# newer versions of matplotlib require python 3 +matplotlib>=2.1.0,<3.0 +mock>=1.0.1 +MySQL-python>=1.2.5 +importlib_resources +jinja2 +ipython==5.3.0 +numpy>=1.10.1 +pexpect>=4.0.1 +pillow +psutil>=4.2.0 +pyasn1>0.4.1 +pyasn1_modules +Pygments>=1.5 +parameterized +pylint>=1.6.5 +pyparsing>=2.0.6 +pytest>=3.6 +pytest-cov>=2.2.0 +pytest-mock +pytz +readline>=6.2.4 +recommonmark +requests>=2.9.1 +rucio-clients >=1.25.6 +simplejson>=3.8.1 +six>=1.10 +# Freeze until all problems with 1.4 are solved +sqlalchemy==1.3.* +xmltodict +# more recent version are python 3 only +stomp.py==4.1.23 +suds-jurko>=0.6 +sphinx +# typing comes in via m2crypto. newer versions of typing caused an error in hypothesis +typing==3.6.6 +hypothesis +python-json-logger>=0.1.8 +multi-mechanize>=1.2.0 +caniusepython3 +subprocess32 +flaky +ldap3 +# setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 +setuptools_scm<6.0 +# OAuth2 +authlib >=1.0.0 +pyjwt +dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 449b78b6456..a2a278abe47 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ install_requires = six sqlalchemy subprocess32 - authlib + authlib >=1.0.0 pyjwt dominate zip_safe = False @@ -131,6 +131,8 @@ console_scripts = # Core.Tornado tornado-start-CS = DIRAC.Core.Tornado.scripts.tornado_start_CS:main [server] tornado-start-all = DIRAC.Core.Tornado.scripts.tornado_start_all:main [server] + tornado-start-AS = DIRAC.Core.Tornado.scripts.tornado_start_AS:main [server] + tornado-start-web = DIRAC.Core.Tornado.scripts.tornado_start_web:main [server] # DataManagementSystem dirac-admin-allow-se = DIRAC.DataManagementSystem.scripts.dirac_admin_allow_se:main [admin] dirac-admin-ban-se = DIRAC.DataManagementSystem.scripts.dirac_admin_ban_se:main [admin] @@ -161,6 +163,7 @@ console_scripts = dirac-dms-user-lfns = DIRAC.DataManagementSystem.scripts.dirac_dms_user_lfns:main dirac-dms-user-quota = DIRAC.DataManagementSystem.scripts.dirac_dms_user_quota:main # FrameworkSystem + dirac-login = DIRAC.FrameworkSystem.scripts.dirac_login:main [server] dirac-admin-get-CAs = DIRAC.FrameworkSystem.scripts.dirac_admin_get_CAs:main [server] dirac-admin-get-proxy = DIRAC.FrameworkSystem.scripts.dirac_admin_get_proxy:main [admin] dirac-admin-proxy-upload = DIRAC.FrameworkSystem.scripts.dirac_admin_proxy_upload:main [admin] From 121a3c0003296bb3ffe20291476f203603a531fd Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 20 May 2021 20:56:03 +0200 Subject: [PATCH 008/178] add test --- tests/Integration/all_integration_server_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/Integration/all_integration_server_tests.sh b/tests/Integration/all_integration_server_tests.sh index d98f978a193..8d2915c2a7d 100644 --- a/tests/Integration/all_integration_server_tests.sh +++ b/tests/Integration/all_integration_server_tests.sh @@ -27,6 +27,7 @@ pytest "${THIS_DIR}/Core/Test_MySQLDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( echo -e "*** $(date -u) **** FRAMEWORK TESTS (partially skipped) ****\n" pytest "${THIS_DIR}/Framework/Test_InstalledComponentsDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) python "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +python "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #pytest ${THIS_DIR}/Framework/Test_LoggingDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #-------------------------------------------------------------------------------# From 1e89d963b74dbbd8e1e91f9ddfa19156d38ee3cf Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 20 May 2021 20:56:32 +0200 Subject: [PATCH 009/178] add AuthDB to cfg --- src/DIRAC/Core/scripts/install_full.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DIRAC/Core/scripts/install_full.cfg b/src/DIRAC/Core/scripts/install_full.cfg index 248a69b59d9..3410055c941 100755 --- a/src/DIRAC/Core/scripts/install_full.cfg +++ b/src/DIRAC/Core/scripts/install_full.cfg @@ -101,6 +101,7 @@ LocalInstallation Databases += FTSDB Databases += ComponentMonitoringDB Databases += ProxyDB + Databases += AuthDB Databases += PilotAgentsDB Databases += AccountingDB Databases += TransformationDB From 86895db40c5b1d09f0c95518e05bd56305389f91 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 20 May 2021 20:57:11 +0200 Subject: [PATCH 010/178] smal fixes --- src/DIRAC/Core/Security/TokenFile.py | 4 ++-- src/DIRAC/Core/Security/TokenInfo.py | 6 +++--- .../Core/Tornado/Server/BaseRequestHandler.py | 8 -------- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 18 +++++++----------- .../private/authorization/grants/__init__.py | 3 --- .../private/authorization/utils/__init__.py | 3 --- .../FrameworkSystem/scripts/dirac_login.py | 5 +---- .../Resources/IdProvider/IdProviderFactory.py | 4 ++-- .../Resources/IdProvider/OAuth2IdProvider.py | 7 +------ 9 files changed, 16 insertions(+), 42 deletions(-) diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py index ad569c22ad0..45fe00b6a75 100644 --- a/src/DIRAC/Core/Security/TokenFile.py +++ b/src/DIRAC/Core/Security/TokenFile.py @@ -52,11 +52,11 @@ def writeToTokenFile(tokenContents, fileName=False): with open(fileName, 'wt') as fd: fd.write(tokenContents) except Exception as e: - return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e).replace(',)', ')'))) + return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e))) try: os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) except Exception as e: - return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e).replace(',)', ')'))) + return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e))) return S_OK(fileName) diff --git a/src/DIRAC/Core/Security/TokenInfo.py b/src/DIRAC/Core/Security/TokenInfo.py index 39efe2e5f23..c056e7bd897 100644 --- a/src/DIRAC/Core/Security/TokenInfo.py +++ b/src/DIRAC/Core/Security/TokenInfo.py @@ -1,11 +1,11 @@ """ - Set of utilities to retrieve Information from proxy + Set of utilities to retrieve Information from token """ from __future__ import division from __future__ import absolute_import from __future__ import print_function -import jwt as _jwt +import jwt import six import time @@ -39,7 +39,7 @@ def getTokenInfo(token=False): return result token = OAuth2Token(result['Value']) - payload = _jwt.decode(token['access_token'], options=dict(verify_signature=False)) + payload = jwt.decode(token['access_token'], options=dict(verify_signature=False)) result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) if not result['OK']: return result diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 611f31a0529..36eaa9d5fee 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -9,20 +9,13 @@ from io import open -import jwt as _jwt - import os import time -import pprint -import requests import threading from datetime import datetime from six import string_types from six.moves import http_client from six.moves.urllib.parse import unquote -from authlib.jose import JsonWebKey, jwt -from authlib.oauth2.rfc6749.util import scope_to_list - import tornado from tornado import gen @@ -38,7 +31,6 @@ from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory sLog = gLogger.getSubLogger(__name__.split('.')[-1]) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 3dde68ea1fa..599907c72f9 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -4,22 +4,19 @@ from __future__ import division from __future__ import print_function +import jwt import json -import jwt as _jwt from time import time -from pprint import pprint -from M2Crypto import RSA, BIO from sqlalchemy import Column, Integer, Text, String from sqlalchemy.orm import scoped_session from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from sqlalchemy.ext.declarative import declarative_base -from authlib.jose import KeySet, RSAKey, jwk -from authlib.common.security import generate_token +from authlib.jose import KeySet, RSAKey from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin -from DIRAC import S_OK, S_ERROR, gLogger, gConfig +from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token @@ -159,7 +156,7 @@ def storeToken(self, token): :return: S_OK(str)/S_ERROR() """ - token['expires_at'] = int(_jwt.decode(token['refresh_token'], options=dict(verify_signature=False))['exp']) + token['expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False))['exp']) gLogger.debug('Store token:', dict(token)) attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) session = self.session() @@ -256,14 +253,13 @@ def getActiveKeys(self): """ session = self.session() try: - # Remove all expired jwks - session.query(JWK).filter(JWK.expires_at < time()).delete() - jwks = session.query(JWK).filter(JWK.expires_at > time()).all() + # Remove all expis + session.query(JWK).filter(JWK.expires_at < time()).delete()s = session.query(JWK).filter(JWK.expires_at > time()).all() except NoResultFound: return self.__result(session, S_OK([])) except Exception as e: return self.__result(session, S_ERROR(str(e))) - return self.__result(session, S_OK([self.__rowToDict(jwk) for jwk in jwks])) + return self.__result(session, S_OK([self.__rowToD) s])) def removeKeys(self): """ Get active keys diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py index 878841c1840..e69de29bb2d 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py @@ -1,3 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py index 878841c1840..e69de29bb2d 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py @@ -1,3 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 359459ad120..6229fe5e112 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -1,7 +1,7 @@ #!/usr/bin/env python ######################################################################## # File : dirac-login.py -# Author : Adrian Casajus +# Author : Andrii Lytovchenko ######################################################################## """ Login to DIRAC. @@ -15,7 +15,6 @@ import os import sys -import urllib3 import requests import threading @@ -137,8 +136,6 @@ def doOAuthMagic(self): if self.lifetime: idpObj.scope += '+lifetime:%s' % (int(self.lifetime) * 3600) - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # Submit Device authorisation flow try: result = idpObj.authorization() diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 65907c5e433..1d42cb4fbee 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -10,7 +10,7 @@ from __future__ import division from __future__ import print_function -import jwt as _jwt +import jwt from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe @@ -52,7 +52,7 @@ def getIdProviderForToken(self, token): data = {} # Read token without verification to get issuer - issuer = _jwt.decode(token, options=dict(verify_signature=False))['iss'].strip('/') + issuer = jwt.decode(token, options=dict(verify_signature=False))['iss'].strip('/') result = getSettingsNamesForIdPIssuer(issuer) if result['OK']: diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 17edfc25a25..e449d16a46a 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -6,7 +6,6 @@ import re import six -import jwt as _jwt import time import pprint import requests @@ -16,7 +15,6 @@ from authlib.common.security import generate_token from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope from authlib.oauth2.rfc6749.parameters import prepare_token_request -from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.oauth2.rfc8628 import DEVICE_CODE_GRANT_TYPE from authlib.integrations.requests_client import OAuth2Session from authlib.oidc.discovery.well_known import get_well_known_url @@ -27,7 +25,7 @@ from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Resources.IdProvider.IdProvider import IdProvider -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getVOForGroup, getGroupOption +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption __RCSID__ = "$Id$" @@ -109,7 +107,6 @@ def verifyToken(self, accessToken): :param str accessToken: access token """ - pprint.pprint(self.jwks) try: # Try to decode token gLogger.debug("Try to decode token:", accessToken) @@ -118,7 +115,6 @@ def verifyToken(self, accessToken): # If we have outdated keys, we try to update them from identity provider gLogger.debug("Try to update %s jwks.." % self.metadata['issuer']) self.jwks = self.fetch_metadata(self.get_metadata('jwks_uri')) - pprint.pprint(self.jwks) return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) def update_token(self, token, refresh_token): @@ -279,7 +275,6 @@ def userDiscover(self, credDict): credDict['DIRACGroups'] = [] for vo, voData in credDict.get('VOs', {}).items(): result = getVOMSRoleGroupMapping(vo) - pprint.pprint(result) if result['OK']: for role in voData['VORoles']: groups = result['Value']['VOMSDIRAC'].get('/%s' % role) From 0ed89fd72f3ab525d693df49b784ac7c3f3957d0 Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Fri, 21 May 2021 11:48:30 +0200 Subject: [PATCH 011/178] Update src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py Co-authored-by: chaen --- src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index bb1e85de049..70c82b6ad4f 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -445,7 +445,7 @@ def getSettingsNamesForIdPIssuer(issuer): nameIssuer = gConfig.getValue('%s/IdProviders/%s/issuer' % (gBaseResourcesSection, name)) if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): names.append(name) - return S_OK(names) if names else S_ERROR('Not found provider wwith %s issuer.' % issuer) + return S_OK(names) if names else S_ERROR('Not found provider with %s issuer.' % issuer) def getInfoAboutProviders(of=None, providerName=None, option='', section=''): From 3bccdf1427a4f474f3500a5e893138677b6ba18e Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Fri, 21 May 2021 11:54:58 +0200 Subject: [PATCH 012/178] Update src/DIRAC/Core/Security/TokenFile.py Co-authored-by: chaen --- src/DIRAC/Core/Security/TokenFile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py index 45fe00b6a75..9f83bbef5d5 100644 --- a/src/DIRAC/Core/Security/TokenFile.py +++ b/src/DIRAC/Core/Security/TokenFile.py @@ -26,7 +26,7 @@ def readTokenFromFile(fileName=None): if not fileName: fileName = getTokenLocation() or os.environ.get('DIRAC_TOKEN_FILE', "/tmp/JWTup_u%d" % os.getuid()) try: - with open(fileName, 'r') as f: + with open(fileName, 'rt') as f: data = f.read() return S_OK(json.loads(data)) except Exception as e: From 83666007e2800a6cd1aab9caa9ed7af6ab3b0056 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 21 May 2021 12:30:46 +0200 Subject: [PATCH 013/178] fix discoverCredentialsToUse --- .../Core/Tornado/Client/private/TornadoBaseClient.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index f84d0d05fd0..15193892820 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -206,6 +206,11 @@ def __discoverCredentialsToUse(self): -> if KW_SKIP_CA_CHECK is not in kwargs and we are using the certificates, set KW_SKIP_CA_CHECK to false in kwargs -> if KW_SKIP_CA_CHECK is not in kwargs and we are not using the certificate, check the skipCACheck + * Baerer token: + -> If KW_USE_ACCESS_TOKEN in kwargs, sets it in self.__useAccessToken + -> If not, check "/DIRAC/Security/UseTokens", and sets it in self.__useAccessToken + and kwargs[KW_USE_ACCESS_TOKEN] + -> If DIRAC_USE_ACCESS_TOKEN' in os.environ, sets it in self.__useAccessToken * Proxy Chain WARNING: MOSTLY COPY/PASTE FROM Core/Diset/private/BaseClient @@ -227,9 +232,9 @@ def __discoverCredentialsToUse(self): if self.KW_USE_ACCESS_TOKEN in self.kwargs: self.__useAccessToken = self.kwargs[self.KW_USE_ACCESS_TOKEN] else: - if not gConfig.useServerCertificate(): - self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") - if os.environ.get('DIRAC_USE_ACCESS_TOKEN'): + self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") + self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken + if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: self.__useAccessToken = os.environ['DIRAC_USE_ACCESS_TOKEN'] # Rewrite a little bit from here: don't need the proxy string, we use the file From 2cfb5f5e562dee2536c04f960e3c7d4111223e46 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 21 May 2021 13:39:13 +0200 Subject: [PATCH 014/178] fix bugs, issues --- .../ConfigurationSystem/Client/Utilities.py | 7 +++++-- src/DIRAC/Core/Security/TokenFile.py | 4 ++-- src/DIRAC/Core/Security/TokenInfo.py | 21 ++++++++++--------- .../Core/Tornado/Server/TornadoServer.py | 2 +- .../private/authorization/utils/Tokens.py | 8 +------ .../FrameworkSystem/scripts/dirac_login.py | 13 ++++++------ 6 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index 37c9a17a768..b19f9ece94c 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -600,8 +600,11 @@ def getAuthorisationServerMetadata(issuer=None): return S_ERROR('No issuer found in DIRAC authorization server configuration.') # Search values with type list - for key, v in data.items(): - data[key] = [e for e in v.replace(', ', ',').split(',') if e] if ',' in v else v + for key in data: + if ',' in data[key]: + # Convert to list + data[key] = data[key].replace(', ', ',').split(',') + data[key] = [item for item in data[key] if item] return S_OK(data) diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py index 9f83bbef5d5..9817ec75ed3 100644 --- a/src/DIRAC/Core/Security/TokenFile.py +++ b/src/DIRAC/Core/Security/TokenFile.py @@ -30,7 +30,7 @@ def readTokenFromFile(fileName=None): data = f.read() return S_OK(json.loads(data)) except Exception as e: - return S_ERROR('Cannot read token.') + return S_ERROR('Cannot read token. %s' % repr(e)) def writeToTokenFile(tokenContents, fileName=False): @@ -73,7 +73,7 @@ def writeTokenDictToTokenFile(tokenDict, fileName=None): try: retVal = json.dumps(tokenDict) except Exception as e: - return S_ERROR('Cannot read token.') + return S_ERROR('Cannot read token. %s' % repr(e)) return writeToTokenFile(retVal, fileName) diff --git a/src/DIRAC/Core/Security/TokenInfo.py b/src/DIRAC/Core/Security/TokenInfo.py index c056e7bd897..f306f5e1fcb 100644 --- a/src/DIRAC/Core/Security/TokenInfo.py +++ b/src/DIRAC/Core/Security/TokenInfo.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from __future__ import print_function -import jwt import six import time @@ -16,6 +15,7 @@ from DIRAC.Core.Security.TokenFile import readTokenFromFile from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory __RCSID__ = "$Id$" @@ -34,12 +34,17 @@ def getTokenInfo(token=False): tokenLocation = token if isinstance(token, six.string_types) else Locations.getTokenLocation() if not tokenLocation: return S_ERROR("Cannot find token location.") - result = readTokenFromFile() + result = readTokenFromFile(tokenLocation) if not result['OK']: return result token = OAuth2Token(result['Value']) - payload = jwt.decode(token['access_token'], options=dict(verify_signature=False)) + result = IdProviderFactory().getIdProviderForToken(accessToken) + if not result['OK']: + return result + cli = result['Value'] + payload = cli.verifyToken(accessToken) + result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) if not result['OK']: return result @@ -56,18 +61,14 @@ def formatTokenInfoAsString(infoDict): :return: str """ - secs = int(infoDict['exp']) - time.time() - hours = int(secs / 3600) - secs -= hours * 3600 - mins = int(secs / 60) - secs -= mins * 60 - exp = "%02d:%02d:%02d" % (hours, mins, secs) + secsLeft = int(infoDict['exp']) - time.time() + strTimeleft = datetime.fromtimestamp(secs).strftime("%I:%M:%S") leftAlign = 13 contentList = [] contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) - contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), exp)) + contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) if infoDict.get('group'): contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 50b310498dd..aabe2df86a3 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -62,7 +62,7 @@ class TornadoServer(object): Example 2:We want to debug service1 and service2 only, and use another port for that :: services = ['component/service1:port1', 'component/service2'] - endpoints = ['component/endpoint1:port1', 'component/endpoint2'] + endpoints = ['component/endpoint1', 'component/endpoint2'] serverToLaunch = TornadoServer(services=services, endpoints=endpoints, port=1234) serverToLaunch.startTornado() diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 6be7777e937..6864e0cdb95 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -3,16 +3,10 @@ from __future__ import print_function from time import time -import functools -from contextlib import contextmanager -from authlib.jose import jwt -from authlib.oauth2 import OAuth2Error, ResourceProtector as _ResourceProtector -from authlib.oauth2.rfc6749 import MissingAuthorizationError, HttpRequest -from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator +from authlib.oauth2.rfc6749.util import scope_to_list from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin -from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope class OAuth2Token(_OAuth2Token, OAuth2TokenMixin): diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 6229fe5e112..871fd0ba7f1 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -40,6 +40,7 @@ def __init__(self): self.provider = 'DIRACCLI' self.issuer = None self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() + self.tokenLoc = '/tmp/JWTup_u%s' % os.getuid() def returnProxy(self, _arg): """ Set email @@ -139,8 +140,8 @@ def doOAuthMagic(self): # Submit Device authorisation flow try: result = idpObj.authorization() - except KeyboardInterrupt as e: - return S_ERROR(repr(e)) + except KeyboardInterrupt: + return S_ERROR('User canceled the operation..') if not result['OK']: return result @@ -150,10 +151,10 @@ def doOAuthMagic(self): return result gLogger.notice('Proxy is saved to %s.' % self.proxyLoc) else: - result = writeTokenDictToTokenFile(idpObj.token) + result = writeTokenDictToTokenFile(idpObj.token, self.tokenLoc) if not result['OK']: return result - gLogger.notice('Token is saved in %s.' % result['Value']) + gLogger.notice('Token is saved in %s.' % self.tokenLoc) result = Script.enableCS() if not result['OK']: @@ -166,12 +167,12 @@ def doOAuthMagic(self): return result['Message'] gLogger.notice(formatProxyInfoAsString(result['Value'])) else: - result = getTokenInfo(self.proxyLoc) + result = getTokenInfo(self.tokenLoc) if not result['OK']: return result['Message'] gLogger.notice(formatTokenInfoAsString(result['Value'])) - return S_OK(self.proxyLoc) + return S_OK() @DIRACScript() From ef4219ad2cf7e4d8430e689e3510fa69b249b581 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:16:04 +0200 Subject: [PATCH 015/178] fix conf method --- .../ConfigurationSystem/Client/Utilities.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index b19f9ece94c..c669ccaca43 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -579,32 +579,24 @@ def getAuthAPI(): return gConfig.getValue("/Systems/Framework/%s/URLs/AuthAPI" % getSystemInstance("Framework")) -def getAuthorisationServerMetadata(issuer=None): +def getAuthorizationServerMetadata(issuer=None): """ Get authorization server metadata :return: S_OK(dict)/S_ERROR() """ - data = {'issuer': issuer} - - result = gConfig.getSections('/DIRAC') - if result['OK']: - if 'Authorization' in result['Value']: - result = gConfig.getOptionsDictRecursively('/DIRAC/Authorization') - if result['OK']: - data.update(result['Value']) + result = gConfig.getOptionsDictRecursively('/DIRAC/Authorization') if not result['OK']: - return result - if not data['issuer']: - data['issuer'] = getAuthAPI() + return {'issuer': issuer} if issuer else result + data = result['Value'] + + # Research DIRAC Authorization Server issuer + data['issuer'] = data.get('issuer', issuer) if not data['issuer']: - return S_ERROR('No issuer found in DIRAC authorization server configuration.') - - # Search values with type list - for key in data: - if ',' in data[key]: - # Convert to list - data[key] = data[key].replace(', ', ',').split(',') - data[key] = [item for item in data[key] if item] + try: + data['issuer'] = getAuthAPI() + except Exception as e: + return S_ERROR('No issuer found in DIRAC authorization server: %s' % repr(e)) + return S_OK(data) From 9ce128d188ccd650d862237693d99a50ffa5b309 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:16:48 +0200 Subject: [PATCH 016/178] verify token on client --- src/DIRAC/Core/Security/TokenInfo.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/DIRAC/Core/Security/TokenInfo.py b/src/DIRAC/Core/Security/TokenInfo.py index f306f5e1fcb..42a6253d351 100644 --- a/src/DIRAC/Core/Security/TokenInfo.py +++ b/src/DIRAC/Core/Security/TokenInfo.py @@ -7,6 +7,7 @@ import six import time +import datetime from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import DErrno @@ -37,13 +38,14 @@ def getTokenInfo(token=False): result = readTokenFromFile(tokenLocation) if not result['OK']: return result - token = OAuth2Token(result['Value']) + token = OAuth2Token(result['Value'])['access_token'] - result = IdProviderFactory().getIdProviderForToken(accessToken) + result = IdProviderFactory().getIdProviderForToken(token) if not result['OK']: return result cli = result['Value'] - payload = cli.verifyToken(accessToken) + cli.updateJWKs() + payload = cli.verifyToken(token) result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) if not result['OK']: @@ -62,7 +64,7 @@ def formatTokenInfoAsString(infoDict): :return: str """ secsLeft = int(infoDict['exp']) - time.time() - strTimeleft = datetime.fromtimestamp(secs).strftime("%I:%M:%S") + strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") leftAlign = 13 contentList = [] From 99e5357349d1b90e16bdc5ea125b56e733279b0a Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:26:30 +0200 Subject: [PATCH 017/178] add refresh jwks, fix bugs --- .../Core/Tornado/Server/BaseRequestHandler.py | 112 ++++++++++++++---- 1 file changed, 88 insertions(+), 24 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 36eaa9d5fee..6adc5f24847 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -10,7 +10,10 @@ from io import open import os +import jwt as _jwt +from authlib.jose import JsonWebKey, jwt import time +import requests import threading from datetime import datetime from six import string_types @@ -32,6 +35,7 @@ from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance sLog = gLogger.getSubLogger(__name__.split('.')[-1]) @@ -80,6 +84,14 @@ class BaseRequestHandler(RequestHandler): # Which grant type to use USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + # Key updates are started at initialization, ie at the first request, + # this parameter shows that the first request of keys is made and + # it is possible to use them for check of tokens + __init_jwk_done = False + _idps = IdProviderFactory() + _idp = {} + _jwks = {} + @classmethod def _initMonitoring(cls, serviceName, fullUrl): """ @@ -146,6 +158,47 @@ def _getServiceInfo(cls, serviceName, request): """ return {} + @classmethod + @gen.coroutine + def __refreshJWKsLoop(cls): + """ Auto refresh JWKs + """ + while True: + + # Research Identity Providers + result = getProvidersForInstance('Id') + if result['OK']: + for providerName in list(set(result['Value'] + ['DIRACCLI'])): + result = cls._idps.getIdProvider(providerName) + if result['OK']: + issuer = result['Value'].issuer.strip('/') + jwks_uri = result['Value'].get_metadata('jwks_uri') + cls._idp[issuer] = result['Value'] + + gLogger.debug('Updating public keys..') + retVal = yield IOLoop.current().run_in_executor(None, cls.__refreshJWKs, jwks_uri) + result = retVal.result() + if not result['OK']: + gLogger.error('%s keys not updated' % issuer, result['Message']) + else: + gLogger.debug('%s keys updated' % issuer, result['Value']) + cls._jwks[issuer] = result['Value'] + + cls.__init_jwk_done = True + yield gen.sleep(24 * 3600) + + @classmethod + @gen.coroutine + def __refreshJWKs(cls, jwks_uri): + """ Updating public keys + """ + try: + response = requests.get(jwks_uri, verify=False) + response.raise_for_status() + return S_OK(response.json()) + except requests.exceptions.RequestException as e: + return S_ERROR("Error %s" % e) + @classmethod def __initializeService(cls, request): """ @@ -169,7 +222,8 @@ def __initializeService(cls, request): if cls.__init_done: return S_OK() - cls._idps = IdProviderFactory() + # Run automatic public key updates + IOLoop.current().spawn_callback(cls.__refreshJWKsLoop) # absoluteUrl: full URL e.g. ``https://://`` absoluteUrl = request.path @@ -311,7 +365,6 @@ def prepare(self): self.method = self._getMethodName() self._monitorRequest() - self._prepare() def _prepare(self): """ @@ -422,6 +475,9 @@ def _executeMethod(self, args): This method is called in an executor, and so cannot use methods like self.write See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes """ + # Because the keys are updated during the first initialization, the + # first authorization process must start after initialization + self._prepare() sLog.notice( "Incoming request %s /%s: %s" % @@ -452,17 +508,12 @@ def _finishFuture(self, retVal): # Wait result only if it's a Future object self.result = retVal.result() if isinstance(retVal, Future) else retVal - print('FUTURE RESULT >>>') - print(self.result) - print(self.get_status()) - # Here it is safe to write back to the client, because we are not # in a thread anymore # Is it S_OK or S_ERROR - if isinstance(self.result, dict): - if isinstance(self.result.get('OK'), bool) and ('Value' if self.result['OK'] else 'Message') in self.result: - self._parseDIRACResult(self.result) + if self.isDIRACResult(self.result): + self._parseDIRACResult(self.result) # If set to true, do not JEncode the return of the RPC call # This is basically only used for file download through @@ -483,6 +534,13 @@ def _finishFuture(self, retVal): self.finish() + def isDIRACResult(self, result): + """ Check if it DIRAC result + """ + if isinstance(result, dict): + if isinstance(result.get('OK'), bool) and ('Value' if result['OK'] else 'Message') in result: + return True + def _parseDIRACResult(self, result): """ Processing of a standard DIRAC result, but in a separate method so that it can be modified for another class if necessary @@ -579,25 +637,31 @@ def _authzSSL(self): credDict['extraCredentials'] = decode(extraCred)[0] return S_OK(credDict) - def _authzJWT(self): + def _authzJWT(self, accessToken=None): """ Load token claims in DIRAC and extract informations. + :param str accessToken: access_token + :return: S_OK(dict)/S_ERROR() """ - # Export token from headers - token = self.request.headers.get('Authorization') - if not token or len(token.split()) != 2: - return S_ERROR('Not found a bearer access token.') - tokenType, accessToken = token.split() - if tokenType.lower() != 'bearer': - return S_ERROR('Found a not bearer access token.') - - result = self._idps.getIdProviderForToken(accessToken) - if not result['OK']: - return result - cli = result['Value'] - payload = cli.verifyToken(accessToken) - credDict = cli.researchGroup(payload, accessToken) + if not self.__init_jwk_done: + time.sleep(5) + + if not accessToken: + # Export token from headers + token = self.request.headers.get('Authorization') + if not token or len(token.split()) != 2: + return S_ERROR('Not found a bearer access token.') + tokenType, accessToken = token.split() + if tokenType.lower() != 'bearer': + return S_ERROR('Found a not bearer access token.') + + # Read token without verification to get issuer + issuer = _jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') + + # Verify token + payload = self._idp[issuer].verifyToken(accessToken, self._jwks[issuer]) + credDict = self._idp[issuer].researchGroup(payload, accessToken) return S_OK(credDict) From 0313f378cd2da665ed3e796fb74af8e6478383b3 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:29:22 +0200 Subject: [PATCH 018/178] fix bugs --- .../Core/Tornado/Server/HandlerManager.py | 21 ++++++++++++------- .../Core/Tornado/Server/TornadoServer.py | 2 +- .../Core/Tornado/scripts/tornado_start_all.py | 2 +- .../Core/Tornado/scripts/tornado_start_web.py | 5 +---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index f51dadc9ddc..cf5d56fc386 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -139,18 +139,20 @@ def discoverHandlers(self, handlerInstance): sysInstance = PathFinder.getSystemInstance(system) result = gConfig.getSections('/Systems/%s/%s/%s' % (system, sysInstance, handlerInstance)) if result['OK']: - for inst in result['Value']: - newInst = ("%s/%s" % (system, inst)) + for instName in result['Value']: + newInst = ("%s/%s" % (system, instName)) + port = gConfig.getValue('/Systems/%s/%s/%s/%s/Port' % (system, sysInstance, + handlerInstance, instName)) + if port: + newInst += ':%s' % port if handlerInstance == 'Services': # We search in the CS all handlers which used HTTPS as protocol - isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, sysInstance, inst)) + isHTTPS = gConfig.getValue('/Systems/%s/%s/%s/%s/Protocol' % (system, sysInstance, + handlerInstance, instName)) if isHTTPS and isHTTPS.lower() == 'https': urls.append(newInst) else: - port = gConfig.getValue('/Systems/%s/%s/Services/%s/Port' % (system, sysInstance, inst)) - if port: - newInst += ':%s' % port urls.append(newInst) # On systems sometime you have things not related to services... except RuntimeError: @@ -178,6 +180,9 @@ def loadServicesHandlers(self, services=None): self.__services = self.discoverHandlers('Services') if self.__services: + # Extract ports + ports, self.__services = self.__extractPorts(self.__services) + self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") # Use DIRAC system to load: search in CS if path is given and if not defined @@ -192,7 +197,7 @@ def loadServicesHandlers(self, services=None): # Here we just want the service name, for tornado serviceTuple = url.replace('https://', '').split('/')[-2:] url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) - self.__addHandler(module['loadName'], module['classObj'], url) + self.__addHandler(module['loadName'], module['classObj'], url, ports.get(module['modName'])) return S_OK() def __extractPorts(self, urls): @@ -225,7 +230,7 @@ def loadEndpointsHandlers(self, endpoints=None): :return: S_OK()/S_ERROR() """ - # list of endpoints, e.g. ['Framework/Proxy', ...] + # list of endpoints, e.g. ['Framework/Auth', ...] if isinstance(endpoints, string_types): endpoints = [endpoints] # list of endpoints. If __endpoints is ``True`` then list of endpoints will dicover from CS diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index aabe2df86a3..7be28d08933 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -209,7 +209,7 @@ def startTornado(self): router = Application(app['routes'], **settings) server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) try: - server.listen(port) + server.listen(int(port)) except Exception as e: # pylint: disable=broad-except sLog.exception("Exception starting HTTPServer", e) raise diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py index 4a28a5b4704..c0b729a7890 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_all.py @@ -60,7 +60,7 @@ def main(): gLogger.initialize('Tornado', "/") - serverToLaunch = TornadoServer() + serverToLaunch = TornadoServer(endpoints=True) serverToLaunch.startTornado() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py index d9023de29d6..23cfe7cb093 100644 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py @@ -39,10 +39,7 @@ def main(): gLogger.initialize('Tornado', "/") - services = ['DataManagement/TornadoFileCatalog'] - endpoints = False - - serverToLaunch = TornadoServer(services, endpoints, port=8000) + serverToLaunch = TornadoServer(False) try: from WebAppDIRAC.Core.App import App From 285a779d36a22836328d543447596a68b0dff766 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:31:13 +0200 Subject: [PATCH 019/178] fix AuthDB --- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 599907c72f9..d6f7143bdcd 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -253,13 +253,14 @@ def getActiveKeys(self): """ session = self.session() try: - # Remove all expis - session.query(JWK).filter(JWK.expires_at < time()).delete()s = session.query(JWK).filter(JWK.expires_at > time()).all() + # Remove all expired jwks + session.query(JWK).filter(JWK.expires_at < time()).delete() + jwks = session.query(JWK).filter(JWK.expires_at > time()).all() except NoResultFound: return self.__result(session, S_OK([])) except Exception as e: return self.__result(session, S_ERROR(str(e))) - return self.__result(session, S_OK([self.__rowToD) s])) + return self.__result(session, S_OK([self.__rowToDict(jwk) for jwk in jwks])) def removeKeys(self): """ Get active keys From 214e44973e6610ea25041f5c242647b01f6a76f1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:36:07 +0200 Subject: [PATCH 020/178] move getClient to Client object --- .../private/authorization/AuthServer.py | 36 ++----------- .../authorization/grants/DeviceFlow.py | 6 +-- .../private/authorization/utils/Clients.py | 50 ++++++++++++++++--- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index caff2b1f701..4667552f781 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -22,13 +22,13 @@ from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, DeviceCodeGrant) from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import Client, DEFAULT_CLIENTS +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIACClientByID from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata, isDownloadablePersonalProxy +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getSetup @@ -44,7 +44,7 @@ def collectMetadata(issuer=None): """ Collect metadata """ - result = getAuthorisationServerMetadata() + result = getAuthorizationServerMetadata(issuer) if not result['OK']: raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) metadata = result['Value'] @@ -77,7 +77,7 @@ def __init__(self): self.proxyCli = ProxyManagerClient() self.idps = IdProviderFactory() # Privide two authlib methods query_client and save_token - _AuthorizationServer.__init__(self, query_client=self.getClient, save_token=self.saveToken) + _AuthorizationServer.__init__(self, query_client=getDIACClientByID, save_token=self.saveToken) self.generate_token = self.generateProxyOrToken self.bearerToken = BearerToken(self.access_token_generator, self.refresh_token_generator) self.config = {} @@ -108,34 +108,6 @@ def saveToken(self, token, request): if not result['OK']: gLogger.error(result['Message']) - def getClient(self, clientID): - """ Search authorization client - - :param str clientID: client ID - - :return: object - """ - data = {} - gLogger.debug('Try to query %s client' % clientID) - result = getProvidersForInstance('Id', 'DIRAC') - if not result['OK']: - gLogger.error(result['Message']) - return None - - clients = list(set(result['Value'] + list(DEFAULT_CLIENTS.keys()))) - for client in clients: - data = DEFAULT_CLIENTS.get(client, {}) - result = getProviderInfo(client) - if not result['OK']: - gLogger.debug(result['Message']) - else: - data.update(result['Value']) - if data.get('client_id') and data['client_id'] == clientID: - gLogger.debug('Found client:\n', pprint.pformat(data)) - return Client(data) - - return None - def __getScope(self, scope, param): """ Get parameter scope diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index 16651f2e513..4158e2639d3 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -11,7 +11,7 @@ DeviceCredentialDict) from DIRAC import gLogger -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata log = gLogger.getSubLogger(__name__) @@ -30,7 +30,7 @@ def get_verification_uri(self): :return: str """ - result = getAuthorisationServerMetadata() + result = getAuthorizationServerMetadata() if not result['OK']: raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) return result['Value']['issuer'] + '/device' @@ -116,7 +116,7 @@ def query_device_credential(self, device_code): data = result['Value'] if not data: return None - result = getAuthorisationServerMetadata() + result = getAuthorizationServerMetadata() if not result['OK']: raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) data['verification_uri'] = result['Value']['issuer'] + '/device' diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index 2a5c2f661a9..ad0b4bdf8f9 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -5,9 +5,11 @@ import six import json import time +import pprint from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo from DIRAC import gLogger @@ -17,26 +19,58 @@ DEFAULT_CLIENTS = { 'DIRACCLI': dict( - ProviderType='DIRAC', + verify=False, client_id='DIRAC_CLI', response_types=['device'], - grant_types=['urn:ietf:params:oauth:grant-type:device_code'] + grant_types=['urn:ietf:params:oauth:grant-type:device_code'], + ProviderType='DIRACCLI' ), - 'WebAppDIRAC': dict( - ProviderType='DIRAC', - token_endpoint_auth_method='client_secret_basic', + 'DIRACWeb': dict( + # token_endpoint_auth_method='client_secret_basic', + token_endpoint_auth_method='client_secret_post', response_types=['code'], - grant_types=['authorization_code', 'refresh_token'] + grant_types=['authorization_code', 'refresh_token'], + ProviderType='DIRACWeb' ) } +def getDIACClientByID(clientID): + """ Search authorization client + + :param str clientID: client ID + + :return: object or None + """ + gLogger.debug('Try to query %s client' % clientID) + if clientID == DEFAULT_CLIENTS['DIRACCLI']['client_id']: + gLogger.debug('Found client:\n', pprint.pformat(DEFAULT_CLIENTS['DIRACCLI'])) + return Client(DEFAULT_CLIENTS['DIRACCLI']) + + result = getProvidersForInstance('Id') + if not result['OK']: + gLogger.error(result['Message']) + return None + + for client in result['Value']: + result = getProviderInfo(client) + if not result['OK']: + gLogger.debug(result['Message']) + continue + data = DEFAULT_CLIENTS.get(result['Value']['ProviderType'], {}) + data.update(result['Value']) + if data.get('client_id') and data['client_id'] == clientID: + gLogger.debug('Found client:\n', pprint.pformat(data)) + return Client(data) + + return None + + class Client(OAuth2ClientMixin): def __init__(self, params): - super(Client, self).__init__() client_metadata = params.get('client_metadata', params) - client_metadata['scope'] = ' '.join([client_metadata.get('scope', ''), DEFAULT_SCOPE]) + client_metadata['scope'] = ' '.join(list(set([client_metadata.get('scope', ''), DEFAULT_SCOPE]))) if params.get('redirect_uri') and not client_metadata.get('redirect_uris'): client_metadata['redirect_uris'] = [params['redirect_uri']] self.client_id = params['client_id'] From 27a479d9e488a85a9b55a51322bf079ce74d7c7d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:40:22 +0200 Subject: [PATCH 021/178] fix idps --- src/DIRAC/Resources/IdProvider/IdProvider.py | 15 ++-- .../Resources/IdProvider/IdProviderFactory.py | 13 +-- .../Resources/IdProvider/OAuth2IdProvider.py | 88 +++++++++---------- 3 files changed, 56 insertions(+), 60 deletions(-) diff --git a/src/DIRAC/Resources/IdProvider/IdProvider.py b/src/DIRAC/Resources/IdProvider/IdProvider.py index 4e209f6b14a..d1f8e6def02 100644 --- a/src/DIRAC/Resources/IdProvider/IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IdProvider.py @@ -4,21 +4,25 @@ from __future__ import division from __future__ import print_function -from DIRAC import gLogger, S_OK +from DIRAC import gLogger __RCSID__ = "$Id$" class IdProvider(object): - def __init__(self, *args, **kwargs): + DEFAULT_METADATA = {} + + def __init__(self, **kwargs): """ C'or """ self.log = gLogger.getSubLogger(self.__class__.__name__) - self.parameters = kwargs.get('parameters', {}) - self._initialization() + meta = self.DEFAULT_METADATA + meta.update(kwargs) + self.setParameters(meta) + self._initialization(**meta) - def _initialization(self): + def _initialization(self, **kwargs): """ Initialization """ pass @@ -28,3 +32,4 @@ def setParameters(self, parameters): :param dict parameters: parameters of the identity Provider """ self.parameters = parameters + self.name = parameters.get('ProviderName') diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 1d42cb4fbee..bd608f05489 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -16,7 +16,7 @@ from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe from DIRAC.Core.Utilities.DictCache import DictCache from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo, getSettingsNamesForIdPIssuer -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS __RCSID__ = "$Id$" @@ -34,7 +34,7 @@ def __init__(self): @gCacheMetadata def getMetadata(self, idP): - return self.cacheMetadata.get(idP) + return self.cacheMetadata.get(idP) or {} @gCacheMetadata def addMetadata(self, idP, data, time=24 * 3600): @@ -58,7 +58,7 @@ def getIdProviderForToken(self, token): if result['OK']: return self.getIdProvider(result['Value'][0]) - _result = getAuthorisationServerMetadata() + _result = getAuthorizationServerMetadata() if not _result['OK']: return _result if issuer == _result['Value'].get('issuer', '').strip('/'): @@ -77,7 +77,7 @@ def getIdProvider(self, name, **kwargs): self.log.debug('Search %s configuration..' % name) pDict = DEFAULT_CLIENTS.get(name, {}) if pDict: - result = getAuthorisationServerMetadata() + result = getAuthorizationServerMetadata() if not result['OK']: return result pDict.update(result['Value']) @@ -106,12 +106,7 @@ def getIdProvider(self, name, **kwargs): pClass = result['Value'] try: - meta = self.getMetadata(name) - if meta: - pDict.update(meta) provider = pClass(**pDict) - if not meta and hasattr(provider, 'metadata'): - self.addMetadata(name, provider.metadata) except Exception as x: msg = 'IdProviderFactory could not instantiate %s object: %s' % (subClassName, str(x)) self.log.exception() diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index e449d16a46a..09b6e6780c7 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -9,7 +9,6 @@ import time import pprint import requests -from requests import exceptions from authlib.jose import JsonWebKey, jwt from authlib.common.urls import url_decode from authlib.common.security import generate_token @@ -23,7 +22,7 @@ from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token -from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC import S_OK, S_ERROR from DIRAC.Resources.IdProvider.IdProvider import IdProvider from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption @@ -75,50 +74,42 @@ def claimParser(claimDict, attributes): class OAuth2IdProvider(IdProvider, OAuth2Session): + """ Base class to describe the configuration of the OAuth2 client of the corresponding provider. + """ - def __init__(self, name=None, token_endpoint_auth_method='client_secret_post', revocation_endpoint_auth_method=None, - scope='', token=None, token_placement='header', update_token=None, **parameters): - """ OIDCClient constructor - """ - if 'ProviderName' not in parameters: - parameters['ProviderName'] = name - IdProvider.__init__(self, **parameters) - OAuth2Session.__init__(self, token_endpoint_auth_method=token_endpoint_auth_method, - revocation_endpoint_auth_method=revocation_endpoint_auth_method, - scope=scope, token=token, token_placement=token_placement, - update_token=update_token, **parameters) - self.jwks = parameters.get('jwks') - # Convert scope to list - scope = scope or '' - self.scope = list_to_scope([s.strip() for s in scope.strip().replace('+', ' ').split(',' if ',' in scope else ' ')]) - self.parameters = parameters - self.name = parameters['ProviderName'] - self.verify = False - - self.server_metadata_url = parameters.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) - + def __init__(self, **kwargs): + """ Initialization """ + IdProvider.__init__(self, **kwargs) + OAuth2Session.__init__(self, **kwargs) + self.jwks_fetch_last = 0 + self.metadata_fetch_last = 0 + self.issuer = self.metadata['issuer'] + self.scope = self.scope or '' + self.jwks = kwargs.get('jwks') + self.verify = kwargs.get('verify', False) + self.token_placement = kwargs.get('token_placement', 'header') + self.code_challenge_method = 'S256' + self.token_endpoint_auth_method = kwargs.get('token_endpoint_auth_method', 'client_secret_post') + self.server_metadata_url = kwargs.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) + self.metadata_fetch_last = time.time() - self.METADATA_REFRESH_RATE self.log.debug('"%s" OAuth2 IdP initialization done:' % self.name, '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, self.client_secret, pprint.pformat(self.metadata))) - def verifyToken(self, accessToken): + def verifyToken(self, accessToken, jwks=None): """ Verify access token :param str accessToken: access token + + :return: dict """ - try: - # Try to decode token - gLogger.debug("Try to decode token:", accessToken) - return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) - except Exception: - # If we have outdated keys, we try to update them from identity provider - gLogger.debug("Try to update %s jwks.." % self.metadata['issuer']) - self.jwks = self.fetch_metadata(self.get_metadata('jwks_uri')) - return jwt.decode(accessToken, JsonWebKey.import_key_set(self.jwks)) - - def update_token(self, token, refresh_token): - pass + jwks = jwks or self.jwks + self.log.debug("Try to decode token %s with JWKs:\n" % accessToken, pprint.pformat(jwks)) + if not jwks: + raise Exception("JWKs not found.") + # Try to decode and verify token + return jwt.decode(accessToken, JsonWebKey.import_key_set(jwks)) def refreshToken(self, refresh_token): """ Refresh token @@ -139,18 +130,21 @@ def revokeToken(self, token=None, token_type_hint='refresh_token'): def get_metadata(self, option=None): """ Get metadata + + :param str option: option + + :return: option value """ if not self.metadata.get(option): self.fetch_metadata() return self.metadata.get(option) - def fetch_metadata(self, url=None): + def fetch_metadata(self): """ Fetch metada """ - data = self.get(url or self.server_metadata_url, withhold_token=True).json() - if url: - return data - self.metadata.update(data) + if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): + data = self.get(self.server_metadata_url, withhold_token=True).json() + self.metadata.update(data) def researchGroup(self, payload, token): """ Research group @@ -172,7 +166,7 @@ def authorization(self, group=None): # Notify user to go to authorization endpoint showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], response['verification_uri']) - gLogger.notice(showURL) + self.log.notice(showURL) return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) @@ -227,6 +221,8 @@ def parseAuthResponse(self, response, session=None): def fetchToken(self, **kwargs): """ Fetch token + + :return: dict """ self.fetch_access_token(self.get_metadata('token_endpoint'), **kwargs) self.token['client_id'] = self.client_id @@ -311,9 +307,9 @@ def submitDeviceCodeAuthorizationFlow(self, group=None): return S_ERROR('Mandatory %s key is absent in authentication response.' % k) return S_OK(deviceResponse) - except requests.exceptions.Timeout: + except requests..Timeout: return S_ERROR('Authentication server is not answer, timeout.') - except requests.exceptions.RequestException as ex: + except requests..RequestException as ex: return S_ERROR(repr(ex)) except Exception as ex: return S_ERROR('Cannot read authentication response: %s' % repr(ex)) @@ -329,7 +325,7 @@ def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, t """ __start = time.time() - gLogger.notice('Authorization pending.. (use CNTL + C to stop)') + self.log.notice('Authorization pending.. (use CNTL + C to stop)') while True: time.sleep(int(interval)) if time.time() - __start > timeout: @@ -377,7 +373,7 @@ def exchangeGroup(self, group): self.token = token return S_OK(token) except Exception as e: - return S_ERROR(repr(e)) + return S_ERROR('Cannot exchange token with %s group: %s' % (group,repr(e))) def getUserProfile(self): return self.get(self.get_metadata('userinfo_endpoint')).json() From 30aea4b8fe88d6c52e2d3a899ea7231f0c4c2334 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:40:40 +0200 Subject: [PATCH 022/178] remove comment in cs --- dirac.cfg | 8 -------- 1 file changed, 8 deletions(-) diff --git a/dirac.cfg b/dirac.cfg index 502a64166e0..1f76f0340c8 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -426,14 +426,6 @@ Resources { IdProviders { - # WebAppDIRAC - # { - # # This type describe DIRAC authorization server client - # ProviderType = DIRAC - # issuer = https://dirac.egi.eu/DIRAC/auth - # client_id = type_client_id_here_receved_after_client_registration - # client_secret = type_client_secret_here_receved_after_client_registration - # } CheckIn { # What supported type of provider does it belong to From a1ca3512c92854b3a42e80432346e3cd31e16ca9 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:44:25 +0200 Subject: [PATCH 023/178] add dirac AS clients --- ...RACIdProvider.py => DIRACCLIIdProvider.py} | 0 .../IdProvider/DIRACWebIdProvider.py | 21 +++++++++++++++++++ .../Resources/IdProvider/OAuth2IdProvider.py | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) rename src/DIRAC/Resources/IdProvider/{DIRACIdProvider.py => DIRACCLIIdProvider.py} (100%) create mode 100644 src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py diff --git a/src/DIRAC/Resources/IdProvider/DIRACIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py similarity index 100% rename from src/DIRAC/Resources/IdProvider/DIRACIdProvider.py rename to src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py diff --git a/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py new file mode 100644 index 00000000000..2753310e58d --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py @@ -0,0 +1,21 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS + +__RCSID__ = "$Id$" + + +class DIRACWebIdProvider(OAuth2IdProvider): + + DEFAULT_METADATA = DEFAULT_CLIENTS['DIRACWeb'] + + def fetch_metadata(self): + """ Fetch metada + """ + self.metadata.update(collectMetadata(self.metadata['issuer'])) diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 09b6e6780c7..fc916dbb004 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -307,9 +307,9 @@ def submitDeviceCodeAuthorizationFlow(self, group=None): return S_ERROR('Mandatory %s key is absent in authentication response.' % k) return S_OK(deviceResponse) - except requests..Timeout: + except requests.exceptions.Timeout: return S_ERROR('Authentication server is not answer, timeout.') - except requests..RequestException as ex: + except requests.exceptions.RequestException as ex: return S_ERROR(repr(ex)) except Exception as ex: return S_ERROR('Cannot read authentication response: %s' % repr(ex)) From 995890833fdaa4a5cfd057ffe112e190ae89a18d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 13:47:39 +0200 Subject: [PATCH 024/178] fix package versions --- environment-py3.yml | 4 ++-- environment.yml | 4 ++-- requirements.txt | 4 ++-- setup.cfg | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index f96ea36e775..43be46f9f56 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,8 +89,8 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - authlib >=1.0.0 - - pyjwt + - authlib <=0.15.3 + - pyjwt <=1.7.1 - dominate - pip: # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/environment.yml b/environment.yml index d0a455611eb..63c52b2ad65 100644 --- a/environment.yml +++ b/environment.yml @@ -73,8 +73,8 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - authlib >=1.0.0 - - pyjwt + - authlib <=0.15.3 + - pyjwt <=1.7.1 - dominate - pip: - diraccfg diff --git a/requirements.txt b/requirements.txt index ff3483ecd60..ed1477c0d8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -authlib >=1.0.0 -pyjwt +authlib <=0.15.3 +pyjwt <=1.7.1 dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index a2a278abe47..9f0cf392d64 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,8 +55,8 @@ install_requires = six sqlalchemy subprocess32 - authlib >=1.0.0 - pyjwt + authlib <=0.15.3 + pyjwt <=1.7.1 dominate zip_safe = False include_package_data = True From 5cff7a96c81028cc5c0b9eab7fc04ce1ea572b76 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 24 May 2021 18:33:17 +0200 Subject: [PATCH 025/178] fix tests --- .../ConfigurationSystem/Client/Helpers/Resources.py | 12 ++++-------- src/DIRAC/ConfigurationSystem/Client/Utilities.py | 2 +- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 7 +++---- .../private/authorization/AuthServer.py | 11 ++++------- src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py | 6 ++++-- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index 70c82b6ad4f..3a62be1f107 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -500,13 +500,9 @@ def getProvidersForInstance(instance, providerType=None): :return: S_OK(list)/S_ERROR() """ - data = [] + providers = [] instance = "%sProviders" % instance - result = gConfig.getSections(gBaseResourcesSection) - if result['OK']: - if instance not in result['Value']: - return S_OK(data) - result = gConfig.getSections('%s/%s' % (gBaseResourcesSection, instance)) + result = gConfig.getSections('%s/%s' % (gBaseResourcesSection, instance)) # Return an empty list if the section does not exist if not result['OK'] or not result['Value'] or not providerType: @@ -514,8 +510,8 @@ def getProvidersForInstance(instance, providerType=None): for prov in result['Value']: if providerType == gConfig.getValue('%s/%s/%s/ProviderType' % (gBaseResourcesSection, instance, prov)): - data.append(prov) - return S_OK(data) + providers.append(prov) + return S_OK(providers) def getProviderInfo(provider): diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index c669ccaca43..b76b26958ec 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -596,7 +596,7 @@ def getAuthorizationServerMetadata(issuer=None): data['issuer'] = getAuthAPI() except Exception as e: return S_ERROR('No issuer found in DIRAC authorization server: %s' % repr(e)) - + return S_OK(data) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index ea98acf9ba1..447dc12cda7 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -350,12 +350,11 @@ def web_device(self, provider=None): group = groups[0] if groups else None if group and not provider: - print(group) provider = Registry.getIdPForGroup(group) - print('Use provider:', provider) - - authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) # pylint: disable=no-member + self.log.debug('Use provider:', provider) + # pylint: disable=no-member + authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) # Save session to cookie return self.server.handle_response(302, {}, [("Location", authURL)], session) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 4667552f781..9ac7437dbd7 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -203,11 +203,8 @@ def parseIdPAuthorizationResponse(self, response, session): # FINISHING with IdP auth result credDict = result['Value'] - - # ########### TODO: This place will store the original tokens ############ # - # result = self.__tokenDB.updateToken(provObj.token, user_id=provObj.token['user_id']) - # if not result['OK']: - # return result + # ########### TODO: This line will store the original tokens ############ # + # updateToken(provObj.token, user_id=provObj.token['user_id']) gLogger.debug("Read profile:", pprint.pformat(credDict)) # Is ID registred? @@ -412,7 +409,7 @@ def __registerNewUser(self, provider, userProfile): for addresses in getEmailsForGroup('dirac_admin'): result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) if not result['OK']: - self.log.error(result['Message']) + gLogger.error(result['Message']) if result['OK']: - self.log.info(result['Value'], "administrators have been notified about a new user.") + gLogger.info(result['Value'], "administrators have been notified about a new user.") return result diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index fc916dbb004..818635da378 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -77,11 +77,12 @@ class OAuth2IdProvider(IdProvider, OAuth2Session): """ Base class to describe the configuration of the OAuth2 client of the corresponding provider. """ + METADATA_REFRESH_RATE = 24 * 3600 + def __init__(self, **kwargs): """ Initialization """ IdProvider.__init__(self, **kwargs) OAuth2Session.__init__(self, **kwargs) - self.jwks_fetch_last = 0 self.metadata_fetch_last = 0 self.issuer = self.metadata['issuer'] self.scope = self.scope or '' @@ -145,6 +146,7 @@ def fetch_metadata(self): if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): data = self.get(self.server_metadata_url, withhold_token=True).json() self.metadata.update(data) + self.metadata_fetch_last = time.time() def researchGroup(self, payload, token): """ Research group @@ -373,7 +375,7 @@ def exchangeGroup(self, group): self.token = token return S_OK(token) except Exception as e: - return S_ERROR('Cannot exchange token with %s group: %s' % (group,repr(e))) + return S_ERROR('Cannot exchange token with %s group: %s' % (group, repr(e))) def getUserProfile(self): return self.get(self.get_metadata('userinfo_endpoint')).json() From c892f6da5f99f41adc0d985c925fce4d6881aafa Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 26 May 2021 11:52:11 +0200 Subject: [PATCH 026/178] fix issues --- .../Client/Helpers/Resources.py | 18 -- src/DIRAC/Core/Security/Locations.py | 17 -- src/DIRAC/Core/Security/TokenFile.py | 99 ----------- src/DIRAC/Core/Security/TokenInfo.py | 79 --------- .../Client/private/TornadoBaseClient.py | 2 +- .../Core/Tornado/Server/BaseRequestHandler.py | 7 +- .../Core/Tornado/scripts/tornado_start_AS.py | 55 ------- .../Core/Tornado/scripts/tornado_start_web.py | 63 ------- .../private/authorization/AuthServer.py | 2 +- .../authorization/grants/AuthorizationCode.py | 5 +- .../private/authorization/utils/Clients.py | 4 +- .../private/authorization/utils/Tokens.py | 155 +++++++++++++++++- .../FrameworkSystem/scripts/dirac_login.py | 4 +- .../Resources/IdProvider/IdProviderFactory.py | 2 +- .../Resources/IdProvider/OAuth2IdProvider.py | 11 ++ src/DIRAC/Resources/IdProvider/Utilities.py | 69 ++++++++ tests/Integration/Framework/Test_AuthDB.py | 3 + tests/Jenkins/dirac_ci.sh | 3 + 18 files changed, 253 insertions(+), 345 deletions(-) delete mode 100644 src/DIRAC/Core/Security/TokenFile.py delete mode 100644 src/DIRAC/Core/Security/TokenInfo.py delete mode 100644 src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py delete mode 100644 src/DIRAC/Core/Tornado/scripts/tornado_start_web.py create mode 100644 src/DIRAC/Resources/IdProvider/Utilities.py diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index 3a62be1f107..ac8cdf47bf8 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -430,24 +430,6 @@ def getFilterConfig(filterID): return gConfig.getOptionsDict('Resources/LogFilters/%s' % filterID) -def getSettingsNamesForIdPIssuer(issuer): - """ Get identity providers for issuer - - :param str issuer: issuer - - :return: S_OK(list)/S_ERROR() - """ - names = [] - result = getProvidersForInstance('Id') - if not result['OK']: - return result - for name in result['Value']: - nameIssuer = gConfig.getValue('%s/IdProviders/%s/issuer' % (gBaseResourcesSection, name)) - if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): - names.append(name) - return S_OK(names) if names else S_ERROR('Not found provider with %s issuer.' % issuer) - - def getInfoAboutProviders(of=None, providerName=None, option='', section=''): """ Get the information about providers diff --git a/src/DIRAC/Core/Security/Locations.py b/src/DIRAC/Core/Security/Locations.py index 8c876caaf88..fe9e575aca6 100644 --- a/src/DIRAC/Core/Security/Locations.py +++ b/src/DIRAC/Core/Security/Locations.py @@ -12,23 +12,6 @@ g_SecurityConfPath = "/DIRAC/Security" -def getTokenLocation(): - """ Get the path of the currently active access token file - """ - envVar = 'DIRAC_TOKEN_FILE' - if envVar in os.environ: - tokenPath = os.path.realpath(os.environ[envVar]) - if os.path.isfile(tokenPath): - return tokenPath - # /tmp/JWTup_u - tokenName = "JWTup_u%d" % os.getuid() - if os.path.isfile("/tmp/%s" % tokenName): - return "/tmp/%s" % tokenName - - # No access token found - return False - - def getProxyLocation(): """ Get the path of the currently active grid proxy file """ diff --git a/src/DIRAC/Core/Security/TokenFile.py b/src/DIRAC/Core/Security/TokenFile.py deleted file mode 100644 index 9817ec75ed3..00000000000 --- a/src/DIRAC/Core/Security/TokenFile.py +++ /dev/null @@ -1,99 +0,0 @@ -""" Collection of utilities for dealing with security files (i.e. token files) -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -__RCSID__ = "$Id$" - -import os -import json -import stat -import tempfile - -from DIRAC import S_OK, S_ERROR -from DIRAC.Core.Utilities import DErrno -from DIRAC.Core.Security.Locations import getTokenLocation - - -def readTokenFromFile(fileName=None): - """ Read token from a file - - :param str fileName: filename to read - - :return: S_OK(dict)/S_ERROR() - """ - if not fileName: - fileName = getTokenLocation() or os.environ.get('DIRAC_TOKEN_FILE', "/tmp/JWTup_u%d" % os.getuid()) - try: - with open(fileName, 'rt') as f: - data = f.read() - return S_OK(json.loads(data)) - except Exception as e: - return S_ERROR('Cannot read token. %s' % repr(e)) - - -def writeToTokenFile(tokenContents, fileName=False): - """ Write a token string to file - - :param str tokenContents: token as string - :param str fileName: filename to dump to - - :return: S_OK(str)/S_ERROR() - """ - if not fileName: - try: - fd, tokenLocation = tempfile.mkstemp() - os.close(fd) - except IOError: - return S_ERROR(DErrno.ECTMPF) - fileName = tokenLocation - try: - with open(fileName, 'wt') as fd: - fd.write(tokenContents) - except Exception as e: - return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e))) - try: - os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) - except Exception as e: - return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e))) - return S_OK(fileName) - - -def writeTokenDictToTokenFile(tokenDict, fileName=None): - """ Write a token dict to file - - :param dict tokenDict: dict object to dump to file - :param str fileName: filename to dump to - - :return: S_OK(str)/S_ERROR() - """ - if not fileName: - fileName = getTokenLocation() or os.environ.get('DIRAC_TOKEN_FILE', "/tmp/JWTup_u%d" % os.getuid()) - try: - retVal = json.dumps(tokenDict) - except Exception as e: - return S_ERROR('Cannot read token. %s' % repr(e)) - return writeToTokenFile(retVal, fileName) - - -def writeTokenDictToTemporaryFile(tokenDict): - """ Write a token dict to a temporary file - - :param dict tokenDict: dict object to dump to file - - :return: S_OK(str)/S_ERROR() -- contain file name - """ - try: - fd, tokenLocation = tempfile.mkstemp() - os.close(fd) - except IOError: - return S_ERROR(DErrno.ECTMPF) - retVal = writeTokenDictToTokenFile(tokenDict, tokenLocation) - if not retVal['OK']: - try: - os.unlink(tokenLocation) - except Exception: - pass - return retVal - return S_OK(tokenLocation) diff --git a/src/DIRAC/Core/Security/TokenInfo.py b/src/DIRAC/Core/Security/TokenInfo.py deleted file mode 100644 index 42a6253d351..00000000000 --- a/src/DIRAC/Core/Security/TokenInfo.py +++ /dev/null @@ -1,79 +0,0 @@ -""" - Set of utilities to retrieve Information from token -""" -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function - -import six -import time -import datetime - -from DIRAC import S_OK, S_ERROR -from DIRAC.Core.Utilities import DErrno -from DIRAC.Core.Security import Locations - -from DIRAC.Core.Security.TokenFile import readTokenFromFile -from DIRAC.ConfigurationSystem.Client.Helpers import Registry -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory - -__RCSID__ = "$Id$" - - -def getTokenInfo(token=False): - """ Return token info - - :param token: token location or token as dict - - :return: S_OK(dict)/S_ERROR() - """ - # Discover token location - if isinstance(token, dict): - token = OAuth2Token(token) - else: - tokenLocation = token if isinstance(token, six.string_types) else Locations.getTokenLocation() - if not tokenLocation: - return S_ERROR("Cannot find token location.") - result = readTokenFromFile(tokenLocation) - if not result['OK']: - return result - token = OAuth2Token(result['Value'])['access_token'] - - result = IdProviderFactory().getIdProviderForToken(token) - if not result['OK']: - return result - cli = result['Value'] - cli.updateJWKs() - payload = cli.verifyToken(token) - - result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) - if not result['OK']: - return result - payload['username'] = result['Value'] - if payload.get('group'): - payload['properties'] = Registry.getPropertiesForGroup(payload['group']) - return S_OK(payload) - - -def formatTokenInfoAsString(infoDict): - """ Convert a token infoDict into a string - - :param dict infoDict: info - - :return: str - """ - secsLeft = int(infoDict['exp']) - time.time() - strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") - - leftAlign = 13 - contentList = [] - contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) - contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) - contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) - contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) - if infoDict.get('group'): - contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) - if infoDict.get('properties'): - contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) - return "\n".join(contentList) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 15193892820..f06af7ae615 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -47,9 +47,9 @@ from DIRAC.Core.DISET.ThreadConfig import ThreadConfig from DIRAC.Core.Security import Locations -from DIRAC.Core.Security.TokenFile import readTokenFromFile from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import readTokenFromFile # TODO CHRIS: refactor all the messy `discover` methods diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 6adc5f24847..376c7e6eb6d 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -10,8 +10,7 @@ from io import open import os -import jwt as _jwt -from authlib.jose import JsonWebKey, jwt +import jwt import time import requests import threading @@ -35,7 +34,7 @@ from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance sLog = gLogger.getSubLogger(__name__.split('.')[-1]) @@ -657,7 +656,7 @@ def _authzJWT(self, accessToken=None): return S_ERROR('Found a not bearer access token.') # Read token without verification to get issuer - issuer = _jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') + issuer = jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') # Verify token payload = self._idp[issuer].verifyToken(accessToken, self._jwks[issuer]) diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py deleted file mode 100644 index b71254634cf..00000000000 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_AS.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -__RCSID__ = "$Id$" - -import os -import sys -import tornado - -from DIRAC.Core.Utilities.DIRACScript import DIRACScript - - -@DIRACScript() -def main(): - # Must be define BEFORE any dirac import - os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" - - from DIRAC.ConfigurationSystem.Client.PathFinder import getAPISection - from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData - from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration - from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer - from DIRAC.Core.Utilities.DErrno import includeExtensionErrors - from DIRAC.FrameworkSystem.Client.Logger import gLogger - - localCfg = LocalConfiguration() - localCfg.addMandatoryEntry("/DIRAC/Setup") - localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") - localCfg.addDefaultEntry("LogLevel", "INFO") - localCfg.addDefaultEntry("LogColor", True) - resultDict = localCfg.loadUserData() - if not resultDict['OK']: - gLogger.initialize("Tornado", "/") - gLogger.error("There were errors when loading configuration", resultDict['Message']) - sys.exit(1) - - includeExtensionErrors() - - gLogger.initialize('Tornado', "/") - - endpoints = ['Framework/Auth'] - try: - asPort = int(gConfigurationData.extractOptionFromCFG('%s/Port' % getAPISection('Framework/Auth'))) - except TypeError: - asPort = None - - serverToLaunch = TornadoServer(False, endpoints, port=asPort) - - serverToLaunch.startTornado() - - -if __name__ == "__main__": - main() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py deleted file mode 100644 index 23cfe7cb093..00000000000 --- a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -__RCSID__ = "$Id$" - -import os -import sys -import tornado -import pprint - -from DIRAC.Core.Utilities.DIRACScript import DIRACScript - - -@DIRACScript() -def main(): - # Must be define BEFORE any dirac import - os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" - - from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration - from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer - from DIRAC.Core.Utilities.DErrno import includeExtensionErrors - from DIRAC.FrameworkSystem.Client.Logger import gLogger - - localCfg = LocalConfiguration() - localCfg.addMandatoryEntry("/DIRAC/Setup") - localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") - localCfg.addDefaultEntry("LogLevel", "INFO") - localCfg.addDefaultEntry("LogColor", True) - resultDict = localCfg.loadUserData() - if not resultDict['OK']: - gLogger.initialize("Tornado", "/") - gLogger.error("There were errors when loading configuration", resultDict['Message']) - sys.exit(1) - - includeExtensionErrors() - - gLogger.initialize('Tornado', "/") - - serverToLaunch = TornadoServer(False) - - try: - from WebAppDIRAC.Core.App import App - except ImportError as e: - gLogger.fatal('Web portal is not installed. %s' % repr(e)) - sys.exit(1) - - # Get routes and settings for a portal - result = App().getAppToDict(8000) - if not result['OK']: - gLogger.fatal(result['Message']) - sys.exit(1) - app = result['Value'] - - serverToLaunch.addHandlers(app['routes'], app['settings']) - - serverToLaunch.startTornado() - - -if __name__ == "__main__": - main() diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 9ac7437dbd7..32882c1d9d6 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -27,10 +27,10 @@ from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getSetup from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index df1fe9d0703..7691e896850 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -8,15 +8,18 @@ from pprint import pprint from authlib.jose import JsonWebSignature from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant from authlib.common.encoding import to_unicode, json_dumps, json_b64encode, urlsafe_b64decode, json_loads from DIRAC import gLogger, S_OK, S_ERROR class OAuth2Code(dict): + """ This class describe Authorization Code object """ + def __init__(self, params): + """ C'or """ params['auth_time'] = params.get('auth_time', int(time())) super(OAuth2Code, self).__init__(params) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index ad0b4bdf8f9..ca6deb84446 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -7,9 +7,9 @@ import time import pprint -from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance, getProviderInfo from DIRAC import gLogger diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 6864e0cdb95..54e1cfcc91a 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -2,13 +2,164 @@ from __future__ import division from __future__ import print_function -from time import time +import six +import time +import datetime + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from authlib.oauth2.rfc6749.util import scope_to_list from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin +def getTokenLocation(): + """ Get the path of the currently active access token file + """ + envVar = 'DIRAC_TOKEN_FILE' + if envVar in os.environ: + tokenPath = os.path.realpath(os.environ[envVar]) + if os.path.isfile(tokenPath): + return tokenPath + # /tmp/JWTup_u + return "/tmp/JWTup_u%s" % os.getuid() + + +def readTokenFromFile(fileName=None): + """ Read token from a file + + :param str fileName: filename to read + + :return: S_OK(dict)/S_ERROR() + """ + fileName = fileName or getTokenLocation() + try: + with open(fileName, 'rt') as f: + tokenDict = f.read() + return S_OK(json.loads(tokenDict)) + except Exception as e: + return S_ERROR('Cannot read token. %s' % repr(e)) + + +def writeToTokenFile(tokenContents, fileName): + """ Write a token string to file + + :param str tokenContents: token as string + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + try: + with open(fileName, 'wt') as fd: + fd.write(tokenContents) + except Exception as e: + return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e))) + try: + os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e))) + return S_OK(fileName) + + +def writeTokenDictToTokenFile(tokenDict, fileName=None): + """ Write a token dict to file + + :param dict tokenDict: dict object to dump to file + :param str fileName: filename to dump to + + :return: S_OK(str)/S_ERROR() + """ + fileName = fileName or getTokenLocation() + try: + retVal = json.dumps(tokenDict) + except Exception as e: + return S_ERROR('Cannot dump token to string. %s' % repr(e)) + return writeToTokenFile(retVal, fileName) + + +def writeTokenDictToTemporaryFile(tokenDict): + """ Write a token dict to a temporary file + + :param dict tokenDict: dict object to dump to file + + :return: S_OK(str)/S_ERROR() -- contain file name + """ + try: + fd, tokenLocation = tempfile.mkstemp() + os.close(fd) + except IOError: + return S_ERROR(DErrno.ECTMPF) + retVal = writeTokenDictToTokenFile(tokenDict, tokenLocation) + if not retVal['OK']: + try: + os.unlink(tokenLocation) + except Exception: + pass + return retVal + return S_OK(tokenLocation) + + +def getTokenInfo(token=False): + """ Return token info + + :param token: token location or token as dict + + :return: S_OK(dict)/S_ERROR() + """ + # Discover token location + if isinstance(token, dict): + token = OAuth2Token(token) + else: + tokenLocation = token if isinstance(token, six.string_types) else getTokenLocation() + if not tokenLocation: + return S_ERROR("Cannot find token location.") + result = readTokenFromFile(tokenLocation) + if not result['OK']: + return result + token = OAuth2Token(result['Value'])['access_token'] + + result = IdProviderFactory().getIdProviderForToken(token) + if not result['OK']: + return S_ERROR("Cannot load provider: %s" % result['Message']) + cli = result['Value'] + cli.updateJWKs() + payload = cli.verifyToken(token) + + result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) + if not result['OK']: + return result + payload['username'] = result['Value'] + if payload.get('group'): + payload['properties'] = Registry.getPropertiesForGroup(payload['group']) + return S_OK(payload) + + +def formatTokenInfoAsString(infoDict): + """ Convert a token infoDict into a string + + :param dict infoDict: info + + :return: str + """ + secsLeft = int(infoDict['exp']) - time.time() + strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") + + leftAlign = 13 + contentList = [] + contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) + contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) + contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) + contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) + if infoDict.get('group'): + contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) + if infoDict.get('properties'): + contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) + return "\n".join(contentList) + + class OAuth2Token(_OAuth2Token, OAuth2TokenMixin): """ Implementation a Token object """ @@ -23,7 +174,7 @@ def __init__(self, params=None, **kwargs): self.refresh_token = kwargs.get('refresh_token') self.scope = kwargs.get('scope') self.revoked = kwargs.get('revoked') - self.issued_at = int(kwargs.get('issued_at', kwargs.get('iat', time()))) + self.issued_at = int(kwargs.get('issued_at', kwargs.get('iat', time.time()))) self.expires_in = int(kwargs.get('expires_in', 0)) self.expires_at = int(kwargs.get('expires_at', kwargs.get('exp', 0))) if not self.issued_at: diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 871fd0ba7f1..519a9a99f61 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -22,11 +22,11 @@ from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.Core.Base import Script from DIRAC.Core.Utilities.DIRACScript import DIRACScript -from DIRAC.Core.Security.TokenFile import writeTokenDictToTokenFile from DIRAC.Core.Security.ProxyFile import writeToProxyFile from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString -from DIRAC.Core.Security.TokenInfo import getTokenInfo, formatTokenInfoAsString from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, + getTokenInfo, formatTokenInfoAsString) __RCSID__ = "$Id$" diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index bd608f05489..1608d51c654 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -15,7 +15,7 @@ from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe from DIRAC.Core.Utilities.DictCache import DictCache -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo, getSettingsNamesForIdPIssuer +from DIRAC.Resources.IdProvider.Utilities import getProviderInfo, getSettingsNamesForIdPIssuer from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 818635da378..f01e28bf2cf 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -147,6 +147,17 @@ def fetch_metadata(self): data = self.get(self.server_metadata_url, withhold_token=True).json() self.metadata.update(data) self.metadata_fetch_last = time.time() + + def updateJWKs(self): + """ Update JWKs + """ + try: + response = requests.get(get_metadata('jwks_uri'), verify=self.verify) + response.raise_for_status() + self.jwks = response.json() + return S_OK(self.jwks) + except requests.exceptions.RequestException as e: + return S_ERROR("Error %s" % e) def researchGroup(self, payload, token): """ Research group diff --git a/src/DIRAC/Resources/IdProvider/Utilities.py b/src/DIRAC/Resources/IdProvider/Utilities.py new file mode 100644 index 00000000000..ca2134336ce --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/Utilities.py @@ -0,0 +1,69 @@ +""" Utilities for the IdProvider package +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from DIRAC import S_OK, S_ERROR, gConfig + + +def getSettingsNamesForIdPIssuer(issuer): + """ Get identity providers for issuer + + :param str issuer: issuer + + :return: S_OK(list)/S_ERROR() + """ + names = [] + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for name in result['Value']: + nameIssuer = gConfig.getValue('/Resources/IdProviders/%s/issuer' % name) + if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): + names.append(name) + return S_OK(names) if names else S_ERROR('Not found provider with %s issuer.' % issuer) + + +def getProvidersForInstance(instance, providerType=None): + """ Get providers for instance + + :param str instance: instance of what this providers + :param str providerType: provider type + + :return: S_OK(list)/S_ERROR() + """ + providers = [] + instance = "%sProviders" % instance + result = gConfig.getSections('/Resources/%s' % instance) + + # Return an empty list if the section does not exist + if not result['OK'] or not result['Value'] or not providerType: + return result + + for prov in result['Value']: + if providerType == gConfig.getValue('/Resources/%s/%s/ProviderType' % (instance, prov)): + providers.append(prov) + return S_OK(providers) + + +def getProviderInfo(provider): + """ Get provider info + + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() + """ + result = gConfig.getSections('/Resources') + if not result['OK']: + return result + for section in result['Value']: + if section.endswith('Providers'): + result = getProvidersForInstance(section[:-9]) + if not result['OK']: + return result + if provider in result['Value']: + return gConfig.getOptionsDictRecursively("/Resources/%s/%s/" % (section, provider)) + return S_ERROR('%s provider not found.' % provider) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 82affc82470..28c1bf69111 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -9,6 +9,9 @@ from authlib.jose import JsonWebKey, JsonWebSignature, jwt from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB db = AuthDB() diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 00d67d444f0..58efed0009d 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -191,6 +191,9 @@ installSite() { echo "==> Completed installation" + # Hotfix to pass tests + pip install authlib==0.15.3 pyjwt==1.7.1 dominate + } From df5f4d220f46e9127a6da2f7ae0adfc87cbf5739 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 26 May 2021 15:24:37 +0200 Subject: [PATCH 027/178] whitespace --- src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index f01e28bf2cf..1cfe0c639ce 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -147,7 +147,7 @@ def fetch_metadata(self): data = self.get(self.server_metadata_url, withhold_token=True).json() self.metadata.update(data) self.metadata_fetch_last = time.time() - + def updateJWKs(self): """ Update JWKs """ From 66b018995f48d27b57641830450c13e1659f739d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 27 May 2021 10:06:12 +0200 Subject: [PATCH 028/178] fix bugs --- src/DIRAC/ConfigurationSystem/Client/Utilities.py | 4 ++-- .../private/authorization/grants/AuthorizationCode.py | 2 +- .../FrameworkSystem/private/authorization/utils/Tokens.py | 3 +++ src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py | 2 +- src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index b76b26958ec..386e38bcdf0 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -589,7 +589,7 @@ def getAuthorizationServerMetadata(issuer=None): return {'issuer': issuer} if issuer else result data = result['Value'] - # Research DIRAC Authorization Server issuer + # Search DIRAC Authorization Server issuer data['issuer'] = data.get('issuer', issuer) if not data['issuer']: try: @@ -606,4 +606,4 @@ def isDownloadablePersonalProxy(): :return: S_OK(bool)/S_ERROR() """ cs_path = '/Systems/Framework/%s/APIs/Auth' % getSystemInstance("Framework") - return gConfig.getOption(cs_path + '/downloadablePersonalProxy') + return gConfig.getValue(cs_path + '/downloadablePersonalProxy', "false").lower() in ("y", "yes", "true") diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index 7691e896850..4063c20d8d9 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -119,6 +119,6 @@ def generate_authorization_code(self): gLogger.debug('Authorization code generated:', dict(code)) result = self.server.db.getPrivateKey() if not result['OK']: - raise OAuth2Error('Cannot check authorization code: %s' % result['Message']) + raise OAuth2Error('Cannot generate authorization code: %s' % result['Message']) key = result['Value']['key'] return jws.serialize_compact(protected, json_b64encode(dict(code)), key) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 54e1cfcc91a..6b499f65258 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -2,8 +2,11 @@ from __future__ import division from __future__ import print_function +import os import six +import stat import time +import json import datetime from DIRAC import S_OK, S_ERROR diff --git a/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py index 135f273fb13..21ac3ca7d4c 100644 --- a/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py @@ -10,7 +10,7 @@ __RCSID__ = "$Id$" -class DIRACIdProvider(OAuth2IdProvider): +class DIRACCLIIdProvider(OAuth2IdProvider): def fetch_metadata(self, url=None): """ Fetch metada diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 1cfe0c639ce..17955a9a8a3 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -152,7 +152,7 @@ def updateJWKs(self): """ Update JWKs """ try: - response = requests.get(get_metadata('jwks_uri'), verify=self.verify) + response = requests.get(self.get_metadata('jwks_uri'), verify=self.verify) response.raise_for_status() self.jwks = response.json() return S_OK(self.jwks) From 5592c44c87d3925737d74f7b65944aa01ad40988 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 27 May 2021 10:12:10 +0200 Subject: [PATCH 029/178] remove tornado_start_AS --- setup.cfg | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9f0cf392d64..997b87fca20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -131,8 +131,6 @@ console_scripts = # Core.Tornado tornado-start-CS = DIRAC.Core.Tornado.scripts.tornado_start_CS:main [server] tornado-start-all = DIRAC.Core.Tornado.scripts.tornado_start_all:main [server] - tornado-start-AS = DIRAC.Core.Tornado.scripts.tornado_start_AS:main [server] - tornado-start-web = DIRAC.Core.Tornado.scripts.tornado_start_web:main [server] # DataManagementSystem dirac-admin-allow-se = DIRAC.DataManagementSystem.scripts.dirac_admin_allow_se:main [admin] dirac-admin-ban-se = DIRAC.DataManagementSystem.scripts.dirac_admin_ban_se:main [admin] From 19a033bafb49630bd5a610edc5220ba5b8eae8d4 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:26:25 +0200 Subject: [PATCH 030/178] add fetch tokens to base client --- .../Client/private/TornadoBaseClient.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index f06af7ae615..ac68da14dc6 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -49,7 +49,8 @@ from DIRAC.Core.Security import Locations from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import readTokenFromFile +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile # TODO CHRIS: refactor all the messy `discover` methods @@ -106,6 +107,7 @@ def __init__(self, serviceName, **kwargs): self.__ca_location = False self.kwargs = kwargs + self.__idp = None self.__useAccessToken = None self.__useCertificates = None # The CS useServerCertificate option can be overridden by explicit argument @@ -237,6 +239,12 @@ def __discoverCredentialsToUse(self): if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: self.__useAccessToken = os.environ['DIRAC_USE_ACCESS_TOKEN'] + if self.__useAccessToken: + result = IdProviderFactory().getIdProvider('DIRACCLI') + if not result['OK']: + return result + self.__idp = result['Value'] + # Rewrite a little bit from here: don't need the proxy string, we use the file if self.KW_PROXY_CHAIN in self.kwargs: try: @@ -512,10 +520,26 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Use access token? elif self.__useAccessToken: - result = readTokenFromFile() + result = getLocalTokenDict() if not result['OK']: return result - auth = {'headers': {"Authorization": "Bearer %s" % result['Value']['access_token']}} + token = result['Value'] + + # Check if access token expired + if token.is_expired(): + if not token.get('refresh_token'): + return S_ERROR('Access token expired.') + + # Try to refresh token + result = self.__idp.refreshToken(token['refresh_token']) + if result['OK']: + token = result['Value'] + result = writeTokenDictToTokenFile(token) + if not result['OK']: + return result + gLogger.notice('Token is saved in %s.' % result['Value']) + + auth = {'headers': {"Authorization": "Bearer %s" % token['access_token']}} # CHRIS 04.02.21 # TODO: add proxyLocation check ? From c9a5900f660d0d6abf5c38d33a80cdf67c6c52cb Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:27:24 +0200 Subject: [PATCH 031/178] change authz with tokens, more docs --- .../Core/Tornado/Server/BaseRequestHandler.py | 318 ++++++------------ src/DIRAC/Core/Tornado/Server/TornadoREST.py | 78 ++++- .../Core/Tornado/Server/TornadoService.py | 143 ++++++++ 3 files changed, 322 insertions(+), 217 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 376c7e6eb6d..1b995a6f2f2 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -12,7 +12,6 @@ import os import jwt import time -import requests import threading from datetime import datetime from six import string_types @@ -40,24 +39,81 @@ class BaseRequestHandler(RequestHandler): - """ - Base class for all the Handlers. - It directly inherits from :py:class:`tornado.web.RequestHandler` - - Each HTTP request is served by a new instance of this class. - - For the sequence of method called, please refer to - the `tornado documentation `_. - - For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, - the handler can define a method ``export_streamToClient``. This is the method that will be called - whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET - ``transfer_toClient``. - Note that this is here only for compatibility, and we discourage using it for new purposes, as it is - bound to disappear. - - The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. - + """ Base class for all the Handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` + + Each HTTP request is served by a new instance of this class. + + For the sequence of method called, please refer to + the `tornado documentation `_. + + This class is basic for :py:class:`DIRAC.Core.Tornado.Server.TornadoService.TornadoService` + and :py:class:`DIRAC.Core.Tornado.Server.TornadoREST.TornadoREST`. + + In order to create a class that inherits from `BaseRequestHandler`, it has to + follow a certain skeleton:: + + class TornadoInstance(BaseRequestHandler): + + # Prefix of methods names + METHOD_PREFIX = "export_" + + @classmethod + def _getServiceName(cls, request): + ''' Search service name in request + ''' + return request.path[1:] + + @classmethod + def _getServiceInfo(cls, serviceName, request): + ''' Fill service information. + ''' + return {'serviceName': serviceName, + 'serviceSectionPath': PathFinder.getServiceSection(serviceName), + 'csPaths': [PathFinder.getServiceSection(serviceName)], + 'URL': request.full_url()} + + @classmethod + def _getServiceAuthSection(cls, serviceName): + ''' Search service auth section. + ''' + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + + def _getMethodName(self): + ''' Parse method name. + ''' + return self.get_argument("method") + + def _getMethodArgs(self, args): + ''' Decode args. + ''' + args_encoded = self.get_body_argument('args', default=encode([])) + return decode(args_encoded)[0] + + # Make post a coroutine. + # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines + # for details + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + ''' Describe HTTP method to use + ''' + # Execute the method in an executor (basically a separate thread) + # Because of that, we cannot calls certain methods like `self.write` + # in _executeMethod. This is because these methods are not threadsafe + # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + # However, we can still rely on instance attributes to store what should + # be sent back (reminder: there is an instance + # of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) + + For compatibility with the existing :py:class:`DIRAC.Core.DISET.TransferClient.TransferClient`, + the handler can define a method ``export_streamToClient``. This is the method that will be called + whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET + ``transfer_toClient``. + Note that this is here only for compatibility, and we discourage using it for new purposes, as it is + bound to disappear. """ # Because we initialize at first request, we use a flag to know if it's already done __init_done = False @@ -83,13 +139,9 @@ class BaseRequestHandler(RequestHandler): # Which grant type to use USE_AUTHZ_GRANTS = ['SSL', 'JWT'] - # Key updates are started at initialization, ie at the first request, - # this parameter shows that the first request of keys is made and - # it is possible to use them for check of tokens - __init_jwk_done = False + # Definition of identity providers _idps = IdProviderFactory() _idp = {} - _jwks = {} @classmethod def _initMonitoring(cls, serviceName, fullUrl): @@ -144,7 +196,7 @@ def _getServiceAuthSection(cls, serviceName): :return: str """ - return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + raise NotImplementedError('Please, create the _getServiceAuthSection class method') @classmethod def _getServiceInfo(cls, serviceName, request): @@ -158,45 +210,18 @@ def _getServiceInfo(cls, serviceName, request): return {} @classmethod - @gen.coroutine - def __refreshJWKsLoop(cls): - """ Auto refresh JWKs - """ - while True: - - # Research Identity Providers - result = getProvidersForInstance('Id') - if result['OK']: - for providerName in list(set(result['Value'] + ['DIRACCLI'])): - result = cls._idps.getIdProvider(providerName) - if result['OK']: - issuer = result['Value'].issuer.strip('/') - jwks_uri = result['Value'].get_metadata('jwks_uri') - cls._idp[issuer] = result['Value'] - - gLogger.debug('Updating public keys..') - retVal = yield IOLoop.current().run_in_executor(None, cls.__refreshJWKs, jwks_uri) - result = retVal.result() - if not result['OK']: - gLogger.error('%s keys not updated' % issuer, result['Message']) - else: - gLogger.debug('%s keys updated' % issuer, result['Value']) - cls._jwks[issuer] = result['Value'] - - cls.__init_jwk_done = True - yield gen.sleep(24 * 3600) - - @classmethod - @gen.coroutine - def __refreshJWKs(cls, jwks_uri): - """ Updating public keys - """ - try: - response = requests.get(jwks_uri, verify=False) - response.raise_for_status() - return S_OK(response.json()) - except requests.exceptions.RequestException as e: - return S_ERROR("Error %s" % e) + def __loadIdPs(cls): + """ Load identity providers that will be used to verify tokens + """ + gLogger.info('Load identit providers..') + # Research Identity Providers + result = getProvidersForInstance('Id') + if result['OK']: + for providerName in result['Value']: + result = cls._idps.getIdProvider(providerName) + if not result['OK']: + gLogger.exception(result['Message']) + cls._idp[result['Value'].issuer.strip('/')] = result['Value'] @classmethod def __initializeService(cls, request): @@ -221,8 +246,8 @@ def __initializeService(cls, request): if cls.__init_done: return S_OK() - # Run automatic public key updates - IOLoop.current().spawn_callback(cls.__refreshJWKsLoop) + # Load all registred identity providers + cls.__loadIdPs() # absoluteUrl: full URL e.g. ``https://://`` absoluteUrl = request.path @@ -365,6 +390,8 @@ def prepare(self): self._monitorRequest() + self._prepare() + def _prepare(self): """ Prepare the request. It reads certificates and check authorizations. @@ -390,77 +417,16 @@ def _prepare(self): self._getMethodAuthProps()) if not authorized: extraInfo = '' - if self.credDict.get('DN'): - extraInfo += 'DN: %s' % self.credDict['DN'] if self.credDict.get('ID'): extraInfo += 'ID: %s' % self.credDict['ID'] + elif self.credDict.get('DN'): + extraInfo += 'DN: %s' % self.credDict['DN'] sLog.error( "Unauthorized access", "Identity %s; path %s; %s" % (self.srv_getFormattedRemoteCredentials(), self.request.path, extraInfo)) raise HTTPError(status_code=http_client.UNAUTHORIZED) - # Make post a coroutine. - # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines - # for details - @gen.coroutine - def post(self, *args, **kwargs): # pylint: disable=arguments-differ - """ - Method to handle incoming ``POST`` requests. - Note that all the arguments are already prepared in the :py:meth:`.prepare` - method. - - The ``POST`` arguments expected are: - - * ``method``: name of the method to call - * ``args``: JSON encoded arguments for the method - * ``extraCredentials``: (optional) Extra informations to authenticate client - * ``rawContent``: (optionnal, default False) If set to True, return the raw output - of the method called. - - If ``rawContent`` was requested by the client, the ``Content-Type`` - is ``application/octet-stream``, otherwise we set it to ``application/json`` - and JEncode retVal. - - If ``retVal`` is a dictionary that contains a ``Callstack`` item, - it is removed, not to leak internal information. - - - Example of call using ``requests``:: - - In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' - ...: cert = '/tmp/x509up_u1000' - ...: kwargs = {'method':'whoami'} - ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' - ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: - ...: print r.json() - ...: - {u'OK': True, - u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', - u'group': u'dirac_user', - u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', - u'isLimitedProxy': False, - u'isProxy': True, - u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', - u'properties': [u'NormalUser'], - u'secondsLeft': 85441, - u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch/CN=2409820262', - u'username': u'adminusername', - u'validDN': False, - u'validGroup': False}} - """ - # Execute the method in an executor (basically a separate thread) - # Because of that, we cannot calls certain methods like `self.write` - # in _executeMethod. This is because these methods are not threadsafe - # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # However, we can still rely on instance attributes to store what should - # be sent back (reminder: there is an instance - # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) - - # retVal is :py:class:`tornado.concurrent.Future` - self._finishFuture(retVal) - @gen.coroutine def _executeMethod(self, args): """ @@ -474,9 +440,6 @@ def _executeMethod(self, args): This method is called in an executor, and so cannot use methods like self.write See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes """ - # Because the keys are updated during the first initialization, the - # first authorization process must start after initialization - self._prepare() sLog.notice( "Incoming request %s /%s: %s" % @@ -503,15 +466,16 @@ def _finishFuture(self, retVal): :param object retVal: tornado.concurrent.Future """ - # Wait result only if it's a Future object self.result = retVal.result() if isinstance(retVal, Future) else retVal # Here it is safe to write back to the client, because we are not # in a thread anymore - # Is it S_OK or S_ERROR - if self.isDIRACResult(self.result): + # Is it S_OK or S_ERROR? + if (isinstance(self.result, dict) and + isinstance(self.result.get('OK'), bool) and + ('Value' if self.result['OK'] else 'Message') in self.result): self._parseDIRACResult(self.result) # If set to true, do not JEncode the return of the RPC call @@ -533,13 +497,6 @@ def _finishFuture(self, retVal): self.finish() - def isDIRACResult(self, result): - """ Check if it DIRAC result - """ - if isinstance(result, dict): - if isinstance(result.get('OK'), bool) and ('Value' if result['OK'] else 'Message') in result: - return True - def _parseDIRACResult(self, result): """ Processing of a standard DIRAC result, but in a separate method so that it can be modified for another class if necessary @@ -643,9 +600,6 @@ def _authzJWT(self, accessToken=None): :return: S_OK(dict)/S_ERROR() """ - if not self.__init_jwk_done: - time.sleep(5) - if not accessToken: # Export token from headers token = self.request.headers.get('Authorization') @@ -656,13 +610,14 @@ def _authzJWT(self, accessToken=None): return S_ERROR('Found a not bearer access token.') # Read token without verification to get issuer - issuer = jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') - + self.log.debug('Read issuer from access token', accessToken) + issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, + verify_aud=False))['iss'].strip('/') # Verify token - payload = self._idp[issuer].verifyToken(accessToken, self._jwks[issuer]) - credDict = self._idp[issuer].researchGroup(payload, accessToken) - - return S_OK(credDict) + self.log.debug('Verify access token') + result = self._idp[issuer].verifyToken(accessToken) + self.log.debug('Search user group') + return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result def _authzVISITOR(self): """ Visitor access @@ -690,67 +645,6 @@ def getProperties(self): def isRegisteredUser(self): return self.credDict.get('username', 'anonymous') != 'anonymous' and self.credDict.get('group') - auth_ping = ['all'] - - def export_ping(self): - """ - Default ping method, returns some info about server. - - It returns the exact same information as DISET, for transparency purpose. - """ - # COPY FROM DIRAC.Core.DISET.RequestHandler - dInfo = {} - dInfo['version'] = DIRAC.version - dInfo['time'] = datetime.utcnow() - # Uptime - try: - with open("/proc/uptime", 'rt') as oFD: - iUptime = int(float(oFD.readline().split()[0].strip())) - dInfo['host uptime'] = iUptime - except Exception: # pylint: disable=broad-except - pass - startTime = self._startTime - dInfo['service start time'] = self._startTime - serviceUptime = datetime.utcnow() - startTime - dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds - # Load average - try: - with open("/proc/loadavg", 'rt') as oFD: - dInfo['load'] = " ".join(oFD.read().split()[:3]) - except Exception: # pylint: disable=broad-except - pass - dInfo['name'] = self._serviceInfoDict['serviceName'] - stTimes = os.times() - dInfo['cpu times'] = {'user time': stTimes[0], - 'system time': stTimes[1], - 'children user time': stTimes[2], - 'children system time': stTimes[3], - 'elapsed real time': stTimes[4] - } - - return S_OK(dInfo) - - auth_echo = ['all'] - - @staticmethod - def export_echo(data): - """ - This method used for testing the performance of a service - """ - return S_OK(data) - - auth_whoami = ['authenticated'] - - def export_whoami(self): - """ - A simple whoami, returns all credential dictionary, except certificate chain object. - """ - credDict = self.srv_getRemoteCredentials() - if 'x509Chain' in credDict: - # Not serializable - del credDict['x509Chain'] - return S_OK(credDict) - @classmethod def srv_getCSOption(cls, optionName, defaultValue=False): """ diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py index 6d0523fc517..7b966c00c3a 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoREST.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -25,6 +25,50 @@ class TornadoREST(BaseRequestHandler): # pylint: disable=abstract-method + """ Base class for all the endpoints handlers. + It directly inherits from :py:class:`DIRAC.Core.Tornado.Server.BaseRequestHandler.BaseRequestHandler` + + Each HTTP request is served by a new instance of this class. + + In order to create a handler for your service, it has to + follow a certain skeleton:: + + from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST + class yourEndpointHandler(TornadoREST): + + @classmethod + def initializeHandler(cls, infosDict): + ''' Called only once when the first request for this handler arrives Useful for initializing DB or so. + ''' + pass + + def initializeRequest(self): + ''' Called at the beginning of each request + ''' + pass + + # Specify the path arguments + path_someMethod = ['([A-z0-9-_]*)'] + + # Specify the default permission for the method + # See :py:class:`DIRAC.Core.DISET.AuthManager.AuthManager` + auth_someMethod = ['authenticated'] + + def web_someMethod(self, provider=None): + ''' Your method + ''' + return S_OK(provider) + + Note that because we inherit from :py:class:`tornado.web.RequestHandler` + and we are running using executors, the methods you export cannot write + back directly to the client. Please see inline comments for more details. + + In order to pass information around and keep some states, we use instance attributes. + These are initialized in the :py:meth:`.initialize` method. + + The handler define the ``post`` and ``get`` verbs. Please refer to :py:meth:`.post` for the details. + """ + USE_AUTHZ_GRANTS = ['SSL', 'JWT', 'VISITOR'] METHOD_PREFIX = 'web_' LOCATION = '/' @@ -56,9 +100,7 @@ def _getMethodName(self): :return: str """ - print(self.request.path) method = self.request.path.replace(self.LOCATION, '', 1).strip('/').split('/')[0] - print(method) if method and hasattr(self, ''.join([self.METHOD_PREFIX, method])): return method elif hasattr(self, '%sindex' % self.METHOD_PREFIX): @@ -71,10 +113,36 @@ def _getMethodName(self): @gen.coroutine def get(self, *args, **kwargs): # pylint: disable=arguments-differ """ Method to handle incoming ``GET`` requests. - Logic copied from :py:func:`~DIRAC.Core.Tornado.Server.BaseRequestHandler.post`. + Note that all the arguments are already prepared in the :py:meth:`.prepare` method. """ - # Execute the method in an executor (basically a separate thread) retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + self._finishFuture(retVal) - # retVal is :py:class:`tornado.concurrent.Future` + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + """ Method to handle incoming ``POST`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` method. + """ + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) self._finishFuture(retVal) + + auth_echo = ['all'] + + @staticmethod + def web_echo(data): + """ + This method used for testing the performance of a service + """ + return S_OK(data) + + auth_whoami = ['authenticated'] + + def web_whoami(self): + """ + A simple whoami, returns all credential dictionary, except certificate chain object. + """ + credDict = self.srv_getRemoteCredentials() + if 'x509Chain' in credDict: + # Not serializable + del credDict['x509Chain'] + return S_OK(credDict) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index cd23bee05ac..4f765b9f6ba 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -9,6 +9,10 @@ __RCSID__ = "$Id$" +import tornado.ioloop +from tornado import gen +from tornado.ioloop import IOLoop + import DIRAC from DIRAC import gConfig, gLogger, S_OK @@ -24,6 +28,11 @@ class TornadoService(BaseRequestHandler): # pylint: disable=abstract-method """ + Base class for all the sevices handlers. + It directly inherits from :py:class:`DIRAC.Core.Tornado.Server.BaseRequestHandler.BaseRequestHandler` + + Each HTTP request is served by a new instance of this class. + In order to create a handler for your service, it has to follow a certain skeleton:: @@ -82,6 +91,8 @@ def export_streamToClient(self, myDataToSend, token): In order to pass information around and keep some states, we use instance attributes. These are initialized in the :py:meth:`.initialize` method. + The handler only define the ``post`` verb. Please refer to :py:meth:`.post` for the details. + """ # Prefix of methods names METHOD_PREFIX = "export_" @@ -111,6 +122,16 @@ def _getServiceInfo(cls, serviceName, request): 'csPaths': [PathFinder.getServiceSection(serviceName)], 'URL': request.full_url()} + @classmethod + def _getServiceAuthSection(cls, serviceName): + """ Search service auth section. + + :param str serviceName: service name + + :return: str + """ + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + def _getMethodName(self): """ Parse method name. @@ -307,3 +328,125 @@ def __executeMethod(self): # Decode args args_encoded = self.get_body_argument('args', default=encode([])) return decode(args_encoded)[0] + + # Make post a coroutine. + # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines + # for details + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + """ + Method to handle incoming ``POST`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` + method. + + The ``POST`` arguments expected are: + + * ``method``: name of the method to call + * ``args``: JSON encoded arguments for the method + * ``extraCredentials``: (optional) Extra informations to authenticate client + * ``rawContent``: (optionnal, default False) If set to True, return the raw output + of the method called. + + If ``rawContent`` was requested by the client, the ``Content-Type`` + is ``application/octet-stream``, otherwise we set it to ``application/json`` + and JEncode retVal. + + If ``retVal`` is a dictionary that contains a ``Callstack`` item, + it is removed, not to leak internal information. + + + Example of call using ``requests``:: + + In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' + ...: cert = '/tmp/x509up_u1000' + ...: kwargs = {'method':'whoami'} + ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' + ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: + ...: print r.json() + ...: + {u'OK': True, + u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'group': u'dirac_user', + u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'isLimitedProxy': False, + u'isProxy': True, + u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'properties': [u'NormalUser'], + u'secondsLeft': 85441, + u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch/CN=2409820262', + u'username': u'adminusername', + u'validDN': False, + u'validGroup': False}} + """ + # Execute the method in an executor (basically a separate thread) + # Because of that, we cannot calls certain methods like `self.write` + # in _executeMethod. This is because these methods are not threadsafe + # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + # However, we can still rely on instance attributes to store what should + # be sent back (reminder: there is an instance + # of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) + + auth_ping = ['all'] + + def export_ping(self): + """ + Default ping method, returns some info about server. + + It returns the exact same information as DISET, for transparency purpose. + """ + # COPY FROM DIRAC.Core.DISET.RequestHandler + dInfo = {} + dInfo['version'] = DIRAC.version + dInfo['time'] = datetime.utcnow() + # Uptime + try: + with open("/proc/uptime", 'rt') as oFD: + iUptime = int(float(oFD.readline().split()[0].strip())) + dInfo['host uptime'] = iUptime + except Exception: # pylint: disable=broad-except + pass + startTime = self._startTime + dInfo['service start time'] = self._startTime + serviceUptime = datetime.utcnow() - startTime + dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds + # Load average + try: + with open("/proc/loadavg", 'rt') as oFD: + dInfo['load'] = " ".join(oFD.read().split()[:3]) + except Exception: # pylint: disable=broad-except + pass + dInfo['name'] = self._serviceInfoDict['serviceName'] + stTimes = os.times() + dInfo['cpu times'] = {'user time': stTimes[0], + 'system time': stTimes[1], + 'children user time': stTimes[2], + 'children system time': stTimes[3], + 'elapsed real time': stTimes[4] + } + + return S_OK(dInfo) + + auth_echo = ['all'] + + @staticmethod + def export_echo(data): + """ + This method used for testing the performance of a service + """ + return S_OK(data) + + auth_whoami = ['authenticated'] + + def export_whoami(self): + """ + A simple whoami, returns all credential dictionary, except certificate chain object. + """ + credDict = self.srv_getRemoteCredentials() + if 'x509Chain' in credDict: + # Not serializable + del credDict['x509Chain'] + return S_OK(credDict) From 4ae8cd288963b818e4c86866af7884ff48081fa1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:30:33 +0200 Subject: [PATCH 032/178] add refresh token encryption --- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 194 +++++++++++++++++-------- 1 file changed, 137 insertions(+), 57 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index d6f7143bdcd..a1017db5d22 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -6,17 +6,20 @@ import jwt import json +import time +import pprint +import M2Crypto -from time import time from sqlalchemy import Column, Integer, Text, String from sqlalchemy.orm import scoped_session from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from sqlalchemy.ext.declarative import declarative_base from authlib.jose import KeySet, RSAKey +from authlib.common.encoding import urlsafe_b64decode, urlsafe_b64encode, to_bytes, to_unicode, json_b64encode from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin -from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC import S_OK, S_ERROR from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token @@ -26,6 +29,22 @@ Model = declarative_base() +def encrypt(data, key): + """ Encryption with key """ + cipher = M2Crypto.EVP.Cipher(alg='aes_256_cbc', key=key[16:], iv=key[:16], op=1) + ciphertext = cipher.update(data.encode('utf-8')) + cipher.final() + ciphertext = urlsafe_b64encode(ciphertext) + return ciphertext + + +def decrypt(ciphertext, key): + """ Decryption with key """ + cipher = M2Crypto.EVP.Cipher(alg='aes_256_cbc', key=key[16:], iv=key[:16], op=0) + data = cipher.update(urlsafe_b64decode(to_bytes(ciphertext))) + cipher.final() + data = to_unicode(data.decode('utf-8')) + return data + + class Token(Model, OAuth2TokenMixin): __tablename__ = 'Token' __table_args__ = {'mysql_engine': 'InnoDB', @@ -34,9 +53,14 @@ class Token(Model, OAuth2TokenMixin): # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length id = Column(Integer, autoincrement=True, primary_key=True) + kid = Column(String(255)) + user_id = Column(String(255)) + provider = Column(String(255)) + client_id = Column(String(255)) + expires_at = Column(Integer, nullable=False, default=0) access_token = Column(Text, nullable=False) refresh_token = Column(Text, nullable=False) - expires_at = Column(Integer, nullable=False, default=0) + rt_expires_at = Column(Integer, nullable=False, default=0) class JWK(Model): @@ -53,19 +77,19 @@ class AuthSession(Model): __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} id = Column(String(255), unique=True, primary_key=True, nullable=False) - state = Column(String(255)) uri = Column(String(255)) - client_id = Column(String(255)) + state = Column(String(255)) + scope = Column(String(255)) user_id = Column(String(255)) username = Column(String(255)) + client_id = Column(String(255)) + user_code = Column(String(255)) + device_code = Column(String(255)) + interval = Column(Integer, nullable=False, default=5) expires_at = Column(Integer, nullable=False, default=0) expires_in = Column(Integer, nullable=False, default=0) - interval = Column(Integer, nullable=False, default=5) verification_uri = Column(String(255)) verification_uri_complete = Column(String(255)) - user_code = Column(String(255)) - device_code = Column(String(255)) - scope = Column(String(255)) class AuthDB(SQLAlchemyDB): @@ -109,63 +133,109 @@ def __initializeDB(self): return S_OK() - def getToken(self, token, token_type_hint='refresh_token'): - """ Find Token for refresh token + def encryptRefreshToken(self, token, metadata): + """ Encrypt refresh token - :param str token: token - :param str token_type_hint: token type + :param dict token: token dict + :param str client_id: client ID + :param str provider: provider name - :return: S_OK()/S_ERROR() + :return: S_OK(dict)/S_ERROR() """ - session = self.session() + for field in ['expires_at', 'client_id', 'provider']: + if not metadata.get(field): + return S_ERROR('%s field is absent in metadata.' % field) + # Get secret key + key = self.getPrivateKey() + if not key['OK']: + return key + # Encrypt refresh token try: - session.query(Token).filter(Token.expires_at < time()).delete() - if token_type_hint == 'access_token': - token = session.query(Token).filter(Token.access_token == token).first() - else: - token = session.query(Token).filter(Token.refresh_token == token).first() - if not token: - return self.__result(session, S_ERROR("Token not found.")) - except NoResultFound: - return self.__result(session, S_ERROR("Token not found.")) + metadata['kid'] = key['Value']['kid'] + metadata['refresh_token'] = encrypt(token['refresh_token'], key['Value']['strkey']) + token['refresh_token'] = json_b64encode(metadata) + return S_OK(token) except Exception as e: - return self.__result(session, S_ERROR(str(e))) - return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)))) + self.log.exception(e) + return S_ERROR('Cannot encode refresh token: %s' % repr(e)) - def revokeToken(self, token): - """ Revoke token + def decryptRefreshToken(self, token): + """ Decrypt refresh token - :param dict token: token to revoke + :param dict token: token dict - :return: S_OK()/S_ERROR() + :return: S_OK(dict)/S_ERROR() + """ + try: + decoded = json.loads(urlsafe_b64decode(token['refresh_token'])) + except Exception as e: + return S_ERROR('Cannot find secret key: %s' % repr(e)) + # Get secret key by key ID + key = self.getPrivateKey(decoded['kid']) + if not key['OK']: + return key + # Decript refresh token + try: + token['refresh_token'] = decrypt(decoded['refresh_token'], key['Value']['strkey']) + token['expires_at'] = decoded['expires_at'] + token['client_id'] = decoded['client_id'] + token['provider'] = decoded['provider'] + return S_OK(OAuth2Token(token)) + except Exception as e: + self.log.exception(e) + return S_ERROR('Cannot decode refresh token: %s' % repr(e)) + + def getTokenForUserProvider(self, userID, provider): + """ Get token for user ID and provider name + + :param str userID: user ID + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() """ session = self.session() try: - token = session.query(Token).filter(Token.access_token == token['access_token']).first() - token.revoked = True - except NoResultFound: - return self.__result(session, S_OK()) + token = session.query(Token).filter(Token.rt_expires_at > time.time()).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).first() except Exception as e: - return self.__result(session, S_ERROR('Could not revoke token: %s' % e)) - return self.__result(session, S_OK()) + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) - def storeToken(self, token): - """ Save token + def updateToken(self, token, userID, provider): + """ Update tokens :param dict token: token info + :param str userID: user ID + :param str provider: provider - :return: S_OK(str)/S_ERROR() + :return: S_OK(list)/S_ERROR() """ - token['expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False))['exp']) - gLogger.debug('Store token:', dict(token)) + token['user_id'] = userID + token['provider'] = provider + try: + token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, verify_aud=False))['exp']) + except Exception as e: + self.log.debug('Cannot get refresh token expires time: %s' % repr(e)) + + token['rt_expires_at'] = int(token.get('rt_expires_at', 24 * 3600 + time.time())) + if token['rt_expires_at'] < time.time(): + return S_ERROR('Cannot store expired refresh token.') + + attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) + self.log.debug('Store token:', pprint.pformat(attrts)) session = self.session() try: - session.query(Token).filter(Token.access_token == token['access_token']).delete() + session.query(Token).filter(Token.expires_at < time.time()).delete() + oldTokens = session.query(Token).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).all() session.add(Token(**attrts)) + session.query(Token).filter(Token.user_id == userID).filter(Token.provider == provider)\ + .filter(Token.access_token != token['access_token']).delete() except Exception as e: - return self.__result(session, S_ERROR('Could not add Token: %s' % e)) - return self.__result(session, S_OK('Token successfully added')) + return self.__result(session, S_ERROR('Could not add Token: %s' % repr(e))) + self.log.info('Token successfully added for %s user, %s provider' % (token['user_id'], token['provider'])) + return self.__result(session, S_OK([self.__rowToDict(t) for t in oldTokens] if oldTokens else [])) def removeTokens(self): """ Get active keys @@ -186,7 +256,7 @@ def generateRSAKeys(self): """ key = RSAKey.generate_key(key_size=1024, is_private=True) dictKey = dict(key=json.dumps(key.as_dict()), - expires_at=time() + (30 * 24 * 3600), + expires_at=time.time() + (30 * 24 * 3600), kid=KeySet([key]).as_dict()['keys'][0]['kid']) session = self.session() @@ -227,35 +297,45 @@ def getJWKs(self): keys.append({'n': k['n'], "kty": k['kty'], "e": k['e'], "kid": k['kid']}) return S_OK({'keys': keys}) - def getPrivateKey(self): + def getPrivateKey(self, kid=None): """ Get private key + :param str kid: key ID + :return: S_OK(obj)/S_ERROR() """ - result = self.getActiveKeys() + result = self.getActiveKeys(kid) if not result['OK']: return result + jwks = result['Value'] + if kid: + strkey=jwks[0]['key'] + return S_OK(dict(rsakey=RSAKey.import_key(json.loads(strkey)), kid=kid, strkey=strkey)) newer = {} - for d in result['Value']: - if d['expires_at'] > newer.get('expires_at', time() + (24 * 3600)): - newer = d + for jwk in jwks: + if jwk['expires_at'] > newer.get('expires_at', time.time() + (24 * 3600)): + newer = jwk if not newer.get('key'): result = self.generateRSAKeys() if not result['OK']: return result newer = result['Value'] - return S_OK({'key': RSAKey.import_key(json.loads(newer['key'])), 'kid': newer['kid']}) + return S_OK(dict(rsakey=RSAKey.import_key(json.loads(newer['key'])), kid=newer['kid'], strkey=newer['key'])) - def getActiveKeys(self): + def getActiveKeys(self, kid=None): """ Get active keys + :param str kid: key ID + :return: S_OK(list)/S_ERROR() """ session = self.session() try: # Remove all expired jwks - session.query(JWK).filter(JWK.expires_at < time()).delete() - jwks = session.query(JWK).filter(JWK.expires_at > time()).all() + session.query(JWK).filter(JWK.expires_at < time.time()).delete() + jwks = session.query(JWK).filter(JWK.expires_at > time.time()).all() + if kid: + jwks = [jwk for jwk in jwks if jwk.kid == kid] except NoResultFound: return self.__result(session, S_OK([])) except Exception as e: @@ -283,8 +363,8 @@ def addSession(self, data): """ attrts = {} if not data.get('expires_at'): - data['expires_at'] = data['expires_in'] + time() - gLogger.debug('Add authorization session:', data) + data['expires_at'] = data['expires_in'] + time.time() + self.log.debug('Add authorization session:', data) for k, v in data.items(): if k not in AuthSession.__dict__.keys(): self.log.warn('%s is not expected as authentication session attribute.' % k) @@ -318,7 +398,7 @@ def removeSession(self, sessionID): session = self.session() try: # Remove all expired sessions - session.query(AuthSession).filter(AuthSession.expires_at < time()).delete() + session.query(AuthSession).filter(AuthSession.expires_at < time.time()).delete() session.query(AuthSession).filter(AuthSession.id == sessionID).delete() except Exception as e: return self.__result(session, S_ERROR(str(e))) From 4e94c02702e4eabe2bfb59049909bae664abb647 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:32:01 +0200 Subject: [PATCH 033/178] align with changes --- .../authorization/grants/AuthorizationCode.py | 2 +- .../authorization/grants/RefreshToken.py | 47 ++++++++++++------- .../authorization/grants/RevokeToken.py | 27 ++++++----- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index 4063c20d8d9..88887485f36 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -120,5 +120,5 @@ def generate_authorization_code(self): result = self.server.db.getPrivateKey() if not result['OK']: raise OAuth2Error('Cannot generate authorization code: %s' % result['Message']) - key = result['Value']['key'] + key = result['Value']['rsakey'] return jws.serialize_compact(protected, json_b64encode(dict(code)), key) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py index 40283cd4745..6e977b9eca3 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -5,12 +5,12 @@ from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant -from DIRAC import gLogger - class RefreshTokenGrant(_RefreshTokenGrant): """ See :class:`authlib.oauth2.rfc6749.grants.RefreshTokenGrant` """ + DEFAULT_EXPIRES_AT = 12 * 3600 + def authenticate_refresh_token(self, refresh_token): """ Get credential for token @@ -18,27 +18,38 @@ def authenticate_refresh_token(self, refresh_token): :return: dict or None """ - # Check auth session - result = self.server.db.getToken(refresh_token) + result = self.server.db.decryptRefreshToken({'refresh_token': refresh_token}) if not result['OK']: - raise OAuth2Error('Cannot get token', result['Message']) - token = result['Value'] - return None if token.revoked else token + raise OAuth2Error(result['Message']) + return result['Value'] + + def _validate_token_scope(self, token): + """ Skip scope validadtion """ + pass def authenticate_user(self, credential): - """ Authorize user + """ Authorize user """ + return True - :param object credential: credential + def issue_token(self, user, credential): + """ Refresh tokens + + :param user: unuse + :param dict credential: token credential - :return: str + :return: dict """ - return credential.sub + result = self.server.idps.getIdProvider(credential['provider']) + if result['OK']: + result = result['Value'].refreshToken(credential['refresh_token']) + if result['OK']: + result = self.server.db.encryptRefreshToken(result['Value'], dict(provider=credential['provider'], + client_id=credential['client_id'], + expires_at=self.DEFAULT_EXPIRES_AT)) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] def revoke_old_credential(self, credential): - """ Remove old credential - - :param object credential: credential - """ - result = self.server.db.revokeToken(credential) - if not result['OK']: - gLogger.error(result['Message']) + """ Remove old credential """ + pass \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py index 3e8980acaff..c3f466c333b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -2,10 +2,9 @@ from __future__ import division from __future__ import print_function +from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint -from DIRAC import gLogger - class RevocationEndpoint(_RevocationEndpoint): """ See :class:`authlib.oauth2.rfc7009.RevocationEndpoint` """ @@ -19,21 +18,23 @@ def query_token(self, token, token_type_hint, client): :return: str """ - result = self.server.db.getToken(token, token_type_hint) - if not result['OK']: - gLogger.error(result['Message']) - return None - rv = result['Value'] - client_id = client.get_client_id() - if rv and rv.client_id == client_id: - return rv - return None + if token_type_hint == 'refresh_token': + result = self.server.db.decryptRefreshToken({'refresh_token': token}) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] + return token def revoke_token(self, token): """ Mark the give token as revoked. :param dict token: token dict """ - result = self.server.db.revokeToken(token) + if isinstance(token, dict): + result = self.server.idps.getIdProvider(token['provider']) + else: + result = self.server.idps.getIdProviderForToken(token) + if result['OK']: + result = result['Value'].revokeToken(token) if not result['OK']: - gLogger.error(result['Message']) + raise OAuth2Error(result['Message']) From fd0fe19becc57336ce0cf6ea2da9f659753d4c62 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:37:28 +0200 Subject: [PATCH 034/178] remove DIRAC tokens --- .../private/authorization/AuthServer.py | 138 +++++++----------- 1 file changed, 52 insertions(+), 86 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 32882c1d9d6..89f275eff00 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -3,16 +3,15 @@ from __future__ import division from __future__ import print_function -import json -from time import time +import sys +import time import pprint +import logging from dominate import document, tags as dom from tornado.template import Template -from authlib.jose import jwt from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc6750 import BearerToken from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.oauth2.rfc6749.util import scope_to_list @@ -31,11 +30,8 @@ from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername -from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getSetup from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient -import logging -import sys log = logging.getLogger('authlib') log.addHandler(logging.StreamHandler(sys.stdout)) log.setLevel(logging.DEBUG) @@ -77,9 +73,8 @@ def __init__(self): self.proxyCli = ProxyManagerClient() self.idps = IdProviderFactory() # Privide two authlib methods query_client and save_token - _AuthorizationServer.__init__(self, query_client=getDIACClientByID, save_token=self.saveToken) + _AuthorizationServer.__init__(self, query_client=getDIACClientByID, save_token=lambda x, y: None) self.generate_token = self.generateProxyOrToken - self.bearerToken = BearerToken(self.access_token_generator, self.refresh_token_generator) self.config = {} self.metadata = collectMetadata() self.metadata.validate() @@ -96,18 +91,6 @@ def addSession(self, session): def getSession(self, session): self.db.getSession(session) - def saveToken(self, token, request): - """ Store tokens - - :param dict token: tokens - :param object request: Request object - """ - if token.get('refresh_token'): - token['client_id'] = request.client.client_id - result = self.db.storeToken(token) - if not result['OK']: - gLogger.error(result['Message']) - def __getScope(self, scope, param): """ Get parameter scope @@ -125,13 +108,14 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, expires_in=None, include_refresh_token=True): """ Generate proxy or tokens after authorization """ + group = self.__getScope(scope, 'g') + lifetime = self.__getScope(scope, 'lifetime') + provider = getIdPForGroup(group) + if 'proxy' in scope_to_list(scope): # Try to return user proxy if proxy scope present in the authorization request if not isDownloadablePersonalProxy(): raise Exception("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") - - group = self.__getScope(scope, 'g') - lifetime = self.__getScope(scope, 'lifetime') gLogger.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) result = getUsernameForDN('/O=DIRAC/CN=%s' % user) if result['OK']: @@ -156,8 +140,33 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, return {'proxy': result['Value']} raise Exception('; '.join(err)) - return self.bearerToken(client, grant_type, user=user, scope=scope, expires_in=expires_in, - include_refresh_token=include_refresh_token) + else: + # Get identity provider + result = self.idps.getIdProvider(provider) + if result['OK']: + idpObj = result['Value'] + # Get actual token from storage + result = self.db.getTokenForUserProvider(user, provider) + if result['OK']: + idpObj.token = result['Value'] + # Try to refresh it if expired + if idpObj.token.is_expired(): + result = idpObj.refreshToken() + if result['OK']: + result = self.db.updateToken(idpObj.token, user, provider) + if not result['OK']: + raise OAuth2Error(result['Message']) + # Ask identity provider tokens with needed group scopes + result = idpObj.exchangeGroup(group) + if result['OK']: + token = result['Value'] + # Encrypt refresh token + result = self.db.encryptRefreshToken(token, dict(provider=idpObj.name, + client_id=client.get_client_id(), + expires_at=12 * 3600 + time.time())) + if not result['OK']: + raise OAuth2Error(result['Message']) + return result['Value'] def getIdPAuthorization(self, providerName, request): """ Submit subsession and return dict with authorization url and session number @@ -195,17 +204,15 @@ def parseIdPAuthorizationResponse(self, response, session): result = self.idps.getIdProvider(providerName) if not result['OK']: return result - provObj = result['Value'] - result = provObj.parseAuthResponse(response, session) + idpObj = result['Value'] + result = idpObj.parseAuthResponse(response, session) if not result['OK']: return result - # FINISHING with IdP auth result + # FINISHING with IdP + # As a result of authentication we will receive user credential dictionary credDict = result['Value'] - # ########### TODO: This line will store the original tokens ############ # - # updateToken(provObj.token, user_id=provObj.token['user_id']) - gLogger.debug("Read profile:", pprint.pformat(credDict)) # Is ID registred? result = getUsernameForDN(credDict['DN']) @@ -218,60 +225,19 @@ def parseIdPAuthorizationResponse(self, response, session): comment += ' Please, contact the DIRAC administrators.' return S_ERROR(comment) credDict['username'] = result['Value'] - return S_OK(credDict) - - def access_token_generator(self, client, grant_type, user, scope): - """ A function to generate ``access_token`` - - :param object client: Client object - :param str grant_type: grant type - :param str user: user unique id - :param str scope: scope - :return: str - """ - gLogger.debug('GENERATE DIRAC ACCESS TOKEN for "%s" with "%s" scopes.' % (user, scope)) - return self.signToken({'sub': user, - 'iss': self.metadata['issuer'], - 'iat': int(time()), - 'exp': int(time()) + (self.__getScope(scope, 'lifetime') or (12 * 3600)), - 'scope': scope, - 'setup': getSetup(), - 'group': self.__getScope(scope, 'g')}) - - def refresh_token_generator(self, client, grant_type, user, scope): - """ A function to generate ``refresh_token`` - - :param object client: Client object - :param str grant_type: grant type - :param str user: user unique id - :param str scope: scope - - :return: str - """ - gLogger.debug('GENERATE DIRAC REFRESH TOKEN for "%s" with "%s" scopes.' % (user, scope)) - return self.signToken({'sub': user, - 'iss': self.metadata['issuer'], - # 'iat': int(time()), - 'exp': int(time()) + (24 * 3600)}) - - def signToken(self, payload): - """ Sign token - - :param dict payload: token payload - - ;return: str - """ - result = self.db.getPrivateKey() + # Update token for user. This token will be stored separately in the database and + # updated from time to time. This token will never be transmitted, + # it will be used to make exchange token requests. + result = self.db.updateToken(idpObj.token, credDict['ID'], idpObj.name) if not result['OK']: - raise Exception(result['Message']) + return result + + # Revoke old tokens + for oldToken in result['Value']: + idpObj.revokeToken(oldToken.get('refresh_token')) - # Sign token - key = result['Value']['key'] - kid = result['Value']['kid'] - header = {'alg': 'RS256', 'kid': kid} - # Need to use enum==0.3.1 for python 2.7 - return jwt.encode(header, payload, key) + return S_OK(credDict) def get_error_uris(self, request): error_uris = self.config.get('error_uris') @@ -392,16 +358,16 @@ def __registerNewUser(self, provider, userProfile): """ from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient - username = userProfile['DN'] + username = userProfile['ID'] mail = {} mail['subject'] = "[SessionManager] User %s to be added." % username - mail['body'] = 'User %s was authenticated by ' % userProfile['FullName'] + mail['body'] = 'User %s was authenticated by ' % username mail['body'] += provider mail['body'] += "\n\nAuto updating of the user database is not allowed." mail['body'] += " New user %s to be added," % username mail['body'] += "with the following information:\n" - mail['body'] += "\nUser name: %s\n" % username + mail['body'] += "\nUser ID: %s\n" % username mail['body'] += "\nUser profile:\n%s" % pprint.pformat(userProfile) mail['body'] += "\n\n------" mail['body'] += "\n This is a notification from the DIRAC AuthManager service, please do not reply.\n" From af056280099bee9b01953d57ff1a7b2a78f8907c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:39:02 +0200 Subject: [PATCH 035/178] update clients, token class --- .../private/authorization/utils/Clients.py | 8 +- .../private/authorization/utils/Tokens.py | 263 +++++++++--------- 2 files changed, 135 insertions(+), 136 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index ca6deb84446..b1336c86b82 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -21,8 +21,9 @@ 'DIRACCLI': dict( verify=False, client_id='DIRAC_CLI', + client_secret='secret', response_types=['device'], - grant_types=['urn:ietf:params:oauth:grant-type:device_code'], + grant_types=['urn:ietf:params:oauth:grant-type:device_code', 'refresh_token'], ProviderType='DIRACCLI' ), 'DIRACWeb': dict( @@ -70,7 +71,10 @@ class Client(OAuth2ClientMixin): def __init__(self, params): super(Client, self).__init__() client_metadata = params.get('client_metadata', params) - client_metadata['scope'] = ' '.join(list(set([client_metadata.get('scope', ''), DEFAULT_SCOPE]))) + if client_metadata.get('scope') and DEFAULT_SCOPE not in client_metadata['scope']: + client_metadata['scope'] += ' %s' % DEFAULT_SCOPE + else: + client_metadata['scope'] = DEFAULT_SCOPE if params.get('redirect_uri') and not client_metadata.get('redirect_uris'): client_metadata['redirect_uris'] = [params['redirect_uri']] self.client_id = params['client_id'] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 6b499f65258..02023c4855a 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -3,6 +3,8 @@ from __future__ import print_function import os +import re +import jwt import six import stat import time @@ -20,15 +22,31 @@ def getTokenLocation(): - """ Get the path of the currently active access token file + """ Research token file location. Use the bearer token discovery protocol + defined by the WLCG (https://zenodo.org/record/3937438) to find one. + + :return: str """ - envVar = 'DIRAC_TOKEN_FILE' - if envVar in os.environ: - tokenPath = os.path.realpath(os.environ[envVar]) - if os.path.isfile(tokenPath): - return tokenPath - # /tmp/JWTup_u - return "/tmp/JWTup_u%s" % os.getuid() + if os.environ.get('BEARER_TOKEN_FILE'): + return os.environ['BEARER_TOKEN_FILE'] + elif os.environ.get('XDG_RUNTIME_DIR'): + return "%s/bt_u%s" % (os.environ['XDG_RUNTIME_DIR'], os.getuid()) + else: + return "/tmp/bt_u%s" % os.getuid() + + +def getLocalTokenDict(location=None): + """ Search local token. Use the bearer token discovery protocol + defined by the WLCG (https://zenodo.org/record/3937438) to find one. + + :param str location: environ variable name or file path + + :return: S_OK(dict)/S_ERROR() + """ + env = (location if location and location.startswith('/') else None) or 'BEARER_TOKEN' + if os.environ.get(env): + return S_OK(OAuth2Token(os.environ[env])) + return readTokenFromFile(location if location and location.startswith('/') else None) def readTokenFromFile(fileName=None): @@ -38,13 +56,13 @@ def readTokenFromFile(fileName=None): :return: S_OK(dict)/S_ERROR() """ - fileName = fileName or getTokenLocation() + location = fileName or getTokenLocation() try: - with open(fileName, 'rt') as f: - tokenDict = f.read() - return S_OK(json.loads(tokenDict)) - except Exception as e: - return S_ERROR('Cannot read token. %s' % repr(e)) + with open(location, 'rt') as f: + token = f.read() + except IOError as e: + return S_ERROR(DErrno.EOF, "Can't open %s token file.\n%s" % (location, repr(e))) + return S_OK(OAuth2Token(token)) def writeToTokenFile(tokenContents, fileName): @@ -55,16 +73,17 @@ def writeToTokenFile(tokenContents, fileName): :return: S_OK(str)/S_ERROR() """ + location = fileName or getTokenLocation() try: - with open(fileName, 'wt') as fd: + with open(location, 'wt') as fd: fd.write(tokenContents) except Exception as e: - return S_ERROR(DErrno.EWF, " %s: %s" % (fileName, repr(e))) + return S_ERROR(DErrno.EWF, " %s: %s" % (location, repr(e))) try: - os.chmod(fileName, stat.S_IRUSR | stat.S_IWUSR) + os.chmod(location, stat.S_IRUSR | stat.S_IWUSR) except Exception as e: - return S_ERROR(DErrno.ESPF, "%s: %s" % (fileName, repr(e))) - return S_OK(fileName) + return S_ERROR(DErrno.ESPF, "%s: %s" % (location, repr(e))) + return S_OK(location) def writeTokenDictToTokenFile(tokenDict, fileName=None): @@ -76,128 +95,41 @@ def writeTokenDictToTokenFile(tokenDict, fileName=None): :return: S_OK(str)/S_ERROR() """ fileName = fileName or getTokenLocation() - try: - retVal = json.dumps(tokenDict) - except Exception as e: - return S_ERROR('Cannot dump token to string. %s' % repr(e)) - return writeToTokenFile(retVal, fileName) + if not isinstance(tokenDict, dict): + return S_ERROR('Token is not a dictionary') + return writeToTokenFile(json.dumps(tokenDict), fileName) -def writeTokenDictToTemporaryFile(tokenDict): - """ Write a token dict to a temporary file - - :param dict tokenDict: dict object to dump to file - - :return: S_OK(str)/S_ERROR() -- contain file name - """ - try: - fd, tokenLocation = tempfile.mkstemp() - os.close(fd) - except IOError: - return S_ERROR(DErrno.ECTMPF) - retVal = writeTokenDictToTokenFile(tokenDict, tokenLocation) - if not retVal['OK']: - try: - os.unlink(tokenLocation) - except Exception: - pass - return retVal - return S_OK(tokenLocation) - - -def getTokenInfo(token=False): - """ Return token info - - :param token: token location or token as dict - - :return: S_OK(dict)/S_ERROR() - """ - # Discover token location - if isinstance(token, dict): - token = OAuth2Token(token) - else: - tokenLocation = token if isinstance(token, six.string_types) else getTokenLocation() - if not tokenLocation: - return S_ERROR("Cannot find token location.") - result = readTokenFromFile(tokenLocation) - if not result['OK']: - return result - token = OAuth2Token(result['Value'])['access_token'] - - result = IdProviderFactory().getIdProviderForToken(token) - if not result['OK']: - return S_ERROR("Cannot load provider: %s" % result['Message']) - cli = result['Value'] - cli.updateJWKs() - payload = cli.verifyToken(token) - - result = Registry.getUsernameForDN('/O=DIRAC/CN=%s' % payload['sub']) - if not result['OK']: - return result - payload['username'] = result['Value'] - if payload.get('group'): - payload['properties'] = Registry.getPropertiesForGroup(payload['group']) - return S_OK(payload) - - -def formatTokenInfoAsString(infoDict): - """ Convert a token infoDict into a string - - :param dict infoDict: info - - :return: str - """ - secsLeft = int(infoDict['exp']) - time.time() - strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") - - leftAlign = 13 - contentList = [] - contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) - contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) - contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) - contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) - if infoDict.get('group'): - contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) - if infoDict.get('properties'): - contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) - return "\n".join(contentList) - - -class OAuth2Token(_OAuth2Token, OAuth2TokenMixin): +class OAuth2Token(_OAuth2Token): """ Implementation a Token object """ def __init__(self, params=None, **kwargs): + """ Constructor + """ + if isinstance(params, six.string_types): + # Is params a JWT? + if re.match(r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$", params): + params = dict(access_token=params) + else: + params = json.loads(params) + kwargs.update(params or {}) - kwargs['revoked'] = False if kwargs.get('revoked', 'False') == 'False' else True - self.sub = kwargs.get('sub') - self.issuer = kwargs.get('iss') - self.client_id = kwargs.get('client_id', kwargs.get('aud')) - self.token_type = kwargs.get('token_type') - self.access_token = kwargs.get('access_token') - self.refresh_token = kwargs.get('refresh_token') - self.scope = kwargs.get('scope') - self.revoked = kwargs.get('revoked') - self.issued_at = int(kwargs.get('issued_at', kwargs.get('iat', time.time()))) - self.expires_in = int(kwargs.get('expires_in', 0)) - self.expires_at = int(kwargs.get('expires_at', kwargs.get('exp', 0))) - if not self.issued_at: - raise Exception('Missing "iat" in token.') - if not self.expires_at: - if not self.expires_in: - raise Exception('Cannot calculate token "expires_at".') - self.expires_at = self.issued_at + self.expires_in - if not self.expires_in: - self.expires_in = self.expires_at - self.issued_at - kwargs.update({'client_id': self.client_id, - 'token_type': self.token_type, - 'access_token': self.access_token, - 'refresh_token': self.refresh_token, - 'scope': self.scope, - 'revoked': self.revoked, - 'issued_at': self.issued_at, - 'expires_in': self.expires_in, - 'expires_at': self.expires_at}) + if not kwargs.get('expires_at') and kwargs.get('access_token'): + # Get access token expires_at claim + kwargs['expires_at'] = int(self.get_token_attr('exp')) super(OAuth2Token, self).__init__(kwargs) + + def get_client_id(self): + return self.get('client_id') + + def get_scope(self): + return self.get('scope') + + def get_expires_in(self): + return self.get('expires_in') + + def get_expires_at(self): + return self.get('issued_at') + self.get('expires_in') @property def scopes(self): @@ -214,3 +146,66 @@ def groups(self): :return: list """ return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] + + def get_token_attr(self, attr, token_type='access_token'): + """ Get token attribute without verification + + :param str attr: attribute + :param str token_type: token type + + :return: str + """ + if not self.get(token_type): + return None + return jwt.decode(self.get(token_type), options=dict(verify_signature=False, + verify_exp=False, + verify_aud=False, + verify_nbf=False)).get(attr) + + def getInfoAsString(self): + """ Return information about token as string + + :return: str + """ + result = IdProviderFactory().getIdProviderForToken(self.get('access_token')) + if not result['OK']: + return "Cannot load provider: %s" % result['Message'] + cli = result['Value'] + cli.token = self.copy() + result = cli.verifyToken() + if not result['OK']: + return result['Message'] + payload = result['Value'] + result = cli.researchGroup(payload) + if not result['OK']: + return result['Message'] + credDict = result['Value'] + result = Registry.getUsernameForDN(credDict['DN']) + if not result['OK']: + return result['Message'] + credDict['username'] = result['Value'] + if credDict.get('group'): + credDict['properties'] = Registry.getPropertiesForGroup(credDict['group']) + payload.update(credDict) + return self.__formatTokenInfoAsString(payload) + + def __formatTokenInfoAsString(self, infoDict): + """ Convert a token infoDict into a string + + :param dict infoDict: info + + :return: str + """ + secsLeft = int(infoDict['exp']) - time.time() + strTimeleft = datetime.datetime.fromtimestamp(secsLeft).strftime("%I:%M:%S") + leftAlign = 13 + contentList = [] + contentList.append('%s: %s' % ('subject'.ljust(leftAlign), infoDict['sub'])) + contentList.append('%s: %s' % ('issuer'.ljust(leftAlign), infoDict['iss'])) + contentList.append('%s: %s' % ('timeleft'.ljust(leftAlign), strTimeleft)) + contentList.append('%s: %s' % ('username'.ljust(leftAlign), infoDict['username'])) + if infoDict.get('group'): + contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) + if infoDict.get('properties'): + contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) + return "\n".join(contentList) \ No newline at end of file From ca09fa0ba421a4db98497e231d97f80a06ed434c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:41:15 +0200 Subject: [PATCH 036/178] update dirac-login, add dirac-logout --- setup.cfg | 1 + .../FrameworkSystem/scripts/dirac_login.py | 42 +++--- .../FrameworkSystem/scripts/dirac_logout.py | 131 ++++++++++++++++++ 3 files changed, 158 insertions(+), 16 deletions(-) create mode 100644 src/DIRAC/FrameworkSystem/scripts/dirac_logout.py diff --git a/setup.cfg b/setup.cfg index 997b87fca20..73724c20264 100644 --- a/setup.cfg +++ b/setup.cfg @@ -162,6 +162,7 @@ console_scripts = dirac-dms-user-quota = DIRAC.DataManagementSystem.scripts.dirac_dms_user_quota:main # FrameworkSystem dirac-login = DIRAC.FrameworkSystem.scripts.dirac_login:main [server] + dirac-logout = DIRAC.FrameworkSystem.scripts.dirac_logout:main [server] dirac-admin-get-CAs = DIRAC.FrameworkSystem.scripts.dirac_admin_get_CAs:main [server] dirac-admin-get-proxy = DIRAC.FrameworkSystem.scripts.dirac_admin_get_proxy:main [admin] dirac-admin-proxy-upload = DIRAC.FrameworkSystem.scripts.dirac_admin_proxy_upload:main [admin] diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 519a9a99f61..b0691729ae5 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -15,8 +15,6 @@ import os import sys -import requests -import threading import DIRAC from DIRAC import gLogger, S_OK, S_ERROR @@ -25,8 +23,7 @@ from DIRAC.Core.Security.ProxyFile import writeToProxyFile from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, - getTokenInfo, formatTokenInfoAsString) +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import writeTokenDictToTokenFile, readTokenFromFile __RCSID__ = "$Id$" @@ -40,10 +37,10 @@ def __init__(self): self.provider = 'DIRACCLI' self.issuer = None self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() - self.tokenLoc = '/tmp/JWTup_u%s' % os.getuid() + self.tokenLoc = None def returnProxy(self, _arg): - """ Set email + """ To return proxy :return: S_OK() """ @@ -51,7 +48,7 @@ def returnProxy(self, _arg): return S_OK() def setGroup(self, arg): - """ Set email + """ Set group :param str arg: group @@ -61,7 +58,7 @@ def setGroup(self, arg): return S_OK() def setProvider(self, arg): - """ Set email + """ Set provider name :param str arg: provider @@ -71,7 +68,7 @@ def setProvider(self, arg): return S_OK() def setIssuer(self, arg): - """ Set email + """ Set issuer :param str arg: issuer @@ -80,6 +77,16 @@ def setIssuer(self, arg): self.issuer = arg return S_OK() + def setTokenFile(self, arg): + """ Set token file + + :param str arg: token file + + :return: S_OK() + """ + self.tokenLoc = arg + return S_OK() + def setLivetime(self, arg): """ Set email @@ -117,6 +124,11 @@ def registerCLISwitches(self): "lifetime=", "set proxy lifetime in a hours", self.setLivetime) + Script.registerSwitch( + "F:", + "file=", + "set token file location", + self.setTokenFile) def doOAuthMagic(self): """ Magic method with tokens @@ -138,10 +150,7 @@ def doOAuthMagic(self): idpObj.scope += '+lifetime:%s' % (int(self.lifetime) * 3600) # Submit Device authorisation flow - try: - result = idpObj.authorization() - except KeyboardInterrupt: - return S_ERROR('User canceled the operation..') + result = idpObj.deviceAuthorization() if not result['OK']: return result @@ -154,6 +163,7 @@ def doOAuthMagic(self): result = writeTokenDictToTokenFile(idpObj.token, self.tokenLoc) if not result['OK']: return result + self.tokenLoc = result['Value'] gLogger.notice('Token is saved in %s.' % self.tokenLoc) result = Script.enableCS() @@ -167,10 +177,10 @@ def doOAuthMagic(self): return result['Message'] gLogger.notice(formatProxyInfoAsString(result['Value'])) else: - result = getTokenInfo(self.tokenLoc) + result = readTokenFromFile(self.tokenLoc) if not result['OK']: - return result['Message'] - gLogger.notice(formatTokenInfoAsString(result['Value'])) + return result + gLogger.notice(result['Value'].getInfoAsString()) return S_OK() diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py new file mode 100644 index 00000000000..92c3978d71c --- /dev/null +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +######################################################################## +# File : dirac-logout.py +# Author : Andrii Lytovchenko +######################################################################## +""" +Logout + +Example: + $ dirac-logout +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys + +import DIRAC +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Base import Script +from DIRAC.Core.Utilities.DIRACScript import DIRACScript +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import readTokenFromFile, getTokenLocation + +__RCSID__ = "$Id$" + + +class Params(object): + + def __init__(self): + self.provider = 'DIRACCLI' + self.issuer = None + self.tokenLoc = None + + def setProvider(self, arg): + """ Set provider name + + :param str arg: provider + + :return: S_OK() + """ + self.provider = arg + return S_OK() + + def setIssuer(self, arg): + """ Set issuer + + :param str arg: issuer + + :return: S_OK() + """ + self.issuer = arg + return S_OK() + + def setTokenFile(self, arg): + """ Set token file + + :param str arg: token file + + :return: S_OK() + """ + self.tokenLoc = arg + return S_OK() + + def registerCLISwitches(self): + """ Register CLI switches """ + Script.registerSwitch( + "O:", + "provider=", + "set identity provider", + self.setProvider) + Script.registerSwitch( + "I:", + "issuer=", + "set issuer", + self.setIssuer) + Script.registerSwitch( + "F:", + "file=", + "set token file location", + self.setTokenFile) + + def doOAuthMagic(self): + """ Magic method with tokens + + :return: S_OK()/S_ERROR() + """ + params = {} + if self.issuer: + params['issuer'] = self.issuer + result = IdProviderFactory().getIdProvider(self.provider, **params) + if not result['OK']: + return result + idpObj = result['Value'] + self.tokenLoc = self.tokenLoc or getTokenLocation() + result = readTokenFromFile(self.tokenLoc) + if not result['OK']: + return result + token = result['Value'] + # Revoke token + for tokenType in ['access_token', 'refresh_token']: + if token.get(tokenType): + result = idpObj.revokeToken(token[tokenType], tokenType) + if not result['OK']: + gLogger.error(result['Message']) + os.unlink(self.tokenLoc) + gLogger.notice('Token is removed from %s.' % self.tokenLoc) + + return S_OK() + + +@DIRACScript() +def main(): + piParams = Params() + piParams.registerCLISwitches() + + Script.disableCS() + Script.parseCommandLine(ignoreErrors=True) + DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") + + resultDoMagic = piParams.doOAuthMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() From d0c11416bfcd58cf766ff11c9e73fd1914c259dd Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 01:46:26 +0200 Subject: [PATCH 037/178] align with changes --- dirac.cfg | 2 +- .../Resources/IdProvider/CheckInIdProvider.py | 33 +- .../Resources/IdProvider/IAMIdProvider.py | 5 +- .../Resources/IdProvider/IdProviderFactory.py | 8 +- .../Resources/IdProvider/OAuth2IdProvider.py | 357 +++++++++++------- src/DIRAC/Resources/IdProvider/Utilities.py | 18 + tests/Integration/Framework/Test_AuthDB.py | 103 ++--- 7 files changed, 322 insertions(+), 204 deletions(-) diff --git a/dirac.cfg b/dirac.cfg index 1f76f0340c8..3836f791a4f 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -104,7 +104,7 @@ Registry VOMSRole = /lhcb # Scope associated with a role of the user in the VO - IdPScope = some_special_scope + IdPRole = some_special_scope # Virtual organization associated with the group VOMSVO = lhcb diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py index 2eb0230a7e8..88c0e7b2542 100644 --- a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -4,6 +4,7 @@ from __future__ import division from __future__ import print_function +from DIRAC import S_OK from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider __RCSID__ = "$Id$" @@ -11,20 +12,24 @@ class CheckInIdProvider(OAuth2IdProvider): - # urn:mace:egi.eu:group:registry:training.egi.eu:role=member#aai.egi.eu' - NAMESPACE = 'urn:mace:egi.eu:group:registry' - SIGN = '#aai.egi.eu' - PARAM_SCOPE = 'eduperson_entitlement?value=' - - def researchGroup(self, payload, token=None): + def researchGroup(self, payload=None, token=None): """ Research group + + :param str payload: token payload + :param str token: access token + + :return: S_OK(dict)/S_ERROR() """ if token: - self.token = token - claims = self.getUserProfile() - credDict = self.parseBasic(claims) - credDict.update(self.parseEduperson(claims)) - cerdDict = self.userDiscover(credDict) - credDict['provider'] = self.name - - return credDict + self.token = {'access_token': token} + + result = self.getUserProfile() + if not result['OK']: + return result + payload = result['Value'] + + credDict = self.parseBasic(payload) + if not credDict.get('DIRACGroups'): + credDict.update(self.parseEduperson(payload)) + credDict['group'] = credDict.get('DIRACGroups', [None])[0] + return S_OK(credDict) diff --git a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py index a8187e61d82..c3f51e04a7a 100644 --- a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py @@ -11,7 +11,4 @@ class IAMIdProvider(OAuth2IdProvider): - def researchGroup(self, payload, token): - """ Research group - """ - pass + pass diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 1608d51c654..80b19de5ec9 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -45,14 +45,18 @@ def getIdProviderForToken(self, token): """ This method returns a IdProvider instance corresponding to the supplied issuer in a token. - :param str token: token + :param token: access token or dict with access_token key :return: S_OK(IdProvider)/S_ERROR() """ + if isinstance(token, dict): + token = token['access_token'] + data = {} # Read token without verification to get issuer - issuer = jwt.decode(token, options=dict(verify_signature=False))['iss'].strip('/') + issuer = jwt.decode(token, leeway=300, + options=dict(verify_signature=False, verify_aud=False))['iss'].strip('/') result = getSettingsNamesForIdPIssuer(issuer) if result['OK']: diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 17955a9a8a3..57bc8d57c53 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -20,11 +20,11 @@ from authlib.oauth2.rfc7636 import create_s256_code_challenge from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities import ThreadSafe from DIRAC.Resources.IdProvider.IdProvider import IdProvider -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption, getAllGroups __RCSID__ = "$Id$" @@ -33,13 +33,16 @@ 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' } +gJWKs = ThreadSafe.Synchronizer() +gMetadata = ThreadSafe.Synchronizer() +gRefreshToken = ThreadSafe.Synchronizer() + def claimParser(claimDict, attributes): - """ Parse claims to write it as DIRAC profile + """ Parse claims to dictionary with certain keys :param dict claimDict: claims :param dict attributes: contain claim and regex to parse it - :param dict profile: to fill parsed data :return: dict """ @@ -77,6 +80,7 @@ class OAuth2IdProvider(IdProvider, OAuth2Session): """ Base class to describe the configuration of the OAuth2 client of the corresponding provider. """ + JWKS_REFRESH_RATE = 24 * 3600 METADATA_REFRESH_RATE = 24 * 3600 def __init__(self, **kwargs): @@ -90,85 +94,218 @@ def __init__(self, **kwargs): self.verify = kwargs.get('verify', False) self.token_placement = kwargs.get('token_placement', 'header') self.code_challenge_method = 'S256' - self.token_endpoint_auth_method = kwargs.get('token_endpoint_auth_method', 'client_secret_post') + # self.token_endpoint_auth_method = kwargs.get('token_endpoint_auth_method') #, 'client_secret_post') self.server_metadata_url = kwargs.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) + self.jwks_fetch_last = time.time() - self.JWKS_REFRESH_RATE self.metadata_fetch_last = time.time() - self.METADATA_REFRESH_RATE self.log.debug('"%s" OAuth2 IdP initialization done:' % self.name, - '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, - self.client_secret, + '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, self.client_secret, pprint.pformat(self.metadata))) - def verifyToken(self, accessToken, jwks=None): + + def get_metadata(self, option=None): + """ Get metadata + + :param str option: option + + :return: option value + """ + if not self.metadata.get(option): + self.fetch_metadata() + return self.metadata.get(option) + + @gMetadata + def fetch_metadata(self): + """ Fetch metada + """ + if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): + data = self.get(self.server_metadata_url, withhold_token=True).json() + self.metadata.update(data) + self.metadata_fetch_last = time.time() + + @gJWKs + def updateJWKs(self): + """ Update JWKs + """ + if self.jwks_fetch_last < (time.time() - self.JWKS_REFRESH_RATE): + try: + self.jwks = self.get(self.get_metadata('jwks_uri'), withhold_token=True).json() + self.jwks_fetch_last = time.time() + return S_OK(self.jwks) + except Exception as e: + self.log.exception(e) + return S_ERROR("Error %s" % repr(e)) + return S_OK() + + def verifyToken(self, accessToken=None, jwks=None): """ Verify access token :param str accessToken: access token + :param dict jwks: JWKs :return: dict """ - jwks = jwks or self.jwks - self.log.debug("Try to decode token %s with JWKs:\n" % accessToken, pprint.pformat(jwks)) + # Define an access token + if not accessToken: + accessToken = self.token['access_token'] + # Renew a JWKs of an identity provider if needed + if not jwks: + result = self.updateJWKs() + if not result['OK']: + return result + jwks = self.jwks if not jwks: - raise Exception("JWKs not found.") - # Try to decode and verify token - return jwt.decode(accessToken, JsonWebKey.import_key_set(jwks)) + return S_ERROR("JWKs not found.") + # Try to decode and verify an access token + self.log.debug("Try to decode token %s with JWKs:\n" % accessToken, pprint.pformat(jwks)) + try: + return S_OK(jwt.decode(accessToken, JsonWebKey.import_key_set(jwks))) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) - def refreshToken(self, refresh_token): + @gRefreshToken + def refreshToken(self, refresh_token=None): """ Refresh token :param str token: refresh_token :return: dict """ - return self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token) + if not refresh_token: + refresh_token = self.token.get('refresh_token') + try: + return S_OK(self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token)) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + @gRefreshToken + def fetchToken(self, **kwargs): + """ Fetch token + + :return: dict + """ + try: + self.fetch_access_token(self.get_metadata('token_endpoint'), **kwargs) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + self.token['client_id'] = self.client_id + self.token['provider'] = self.name + return S_OK(self.token) def revokeToken(self, token=None, token_type_hint='refresh_token'): """ Revoke token :param str token: token :param str token_type_hint: token type + + :return: S_OK()/S_ERROR() """ - self.revoke_token(self.get_metadata('revocation_endpoint'), token=token, token_type_hint=token_type_hint) + if not token: + tokn = self.token.get(token_type_hint) + try: + self.revoke_token(self.get_metadata('revocation_endpoint'), token=token, token_type_hint=token_type_hint) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + return S_OK() - def get_metadata(self, option=None): - """ Get metadata + def exchangeGroup(self, group): + """ Get new tokens for group scope - :param str option: option + :param str group: requested group - :return: option value + :return: dict -- token """ - if not self.metadata.get(option): - self.fetch_metadata() - return self.metadata.get(option) + result = self.getGroupScopes(group) + if not result['OK']: + return result + groupScopes = result['Value'] + try: + token = self.exchange_token(self.get_metadata('token_endpoint'), subject_token=self.token['access_token'], + subject_token_type='urn:ietf:params:oauth:token-type:access_token', + scope=list_to_scope(scope_to_list(self.scope) + groupScopes)) + if not token: + return S_ERROR('Cannot exchange token with %s group.' % group) + self.token = token + return S_OK(token) + except Exception as e: + self.log.exception(e) + return S_ERROR('Cannot exchange token with %s group: %s' % (group, repr(e))) - def fetch_metadata(self): - """ Fetch metada + def researchGroup(self, payload=None, token=None): + """ Research group + + :param str payload: token payload + :param str token: access token + + :return: S_OK(dict)/S_ERROR() """ - if self.metadata_fetch_last < (time.time() - self.METADATA_REFRESH_RATE): - data = self.get(self.server_metadata_url, withhold_token=True).json() - self.metadata.update(data) - self.metadata_fetch_last = time.time() + credDict = self.parseBasic(payload) + if not credDict.get('DIRACGroups'): + credDict.update(self.parseEduperson(payload)) + if credDict.get('DIRACGroups'): + self.log.debug('Found next groups:', ', '.join(credDict['DIRACGroups'])) + credDict['group'] = credDict['DIRACGroups'][0] + return S_OK(credDict) - def updateJWKs(self): - """ Update JWKs + def parseBasic(self, claimDict): + """ Parse basic claims + + :param dict claimDict: claims + + :return: S_OK(dict)/S_ERROR() """ - try: - response = requests.get(self.get_metadata('jwks_uri'), verify=self.verify) - response.raise_for_status() - self.jwks = response.json() - return S_OK(self.jwks) - except requests.exceptions.RequestException as e: - return S_ERROR("Error %s" % e) - - def researchGroup(self, payload, token): - """ Research group + self.log.debug('Token payload:', pprint.pformat(claimDict)) + credDict = {} + credDict['ID'] = claimDict['sub'] + credDict['DN'] = '/O=DIRAC/CN=%s' % credDict['ID'] + if claimDict.get('scope'): + self.log.debug('Search groups for %s scope.' % claimDict['scope']) + credDict['DIRACGroups'] = self.getScopeGroups(claimDict['scope']) + return credDict + + def parseEduperson(self, claimDict): + """ Parse eduperson claims + + :return: dict """ - credDict = self.parseBasic(payload) - if not credDict.get('group'): - cerdDict = self.userDiscover(credDict) - credDict['provider'] = self.name + vos = {} + credDict = {} + attributes = { + 'eduperson_unique_id': '^(?P.*)', + 'eduperson_entitlement': '%s:%s' % ('^(?P[A-z,.,_,-,:]+):(group:registry|group)', + '(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*') + } + self.log.debug('Try to parse eduperson claims..') + # Parse eduperson claims + resDict = claimParser(claimDict, attributes) + if resDict.get('eduperson_unique_id'): + self.log.debug('Found eduperson_unique_id claim:', pprint.pformat(resDict['eduperson_unique_id'])) + credDict['ID'] = resDict['eduperson_unique_id']['ID'] + if resDict.get('eduperson_entitlement'): + self.log.debug('Found eduperson_entitlement claim:', pprint.pformat(resDict['eduperson_entitlement'])) + for voDict in resDict['eduperson_entitlement']: + if voDict['VO'] not in vos: + vos[voDict['VO']] = {'VORoles': []} + if voDict['VORole'] not in vos[voDict['VO']]['VORoles']: + vos[voDict['VO']]['VORoles'].append(voDict['VORole']) + # Search DIRAC groups + for vo in vos: + result = getVOMSRoleGroupMapping(vo) + if not result['OK']: + # Skip VO if it absent in Registry + self.log.debug(result['Message']) + continue + for role in vos[vo]['VORoles']: + groups = result['Value']['VOMSDIRAC'].get('/%s/%s' % (vo, role)) + if groups: + credDict['DIRACGroups'] = list(set(credDict.get('DIRACGroups', []) + groups)) return credDict - def authorization(self, group=None): + def deviceAuthorization(self, group=None): """ Authorizaion through DeviceCode flow """ result = self.submitDeviceCodeAuthorizationFlow(group) @@ -180,8 +317,10 @@ def authorization(self, group=None): showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], response['verification_uri']) self.log.notice(showURL) - - return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) + try: + return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) + except KeyboardInterrupt: + return S_ERROR('User canceled the operation..') def submitNewSession(self, pkce=True): """ Submit new authorization session @@ -218,81 +357,20 @@ def parseAuthResponse(self, response, session=None): self.log.debug('Current session is:\n', pprint.pformat(session)) self.fetchToken(authorization_response=response.uri, code_verifier=session.get('code_verifier')) - # Get user info - claims = self.getUserProfile() - credDict = self.parseBasic(claims) - credDict.update(self.parseEduperson(claims)) - cerdDict = self.userDiscover(credDict) - self.log.debug('Got response dictionary:\n', pprint.pformat(cerdDict)) + result = self.verifyToken(self.token['access_token']) + if result['OK']: + result = self.researchGroup(result['Value']) + if not result['OK']: + return result + credDict = result['Value'] + self.log.debug('Got response dictionary:\n', pprint.pformat(credDict)) # Store token self.token['user_id'] = credDict['ID'] - self.log.debug('Store token to the database:\n', pprint.pformat(dict(self.token))) return S_OK(credDict) - def fetchToken(self, **kwargs): - """ Fetch token - - :return: dict - """ - self.fetch_access_token(self.get_metadata('token_endpoint'), **kwargs) - self.token['client_id'] = self.client_id - self.token['provider'] = self.name - return OAuth2Token(self.token) - - def parseBasic(self, claimDict): - """ Parse basic claims - - :param dict claimDict: claims - - :return: S_OK(dict)/S_ERROR() - """ - credDict = {} - credDict['ID'] = claimDict['sub'] - credDict['DN'] = '/O=DIRAC/CN=%s' % credDict['ID'] - credDict['group'] = claimDict.get('group') - return credDict - - def parseEduperson(self, claimDict): - """ Parse eduperson claims - - :return: dict - """ - credDict = {} - attributes = { - 'eduperson_unique_id': '^(?P.*)', - 'eduperson_entitlement': '%s:%s' % ('^(?P[A-z,.,_,-,:]+):(group:registry|group)', - '(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*') - } - if 'eduperson_entitlement' not in claimDict: - claimDict = self.getUserProfile() - resDict = claimParser(claimDict, attributes) - if not resDict: - return credDict - credDict['ID'] = resDict['eduperson_unique_id']['ID'] - credDict['VOs'] = {} - for voDict in resDict['eduperson_entitlement']: - if voDict['VO'] not in credDict['VOs']: - credDict['VOs'][voDict['VO']] = {'VORoles': []} - if voDict['VORole'] not in credDict['VOs'][voDict['VO']]['VORoles']: - credDict['VOs'][voDict['VO']]['VORoles'].append(voDict['VORole']) - return credDict - - def userDiscover(self, credDict): - credDict['DIRACGroups'] = [] - for vo, voData in credDict.get('VOs', {}).items(): - result = getVOMSRoleGroupMapping(vo) - if result['OK']: - for role in voData['VORoles']: - groups = result['Value']['VOMSDIRAC'].get('/%s' % role) - if groups: - credDict['DIRACGroups'] = list(set(credDict['DIRACGroups'] + groups)) - if credDict['DIRACGroups']: - credDict['group'] = credDict['DIRACGroups'][0] - return credDict - def submitDeviceCodeAuthorizationFlow(self, group=None): """ Submit authorization flow @@ -361,35 +439,39 @@ def getGroupScopes(self, group): :return: list """ - idPScope = getGroupOption(group, 'IdPScope') + idPScope = getGroupOption(group, 'IdPRole') if not idPScope: return S_ERROR('Cannot find role for %s' % group) return S_OK(scope_to_list(idPScope)) + + def getScopeGroups(self, scope): + """ Get scope groups - def exchangeGroup(self, group): - """ Get new tokens for group scope + :param str scope: scope - :param str group: requested group + :return: list + """ + groups = [] + for group in getAllGroups(): + result = self.getGroupScopes(group) + if not result['OK']: + # Skip DIRAAC group without scope parameter + self.log.debug(result['Message']) + continue + if set(result['Value']).issubset(scope_to_list(scope)): + groups.append(group) + return groups - :return: dict -- token + def getUserProfile(self): + """ Get user profile + + :return: S_OK()/S_ERROR() """ - result = self.getGroupScopes(group) - if not result['OK']: - return result - groupScopes = result['Value'] try: - token = self.exchange_token(self.get_metadata('token_endpoint'), subject_token=self.token['access_token'], - subject_token_type='urn:ietf:params:oauth:token-type:access_token', - scope=list_to_scope(scope_to_list(self.scope) + groupScopes)) - if not token: - return S_ERROR('Cannot exchange token with %s group.' % group) - self.token = token - return S_OK(token) + return S_OK(self.get(self.get_metadata('userinfo_endpoint')).json()) except Exception as e: - return S_ERROR('Cannot exchange token with %s group: %s' % (group, repr(e))) - - def getUserProfile(self): - return self.get(self.get_metadata('userinfo_endpoint')).json() + self.log.exception(e) + return S_ERROR('Cannot get user profile: %s' % repr(e)) def exchange_token(self, url, subject_token=None, subject_token_type=None, body='', refresh_token=None, access_token=None, auth=None, headers=None, **kwargs): @@ -442,6 +524,3 @@ def _exchange_token(self, url, body='', refresh_token=None, headers=None, auth=N self.update_token(self.token, refresh_token=refresh_token) return self.token - - def generateState(self, session=None): - return session or generate_token(10) diff --git a/src/DIRAC/Resources/IdProvider/Utilities.py b/src/DIRAC/Resources/IdProvider/Utilities.py index ca2134336ce..d363c4d76a6 100644 --- a/src/DIRAC/Resources/IdProvider/Utilities.py +++ b/src/DIRAC/Resources/IdProvider/Utilities.py @@ -27,6 +27,24 @@ def getSettingsNamesForIdPIssuer(issuer): return S_OK(names) if names else S_ERROR('Not found provider with %s issuer.' % issuer) +def getSettingsNamesForClientID(clientID): + """ Get identity providers for clientID + + :param str clientID: clientID + + :return: S_OK(list)/S_ERROR() + """ + names = [] + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for name in result['Value']: + res = gConfig.getValue('/Resources/IdProviders/%s/client_id' % name) + if res and clientID == res: + names.append(name) + return S_OK(names) if names else S_ERROR('Not found provider with %s clientID.' % clientID) + + def getProvidersForInstance(instance, providerType=None): """ Get providers for instance diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 28c1bf69111..02f8991fc83 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -6,7 +6,7 @@ from __future__ import print_function import time -from authlib.jose import JsonWebKey, JsonWebSignature, jwt +from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads from DIRAC.Core.Base.Script import parseCommandLine @@ -24,13 +24,37 @@ 'setup': 'setup', 'group': 'my_group'} -exp_payload = {'sub': 'user', - 'iss': 'issuer', - 'iat': int(time.time()) - 10, - 'exp': int(time.time()) - 10, - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} +exp_payload = payload.copy() +exp_payload['iat'] = int(time.time()) - 10 +exp_payload['exp'] = int(time.time()) - 10 + +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + expires_at=int(time.time()) + 3600) + +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + +Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), + expires_at=int(time.time()) - 10) + + +def test_cryptToken(): + """ Try to encrypt/decrypt refresh token + """ + data = dict(client_id='clientID', provider='provider', expires_at=DToken['expires_at']) + result = db.encryptRefreshToken(DToken.copy(), data.copy()) + assert result['OK'], result['Message'] + assert result['Value']['refresh_token'] != DToken['refresh_token'] + + result = db.decryptRefreshToken({'refresh_token': result['Value']['refresh_token']}) + assert result['OK'], result['Message'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + for k in data: + assert result['Value'][k] == data[k] def test_Token(): @@ -40,51 +64,34 @@ def test_Token(): result = db.removeTokens() assert result['OK'], result['Message'] - # Get key - result = db.getPrivateKey() + # Store tokens + result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC') assert result['OK'], result['Message'] - privat_key = result['Value']['key'] + assert result['Value'] == [] - # Sign token - token = dict(access_token=jwt.encode({'alg': 'RS256'}, payload, privat_key), - expires_in=864000, - token_type='Bearer', - client_id='1hlUgttap3P9oTSXUwpIT50TVHxCflN3O98uHP217Y', - scope='g:checkin-integration_user', - refresh_token=jwt.encode({'alg': 'RS256'}, payload, privat_key)) # Expired token - exp_token = dict(access_token=jwt.encode({'alg': 'RS256'}, exp_payload, privat_key), - expires_in=864000, - token_type='Bearer', - client_id='1hlUgttap3P9oTSXUwpIT50TVHxCflN3O98uHP217Y', - scope='g:checkin-integration_user', - refresh_token=jwt.encode({'alg': 'RS256'}, exp_payload, privat_key)) - - # Store tokens - result = db.storeToken(token) - assert result['OK'], result['Message'] - result = db.storeToken(token) - assert result['OK'], result['Message'] + # result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC') + # assert not result['OK'] # Check token - result = db.getToken(token['refresh_token']) + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') assert result['OK'], result['Message'] - assert result['Value']['access_token'] == token['access_token'] - assert result['Value']['refresh_token'] == token['refresh_token'] - assert not result['Value']['revoked'] + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] - # Check expired token - result = db.getToken(exp_token['refresh_token']) - assert not result['OK'] - - # Revoke token - result = db.revokeToken(token) + # Store new tokens + result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC') assert result['OK'], result['Message'] + # Must return old tokens + assert len(result['Value']) == 1 + assert result['Value'][0]['access_token'] == DToken['access_token'] + assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] - # Check if token revoked - result = db.getToken(token['refresh_token']) + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') assert result['OK'], result['Message'] - assert result['Value']['revoked'] + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] def test_keys(): @@ -118,10 +125,18 @@ def test_keys(): # Create new one result = db.getPrivateKey() assert result['OK'], result['Message'] + assert type(result['Value']['rsakey']) is RSAKey + assert type(result['Value']['strkey']) is str # Sign token header['kid'] = result['Value']['kid'] - private_key = result['Value']['key'] + private_key = result['Value']['rsakey'] + + # Find key by KID + result = db.getPrivateKey(header['kid']) + assert result['OK'], result['Message'] + assert result['Value']['rsakey'] == private_key + token = jwt.encode(header, payload, private_key) # Sign auth code code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) From a5568fd26d405eba405101c7a086b8a33525fba6 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 02:01:22 +0200 Subject: [PATCH 038/178] fix checkin --- dirac.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dirac.cfg b/dirac.cfg index 3836f791a4f..7f3aaaad8ae 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -58,7 +58,7 @@ Registry VOMSName = lhcb # Registered identity provider associated with VO - IdP = ChechIn + IdP = CheckIn # Section to describe all the VOMS servers that can be used with the given VOMS VO VOMSServers From bdaf81843efeb26a5fb76712996cbdec86721b06 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 02:22:20 +0200 Subject: [PATCH 039/178] fix tests --- src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py | 13 ++++++------- src/DIRAC/Core/Tornado/Server/TornadoREST.py | 2 +- src/DIRAC/Core/Tornado/Server/TornadoService.py | 5 +++++ src/DIRAC/FrameworkSystem/DB/AuthDB.py | 6 +++--- .../private/authorization/AuthServer.py | 3 ++- .../private/authorization/grants/RefreshToken.py | 6 +++--- .../private/authorization/utils/Tokens.py | 10 +++++----- src/DIRAC/FrameworkSystem/scripts/dirac_logout.py | 2 +- src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py | 3 +-- 9 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py index 1b995a6f2f2..89e5d35f1d2 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -49,12 +49,12 @@ class BaseRequestHandler(RequestHandler): This class is basic for :py:class:`DIRAC.Core.Tornado.Server.TornadoService.TornadoService` and :py:class:`DIRAC.Core.Tornado.Server.TornadoREST.TornadoREST`. - + In order to create a class that inherits from `BaseRequestHandler`, it has to follow a certain skeleton:: class TornadoInstance(BaseRequestHandler): - + # Prefix of methods names METHOD_PREFIX = "export_" @@ -72,7 +72,7 @@ def _getServiceInfo(cls, serviceName, request): 'serviceSectionPath': PathFinder.getServiceSection(serviceName), 'csPaths': [PathFinder.getServiceSection(serviceName)], 'URL': request.full_url()} - + @classmethod def _getServiceAuthSection(cls, serviceName): ''' Search service auth section. @@ -113,7 +113,7 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ whenever ``TransferClient.receiveFile`` is called. It is the equivalent of the DISET ``transfer_toClient``. Note that this is here only for compatibility, and we discourage using it for new purposes, as it is - bound to disappear. + bound to disappear. """ # Because we initialize at first request, we use a flag to know if it's already done __init_done = False @@ -473,9 +473,8 @@ def _finishFuture(self, retVal): # in a thread anymore # Is it S_OK or S_ERROR? - if (isinstance(self.result, dict) and - isinstance(self.result.get('OK'), bool) and - ('Value' if self.result['OK'] else 'Message') in self.result): + r = self.result + if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: self._parseDIRACResult(self.result) # If set to true, do not JEncode the return of the RPC call diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py index 7b966c00c3a..13d484eeba5 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoREST.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -17,7 +17,7 @@ import DIRAC -from DIRAC import gLogger +from DIRAC import gLogger, S_OK from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index 4f765b9f6ba..9aa7bca807d 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -9,6 +9,11 @@ __RCSID__ = "$Id$" +from io import open + +import os +import datetime + import tornado.ioloop from tornado import gen from tornado.ioloop import IOLoop diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index a1017db5d22..ec1468e136b 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -213,7 +213,8 @@ def updateToken(self, token, userID, provider): token['user_id'] = userID token['provider'] = provider try: - token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, verify_aud=False))['exp']) + token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, + verify_aud=False))['exp']) except Exception as e: self.log.debug('Cannot get refresh token expires time: %s' % repr(e)) @@ -221,7 +222,6 @@ def updateToken(self, token, userID, provider): if token['rt_expires_at'] < time.time(): return S_ERROR('Cannot store expired refresh token.') - attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) self.log.debug('Store token:', pprint.pformat(attrts)) session = self.session() @@ -309,7 +309,7 @@ def getPrivateKey(self, kid=None): return result jwks = result['Value'] if kid: - strkey=jwks[0]['key'] + strkey = jwks[0]['key'] return S_OK(dict(rsakey=RSAKey.import_key(json.loads(strkey)), kid=kid, strkey=strkey)) newer = {} for jwk in jwks: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 89f275eff00..ebee16a46e6 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -29,7 +29,8 @@ from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getEmailsForGroup, getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, + getDNForUsername, getIdPForGroup) from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient log = logging.getLogger('authlib') diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py index 6e977b9eca3..c503d82fd79 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -22,7 +22,7 @@ def authenticate_refresh_token(self, refresh_token): if not result['OK']: raise OAuth2Error(result['Message']) return result['Value'] - + def _validate_token_scope(self, token): """ Skip scope validadtion """ pass @@ -33,7 +33,7 @@ def authenticate_user(self, credential): def issue_token(self, user, credential): """ Refresh tokens - + :param user: unuse :param dict credential: token credential @@ -52,4 +52,4 @@ def issue_token(self, user, credential): def revoke_old_credential(self, credential): """ Remove old credential """ - pass \ No newline at end of file + pass diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 02023c4855a..b8113252382 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -118,7 +118,7 @@ def __init__(self, params=None, **kwargs): # Get access token expires_at claim kwargs['expires_at'] = int(self.get_token_attr('exp')) super(OAuth2Token, self).__init__(kwargs) - + def get_client_id(self): return self.get('client_id') @@ -137,7 +137,7 @@ def scopes(self): :return: list """ - return scope_to_list(self.scope) or [] + return scope_to_list(self.get(scope, '') or [] @property def groups(self): @@ -146,10 +146,10 @@ def groups(self): :return: list """ return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] - + def get_token_attr(self, attr, token_type='access_token'): """ Get token attribute without verification - + :param str attr: attribute :param str token_type: token type @@ -208,4 +208,4 @@ def __formatTokenInfoAsString(self, infoDict): contentList.append('%s: %s' % ('DIRAC group'.ljust(leftAlign), infoDict['group'])) if infoDict.get('properties'): contentList.append('%s: %s' % ('properties'.ljust(leftAlign), ', '.join(infoDict['properties']))) - return "\n".join(contentList) \ No newline at end of file + return "\n".join(contentList) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py index 92c3978d71c..93b5ade9ee5 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py @@ -52,7 +52,7 @@ def setIssuer(self, arg): """ self.issuer = arg return S_OK() - + def setTokenFile(self, arg): """ Set token file diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 57bc8d57c53..3eeee9cce62 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -102,7 +102,6 @@ def __init__(self, **kwargs): '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, self.client_secret, pprint.pformat(self.metadata))) - def get_metadata(self, option=None): """ Get metadata @@ -443,7 +442,7 @@ def getGroupScopes(self, group): if not idPScope: return S_ERROR('Cannot find role for %s' % group) return S_OK(scope_to_list(idPScope)) - + def getScopeGroups(self, scope): """ Get scope groups From dccbaf1259806263f663fc4ca7dbd9777c510db9 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 13:09:24 +0200 Subject: [PATCH 040/178] fix Token scope --- src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index b8113252382..d99b431e23e 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -137,7 +137,7 @@ def scopes(self): :return: list """ - return scope_to_list(self.get(scope, '') or [] + return scope_to_list(self.get(scope, '')) @property def groups(self): From 533f1424c7e715b86c8aec95351805c1710e4da5 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 13:12:20 +0200 Subject: [PATCH 041/178] fix datetime --- src/DIRAC/Core/Tornado/Server/TornadoService.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index 9aa7bca807d..49326471720 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -12,7 +12,7 @@ from io import open import os -import datetime +from datetime import datetime import tornado.ioloop from tornado import gen From dd9cc3fdf00bfb1cbcf20b13aa19cc22fd0a1aac Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 15:53:48 +0200 Subject: [PATCH 042/178] fix --- src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index d99b431e23e..48218171449 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -137,7 +137,7 @@ def scopes(self): :return: list """ - return scope_to_list(self.get(scope, '')) + return scope_to_list(self.get('scope', '')) @property def groups(self): From 4cf71d6dddbefc841dab581e1e5322a5ff690a35 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 2 Jun 2021 16:22:34 +0200 Subject: [PATCH 043/178] fix setup --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 73724c20264..b76f5702f78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -161,8 +161,8 @@ console_scripts = dirac-dms-user-lfns = DIRAC.DataManagementSystem.scripts.dirac_dms_user_lfns:main dirac-dms-user-quota = DIRAC.DataManagementSystem.scripts.dirac_dms_user_quota:main # FrameworkSystem - dirac-login = DIRAC.FrameworkSystem.scripts.dirac_login:main [server] - dirac-logout = DIRAC.FrameworkSystem.scripts.dirac_logout:main [server] + dirac-login = DIRAC.FrameworkSystem.scripts.dirac_login:main + dirac-logout = DIRAC.FrameworkSystem.scripts.dirac_logout:main dirac-admin-get-CAs = DIRAC.FrameworkSystem.scripts.dirac_admin_get_CAs:main [server] dirac-admin-get-proxy = DIRAC.FrameworkSystem.scripts.dirac_admin_get_proxy:main [admin] dirac-admin-proxy-upload = DIRAC.FrameworkSystem.scripts.dirac_admin_proxy_upload:main [admin] From c86f520ba333a1d2157e334f8bc57a28d4f46d6a Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 4 Jun 2021 13:57:54 +0200 Subject: [PATCH 044/178] update authz method --- .../private/authorization/utils/Clients.py | 2 -- .../private/authorization/utils/Tokens.py | 30 ++++++++++++++----- .../Resources/IdProvider/CheckInIdProvider.py | 23 +------------- .../Resources/IdProvider/OAuth2IdProvider.py | 15 +++++++--- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index b1336c86b82..e4240f77f85 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -27,8 +27,6 @@ ProviderType='DIRACCLI' ), 'DIRACWeb': dict( - # token_endpoint_auth_method='client_secret_basic', - token_endpoint_auth_method='client_secret_post', response_types=['code'], grant_types=['authorization_code', 'refresh_token'], ProviderType='DIRACWeb' diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 48218171449..995b4433561 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -116,7 +116,7 @@ def __init__(self, params=None, **kwargs): kwargs.update(params or {}) if not kwargs.get('expires_at') and kwargs.get('access_token'): # Get access token expires_at claim - kwargs['expires_at'] = int(self.get_token_attr('exp')) + kwargs['expires_at'] = int(self.get_claim('exp')) super(OAuth2Token, self).__init__(kwargs) def get_client_id(self): @@ -147,20 +147,36 @@ def groups(self): """ return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] - def get_token_attr(self, attr, token_type='access_token'): - """ Get token attribute without verification + def get_payload(self, token_type='access_token'): + """ Decode token - :param str attr: attribute :param str token_type: token type - :return: str + :return: dict """ if not self.get(token_type): - return None + return {} return jwt.decode(self.get(token_type), options=dict(verify_signature=False, verify_exp=False, verify_aud=False, - verify_nbf=False)).get(attr) + verify_nbf=False)) + + def get_claim(self, claim, token_type='access_token'): + """ Get token claim without verification + + :param str attr: attribute + :param str token_type: token type + + :return: str + """ + return get_payload(token_type).get(claim) + + def dump_to_string(self): + """ Dump token dictionary to sting + + :return: str + """ + return json.dumps(dict(self)) def getInfoAsString(self): """ Return information about token as string diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py index 88c0e7b2542..a8bf1c8b8cb 100644 --- a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -4,7 +4,6 @@ from __future__ import division from __future__ import print_function -from DIRAC import S_OK from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider __RCSID__ = "$Id$" @@ -12,24 +11,4 @@ class CheckInIdProvider(OAuth2IdProvider): - def researchGroup(self, payload=None, token=None): - """ Research group - - :param str payload: token payload - :param str token: access token - - :return: S_OK(dict)/S_ERROR() - """ - if token: - self.token = {'access_token': token} - - result = self.getUserProfile() - if not result['OK']: - return result - payload = result['Value'] - - credDict = self.parseBasic(payload) - if not credDict.get('DIRACGroups'): - credDict.update(self.parseEduperson(payload)) - credDict['group'] = credDict.get('DIRACGroups', [None])[0] - return S_OK(credDict) + pass \ No newline at end of file diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 3eeee9cce62..4643394bc13 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -25,6 +25,7 @@ from DIRAC.Core.Utilities import ThreadSafe from DIRAC.Resources.IdProvider.IdProvider import IdProvider from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption, getAllGroups +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token __RCSID__ = "$Id$" @@ -174,7 +175,8 @@ def refreshToken(self, refresh_token=None): if not refresh_token: refresh_token = self.token.get('refresh_token') try: - return S_OK(self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token)) + token = self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token) + return S_OK(OAuth2Token(dict(token))) except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) @@ -192,7 +194,7 @@ def fetchToken(self, **kwargs): return S_ERROR(repr(e)) self.token['client_id'] = self.client_id self.token['provider'] = self.name - return S_OK(self.token) + return S_OK(OAuth2Token(dict(self.token))) def revokeToken(self, token=None, token_type_hint='refresh_token'): """ Revoke token @@ -228,8 +230,7 @@ def exchangeGroup(self, group): scope=list_to_scope(scope_to_list(self.scope) + groupScopes)) if not token: return S_ERROR('Cannot exchange token with %s group.' % group) - self.token = token - return S_OK(token) + return S_OK(OAuth2Token(dict(token))) except Exception as e: self.log.exception(e) return S_ERROR('Cannot exchange token with %s group: %s' % (group, repr(e))) @@ -242,6 +243,12 @@ def researchGroup(self, payload=None, token=None): :return: S_OK(dict)/S_ERROR() """ + if not token: + token = self.token + + if not payload and token: + payload = OAuth2Token(dict(token)).get_payload() + credDict = self.parseBasic(payload) if not credDict.get('DIRACGroups'): credDict.update(self.parseEduperson(payload)) From f250f24c4a98c1bc11869f719c0399e214458170 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 4 Jun 2021 14:11:10 +0200 Subject: [PATCH 045/178] add env description --- .../environment_variable_configuration.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index 4725a3903f5..d93feefb312 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -65,3 +65,9 @@ DIRAC_X509_HOST_KEY X509_VOMSES Must be set to point to a folder containing VOMSES information. See :ref:`multi_vo_dirac` + +BEARER_TOKEN + If the environment variable is set, then the value is taken to be the token contents + +BEARER_TOKEN_FILE + If the environment variable is set, then its value is interpreted as a filename. The contents of thespecified file are taken to be the token contents. From 292c9dd38b263cb835007ce94abd7113827eca95 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:32:18 +0200 Subject: [PATCH 046/178] provide methods to wrap userID as DN --- .../Client/Helpers/Registry.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index 1d6faeb7374..7d462612324 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -707,3 +707,33 @@ def getEmailsForGroup(groupName): email = getUserOption(username, 'Email', []) emails.append(email) return emails + + +def wrapIDAsDN(userID): + """ Wrap user ID as user DN + + :param str userID: user ID + + :return: str + """ + return '/O=DIRAC/CN=%s' % userID + + +def getIDFromDN(userDN): + """ Parse user ID from user DN + + :param str userDN: user DN + + :return: str + """ + return userDN.strip('/O=DIRAC/CN=') + + +def isDNWrappedID(user): + """ Is it wrapped user ID? + + :param str user: user ID + + :return: bool + """ + return user.startswith('/O=DIRAC/CN=') From 7e847537058f50ff695e6ae033957bc75ab49853 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:33:20 +0200 Subject: [PATCH 047/178] move Authorization conf section to DIRAC/Security --- src/DIRAC/ConfigurationSystem/Client/Utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index 386e38bcdf0..4bf1d0f1259 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -584,7 +584,7 @@ def getAuthorizationServerMetadata(issuer=None): :return: S_OK(dict)/S_ERROR() """ - result = gConfig.getOptionsDictRecursively('/DIRAC/Authorization') + result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization') if not result['OK']: return {'issuer': issuer} if issuer else result data = result['Value'] From 7215564e86e26028dd83d34ea9371a15e082838f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:40:48 +0200 Subject: [PATCH 048/178] move BaseRequestHandler to private, update TornadoService and TornadoREST --- src/DIRAC/Core/Tornado/Server/TornadoREST.py | 16 ++-- .../Core/Tornado/Server/TornadoService.py | 4 +- .../{ => private}/BaseRequestHandler.py | 95 +++++++++++-------- .../Core/Tornado/Server/private/__init__.py | 0 4 files changed, 66 insertions(+), 49 deletions(-) rename src/DIRAC/Core/Tornado/Server/{ => private}/BaseRequestHandler.py (90%) create mode 100644 src/DIRAC/Core/Tornado/Server/private/__init__.py diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py index 13d484eeba5..78fd024e266 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoREST.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -19,7 +19,7 @@ from DIRAC import gLogger, S_OK from DIRAC.ConfigurationSystem.Client import PathFinder -from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import BaseRequestHandler sLog = gLogger.getSubLogger(__name__) @@ -96,7 +96,9 @@ def _getServiceAuthSection(cls, endpointName): return "%s/Authorization" % PathFinder.getAPISection(endpointName) def _getMethodName(self): - """ Parse method name. + """ Parse method name. By default we read the first section in the path + following the coincidence with the value of `LOCATION`. + If such a method is not defined, then try to use the `index` method. :return: str """ @@ -107,15 +109,14 @@ def _getMethodName(self): gLogger.warn('%s method not implemented. Use the index method to handle this.' % method) return 'index' else: - raise NotImplementedError('%s method not implemented. \ - You can use the index method to handle this.' % method) + raise NotImplementedError('%s method not implemented. You can use the index method to handle this.' % method) @gen.coroutine def get(self, *args, **kwargs): # pylint: disable=arguments-differ """ Method to handle incoming ``GET`` requests. Note that all the arguments are already prepared in the :py:meth:`.prepare` method. """ - retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) self._finishFuture(retVal) @gen.coroutine @@ -123,7 +124,7 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ """ Method to handle incoming ``POST`` requests. Note that all the arguments are already prepared in the :py:meth:`.prepare` method. """ - retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) self._finishFuture(retVal) auth_echo = ['all'] @@ -138,8 +139,7 @@ def web_echo(data): auth_whoami = ['authenticated'] def web_whoami(self): - """ - A simple whoami, returns all credential dictionary, except certificate chain object. + """ A simple whoami, returns all credential dictionary, except certificate chain object. """ credDict = self.srv_getRemoteCredentials() if 'x509Chain' in credDict: diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index 49326471720..10568c763cc 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -25,7 +25,7 @@ from DIRAC.Core.DISET.AuthManager import AuthManager from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import BaseRequestHandler from DIRAC.ConfigurationSystem.Client import PathFinder sLog = gLogger.getSubLogger(__name__) @@ -390,7 +390,7 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ # However, we can still rely on instance attributes to store what should # be sent back (reminder: there is an instance # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) # retVal is :py:class:`tornado.concurrent.Future` self._finishFuture(retVal) diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py similarity index 90% rename from src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py rename to src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 89e5d35f1d2..f95e692621e 100644 --- a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -17,6 +17,7 @@ from six import string_types from six.moves import http_client from six.moves.urllib.parse import unquote +from functools import partial import tornado from tornado import gen @@ -27,6 +28,7 @@ import DIRAC from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.Utilities import DErrno from DIRAC.Core.DISET.AuthManager import AuthManager from DIRAC.Core.Utilities.JEncode import decode, encode from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error @@ -99,12 +101,12 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ ''' # Execute the method in an executor (basically a separate thread) # Because of that, we cannot calls certain methods like `self.write` - # in _executeMethod. This is because these methods are not threadsafe + # in __executeMethod. This is because these methods are not threadsafe # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes # However, we can still rely on instance attributes to store what should # be sent back (reminder: there is an instance # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + retVal = yield IOLoop.current().run_in_executor(self._prepareExecutor(args)) # retVal is :py:class:`tornado.concurrent.Future` self._finishFuture(retVal) @@ -207,13 +209,14 @@ def _getServiceInfo(cls, serviceName, request): :return: dict """ + gLogger.warn('Service information will not be collected because the _getServiceInfo method is not defined.') return {} @classmethod def __loadIdPs(cls): """ Load identity providers that will be used to verify tokens """ - gLogger.info('Load identit providers..') + gLogger.info('Load identity providers..') # Research Identity Providers result = getProvidersForInstance('Id') if result['OK']: @@ -227,7 +230,7 @@ def __loadIdPs(cls): def __initializeService(cls, request): """ Initialize a service. - The work is only perform once at the first request. + The work is only performed once at the first request. :param object request: tornado Request @@ -356,7 +359,7 @@ def _getMethodArgs(self, args): def _getMethodAuthProps(self): """ Resolves the hard coded authorization requirements for method. - :return: object + :return: list """ try: return getattr(self, 'auth_' + self.method) @@ -366,15 +369,19 @@ def _getMethodAuthProps(self): return self.AUTH_PROPS def _getMethod(self): - """ Get method object. + """ Get method function to call. - :return: object + :return: function """ try: - return getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method)) + method = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method)) except AttributeError as e: sLog.error("Invalid method", self.method) raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + if not callable(method): + sLog.error("Invalid method", self.method) + raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + return method def prepare(self): """ @@ -428,7 +435,7 @@ def _prepare(self): raise HTTPError(status_code=http_client.UNAUTHORIZED) @gen.coroutine - def _executeMethod(self, args): + def __executeMethod(self, targetMethod, args): """ Execute the method called, this method is ran in an executor We have several try except to catch the different problem which can occur @@ -439,6 +446,11 @@ def _executeMethod(self, args): .. warning:: This method is called in an executor, and so cannot use methods like self.write See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + + :param str targetMethod: name of the method to call + :param list args: target method arguments + + :return: Future """ sLog.notice( @@ -447,35 +459,49 @@ def _executeMethod(self, args): self._serviceName, self.method)) - # getting method - method = self._getMethod() - methodArgs = self._getMethodArgs(args) - # Execute try: self.initializeRequest() - retVal = method(*methodArgs) + retVal = targetMethod(*args) except Exception as e: # pylint: disable=broad-except sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) raise HTTPError(http_client.INTERNAL_SERVER_ERROR) return retVal + def _prepareExecutor(self, args): + """ Preparation of necessary arguments for the `__executeMethod` method + + :param list args: arguments passed to the `post`, `get`, etc. tornado methods + + :return: executor, target method with arguments + """ + return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) + def _finishFuture(self, retVal): """ Handler Future result :param object retVal: tornado.concurrent.Future """ - # Wait result only if it's a Future object - self.result = retVal.result() if isinstance(retVal, Future) else retVal + # Wait result of a Future object + self.result = retVal.result() + + # Here it is safe to write back to the client, because we are not in a thread anymore + + # If you need to end the method using tornado methods, outside the thread, + # you need to define the finish_ method. + # This method will be started after __executeMethod is completed. + try: + finishFunc = eval('self.finish_%s' % self.method) + except (NameError, AttributeError): + finishFunc = None - # Here it is safe to write back to the client, because we are not - # in a thread anymore + if callable(finishFunc): + finishFunc() - # Is it S_OK or S_ERROR? - r = self.result - if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: - self._parseDIRACResult(self.result) + # In case nothing is returned + elif self.result is None: + self.finish() # If set to true, do not JEncode the return of the RPC call # This is basically only used for file download through @@ -483,25 +509,16 @@ def _finishFuture(self, retVal): elif self.get_argument('rawContent', default=False): # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt self.set_header("Content-Type", "application/octet-stream") - self.write(self.result) + self.finish(self.result) # Return simple text or html elif isinstance(self.result, string_types): - self.write(self.result) + self.finish(self.result) # JSON - elif isinstance(self.result, dict): + else: self.set_header("Content-Type", "application/json") - self.write(encode(self.result)) - - self.finish() - - def _parseDIRACResult(self, result): - """ Processing of a standard DIRAC result, - but in a separate method so that it can be modified for another class if necessary - """ - self.set_header("Content-Type", "application/json") - self.write(encode(result)) + self.finish(encode(self.result)) def on_finish(self): """ @@ -541,7 +558,7 @@ def _gatherPeerCredentials(self, grants=None): for a in grants: grant = a.upper() - grantFunc = eval('self._authz%s' % grant) + grantFunc = getattr(self, '_authz%s' % grant) if not callable(grantFunc): raise Exception('%s authentication type is not supported.' % grant) result = grantFunc() @@ -574,7 +591,7 @@ def _authzSSL(self): chainAsTextEncoded = self.request.headers.get('X-SSL-CERT') chainAsText = unquote(chainAsTextEncoded) else: - return S_ERROR('Not found a valide client certificate.') + return S_ERROR(DErrno.ECERTFIND, 'Valid certificate not found.') peerChain.loadChainFromString(chainAsText) @@ -603,10 +620,10 @@ def _authzJWT(self, accessToken=None): # Export token from headers token = self.request.headers.get('Authorization') if not token or len(token.split()) != 2: - return S_ERROR('Not found a bearer access token.') + return S_ERROR(DErrno.EATOKENFIND, 'Not found a bearer access token.') tokenType, accessToken = token.split() if tokenType.lower() != 'bearer': - return S_ERROR('Found a not bearer access token.') + return S_ERROR(DErrno.ETOKENTYPE, 'Found a not bearer access token.') # Read token without verification to get issuer self.log.debug('Read issuer from access token', accessToken) diff --git a/src/DIRAC/Core/Tornado/Server/private/__init__.py b/src/DIRAC/Core/Tornado/Server/private/__init__.py new file mode 100644 index 00000000000..e69de29bb2d From deb99227af7c363c5e3f5e471d3ecc32b8eaaab8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:41:36 +0200 Subject: [PATCH 049/178] add DErrno errors --- src/DIRAC/Core/Utilities/DErrno.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/DIRAC/Core/Utilities/DErrno.py b/src/DIRAC/Core/Utilities/DErrno.py index 54b692d2fd1..68ed494f359 100644 --- a/src/DIRAC/Core/Utilities/DErrno.py +++ b/src/DIRAC/Core/Utilities/DErrno.py @@ -105,6 +105,10 @@ EMQCONN = 1142 # Elasticsearch EELNOFOUND = 1146 +# Tokens +EATOKENFIND = 1150 +EATOKENREAD = 1151 +ETOKENTYPE = 1152 # config ESECTION = 1400 @@ -185,6 +189,12 @@ 1142: 'EMQCONN', # Elasticsearch 1146: 'EELNOFOUND', + + # 115X: Tokens + 1150: 'EATOKENFIND', + 1151: 'EATOKENREAD', + 1152: 'ETOKENTYPE', + # Config 1400: "ESECTION", # Processes @@ -260,6 +270,12 @@ EMQCONN: "MQ connection failure", # 114X Elasticsearch EELNOFOUND: "Index not found", + + # 115X: Tokens + EATOKENFIND: "Can't find a bearer access token.", + EATOKENREAD: "Can't read a bearer access token.", + ETOKENTYPE: "Unsupported access token type.", + # Config ESECTION: "Section is not found", # processes From cd7cb627b9db1044241ce37e3cfb3f4c0d6cb928 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:45:17 +0200 Subject: [PATCH 050/178] add TokenManager --- .../Client/TokenManagerClient.py | 18 ++ src/DIRAC/FrameworkSystem/ConfigTemplate.cfg | 10 + src/DIRAC/FrameworkSystem/DB/TokenDB.py | 176 +++++++++++++++++ src/DIRAC/FrameworkSystem/DB/TokenDB.sql | 2 + .../Service/TokenManagerHandler.py | 177 ++++++++++++++++++ tests/Integration/Framework/Test_TokenDB.py | 79 ++++++++ .../all_integration_server_tests.sh | 1 + 7 files changed, 463 insertions(+) create mode 100644 src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py create mode 100644 src/DIRAC/FrameworkSystem/DB/TokenDB.py create mode 100644 src/DIRAC/FrameworkSystem/DB/TokenDB.sql create mode 100644 src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py create mode 100644 tests/Integration/Framework/Test_TokenDB.py diff --git a/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py new file mode 100644 index 00000000000..c49abfef639 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py @@ -0,0 +1,18 @@ +""" The TokenManagerClient is a class representing the client of the DIRAC TokenManager service. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from DIRAC.Core.Base.Client import Client, createClient + + +@createClient('Framework/TokenManager') +class TokenManagerClient(Client): + """Client exposing the TokenManager Service.""" + + def __init__(self, **kwargs): + super(TokenManagerClient, self).__init__(**kwargs) + self.setServer('Framework/TokenManager') diff --git a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg index d023e7bf226..ec28555ee15 100644 --- a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg +++ b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg @@ -1,11 +1,14 @@ APIs { + ##BEGIN Auth: + # Section to describe RESTful API for DIRAC Authorization Server(AS) Auth { Port = 8000 # Allow download personal proxy downloadablePersonalProxy = True } + ##END } Services { @@ -22,6 +25,13 @@ Services storeHostInfo = Operator } } + ##BEGIN TokenManager: + # Section to describe TokenManager system + TokenManager + { + Protocol = https + } + ##END ##BEGIN ProxyManager: # Section to describe ProxyManager system # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Systems/Framework/ProxyManager/index.html diff --git a/src/DIRAC/FrameworkSystem/DB/TokenDB.py b/src/DIRAC/FrameworkSystem/DB/TokenDB.py new file mode 100644 index 00000000000..e59f0fd8a3d --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/TokenDB.py @@ -0,0 +1,176 @@ +""" Auth class is a front-end to the Auth Database +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import jwt +import time +import pprint + +from sqlalchemy import Column, Integer, Text, String +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.ext.declarative import declarative_base + +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + + +Model = declarative_base() + + +class Token(Model, OAuth2TokenMixin): + __tablename__ = 'Token' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + # access_token too large for varchar(255) + # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 + # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length + id = Column(Integer, autoincrement=True, primary_key=True) + kid = Column(String(255)) + user_id = Column(String(255)) + provider = Column(String(255)) + client_id = Column(String(255)) + expires_at = Column(Integer, nullable=False, default=0) + access_token = Column(Text, nullable=False) + refresh_token = Column(Text, nullable=False) + rt_expires_at = Column(Integer, nullable=False, default=0) + + +class TokenDB(SQLAlchemyDB): + """ TokenDB class is a front-end to the OAuth Database + """ + def __init__(self): + """ Constructor + """ + super(TokenDB, self).__init__() + self._initializeConnection('Framework/TokenDB') + result = self.__initializeDB() + if not result['OK']: + raise Exception("Can't create tables: %s" % result['Message']) + self.session = scoped_session(self.sessionMaker_o) + + def __initializeDB(self): + """ Create the tables + """ + tablesInDB = self.inspector.get_table_names() + + # Token + if 'Token' not in tablesInDB: + try: + Token.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + return S_OK() + + def getTokenForUserProvider(self, userID, provider): + """ Get token for user ID and provider name + + :param str userID: user ID + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + token = session.query(Token).filter(Token.rt_expires_at > time.time()).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).first() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) + + def updateToken(self, token, userID, provider, rt_expired_in): + """ Update tokens + + :param dict token: token info + :param str userID: user ID + :param str provider: provider + :param int rt_expired_in: refresh token lifetime + + :return: S_OK(list)/S_ERROR() + """ + token['user_id'] = userID + token['provider'] = provider + if not token.get('rt_expires_at'): + try: + token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, + verify_aud=False))['exp']) + except Exception as e: + self.log.debug('Cannot get refresh token expires time: %s' % repr(e)) + + token['rt_expires_at'] = int(token.get('rt_expires_at', rt_expired_in + int(time.time()))) + if token['rt_expires_at'] < time.time(): + return S_ERROR('Cannot store expired refresh token.') + + attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) + self.log.debug('Store token:', pprint.pformat(attrts)) + session = self.session() + try: + session.query(Token).filter(Token.expires_at < time.time()).delete() + oldTokens = session.query(Token).filter(Token.user_id == userID)\ + .filter(Token.provider == provider).all() + session.add(Token(**attrts)) + session.query(Token).filter(Token.user_id == userID).filter(Token.provider == provider)\ + .filter(Token.access_token != token['access_token']).delete() + except Exception as e: + return self.__result(session, S_ERROR('Could not add Token: %s' % repr(e))) + self.log.info('Token successfully added for %s user, %s provider' % (token['user_id'], token['provider'])) + return self.__result(session, S_OK([self.__rowToDict(t) for t in oldTokens] if oldTokens else [])) + + def removeToken(self, access_token=None, refresh_token=None, user_id=None): + """ Remove token + + :param str access_token: access token + :param str refresh_token: refresh token + + :return: S_OK(object)/S_ERROR() + """ + session = self.session() + try: + if access_token: + session.query(Token).filter(Token.access_token == access_token).delete() + elif refresh_token: + session.query(Token).filter(Token.refresh_token == refresh_token).delete() + elif user_id: + session.query(Token).filter(Token.user_id == user_id).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK('Token successfully removed')) + + def getTokensByUserID(self, userID): + session = self.session() + try: + tokens = session.query(Token).filter(Token.user_id == userID).all() + except NoResultFound: + return self.__result(session, S_OK([])) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK([OAuth2Token(self.__rowToDict(t)) for t in tokens])) + + def __result(self, session, result=None): + try: + if not result['OK']: + session.rollback() + else: + session.commit() + except Exception as e: + session.rollback() + result = S_ERROR('Could not commit: %s' % (e)) + session.close() + return result + + def __rowToDict(self, row): + """ Convert sqlalchemy row to dictionary + + :param object row: sqlalchemy row + + :return: dict + """ + return {c.name: str(getattr(row, c.name)) for c in row.__table__.columns} if row else {} diff --git a/src/DIRAC/FrameworkSystem/DB/TokenDB.sql b/src/DIRAC/FrameworkSystem/DB/TokenDB.sql new file mode 100644 index 00000000000..15d9a8185fb --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/TokenDB.sql @@ -0,0 +1,2 @@ +# Everything is created by the DB object upon instantiation if it does not exists. +use TokenDB; \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py new file mode 100644 index 00000000000..9867f7a72e6 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -0,0 +1,177 @@ +""" TokenManagement service + + .. literalinclude:: ../ConfigTemplate.cfg + :start-after: ##BEGIN TokenManager: + :end-before: ##END + :dedent: 2 + :caption: TokenManager options +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import pprint + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Security import Properties +from DIRAC.Core.Tornado.Server.TornadoService import TornadoService +from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + + +class TokenManagerHandler(TornadoService): + + __maxExtraLifeFactor = 1.5 + __tokenDB = None + + @classmethod + def initializeHandler(cls, serviceInfoDict): + try: + cls.__tokenDB = TokenDB() + except Exception as e: + gLogger.exception(e) + return S_ERROR('Could not connect to the database %s' % repr(e)) + + cls.idps = IdProviderFactory() + return S_OK() + + def __generateUserTokensInfo(self): + """ Generate information dict about user tokens + + :return: dict + """ + tokensInfo = [] + credDict = self.getRemoteCredentials() + result = Registry.getDNForUsername(credDict['username']) + if not result['OK']: + return result + for dn in result['Value']: + if Registry.isDNWrappedID(dn): + result = self.__tokenDB.getTokensByUserID(Registry.getIDFromDN(dn)) + if not result['OK']: + gLogger.error(result['Message']) + tokensInfo += result['Value'] + return tokensInfo + + def __addKnownUserTokensInfo(self, retDict): + """ Given a S_OK/S_ERR add a tokens entry with info of all the tokens a user has uploaded + + :return: S_OK(dict)/S_ERROR() + """ + retDict['tokens'] = self.__generateUserTokensInfo() + return retDict + + auth_getUserTokensInfo = ['authenticated'] + + def export_getUserTokensInfo(self): + """ Get the info about the user tokens in the system + + :return: S_OK(dict) + """ + return S_OK(self.__generateUserTokensInfo()) + + auth_uploadToken = ['authenticated'] + + def export_updateToken(self, token, userID, provider, rt_expired_in=24 * 3600): + """ Request to delegate tokens to DIRAC + + :param dict token: token + :param str userID: user ID + :param str provider: provider name + :param int rt_expired_in: refresh token expires time + + :return: S_OK(dict)/S_ERROR() -- dict contain uploaded tokens info + """ + self.log.verbose('Update %s user token:\n', pprint.pformat(token)) + result = self.idps.getIdProvider(provider) + if not result['OK']: + return result + idPObj = result['Value'] + result = self.__tokenDB.updateToken(token, userID, provider, rt_expired_in) + if not result['OK']: + return result + for oldToken in result['Value']: + if 'refresh_token' in oldToken and oldToken['refresh_token'] != token['refresh_token']: + self.log.verbose('Revoke old refresh token:\n', pprint.pformat(oldToken)) + idPObj.revokeToken(oldToken['refresh_token']) + return self.__tokenDB.getTokensByUserID(userID) + + def __checkProperties(self, requestedUserDN, requestedUserGroup): + """ Check the properties and return if they can only download limited tokens if authorized + + :param str requestedUserDN: user DN + :param str requestedUserGroup: DIRAC group + + :return: S_OK(boolean)/S_ERROR() + """ + credDict = self.getRemoteCredentials() + if Properties.FULL_DELEGATION in credDict['properties']: + return S_OK(False) + if Properties.LIMITED_DELEGATION in credDict['properties']: + return S_OK(True) + if Properties.PRIVATE_LIMITED_DELEGATION in credDict['properties']: + if credDict['DN'] != requestedUserDN: + return S_ERROR("You are not allowed to download any proxy") + if Properties.PRIVATE_LIMITED_DELEGATION not in Registry.getPropertiesForGroup(requestedUserGroup): + return S_ERROR("You can't download tokens for that group") + return S_OK(True) + # Not authorized! + return S_ERROR("You can't get tokens!") + + def export_getToken(self, username, userGroup): + """ Get a access token for a user/group + + * Properties: + * FullDelegation <- permits full delegation of tokens + * LimitedDelegation <- permits downloading only limited tokens + * PrivateLimitedDelegation <- permits downloading only limited tokens for one self + """ + userID = [] + provider = Registry.getIdPForGroup(userGroup) + + result = self.idps.getIdProvider(provider) + if not result['OK']: + return result + idpObj = result['Value'] + + result = Registry.getDNForUsername(username) + if not result['OK']: + return result + + err = [] + for dn in result['Value']: + if Registry.isDNWrappedID(dn): + result = self.__tokenDB.getTokenForUserProvider(Registry.getIDFromDN(dn), provider) + if not result['OK']: + err.append(result['Message']) + elif result['Value']: + idpObj.token = result['Value'] + result = self.__checkProperties(dn, userGroup) + if result['OK']: + result = idpObj.exchangeGroup(userGroup) + if result['OK']: + return result + if not err: + return S_ERROR('No user ID found for %s' % username) + return S_ERROR('; '.join(err)) + + def export_deleteToken(self, userDN): + """ Delete a token from the DB + + :param str userDN: user DN + + :return: S_OK()/S_ERROR() + """ + credDict = self.getRemoteCredentials() + if Properties.PROXY_MANAGEMENT not in credDict['properties']: + if userDN != credDict['DN']: + return S_ERROR("You aren't allowed!") + retVal = self.__tokenDB.removeToken(user_id=Registry.getIDFromDN(dn)) + if not retVal['OK']: + return retVal + self.__tokenDB.logAction("delete proxy", credDict['DN'], credDict['group'], userDN, userGroup) + return S_OK() diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py new file mode 100644 index 00000000000..02e53103103 --- /dev/null +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -0,0 +1,79 @@ +""" This is a test of the AuthDB + It supposes that the DB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from authlib.jose import jwt + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB + +db = TokenDB() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + +exp_payload = payload.copy() +exp_payload['iat'] = int(time.time()) - 10 +exp_payload['exp'] = int(time.time()) - 10 + +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + expires_at=int(time.time()) + 3600) + +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + +Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), + expires_at=int(time.time()) - 10, + rt_expires_at=int(time.time()) - 10) + + +def test_Token(): + """ Try to revoke/save/get tokens + """ + # Remove all tokens + result = db.removeToken(user_id=123) + assert result['OK'], result['Message'] + + # Store tokens + result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Expired token + result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert not result['OK'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + # Store new tokens + result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + # Must return old tokens + assert len(result['Value']) == 1 + assert result['Value'][0]['access_token'] == DToken['access_token'] + assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] diff --git a/tests/Integration/all_integration_server_tests.sh b/tests/Integration/all_integration_server_tests.sh index 8d2915c2a7d..1d44e54727f 100644 --- a/tests/Integration/all_integration_server_tests.sh +++ b/tests/Integration/all_integration_server_tests.sh @@ -28,6 +28,7 @@ echo -e "*** $(date -u) **** FRAMEWORK TESTS (partially skipped) ****\n" pytest "${THIS_DIR}/Framework/Test_InstalledComponentsDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) python "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) python "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +python "${THIS_DIR}/Framework/Test_TokenDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #pytest ${THIS_DIR}/Framework/Test_LoggingDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #-------------------------------------------------------------------------------# From 5df72c44476708f43b94d96ab7a7a3eb877033e6 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:46:39 +0200 Subject: [PATCH 051/178] update DIRAC AS client description --- .../private/authorization/utils/Clients.py | 91 ++++++++----------- .../Resources/IdProvider/IdProviderFactory.py | 33 ++++--- 2 files changed, 54 insertions(+), 70 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index e4240f77f85..19829524257 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -3,93 +3,78 @@ from __future__ import print_function import six -import json import time import pprint +from DIRAC import gConfig, gLogger from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin -from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance, getProviderInfo - -from DIRAC import gLogger __RCSID__ = "$Id$" -DEFAULT_SCOPE = 'proxy g: lifetime:' - DEFAULT_CLIENTS = { - 'DIRACCLI': dict( - verify=False, - client_id='DIRAC_CLI', - client_secret='secret', - response_types=['device'], - grant_types=['urn:ietf:params:oauth:grant-type:device_code', 'refresh_token'], - ProviderType='DIRACCLI' - ), - 'DIRACWeb': dict( - response_types=['code'], - grant_types=['authorization_code', 'refresh_token'], - ProviderType='DIRACWeb' - ) + 'DIRACCLI': dict(client_id='DIRAC_CLI', scope='proxy g: lifetime:', response_types=['device'], + grant_types=['urn:ietf:params:oauth:grant-type:device_code', 'refresh_token'], + token_endpoint_auth_method='none', verify=False, + ProviderType='OAuth2'), + 'DIRACWeb': dict(client_id='DIRAC_Web', scope='g:', response_types=['code'], + grant_types=['authorization_code', 'refresh_token'], + ProviderType='OAuth2') } +def getDIRACClients(): + """ Get DIRAC authorization clients + + :return: S_OK(dict)/S_ERROR() + """ + clients = DEFAULT_CLIENTS.copy() + result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization/Client') + if not result['OK']: + gLogger.error(result['Message']) + confClients = result.get('Value', {}) + for cli in confClients: + if cli not in clients: + clients[cli] = confClients[cli] + else: + clients[cli].update(confClients[cli]) + return clients + + def getDIACClientByID(clientID): - """ Search authorization client + """ Search authorization client. :param str clientID: client ID :return: object or None """ gLogger.debug('Try to query %s client' % clientID) - if clientID == DEFAULT_CLIENTS['DIRACCLI']['client_id']: - gLogger.debug('Found client:\n', pprint.pformat(DEFAULT_CLIENTS['DIRACCLI'])) - return Client(DEFAULT_CLIENTS['DIRACCLI']) - - result = getProvidersForInstance('Id') - if not result['OK']: - gLogger.error(result['Message']) - return None - - for client in result['Value']: - result = getProviderInfo(client) - if not result['OK']: - gLogger.debug(result['Message']) - continue - data = DEFAULT_CLIENTS.get(result['Value']['ProviderType'], {}) - data.update(result['Value']) - if data.get('client_id') and data['client_id'] == clientID: - gLogger.debug('Found client:\n', pprint.pformat(data)) - return Client(data) - + clients = getDIRACClients() + for cli in clients: + if clientID == clients[cli]['client_id']: + gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) + return Client(clients[cli]) + return Client(data) return None class Client(OAuth2ClientMixin): + def __init__(self, params): - super(Client, self).__init__() - client_metadata = params.get('client_metadata', params) - if client_metadata.get('scope') and DEFAULT_SCOPE not in client_metadata['scope']: - client_metadata['scope'] += ' %s' % DEFAULT_SCOPE - else: - client_metadata['scope'] = DEFAULT_SCOPE - if params.get('redirect_uri') and not client_metadata.get('redirect_uris'): - client_metadata['redirect_uris'] = [params['redirect_uri']] + if params.get('redirect_uri') and not params.get('redirect_uris'): + params['redirect_uris'] = [params['redirect_uri']] + self.set_client_metadata(params) self.client_id = params['client_id'] self.client_secret = params.get('client_secret', '') self.client_id_issued_at = params.get('client_id_issued_at', int(time.time())) self.client_secret_expires_at = params.get('client_secret_expires_at', 0) - if isinstance(client_metadata, dict): - self._client_metadata = json.dumps(client_metadata) - else: - self._client_metadata = client_metadata def get_allowed_scope(self, scope): if not isinstance(scope, six.string_types): scope = list_to_scope(scope) allowed = scope_to_list(super(Client, self).get_allowed_scope(scope)) for s in scope_to_list(scope): - for def_scope in scope_to_list(DEFAULT_SCOPE): + for def_scope in scope_to_list(self.scope): if s.startswith(def_scope) and s not in allowed: allowed.append(s) gLogger.debug('Try to allow "%s" scope:' % scope, allowed) diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 80b19de5ec9..a9f3fc9ea4f 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -17,7 +17,7 @@ from DIRAC.Core.Utilities.DictCache import DictCache from DIRAC.Resources.IdProvider.Utilities import getProviderInfo, getSettingsNamesForIdPIssuer from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients __RCSID__ = "$Id$" @@ -74,33 +74,32 @@ def getIdProvider(self, name, **kwargs): """ This method returns a IdProvider instance corresponding to the supplied name. - :param str name: the name of the Identity Provider + :param str name: the name of the Identity Provider client :return: S_OK(IdProvider)/S_ERROR() """ - self.log.debug('Search %s configuration..' % name) - pDict = DEFAULT_CLIENTS.get(name, {}) - if pDict: + self.log.debug('Search %s identity provider client configuration..' % name) + clients = getDIRACClients() + if name in clients: + # If it is a DIRAC default pre-registred client + pDict = clients[name] result = getAuthorizationServerMetadata() if not result['OK']: return result pDict.update(result['Value']) - pDict.update(kwargs) - - result = getProviderInfo(name) - if not result['OK']: - if not pDict: + else: + # if it is external identity provider client + result = getProviderInfo(name) + if not result['OK']: self.log.error('Failed to read configuration', '%s: %s' % (name, result['Message'])) return result - gLogger.debug(result['Message']) - else: - pDict.update(result['Value']) - pDict['ProviderName'] = name + pDict = result['Value'] - pType = pDict['ProviderType'] + pDict.update(kwargs) + pDict['ProviderName'] = name - self.log.verbose('Creating IdProvider of %s type with the name %s' % (pType, name)) - subClassName = "%sIdProvider" % (pType) + self.log.verbose('Creating IdProvider of %s type with the name %s' % (pDict['ProviderType'], name)) + subClassName = "%sIdProvider" % pDict['ProviderType'] objectLoader = ObjectLoader.ObjectLoader() result = objectLoader.loadObject('Resources.IdProvider.%s' % subClassName, subClassName) From ce92f35ee60a0a8aa2f4cd7235c07e198cd5a418 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:46:55 +0200 Subject: [PATCH 052/178] fix --- src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index ac68da14dc6..d60e7a6cc23 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -531,6 +531,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): return S_ERROR('Access token expired.') # Try to refresh token + self.__idp.scope = None result = self.__idp.refreshToken(token['refresh_token']) if result['OK']: token = result['Value'] From 0c4d3722fbcc3f9095cb22df1fdfad96d7d3b202 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:47:37 +0200 Subject: [PATCH 053/178] optimize DeviceFlow --- .../authorization/grants/DeviceFlow.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index 4158e2639d3..f00f9fd01f9 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -10,11 +10,6 @@ DeviceCodeGrant as _DeviceCodeGrant, DeviceCredentialDict) -from DIRAC import gLogger -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata - -log = gLogger.getSubLogger(__name__) - class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): """ See :class:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint` """ @@ -30,10 +25,7 @@ def get_verification_uri(self): :return: str """ - result = getAuthorizationServerMetadata() - if not result['OK']: - raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) - return result['Value']['issuer'] + '/device' + return self.server.metadata['device_authorization_endpoint'] def save_device_credential(self, client_id, scope, data): """ Save device credentials @@ -61,7 +53,7 @@ def validate_authorization_request(self): """ # Validate client for this request client_id = self.request.client_id - log.debug('Validate authorization request of', client_id) + self.server.log.debug('Validate authorization request of', client_id) if client_id is None: raise InvalidClientError(state=self.request.state) client = self.server.query_client(client_id) @@ -116,10 +108,7 @@ def query_device_credential(self, device_code): data = result['Value'] if not data: return None - result = getAuthorizationServerMetadata() - if not result['OK']: - raise OAuth2Error('Cannot prepare authorization server metadata. %s' % result['Message']) - data['verification_uri'] = result['Value']['issuer'] + '/device' + data['verification_uri'] = self.server.metadata['device_authorization_endpoint'] data['expires_at'] = int(data['expires_in']) + int(time.time()) data['interval'] = DeviceAuthorizationEndpoint.INTERVAL return DeviceCredentialDict(data) @@ -134,8 +123,7 @@ def query_user_grant(self, user_code): result = self.server.db.getSessionByUserCode(user_code) if not result['OK']: raise OAuth2Error('Cannot found authorization session', result['Message']) - data = result['Value'] - return (data['user_id'], True) if data.get('username') != "None" else None + return (result['Value']['user_id'], True) if result['Value'].get('username') != "None" else None def should_slow_down(self, credential, now): """ The authorization request is still pending and polling should continue, From 054d9ef72878d904252ef1460f738d1f5b2f1033 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:48:13 +0200 Subject: [PATCH 054/178] fix Token class --- .../private/authorization/utils/Tokens.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 995b4433561..397a68404a5 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -23,7 +23,7 @@ def getTokenLocation(): """ Research token file location. Use the bearer token discovery protocol - defined by the WLCG (https://zenodo.org/record/3937438) to find one. + defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. :return: str """ @@ -37,15 +37,16 @@ def getTokenLocation(): def getLocalTokenDict(location=None): """ Search local token. Use the bearer token discovery protocol - defined by the WLCG (https://zenodo.org/record/3937438) to find one. + defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. :param str location: environ variable name or file path :return: S_OK(dict)/S_ERROR() """ env = (location if location and location.startswith('/') else None) or 'BEARER_TOKEN' - if os.environ.get(env): - return S_OK(OAuth2Token(os.environ[env])) + token = os.environ.get(env, "").strip() + if token: + return S_OK(OAuth2Token(token)) return readTokenFromFile(location if location and location.startswith('/') else None) @@ -59,7 +60,7 @@ def readTokenFromFile(fileName=None): location = fileName or getTokenLocation() try: with open(location, 'rt') as f: - token = f.read() + token = f.read().strip() except IOError as e: return S_ERROR(DErrno.EOF, "Can't open %s token file.\n%s" % (location, repr(e))) return S_OK(OAuth2Token(token)) @@ -108,19 +109,22 @@ def __init__(self, params=None, **kwargs): """ if isinstance(params, six.string_types): # Is params a JWT? + params = params.strip() if re.match(r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$", params): params = dict(access_token=params) else: params = json.loads(params) kwargs.update(params or {}) + kwargs['issued_at'] = kwargs.get('issued_at', kwargs.get('iat')) + kwargs['expires_at'] = kwargs.get('expires_at', kwargs.get('exp')) if not kwargs.get('expires_at') and kwargs.get('access_token'): # Get access token expires_at claim - kwargs['expires_at'] = int(self.get_claim('exp')) + kwargs['expires_at'] = self.get_claim('exp') super(OAuth2Token, self).__init__(kwargs) def get_client_id(self): - return self.get('client_id') + return self.get('client_id', self.get('azp')) def get_scope(self): return self.get('scope') @@ -129,7 +133,7 @@ def get_expires_in(self): return self.get('expires_in') def get_expires_at(self): - return self.get('issued_at') + self.get('expires_in') + return int(self.get('expires_at', self.get('issued_at') + self.get('expires_in'))) @property def scopes(self): @@ -169,7 +173,7 @@ def get_claim(self, claim, token_type='access_token'): :return: str """ - return get_payload(token_type).get(claim) + return self.get_payload(token_type).get(claim) def dump_to_string(self): """ Dump token dictionary to sting From f565be74f9ad7c2230139febda7e3cdfbd1c297d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:48:47 +0200 Subject: [PATCH 055/178] fix dirac-login --- .../FrameworkSystem/scripts/dirac_login.py | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index b0691729ae5..fb2ba44345f 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -15,6 +15,7 @@ import os import sys +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope import DIRAC from DIRAC import gLogger, S_OK, S_ERROR @@ -34,7 +35,6 @@ def __init__(self): self.proxy = False self.group = None self.lifetime = None - self.provider = 'DIRACCLI' self.issuer = None self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() self.tokenLoc = None @@ -57,16 +57,6 @@ def setGroup(self, arg): self.group = arg return S_OK() - def setProvider(self, arg): - """ Set provider name - - :param str arg: provider - - :return: S_OK() - """ - self.provider = arg - return S_OK() - def setIssuer(self, arg): """ Set issuer @@ -109,11 +99,6 @@ def registerCLISwitches(self): "group=", "set DIRAC group", self.setGroup) - Script.registerSwitch( - "O:", - "provider=", - "set identity provider", - self.setProvider) Script.registerSwitch( "I:", "issuer=", @@ -138,16 +123,18 @@ def doOAuthMagic(self): params = {} if self.issuer: params['issuer'] = self.issuer - result = IdProviderFactory().getIdProvider(self.provider, **params) + result = IdProviderFactory().getIdProvider('DIRACCLI', **params) if not result['OK']: return result idpObj = result['Value'] + scope = [] if self.group: - idpObj.scope += '+g:%s' % self.group + scope.append('g:%s' % self.group) if self.proxy: - idpObj.scope += '+proxy' + scope.append('proxy') if self.lifetime: - idpObj.scope += '+lifetime:%s' % (int(self.lifetime) * 3600) + scope.append('lifetime:%s' % (int(self.lifetime) * 3600)) + idpObj.scope = '+'.join(scope) if scope else None # Submit Device authorisation flow result = idpObj.deviceAuthorization() From b67a701747a1c8a75651b0af187f67f879ff6896 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 19:49:58 +0200 Subject: [PATCH 056/178] add refresh token reuse protection, fixes --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 106 ++++++------ src/DIRAC/FrameworkSystem/DB/AuthDB.py | 162 ++++-------------- .../private/authorization/AuthServer.py | 156 +++++++++++------ .../authorization/grants/RefreshToken.py | 52 ++++-- .../Resources/IdProvider/OAuth2IdProvider.py | 24 ++- tests/Integration/Framework/Test_AuthDB.py | 80 ++++----- 6 files changed, 280 insertions(+), 300 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 447dc12cda7..5f2cf697021 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -1,4 +1,10 @@ """ This handler basically provides a REST interface to interact with the OAuth 2 authentication server + + .. literalinclude:: ../ConfigTemplate.cfg + :start-after: ##BEGIN Auth: + :end-before: ##END + :dedent: 2 + :caption: Auth options """ from __future__ import absolute_import from __future__ import division @@ -19,10 +25,9 @@ from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint -from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory __RCSID__ = "$Id$" @@ -96,57 +101,41 @@ def initializeRequest(self): href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") dom.style(self.CSS) - def _parseDIRACResult(self, result): - """ Here the result which returns handle_response is processed - """ - if not result['OK']: - # If response error is DIRAC server error, not OAuth2 flow error - self.removeSession() - self.set_status(400) - self.write({'error': 'server_error', - 'description': '%s:\n%s' % (result['Message'], '\n'.join(result['CallStack']))}) - else: - # Successful responses and OAuth2 errors are processed here - status_code, headers, payload, new_session, error = result['Value'][0] - if status_code: - self.set_status(status_code) - if headers: - for key, value in headers: - self.set_header(key, value) - if payload: - self.write(payload) - if new_session: - self.saveSession(new_session) - if error: - self.removeSession() - for method, args_kwargs in result['Value'][1].items(): - eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) - - def saveSession(self, session): - """ Save session to cookie - - :param dict session: session - """ - self.set_secure_cookie('auth_session', json.dumps(session), secure=True, httponly=True) - - def removeSession(self): - """ Remove session from cookie """ - self.clear_cookie('auth_session') - - def getSession(self, state=None, **kw): - """ Get session from cookie + def _finishFuture(self, retVal): + """ Handler Future result - :param str state: state - - :return: dict + :param object retVal: tornado.concurrent.Future """ - try: - session = json.loads(self.get_secure_cookie('auth_session')) - checkState = (session['state'] == state) if state else None - checkOption = (session[kw.items()[0][0]] == kw.items()[0][0]) if kw else None - except Exception as e: - return None - return session if (checkState or checkOption) else None + self.result = retVal.result() + + # Is it S_OK or S_ERROR? + r = self.result + if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: + if not self.result['OK']: + # S_ERROR is interpreted in the OAuth2 error format. + self.set_status(400) + self.write({'error': 'server_error', 'description': self.result['Message']}) + self.clear_cookie('auth_session') + self.log.error('%s\n' % self.result['Message'], ''.join(self.result['CallStack'])) + else: + # Successful responses and OAuth2 errors are processed here + status_code, headers, payload, new_session, error = self.result['Value'][0] + if status_code: + self.set_status(status_code) + if headers: + for key, value in headers: + self.set_header(key, value) + if payload: + self.write(payload) + if new_session: + self.set_secure_cookie('auth_session', json.dumps(new_session), secure=True, httponly=True) + if error: + self.clear_cookie('auth_session') + for method, args_kwargs in self.result['Value'][1].items(): + eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) + self.finish() + else: + super(AuthHandler, self)._finishFuture(retVal) path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] @@ -191,7 +180,7 @@ def web_index(self, instance): } """ if self.request.method == "GET": - return dict(self.server.metadata) + return self.server.metadata def web_jwk(self): """ JWKs endpoint @@ -267,11 +256,7 @@ def web_userinfo(self): ] } """ - # Token verification - # token = ResourceProtector().acquire_token(self.request, '') - # return {'sub': token.sub, 'issuer': token.issuer, 'group': token.groups[0]} - userinfo = self.getRemoteCredentials() - return userinfo + return self.getRemoteCredentials() path_device = ['([A-z0-9-_]*)'] @@ -426,7 +411,12 @@ def web_redirect(self): return self.server.handle_error_response(state, error) # Check current auth session that was initiated for the selected external identity provider - sessionWithExtIdP = self.getSession(state) + try: + session = json.loads(self.get_secure_cookie('auth_session')) + except Exception: + session = {} + + sessionWithExtIdP = session if state and (session.get('state') == state) else None if not sessionWithExtIdP: return S_ERROR("%s session is expired." % state) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index ec1468e136b..94b0a78bad2 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.declarative import declarative_base from authlib.jose import KeySet, RSAKey +from authlib.common.security import generate_token from authlib.common.encoding import urlsafe_b64decode, urlsafe_b64encode, to_bytes, to_unicode, json_b64encode from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin @@ -29,38 +30,14 @@ Model = declarative_base() -def encrypt(data, key): - """ Encryption with key """ - cipher = M2Crypto.EVP.Cipher(alg='aes_256_cbc', key=key[16:], iv=key[:16], op=1) - ciphertext = cipher.update(data.encode('utf-8')) + cipher.final() - ciphertext = urlsafe_b64encode(ciphertext) - return ciphertext - - -def decrypt(ciphertext, key): - """ Decryption with key """ - cipher = M2Crypto.EVP.Cipher(alg='aes_256_cbc', key=key[16:], iv=key[:16], op=0) - data = cipher.update(urlsafe_b64decode(to_bytes(ciphertext))) + cipher.final() - data = to_unicode(data.decode('utf-8')) - return data - - -class Token(Model, OAuth2TokenMixin): - __tablename__ = 'Token' +class RefreshToken(Model): + __tablename__ = 'RefreshToken' __table_args__ = {'mysql_engine': 'InnoDB', 'mysql_charset': 'utf8'} - # access_token too large for varchar(255) - # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 - # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length - id = Column(Integer, autoincrement=True, primary_key=True) - kid = Column(String(255)) - user_id = Column(String(255)) - provider = Column(String(255)) - client_id = Column(String(255)) - expires_at = Column(Integer, nullable=False, default=0) + jti = Column(String(255), nullable=False, primary_key=True) + issued_at = Column(Integer, nullable=False, default=0) access_token = Column(Text, nullable=False) - refresh_token = Column(Text, nullable=False) - rt_expires_at = Column(Integer, nullable=False, default=0) + refresh_token = Column(Text) class JWK(Model): @@ -110,10 +87,10 @@ def __initializeDB(self): """ tablesInDB = self.inspector.get_table_names() - # Token - if 'Token' not in tablesInDB: + # RefreshToken + if 'RefreshToken' not in tablesInDB: try: - Token.__table__.create(self.engine) # pylint: disable=no-member + RefreshToken.__table__.create(self.engine) # pylint: disable=no-member except Exception as e: return S_ERROR(e) @@ -133,121 +110,58 @@ def __initializeDB(self): return S_OK() - def encryptRefreshToken(self, token, metadata): - """ Encrypt refresh token + def storeRefreshToken(self, token, tokenID=None): + """ Store refresh token - :param dict token: token dict - :param str client_id: client ID - :param str provider: provider name + :param dict token: tokens as dict + :param str tokenID: token ID :return: S_OK(dict)/S_ERROR() """ - for field in ['expires_at', 'client_id', 'provider']: - if not metadata.get(field): - return S_ERROR('%s field is absent in metadata.' % field) - # Get secret key - key = self.getPrivateKey() - if not key['OK']: - return key - # Encrypt refresh token - try: - metadata['kid'] = key['Value']['kid'] - metadata['refresh_token'] = encrypt(token['refresh_token'], key['Value']['strkey']) - token['refresh_token'] = json_b64encode(metadata) - return S_OK(token) - except Exception as e: - self.log.exception(e) - return S_ERROR('Cannot encode refresh token: %s' % repr(e)) - - def decryptRefreshToken(self, token): - """ Decrypt refresh token + iat = int(time.time()) + jti = tokenID or generate_token(10) + self.log.debug('Store %s token:\n' % jti, pprint.pformat(token)) - :param dict token: token dict - - :return: S_OK(dict)/S_ERROR() - """ - try: - decoded = json.loads(urlsafe_b64decode(token['refresh_token'])) - except Exception as e: - return S_ERROR('Cannot find secret key: %s' % repr(e)) - # Get secret key by key ID - key = self.getPrivateKey(decoded['kid']) - if not key['OK']: - return key - # Decript refresh token + session = self.session() try: - token['refresh_token'] = decrypt(decoded['refresh_token'], key['Value']['strkey']) - token['expires_at'] = decoded['expires_at'] - token['client_id'] = decoded['client_id'] - token['provider'] = decoded['provider'] - return S_OK(OAuth2Token(token)) + session.add(RefreshToken(jti=jti, + issued_at=iat, + access_token=token['access_token'], + refresh_token=token.get('refresh_token'))) except Exception as e: - self.log.exception(e) - return S_ERROR('Cannot decode refresh token: %s' % repr(e)) + return self.__result(session, S_ERROR('Could not add refresh token: %s' % repr(e))) - def getTokenForUserProvider(self, userID, provider): - """ Get token for user ID and provider name + self.log.info('Token with %s ID successfully added:\n' % jti, pprint.pformat(token)) + return S_OK(dict(jti=jti, iat=iat)) - :param str userID: user ID - :param str provider: provider + def revokeRefreshToken(self, tokenID): + """ Revoke refresh token - :return: S_OK(dict)/S_ERROR() + :param str tokenID: refresh token ID + + :return: S_OK()/S_ERROR() """ session = self.session() try: - token = session.query(Token).filter(Token.rt_expires_at > time.time()).filter(Token.user_id == userID)\ - .filter(Token.provider == provider).first() + session.query(RefreshToken).filter(RefreshToken.jti == tokenID).delete() except Exception as e: return self.__result(session, S_ERROR(str(e))) - return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) - - def updateToken(self, token, userID, provider): - """ Update tokens - - :param dict token: token info - :param str userID: user ID - :param str provider: provider - - :return: S_OK(list)/S_ERROR() - """ - token['user_id'] = userID - token['provider'] = provider - try: - token['rt_expires_at'] = int(jwt.decode(token['refresh_token'], options=dict(verify_signature=False, - verify_aud=False))['exp']) - except Exception as e: - self.log.debug('Cannot get refresh token expires time: %s' % repr(e)) - - token['rt_expires_at'] = int(token.get('rt_expires_at', 24 * 3600 + time.time())) - if token['rt_expires_at'] < time.time(): - return S_ERROR('Cannot store expired refresh token.') + return S_OK() - attrts = dict((k, v) for k, v in dict(token).items() if k in list(Token.__dict__.keys())) - self.log.debug('Store token:', pprint.pformat(attrts)) - session = self.session() - try: - session.query(Token).filter(Token.expires_at < time.time()).delete() - oldTokens = session.query(Token).filter(Token.user_id == userID)\ - .filter(Token.provider == provider).all() - session.add(Token(**attrts)) - session.query(Token).filter(Token.user_id == userID).filter(Token.provider == provider)\ - .filter(Token.access_token != token['access_token']).delete() - except Exception as e: - return self.__result(session, S_ERROR('Could not add Token: %s' % repr(e))) - self.log.info('Token successfully added for %s user, %s provider' % (token['user_id'], token['provider'])) - return self.__result(session, S_OK([self.__rowToDict(t) for t in oldTokens] if oldTokens else [])) + def getCredentialByRefreshToken(self, tokenID): + """ Get refresh token credential - def removeTokens(self): - """ Get active keys + :param str tokenID: refresh token ID - :return: S_OK(list)/S_ERROR() + :return: S_OK(dict)/S_ERROR() """ session = self.session() try: - session.query(Token).delete() + token = session.query(RefreshToken).filter(RefreshToken.jti == tokenID).first() + session.query(RefreshToken).filter(RefreshToken.jti == tokenID).delete() except Exception as e: return self.__result(session, S_ERROR(str(e))) - return self.__result(session, S_OK()) + return self.__result(session, S_OK(OAuth2Token(self.__rowToDict(token)) if token else None)) def generateRSAKeys(self): """ Generate an RSA keypair with an exponent of 65537 in PEM format diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index ebee16a46e6..bb74f45ed69 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function +import re import sys import time import pprint @@ -10,11 +11,13 @@ from dominate import document, tags as dom from tornado.template import Template +from authlib.jose import jwt from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error +from authlib.common.security import generate_token from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc8414 import AuthorizationServerMetadata -from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant @@ -29,9 +32,10 @@ from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, wrapIDAsDN, getDNForUsername, getIdPForGroup) from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient +from DIRAC.FrameworkSystem.Client.TokenManagerClient import TokenManagerClient log = logging.getLogger('authlib') log.addHandler(logging.StreamHandler(sys.stdout)) @@ -40,7 +44,13 @@ def collectMetadata(issuer=None): - """ Collect metadata """ + """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: + https://datatracker.ietf.org/doc/html/rfc8414#section-2 + + :param str issuer: issuer to set + + :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time + """ result = getAuthorizationServerMetadata(issuer) if not result['OK']: raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) @@ -55,11 +65,12 @@ def collectMetadata(issuer=None): 'urn:ietf:params:oauth:grant-type:device_code'] metadata['response_types_supported'] = ['code', 'device', 'token'] metadata['code_challenge_methods_supported'] = ['S256'] + metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] return AuthorizationServerMetadata(metadata) class AuthServer(_AuthorizationServer): - """ Implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + """ Implementation of the :class:`authlib.oauth2.rfc6749.AuthorizationServer`. Initialize:: @@ -67,11 +78,13 @@ class AuthServer(_AuthorizationServer): """ css = {} LOCATION = None + REFRESH_TOKEN_EXPIRES_IN = 24 * 3600 def __init__(self): self.db = AuthDB() - # self.__tokenDB = TokenDB() + self.log = log self.proxyCli = ProxyManagerClient() + self.tokenCli = TokenManagerClient() self.idps = IdProviderFactory() # Privide two authlib methods query_client and save_token _AuthorizationServer.__init__(self, query_client=getDIACClientByID, save_token=lambda x, y: None) @@ -92,7 +105,7 @@ def addSession(self, session): def getSession(self, session): self.db.getSession(session) - def __getScope(self, scope, param): + def _getScope(self, scope, param): """ Get parameter scope :param str scope: scope @@ -109,24 +122,28 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, expires_in=None, include_refresh_token=True): """ Generate proxy or tokens after authorization """ - group = self.__getScope(scope, 'g') - lifetime = self.__getScope(scope, 'lifetime') + group = self._getScope(scope, 'g') + lifetime = self._getScope(scope, 'lifetime') provider = getIdPForGroup(group) + # Search DIRAC username + result = getUsernameForDN(wrapIDAsDN(user)) + if not result['OK']: + raise Exception(result['Message']) + userName = result['Value'] + if 'proxy' in scope_to_list(scope): # Try to return user proxy if proxy scope present in the authorization request if not isDownloadablePersonalProxy(): raise Exception("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") - gLogger.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) - result = getUsernameForDN('/O=DIRAC/CN=%s' % user) - if result['OK']: - result = getDNForUsername(result['Value']) + self.log.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) + result = getDNForUsername(userName) if not result['OK']: raise Exception(result['Message']) userDNs = result['Value'] err = [] for dn in userDNs: - gLogger.debug('Try to get proxy for %s' % dn) + self.log.debug('Try to get proxy for %s' % dn) if lifetime: result = self.proxyCli.downloadProxy(dn, group, requiredTimeLeft=int(lifetime)) else: @@ -134,7 +151,7 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, if not result['OK']: err.append(result['Message']) else: - gLogger.info('Proxy was created.') + self.log.info('Proxy was created.') result = result['Value'].dumpAllToString() if not result['OK']: raise Exception(result['Message']) @@ -142,33 +159,59 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, raise Exception('; '.join(err)) else: - # Get identity provider - result = self.idps.getIdProvider(provider) - if result['OK']: - idpObj = result['Value'] - # Get actual token from storage - result = self.db.getTokenForUserProvider(user, provider) - if result['OK']: - idpObj.token = result['Value'] - # Try to refresh it if expired - if idpObj.token.is_expired(): - result = idpObj.refreshToken() - if result['OK']: - result = self.db.updateToken(idpObj.token, user, provider) - if not result['OK']: - raise OAuth2Error(result['Message']) - # Ask identity provider tokens with needed group scopes - result = idpObj.exchangeGroup(group) - if result['OK']: - token = result['Value'] - # Encrypt refresh token - result = self.db.encryptRefreshToken(token, dict(provider=idpObj.name, - client_id=client.get_client_id(), - expires_at=12 * 3600 + time.time())) + # Ask TokenManager to generate new tokens for user + result = self.tokenCli.getToken(userName, group) + if not result['OK']: + raise OAuth2Error(result['Message']) + token = result['Value'] + + # Wrap the refresh token and register it to protect against reuse + result = self.registerRefreshToken(dict(sub=user, scope=scope, provider=provider, + azp=client.get_client_id()), token) if not result['OK']: raise OAuth2Error(result['Message']) return result['Value'] + def __signToken(self, payload): + """ Sign token + + :param dict payload: payload + + :return: S_OK(str)/S_ERROR() + """ + result = self.db.getPrivateKey() + if not result['OK']: + return result + key = result['Value']['rsakey'] + kid = result['Value']['kid'] + try: + return S_OK(jwt.encode(dict(alg='RS256', kid=kid), payload, key)) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + def registerRefreshToken(self, payload, token): + """ Register refresh token to protect it from reuse + + :param dict payload: payload + :param dict token: token as a dictionary + + :return: S_OK(dict)S_ERROR() + """ + result = self.db.storeRefreshToken(token, payload.get('jti')) + if result['OK']: + payload.update(result['Value']) + result = self.__signToken(payload) + if not result['OK']: + if token.get('refresh_token'): + prov = self.idps.getIdProvider(payload['provider']) + if prov['OK']: + prov['Value'].revokeToken(token['refresh_token']) + prov['Value'].revokeToken(token['access_token'], 'access_token') + return result + token['refresh_token'] = result['Value'] + return S_OK(token) + def getIdPAuthorization(self, providerName, request): """ Submit subsession and return dict with authorization url and session number @@ -186,7 +229,7 @@ def getIdPAuthorization(self, providerName, request): session['Provider'] = providerName session['mainSession'] = request if isinstance(request, dict) else request.toDict() - gLogger.verbose('Redirect to', authURL) + self.log.verbose('Redirect to', authURL) return self.handle_response(302, {}, [("Location", authURL)], session) def parseIdPAuthorizationResponse(self, response, session): @@ -200,7 +243,7 @@ def parseIdPAuthorizationResponse(self, response, session): :return: S_OK(dict)/S_ERROR() """ providerName = session.pop('Provider') - gLogger.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) + self.log.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) # Parse response result = self.idps.getIdProvider(providerName) if not result['OK']: @@ -214,7 +257,7 @@ def parseIdPAuthorizationResponse(self, response, session): # As a result of authentication we will receive user credential dictionary credDict = result['Value'] - gLogger.debug("Read profile:", pprint.pformat(credDict)) + self.log.debug("Read profile:", pprint.pformat(credDict)) # Is ID registred? result = getUsernameForDN(credDict['DN']) if not result['OK']: @@ -230,15 +273,8 @@ def parseIdPAuthorizationResponse(self, response, session): # Update token for user. This token will be stored separately in the database and # updated from time to time. This token will never be transmitted, # it will be used to make exchange token requests. - result = self.db.updateToken(idpObj.token, credDict['ID'], idpObj.name) - if not result['OK']: - return result - - # Revoke old tokens - for oldToken in result['Value']: - idpObj.revokeToken(oldToken.get('refresh_token')) - - return S_OK(credDict) + result = self.tokenCli.updateToken(idpObj.token, credDict['ID'], idpObj.name) + return S_OK(credDict) if result['OK'] else result def get_error_uris(self, request): error_uris = self.config.get('error_uris') @@ -246,21 +282,27 @@ def get_error_uris(self, request): return dict(error_uris) def create_oauth2_request(self, request, method_cls=OAuth2Request, use_json=False): - gLogger.debug('Create OAuth2 request', 'with json' if use_json else '') + self.log.debug('Create OAuth2 request', 'with json' if use_json else '') return createOAuth2Request(request, method_cls, use_json) def create_json_request(self, request): return self.create_oauth2_request(request, HttpRequest, True) + def validate_requested_scope(self, scope, state=None): + """ See :func:`authlib.oauth2.rfc6749.authorization_server.validate_requested_scope` """ + # We also consider parametric scope containing ":" charter + extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) + super(AuthServer, self).validate_requested_scope(extended_scope, state) + def handle_error_response(self, request, error): return self.handle_response(*error(translations=self.get_translations(request), error_uris=self.get_error_uris(request)), error=True) def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, error=None, **actions): - gLogger.debug('Handle authorization response with %s status code:' % status_code, payload) - gLogger.debug('Headers:', headers) + self.log.debug('Handle authorization response with %s status code:' % status_code, payload) + self.log.debug('Headers:', headers) if newSession: - gLogger.debug('newSession:', newSession) + self.log.debug('newSession:', newSession) return S_OK([[status_code, headers, payload, newSession, error], actions]) def create_authorization_response(self, response, username): @@ -283,9 +325,9 @@ def validate_consent_request(self, request, provider=None): return 'Use GET method to access this endpoint.' try: req = self.create_oauth2_request(request) - gLogger.info('Validate consent request for', req.state) + self.log.info('Validate consent request for', req.state) grant = self.get_authorization_grant(req) - gLogger.debug('Use grant:', grant) + self.log.debug('Use grant:', grant) grant.validate_consent_request() if not hasattr(grant, 'prompt'): grant.prompt = None @@ -376,7 +418,7 @@ def __registerNewUser(self, provider, userProfile): for addresses in getEmailsForGroup('dirac_admin'): result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) if not result['OK']: - gLogger.error(result['Message']) + self.log.error(result['Message']) if result['OK']: - gLogger.info(result['Value'], "administrators have been notified about a new user.") + self.log.info(result['Value'], "administrators have been notified about a new user.") return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py index c503d82fd79..376382c91ec 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -2,14 +2,18 @@ from __future__ import division from __future__ import print_function +from authlib.jose import JsonWebKey, jwt from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, wrapIDAsDN + class RefreshTokenGrant(_RefreshTokenGrant): """ See :class:`authlib.oauth2.rfc6749.grants.RefreshTokenGrant` """ DEFAULT_EXPIRES_AT = 12 * 3600 + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] def authenticate_refresh_token(self, refresh_token): """ Get credential for token @@ -18,18 +22,38 @@ def authenticate_refresh_token(self, refresh_token): :return: dict or None """ - result = self.server.db.decryptRefreshToken({'refresh_token': refresh_token}) + result = self.server.db.getJWKs() if not result['OK']: raise OAuth2Error(result['Message']) - return result['Value'] + jwks = result['Value'] + rtDict = jwt.decode(refresh_token, JsonWebKey.import_key_set(jwks)) + result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) + if not result['OK']: + raise OAuth2Error(result['Message']) + credential = result['Value'] - def _validate_token_scope(self, token): - """ Skip scope validadtion """ - pass + if int(rtDict['iat']) != int(credential['issued_at']): + # An attempt to reuse the refresh token was detected + prov = self.server.idps.getIdProvider(rtDict['provider']) + if prov['OK']: + prov['Value'].revokeToken(credential['refresh_token']) + prov['Value'].revokeToken(credential['access_token'], 'access_token') + return None + + credential.update(rtDict) + return credential def authenticate_user(self, credential): - """ Authorize user """ - return True + """ Authorize user + + :param dict credential: credential (token payload) + + :return: str or bool + """ + result = getUsernameForDN(wrapIDAsDN(credential['sub'])) + if not result['OK']: + self.server.log.error(result['Message']) + return result.get('Value') def issue_token(self, user, credential): """ Refresh tokens @@ -39,13 +63,15 @@ def issue_token(self, user, credential): :return: dict """ - result = self.server.idps.getIdProvider(credential['provider']) - if result['OK']: - result = result['Value'].refreshToken(credential['refresh_token']) + if credential['refresh_token']: + result = self.server.idps.getIdProvider(credential['provider']) if result['OK']: - result = self.server.db.encryptRefreshToken(result['Value'], dict(provider=credential['provider'], - client_id=credential['client_id'], - expires_at=self.DEFAULT_EXPIRES_AT)) + result = result['Value'].refreshToken(credential['refresh_token']) + else: + result = self.server.tokenCli.getToken(user, self.server._getScope(credential['scope'], 'g')) + if result['OK']: + token = result['Value'] + result = self.server.registerRefreshToken(credential, token) if not result['OK']: raise OAuth2Error(result['Message']) return result['Value'] diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 4643394bc13..97d7749ff3e 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -24,7 +24,8 @@ from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities import ThreadSafe from DIRAC.Resources.IdProvider.IdProvider import IdProvider -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getGroupOption, getAllGroups +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getVOMSRoleGroupMapping, getGroupOption, + getAllGroups, wrapIDAsDN) from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token __RCSID__ = "$Id$" @@ -165,17 +166,25 @@ def verifyToken(self, accessToken=None, jwks=None): return S_ERROR(repr(e)) @gRefreshToken - def refreshToken(self, refresh_token=None): + def refreshToken(self, refresh_token=None, group=None, **kwargs): """ Refresh token :param str token: refresh_token + :param str group: DIRAC group :return: dict """ + if group: + # If group set add group scopes to request + result = self.getGroupScopes(group) + if not result['OK']: + return result + kwargs.update(dict(scope=list_to_scope(result['Value']))) + if not refresh_token: refresh_token = self.token.get('refresh_token') try: - token = self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token) + token = self.refresh_token(self.get_metadata('token_endpoint'), refresh_token=refresh_token, **kwargs) return S_OK(OAuth2Token(dict(token))) except Exception as e: self.log.exception(e) @@ -267,7 +276,7 @@ def parseBasic(self, claimDict): self.log.debug('Token payload:', pprint.pformat(claimDict)) credDict = {} credDict['ID'] = claimDict['sub'] - credDict['DN'] = '/O=DIRAC/CN=%s' % credDict['ID'] + credDict['DN'] = wrapIDAsDN(credDict['ID']) if claimDict.get('scope'): self.log.debug('Search groups for %s scope.' % claimDict['scope']) credDict['DIRACGroups'] = self.getScopeGroups(claimDict['scope']) @@ -320,8 +329,11 @@ def deviceAuthorization(self, group=None): response = result['Value'] # Notify user to go to authorization endpoint - showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], - response['verification_uri']) + if response.get('verification_uri_complete'): + showURL = 'Use next link to continue"\n%s' % response['verification_uri_complete'] + else: + showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], + response['verification_uri']) self.log.notice(showURL) try: return self.waitFinalStatusOfDeviceCodeAuthorizationFlow(response['device_code']) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 02f8991fc83..4be042d4f5c 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -24,74 +24,70 @@ 'setup': 'setup', 'group': 'my_group'} -exp_payload = payload.copy() -exp_payload['iat'] = int(time.time()) - 10 -exp_payload['exp'] = int(time.time()) - 10 - DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), expires_at=int(time.time()) + 3600) New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), - issued_at=int(time.time()), expires_in=int(time.time()) + 3600) -Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), - expires_at=int(time.time()) - 10) - -def test_cryptToken(): - """ Try to encrypt/decrypt refresh token +def test_RefreshToken(): + """ Try to revoke/save/get refresh tokens """ - data = dict(client_id='clientID', provider='provider', expires_at=DToken['expires_at']) - result = db.encryptRefreshToken(DToken.copy(), data.copy()) - assert result['OK'], result['Message'] - assert result['Value']['refresh_token'] != DToken['refresh_token'] - - result = db.decryptRefreshToken({'refresh_token': result['Value']['refresh_token']}) - assert result['OK'], result['Message'] - assert result['Value']['refresh_token'] == DToken['refresh_token'] - for k in data: - assert result['Value'][k] == data[k] - + preset_jti = '123' -def test_Token(): - """ Try to revoke/save/get tokens - """ - # Remove all tokens - result = db.removeTokens() + # Remove refresh token + result = db.revokeRefreshToken(preset_jti) assert result['OK'], result['Message'] # Store tokens - result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC') + result = db.storeRefreshToken(DToken.copy(), preset_jti) assert result['OK'], result['Message'] - assert result['Value'] == [] + assert result['Value']['jti'] == preset_jti + assert result['Value']['iat'] <= int(time.time()) + + result = db.storeRefreshToken(New_DToken.copy()) + assert result['OK'], result['Message'] + assert result['Value']['jti'] + assert result['Value']['iat'] <= int(time.time()) - # Expired token - # result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC') - # assert not result['OK'] + token_id = result['Value']['jti'] + issued_at = result['Value']['iat'] # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + result = db.getCredentialByRefreshToken(preset_jti) assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti assert result['Value']['access_token'] == DToken['access_token'] assert result['Value']['refresh_token'] == DToken['refresh_token'] - # Store new tokens - result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC') + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == token_id + assert int(result['Value']['issued_at']) == issued_at + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] + + # Check token after request + for jti in [token_id, preset_jti]: + result = db.getCredentialByRefreshToken(jti) + assert result['OK'], result['Message'] + assert not result['Value'] + + # Renew tokens + result = db.storeRefreshToken(New_DToken.copy(), token_id) + assert result['OK'], result['Message'] + + # Revoke token + result = db.revokeRefreshToken(token_id) assert result['OK'], result['Message'] - # Must return old tokens - assert len(result['Value']) == 1 - assert result['Value'][0]['access_token'] == DToken['access_token'] - assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + result = db.getCredentialByRefreshToken(token_id) assert result['OK'], result['Message'] - assert result['Value']['access_token'] == New_DToken['access_token'] - assert result['Value']['refresh_token'] == New_DToken['refresh_token'] + assert not result['Value'] def test_keys(): From ff2792f028b6e118376e14a8e004ba9d4cb8c8dd Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 20:02:19 +0200 Subject: [PATCH 057/178] add DIRAC_USE_ACCESS_TOKEN description --- .../environment_variable_configuration.rst | 16 +++++++++++----- .../Tornado/Client/private/TornadoBaseClient.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index d93feefb312..7a613c9a605 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -21,9 +21,6 @@ DIRAC_DEPRECATED_FAIL If set, the use of functions or objects that are marked ``@deprecated`` will fail. Useful for example in continuous integration tests against future versions of DIRAC -DIRAC_FEWER_CFG_LOCKS - If ``true`` or ``yes`` or ``on`` or ``1`` or ``y`` or ``t``, DIRAC will reduce the number of locks used when accessing the CS for better performance (default, ``no``). - DIRAC_GFAL_GRIDFTP_SESSION_REUSE If set to ``true`` or ``yes`` the GRIDFT SESSION RESUSE option will be set to True, should be set on server installations. See the information in the :ref:`resourcesStorageElement` page. @@ -50,6 +47,12 @@ DIRAC_M2CRYPTO_SSL_METHODS DIRAC_NO_CFG If set to anything, cfg files on the command line must be passed to the command using the --cfg option. +DIRAC_USE_NEWTHREADPOOL + If this environment is set to ``true`` or ``yes``, the concurrent.futures.ThreadPoolExecutor will be used (default=Yes) + +DIRAC_USE_M2CRYPTO + If anything else than ``true`` or ``yes`` (default) DIRAC will revert back to using pyGSI instead of m2crypto for handling certificates, proxies, etc. + DIRACSYSCONFIG If set, its value should be (the full locations on the file system of) one of more DIRAC cfg file(s) (comma separated), whose content will be used for the DIRAC configuration (see :ref:`dirac-cs-structure`) @@ -67,7 +70,10 @@ X509_VOMSES Must be set to point to a folder containing VOMSES information. See :ref:`multi_vo_dirac` BEARER_TOKEN - If the environment variable is set, then the value is taken to be the token contents + If the environment variable is set, then the value is taken to be the token contents(https://doi.org/10.5281/zenodo.3937438). BEARER_TOKEN_FILE - If the environment variable is set, then its value is interpreted as a filename. The contents of thespecified file are taken to be the token contents. + If the environment variable is set, then its value is interpreted as a filename. The content of the specified file is used as token string(https://doi.org/10.5281/zenodo.3937438). + +DIRAC_USE_ACCESS_TOKEN + If this environment is set to ``true`` or ``yes``, the concurrent.futures.ThreadPoolExecutor will be used (default=false) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index d60e7a6cc23..7fcb02ac244 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -237,7 +237,7 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: - self.__useAccessToken = os.environ['DIRAC_USE_ACCESS_TOKEN'] + self.__useAccessToken = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ['yes', 'true'] if self.__useAccessToken: result = IdProviderFactory().getIdProvider('DIRACCLI') From a06b0358f31f015d7b583ba001968d551a690223 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 16 Jun 2021 20:28:31 +0200 Subject: [PATCH 058/178] fix --- .../FrameworkSystem/private/authorization/utils/Clients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index 19829524257..b7836dd1307 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -54,7 +54,7 @@ def getDIACClientByID(clientID): if clientID == clients[cli]['client_id']: gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) return Client(clients[cli]) - return Client(data) + return Client(data) return None From dd1a921c9c0d48865682b47681ae1c78d55aaf04 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 17 Jun 2021 02:14:37 +0200 Subject: [PATCH 059/178] fix issues --- .../Client/Helpers/Registry.py | 4 +- .../Client/private/TornadoBaseClient.py | 1 + .../Core/Tornado/Server/HandlerManager.py | 9 +-- .../Server/private/BaseRequestHandler.py | 2 +- .../Service/TokenManagerHandler.py | 7 +-- .../private/authorization/AuthServer.py | 6 +- .../authorization/grants/RevokeToken.py | 23 +++++--- .../private/authorization/utils/Clients.py | 3 +- .../private/authorization/utils/Tokens.py | 43 +++++++++----- .../FrameworkSystem/scripts/dirac_login.py | 33 +++++++++-- .../FrameworkSystem/scripts/dirac_logout.py | 56 +++++++++---------- .../Resources/IdProvider/CheckInIdProvider.py | 2 +- .../Resources/IdProvider/IdProviderFactory.py | 13 +---- src/DIRAC/Resources/IdProvider/Utilities.py | 9 ++- tests/Integration/Framework/Test_AuthDB.py | 2 +- 15 files changed, 120 insertions(+), 93 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index 7d462612324..fc2a8fb356e 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -713,7 +713,7 @@ def wrapIDAsDN(userID): """ Wrap user ID as user DN :param str userID: user ID - + :return: str """ return '/O=DIRAC/CN=%s' % userID @@ -723,7 +723,7 @@ def getIDFromDN(userDN): """ Parse user ID from user DN :param str userDN: user DN - + :return: str """ return userDN.strip('/O=DIRAC/CN=') diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 7fcb02ac244..4e7046a1b29 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -520,6 +520,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Use access token? elif self.__useAccessToken: + # Read token from token environ variable or from token file result = getLocalTokenDict() if not result['OK']: return result diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index cf5d56fc386..ce28b4a31b4 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -11,6 +11,7 @@ import inspect from six import string_types +from six.moves.urllib.parse import urlparse from tornado.web import url as TornadoURL, RequestHandler from DIRAC import gConfig, gLogger, S_ERROR, S_OK @@ -200,16 +201,16 @@ def loadServicesHandlers(self, services=None): self.__addHandler(module['loadName'], module['classObj'], url, ports.get(module['modName'])) return S_OK() - def __extractPorts(self, urls): - """ Extract ports from urls + def __extractPorts(self, serviceURIs): + """ Extract ports from serviceURIs - :param list urls: urls that can contain port, .e.g:: System/Service:port + :param list serviceURIs: list of uri that can contain port, .e.g:: System/Service:port :return: (dict, list) """ portMapping = {} newURLs = [] - for _url in urls: + for _url in serviceURIs: if ':' in _url: urlTuple = _url.split(':') if urlTuple[0] not in portMapping: diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index f95e692621e..6bc4e1f9b26 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -489,7 +489,7 @@ def _finishFuture(self, retVal): # Here it is safe to write back to the client, because we are not in a thread anymore # If you need to end the method using tornado methods, outside the thread, - # you need to define the finish_ method. + # you need to define the finish_ method. # This method will be started after __executeMethod is completed. try: finishFunc = eval('self.finish_%s' % self.method) diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index 9867f7a72e6..d661cb065e2 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -1,4 +1,4 @@ -""" TokenManagement service +""" TokenManager service .. literalinclude:: ../ConfigTemplate.cfg :start-after: ##BEGIN TokenManager: @@ -115,7 +115,7 @@ def __checkProperties(self, requestedUserDN, requestedUserGroup): return S_OK(True) if Properties.PRIVATE_LIMITED_DELEGATION in credDict['properties']: if credDict['DN'] != requestedUserDN: - return S_ERROR("You are not allowed to download any proxy") + return S_ERROR("You are not allowed to download any token") if Properties.PRIVATE_LIMITED_DELEGATION not in Registry.getPropertiesForGroup(requestedUserGroup): return S_ERROR("You can't download tokens for that group") return S_OK(True) @@ -170,8 +170,7 @@ def export_deleteToken(self, userDN): if Properties.PROXY_MANAGEMENT not in credDict['properties']: if userDN != credDict['DN']: return S_ERROR("You aren't allowed!") - retVal = self.__tokenDB.removeToken(user_id=Registry.getIDFromDN(dn)) + retVal = self.__tokenDB.removeToken(user_id=Registry.getIDFromDN(userDN)) if not retVal['OK']: return retVal - self.__tokenDB.logAction("delete proxy", credDict['DN'], credDict['group'], userDN, userGroup) return S_OK() diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index bb74f45ed69..ae30e654077 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -46,9 +46,9 @@ def collectMetadata(issuer=None): """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: https://datatracker.ietf.org/doc/html/rfc8414#section-2 - + :param str issuer: issuer to set - + :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time """ result = getAuthorizationServerMetadata(issuer) @@ -189,7 +189,7 @@ def __signToken(self, payload): except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) - + def registerRefreshToken(self, payload, token): """ Register refresh token to protect it from reuse diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py index c3f466c333b..5273d9b2b6b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -16,25 +16,32 @@ def query_token(self, token, token_type_hint, client): :param str token_type_hint: token type :param client: client - :return: str + :return: dict """ if token_type_hint == 'refresh_token': - result = self.server.db.decryptRefreshToken({'refresh_token': token}) + result = self.server.db.getJWKs() + if not result['OK']: + raise OAuth2Error(result['Message']) + jwks = result['Value'] + rtDict = jwt.decode(refresh_token, JsonWebKey.import_key_set(jwks)) + result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) if not result['OK']: raise OAuth2Error(result['Message']) return result['Value'] - return token + return {token_type_hint: token} def revoke_token(self, token): """ Mark the give token as revoked. :param dict token: token dict """ - if isinstance(token, dict): - result = self.server.idps.getIdProvider(token['provider']) - else: - result = self.server.idps.getIdProviderForToken(token) + result = self.server.idps.getIdProviderForToken(token['access_token']) + if not result['OK']: + raise OAuth2Error(result['Message']) if result['OK']: - result = result['Value'].revokeToken(token) + for tokenType in token: + result = result['Value'].revokeToken(token[tokenType], tokenType) + if not result['OK']: + self.server.log.error(result['Message']) if not result['OK']: raise OAuth2Error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index b7836dd1307..ece4c74b5a0 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -39,7 +39,7 @@ def getDIRACClients(): else: clients[cli].update(confClients[cli]) return clients - + def getDIACClientByID(clientID): """ Search authorization client. @@ -54,7 +54,6 @@ def getDIACClientByID(clientID): if clientID == clients[cli]['client_id']: gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) return Client(clients[cli]) - return Client(data) return None diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 397a68404a5..b3963d8aa78 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -20,15 +20,22 @@ from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin +BEARER_TOKEN_ENV = 'BEARER_TOKEN' +BEARER_TOKEN_FILE_ENV = 'BEARER_TOKEN_FILE' -def getTokenLocation(): + +def getTokenFileLocation(fileName=None): """ Research token file location. Use the bearer token discovery protocol defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. + :param str fileName: file name to dump to + :return: str """ - if os.environ.get('BEARER_TOKEN_FILE'): - return os.environ['BEARER_TOKEN_FILE'] + if fileName: + return fileName + if os.environ.get(BEARER_TOKEN_FILE_ENV): + return os.environ[BEARER_TOKEN_FILE_ENV] elif os.environ.get('XDG_RUNTIME_DIR'): return "%s/bt_u%s" % (os.environ['XDG_RUNTIME_DIR'], os.getuid()) else: @@ -39,15 +46,21 @@ def getLocalTokenDict(location=None): """ Search local token. Use the bearer token discovery protocol defined by the WLCG (https://doi.org/10.5281/zenodo.3937438) to find one. - :param str location: environ variable name or file path + :param str location: token file path :return: S_OK(dict)/S_ERROR() """ - env = (location if location and location.startswith('/') else None) or 'BEARER_TOKEN' - token = os.environ.get(env, "").strip() - if token: - return S_OK(OAuth2Token(token)) - return readTokenFromFile(location if location and location.startswith('/') else None) + result = readTokenFromEnv() + return result if result['OK'] and result['Value'] else readTokenFromFile(location) + + +def readTokenFromEnv(): + """ Read token from an environ variable + + :return: S_OK(dict or None) + """ + token = os.environ.get(BEARER_TOKEN_ENV, "").strip() + return S_OK(OAuth2Token(token) if token else None) def readTokenFromFile(fileName=None): @@ -55,15 +68,15 @@ def readTokenFromFile(fileName=None): :param str fileName: filename to read - :return: S_OK(dict)/S_ERROR() + :return: S_OK(dict or None)/S_ERROR() """ - location = fileName or getTokenLocation() + location = getTokenFileLocation(fileName) try: with open(location, 'rt') as f: token = f.read().strip() except IOError as e: return S_ERROR(DErrno.EOF, "Can't open %s token file.\n%s" % (location, repr(e))) - return S_OK(OAuth2Token(token)) + return S_OK(OAuth2Token(token) if token else None) def writeToTokenFile(tokenContents, fileName): @@ -74,7 +87,7 @@ def writeToTokenFile(tokenContents, fileName): :return: S_OK(str)/S_ERROR() """ - location = fileName or getTokenLocation() + location = getTokenFileLocation(fileName) try: with open(location, 'wt') as fd: fd.write(tokenContents) @@ -95,7 +108,7 @@ def writeTokenDictToTokenFile(tokenDict, fileName=None): :return: S_OK(str)/S_ERROR() """ - fileName = fileName or getTokenLocation() + fileName = getTokenFileLocation(fileName) if not isinstance(tokenDict, dict): return S_ERROR('Token is not a dictionary') return writeToTokenFile(json.dumps(tokenDict), fileName) @@ -174,7 +187,7 @@ def get_claim(self, claim, token_type='access_token'): :return: str """ return self.get_payload(token_type).get(claim) - + def dump_to_string(self): """ Dump token dictionary to sting diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index fb2ba44345f..526202bb414 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -24,7 +24,8 @@ from DIRAC.Core.Security.ProxyFile import writeToProxyFile from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import writeTokenDictToTokenFile, readTokenFromFile +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, readTokenFromFile, + getTokenFileLocation) __RCSID__ = "$Id$" @@ -32,6 +33,7 @@ class Params(object): def __init__(self): + self.provider = 'DIRACCLI' self.proxy = False self.group = None self.lifetime = None @@ -123,7 +125,7 @@ def doOAuthMagic(self): params = {} if self.issuer: params['issuer'] = self.issuer - result = IdProviderFactory().getIdProvider('DIRACCLI', **params) + result = IdProviderFactory().getIdProvider(self.provider, **params) if not result['OK']: return result idpObj = result['Value'] @@ -136,23 +138,42 @@ def doOAuthMagic(self): scope.append('lifetime:%s' % (int(self.lifetime) * 3600)) idpObj.scope = '+'.join(scope) if scope else None + tokenFile = getTokenFileLocation(self.tokenLoc) + # Submit Device authorisation flow result = idpObj.deviceAuthorization() if not result['OK']: return result if self.proxy: + # Save new proxy certificate result = writeToProxyFile(idpObj.token['proxy'].encode("UTF-8"), self.proxyLoc) if not result['OK']: return result gLogger.notice('Proxy is saved to %s.' % self.proxyLoc) else: - result = writeTokenDictToTokenFile(idpObj.token, self.tokenLoc) + # Revoke old tokens from token file + if os.path.isfile(tokenFile): + result = readTokenFromFile(tokenFile) + if not result['OK']: + gLogger.error(result['Message']) + elif result['Value']: + oldToken = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(oldToken[tokenType], tokenType) + if result['OK']: + gLogger.notice('%s is revoked from' % tokenType, tokenFile) + else: + gLogger.error(result['Message']) + + # Save new tokens to token file + result = writeTokenDictToTokenFile(idpObj.token, tokenFile) if not result['OK']: return result - self.tokenLoc = result['Value'] - gLogger.notice('Token is saved in %s.' % self.tokenLoc) + tokenFile = result['Value'] + gLogger.notice('New token is saved to %s.' % tokenFile) + # Try to get user information result = Script.enableCS() if not result['OK']: return S_ERROR("Cannot contact CS to get user list") @@ -164,7 +185,7 @@ def doOAuthMagic(self): return result['Message'] gLogger.notice(formatProxyInfoAsString(result['Value'])) else: - result = readTokenFromFile(self.tokenLoc) + result = readTokenFromFile(tokenFile) if not result['OK']: return result gLogger.notice(result['Value'].getInfoAsString()) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py index 93b5ade9ee5..1f6f8eb1d81 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py @@ -21,7 +21,8 @@ from DIRAC.Core.Base import Script from DIRAC.Core.Utilities.DIRACScript import DIRACScript from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import readTokenFromFile, getTokenLocation +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (readTokenFromFile, readTokenFromEnv, + getTokenFileLocation, BEARER_TOKEN_ENV) __RCSID__ = "$Id$" @@ -31,17 +32,7 @@ class Params(object): def __init__(self): self.provider = 'DIRACCLI' self.issuer = None - self.tokenLoc = None - - def setProvider(self, arg): - """ Set provider name - - :param str arg: provider - - :return: S_OK() - """ - self.provider = arg - return S_OK() + self.tokenFileLoc = None def setIssuer(self, arg): """ Set issuer @@ -60,16 +51,11 @@ def setTokenFile(self, arg): :return: S_OK() """ - self.tokenLoc = arg + self.tokenFileLoc = arg return S_OK() def registerCLISwitches(self): """ Register CLI switches """ - Script.registerSwitch( - "O:", - "provider=", - "set identity provider", - self.setProvider) Script.registerSwitch( "I:", "issuer=", @@ -86,6 +72,7 @@ def doOAuthMagic(self): :return: S_OK()/S_ERROR() """ + tokens = [] params = {} if self.issuer: params['issuer'] = self.issuer @@ -93,19 +80,26 @@ def doOAuthMagic(self): if not result['OK']: return result idpObj = result['Value'] - self.tokenLoc = self.tokenLoc or getTokenLocation() - result = readTokenFromFile(self.tokenLoc) - if not result['OK']: - return result - token = result['Value'] - # Revoke token - for tokenType in ['access_token', 'refresh_token']: - if token.get(tokenType): - result = idpObj.revokeToken(token[tokenType], tokenType) - if not result['OK']: - gLogger.error(result['Message']) - os.unlink(self.tokenLoc) - gLogger.notice('Token is removed from %s.' % self.tokenLoc) + tokenFile = getTokenFileLocation(self.tokenFileLoc) + + # Try to find token in environ and in a token file and revoke it + for result, location in [(readTokenFromEnv(), BEARER_TOKEN_ENV), + (readTokenFromFile(tokenFile), tokenFile)]: + if not result['OK']: + gLogger.error(result['Message']) + elif result['Value']: + token = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(token[tokenType], tokenType) + if result['OK']: + gLogger.notice('%s is revoked from' % tokenType, location) + else: + gLogger.error(result['Message']) + + # After remove token file + if os.path.isfile(tokenFile): + os.unlink(tokenFile) + gLogger.notice('%s token file is removed.' % tokenFile) return S_OK() diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py index a8bf1c8b8cb..dcf19c8bb3d 100644 --- a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -11,4 +11,4 @@ class CheckInIdProvider(OAuth2IdProvider): - pass \ No newline at end of file + pass diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index a9f3fc9ea4f..526783423a4 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -59,16 +59,9 @@ def getIdProviderForToken(self, token): options=dict(verify_signature=False, verify_aud=False))['iss'].strip('/') result = getSettingsNamesForIdPIssuer(issuer) - if result['OK']: - return self.getIdProvider(result['Value'][0]) - - _result = getAuthorizationServerMetadata() - if not _result['OK']: - return _result - if issuer == _result['Value'].get('issuer', '').strip('/'): - return self.getIdProvider(DEFAULT_CLIENTS.keys()[0]) - - return result + if not result['OK']: + return result + return self.getIdProvider(result['Value']) def getIdProvider(self, name, **kwargs): """ This method returns a IdProvider instance corresponding to the supplied diff --git a/src/DIRAC/Resources/IdProvider/Utilities.py b/src/DIRAC/Resources/IdProvider/Utilities.py index d363c4d76a6..05822524ead 100644 --- a/src/DIRAC/Resources/IdProvider/Utilities.py +++ b/src/DIRAC/Resources/IdProvider/Utilities.py @@ -10,21 +10,20 @@ def getSettingsNamesForIdPIssuer(issuer): - """ Get identity providers for issuer + """ Get identity provider for issuer :param str issuer: issuer - :return: S_OK(list)/S_ERROR() + :return: S_OK(str)/S_ERROR() """ - names = [] result = getProvidersForInstance('Id') if not result['OK']: return result for name in result['Value']: nameIssuer = gConfig.getValue('/Resources/IdProviders/%s/issuer' % name) if nameIssuer and issuer.strip('/') == nameIssuer.strip('/'): - names.append(name) - return S_OK(names) if names else S_ERROR('Not found provider with %s issuer.' % issuer) + return S_OK(name) + return S_ERROR('Not found provider with %s issuer.' % issuer) def getSettingsNamesForClientID(clientID): diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 4be042d4f5c..6fd498711cd 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -47,7 +47,7 @@ def test_RefreshToken(): assert result['OK'], result['Message'] assert result['Value']['jti'] == preset_jti assert result['Value']['iat'] <= int(time.time()) - + result = db.storeRefreshToken(New_DToken.copy()) assert result['OK'], result['Message'] assert result['Value']['jti'] From 634658dda092d6af0b7a7913b4e22e045de29930 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 17 Jun 2021 10:15:34 +0200 Subject: [PATCH 060/178] fix bug --- .../private/authorization/grants/RevokeToken.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py index 5273d9b2b6b..a7c4d9ee781 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +from authlib.jose import JsonWebKey, jwt from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint @@ -23,7 +24,7 @@ def query_token(self, token, token_type_hint, client): if not result['OK']: raise OAuth2Error(result['Message']) jwks = result['Value'] - rtDict = jwt.decode(refresh_token, JsonWebKey.import_key_set(jwks)) + rtDict = jwt.decode(token, JsonWebKey.import_key_set(jwks)) result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) if not result['OK']: raise OAuth2Error(result['Message']) From 61fe0f04aa109afc1e2136ad0d5f3313d2adfdf0 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 17 Jun 2021 16:55:49 +0200 Subject: [PATCH 061/178] fix white space --- src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index b3963d8aa78..dfecfae23b3 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -52,7 +52,7 @@ def getLocalTokenDict(location=None): """ result = readTokenFromEnv() return result if result['OK'] and result['Value'] else readTokenFromFile(location) - + def readTokenFromEnv(): """ Read token from an environ variable From fdfa65424a7586221f1888b90d07e83a36a21ff0 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 17 Jun 2021 17:31:54 +0200 Subject: [PATCH 062/178] fix rebase --- .../environment_variable_configuration.rst | 9 +++------ src/DIRAC/Core/Tornado/Server/HandlerManager.py | 1 - src/DIRAC/Core/scripts/install_full.cfg | 2 ++ tests/Jenkins/dirac_ci.sh | 3 --- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index 7a613c9a605..03ab7347cef 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -21,6 +21,9 @@ DIRAC_DEPRECATED_FAIL If set, the use of functions or objects that are marked ``@deprecated`` will fail. Useful for example in continuous integration tests against future versions of DIRAC +DIRAC_FEWER_CFG_LOCKS + If ``true`` or ``yes`` or ``on`` or ``1`` or ``y`` or ``t``, DIRAC will reduce the number of locks used when accessing the CS for better performance (default, ``no``). + DIRAC_GFAL_GRIDFTP_SESSION_REUSE If set to ``true`` or ``yes`` the GRIDFT SESSION RESUSE option will be set to True, should be set on server installations. See the information in the :ref:`resourcesStorageElement` page. @@ -47,12 +50,6 @@ DIRAC_M2CRYPTO_SSL_METHODS DIRAC_NO_CFG If set to anything, cfg files on the command line must be passed to the command using the --cfg option. -DIRAC_USE_NEWTHREADPOOL - If this environment is set to ``true`` or ``yes``, the concurrent.futures.ThreadPoolExecutor will be used (default=Yes) - -DIRAC_USE_M2CRYPTO - If anything else than ``true`` or ``yes`` (default) DIRAC will revert back to using pyGSI instead of m2crypto for handling certificates, proxies, etc. - DIRACSYSCONFIG If set, its value should be (the full locations on the file system of) one of more DIRAC cfg file(s) (comma separated), whose content will be used for the DIRAC configuration (see :ref:`dirac-cs-structure`) diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index ce28b4a31b4..23ca6601747 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -11,7 +11,6 @@ import inspect from six import string_types -from six.moves.urllib.parse import urlparse from tornado.web import url as TornadoURL, RequestHandler from DIRAC import gConfig, gLogger, S_ERROR, S_OK diff --git a/src/DIRAC/Core/scripts/install_full.cfg b/src/DIRAC/Core/scripts/install_full.cfg index 3410055c941..9286e003252 100755 --- a/src/DIRAC/Core/scripts/install_full.cfg +++ b/src/DIRAC/Core/scripts/install_full.cfg @@ -101,6 +101,7 @@ LocalInstallation Databases += FTSDB Databases += ComponentMonitoringDB Databases += ProxyDB + Databases += TokenDB Databases += AuthDB Databases += PilotAgentsDB Databases += AccountingDB @@ -121,6 +122,7 @@ LocalInstallation Services += Framework/SecurityLogging Services += Framework/UserProfileManager Services += Framework/ProxyManager + Services += Framework/TokenManager Services += Framework/Plotting Services += Framework/BundleDelivery Services += Monitoring/Monitoring diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 58efed0009d..00d67d444f0 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -191,9 +191,6 @@ installSite() { echo "==> Completed installation" - # Hotfix to pass tests - pip install authlib==0.15.3 pyjwt==1.7.1 dominate - } From 017e186e9e1a53debfab262c4ee2d0ea9beeab7f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 17 Jun 2021 19:01:50 +0200 Subject: [PATCH 063/178] delete DIRACCLIIdProvider DIRACWebIdProvider --- .../IdProvider/DIRACCLIIdProvider.py | 20 ------------------ .../IdProvider/DIRACWebIdProvider.py | 21 ------------------- 2 files changed, 41 deletions(-) delete mode 100644 src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py delete mode 100644 src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py diff --git a/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py deleted file mode 100644 index 21ac3ca7d4c..00000000000 --- a/src/DIRAC/Resources/IdProvider/DIRACCLIIdProvider.py +++ /dev/null @@ -1,20 +0,0 @@ -""" IdProvider based on OAuth2 protocol -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider -from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata - -__RCSID__ = "$Id$" - - -class DIRACCLIIdProvider(OAuth2IdProvider): - - def fetch_metadata(self, url=None): - """ Fetch metada - """ - self.metadata.update(collectMetadata(self.metadata['issuer'])) - if url: - return self.get(url, withhold_token=True).json() diff --git a/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py b/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py deleted file mode 100644 index 2753310e58d..00000000000 --- a/src/DIRAC/Resources/IdProvider/DIRACWebIdProvider.py +++ /dev/null @@ -1,21 +0,0 @@ -""" IdProvider based on OAuth2 protocol -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider -from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS - -__RCSID__ = "$Id$" - - -class DIRACWebIdProvider(OAuth2IdProvider): - - DEFAULT_METADATA = DEFAULT_CLIENTS['DIRACWeb'] - - def fetch_metadata(self): - """ Fetch metada - """ - self.metadata.update(collectMetadata(self.metadata['issuer'])) From ff08c15097c4fab6084b24ba1590406b4b34bb84 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 22 Jun 2021 20:36:42 +0200 Subject: [PATCH 064/178] fix py3 things --- environment-py3.yml | 4 +- environment.yml | 4 +- requirements.txt | 4 +- setup.cfg | 4 +- .../Core/Tornado/Server/HandlerManager.py | 2 +- src/DIRAC/Core/Tornado/Server/TornadoREST.py | 3 +- .../Server/private/BaseRequestHandler.py | 74 +++++++------------ src/DIRAC/FrameworkSystem/API/AuthHandler.py | 27 +++---- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 38 ++-------- src/DIRAC/FrameworkSystem/DB/TokenDB.py | 2 +- .../private/authorization/AuthServer.py | 61 ++++++++++----- .../authorization/grants/AuthorizationCode.py | 3 +- .../authorization/grants/DeviceFlow.py | 2 +- .../authorization/grants/RefreshToken.py | 5 +- .../authorization/grants/RevokeToken.py | 6 +- .../private/authorization/utils/Clients.py | 16 ---- .../private/authorization/utils/Requests.py | 7 +- .../private/authorization/utils/Tokens.py | 29 ++++++-- tests/Integration/Framework/Test_AuthDB.py | 34 ++++----- tests/Integration/Framework/Test_TokenDB.py | 12 +-- 20 files changed, 160 insertions(+), 177 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index 43be46f9f56..87d335de478 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,8 +89,8 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - authlib <=0.15.3 - - pyjwt <=1.7.1 + - authlib >=1.0.0 + - pyjwt >=2.1.0 - dominate - pip: # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/environment.yml b/environment.yml index 63c52b2ad65..02d8d146250 100644 --- a/environment.yml +++ b/environment.yml @@ -73,8 +73,8 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - authlib <=0.15.3 - - pyjwt <=1.7.1 + - authlib >=1.0.0 + - pyjwt >=2.1.0 - dominate - pip: - diraccfg diff --git a/requirements.txt b/requirements.txt index ed1477c0d8a..f32f12f505a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -authlib <=0.15.3 -pyjwt <=1.7.1 +authlib >=1.0.0 +pyjwt >=2.1.0 dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index b76f5702f78..97564191d9c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,8 +55,8 @@ install_requires = six sqlalchemy subprocess32 - authlib <=0.15.3 - pyjwt <=1.7.1 + authlib >=1.0.0 + pyjwt >=2.1.0 dominate zip_safe = False include_package_data = True diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index 23ca6601747..3a1fba3cd30 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -257,7 +257,7 @@ def loadEndpointsHandlers(self, endpoints=None): urls = [] # Look for methods that are exported for mName, mObj in inspect.getmembers(handler): - if inspect.ismethod(mObj) and mName.find(handler.METHOD_PREFIX) == 0: + if inspect.isroutine(mObj) and mName.find(handler.METHOD_PREFIX) == 0: methodName = mName[len(handler.METHOD_PREFIX):] args = getattr(handler, 'path_%s' % methodName, []) gLogger.debug(" - Route %s/%s -> %s %s" % (handler.LOCATION, methodName, module['loadName'], mName)) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py index 78fd024e266..5857d8281d8 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoREST.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -9,11 +9,12 @@ __RCSID__ = "$Id$" +from functools import partial + import tornado.ioloop from tornado import gen from tornado.web import HTTPError from tornado.ioloop import IOLoop -from six.moves import http_client import DIRAC diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 6bc4e1f9b26..2e0c5ed39dd 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -77,7 +77,7 @@ def _getServiceInfo(cls, serviceName, request): @classmethod def _getServiceAuthSection(cls, serviceName): - ''' Search service auth section. + ''' Search service "Authorization" configuration section. ''' return "%s/Authorization" % PathFinder.getServiceSection(serviceName) @@ -104,9 +104,8 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ # in __executeMethod. This is because these methods are not threadsafe # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes # However, we can still rely on instance attributes to store what should - # be sent back (reminder: there is an instance - # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(self._prepareExecutor(args)) + # be sent back (reminder: there is an instance of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(*self._prepareExecutor(args)) # retVal is :py:class:`tornado.concurrent.Future` self._finishFuture(retVal) @@ -222,9 +221,10 @@ def __loadIdPs(cls): if result['OK']: for providerName in result['Value']: result = cls._idps.getIdProvider(providerName) - if not result['OK']: - gLogger.exception(result['Message']) - cls._idp[result['Value'].issuer.strip('/')] = result['Value'] + if result['OK']: + cls._idp[result['Value'].issuer.strip('/')] = result['Value'] + else: + gLogger.error(result['Message']) @classmethod def __initializeService(cls, request): @@ -361,23 +361,16 @@ def _getMethodAuthProps(self): :return: list """ - try: - return getattr(self, 'auth_' + self.method) - except AttributeError: - if self.AUTH_PROPS and not isinstance(self.AUTH_PROPS, (list, tuple)): - self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] - return self.AUTH_PROPS + if self.AUTH_PROPS and not isinstance(self.AUTH_PROPS, (list, tuple)): + self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] + return getattr(self, 'auth_' + self.method, self.AUTH_PROPS) def _getMethod(self): """ Get method function to call. :return: function """ - try: - method = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method)) - except AttributeError as e: - sLog.error("Invalid method", self.method) - raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + method = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method), None) if not callable(method): sLog.error("Invalid method", self.method) raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) @@ -416,7 +409,7 @@ def _prepare(self): sLog.error( "Error gathering credentials ", "%s; path %s" % (self.getRemoteAddress(), self.request.path)) - raise HTTPError(status_code=http_client.UNAUTHORIZED) + raise HTTPError(http_client.UNAUTHORIZED, str(e)) # Check whether we are authorized to perform the query # Note that performing the authQuery modifies the credDict... @@ -432,9 +425,8 @@ def _prepare(self): "Unauthorized access", "Identity %s; path %s; %s" % (self.srv_getFormattedRemoteCredentials(), self.request.path, extraInfo)) - raise HTTPError(status_code=http_client.UNAUTHORIZED) + raise HTTPError(http_client.UNAUTHORIZED) - @gen.coroutine def __executeMethod(self, targetMethod, args): """ Execute the method called, this method is ran in an executor @@ -462,12 +454,10 @@ def __executeMethod(self, targetMethod, args): # Execute try: self.initializeRequest() - retVal = targetMethod(*args) + return targetMethod(*args) except Exception as e: # pylint: disable=broad-except sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) - raise HTTPError(http_client.INTERNAL_SERVER_ERROR) - - return retVal + raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) def _prepareExecutor(self, args): """ Preparation of necessary arguments for the `__executeMethod` method @@ -483,24 +473,19 @@ def _finishFuture(self, retVal): :param object retVal: tornado.concurrent.Future """ - # Wait result of a Future object - self.result = retVal.result() + self.result = retVal # Here it is safe to write back to the client, because we are not in a thread anymore # If you need to end the method using tornado methods, outside the thread, # you need to define the finish_ method. # This method will be started after __executeMethod is completed. - try: - finishFunc = eval('self.finish_%s' % self.method) - except (NameError, AttributeError): - finishFunc = None - + finishFunc = getattr(self, 'finish_%s' % self.method, None) if callable(finishFunc): finishFunc() # In case nothing is returned - elif self.result is None: + elif retVal is None: self.finish() # If set to true, do not JEncode the return of the RPC call @@ -509,16 +494,16 @@ def _finishFuture(self, retVal): elif self.get_argument('rawContent', default=False): # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt self.set_header("Content-Type", "application/octet-stream") - self.finish(self.result) + self.finish(retVal) # Return simple text or html - elif isinstance(self.result, string_types): - self.finish(self.result) + elif isinstance(retVal, string_types): + self.finish(retVal) # JSON else: self.set_header("Content-Type", "application/json") - self.finish(encode(self.result)) + self.finish(encode(retVal)) def on_finish(self): """ @@ -549,16 +534,13 @@ def _gatherPeerCredentials(self, grants=None): (not a DIRAC structure !) """ err = [] - result = None - - grants = grants or self.USE_AUTHZ_GRANTS - - if not grants: - raise Exception('USE_AUTHZ_GRANTS is not defined.') - for a in grants: - grant = a.upper() - grantFunc = getattr(self, '_authz%s' % grant) + # At least some authorization method must be defined, if nothing is defined, + # the authorization will go through the `_authzVISITOR` method and + # everyone will have access as anonymous@visitor + for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): + grant = grant.upper() + grantFunc = getattr(self, '_authz%s' % grant, None) if not callable(grantFunc): raise Exception('%s authentication type is not supported.' % grant) result = grantFunc() diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 5f2cf697021..b02b09ee607 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -106,20 +106,20 @@ def _finishFuture(self, retVal): :param object retVal: tornado.concurrent.Future """ - self.result = retVal.result() + self.result = retVal # Is it S_OK or S_ERROR? - r = self.result + r = retVal if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: - if not self.result['OK']: + if not retVal['OK']: # S_ERROR is interpreted in the OAuth2 error format. self.set_status(400) - self.write({'error': 'server_error', 'description': self.result['Message']}) + self.write({'error': 'server_error', 'description': retVal['Message']}) self.clear_cookie('auth_session') - self.log.error('%s\n' % self.result['Message'], ''.join(self.result['CallStack'])) + self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) else: # Successful responses and OAuth2 errors are processed here - status_code, headers, payload, new_session, error = self.result['Value'][0] + status_code, headers, payload, new_session, error = retVal['Value'][0] if status_code: self.set_status(status_code) if headers: @@ -131,7 +131,7 @@ def _finishFuture(self, retVal): self.set_secure_cookie('auth_session', json.dumps(new_session), secure=True, httponly=True) if error: self.clear_cookie('auth_session') - for method, args_kwargs in self.result['Value'][1].items(): + for method, args_kwargs in retVal['Value'][1].items(): eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) self.finish() else: @@ -204,7 +204,8 @@ def web_jwk(self): ] } """ - return self.server.db.getJWKs().get('Value', {}) + result = self.server.db.getKeySet() + return result['Value'].as_dict() if result['OK'] else {} def web_revoke(self): """ Revocation endpoint @@ -406,9 +407,10 @@ def web_redirect(self): state = self.get_argument('state') # Try to catch errors - if self.get_argument('error', None): - error = OAuth2Error(error=self.get_argument('error'), description=self.get_argument('error_description', '')) - return self.server.handle_error_response(state, error) + error = self.get_argument('error', None) + if error: + return self.server.handle_error_response( + state, OAuth2Error(error=error, description=self.get_argument('error_description', ''))) # Check current auth session that was initiated for the selected external identity provider try: @@ -422,8 +424,7 @@ def web_redirect(self): if not sessionWithExtIdP.get('authed'): # Parse result of the second authentication flow - self.log.info('%s session, parsing authorization response:\n' % state, - '\n'.join([self.request.uri, self.request.query, self.request.body, str(self.request.headers)])) + self.log.info('%s session, parsing authorization response:\n' % state, self.request.uri) result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) if not result['OK']: diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 94b0a78bad2..8bcbd110633 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -17,8 +17,6 @@ from authlib.jose import KeySet, RSAKey from authlib.common.security import generate_token -from authlib.common.encoding import urlsafe_b64decode, urlsafe_b64encode, to_bytes, to_unicode, json_b64encode -from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin from DIRAC import S_OK, S_ERROR from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB @@ -169,47 +167,27 @@ def generateRSAKeys(self): :return: S_OK/S_ERROR """ key = RSAKey.generate_key(key_size=1024, is_private=True) - dictKey = dict(key=json.dumps(key.as_dict()), - expires_at=time.time() + (30 * 24 * 3600), - kid=KeySet([key]).as_dict()['keys'][0]['kid']) - + keyDict = dict(key=json.dumps(key.as_dict(True)), kid=key.thumbprint(), expires_at=time.time() + (30 * 24 * 3600)) session = self.session() try: - session.add(JWK(**dictKey)) + session.add(JWK(**keyDict)) except Exception as e: return self.__result(session, S_ERROR('Could not generate keys: %s' % e)) - return self.__result(session, S_OK(dictKey)) + return self.__result(session, S_OK(keyDict)) def getKeySet(self): """ Get key set :return: S_OK(obj)/S_ERROR() """ - keys = [] result = self.getActiveKeys() if result['OK'] and not result['Value']: result = self.generateRSAKeys() if result['OK']: - result = self.getActiveKeys() - if not result['OK']: - return result - for keyDict in result['Value']: - key = RSAKey.import_key(json.loads(keyDict['key'])) - keys.append(key) - return S_OK(KeySet(keys)) - - def getJWKs(self): - """ Get JWKs list - - :return: S_OK(dict)/S_ERROR() - """ - keys = [] - result = self.getKeySet() + result['Value'] = [result['Value']] if not result['OK']: return result - for k in result['Value'].as_dict()['keys']: - keys.append({'n': k['n'], "kty": k['kty'], "e": k['e'], "kid": k['kid']}) - return S_OK({'keys': keys}) + return S_OK(KeySet([RSAKey.import_key(json.loads(key['key'])) for key in result['Value']])) def getPrivateKey(self, kid=None): """ Get private key @@ -224,17 +202,17 @@ def getPrivateKey(self, kid=None): jwks = result['Value'] if kid: strkey = jwks[0]['key'] - return S_OK(dict(rsakey=RSAKey.import_key(json.loads(strkey)), kid=kid, strkey=strkey)) + return S_OK(RSAKey.import_key(json.loads(jwks[0]['key']))) newer = {} for jwk in jwks: - if jwk['expires_at'] > newer.get('expires_at', time.time() + (24 * 3600)): + if int(jwk['expires_at']) > int(newer.get('expires_at', time.time() + (24 * 3600))): newer = jwk if not newer.get('key'): result = self.generateRSAKeys() if not result['OK']: return result newer = result['Value'] - return S_OK(dict(rsakey=RSAKey.import_key(json.loads(newer['key'])), kid=newer['kid'], strkey=newer['key'])) + return S_OK(RSAKey.import_key(json.loads(newer['key']))) def getActiveKeys(self, kid=None): """ Get active keys diff --git a/src/DIRAC/FrameworkSystem/DB/TokenDB.py b/src/DIRAC/FrameworkSystem/DB/TokenDB.py index e59f0fd8a3d..6c6a86633c9 100644 --- a/src/DIRAC/FrameworkSystem/DB/TokenDB.py +++ b/src/DIRAC/FrameworkSystem/DB/TokenDB.py @@ -36,7 +36,6 @@ class Token(Model, OAuth2TokenMixin): kid = Column(String(255)) user_id = Column(String(255)) provider = Column(String(255)) - client_id = Column(String(255)) expires_at = Column(Integer, nullable=False, default=0) access_token = Column(Text, nullable=False) refresh_token = Column(Text, nullable=False) @@ -120,6 +119,7 @@ def updateToken(self, token, userID, provider, rt_expired_in): session.query(Token).filter(Token.user_id == userID).filter(Token.provider == provider)\ .filter(Token.access_token != token['access_token']).delete() except Exception as e: + self.log.exception(e) return self.__result(session, S_ERROR('Could not add Token: %s' % repr(e))) self.log.info('Token successfully added for %s user, %s provider' % (token['user_id'], token['provider'])) return self.__result(session, S_OK([self.__rowToDict(t) for t in oldTokens] if oldTokens else [])) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index ae30e654077..48c4ccacc6c 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -11,7 +11,7 @@ from dominate import document, tags as dom from tornado.template import Template -from authlib.jose import jwt +from authlib.jose import JsonWebKey, jwt from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error from authlib.common.security import generate_token @@ -24,7 +24,7 @@ from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, DeviceCodeGrant) from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIACClientByID +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients, Client from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request from DIRAC import gLogger, S_OK, S_ERROR @@ -83,15 +83,16 @@ class AuthServer(_AuthorizationServer): def __init__(self): self.db = AuthDB() self.log = log + self.idps = IdProviderFactory() self.proxyCli = ProxyManagerClient() self.tokenCli = TokenManagerClient() - self.idps = IdProviderFactory() - # Privide two authlib methods query_client and save_token - _AuthorizationServer.__init__(self, query_client=getDIACClientByID, save_token=lambda x, y: None) - self.generate_token = self.generateProxyOrToken - self.config = {} self.metadata = collectMetadata() self.metadata.validate() + _AuthorizationServer.__init__(self, scopes_supported=self.metadata['scopes_supported']) + # Skip authlib method save_token + self.save_token = lambda x, y: None + self.send_signal = lambda *x, **y: None + self.generate_token = self.generateProxyOrToken # Register configured grants self.register_grant(RefreshTokenGrant) self.register_grant(DeviceCodeGrant) @@ -99,6 +100,23 @@ def __init__(self): self.register_endpoint(RevocationEndpoint) self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) + def query_client(self, client_id): + """ Search authorization client. + + :param str clientID: client ID + + :return: object or None + """ + gLogger.debug('Try to query %s client' % client_id) + clients = getDIRACClients() + for cli in clients: + print(clients[cli]['client_id']) + print(client_id) + if client_id == clients[cli]['client_id']: + gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) + return Client(clients[cli]) + return None + def addSession(self, session): self.db.addSession(session) @@ -182,10 +200,25 @@ def __signToken(self, payload): result = self.db.getPrivateKey() if not result['OK']: return result - key = result['Value']['rsakey'] - kid = result['Value']['kid'] + key = result['Value'] try: - return S_OK(jwt.encode(dict(alg='RS256', kid=kid), payload, key)) + return S_OK(jwt.encode(dict(alg='RS256', kid=key.thumbprint()), payload, key).decode('utf-8')) + except Exception as e: + self.log.exception(e) + return S_ERROR(repr(e)) + + def readToken(self, token): + """ Decode self token + + :param str token: token to decode + + :return: S_OK(dict)/S_ERROR() + """ + result = self.db.getKeySet() + if not result['OK']: + return result + try: + return S_OK(jwt.decode(token, JsonWebKey.import_key_set(result['Value'].as_dict()))) except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) @@ -276,11 +309,6 @@ def parseIdPAuthorizationResponse(self, response, session): result = self.tokenCli.updateToken(idpObj.token, credDict['ID'], idpObj.name) return S_OK(credDict) if result['OK'] else result - def get_error_uris(self, request): - error_uris = self.config.get('error_uris') - if error_uris: - return dict(error_uris) - def create_oauth2_request(self, request, method_cls=OAuth2Request, use_json=False): self.log.debug('Create OAuth2 request', 'with json' if use_json else '') return createOAuth2Request(request, method_cls, use_json) @@ -295,8 +323,7 @@ def validate_requested_scope(self, scope, state=None): super(AuthServer, self).validate_requested_scope(extended_scope, state) def handle_error_response(self, request, error): - return self.handle_response(*error(translations=self.get_translations(request), - error_uris=self.get_error_uris(request)), error=True) + return self.handle_response(*error(self.get_error_uri(request, error)), error=True) def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, error=None, **actions): self.log.debug('Handle authorization response with %s status code:' % status_code, payload) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index 88887485f36..b22ed0a152b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -120,5 +120,4 @@ def generate_authorization_code(self): result = self.server.db.getPrivateKey() if not result['OK']: raise OAuth2Error('Cannot generate authorization code: %s' % result['Message']) - key = result['Value']['rsakey'] - return jws.serialize_compact(protected, json_b64encode(dict(code)), key) + return jws.serialize_compact(protected, json_b64encode(dict(code)), result['Value']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index f00f9fd01f9..92730711c8c 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -125,7 +125,7 @@ def query_user_grant(self, user_code): raise OAuth2Error('Cannot found authorization session', result['Message']) return (result['Value']['user_id'], True) if result['Value'].get('username') != "None" else None - def should_slow_down(self, credential, now): + def should_slow_down(self, credential): """ The authorization request is still pending and polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests. """ diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py index 376382c91ec..5be8d22790f 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -22,11 +22,10 @@ def authenticate_refresh_token(self, refresh_token): :return: dict or None """ - result = self.server.db.getJWKs() + result = self.server.readToken(refresh_token) if not result['OK']: raise OAuth2Error(result['Message']) - jwks = result['Value'] - rtDict = jwt.decode(refresh_token, JsonWebKey.import_key_set(jwks)) + rtDict = result['Value'] result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) if not result['OK']: raise OAuth2Error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py index a7c4d9ee781..729a4c7f718 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RevokeToken.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import print_function -from authlib.jose import JsonWebKey, jwt from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint @@ -20,11 +19,10 @@ def query_token(self, token, token_type_hint, client): :return: dict """ if token_type_hint == 'refresh_token': - result = self.server.db.getJWKs() + result = self.server.readToken(token) if not result['OK']: raise OAuth2Error(result['Message']) - jwks = result['Value'] - rtDict = jwt.decode(token, JsonWebKey.import_key_set(jwks)) + rtDict = result['Value'] result = self.server.db.getCredentialByRefreshToken(rtDict['jti']) if not result['OK']: raise OAuth2Error(result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index ece4c74b5a0..f3bbec682fd 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -41,22 +41,6 @@ def getDIRACClients(): return clients -def getDIACClientByID(clientID): - """ Search authorization client. - - :param str clientID: client ID - - :return: object or None - """ - gLogger.debug('Try to query %s client' % clientID) - clients = getDIRACClients() - for cli in clients: - if clientID == clients[cli]['client_id']: - gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) - return Client(clients[cli]) - return None - - class Client(OAuth2ClientMixin): def __init__(self, params): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index e5f38473fa8..5594eb20b91 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -58,9 +58,6 @@ def createOAuth2Request(request, method_cls=OAuth2Request, use_json=False): if isinstance(request, dict): return method_cls(request['method'], request['uri'], request.get('body'), request.get('headers')) if use_json: - body = json_decode(request.body) - else: - body = {} - for k, v in request.body_arguments.items(): - body[k] = ' '.join(v) + return method_cls(request.method, request.full_url(), json_decode(request.body), request.headers) + body = {k:request.body_arguments[k][-1].decode("utf-8") for k in request.body_arguments if request.body_arguments[k]} return method_cls(request.method, request.full_url(), body, request.headers) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index dfecfae23b3..e6402edb864 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -18,7 +18,6 @@ from authlib.oauth2.rfc6749.util import scope_to_list from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token -from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin BEARER_TOKEN_ENV = 'BEARER_TOKEN' BEARER_TOKEN_FILE_ENV = 'BEARER_TOKEN_FILE' @@ -136,17 +135,35 @@ def __init__(self, params=None, **kwargs): kwargs['expires_at'] = self.get_claim('exp') super(OAuth2Token, self).__init__(kwargs) - def get_client_id(self): - return self.get('client_id', self.get('azp')) + def check_client(self, client): + """ A method to check if this token is issued to the given client. + + :param client: client object + + :return: bool + """ + return self.get('client_id', self.get('azp')) == client.client_id def get_scope(self): + """ A method to get scope of the authorization code. + + :return: str + """ return self.get('scope') def get_expires_in(self): - return self.get('expires_in') + """ A method to get the ``expires_in`` value of the token. + + :return: int + """ + return int(self.get('expires_in')) - def get_expires_at(self): - return int(self.get('expires_at', self.get('issued_at') + self.get('expires_in'))) + def is_expired(self): + """ A method to define if this token is expired. + + :return: bool + """ + return int(self.get('expires_at', self.get('issued_at') + self.get('expires_in'))) < time.time() @property def scopes(self): diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 6fd498711cd..226952f6627 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -24,12 +24,12 @@ 'setup': 'setup', 'group': 'my_group'} -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), expires_at=int(time.time()) + 3600) -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), expires_in=int(time.time()) + 3600) @@ -121,37 +121,37 @@ def test_keys(): # Create new one result = db.getPrivateKey() assert result['OK'], result['Message'] - assert type(result['Value']['rsakey']) is RSAKey - assert type(result['Value']['strkey']) is str - # Sign token - header['kid'] = result['Value']['kid'] private_key = result['Value']['rsakey'] + assert private_key is RSAKey + + # Sign token + header['kid'] = private_key.thumbprint() # Find key by KID result = db.getPrivateKey(header['kid']) assert result['OK'], result['Message'] - assert result['Value']['rsakey'] == private_key + assert result['Value'].as_dict(True) == private_key.as_dict(True) + # Sign token token = jwt.encode(header, payload, private_key) # Sign auth code code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) # Get public key set result = db.getKeySet() + keyset = result['Value'] assert result['OK'], result['Message'] - _payload = jwt.decode(token, JsonWebKey.import_key_set(result['Value'].as_dict())) + assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) + + # Read token + _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) assert _payload == payload - data = jws.deserialize_compact(code, result['Value'].keys[0]) + # Read auth code + data = jws.deserialize_compact(code, keyset.keys[0]) _code_payload = json_loads(urlsafe_b64decode(data['payload'])) assert _code_payload == code_payload - # Get JWK - result = db.getJWKs() - assert result['OK'], result['Message'] - _payload = jwt.decode(token, JsonWebKey.import_key_set(result['Value'])) - assert _payload == payload, result['Value'] - def test_Sessions(): """ Try to store/get/remove Sessions diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index 02e53103103..3c3f939adb6 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -27,17 +27,17 @@ exp_payload['iat'] = int(time.time()) - 10 exp_payload['exp'] = int(time.time()) - 10 -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), expires_at=int(time.time()) + 3600) -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret"), +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), issued_at=int(time.time()), expires_in=int(time.time()) + 3600) -Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), - refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret"), +Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), expires_at=int(time.time()) - 10, rt_expires_at=int(time.time()) - 10) From 002f22482965359294b7ff940e5eee020b4b425d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 22 Jun 2021 20:37:03 +0200 Subject: [PATCH 065/178] fix py3 things --- .../FrameworkSystem/private/authorization/grants/RefreshToken.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py index 5be8d22790f..ae2743406b5 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import print_function -from authlib.jose import JsonWebKey, jwt from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant From 559c09252d401dfe41d8e50e712745bb0adaa6d3 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 22 Jun 2021 20:42:56 +0200 Subject: [PATCH 066/178] update authlib to authlib >=1.0.0.a2 --- environment-py3.yml | 2 +- environment.yml | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index 87d335de478..cb3095a4445 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,7 +89,7 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - authlib >=1.0.0 + - authlib >=1.0.0.a2 - pyjwt >=2.1.0 - dominate - pip: diff --git a/environment.yml b/environment.yml index 02d8d146250..bf157e1e9d2 100644 --- a/environment.yml +++ b/environment.yml @@ -73,7 +73,7 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - authlib >=1.0.0 + - authlib >=1.0.0.a2 - pyjwt >=2.1.0 - dominate - pip: diff --git a/requirements.txt b/requirements.txt index f32f12f505a..495b22a5a39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -authlib >=1.0.0 +authlib >=1.0.0.a2 pyjwt >=2.1.0 dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 97564191d9c..3b9a8a6a337 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ install_requires = six sqlalchemy subprocess32 - authlib >=1.0.0 + authlib >=1.0.0.a2 pyjwt >=2.1.0 dominate zip_safe = False From 2026596ab2796f935285ed2b33a7b267a5395913 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 22 Jun 2021 21:03:35 +0200 Subject: [PATCH 067/178] update authlib --- environment-py3.yml | 2 +- environment.yml | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index cb3095a4445..e2f1e83df6a 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,7 +89,7 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - authlib >=1.0.0.a2 + - Authlib - pyjwt >=2.1.0 - dominate - pip: diff --git a/environment.yml b/environment.yml index bf157e1e9d2..7f9c6c06212 100644 --- a/environment.yml +++ b/environment.yml @@ -73,7 +73,7 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - authlib >=1.0.0.a2 + - Authlib - pyjwt >=2.1.0 - dominate - pip: diff --git a/requirements.txt b/requirements.txt index 495b22a5a39..7c2de1a10d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -authlib >=1.0.0.a2 +Authlib pyjwt >=2.1.0 dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 3b9a8a6a337..c8c8e1ef741 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ install_requires = six sqlalchemy subprocess32 - authlib >=1.0.0.a2 + Authlib pyjwt >=2.1.0 dominate zip_safe = False From 959ec69e663695768d52caad219e3219f11b2e1e Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 22 Jun 2021 23:59:46 +0200 Subject: [PATCH 068/178] update authlib --- setup.cfg | 2 +- tests/Jenkins/dirac_ci.sh | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index c8c8e1ef741..2c2eef64697 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ install_requires = six sqlalchemy subprocess32 - Authlib + Authlib==1.0.0.a2 pyjwt >=2.1.0 dominate zip_safe = False diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 00d67d444f0..ae2b1d95407 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -189,6 +189,7 @@ installSite() { exit 1 fi + pip install -I Authlib==1.0.0.a2 echo "==> Completed installation" } From 07d4e802f2497b1fa7f63881ea95984d02c6cf64 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 00:16:58 +0200 Subject: [PATCH 069/178] update authlib --- environment-py3.yml | 2 +- environment.yml | 2 +- requirements.txt | 2 +- tests/Jenkins/dirac_ci.sh | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index e2f1e83df6a..578c5db8d0a 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,7 +89,7 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - Authlib + - Authlib==1.0.0.a2 - pyjwt >=2.1.0 - dominate - pip: diff --git a/environment.yml b/environment.yml index 7f9c6c06212..80e5ec6d184 100644 --- a/environment.yml +++ b/environment.yml @@ -73,7 +73,7 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - Authlib + - Authlib==1.0.0.a2 - pyjwt >=2.1.0 - dominate - pip: diff --git a/requirements.txt b/requirements.txt index 7c2de1a10d5..8814b94dbb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -Authlib +Authlib==1.0.0.a2 pyjwt >=2.1.0 dominate \ No newline at end of file diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index ae2b1d95407..00d67d444f0 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -189,7 +189,6 @@ installSite() { exit 1 fi - pip install -I Authlib==1.0.0.a2 echo "==> Completed installation" } From 65980341758f92c97b05a871fd03fe84858b1899 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 10:27:54 +0200 Subject: [PATCH 070/178] update authlib --- environment-py3.yml | 3 ++- environment.yml | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index 578c5db8d0a..edcce165913 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -89,10 +89,11 @@ dependencies: - typing >=3.6.6 - pyyaml # OAuth2 - - Authlib==1.0.0.a2 - pyjwt >=2.1.0 - dominate - pip: + # Prerelease of the required package for integration of OAuth2 + - Authlib>=1.0.0.a2 # This is a fork of tornado with a patch to allow for configurable iostream # It should eventually be part of DIRACGrid - git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable diff --git a/environment.yml b/environment.yml index 80e5ec6d184..fb949ad10bc 100644 --- a/environment.yml +++ b/environment.yml @@ -73,7 +73,7 @@ dependencies: - openssl <1.1 - selectors2 # OAuth2 - - Authlib==1.0.0.a2 + - authlib - pyjwt >=2.1.0 - dominate - pip: diff --git a/requirements.txt b/requirements.txt index 8814b94dbb3..ad9eab50a04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -Authlib==1.0.0.a2 +Authlib>=1.0.0.a2 pyjwt >=2.1.0 dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 2c2eef64697..04f5010b25c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ install_requires = six sqlalchemy subprocess32 - Authlib==1.0.0.a2 + Authlib>=1.0.0.a2 pyjwt >=2.1.0 dominate zip_safe = False From 99391b2e96043eff7496ad545a84a53f3184ef70 Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Wed, 23 Jun 2021 13:43:22 +0200 Subject: [PATCH 071/178] Update src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py Co-authored-by: Chris Burr --- src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index fc2a8fb356e..03735ba7a12 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -716,7 +716,7 @@ def wrapIDAsDN(userID): :return: str """ - return '/O=DIRAC/CN=%s' % userID + return '/O=DIRAC/CN=' + userID def getIDFromDN(userDN): From 7de26b0f707aa5c314c03ee5ae10f2caa3d1fe8c Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Wed, 23 Jun 2021 13:50:35 +0200 Subject: [PATCH 072/178] Update docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst Co-authored-by: Chris Burr --- .../ServerInstallations/environment_variable_configuration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index 03ab7347cef..8cef39b579b 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -70,7 +70,7 @@ BEARER_TOKEN If the environment variable is set, then the value is taken to be the token contents(https://doi.org/10.5281/zenodo.3937438). BEARER_TOKEN_FILE - If the environment variable is set, then its value is interpreted as a filename. The content of the specified file is used as token string(https://doi.org/10.5281/zenodo.3937438). + If the environment variable is set, then its value is interpreted as a filename. The content of the specified file is used as token string (https://doi.org/10.5281/zenodo.3937438). DIRAC_USE_ACCESS_TOKEN If this environment is set to ``true`` or ``yes``, the concurrent.futures.ThreadPoolExecutor will be used (default=false) From e693ca653af6cf592a4ed7f455ed800e5b3f0676 Mon Sep 17 00:00:00 2001 From: Andrii Lytovchenko Date: Wed, 23 Jun 2021 13:50:46 +0200 Subject: [PATCH 073/178] Update docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst Co-authored-by: Chris Burr --- .../ServerInstallations/environment_variable_configuration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst index 8cef39b579b..0cb12e4fbef 100644 --- a/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst +++ b/docs/source/AdministratorGuide/ServerInstallations/environment_variable_configuration.rst @@ -67,7 +67,7 @@ X509_VOMSES Must be set to point to a folder containing VOMSES information. See :ref:`multi_vo_dirac` BEARER_TOKEN - If the environment variable is set, then the value is taken to be the token contents(https://doi.org/10.5281/zenodo.3937438). + If the environment variable is set, then the value is taken to be the token contents (https://doi.org/10.5281/zenodo.3937438). BEARER_TOKEN_FILE If the environment variable is set, then its value is interpreted as a filename. The content of the specified file is used as token string (https://doi.org/10.5281/zenodo.3937438). From ed3796e124de833a9957d69cfe452c237892e892 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 14:04:59 +0200 Subject: [PATCH 074/178] remove TokenManager service from local CS --- src/DIRAC/Core/scripts/install_full.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DIRAC/Core/scripts/install_full.cfg b/src/DIRAC/Core/scripts/install_full.cfg index 9286e003252..f5b823b3e74 100755 --- a/src/DIRAC/Core/scripts/install_full.cfg +++ b/src/DIRAC/Core/scripts/install_full.cfg @@ -122,7 +122,6 @@ LocalInstallation Services += Framework/SecurityLogging Services += Framework/UserProfileManager Services += Framework/ProxyManager - Services += Framework/TokenManager Services += Framework/Plotting Services += Framework/BundleDelivery Services += Monitoring/Monitoring From 0876f003e816adda78585316cc1a9c485135ec6d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 14:05:19 +0200 Subject: [PATCH 075/178] fix TornadoBaseClient --- src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 4e7046a1b29..8dedd71dcad 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -235,9 +235,9 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = self.kwargs[self.KW_USE_ACCESS_TOKEN] else: self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") - self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: self.__useAccessToken = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ['yes', 'true'] + self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken if self.__useAccessToken: result = IdProviderFactory().getIdProvider('DIRACCLI') From 019ace42b156b598a37456ce722cf0fa8d120bcc Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 14:14:29 +0200 Subject: [PATCH 076/178] move TornadoConfigurationHandler fixes to separate PR --- .../Service/TornadoConfigurationHandler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py b/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py index 06b634c7116..d5aa1364250 100644 --- a/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py +++ b/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py @@ -48,7 +48,11 @@ def export_getCompressedData(self): Returns the configuration """ sData = self.ServiceInterface.getCompressedConfigurationData() +<<<<<<< HEAD return S_OK(b64encode(sData).decode()) +======= + return S_OK(b64encode(sData)) +>>>>>>> c2f34f040 (move TornadoConfigurationHandler fixes to separate PR) def export_getCompressedDataIfNewer(self, sClientVersion): """ @@ -59,7 +63,11 @@ def export_getCompressedDataIfNewer(self, sClientVersion): sVersion = self.ServiceInterface.getVersion() retDict = {'newestVersion': sVersion} if sClientVersion < sVersion: +<<<<<<< HEAD retDict['data'] = b64encode(self.ServiceInterface.getCompressedConfigurationData()).decode() +======= + retDict['data'] = b64encode(self.ServiceInterface.getCompressedConfigurationData()) +>>>>>>> c2f34f040 (move TornadoConfigurationHandler fixes to separate PR) return S_OK(retDict) def export_publishSlaveServer(self, sURL): From d6d990c526b88f9bbdb33806791ba6510bc3c8f1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 15:03:12 +0200 Subject: [PATCH 077/178] fix issues --- .../Client/Helpers/Registry.py | 17 +++++------------ .../Tornado/Client/private/TornadoBaseClient.py | 9 +++++---- .../Service/TokenManagerHandler.py | 16 ++++++++-------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index 03735ba7a12..09860fdc65a 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -724,16 +724,9 @@ def getIDFromDN(userDN): :param str userDN: user DN - :return: str - """ - return userDN.strip('/O=DIRAC/CN=') - - -def isDNWrappedID(user): - """ Is it wrapped user ID? - - :param str user: user ID - - :return: bool + :return: S_OK(str)/S_ERROR() """ - return user.startswith('/O=DIRAC/CN=') + prefix = '/O=DIRAC/CN=' + if not userDN.startswith(prefix): + return S_ERROR("%s DN does not contain user ID." % userDN) + return S_OK(userDN[len(prefix):]) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 8dedd71dcad..7800c2859a4 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -212,7 +212,8 @@ def __discoverCredentialsToUse(self): -> If KW_USE_ACCESS_TOKEN in kwargs, sets it in self.__useAccessToken -> If not, check "/DIRAC/Security/UseTokens", and sets it in self.__useAccessToken and kwargs[KW_USE_ACCESS_TOKEN] - -> If DIRAC_USE_ACCESS_TOKEN' in os.environ, sets it in self.__useAccessToken + -> If not, check 'DIRAC_USE_ACCESS_TOKEN' in os.environ, sets it in self.__useAccessToken + and kwargs[KW_USE_ACCESS_TOKEN] * Proxy Chain WARNING: MOSTLY COPY/PASTE FROM Core/Diset/private/BaseClient @@ -233,11 +234,11 @@ def __discoverCredentialsToUse(self): # Use tokens? if self.KW_USE_ACCESS_TOKEN in self.kwargs: self.__useAccessToken = self.kwargs[self.KW_USE_ACCESS_TOKEN] + elif 'DIRAC_USE_ACCESS_TOKEN' in os.environ: + self.__useAccessToken = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") else: self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") - if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: - self.__useAccessToken = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ['yes', 'true'] - self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken + self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken if self.__useAccessToken: result = IdProviderFactory().getIdProvider('DIRACCLI') diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index d661cb065e2..669a99ed396 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -50,8 +50,9 @@ def __generateUserTokensInfo(self): if not result['OK']: return result for dn in result['Value']: - if Registry.isDNWrappedID(dn): - result = self.__tokenDB.getTokensByUserID(Registry.getIDFromDN(dn)) + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokensByUserID(result['Value']) if not result['OK']: gLogger.error(result['Message']) tokensInfo += result['Value'] @@ -144,8 +145,9 @@ def export_getToken(self, username, userGroup): err = [] for dn in result['Value']: - if Registry.isDNWrappedID(dn): - result = self.__tokenDB.getTokenForUserProvider(Registry.getIDFromDN(dn), provider) + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokenForUserProvider(result['Value'], provider) if not result['OK']: err.append(result['Message']) elif result['Value']: @@ -170,7 +172,5 @@ def export_deleteToken(self, userDN): if Properties.PROXY_MANAGEMENT not in credDict['properties']: if userDN != credDict['DN']: return S_ERROR("You aren't allowed!") - retVal = self.__tokenDB.removeToken(user_id=Registry.getIDFromDN(userDN)) - if not retVal['OK']: - return retVal - return S_OK() + result = Registry.getIDFromDN(dn) + return self.__tokenDB.removeToken(user_id=result['Value']) if result['OK'] else result From db7075260675d99615e024c6913f3f64c7f2ee59 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 15:52:39 +0200 Subject: [PATCH 078/178] fix bug --- src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index 669a99ed396..d8187324344 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -172,5 +172,5 @@ def export_deleteToken(self, userDN): if Properties.PROXY_MANAGEMENT not in credDict['properties']: if userDN != credDict['DN']: return S_ERROR("You aren't allowed!") - result = Registry.getIDFromDN(dn) + result = Registry.getIDFromDN(userDN) return self.__tokenDB.removeToken(user_id=result['Value']) if result['OK'] else result From 232b47089662c6baa79646f3b8c690bc2f50d2b7 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 15:54:39 +0200 Subject: [PATCH 079/178] fix env py2 --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index fb949ad10bc..a0cb8b50385 100644 --- a/environment.yml +++ b/environment.yml @@ -74,7 +74,7 @@ dependencies: - selectors2 # OAuth2 - authlib - - pyjwt >=2.1.0 + - pyjwt - dominate - pip: - diraccfg From 80fe3aec3767f7ba1c192dab9d0fe67638544361 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 16:08:52 +0200 Subject: [PATCH 080/178] fix pylint --- src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py | 2 +- .../FrameworkSystem/private/authorization/utils/Requests.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 48c4ccacc6c..0ad375fba28 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -206,7 +206,7 @@ def __signToken(self, payload): except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) - + def readToken(self, token): """ Decode self token diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index 5594eb20b91..bf4f74ff062 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -59,5 +59,6 @@ def createOAuth2Request(request, method_cls=OAuth2Request, use_json=False): return method_cls(request['method'], request['uri'], request.get('body'), request.get('headers')) if use_json: return method_cls(request.method, request.full_url(), json_decode(request.body), request.headers) - body = {k:request.body_arguments[k][-1].decode("utf-8") for k in request.body_arguments if request.body_arguments[k]} + body = {k: request.body_arguments[k][-1].decode("utf-8") + for k in request.body_arguments if request.body_arguments[k]} return method_cls(request.method, request.full_url(), body, request.headers) From fab2dc95c25c2c6c75cd23dbb10abea4ac2732c4 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 16:34:25 +0200 Subject: [PATCH 081/178] add TokenManager to ignore services --- tests/Jenkins/utilities.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Jenkins/utilities.sh b/tests/Jenkins/utilities.sh index 0fee47ce6d5..08bd24da217 100644 --- a/tests/Jenkins/utilities.sh +++ b/tests/Jenkins/utilities.sh @@ -801,7 +801,7 @@ diracServices(){ echo '==> [diracServices]' # Ignore tornado services - local services=$(cut -d '.' -f 1 < services | grep -v Tornado | grep -v PilotsLogging | grep -v StorageElementHandler | grep -v ^ConfigurationSystem | grep -v Plotting | grep -v RAWIntegrity | grep -v RunDBInterface | grep -v ComponentMonitoring | sed 's/System / /g' | sed 's/Handler//g' | sed 's/ /\//g') + local services=$(cut -d '.' -f 1 < services | grep -v Tornado | grep -v TokenManager | grep -v PilotsLogging | grep -v StorageElementHandler | grep -v ^ConfigurationSystem | grep -v Plotting | grep -v RAWIntegrity | grep -v RunDBInterface | grep -v ComponentMonitoring | sed 's/System / /g' | sed 's/Handler//g' | sed 's/ /\//g') # group proxy, will be uploaded explicitly # echo '==> getting/uploading proxy for prod' @@ -852,7 +852,7 @@ diracUninstallServices(){ findServices # Ignore tornado services - local services=$(cut -d '.' -f 1 getting/uploading proxy for prod' From b2dc34cbb494e0599bbda48687cb9a2f5dfd5073 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 17:33:27 +0200 Subject: [PATCH 082/178] add commands to docs --- docs/source/AdministratorGuide/CommandReference/index.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/AdministratorGuide/CommandReference/index.rst b/docs/source/AdministratorGuide/CommandReference/index.rst index 1ca4c9fe267..ec640b98cba 100644 --- a/docs/source/AdministratorGuide/CommandReference/index.rst +++ b/docs/source/AdministratorGuide/CommandReference/index.rst @@ -155,6 +155,9 @@ Other commands: .. toctree:: :maxdepth: 2 + dirac-login + dirac-logout + dirac-admin-accounting-cli dirac-admin-sysadmin-cli From d91f8a0002cfcfcc691f8f996f479191763e6633 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 17:33:46 +0200 Subject: [PATCH 083/178] optimize --- .../Core/Tornado/Server/private/BaseRequestHandler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 2e0c5ed39dd..f6047825417 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -530,8 +530,7 @@ def _gatherPeerCredentials(self, grants=None): :param list grants: grants to use - :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` - (not a DIRAC structure !) + :returns: a dict containing user credentials """ err = [] @@ -541,9 +540,7 @@ def _gatherPeerCredentials(self, grants=None): for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): grant = grant.upper() grantFunc = getattr(self, '_authz%s' % grant, None) - if not callable(grantFunc): - raise Exception('%s authentication type is not supported.' % grant) - result = grantFunc() + result = grantFunc() if callable(grantFunc) else S_ERROR('%s authentication type is not supported.' % grant) if result['OK']: for e in err: sLog.debug(e) From e7dfd221a7f40439bbfa199b19d0a9bf73e65720 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 17:42:44 +0200 Subject: [PATCH 084/178] fix docs --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index b02b09ee607..e385b00ce67 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -373,6 +373,7 @@ def web_authorization(self, provider=None): | provider | path | identity provider to autorize (optional) | CheckIn | | | | It's possible to add it interactively. | | +----------------+--------+-------------------------------------------+---------------------------------------+ + General options: provider -- identity provider to autorize @@ -449,15 +450,14 @@ def web_token(self): +----------------+--------+-------------------------------------+---------------------------------------------+ | **name** | **in** | **description** | **example** | +----------------+--------+-------------------------------------+---------------------------------------------+ - | grant_type | query | what grant type to use, more | urn:ietf:params:oauth:grant-type:device_code| - | | | supported grant types in *grants | | + | grant_type | query | grant type to use | urn:ietf:params:oauth:grant-type:device_code| +----------------+--------+-------------------------------------+---------------------------------------------+ | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | +----------------+--------+-------------------------------------+---------------------------------------------+ | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | +----------------+--------+-------------------------------------+---------------------------------------------+ - *:mod:`grants ` + :mod:`Supported grant types ` Request example:: From 75d2a604d863c5f00dafd85b837a3680a835f61c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 23:36:58 +0200 Subject: [PATCH 085/178] fix --- .../private/authorization/AuthServer.py | 21 +++++++++---------- .../private/authorization/utils/Clients.py | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 0ad375fba28..bba0f43040b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -55,17 +55,18 @@ def collectMetadata(issuer=None): if not result['OK']: raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) metadata = result['Value'] - metadata['jwks_uri'] = metadata['issuer'] + '/jwk' - metadata['token_endpoint'] = metadata['issuer'] + '/token' - metadata['userinfo_endpoint'] = metadata['issuer'] + '/userinfo' - metadata['revocation_endpoint'] = metadata['issuer'] + '/revoke' - metadata['authorization_endpoint'] = metadata['issuer'] + '/authorization' - metadata['device_authorization_endpoint'] = metadata['issuer'] + '/device' + for name, endpoint in [('jwks_uri', 'jwk'), + ('token_endpoint', 'token'), + ('userinfo_endpoint', 'userinfo'), + ('revocation_endpoint', 'revoke'), + ('authorization_endpoint', 'authorization'), + ('device_authorization_endpoint', 'device')]: + metadata[name] = metadata['issuer'].strip('/') + '/' + endpoint + metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] metadata['grant_types_supported'] = ['code', 'authorization_code', 'refresh_token', 'urn:ietf:params:oauth:grant-type:device_code'] metadata['response_types_supported'] = ['code', 'device', 'token'] metadata['code_challenge_methods_supported'] = ['S256'] - metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] return AuthorizationServerMetadata(metadata) @@ -89,7 +90,7 @@ def __init__(self): self.metadata = collectMetadata() self.metadata.validate() _AuthorizationServer.__init__(self, scopes_supported=self.metadata['scopes_supported']) - # Skip authlib method save_token + # Skip authlib method save_token and send_signal self.save_token = lambda x, y: None self.send_signal = lambda *x, **y: None self.generate_token = self.generateProxyOrToken @@ -105,13 +106,11 @@ def query_client(self, client_id): :param str clientID: client ID - :return: object or None + :return: client as object or None """ gLogger.debug('Try to query %s client' % client_id) clients = getDIRACClients() for cli in clients: - print(clients[cli]['client_id']) - print(client_id) if client_id == clients[cli]['client_id']: gLogger.debug('Found %s client:\n' % cli, pprint.pformat(clients[cli])) return Client(clients[cli]) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index f3bbec682fd..66f298d81b1 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -29,7 +29,7 @@ def getDIRACClients(): :return: S_OK(dict)/S_ERROR() """ clients = DEFAULT_CLIENTS.copy() - result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization/Client') + result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization/Clients') if not result['OK']: gLogger.error(result['Message']) confClients = result.get('Value', {}) From 8fc0e00ca4c73d17d2c68af280611e6e9071a53b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 23 Jun 2021 23:40:43 +0200 Subject: [PATCH 086/178] add tests --- .../Integration/Framework/Test_AuthServer.py | 121 ++++++++++++++++++ .../IdProvider/Test_IdProviderFactory.py | 103 +++++++++++++++ .../Resources/IdProvider/__init__.py | 0 3 files changed, 224 insertions(+) create mode 100644 tests/Integration/Framework/Test_AuthServer.py create mode 100644 tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py create mode 100644 tests/Integration/Resources/IdProvider/__init__.py diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py new file mode 100644 index 00000000000..81516a09b96 --- /dev/null +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -0,0 +1,121 @@ +""" This is a test of the AuthServer + It supposes that the AuthDB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import pytest +from mock import MagicMock + +from diraccfg import CFG + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +import DIRAC +from DIRAC import S_OK, S_ERROR, gConfig +from authlib.oauth2.base import OAuth2Error +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + +cfg = CFG() +cfg.loadFromBuffer(""" +DIRAC +{ + Security + { + Authorization + { + issuer = https://issuer.url/ + } + } +} +""") +gConfig.loadCFG(cfg) + + +class Proxy(object): + def dumpAllToString(self): + return S_OK('proxy') + + +class ProxyManagerClient(object): + def downloadProxy(self, *args, **kwargs): + return S_OK(Proxy()) + + +class TokenManagerClient(object): + def getToken(self, *args, **kwargs): + return S_OK({'access_token': 'token', 'refresh_token': 'token'}) + + +mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) +mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) +mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) +mockisDownloadablePersonalProxy = MagicMock(return_value=True) + + +@pytest.fixture +def server(mocker): + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", + side_effect=mockgetIdPForGroup) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", + side_effect=mockgetDNForUsername) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", + side_effect=mockgetUsernameForDN) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", + side_effect=ProxyManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", + side_effect=TokenManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", + side_effect=mockisDownloadablePersonalProxy) + return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() + + +def test_metadata(server): + """ Check metadata + """ + assert server.metadata.get('issuer') + + +def test_queryClient(server): + """ Try to search some default client + """ + assert not server.query_client('not_exist_client') + assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + + +def test_queryClient(server): + """ Try to search some default client + """ + assert not server.query_client('not_exist_client') + assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + + +@pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ + ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), + ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), +]) +def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): + """ Generate tokens + """ + cli = server.query_client(client) + try: + assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result + except OAuth2Error as e: + assert False, str(e) + +def test_writeReadRefreshToken(server): + """ Try to search some default client + """ + result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) + assert result['OK'], result['Message'] + token = result['Value'] + assert token.get('access_token') == 'token' + assert token.get('refresh_token') != 'token' + + result = server.readToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value'].get('jti') + assert result['Value'].get('iat') diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py new file mode 100644 index 00000000000..10d207ff442 --- /dev/null +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import time +import unittest +from authlib.jose import jwt + +from diraccfg import CFG + +import DIRAC +from DIRAC import gConfig +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS +from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata + +cfg = CFG() +cfg.loadFromBuffer(""" +DIRAC +{ + Security + { + Authorization + { + issuer = https://issuer.url/ + Clients + { + DIRACWeb + { + client_id = client_identificator + client_secret = client_secret_key + redirect_uri = https://redirect.url/ + } + } + } + } +} +Resources +{ + IdProviders + { + SomeIdP + { + ProviderType = OAuth2 + issuer = https://idp.url/ + client_id = IdP_client_id + client_secret = IdP_client_secret + redirect_uri = https://dirac/redirect + jwks_uri = https://idp.url/jwk + scope = openid+profile+offline_access+eduperson_entitlement + } + } +} +""") +gConfig.loadCFG(cfg) + +idps = IdProviderFactory() + + +def test_getDIRACClients(): + """ Try to load default DIRAC authorization client + """ + params = collectMetadata() + + # Try to get DIRAC client authorization settings + result = idps.getIdProvider('DIRACCLI', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + + # Try to get DIRAC client authorization settings for Web portal + result = idps.getIdProvider('DIRACWeb', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == 'client_identificator' + assert result['Value'].client_secret == 'client_secret_key' + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + + +def test_getIdPClients(): + """ Try to load external identity provider settings + """ + # Try to get identity provider by name + result = idps.getIdProvider('SomeIdP', jwks='my_jwks') + assert result['OK'], result['Message'] + assert result['Value'].jwks == 'my_jwks' + assert result['Value'].issuer == 'https://idp.url/' + assert result['Value'].client_id == 'IdP_client_id' + assert result['Value'].client_secret == 'IdP_client_secret' + assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' + + # Try to get identity provider for token issued by it + result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( + sub='user', + iss=result['Value'].issuer, + iat=int(time.time()), + exp=int(time.time()) + (12 * 3600), + ), "secret").decode('utf-8')) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://idp.url/' diff --git a/tests/Integration/Resources/IdProvider/__init__.py b/tests/Integration/Resources/IdProvider/__init__.py new file mode 100644 index 00000000000..e69de29bb2d From a9978da2c04a6cfdd6ccc18a6cce4f76b9401a24 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 00:27:59 +0200 Subject: [PATCH 087/178] use pytest --- .../Core/Tornado/Server/private/BaseRequestHandler.py | 1 + tests/Integration/Framework/Test_AuthServer.py | 8 +------- .../Resources/IdProvider/Test_IdProviderFactory.py | 2 +- tests/Integration/all_integration_server_tests.sh | 7 ++++--- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index f6047825417..eabe6fc9903 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -540,6 +540,7 @@ def _gatherPeerCredentials(self, grants=None): for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): grant = grant.upper() grantFunc = getattr(self, '_authz%s' % grant, None) + # pylint: disable=not-callable result = grantFunc() if callable(grantFunc) else S_ERROR('%s authentication type is not supported.' % grant) if result['OK']: for e in err: diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 81516a09b96..dc5326961cf 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -86,13 +86,6 @@ def test_queryClient(server): assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' -def test_queryClient(server): - """ Try to search some default client - """ - assert not server.query_client('not_exist_client') - assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' - - @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), @@ -106,6 +99,7 @@ def test_generateToken(server, client, grant, user, scope, expires_in, refresh_t except OAuth2Error as e: assert False, str(e) + def test_writeReadRefreshToken(server): """ Try to search some default client """ diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 10d207ff442..a829a1d7a63 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -60,7 +60,7 @@ def test_getDIRACClients(): - """ Try to load default DIRAC authorization client + """ Try to load default DIRAC authorization client """ params = collectMetadata() diff --git a/tests/Integration/all_integration_server_tests.sh b/tests/Integration/all_integration_server_tests.sh index 1d44e54727f..aef91e6b90e 100644 --- a/tests/Integration/all_integration_server_tests.sh +++ b/tests/Integration/all_integration_server_tests.sh @@ -26,9 +26,10 @@ pytest "${THIS_DIR}/Core/Test_MySQLDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( #-------------------------------------------------------------------------------# echo -e "*** $(date -u) **** FRAMEWORK TESTS (partially skipped) ****\n" pytest "${THIS_DIR}/Framework/Test_InstalledComponentsDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) -python "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) -python "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) -python "${THIS_DIR}/Framework/Test_TokenDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_TokenDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +pytest "${THIS_DIR}/Framework/Test_AuthServer.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #pytest ${THIS_DIR}/Framework/Test_LoggingDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #-------------------------------------------------------------------------------# From 0cc5cd4b091779a786d9139974cc471c9c0e15db Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 00:29:28 +0200 Subject: [PATCH 088/178] fix --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ad9eab50a04..34d658b43ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,5 +69,5 @@ ldap3 setuptools_scm<6.0 # OAuth2 Authlib>=1.0.0.a2 -pyjwt >=2.1.0 +pyjwt dominate \ No newline at end of file From af78dc3de4ed88b2d2019c3d6441372bcc2df8d1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 00:42:07 +0200 Subject: [PATCH 089/178] fix bugs --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 9 ++++----- .../private/authorization/AuthServer.py | 1 - .../authorization/grants/AuthorizationCode.py | 19 +++++++------------ .../private/authorization/utils/Requests.py | 2 +- .../FrameworkSystem/scripts/dirac_login.py | 1 - .../Integration/Framework/Test_AuthServer.py | 2 +- 6 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index e385b00ce67..7d80f63f396 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -17,11 +17,10 @@ from dominate import document, tags as dom from tornado.template import Template -from authlib.jose import jwk from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc6749.util import scope_to_list -from DIRAC import S_OK, S_ERROR +from DIRAC import S_ERROR from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer @@ -156,7 +155,7 @@ def web_index(self, instance): { "registration_endpoint": "https://domain.com/DIRAC/auth/register", "userinfo_endpoint": "https://domain.com/DIRAC/auth/userinfo", - "jwks_uri": "https://domain.com/DIRAC/auth/jwk", + "s_uri": "https://domain.com/DIRAC/auth/", "code_challenge_methods_supported": [ "S256" ], @@ -182,12 +181,12 @@ def web_index(self, instance): if self.request.method == "GET": return self.server.metadata - def web_jwk(self): + def web_(self): """ JWKs endpoint Request example:: - GET LOCATION/jwk + GET LOCATION/ Response:: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index bba0f43040b..06fc7134738 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -14,7 +14,6 @@ from authlib.jose import JsonWebKey, jwt from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error -from authlib.common.security import generate_token from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py index b22ed0a152b..1dc58a8ad42 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -5,14 +5,10 @@ from __future__ import print_function from time import time -from pprint import pprint from authlib.jose import JsonWebSignature from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant -from authlib.common.encoding import to_unicode, json_dumps, json_b64encode, urlsafe_b64decode, json_loads - -from DIRAC import gLogger, S_OK, S_ERROR +from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads class OAuth2Code(dict): @@ -69,7 +65,7 @@ def query_authorization_code(self, code, client): :return: OAuth2Code or None """ - gLogger.debug('Query authorization code:', code) + self.server.log.debug('Query authorization code:', code) jws = JsonWebSignature(algorithms=['RS256']) result = self.server.db.getKeySet() if not result['OK']: @@ -82,13 +78,13 @@ def query_authorization_code(self, code, client): except Exception as e: err = e if err: - gLogger.error('Cannot get authorization code:', repr(err)) + self.server.log.error('Cannot get authorization code:', repr(err)) return None try: item = OAuth2Code(json_loads(urlsafe_b64decode(data['payload']))) - gLogger.debug('Authorization code scope:', item.get_scope()) + self.server.log.debug('Authorization code scope:', item.get_scope()) except Exception as e: - gLogger.error('Cannot read authorization code:', repr(e)) + self.server.log.error('Cannot read authorization code:', repr(e)) return None if not item.is_expired(): return item @@ -105,8 +101,7 @@ def generate_authorization_code(self): :return: str """ - gLogger.debug('Generate authorization code for credentials:', self.request.user) - pprint(self.request.data) + self.server.log.debug('Generate authorization code for credentials:', self.request.user) jws = JsonWebSignature(algorithms=['RS256']) protected = {'alg': 'RS256'} code = OAuth2Code({'user_id': self.request.user['ID'], @@ -116,7 +111,7 @@ def generate_authorization_code(self): 'client_id': self.request.args['client_id'], 'code_challenge': self.request.args.get('code_challenge'), 'code_challenge_method': self.request.args.get('code_challenge_method')}) - gLogger.debug('Authorization code generated:', dict(code)) + self.server.log.debug('Authorization code generated:', dict(code)) result = self.server.db.getPrivateKey() if not result['OK']: raise OAuth2Error('Cannot generate authorization code: %s' % result['Message']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index bf4f74ff062..aa53d4bdab4 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -7,7 +7,7 @@ from tornado.escape import json_decode from authlib.common.encoding import to_unicode from authlib.oauth2 import OAuth2Request as _OAuth2Request -from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.oauth2.rfc6749.util import scope_to_list __RCSID__ = "$Id$" diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 526202bb414..e9a1ec836ee 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -15,7 +15,6 @@ import os import sys -from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope import DIRAC from DIRAC import gLogger, S_OK, S_ERROR diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index dc5326961cf..3753477825c 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -8,6 +8,7 @@ import time import pytest from mock import MagicMock +from authlib.oauth2.base import OAuth2Error from diraccfg import CFG @@ -16,7 +17,6 @@ import DIRAC from DIRAC import S_OK, S_ERROR, gConfig -from authlib.oauth2.base import OAuth2Error from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer cfg = CFG() From 49280cdcc944d41f2846e1d2ca1d111e3f666a2a Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 01:19:14 +0200 Subject: [PATCH 090/178] fix --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 7d80f63f396..7f2ee02e891 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -162,8 +162,6 @@ def web_index(self, instance): "grant_types_supported": [ "authorization_code", "code", - "urn:ietf:params:oauth:grant-type:device_code", - "implicit", "refresh_token" ], "token_endpoint": "https://domain.com/DIRAC/auth/token", @@ -446,15 +444,15 @@ def web_token(self): POST LOCATION/token Parameters: - +----------------+--------+-------------------------------------+---------------------------------------------+ - | **name** | **in** | **description** | **example** | - +----------------+--------+-------------------------------------+---------------------------------------------+ - | grant_type | query | grant type to use | urn:ietf:params:oauth:grant-type:device_code| - +----------------+--------+-------------------------------------+---------------------------------------------+ - | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | - +----------------+--------+-------------------------------------+---------------------------------------------+ - | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | - +----------------+--------+-------------------------------------+---------------------------------------------+ + +----------------+--------+-------------------------------+---------------------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | grant_type | query | grant type to use | urn:ietf:params:oauth:grant-type:device_code | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------+---------------------------------------------------+ + | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | + +----------------+--------+-------------------------------+---------------------------------------------------+ :mod:`Supported grant types ` From d55a66dda5b9e515c53aa53197435a9dfa213173 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 01:40:54 +0200 Subject: [PATCH 091/178] add pytest-mock to extras_require testing --- setup.cfg | 1 + tests/Integration/Framework/Test_AuthDB.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 04f5010b25c..5ccf8dcd2d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ testing = parameterized pytest pytest-cov + pytest-mock pycodestyle [options.entry_points] diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 226952f6627..b55972aa755 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -122,7 +122,7 @@ def test_keys(): result = db.getPrivateKey() assert result['OK'], result['Message'] - private_key = result['Value']['rsakey'] + private_key = result['Value'] assert private_key is RSAKey # Sign token From c2dcd3786d6b9c38cbc93df2d30636742743f61c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 13:44:21 +0200 Subject: [PATCH 092/178] fix tests --- environment-py3.yml | 2 +- .../Client/private/TornadoBaseClient.py | 10 +- .../Server/private/BaseRequestHandler.py | 70 ++- tests/Integration/Framework/Test_AuthDB.py | 411 +++++++++--------- .../Integration/Framework/Test_AuthServer.py | 203 ++++----- tests/Integration/Framework/Test_TokenDB.py | 147 ++++--- .../IdProvider/Test_IdProviderFactory.py | 151 +++---- 7 files changed, 524 insertions(+), 470 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index edcce165913..26e99aa12e2 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -62,7 +62,6 @@ dependencies: - pyparsing >=2.0.6 - pytest >=3.6 - pytest-cov >=2.2.0 - - pytest-mock - setuptools-scm - shellcheck - typer @@ -92,6 +91,7 @@ dependencies: - pyjwt >=2.1.0 - dominate - pip: + - pytest-mock # Prerelease of the required package for integration of OAuth2 - Authlib>=1.0.0.a2 # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 7800c2859a4..025f5d8b72c 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -49,8 +49,10 @@ from DIRAC.Core.Security import Locations from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile + +if six.PY3: + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile # TODO CHRIS: refactor all the messy `discover` methods @@ -240,7 +242,7 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken - if self.__useAccessToken: + if self.__useAccessToken and six.PY3: result = IdProviderFactory().getIdProvider('DIRACCLI') if not result['OK']: return result @@ -520,7 +522,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): auth = {'cert': Locations.getHostCertificateAndKeyLocation()} # Use access token? - elif self.__useAccessToken: + elif self.__useAccessToken and six.PY3: # Read token from token environ variable or from token file result = getLocalTokenDict() if not result['OK']: diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index eabe6fc9903..d819faeeea9 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -10,7 +10,9 @@ from io import open import os -import jwt +import six +if six.PY3: + import jwt import time import threading from datetime import datetime @@ -459,6 +461,39 @@ def __executeMethod(self, targetMethod, args): sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) + @gen.coroutine + def __executeMethodPy2(self, targetMethod, args): + """ + Execute the method called, this method is ran in an executor + We have several try except to catch the different problem which can occur + + - First, the method does not exist => Attribute error, return an error to client + - second, anything happend during execution => General Exception, send error to client + + .. warning:: + This method is called in an executor, and so cannot use methods like self.write + See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + + :param str targetMethod: name of the method to call + :param list args: target method arguments + + :return: Future + """ + + sLog.notice( + "Incoming request %s /%s: %s" % + (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + self.method)) + + # Execute + try: + self.initializeRequest() + return targetMethod(*args) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) + raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) + def _prepareExecutor(self, args): """ Preparation of necessary arguments for the `__executeMethod` method @@ -466,14 +501,17 @@ def _prepareExecutor(self, args): :return: executor, target method with arguments """ - return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) + if six.PY3: + return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) + return None, partial(self.__executeMethodPy2, self._getMethod(), self._getMethodArgs(args)) def _finishFuture(self, retVal): """ Handler Future result :param object retVal: tornado.concurrent.Future """ - self.result = retVal + # Wait result only if it's a Future object + self.result = retVal.result() if isinstance(retVal, Future) else retVal # Here it is safe to write back to the client, because we are not in a thread anymore @@ -485,7 +523,7 @@ def _finishFuture(self, retVal): finishFunc() # In case nothing is returned - elif retVal is None: + elif self.result is None: self.finish() # If set to true, do not JEncode the return of the RPC call @@ -494,16 +532,16 @@ def _finishFuture(self, retVal): elif self.get_argument('rawContent', default=False): # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt self.set_header("Content-Type", "application/octet-stream") - self.finish(retVal) + self.finish(self.result) # Return simple text or html - elif isinstance(retVal, string_types): - self.finish(retVal) + elif isinstance(self.result, string_types): + self.finish(self.result) # JSON else: self.set_header("Content-Type", "application/json") - self.finish(encode(retVal)) + self.finish(encode(self.result)) def on_finish(self): """ @@ -607,13 +645,15 @@ def _authzJWT(self, accessToken=None): # Read token without verification to get issuer self.log.debug('Read issuer from access token', accessToken) - issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, - verify_aud=False))['iss'].strip('/') - # Verify token - self.log.debug('Verify access token') - result = self._idp[issuer].verifyToken(accessToken) - self.log.debug('Search user group') - return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result + if six.PY3: + issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, + verify_aud=False))['iss'].strip('/') + # Verify token + self.log.debug('Verify access token') + result = self._idp[issuer].verifyToken(accessToken) + self.log.debug('Search user group') + return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result + return S_OK({}) def _authzVISITOR(self): """ Visitor access diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index b55972aa755..b6ad39e6304 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -5,210 +5,213 @@ from __future__ import division from __future__ import print_function -import time -from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey -from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads - -from DIRAC.Core.Base.Script import parseCommandLine -parseCommandLine() - -from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB - -db = AuthDB() - -payload = {'sub': 'user', - 'iss': 'issuer', - 'iat': int(time.time()), - 'exp': int(time.time()) + (12 * 3600), - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} - -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_in=int(time.time()) + 3600) - - -def test_RefreshToken(): - """ Try to revoke/save/get refresh tokens - """ - preset_jti = '123' - - # Remove refresh token - result = db.revokeRefreshToken(preset_jti) - assert result['OK'], result['Message'] - - # Store tokens - result = db.storeRefreshToken(DToken.copy(), preset_jti) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == preset_jti - assert result['Value']['iat'] <= int(time.time()) - - result = db.storeRefreshToken(New_DToken.copy()) - assert result['OK'], result['Message'] - assert result['Value']['jti'] - assert result['Value']['iat'] <= int(time.time()) - - token_id = result['Value']['jti'] - issued_at = result['Value']['iat'] - - # Check token - result = db.getCredentialByRefreshToken(preset_jti) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == preset_jti - assert result['Value']['access_token'] == DToken['access_token'] - assert result['Value']['refresh_token'] == DToken['refresh_token'] - - result = db.getCredentialByRefreshToken(token_id) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == token_id - assert int(result['Value']['issued_at']) == issued_at - assert result['Value']['access_token'] == New_DToken['access_token'] - assert result['Value']['refresh_token'] == New_DToken['refresh_token'] - - # Check token after request - for jti in [token_id, preset_jti]: - result = db.getCredentialByRefreshToken(jti) +import six + +if six.PY3: + import time + from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey + from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + + from DIRAC.Core.Base.Script import parseCommandLine + parseCommandLine() + + from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + + db = AuthDB() + + payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_in=int(time.time()) + 3600) + + + def test_RefreshToken(): + """ Try to revoke/save/get refresh tokens + """ + preset_jti = '123' + + # Remove refresh token + result = db.revokeRefreshToken(preset_jti) + assert result['OK'], result['Message'] + + # Store tokens + result = db.storeRefreshToken(DToken.copy(), preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['iat'] <= int(time.time()) + + result = db.storeRefreshToken(New_DToken.copy()) + assert result['OK'], result['Message'] + assert result['Value']['jti'] + assert result['Value']['iat'] <= int(time.time()) + + token_id = result['Value']['jti'] + issued_at = result['Value']['iat'] + + # Check token + result = db.getCredentialByRefreshToken(preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == token_id + assert int(result['Value']['issued_at']) == issued_at + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] + + # Check token after request + for jti in [token_id, preset_jti]: + result = db.getCredentialByRefreshToken(jti) + assert result['OK'], result['Message'] + assert not result['Value'] + + # Renew tokens + result = db.storeRefreshToken(New_DToken.copy(), token_id) + assert result['OK'], result['Message'] + + # Revoke token + result = db.revokeRefreshToken(token_id) + assert result['OK'], result['Message'] + + # Check token + result = db.getCredentialByRefreshToken(token_id) assert result['OK'], result['Message'] assert not result['Value'] - # Renew tokens - result = db.storeRefreshToken(New_DToken.copy(), token_id) - assert result['OK'], result['Message'] - - # Revoke token - result = db.revokeRefreshToken(token_id) - assert result['OK'], result['Message'] - - # Check token - result = db.getCredentialByRefreshToken(token_id) - assert result['OK'], result['Message'] - assert not result['Value'] - - -def test_keys(): - """ Try to store/get/remove keys - """ - # JWS - jws = JsonWebSignature(algorithms=['RS256']) - code_payload = {'user_id': 'user', - 'scope': 'scope', - 'redirect_uri': 'redirect_uri', - 'client_id': 'client', - 'code_challenge': 'code_challenge'} - - # Token metadata - header = {'alg': 'RS256'} - payload = {'sub': 'user', - 'iss': 'issuer', - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} - - # Remove all keys - result = db.removeKeys() - assert result['OK'], result['Message'] - - # Check active keys - result = db.getActiveKeys() - assert result['OK'], result['Message'] - assert result['Value'] == [] - - # Create new one - result = db.getPrivateKey() - assert result['OK'], result['Message'] - - private_key = result['Value'] - assert private_key is RSAKey - - # Sign token - header['kid'] = private_key.thumbprint() - - # Find key by KID - result = db.getPrivateKey(header['kid']) - assert result['OK'], result['Message'] - assert result['Value'].as_dict(True) == private_key.as_dict(True) - - # Sign token - token = jwt.encode(header, payload, private_key) - # Sign auth code - code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) - - # Get public key set - result = db.getKeySet() - keyset = result['Value'] - assert result['OK'], result['Message'] - assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) - - # Read token - _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) - assert _payload == payload - # Read auth code - data = jws.deserialize_compact(code, keyset.keys[0]) - _code_payload = json_loads(urlsafe_b64decode(data['payload'])) - assert _code_payload == code_payload - - -def test_Sessions(): - """ Try to store/get/remove Sessions - """ - # Example of the new session metadata - sData1 = {'client_id': 'DIRAC_CLI', - 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'expires_in': 1800, - 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'interval': 5, - 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', - 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} - - # Example of the updated session - sData2 = {'client_id': 'DIRAC_CLI', - 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'expires_in': 1800, - 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'interval': 5, - 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', - 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', - 'user_id': 'username'} - - # Remove old session - db.removeSession(sData1['id']) - - # Add session - result = db.addSession(sData1) - assert result['OK'], result['Message'] - - # Get session - result = db.getSessionByUserCode(sData1['user_code']) - assert result['OK'], result['Message'] - assert result['Value']['device_code'] == sData1['device_code'] - assert result['Value'].get('user_id') != sData2['user_id'] - - # Update session - result = db.updateSession(sData2, sData1['id']) - assert result['OK'], result['Message'] - - # Get session - result = db.getSession(sData2['id']) - assert result['OK'], result['Message'] - assert result['Value']['device_code'] == sData2['device_code'] - assert result['Value']['user_id'] == sData2['user_id'] - - # Remove session - result = db.removeSession(sData2['id']) - assert result['OK'], result['Message'] - - # Make sure that the session is absent - result = db.getSession(sData2['id']) - assert result['OK'], result['Message'] - assert not result['Value'] + + def test_keys(): + """ Try to store/get/remove keys + """ + # JWS + jws = JsonWebSignature(algorithms=['RS256']) + code_payload = {'user_id': 'user', + 'scope': 'scope', + 'redirect_uri': 'redirect_uri', + 'client_id': 'client', + 'code_challenge': 'code_challenge'} + + # Token metadata + header = {'alg': 'RS256'} + payload = {'sub': 'user', + 'iss': 'issuer', + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + # Remove all keys + result = db.removeKeys() + assert result['OK'], result['Message'] + + # Check active keys + result = db.getActiveKeys() + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Create new one + result = db.getPrivateKey() + assert result['OK'], result['Message'] + + private_key = result['Value'] + assert private_key is RSAKey + + # Sign token + header['kid'] = private_key.thumbprint() + + # Find key by KID + result = db.getPrivateKey(header['kid']) + assert result['OK'], result['Message'] + assert result['Value'].as_dict(True) == private_key.as_dict(True) + + # Sign token + token = jwt.encode(header, payload, private_key) + # Sign auth code + code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) + + # Get public key set + result = db.getKeySet() + keyset = result['Value'] + assert result['OK'], result['Message'] + assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) + + # Read token + _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) + assert _payload == payload + # Read auth code + data = jws.deserialize_compact(code, keyset.keys[0]) + _code_payload = json_loads(urlsafe_b64decode(data['payload'])) + assert _code_payload == code_payload + + + def test_Sessions(): + """ Try to store/get/remove Sessions + """ + # Example of the new session metadata + sData1 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} + + # Example of the updated session + sData2 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', + 'user_id': 'username'} + + # Remove old session + db.removeSession(sData1['id']) + + # Add session + result = db.addSession(sData1) + assert result['OK'], result['Message'] + + # Get session + result = db.getSessionByUserCode(sData1['user_code']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData1['device_code'] + assert result['Value'].get('user_id') != sData2['user_id'] + + # Update session + result = db.updateSession(sData2, sData1['id']) + assert result['OK'], result['Message'] + + # Get session + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData2['device_code'] + assert result['Value']['user_id'] == sData2['user_id'] + + # Remove session + result = db.removeSession(sData2['id']) + assert result['OK'], result['Message'] + + # Make sure that the session is absent + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert not result['Value'] diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 3753477825c..e087a8b8281 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -5,111 +5,114 @@ from __future__ import division from __future__ import print_function -import time -import pytest -from mock import MagicMock -from authlib.oauth2.base import OAuth2Error +import six -from diraccfg import CFG +if six.PY3: + import time + import pytest + from mock import MagicMock + from authlib.oauth2.base import OAuth2Error -from DIRAC.Core.Base.Script import parseCommandLine -parseCommandLine() + from diraccfg import CFG -import DIRAC -from DIRAC import S_OK, S_ERROR, gConfig -from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + from DIRAC.Core.Base.Script import parseCommandLine + parseCommandLine() -cfg = CFG() -cfg.loadFromBuffer(""" -DIRAC -{ - Security + import DIRAC + from DIRAC import S_OK, S_ERROR, gConfig + from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + + cfg = CFG() + cfg.loadFromBuffer(""" + DIRAC { - Authorization + Security { - issuer = https://issuer.url/ + Authorization + { + issuer = https://issuer.url/ + } } } -} -""") -gConfig.loadCFG(cfg) - - -class Proxy(object): - def dumpAllToString(self): - return S_OK('proxy') - - -class ProxyManagerClient(object): - def downloadProxy(self, *args, **kwargs): - return S_OK(Proxy()) - - -class TokenManagerClient(object): - def getToken(self, *args, **kwargs): - return S_OK({'access_token': 'token', 'refresh_token': 'token'}) - - -mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) -mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) -mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) -mockisDownloadablePersonalProxy = MagicMock(return_value=True) - - -@pytest.fixture -def server(mocker): - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) - return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() - - -def test_metadata(server): - """ Check metadata - """ - assert server.metadata.get('issuer') - - -def test_queryClient(server): - """ Try to search some default client - """ - assert not server.query_client('not_exist_client') - assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' - - -@pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ - ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), - ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), -]) -def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): - """ Generate tokens - """ - cli = server.query_client(client) - try: - assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result - except OAuth2Error as e: - assert False, str(e) - - -def test_writeReadRefreshToken(server): - """ Try to search some default client - """ - result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) - assert result['OK'], result['Message'] - token = result['Value'] - assert token.get('access_token') == 'token' - assert token.get('refresh_token') != 'token' - - result = server.readToken(token['refresh_token']) - assert result['OK'], result['Message'] - assert result['Value'].get('jti') - assert result['Value'].get('iat') + """) + gConfig.loadCFG(cfg) + + + class Proxy(object): + def dumpAllToString(self): + return S_OK('proxy') + + + class ProxyManagerClient(object): + def downloadProxy(self, *args, **kwargs): + return S_OK(Proxy()) + + + class TokenManagerClient(object): + def getToken(self, *args, **kwargs): + return S_OK({'access_token': 'token', 'refresh_token': 'token'}) + + + mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) + mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) + mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) + mockisDownloadablePersonalProxy = MagicMock(return_value=True) + + + @pytest.fixture + def server(mocker): + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", + side_effect=mockgetIdPForGroup) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", + side_effect=mockgetDNForUsername) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", + side_effect=mockgetUsernameForDN) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", + side_effect=ProxyManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", + side_effect=TokenManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", + side_effect=mockisDownloadablePersonalProxy) + return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() + + + def test_metadata(server): + """ Check metadata + """ + assert server.metadata.get('issuer') + + + def test_queryClient(server): + """ Try to search some default client + """ + assert not server.query_client('not_exist_client') + assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + + + @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ + ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), + ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), + ]) + def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): + """ Generate tokens + """ + cli = server.query_client(client) + try: + assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result + except OAuth2Error as e: + assert False, str(e) + + + def test_writeReadRefreshToken(server): + """ Try to search some default client + """ + result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) + assert result['OK'], result['Message'] + token = result['Value'] + assert token.get('access_token') == 'token' + assert token.get('refresh_token') != 'token' + + result = server.readToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value'].get('jti') + assert result['Value'].get('iat') diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index 3c3f939adb6..4f058a15a75 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -5,75 +5,78 @@ from __future__ import division from __future__ import print_function -import time -from authlib.jose import jwt - -from DIRAC.Core.Base.Script import parseCommandLine -parseCommandLine() - -from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB - -db = TokenDB() - -payload = {'sub': 'user', - 'iss': 'issuer', - 'iat': int(time.time()), - 'exp': int(time.time()) + (12 * 3600), - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} - -exp_payload = payload.copy() -exp_payload['iat'] = int(time.time()) - 10 -exp_payload['exp'] = int(time.time()) - 10 - -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - issued_at=int(time.time()), - expires_in=int(time.time()) + 3600) - -Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - expires_at=int(time.time()) - 10, - rt_expires_at=int(time.time()) - 10) - - -def test_Token(): - """ Try to revoke/save/get tokens - """ - # Remove all tokens - result = db.removeToken(user_id=123) - assert result['OK'], result['Message'] - - # Store tokens - result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert result['OK'], result['Message'] - assert result['Value'] == [] - - # Expired token - result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert not result['OK'] - - # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') - assert result['OK'], result['Message'] - assert result['Value']['access_token'] == DToken['access_token'] - assert result['Value']['refresh_token'] == DToken['refresh_token'] - - # Store new tokens - result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert result['OK'], result['Message'] - # Must return old tokens - assert len(result['Value']) == 1 - assert result['Value'][0]['access_token'] == DToken['access_token'] - assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] - - # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') - assert result['OK'], result['Message'] - assert result['Value']['access_token'] == New_DToken['access_token'] - assert result['Value']['refresh_token'] == New_DToken['refresh_token'] +import six + +if six.PY3: + import time + from authlib.jose import jwt + + from DIRAC.Core.Base.Script import parseCommandLine + parseCommandLine() + + from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB + + db = TokenDB() + + payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + + exp_payload = payload.copy() + exp_payload['iat'] = int(time.time()) - 10 + exp_payload['exp'] = int(time.time()) - 10 + + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + + Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + expires_at=int(time.time()) - 10, + rt_expires_at=int(time.time()) - 10) + + + def test_Token(): + """ Try to revoke/save/get tokens + """ + # Remove all tokens + result = db.removeToken(user_id=123) + assert result['OK'], result['Message'] + + # Store tokens + result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Expired token + result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert not result['OK'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + # Store new tokens + result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + # Must return old tokens + assert len(result['Value']) == 1 + assert result['Value'][0]['access_token'] == DToken['access_token'] + assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index a829a1d7a63..688084a08e2 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -4,100 +4,103 @@ __RCSID__ = "$Id$" -import time -import unittest -from authlib.jose import jwt +import six -from diraccfg import CFG +if six.PY3: + import time + import unittest + from authlib.jose import jwt -import DIRAC -from DIRAC import gConfig -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS -from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata + from diraccfg import CFG -cfg = CFG() -cfg.loadFromBuffer(""" -DIRAC -{ - Security + import DIRAC + from DIRAC import gConfig + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS + from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata + + cfg = CFG() + cfg.loadFromBuffer(""" + DIRAC { - Authorization + Security { - issuer = https://issuer.url/ - Clients + Authorization { - DIRACWeb + issuer = https://issuer.url/ + Clients { - client_id = client_identificator - client_secret = client_secret_key - redirect_uri = https://redirect.url/ + DIRACWeb + { + client_id = client_identificator + client_secret = client_secret_key + redirect_uri = https://redirect.url/ + } } } } } -} -Resources -{ - IdProviders + Resources { - SomeIdP + IdProviders { - ProviderType = OAuth2 - issuer = https://idp.url/ - client_id = IdP_client_id - client_secret = IdP_client_secret - redirect_uri = https://dirac/redirect - jwks_uri = https://idp.url/jwk - scope = openid+profile+offline_access+eduperson_entitlement + SomeIdP + { + ProviderType = OAuth2 + issuer = https://idp.url/ + client_id = IdP_client_id + client_secret = IdP_client_secret + redirect_uri = https://dirac/redirect + jwks_uri = https://idp.url/jwk + scope = openid+profile+offline_access+eduperson_entitlement + } } } -} -""") -gConfig.loadCFG(cfg) + """) + gConfig.loadCFG(cfg) -idps = IdProviderFactory() + idps = IdProviderFactory() -def test_getDIRACClients(): - """ Try to load default DIRAC authorization client - """ - params = collectMetadata() + def test_getDIRACClients(): + """ Try to load default DIRAC authorization client + """ + params = collectMetadata() - # Try to get DIRAC client authorization settings - result = idps.getIdProvider('DIRACCLI', **params) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://issuer.url/' - assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] - assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + # Try to get DIRAC client authorization settings + result = idps.getIdProvider('DIRACCLI', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' - # Try to get DIRAC client authorization settings for Web portal - result = idps.getIdProvider('DIRACWeb', **params) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://issuer.url/' - assert result['Value'].client_id == 'client_identificator' - assert result['Value'].client_secret == 'client_secret_key' - assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + # Try to get DIRAC client authorization settings for Web portal + result = idps.getIdProvider('DIRACWeb', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == 'client_identificator' + assert result['Value'].client_secret == 'client_secret_key' + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' -def test_getIdPClients(): - """ Try to load external identity provider settings - """ - # Try to get identity provider by name - result = idps.getIdProvider('SomeIdP', jwks='my_jwks') - assert result['OK'], result['Message'] - assert result['Value'].jwks == 'my_jwks' - assert result['Value'].issuer == 'https://idp.url/' - assert result['Value'].client_id == 'IdP_client_id' - assert result['Value'].client_secret == 'IdP_client_secret' - assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' + def test_getIdPClients(): + """ Try to load external identity provider settings + """ + # Try to get identity provider by name + result = idps.getIdProvider('SomeIdP', jwks='my_jwks') + assert result['OK'], result['Message'] + assert result['Value'].jwks == 'my_jwks' + assert result['Value'].issuer == 'https://idp.url/' + assert result['Value'].client_id == 'IdP_client_id' + assert result['Value'].client_secret == 'IdP_client_secret' + assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' - # Try to get identity provider for token issued by it - result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( - sub='user', - iss=result['Value'].issuer, - iat=int(time.time()), - exp=int(time.time()) + (12 * 3600), - ), "secret").decode('utf-8')) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://idp.url/' + # Try to get identity provider for token issued by it + result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( + sub='user', + iss=result['Value'].issuer, + iat=int(time.time()), + exp=int(time.time()) + (12 * 3600), + ), "secret").decode('utf-8')) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://idp.url/' From 7ea2d3614edc79d28da15676e11698b6dcc1c9f8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 16:15:25 +0200 Subject: [PATCH 093/178] fix tests --- tests/Integration/Framework/Test_AuthDB.py | 27 +++++++++---------- .../Integration/Framework/Test_AuthServer.py | 9 ------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index b6ad39e6304..fe5ec3d1b9e 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -20,12 +20,12 @@ db = AuthDB() payload = {'sub': 'user', - 'iss': 'issuer', - 'iat': int(time.time()), - 'exp': int(time.time()) + (12 * 3600), - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), @@ -35,7 +35,6 @@ refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), expires_in=int(time.time()) + 3600) - def test_RefreshToken(): """ Try to revoke/save/get refresh tokens """ @@ -92,7 +91,6 @@ def test_RefreshToken(): assert result['OK'], result['Message'] assert not result['Value'] - def test_keys(): """ Try to store/get/remove keys """ @@ -100,17 +98,17 @@ def test_keys(): jws = JsonWebSignature(algorithms=['RS256']) code_payload = {'user_id': 'user', 'scope': 'scope', - 'redirect_uri': 'redirect_uri', 'client_id': 'client', + 'redirect_uri': 'redirect_uri', 'code_challenge': 'code_challenge'} # Token metadata header = {'alg': 'RS256'} payload = {'sub': 'user', - 'iss': 'issuer', - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} + 'iss': 'issuer', + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} # Remove all keys result = db.removeKeys() @@ -126,7 +124,7 @@ def test_keys(): assert result['OK'], result['Message'] private_key = result['Value'] - assert private_key is RSAKey + assert isinstance(private_key, RSAKey) # Sign token header['kid'] = private_key.thumbprint() @@ -155,7 +153,6 @@ def test_keys(): _code_payload = json_loads(urlsafe_b64decode(data['payload'])) assert _code_payload == code_payload - def test_Sessions(): """ Try to store/get/remove Sessions """ diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index e087a8b8281..c44b40c226e 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -37,28 +37,23 @@ """) gConfig.loadCFG(cfg) - class Proxy(object): def dumpAllToString(self): return S_OK('proxy') - class ProxyManagerClient(object): def downloadProxy(self, *args, **kwargs): return S_OK(Proxy()) - class TokenManagerClient(object): def getToken(self, *args, **kwargs): return S_OK({'access_token': 'token', 'refresh_token': 'token'}) - mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) mockisDownloadablePersonalProxy = MagicMock(return_value=True) - @pytest.fixture def server(mocker): mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", @@ -75,20 +70,17 @@ def server(mocker): side_effect=mockisDownloadablePersonalProxy) return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() - def test_metadata(server): """ Check metadata """ assert server.metadata.get('issuer') - def test_queryClient(server): """ Try to search some default client """ assert not server.query_client('not_exist_client') assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' - @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), @@ -102,7 +94,6 @@ def test_generateToken(server, client, grant, user, scope, expires_in, refresh_t except OAuth2Error as e: assert False, str(e) - def test_writeReadRefreshToken(server): """ Try to search some default client """ From 413af2301a1c1e5f6d968ae03fbe524afca07a9b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 24 Jun 2021 16:35:01 +0200 Subject: [PATCH 094/178] fix tests --- environment-py3.yml | 2 +- setup.cfg | 1 - .../Server/private/BaseRequestHandler.py | 56 ++++++------------- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 6 +- .../FrameworkSystem/scripts/dirac_login.py | 2 +- .../Integration/Framework/Test_AuthServer.py | 6 +- .../IdProvider/Test_IdProviderFactory.py | 2 - 7 files changed, 26 insertions(+), 49 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index 26e99aa12e2..edcce165913 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -62,6 +62,7 @@ dependencies: - pyparsing >=2.0.6 - pytest >=3.6 - pytest-cov >=2.2.0 + - pytest-mock - setuptools-scm - shellcheck - typer @@ -91,7 +92,6 @@ dependencies: - pyjwt >=2.1.0 - dominate - pip: - - pytest-mock # Prerelease of the required package for integration of OAuth2 - Authlib>=1.0.0.a2 # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/setup.cfg b/setup.cfg index 5ccf8dcd2d8..04f5010b25c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,7 +95,6 @@ testing = parameterized pytest pytest-cov - pytest-mock pycodestyle [options.entry_points] diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index d819faeeea9..88af2e58bd4 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -447,12 +447,8 @@ def __executeMethod(self, targetMethod, args): :return: Future """ - sLog.notice( - "Incoming request %s /%s: %s" % - (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - self.method)) - + sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, self.method)) # Execute try: self.initializeRequest() @@ -463,29 +459,13 @@ def __executeMethod(self, targetMethod, args): @gen.coroutine def __executeMethodPy2(self, targetMethod, args): - """ - Execute the method called, this method is ran in an executor - We have several try except to catch the different problem which can occur + """ The only difference from __executeMethod is the presence of a coroutine decorator - - First, the method does not exist => Attribute error, return an error to client - - second, anything happend during execution => General Exception, send error to client - - .. warning:: - This method is called in an executor, and so cannot use methods like self.write - See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - - :param str targetMethod: name of the method to call - :param list args: target method arguments - - :return: Future + :return: Future """ - sLog.notice( - "Incoming request %s /%s: %s" % - (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - self.method)) - + sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, self.method)) # Execute try: self.initializeRequest() @@ -501,9 +481,9 @@ def _prepareExecutor(self, args): :return: executor, target method with arguments """ - if six.PY3: - return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) - return None, partial(self.__executeMethodPy2, self._getMethod(), self._getMethodArgs(args)) + if six.PY2: + return None, partial(self.__executeMethodPy2, self._getMethod(), self._getMethodArgs(args)) + return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) def _finishFuture(self, retVal): """ Handler Future result @@ -645,15 +625,15 @@ def _authzJWT(self, accessToken=None): # Read token without verification to get issuer self.log.debug('Read issuer from access token', accessToken) - if six.PY3: - issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, - verify_aud=False))['iss'].strip('/') - # Verify token - self.log.debug('Verify access token') - result = self._idp[issuer].verifyToken(accessToken) - self.log.debug('Search user group') - return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result - return S_OK({}) + if six.PY2: + return S_OK({}) + issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, + verify_aud=False))['iss'].strip('/') + # Verify token + self.log.debug('Verify access token') + result = self._idp[issuer].verifyToken(accessToken) + self.log.debug('Search user group') + return self._idp[issuer].researchGroup(result['Value'], accessToken) if result['OK'] else result def _authzVISITOR(self): """ Visitor access diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 7f2ee02e891..ae376acf2f9 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -155,7 +155,7 @@ def web_index(self, instance): { "registration_endpoint": "https://domain.com/DIRAC/auth/register", "userinfo_endpoint": "https://domain.com/DIRAC/auth/userinfo", - "s_uri": "https://domain.com/DIRAC/auth/", + "jwks_uri": "https://domain.com/DIRAC/auth/jwk", "code_challenge_methods_supported": [ "S256" ], @@ -179,12 +179,12 @@ def web_index(self, instance): if self.request.method == "GET": return self.server.metadata - def web_(self): + def web_jwk(self): """ JWKs endpoint Request example:: - GET LOCATION/ + GET LOCATION/jwk Response:: diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index e9a1ec836ee..ceff0366ac1 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -4,7 +4,7 @@ # Author : Andrii Lytovchenko ######################################################################## """ -Login to DIRAC. +Login to DIRAC. For python 3 only. Example: $ dirac-login -g dirac_user diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index c44b40c226e..67d47eace3e 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -6,11 +6,11 @@ from __future__ import print_function import six +import time +import pytest +from mock import MagicMock if six.PY3: - import time - import pytest - from mock import MagicMock from authlib.oauth2.base import OAuth2Error from diraccfg import CFG diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 688084a08e2..37d9c948ea0 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -61,7 +61,6 @@ idps = IdProviderFactory() - def test_getDIRACClients(): """ Try to load default DIRAC authorization client """ @@ -82,7 +81,6 @@ def test_getDIRACClients(): assert result['Value'].client_secret == 'client_secret_key' assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' - def test_getIdPClients(): """ Try to load external identity provider settings """ From ae99399827ac497ee87dd9494b013a9cc8bebe23 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 28 Jun 2021 19:45:05 +0200 Subject: [PATCH 095/178] compatibility with py2 --- environment-py3.yml | 5 +- environment.yml | 4 + requirements.txt | 2 +- setup.cfg | 4 +- .../Client/private/TornadoBaseClient.py | 10 +- .../Server/private/BaseRequestHandler.py | 8 +- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 16 +- .../private/authorization/AuthServer.py | 13 +- .../authorization/grants/DeviceFlow.py | 4 +- tests/Integration/Framework/Test_AuthDB.py | 392 +++++++++--------- .../Integration/Framework/Test_AuthServer.py | 168 ++++---- tests/Integration/Framework/Test_TokenDB.py | 146 ++++--- .../IdProvider/Test_IdProviderFactory.py | 153 ++++--- tests/Jenkins/dirac_ci.sh | 6 +- tests/Jenkins/utilities.sh | 3 + 15 files changed, 477 insertions(+), 457 deletions(-) diff --git a/environment-py3.yml b/environment-py3.yml index edcce165913..ebe05a23909 100644 --- a/environment-py3.yml +++ b/environment-py3.yml @@ -88,12 +88,11 @@ dependencies: #- tornado >=5.0.0,<6.0.0 - typing >=3.6.6 - pyyaml - # OAuth2 - - pyjwt >=2.1.0 - - dominate - pip: # Prerelease of the required package for integration of OAuth2 - Authlib>=1.0.0.a2 + - dominate + - pyjwt # This is a fork of tornado with a patch to allow for configurable iostream # It should eventually be part of DIRACGrid - git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable diff --git a/environment.yml b/environment.yml index a0cb8b50385..6c0255a83d3 100644 --- a/environment.yml +++ b/environment.yml @@ -78,6 +78,10 @@ dependencies: - dominate - pip: - diraccfg + # OAuth2 + - dominate + - authlib + - pyjwt # This is a fork of tornado with a patch to allow for configurable iostream - git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable # This is an extension of Tornado to use M2Crypto diff --git a/requirements.txt b/requirements.txt index 34d658b43ec..c8018f8ec45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,6 @@ ldap3 # setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 setuptools_scm<6.0 # OAuth2 -Authlib>=1.0.0.a2 +Authlib pyjwt dominate \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 04f5010b25c..e32108285b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,8 +55,8 @@ install_requires = six sqlalchemy subprocess32 - Authlib>=1.0.0.a2 - pyjwt >=2.1.0 + Authlib >=1.0.0.a2 + pyjwt dominate zip_safe = False include_package_data = True diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 025f5d8b72c..7800c2859a4 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -49,10 +49,8 @@ from DIRAC.Core.Security import Locations from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode - -if six.PY3: - from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory - from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile # TODO CHRIS: refactor all the messy `discover` methods @@ -242,7 +240,7 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken - if self.__useAccessToken and six.PY3: + if self.__useAccessToken: result = IdProviderFactory().getIdProvider('DIRACCLI') if not result['OK']: return result @@ -522,7 +520,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): auth = {'cert': Locations.getHostCertificateAndKeyLocation()} # Use access token? - elif self.__useAccessToken and six.PY3: + elif self.__useAccessToken: # Read token from token environ variable or from token file result = getLocalTokenDict() if not result['OK']: diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 88af2e58bd4..11ef9b9b784 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -10,9 +10,7 @@ from io import open import os -import six -if six.PY3: - import jwt +import jwt import time import threading from datetime import datetime @@ -36,8 +34,8 @@ from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory sLog = gLogger.getSubLogger(__name__.split('.')[-1]) @@ -625,8 +623,6 @@ def _authzJWT(self, accessToken=None): # Read token without verification to get issuer self.log.debug('Read issuer from access token', accessToken) - if six.PY2: - return S_OK({}) issuer = jwt.decode(accessToken, leeway=300, options=dict(verify_signature=False, verify_aud=False))['iss'].strip('/') # Verify token diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 8bcbd110633..3aa71f7c7b1 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -4,18 +4,17 @@ from __future__ import division from __future__ import print_function -import jwt import json import time import pprint -import M2Crypto from sqlalchemy import Column, Integer, Text, String from sqlalchemy.orm import scoped_session from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from sqlalchemy.ext.declarative import declarative_base -from authlib.jose import KeySet, RSAKey +import authlib +from authlib.jose import KeySet, JsonWebKey from authlib.common.security import generate_token from DIRAC import S_OK, S_ERROR @@ -166,8 +165,9 @@ def generateRSAKeys(self): :return: S_OK/S_ERROR """ - key = RSAKey.generate_key(key_size=1024, is_private=True) - keyDict = dict(key=json.dumps(key.as_dict(True)), kid=key.thumbprint(), expires_at=time.time() + (30 * 24 * 3600)) + key = JsonWebKey.generate_key('RSA', 1024, is_private=True) + keyDict = dict(key=json.dumps(key.as_dict(*([True] if authlib.version >= '1.0.0' else []))), + kid=key.thumbprint(), expires_at=time.time() + (30 * 24 * 3600)) session = self.session() try: session.add(JWK(**keyDict)) @@ -187,7 +187,7 @@ def getKeySet(self): result['Value'] = [result['Value']] if not result['OK']: return result - return S_OK(KeySet([RSAKey.import_key(json.loads(key['key'])) for key in result['Value']])) + return S_OK(KeySet([JsonWebKey.import_key(json.loads(key['key'])) for key in result['Value']])) def getPrivateKey(self, kid=None): """ Get private key @@ -202,7 +202,7 @@ def getPrivateKey(self, kid=None): jwks = result['Value'] if kid: strkey = jwks[0]['key'] - return S_OK(RSAKey.import_key(json.loads(jwks[0]['key']))) + return S_OK(JsonWebKey.import_key(json.loads(jwks[0]['key']))) newer = {} for jwk in jwks: if int(jwk['expires_at']) > int(newer.get('expires_at', time.time() + (24 * 3600))): @@ -212,7 +212,7 @@ def getPrivateKey(self, kid=None): if not result['OK']: return result newer = result['Value'] - return S_OK(RSAKey.import_key(json.loads(newer['key']))) + return S_OK(JsonWebKey.import_key(json.loads(newer['key']))) def getActiveKeys(self, kid=None): """ Get active keys diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 06fc7134738..9349cb7a6a8 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -11,6 +11,7 @@ from dominate import document, tags as dom from tornado.template import Template +import authlib from authlib.jose import JsonWebKey, jwt from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error @@ -88,7 +89,10 @@ def __init__(self): self.tokenCli = TokenManagerClient() self.metadata = collectMetadata() self.metadata.validate() - _AuthorizationServer.__init__(self, scopes_supported=self.metadata['scopes_supported']) + _AuthorizationServer.__init__(self, **(dict(scopes_supported=self.metadata['scopes_supported']) + if authlib.version >= '1.0.0' else + dict(query_client=self.query_client, + save_token=None, metadata=self.metadata))) # Skip authlib method save_token and send_signal self.save_token = lambda x, y: None self.send_signal = lambda *x, **y: None @@ -320,15 +324,12 @@ def validate_requested_scope(self, scope, state=None): extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) super(AuthServer, self).validate_requested_scope(extended_scope, state) - def handle_error_response(self, request, error): - return self.handle_response(*error(self.get_error_uri(request, error)), error=True) - - def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, error=None, **actions): + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, **actions): self.log.debug('Handle authorization response with %s status code:' % status_code, payload) self.log.debug('Headers:', headers) if newSession: self.log.debug('newSession:', newSession) - return S_OK([[status_code, headers, payload, newSession, error], actions]) + return S_OK([[status_code, headers, payload, newSession, 'error' in payload], actions]) def create_authorization_response(self, response, username): result = super(AuthServer, self).create_authorization_response(response, username) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index 92730711c8c..80da66d535f 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -123,9 +123,9 @@ def query_user_grant(self, user_code): result = self.server.db.getSessionByUserCode(user_code) if not result['OK']: raise OAuth2Error('Cannot found authorization session', result['Message']) - return (result['Value']['user_id'], True) if result['Value'].get('username') != "None" else None + return (result['Value']['user_id'], True) if result['Value'].get('username', "None") != "None" else ('', False) - def should_slow_down(self, credential): + def should_slow_down(self, *args): """ The authorization request is still pending and polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests. """ diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index fe5ec3d1b9e..897166f874c 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -6,209 +6,217 @@ from __future__ import print_function import six +import time +import authlib +from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey +from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + +db = AuthDB() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_in=int(time.time()) + 3600) + + +def test_RefreshToken(): + """ Try to revoke/save/get refresh tokens + """ + preset_jti = '123' + + # Remove refresh token + result = db.revokeRefreshToken(preset_jti) + assert result['OK'], result['Message'] + + # Store tokens + result = db.storeRefreshToken(DToken.copy(), preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['iat'] <= int(time.time()) + + result = db.storeRefreshToken(New_DToken.copy()) + assert result['OK'], result['Message'] + assert result['Value']['jti'] + assert result['Value']['iat'] <= int(time.time()) + + token_id = result['Value']['jti'] + issued_at = result['Value']['iat'] + + # Check token + result = db.getCredentialByRefreshToken(preset_jti) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == preset_jti + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert result['Value']['jti'] == token_id + assert int(result['Value']['issued_at']) == issued_at + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] + + # Check token after request + for jti in [token_id, preset_jti]: + result = db.getCredentialByRefreshToken(jti) + assert result['OK'], result['Message'] + assert not result['Value'] -if six.PY3: - import time - from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey - from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads - - from DIRAC.Core.Base.Script import parseCommandLine - parseCommandLine() - - from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB - - db = AuthDB() - + # Renew tokens + result = db.storeRefreshToken(New_DToken.copy(), token_id) + assert result['OK'], result['Message'] + + # Revoke token + result = db.revokeRefreshToken(token_id) + assert result['OK'], result['Message'] + + # Check token + result = db.getCredentialByRefreshToken(token_id) + assert result['OK'], result['Message'] + assert not result['Value'] + + +def test_keys(): + """ Try to store/get/remove keys + """ + # JWS + jws = JsonWebSignature(algorithms=['RS256']) + code_payload = {'user_id': 'user', + 'scope': 'scope', + 'client_id': 'client', + 'redirect_uri': 'redirect_uri', + 'code_challenge': 'code_challenge'} + + # Token metadata + header = {'alg': 'RS256'} payload = {'sub': 'user', 'iss': 'issuer', - 'iat': int(time.time()), - 'exp': int(time.time()) + (12 * 3600), 'scope': 'scope', 'setup': 'setup', 'group': 'my_group'} - DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - - New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_in=int(time.time()) + 3600) - - def test_RefreshToken(): - """ Try to revoke/save/get refresh tokens - """ - preset_jti = '123' - - # Remove refresh token - result = db.revokeRefreshToken(preset_jti) - assert result['OK'], result['Message'] - - # Store tokens - result = db.storeRefreshToken(DToken.copy(), preset_jti) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == preset_jti - assert result['Value']['iat'] <= int(time.time()) - - result = db.storeRefreshToken(New_DToken.copy()) - assert result['OK'], result['Message'] - assert result['Value']['jti'] - assert result['Value']['iat'] <= int(time.time()) - - token_id = result['Value']['jti'] - issued_at = result['Value']['iat'] - - # Check token - result = db.getCredentialByRefreshToken(preset_jti) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == preset_jti - assert result['Value']['access_token'] == DToken['access_token'] - assert result['Value']['refresh_token'] == DToken['refresh_token'] - - result = db.getCredentialByRefreshToken(token_id) - assert result['OK'], result['Message'] - assert result['Value']['jti'] == token_id - assert int(result['Value']['issued_at']) == issued_at - assert result['Value']['access_token'] == New_DToken['access_token'] - assert result['Value']['refresh_token'] == New_DToken['refresh_token'] - - # Check token after request - for jti in [token_id, preset_jti]: - result = db.getCredentialByRefreshToken(jti) - assert result['OK'], result['Message'] - assert not result['Value'] - - # Renew tokens - result = db.storeRefreshToken(New_DToken.copy(), token_id) - assert result['OK'], result['Message'] - - # Revoke token - result = db.revokeRefreshToken(token_id) - assert result['OK'], result['Message'] - - # Check token - result = db.getCredentialByRefreshToken(token_id) - assert result['OK'], result['Message'] - assert not result['Value'] - - def test_keys(): - """ Try to store/get/remove keys - """ - # JWS - jws = JsonWebSignature(algorithms=['RS256']) - code_payload = {'user_id': 'user', - 'scope': 'scope', - 'client_id': 'client', - 'redirect_uri': 'redirect_uri', - 'code_challenge': 'code_challenge'} - - # Token metadata - header = {'alg': 'RS256'} - payload = {'sub': 'user', - 'iss': 'issuer', - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} - - # Remove all keys - result = db.removeKeys() - assert result['OK'], result['Message'] + # Remove all keys + result = db.removeKeys() + assert result['OK'], result['Message'] - # Check active keys - result = db.getActiveKeys() - assert result['OK'], result['Message'] - assert result['Value'] == [] + # Check active keys + result = db.getActiveKeys() + assert result['OK'], result['Message'] + assert result['Value'] == [] - # Create new one - result = db.getPrivateKey() - assert result['OK'], result['Message'] + # Create new one + result = db.getPrivateKey() + assert result['OK'], result['Message'] - private_key = result['Value'] - assert isinstance(private_key, RSAKey) + private_key = result['Value'] + assert isinstance(private_key, RSAKey) - # Sign token - header['kid'] = private_key.thumbprint() + # Sign token + header['kid'] = private_key.thumbprint() - # Find key by KID - result = db.getPrivateKey(header['kid']) - assert result['OK'], result['Message'] + # Find key by KID + result = db.getPrivateKey(header['kid']) + assert result['OK'], result['Message'] + if authlib.version >= '1.0.0': assert result['Value'].as_dict(True) == private_key.as_dict(True) - - # Sign token - token = jwt.encode(header, payload, private_key) - # Sign auth code - code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) - - # Get public key set - result = db.getKeySet() - keyset = result['Value'] - assert result['OK'], result['Message'] + else: + assert result['Value'].as_dict() == private_key.as_dict() + + # Sign token + token = jwt.encode(header, payload, private_key) + # Sign auth code + code = jws.serialize_compact(header, json_b64encode(code_payload), private_key) + + # Get public key set + result = db.getKeySet() + keyset = result['Value'] + assert result['OK'], result['Message'] + if authlib.version >= '1.0.0': assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) - - # Read token - _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) - assert _payload == payload - # Read auth code - data = jws.deserialize_compact(code, keyset.keys[0]) - _code_payload = json_loads(urlsafe_b64decode(data['payload'])) - assert _code_payload == code_payload - - def test_Sessions(): - """ Try to store/get/remove Sessions - """ - # Example of the new session metadata - sData1 = {'client_id': 'DIRAC_CLI', - 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'expires_in': 1800, - 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'interval': 5, - 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', - 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} - - # Example of the updated session - sData2 = {'client_id': 'DIRAC_CLI', - 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'expires_in': 1800, - 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', - 'interval': 5, - 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', - 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', - 'user_id': 'username'} - - # Remove old session - db.removeSession(sData1['id']) - - # Add session - result = db.addSession(sData1) - assert result['OK'], result['Message'] - - # Get session - result = db.getSessionByUserCode(sData1['user_code']) - assert result['OK'], result['Message'] - assert result['Value']['device_code'] == sData1['device_code'] - assert result['Value'].get('user_id') != sData2['user_id'] - - # Update session - result = db.updateSession(sData2, sData1['id']) - assert result['OK'], result['Message'] - - # Get session - result = db.getSession(sData2['id']) - assert result['OK'], result['Message'] - assert result['Value']['device_code'] == sData2['device_code'] - assert result['Value']['user_id'] == sData2['user_id'] - - # Remove session - result = db.removeSession(sData2['id']) - assert result['OK'], result['Message'] - - # Make sure that the session is absent - result = db.getSession(sData2['id']) - assert result['OK'], result['Message'] - assert not result['Value'] + else: + assert bool([key for key in keyset.as_dict()['keys'] if key["kid"] == header['kid']]) + + # Read token + _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) + assert _payload == payload + # Read auth code + data = jws.deserialize_compact(code, keyset.keys[0]) + _code_payload = json_loads(urlsafe_b64decode(data['payload'])) + assert _code_payload == code_payload + + +def test_Sessions(): + """ Try to store/get/remove Sessions + """ + # Example of the new session metadata + sData1 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} + + # Example of the updated session + sData2 = {'client_id': 'DIRAC_CLI', + 'device_code': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'expires_in': 1800, + 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', + 'interval': 5, + 'scope': 'g:my_group', + 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'user_code': 'MDKP-MXMF', + 'verification_uri': 'https://domain.com/DIRAC/auth/device', + 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', + 'user_id': 'username'} + + # Remove old session + db.removeSession(sData1['id']) + + # Add session + result = db.addSession(sData1) + assert result['OK'], result['Message'] + + # Get session + result = db.getSessionByUserCode(sData1['user_code']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData1['device_code'] + assert result['Value'].get('user_id') != sData2['user_id'] + + # Update session + result = db.updateSession(sData2, sData1['id']) + assert result['OK'], result['Message'] + + # Get session + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert result['Value']['device_code'] == sData2['device_code'] + assert result['Value']['user_id'] == sData2['user_id'] + + # Remove session + result = db.removeSession(sData2['id']) + assert result['OK'], result['Message'] + + # Make sure that the session is absent + result = db.getSession(sData2['id']) + assert result['OK'], result['Message'] + assert not result['Value'] diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 67d47eace3e..b1fec04cf3a 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -10,18 +10,39 @@ import pytest from mock import MagicMock -if six.PY3: - from authlib.oauth2.base import OAuth2Error +from diraccfg import CFG + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +import DIRAC +from DIRAC import S_OK, S_ERROR, gConfig +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + + +class Proxy(object): + def dumpAllToString(self): + return S_OK('proxy') + - from diraccfg import CFG +class ProxyManagerClient(object): + def downloadProxy(self, *args, **kwargs): + return S_OK(Proxy()) - from DIRAC.Core.Base.Script import parseCommandLine - parseCommandLine() - import DIRAC - from DIRAC import S_OK, S_ERROR, gConfig - from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +class TokenManagerClient(object): + def getToken(self, *args, **kwargs): + return S_OK({'access_token': 'token', 'refresh_token': 'token'}) + +mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) +mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) +mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) +mockisDownloadablePersonalProxy = MagicMock(return_value=True) + + +@pytest.fixture +def server(mocker): cfg = CFG() cfg.loadFromBuffer(""" DIRAC @@ -36,74 +57,63 @@ } """) gConfig.loadCFG(cfg) - - class Proxy(object): - def dumpAllToString(self): - return S_OK('proxy') - - class ProxyManagerClient(object): - def downloadProxy(self, *args, **kwargs): - return S_OK(Proxy()) - - class TokenManagerClient(object): - def getToken(self, *args, **kwargs): - return S_OK({'access_token': 'token', 'refresh_token': 'token'}) - - mockgetIdPForGroup = MagicMock(return_value=S_OK('IdP')) - mockgetDNForUsername = MagicMock(return_value=S_OK('DN')) - mockgetUsernameForDN = MagicMock(return_value=S_OK('user')) - mockisDownloadablePersonalProxy = MagicMock(return_value=True) - - @pytest.fixture - def server(mocker): - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) - return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() - - def test_metadata(server): - """ Check metadata - """ - assert server.metadata.get('issuer') - - def test_queryClient(server): - """ Try to search some default client - """ - assert not server.query_client('not_exist_client') - assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' - - @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ - ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), - ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), - ]) - def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): - """ Generate tokens - """ - cli = server.query_client(client) - try: - assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result - except OAuth2Error as e: - assert False, str(e) - - def test_writeReadRefreshToken(server): - """ Try to search some default client - """ - result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) - assert result['OK'], result['Message'] - token = result['Value'] - assert token.get('access_token') == 'token' - assert token.get('refresh_token') != 'token' - - result = server.readToken(token['refresh_token']) - assert result['OK'], result['Message'] - assert result['Value'].get('jti') - assert result['Value'].get('iat') + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", + side_effect=mockgetIdPForGroup) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", + side_effect=mockgetDNForUsername) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", + side_effect=mockgetUsernameForDN) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", + side_effect=ProxyManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", + side_effect=TokenManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", + side_effect=mockisDownloadablePersonalProxy) + return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() + + +@pytest.mark.skipif(six.PY2, reason="requires python3") +def test_metadata(server): + """ Check metadata + """ + assert server.metadata.get('issuer') + + +@pytest.mark.skipif(six.PY2, reason="requires python3") +def test_queryClient(server): + """ Try to search some default client + """ + assert not server.query_client('not_exist_client') + assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + + +@pytest.mark.skipif(six.PY2, reason="requires python3") +@pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ + ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), + ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), +]) +def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): + """ Generate tokens + """ + from authlib.oauth2.base import OAuth2Error + cli = server.query_client(client) + try: + assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result + except OAuth2Error as e: + assert False, str(e) + + +@pytest.mark.skipif(six.PY2, reason="requires python3") +def test_writeReadRefreshToken(server): + """ Try to search some default client + """ + result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) + assert result['OK'], result['Message'] + token = result['Value'] + assert token.get('access_token') == 'token' + assert token.get('refresh_token') != 'token' + + result = server.readToken(token['refresh_token']) + assert result['OK'], result['Message'] + assert result['Value'].get('jti') + assert result['Value'].get('iat') diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index 4f058a15a75..4ac8517819e 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -6,77 +6,75 @@ from __future__ import print_function import six - -if six.PY3: - import time - from authlib.jose import jwt - - from DIRAC.Core.Base.Script import parseCommandLine - parseCommandLine() - - from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB - - db = TokenDB() - - payload = {'sub': 'user', - 'iss': 'issuer', - 'iat': int(time.time()), - 'exp': int(time.time()) + (12 * 3600), - 'scope': 'scope', - 'setup': 'setup', - 'group': 'my_group'} - - exp_payload = payload.copy() - exp_payload['iat'] = int(time.time()) - 10 - exp_payload['exp'] = int(time.time()) - 10 - - DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - - New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - issued_at=int(time.time()), - expires_in=int(time.time()) + 3600) - - Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - expires_at=int(time.time()) - 10, - rt_expires_at=int(time.time()) - 10) - - - def test_Token(): - """ Try to revoke/save/get tokens - """ - # Remove all tokens - result = db.removeToken(user_id=123) - assert result['OK'], result['Message'] - - # Store tokens - result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert result['OK'], result['Message'] - assert result['Value'] == [] - - # Expired token - result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert not result['OK'] - - # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') - assert result['OK'], result['Message'] - assert result['Value']['access_token'] == DToken['access_token'] - assert result['Value']['refresh_token'] == DToken['refresh_token'] - - # Store new tokens - result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) - assert result['OK'], result['Message'] - # Must return old tokens - assert len(result['Value']) == 1 - assert result['Value'][0]['access_token'] == DToken['access_token'] - assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] - - # Check token - result = db.getTokenForUserProvider(userID=123, provider='DIRAC') - assert result['OK'], result['Message'] - assert result['Value']['access_token'] == New_DToken['access_token'] - assert result['Value']['refresh_token'] == New_DToken['refresh_token'] +import time +from authlib.jose import jwt + +from DIRAC.Core.Base.Script import parseCommandLine +parseCommandLine() + +from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB + +db = TokenDB() + +payload = {'sub': 'user', + 'iss': 'issuer', + 'iat': int(time.time()), + 'exp': int(time.time()) + (12 * 3600), + 'scope': 'scope', + 'setup': 'setup', + 'group': 'my_group'} + +exp_payload = payload.copy() +exp_payload['iat'] = int(time.time()) - 10 +exp_payload['exp'] = int(time.time()) - 10 + +DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + +New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + +Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + expires_at=int(time.time()) - 10, + rt_expires_at=int(time.time()) - 10) + + +def test_Token(): + """ Try to revoke/save/get tokens + """ + # Remove all tokens + result = db.removeToken(user_id=123) + assert result['OK'], result['Message'] + + # Store tokens + result = db.updateToken(DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + assert result['Value'] == [] + + # Expired token + result = db.updateToken(Exp_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert not result['OK'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == DToken['access_token'] + assert result['Value']['refresh_token'] == DToken['refresh_token'] + + # Store new tokens + result = db.updateToken(New_DToken.copy(), userID=123, provider='DIRAC', rt_expired_in=24) + assert result['OK'], result['Message'] + # Must return old tokens + assert len(result['Value']) == 1 + assert result['Value'][0]['access_token'] == DToken['access_token'] + assert result['Value'][0]['refresh_token'] == DToken['refresh_token'] + + # Check token + result = db.getTokenForUserProvider(userID=123, provider='DIRAC') + assert result['OK'], result['Message'] + assert result['Value']['access_token'] == New_DToken['access_token'] + assert result['Value']['refresh_token'] == New_DToken['refresh_token'] diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 37d9c948ea0..ed021d08cc0 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -4,101 +4,100 @@ __RCSID__ = "$Id$" -import six +import time +import unittest +from authlib.jose import jwt -if six.PY3: - import time - import unittest - from authlib.jose import jwt +from diraccfg import CFG - from diraccfg import CFG +import DIRAC +from DIRAC import gConfig +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS - import DIRAC - from DIRAC import gConfig - from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory - from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS - from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata - - cfg = CFG() - cfg.loadFromBuffer(""" - DIRAC +cfg = CFG() +cfg.loadFromBuffer(""" +DIRAC +{ + Security { - Security + Authorization { - Authorization + issuer = https://issuer.url/ + Clients { - issuer = https://issuer.url/ - Clients + DIRACWeb { - DIRACWeb - { - client_id = client_identificator - client_secret = client_secret_key - redirect_uri = https://redirect.url/ - } + client_id = client_identificator + client_secret = client_secret_key + redirect_uri = https://redirect.url/ } } } } - Resources +} +Resources +{ + IdProviders { - IdProviders + SomeIdP { - SomeIdP - { - ProviderType = OAuth2 - issuer = https://idp.url/ - client_id = IdP_client_id - client_secret = IdP_client_secret - redirect_uri = https://dirac/redirect - jwks_uri = https://idp.url/jwk - scope = openid+profile+offline_access+eduperson_entitlement - } + ProviderType = OAuth2 + issuer = https://idp.url/ + client_id = IdP_client_id + client_secret = IdP_client_secret + redirect_uri = https://dirac/redirect + jwks_uri = https://idp.url/jwk + scope = openid+profile+offline_access+eduperson_entitlement } } - """) - gConfig.loadCFG(cfg) +} +""") +gConfig.loadCFG(cfg) + +idps = IdProviderFactory() + - idps = IdProviderFactory() +def test_getDIRACClients(): + """ Try to load default DIRAC authorization client + """ + params = collectMetadata() - def test_getDIRACClients(): - """ Try to load default DIRAC authorization client - """ - params = collectMetadata() + # Try to get DIRAC client authorization settings + result = idps.getIdProvider('DIRACCLI', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' - # Try to get DIRAC client authorization settings - result = idps.getIdProvider('DIRACCLI', **params) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://issuer.url/' - assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] - assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' + # Try to get DIRAC client authorization settings for Web portal + result = idps.getIdProvider('DIRACWeb', **params) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://issuer.url/' + assert result['Value'].client_id == 'client_identificator' + assert result['Value'].client_secret == 'client_secret_key' + assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' - # Try to get DIRAC client authorization settings for Web portal - result = idps.getIdProvider('DIRACWeb', **params) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://issuer.url/' - assert result['Value'].client_id == 'client_identificator' - assert result['Value'].client_secret == 'client_secret_key' - assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' - def test_getIdPClients(): - """ Try to load external identity provider settings - """ - # Try to get identity provider by name - result = idps.getIdProvider('SomeIdP', jwks='my_jwks') - assert result['OK'], result['Message'] - assert result['Value'].jwks == 'my_jwks' - assert result['Value'].issuer == 'https://idp.url/' - assert result['Value'].client_id == 'IdP_client_id' - assert result['Value'].client_secret == 'IdP_client_secret' - assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' +def test_getIdPClients(): + """ Try to load external identity provider settings + """ + # Try to get identity provider by name + result = idps.getIdProvider('SomeIdP', jwks='my_jwks') + assert result['OK'], result['Message'] + assert result['Value'].jwks == 'my_jwks' + assert result['Value'].issuer == 'https://idp.url/' + assert result['Value'].client_id == 'IdP_client_id' + assert result['Value'].client_secret == 'IdP_client_secret' + assert result['Value'].get_metadata('jwks_uri') == 'https://idp.url/jwk' - # Try to get identity provider for token issued by it - result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( - sub='user', - iss=result['Value'].issuer, - iat=int(time.time()), - exp=int(time.time()) + (12 * 3600), - ), "secret").decode('utf-8')) - assert result['OK'], result['Message'] - assert result['Value'].issuer == 'https://idp.url/' + # Try to get identity provider for token issued by it + result = idps.getIdProviderForToken(jwt.encode({'alg': 'HS256'}, dict( + sub='user', + iss=result['Value'].issuer, + iat=int(time.time()), + exp=int(time.time()) + (12 * 3600), + ), "secret").decode('utf-8')) + assert result['OK'], result['Message'] + assert result['Value'].issuer == 'https://idp.url/' diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 00d67d444f0..9f17f489e21 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -177,8 +177,12 @@ installSite() { fi fi - echo "==> Done installing, now configuring" source "${SERVERINSTALLDIR}/bashrc" + + echo "==> Install OAuth2 requirements" + pip install dominate pyjwt authlib + + echo "==> Done installing, now configuring" if ! dirac-configure --cfg "${SERVERINSTALLDIR}/install.cfg" "${DEBUG}"; then echo "ERROR: dirac-configure failed" >&2 exit 1 diff --git a/tests/Jenkins/utilities.sh b/tests/Jenkins/utilities.sh index 08bd24da217..011f667337b 100644 --- a/tests/Jenkins/utilities.sh +++ b/tests/Jenkins/utilities.sh @@ -383,6 +383,9 @@ installDIRAC() { echo "$DIRAC" echo "$PATH" + echo "==> Install OAuth2 requirements" + pip install dominate pyjwt authlib + # now configuring cmd="dirac-configure -S ${DIRACSETUP} -C ${CSURL} --SkipCAChecks ${CONFIGUREOPTIONS} ${DEBUG}" if ! bash -c "${cmd}"; then From b0918d550c33146211e562058dadf904e7eb66d1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 28 Jun 2021 20:18:20 +0200 Subject: [PATCH 096/178] compatibility with py2 --- .../Server/private/BaseRequestHandler.py | 1 + .../private/authorization/AuthServer.py | 1 + .../Integration/Framework/Test_AuthServer.py | 40 +++++++++---------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 11ef9b9b784..07bdd20e95b 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -10,6 +10,7 @@ from io import open import os +import six import jwt import time import threading diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 9349cb7a6a8..2265f9e8ac0 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -104,6 +104,7 @@ def __init__(self): self.register_endpoint(RevocationEndpoint) self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) + # pylint: disable=method-hidden def query_client(self, client_id): """ Search authorization client. diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index b1fec04cf3a..f897c7ea5b1 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -42,7 +42,7 @@ def getToken(self, *args, **kwargs): @pytest.fixture -def server(mocker): +def auth_server(mocker): cfg = CFG() cfg.loadFromBuffer(""" DIRAC @@ -58,62 +58,58 @@ def server(mocker): """) gConfig.loadCFG(cfg) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) + side_effect=mockgetIdPForGroup) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) + side_effect=mockgetDNForUsername) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) + side_effect=mockgetUsernameForDN) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) + side_effect=ProxyManagerClient) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) + side_effect=TokenManagerClient) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) + side_effect=mockisDownloadablePersonalProxy) return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() -@pytest.mark.skipif(six.PY2, reason="requires python3") -def test_metadata(server): +def test_metadata(auth_server): """ Check metadata """ - assert server.metadata.get('issuer') + assert auth_server.metadata.get('issuer') -@pytest.mark.skipif(six.PY2, reason="requires python3") -def test_queryClient(server): +def test_queryClient(auth_server): """ Try to search some default client """ - assert not server.query_client('not_exist_client') - assert server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' + assert not auth_server.query_client('not_exist_client') + assert auth_server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' -@pytest.mark.skipif(six.PY2, reason="requires python3") @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), ]) -def test_generateToken(server, client, grant, user, scope, expires_in, refresh_token, instance, result): +def test_generateToken(auth_server, client, grant, user, scope, expires_in, refresh_token, instance, result): """ Generate tokens """ from authlib.oauth2.base import OAuth2Error - cli = server.query_client(client) + cli = auth_server.query_client(client) try: - assert server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result + assert auth_server.generate_token(cli, grant, user, scope, expires_in, refresh_token).get(instance) == result except OAuth2Error as e: assert False, str(e) -@pytest.mark.skipif(six.PY2, reason="requires python3") -def test_writeReadRefreshToken(server): +def test_writeReadRefreshToken(auth_server): """ Try to search some default client """ - result = server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) + result = auth_server.registerRefreshToken({}, {'access_token': 'token', 'refresh_token': 'token'}) assert result['OK'], result['Message'] token = result['Value'] assert token.get('access_token') == 'token' assert token.get('refresh_token') != 'token' - result = server.readToken(token['refresh_token']) + result = auth_server.readToken(token['refresh_token']) assert result['OK'], result['Message'] assert result['Value'].get('jti') assert result['Value'].get('iat') From 85aad1f62a874e0aa4abdc1ab85dfeb01c9676f0 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 29 Jun 2021 14:00:40 +0200 Subject: [PATCH 097/178] add mocker, fix authzSSL --- setup.cfg | 1 + src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index e32108285b8..b31ca03e79d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,6 +95,7 @@ testing = parameterized pytest pytest-cov + pytest-mock pycodestyle [options.entry_points] diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 07bdd20e95b..dbd8f9f2205 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -576,20 +576,18 @@ def _authzSSL(self): """ peerChain = X509Chain() derCert = self.request.get_ssl_certificate() - # Get client certificate pem if derCert: chainAsText = derCert.as_pem() # Here we read all certificate chain - cert_chain = self.request.get_ssl_certificate_chain() - for cert in cert_chain: - chainAsText = cert.as_pem() + chainAsText += '\n'.join([cert.as_pem() for cert in self.request.get_ssl_certificate_chain()]) elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS': chainAsTextEncoded = self.request.headers.get('X-SSL-CERT') chainAsText = unquote(chainAsTextEncoded) else: return S_ERROR(DErrno.ECERTFIND, 'Valid certificate not found.') + # Load full certificate chain peerChain.loadChainFromString(chainAsText) # Retrieve the credentials From 820fcfbdff91264ee2ad966c10812a406afb2c34 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 29 Jun 2021 15:40:21 +0200 Subject: [PATCH 098/178] add mock for testing --- tests/Jenkins/dirac_ci.sh | 2 +- tests/Jenkins/utilities.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 9f17f489e21..7373e87c6c7 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -180,7 +180,7 @@ installSite() { source "${SERVERINSTALLDIR}/bashrc" echo "==> Install OAuth2 requirements" - pip install dominate pyjwt authlib + pip install dominate pyjwt authlib pytest-mock echo "==> Done installing, now configuring" if ! dirac-configure --cfg "${SERVERINSTALLDIR}/install.cfg" "${DEBUG}"; then diff --git a/tests/Jenkins/utilities.sh b/tests/Jenkins/utilities.sh index 011f667337b..e933b4a943f 100644 --- a/tests/Jenkins/utilities.sh +++ b/tests/Jenkins/utilities.sh @@ -384,7 +384,7 @@ installDIRAC() { echo "$PATH" echo "==> Install OAuth2 requirements" - pip install dominate pyjwt authlib + pip install dominate pyjwt authlib pytest-mock # now configuring cmd="dirac-configure -S ${DIRACSETUP} -C ${CSURL} --SkipCAChecks ${CONFIGUREOPTIONS} ${DEBUG}" From 73d5c6bb9d725e7a43a3fe1906bf398ed240d967 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 29 Jun 2021 21:56:07 +0200 Subject: [PATCH 099/178] take into account the lack of packages in DIRACOS --- .../Client/private/TornadoBaseClient.py | 16 +++++-- .../Server/private/BaseRequestHandler.py | 26 +++++++--- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 20 ++++---- .../FrameworkSystem/scripts/dirac_login.py | 2 +- tests/Integration/Framework/Test_AuthDB.py | 39 +++++++++------ .../Integration/Framework/Test_AuthServer.py | 48 ++++++++++++------- tests/Integration/Framework/Test_TokenDB.py | 41 ++++++++-------- .../IdProvider/Test_IdProviderFactory.py | 16 +++++-- 8 files changed, 132 insertions(+), 76 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 7800c2859a4..abdca7f2f70 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -49,8 +49,16 @@ from DIRAC.Core.Security import Locations from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile + +try: + # DIRACOS not contain required packages + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile +except ImportError as e: + IdProviderFactory = None + if six.PY3: + # But DIRACOS2 must contain required packages + raise e # TODO CHRIS: refactor all the messy `discover` methods @@ -240,7 +248,7 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken - if self.__useAccessToken: + if self.__useAccessToken and IdProviderFactory: result = IdProviderFactory().getIdProvider('DIRACCLI') if not result['OK']: return result @@ -520,7 +528,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): auth = {'cert': Locations.getHostCertificateAndKeyLocation()} # Use access token? - elif self.__useAccessToken: + elif self.__useAccessToken and IdProviderFactory: # Read token from token environ variable or from token file result = getLocalTokenDict() if not result['OK']: diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index dbd8f9f2205..c2308e936aa 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -11,7 +11,6 @@ import os import six -import jwt import time import threading from datetime import datetime @@ -36,7 +35,16 @@ from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +try: + # DIRACOS not contain required packages + import jwt + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +except ImportError as e: + IdProviderFactory = None + if six.PY3: + # But DIRACOS2 must contain required packages + raise e sLog = gLogger.getSubLogger(__name__.split('.')[-1]) @@ -138,13 +146,13 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ # Prefix of methods names METHOD_PREFIX = "export_" - # Which grant type to use - USE_AUTHZ_GRANTS = ['SSL', 'JWT'] - # Definition of identity providers - _idps = IdProviderFactory() + _idps = IdProviderFactory() if IdProviderFactory else None _idp = {} + # Which grant type to use + USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + @classmethod def _initMonitoring(cls, serviceName, fullUrl): """ @@ -251,7 +259,8 @@ def __initializeService(cls, request): return S_OK() # Load all registred identity providers - cls.__loadIdPs() + if cls._idps: + cls.__loadIdPs() # absoluteUrl: full URL e.g. ``https://://`` absoluteUrl = request.path @@ -555,6 +564,9 @@ def _gatherPeerCredentials(self, grants=None): # the authorization will go through the `_authzVISITOR` method and # everyone will have access as anonymous@visitor for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): + if not self._idps and grant == 'JWT': + # Skip token authorization if authlib and pyjwt packages not installed + continue grant = grant.upper() grantFunc = getattr(self, '_authz%s' % grant, None) # pylint: disable=not-callable diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index ae376acf2f9..74d9ee17e9c 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -14,19 +14,23 @@ import pprint from io import open -from dominate import document, tags as dom from tornado.template import Template -from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc6749.util import scope_to_list - from DIRAC import S_ERROR from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry -from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer -from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request -from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint -from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint + +try: + from dominate import document, tags as dom + from authlib.oauth2.base import OAuth2Error + from authlib.oauth2.rfc6749.util import scope_to_list + from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request + from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint + from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +except ImportError as e: + msg = 'This endpoint requires authlib, pyjwt, dominate and dominate that enabled only for python 3 server installation.' + raise ImportError(msg) if six.PY2 else e __RCSID__ = "$Id$" diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index ceff0366ac1..e9a1ec836ee 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -4,7 +4,7 @@ # Author : Andrii Lytovchenko ######################################################################## """ -Login to DIRAC. For python 3 only. +Login to DIRAC. Example: $ dirac-login -g dirac_user diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index 897166f874c..afd6e22976c 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -1,4 +1,4 @@ -""" This is a test of the AuthDB +""" This is a test of the AuthDB. Requires authlib It supposes that the DB is present and installed in DIRAC """ from __future__ import absolute_import @@ -7,16 +7,21 @@ import six import time -import authlib -from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey -from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads from DIRAC.Core.Base.Script import parseCommandLine parseCommandLine() -from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB - -db = AuthDB() +try: + # DIRACOS not contain required packages + from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey + from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads + from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + db = AuthDB() +except ImportError as e: + db = None + if six.PY3: + # But DIRACOS2 must contain required packages + raise e payload = {'sub': 'user', 'iss': 'issuer', @@ -26,18 +31,19 @@ 'setup': 'setup', 'group': 'my_group'} -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_in=int(time.time()) + 3600) - +@pytest.mark.skipif(six.PY2 and not db, reason="Skiped for Python 2 tests") def test_RefreshToken(): """ Try to revoke/save/get refresh tokens """ + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_in=int(time.time()) + 3600) + preset_jti = '123' # Remove refresh token @@ -92,6 +98,7 @@ def test_RefreshToken(): assert not result['Value'] +@pytest.mark.skipif(six.PY2 and not db, reason="Skiped for Python 2 tests") def test_keys(): """ Try to store/get/remove keys """ @@ -161,6 +168,8 @@ def test_keys(): assert _code_payload == code_payload +# DIRACOS not contain required packages +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2 tests") def test_Sessions(): """ Try to store/get/remove Sessions """ diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index f897c7ea5b1..4454877de70 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -1,4 +1,4 @@ -""" This is a test of the AuthServer +""" This is a test of the AuthServer. Requires authlib, pyjwt, dominate It supposes that the AuthDB is present and installed in DIRAC """ from __future__ import absolute_import @@ -17,7 +17,15 @@ import DIRAC from DIRAC import S_OK, S_ERROR, gConfig -from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + +try: + # DIRACOS not contain required packages + from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +except ImportError as e: + AuthServer = None + if six.PY3: + # But DIRACOS2 must contain required packages + raise e class Proxy(object): @@ -57,27 +65,31 @@ def auth_server(mocker): } """) gConfig.loadCFG(cfg) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) - return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() - - + if AuthServer: + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", + side_effect=mockgetIdPForGroup) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", + side_effect=mockgetDNForUsername) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", + side_effect=mockgetUsernameForDN) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", + side_effect=ProxyManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", + side_effect=TokenManagerClient) + mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", + side_effect=mockisDownloadablePersonalProxy) + return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() + return None + + +@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") def test_metadata(auth_server): """ Check metadata """ assert auth_server.metadata.get('issuer') +@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") def test_queryClient(auth_server): """ Try to search some default client """ @@ -85,6 +97,7 @@ def test_queryClient(auth_server): assert auth_server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' +@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), @@ -100,6 +113,7 @@ def test_generateToken(auth_server, client, grant, user, scope, expires_in, refr assert False, str(e) +@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") def test_writeReadRefreshToken(auth_server): """ Try to search some default client """ diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index 4ac8517819e..aa943866c81 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -1,4 +1,4 @@ -""" This is a test of the AuthDB +""" This is a test of the AuthDB. Requires authlib, pyjwt It supposes that the DB is present and installed in DIRAC """ from __future__ import absolute_import @@ -7,15 +7,10 @@ import six import time -from authlib.jose import jwt from DIRAC.Core.Base.Script import parseCommandLine parseCommandLine() -from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB - -db = TokenDB() - payload = {'sub': 'user', 'iss': 'issuer', 'iat': int(time.time()), @@ -28,24 +23,30 @@ exp_payload['iat'] = int(time.time()) - 10 exp_payload['exp'] = int(time.time()) - 10 -DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - expires_at=int(time.time()) + 3600) - -New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), - issued_at=int(time.time()), - expires_in=int(time.time()) + 3600) - -Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), - expires_at=int(time.time()) - 10, - rt_expires_at=int(time.time()) - 10) - +# DIRACOS not contain required packages +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2 tests") def test_Token(): """ Try to revoke/save/get tokens """ + from authlib.jose import jwt + from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB + db = TokenDB() + + DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + expires_at=int(time.time()) + 3600) + + New_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), + issued_at=int(time.time()), + expires_in=int(time.time()) + 3600) + + Exp_DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + refresh_token=jwt.encode({'alg': 'HS256'}, exp_payload, "secret").decode('utf-8'), + expires_at=int(time.time()) - 10, + rt_expires_at=int(time.time()) - 10) + # Remove all tokens result = db.removeToken(user_id=123) assert result['OK'], result['Message'] diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index ed021d08cc0..765c796a207 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -6,13 +6,11 @@ import time import unittest -from authlib.jose import jwt from diraccfg import CFG import DIRAC from DIRAC import gConfig -from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS @@ -56,9 +54,18 @@ """) gConfig.loadCFG(cfg) -idps = IdProviderFactory() - +try: + # DIRACOS not contain required packages + from authlib.jose import jwt + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + idps = IdProviderFactory() +except ImportError as e: + IdProviderFactory = None + if six.PY3: + # But DIRACOS2 must contain required packages + raise e +@pytest.mark.skipif(six.PY2 and not IdProviderFactory, reason="Skiped for Python 2 tests") def test_getDIRACClients(): """ Try to load default DIRAC authorization client """ @@ -80,6 +87,7 @@ def test_getDIRACClients(): assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' +@pytest.mark.skipif(six.PY2 and not IdProviderFactory, reason="Skiped for Python 2 tests") def test_getIdPClients(): """ Try to load external identity provider settings """ From add96f272d6d90559eefb629b85e8cfd4a7efd9f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 29 Jun 2021 23:00:13 +0200 Subject: [PATCH 100/178] disable token authentication for python 2 --- .../Client/private/TornadoBaseClient.py | 12 ++------ .../Server/private/BaseRequestHandler.py | 15 ++++------ src/DIRAC/FrameworkSystem/DB/AuthDB.py | 4 ++- .../private/authorization/AuthServer.py | 7 ++--- tests/Integration/Framework/Test_AuthDB.py | 28 ++++++++----------- .../Integration/Framework/Test_AuthServer.py | 15 ++++------ tests/Integration/Framework/Test_TokenDB.py | 13 +++++---- .../IdProvider/Test_IdProviderFactory.py | 13 ++++----- tests/Jenkins/dirac_ci.sh | 6 +--- tests/Jenkins/utilities.sh | 3 -- 10 files changed, 44 insertions(+), 72 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index abdca7f2f70..78dfc09b9c1 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -50,16 +50,10 @@ from DIRAC.Core.Utilities import List, Network from DIRAC.Core.Utilities.JEncode import decode, encode -try: +if six.PY3: # DIRACOS not contain required packages from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import getLocalTokenDict, writeTokenDictToTokenFile -except ImportError as e: - IdProviderFactory = None - if six.PY3: - # But DIRACOS2 must contain required packages - raise e - # TODO CHRIS: refactor all the messy `discover` methods # I do not do it now because I want first to decide @@ -248,7 +242,7 @@ def __discoverCredentialsToUse(self): self.__useAccessToken = gConfig.getValue("/DIRAC/Security/UseTokens", "false").lower() in ("y", "yes", "true") self.kwargs[self.KW_USE_ACCESS_TOKEN] = self.__useAccessToken - if self.__useAccessToken and IdProviderFactory: + if self.__useAccessToken and six.PY3: result = IdProviderFactory().getIdProvider('DIRACCLI') if not result['OK']: return result @@ -528,7 +522,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): auth = {'cert': Locations.getHostCertificateAndKeyLocation()} # Use access token? - elif self.__useAccessToken and IdProviderFactory: + elif self.__useAccessToken and six.PY3: # Read token from token environ variable or from token file result = getLocalTokenDict() if not result['OK']: diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index c2308e936aa..7814269dad7 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -36,15 +36,10 @@ from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance -try: +if six.PY3: # DIRACOS not contain required packages import jwt from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -except ImportError as e: - IdProviderFactory = None - if six.PY3: - # But DIRACOS2 must contain required packages - raise e sLog = gLogger.getSubLogger(__name__.split('.')[-1]) @@ -147,7 +142,7 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ METHOD_PREFIX = "export_" # Definition of identity providers - _idps = IdProviderFactory() if IdProviderFactory else None + _idps = IdProviderFactory() if six.PY3 else None _idp = {} # Which grant type to use @@ -259,7 +254,7 @@ def __initializeService(cls, request): return S_OK() # Load all registred identity providers - if cls._idps: + if six.PY3: cls.__loadIdPs() # absoluteUrl: full URL e.g. ``https://://`` @@ -564,8 +559,8 @@ def _gatherPeerCredentials(self, grants=None): # the authorization will go through the `_authzVISITOR` method and # everyone will have access as anonymous@visitor for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): - if not self._idps and grant == 'JWT': - # Skip token authorization if authlib and pyjwt packages not installed + if six.PY3 and grant == 'JWT': + # Skip token authorization for python 2 continue grant = grant.upper() grantFunc = getattr(self, '_authz%s' % grant, None) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 3aa71f7c7b1..0b0d4063e5a 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -166,7 +166,9 @@ def generateRSAKeys(self): :return: S_OK/S_ERROR """ key = JsonWebKey.generate_key('RSA', 1024, is_private=True) - keyDict = dict(key=json.dumps(key.as_dict(*([True] if authlib.version >= '1.0.0' else []))), + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0 + keyDict = dict(key=json.dumps(key.as_dict(True)), kid=key.thumbprint(), expires_at=time.time() + (30 * 24 * 3600)) session = self.session() try: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 2265f9e8ac0..fba1e9df157 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -89,10 +89,9 @@ def __init__(self): self.tokenCli = TokenManagerClient() self.metadata = collectMetadata() self.metadata.validate() - _AuthorizationServer.__init__(self, **(dict(scopes_supported=self.metadata['scopes_supported']) - if authlib.version >= '1.0.0' else - dict(query_client=self.query_client, - save_token=None, metadata=self.metadata))) + # args for authlib < 1.0.0: (query_client=self.query_client, save_token=None, metadata=self.metadata) + # for authlib >= 1.0.0: + _AuthorizationServer.__init__(self, scopes_supported=self.metadata['scopes_supported']) # Skip authlib method save_token and send_signal self.save_token = lambda x, y: None self.send_signal = lambda *x, **y: None diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index afd6e22976c..a1e6d30b4a4 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -7,21 +7,17 @@ import six import time +import pytest from DIRAC.Core.Base.Script import parseCommandLine parseCommandLine() -try: +if six.PY3: # DIRACOS not contain required packages from authlib.jose import JsonWebKey, JsonWebSignature, jwt, RSAKey from authlib.common.encoding import json_b64encode, urlsafe_b64decode, json_loads from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB db = AuthDB() -except ImportError as e: - db = None - if six.PY3: - # But DIRACOS2 must contain required packages - raise e payload = {'sub': 'user', 'iss': 'issuer', @@ -32,7 +28,7 @@ 'group': 'my_group'} -@pytest.mark.skipif(six.PY2 and not db, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_RefreshToken(): """ Try to revoke/save/get refresh tokens """ @@ -98,7 +94,7 @@ def test_RefreshToken(): assert not result['Value'] -@pytest.mark.skipif(six.PY2 and not db, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_keys(): """ Try to store/get/remove keys """ @@ -140,10 +136,9 @@ def test_keys(): # Find key by KID result = db.getPrivateKey(header['kid']) assert result['OK'], result['Message'] - if authlib.version >= '1.0.0': - assert result['Value'].as_dict(True) == private_key.as_dict(True) - else: - assert result['Value'].as_dict() == private_key.as_dict() + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0: + assert result['Value'].as_dict(True) == private_key.as_dict(True) # Sign token token = jwt.encode(header, payload, private_key) @@ -154,10 +149,9 @@ def test_keys(): result = db.getKeySet() keyset = result['Value'] assert result['OK'], result['Message'] - if authlib.version >= '1.0.0': - assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) - else: - assert bool([key for key in keyset.as_dict()['keys'] if key["kid"] == header['kid']]) + # as_dict has no arguments for authlib < 1.0.0 + # for authlib >= 1.0.0: + assert bool([key for key in keyset.as_dict(True)['keys'] if key["kid"] == header['kid']]) # Read token _payload = jwt.decode(token, JsonWebKey.import_key_set(keyset.as_dict())) @@ -169,7 +163,7 @@ def test_keys(): # DIRACOS not contain required packages -@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_Sessions(): """ Try to store/get/remove Sessions """ diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 4454877de70..86ed4f0d37c 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -18,14 +18,9 @@ import DIRAC from DIRAC import S_OK, S_ERROR, gConfig -try: +if six.PY3: # DIRACOS not contain required packages from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer -except ImportError as e: - AuthServer = None - if six.PY3: - # But DIRACOS2 must contain required packages - raise e class Proxy(object): @@ -82,14 +77,14 @@ def auth_server(mocker): return None -@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_metadata(auth_server): """ Check metadata """ assert auth_server.metadata.get('issuer') -@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_queryClient(auth_server): """ Try to search some default client """ @@ -97,7 +92,7 @@ def test_queryClient(auth_server): assert auth_server.query_client('DIRAC_CLI').client_id == 'DIRAC_CLI' -@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") @pytest.mark.parametrize("client, grant, user, scope, expires_in, refresh_token, instance, result", [ ('DIRAC_CLI', None, 'id', 'g:my_group proxy', None, None, 'proxy', 'proxy'), ('DIRAC_CLI', None, 'id', 'g:my_group', None, None, 'access_token', 'token'), @@ -113,7 +108,7 @@ def test_generateToken(auth_server, client, grant, user, scope, expires_in, refr assert False, str(e) -@pytest.mark.skipif(six.PY2 and not AuthServer, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_writeReadRefreshToken(auth_server): """ Try to search some default client """ diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index aa943866c81..7cc8af1f291 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -24,15 +24,18 @@ exp_payload['exp'] = int(time.time()) - 10 -# DIRACOS not contain required packages -@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2 tests") -def test_Token(): - """ Try to revoke/save/get tokens - """ +if six.PY3: + # DIRACOS not contain required packages from authlib.jose import jwt from DIRAC.FrameworkSystem.DB.TokenDB import TokenDB db = TokenDB() + +# DIRACOS not contain required packages +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") +def test_Token(): + """ Try to revoke/save/get tokens + """ DToken = dict(access_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), refresh_token=jwt.encode({'alg': 'HS256'}, payload, "secret").decode('utf-8'), expires_at=int(time.time()) + 3600) diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 765c796a207..6652bc6fb91 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -5,6 +5,7 @@ __RCSID__ = "$Id$" import time +import pytest import unittest from diraccfg import CFG @@ -54,18 +55,14 @@ """) gConfig.loadCFG(cfg) -try: +if six.PY3: # DIRACOS not contain required packages from authlib.jose import jwt from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory idps = IdProviderFactory() -except ImportError as e: - IdProviderFactory = None - if six.PY3: - # But DIRACOS2 must contain required packages - raise e -@pytest.mark.skipif(six.PY2 and not IdProviderFactory, reason="Skiped for Python 2 tests") + +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_getDIRACClients(): """ Try to load default DIRAC authorization client """ @@ -87,7 +84,7 @@ def test_getDIRACClients(): assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' -@pytest.mark.skipif(six.PY2 and not IdProviderFactory, reason="Skiped for Python 2 tests") +@pytest.mark.skipif(six.PY2, reason="Skiped for Python 2") def test_getIdPClients(): """ Try to load external identity provider settings """ diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 7373e87c6c7..00d67d444f0 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -177,12 +177,8 @@ installSite() { fi fi - source "${SERVERINSTALLDIR}/bashrc" - - echo "==> Install OAuth2 requirements" - pip install dominate pyjwt authlib pytest-mock - echo "==> Done installing, now configuring" + source "${SERVERINSTALLDIR}/bashrc" if ! dirac-configure --cfg "${SERVERINSTALLDIR}/install.cfg" "${DEBUG}"; then echo "ERROR: dirac-configure failed" >&2 exit 1 diff --git a/tests/Jenkins/utilities.sh b/tests/Jenkins/utilities.sh index e933b4a943f..08bd24da217 100644 --- a/tests/Jenkins/utilities.sh +++ b/tests/Jenkins/utilities.sh @@ -383,9 +383,6 @@ installDIRAC() { echo "$DIRAC" echo "$PATH" - echo "==> Install OAuth2 requirements" - pip install dominate pyjwt authlib pytest-mock - # now configuring cmd="dirac-configure -S ${DIRACSETUP} -C ${CSURL} --SkipCAChecks ${CONFIGUREOPTIONS} ${DEBUG}" if ! bash -c "${cmd}"; then From b46b034c91e72959778baed5c2256e6eb31b759b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 29 Jun 2021 23:25:32 +0200 Subject: [PATCH 101/178] fix tests --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 18 +++++++----------- .../Resources/IdProvider/IdProviderFactory.py | 1 + tests/Integration/Framework/Test_AuthServer.py | 12 ++++++------ tests/Integration/Framework/Test_TokenDB.py | 1 + 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 74d9ee17e9c..616ab17786c 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -20,17 +20,13 @@ from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry -try: - from dominate import document, tags as dom - from authlib.oauth2.base import OAuth2Error - from authlib.oauth2.rfc6749.util import scope_to_list - from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer - from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request - from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint - from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint -except ImportError as e: - msg = 'This endpoint requires authlib, pyjwt, dominate and dominate that enabled only for python 3 server installation.' - raise ImportError(msg) if six.PY2 else e +from dominate import document, tags as dom +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.util import scope_to_list +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint __RCSID__ = "$Id$" diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 526783423a4..a8dbcc010df 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -10,6 +10,7 @@ from __future__ import division from __future__ import print_function +import six import jwt from DIRAC import S_OK, S_ERROR, gLogger diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 86ed4f0d37c..84e351ad576 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -62,17 +62,17 @@ def auth_server(mocker): gConfig.loadCFG(cfg) if AuthServer: mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) + side_effect=mockgetIdPForGroup) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) + side_effect=mockgetDNForUsername) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) + side_effect=mockgetUsernameForDN) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) + side_effect=ProxyManagerClient) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) + side_effect=TokenManagerClient) mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) + side_effect=mockisDownloadablePersonalProxy) return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() return None diff --git a/tests/Integration/Framework/Test_TokenDB.py b/tests/Integration/Framework/Test_TokenDB.py index 7cc8af1f291..245edd33363 100644 --- a/tests/Integration/Framework/Test_TokenDB.py +++ b/tests/Integration/Framework/Test_TokenDB.py @@ -7,6 +7,7 @@ import six import time +import pytest from DIRAC.Core.Base.Script import parseCommandLine parseCommandLine() From 06a1ebcb9c5f25f88c8072f1a0bcd5f7e038ce73 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 30 Jun 2021 00:20:01 +0200 Subject: [PATCH 102/178] fix tests --- .../Resources/IdProvider/Test_IdProviderFactory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 6652bc6fb91..aae471c80d7 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -12,8 +12,6 @@ import DIRAC from DIRAC import gConfig -from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS cfg = CFG() cfg.loadFromBuffer(""" @@ -59,6 +57,8 @@ # DIRACOS not contain required packages from authlib.jose import jwt from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata + from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS idps = IdProviderFactory() From b27535861327ca326e9a0b83b2de6f1f26be8983 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 30 Jun 2021 01:23:25 +0200 Subject: [PATCH 103/178] fix tests --- tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index aae471c80d7..733a882fa3b 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -4,6 +4,7 @@ __RCSID__ = "$Id$" +import six import time import pytest import unittest From 88ad7fefe684d834bde1839474d44a2add8b72fe Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 4 Jul 2021 20:08:05 +0200 Subject: [PATCH 104/178] add TornadoResponse --- .../Server/private/BaseRequestHandler.py | 27 +++++++- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 66 +++++++++++-------- .../private/authorization/AuthServer.py | 17 ++++- 3 files changed, 79 insertions(+), 31 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 7814269dad7..fd9de091e4d 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -44,6 +44,27 @@ sLog = gLogger.getSubLogger(__name__.split('.')[-1]) +class TornadoResponse(object): + """ This class describe result object """ + + def __init__(self, data=None): + """ C'or """ + self.data = data + self.actions = [] + for mName, mObj in inspect.getmembers(RequestHandler): + if inspect.isroutine(mObj) not mName.startswith('_') and mName not 'finish': + setattr(self, mName, lambda *x, **y: self.actions.append((mName, (x, y)))) + + def finish(self, reqObj): + for mName, targs in self.actions: + getattr(reqObj, mName)(*targs[0], **targs[1]) + if not self._finished: + if self.data is None: + getattr(reqObj, 'finish')() + else: + getattr(reqObj, 'finish')(self.data) + + class BaseRequestHandler(RequestHandler): """ Base class for all the Handlers. It directly inherits from :py:class:`tornado.web.RequestHandler` @@ -502,7 +523,11 @@ def _finishFuture(self, retVal): # you need to define the finish_ method. # This method will be started after __executeMethod is completed. finishFunc = getattr(self, 'finish_%s' % self.method, None) - if callable(finishFunc): + + if isinstance(self.result, TornadoResponse): + self.result.finish(self) + + elif callable(finishFunc): finishFunc() # In case nothing is returned diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 616ab17786c..26becacd3a7 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -105,37 +105,49 @@ def _finishFuture(self, retVal): :param object retVal: tornado.concurrent.Future """ - self.result = retVal - - # Is it S_OK or S_ERROR? - r = retVal - if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: - if not retVal['OK']: - # S_ERROR is interpreted in the OAuth2 error format. - self.set_status(400) - self.write({'error': 'server_error', 'description': retVal['Message']}) - self.clear_cookie('auth_session') - self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) - else: - # Successful responses and OAuth2 errors are processed here - status_code, headers, payload, new_session, error = retVal['Value'][0] - if status_code: - self.set_status(status_code) - if headers: - for key, value in headers: - self.set_header(key, value) - if payload: - self.write(payload) - if new_session: - self.set_secure_cookie('auth_session', json.dumps(new_session), secure=True, httponly=True) - if error: - self.clear_cookie('auth_session') - for method, args_kwargs in retVal['Value'][1].items(): - eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) + # Wait result only if it's a Future object + self.result = retVal.result() if isinstance(retVal, Future) else retVal + + # Is it S_ERROR? + if self.result.get('OK') is False and 'Message' in self.result: + # S_ERROR is interpreted in the OAuth2 error format. + self.set_status(400) + self.write({'error': 'server_error', 'description': retVal['Message']}) + self.clear_cookie('auth_session') + self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) self.finish() else: super(AuthHandler, self)._finishFuture(retVal) + # # Is it S_OK or S_ERROR? + # r = retVal + # if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: + # if not retVal['OK']: + # # S_ERROR is interpreted in the OAuth2 error format. + # self.set_status(400) + # self.write({'error': 'server_error', 'description': retVal['Message']}) + # self.clear_cookie('auth_session') + # self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) + # else: + # # Successful responses and OAuth2 errors are processed here + # status_code, headers, payload, new_session, error = retVal['Value'][0] + # if status_code: + # self.set_status(status_code) + # if headers: + # for key, value in headers: + # self.set_header(key, value) + # if payload: + # self.write(payload) + # if new_session: + # self.set_secure_cookie('auth_session', json.dumps(new_session), secure=True, httponly=True) + # if error: + # self.clear_cookie('auth_session') + # for method, args_kwargs in retVal['Value'][1].items(): + # eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) + # self.finish() + # else: + # super(AuthHandler, self)._finishFuture(retVal) + path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] def web_index(self, instance): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index fba1e9df157..b0e8e74b7f3 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -36,6 +36,7 @@ getDNForUsername, getIdPForGroup) from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient from DIRAC.FrameworkSystem.Client.TokenManagerClient import TokenManagerClient +from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import TornadoResponse log = logging.getLogger('authlib') log.addHandler(logging.StreamHandler(sys.stdout)) @@ -324,12 +325,22 @@ def validate_requested_scope(self, scope, state=None): extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) super(AuthServer, self).validate_requested_scope(extended_scope, state) - def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, **actions): + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None): #, **actions): self.log.debug('Handle authorization response with %s status code:' % status_code, payload) - self.log.debug('Headers:', headers) + resp = TornadoResponse(payload) + if status_code: + resp.set_status(status_code) + if headers: + self.log.debug('Headers:', headers) + for key, value in headers: + resp.set_header(key, value) if newSession: self.log.debug('newSession:', newSession) - return S_OK([[status_code, headers, payload, newSession, 'error' in payload], actions]) + resp.set_secure_cookie('auth_session', json.dumps(newSession), secure=True, httponly=True) + if 'error' in payload: + resp.clear_cookie('auth_session') + return resp + # return S_OK([[status_code, headers, payload, newSession, 'error' in payload], actions]) def create_authorization_response(self, response, username): result = super(AuthServer, self).create_authorization_response(response, username) From 0717bd0a0294272f4fcf50b5cc04a12db7852c3c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 5 Jul 2021 14:22:42 +0200 Subject: [PATCH 105/178] fix rebase --- .../Service/TornadoConfigurationHandler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py b/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py index d5aa1364250..06b634c7116 100644 --- a/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py +++ b/src/DIRAC/ConfigurationSystem/Service/TornadoConfigurationHandler.py @@ -48,11 +48,7 @@ def export_getCompressedData(self): Returns the configuration """ sData = self.ServiceInterface.getCompressedConfigurationData() -<<<<<<< HEAD return S_OK(b64encode(sData).decode()) -======= - return S_OK(b64encode(sData)) ->>>>>>> c2f34f040 (move TornadoConfigurationHandler fixes to separate PR) def export_getCompressedDataIfNewer(self, sClientVersion): """ @@ -63,11 +59,7 @@ def export_getCompressedDataIfNewer(self, sClientVersion): sVersion = self.ServiceInterface.getVersion() retDict = {'newestVersion': sVersion} if sClientVersion < sVersion: -<<<<<<< HEAD retDict['data'] = b64encode(self.ServiceInterface.getCompressedConfigurationData()).decode() -======= - retDict['data'] = b64encode(self.ServiceInterface.getCompressedConfigurationData()) ->>>>>>> c2f34f040 (move TornadoConfigurationHandler fixes to separate PR) return S_OK(retDict) def export_publishSlaveServer(self, sURL): From 6e4c90bfc7b2613766f4f572b5292dabd374d7ac Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 5 Jul 2021 14:41:59 +0200 Subject: [PATCH 106/178] fix bugs --- .../Server/private/BaseRequestHandler.py | 30 +++++++------- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 40 +++---------------- .../private/authorization/AuthServer.py | 21 ++++++---- .../authorization/grants/DeviceFlow.py | 4 +- .../private/authorization/utils/Tokens.py | 2 + .../Resources/IdProvider/IdProviderFactory.py | 1 - 6 files changed, 39 insertions(+), 59 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index fd9de091e4d..e52db1cd426 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -12,6 +12,7 @@ import os import six import time +import inspect import threading from datetime import datetime from six import string_types @@ -52,17 +53,17 @@ def __init__(self, data=None): self.data = data self.actions = [] for mName, mObj in inspect.getmembers(RequestHandler): - if inspect.isroutine(mObj) not mName.startswith('_') and mName not 'finish': - setattr(self, mName, lambda *x, **y: self.actions.append((mName, (x, y)))) + if inspect.isroutine(mObj) and not mName.startswith('_') and mName is not 'finish': + setattr(self, mName, partial(self.__setAction, mName)) + + def __setAction(self, mName, *args, **kwargs): + self.actions.append((mName, args, kwargs)) def finish(self, reqObj): - for mName, targs in self.actions: - getattr(reqObj, mName)(*targs[0], **targs[1]) - if not self._finished: - if self.data is None: - getattr(reqObj, 'finish')() - else: - getattr(reqObj, 'finish')(self.data) + for mName, args, kwargs in self.actions: + getattr(reqObj, mName)(*args, **kwargs) + if not reqObj._finished: + reqObj.finish() if self.data is None else reqObj.finish(self.data) class BaseRequestHandler(RequestHandler): @@ -471,8 +472,8 @@ def __executeMethod(self, targetMethod, args): :return: Future """ - sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, self.method)) + sLog.notice("Incoming request %s /%s: %s(%s)" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, self.method, ', '.join(list(args)))) # Execute try: self.initializeRequest() @@ -505,9 +506,8 @@ def _prepareExecutor(self, args): :return: executor, target method with arguments """ - if six.PY2: - return None, partial(self.__executeMethodPy2, self._getMethod(), self._getMethodArgs(args)) - return None, partial(self.__executeMethod, self._getMethod(), self._getMethodArgs(args)) + return None, partial(self.__executeMethodPy2 if six.PY2 else self.__executeMethod, + self._getMethod(), self._getMethodArgs(args)) def _finishFuture(self, retVal): """ Handler Future result @@ -584,7 +584,7 @@ def _gatherPeerCredentials(self, grants=None): # the authorization will go through the `_authzVISITOR` method and # everyone will have access as anonymous@visitor for grant in (grants or self.USE_AUTHZ_GRANTS or 'VISITOR'): - if six.PY3 and grant == 'JWT': + if six.PY2 and grant == 'JWT': # Skip token authorization for python 2 continue grant = grant.upper() diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 26becacd3a7..693b740f7b8 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -14,15 +14,16 @@ import pprint from io import open +from dominate import document, tags as dom from tornado.template import Template +from tornado.concurrent import Future + +from authlib.oauth2.base import OAuth2Error +from authlib.oauth2.rfc6749.util import scope_to_list from DIRAC import S_ERROR from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry - -from dominate import document, tags as dom -from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc6749.util import scope_to_list from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint @@ -109,7 +110,7 @@ def _finishFuture(self, retVal): self.result = retVal.result() if isinstance(retVal, Future) else retVal # Is it S_ERROR? - if self.result.get('OK') is False and 'Message' in self.result: + if isinstance(self.result, dict) and self.result.get('OK') is False and 'Message' in self.result: # S_ERROR is interpreted in the OAuth2 error format. self.set_status(400) self.write({'error': 'server_error', 'description': retVal['Message']}) @@ -119,35 +120,6 @@ def _finishFuture(self, retVal): else: super(AuthHandler, self)._finishFuture(retVal) - # # Is it S_OK or S_ERROR? - # r = retVal - # if isinstance(r, dict) and isinstance(r.get('OK'), bool) and ('Value' if r['OK'] else 'Message') in r: - # if not retVal['OK']: - # # S_ERROR is interpreted in the OAuth2 error format. - # self.set_status(400) - # self.write({'error': 'server_error', 'description': retVal['Message']}) - # self.clear_cookie('auth_session') - # self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) - # else: - # # Successful responses and OAuth2 errors are processed here - # status_code, headers, payload, new_session, error = retVal['Value'][0] - # if status_code: - # self.set_status(status_code) - # if headers: - # for key, value in headers: - # self.set_header(key, value) - # if payload: - # self.write(payload) - # if new_session: - # self.set_secure_cookie('auth_session', json.dumps(new_session), secure=True, httponly=True) - # if error: - # self.clear_cookie('auth_session') - # for method, args_kwargs in retVal['Value'][1].items(): - # eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) - # self.finish() - # else: - # super(AuthHandler, self)._finishFuture(retVal) - path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] def web_index(self, instance): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index b0e8e74b7f3..3aa224eedf0 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -6,6 +6,7 @@ import re import sys import time +import json import pprint import logging from dominate import document, tags as dom @@ -325,7 +326,16 @@ def validate_requested_scope(self, scope, state=None): extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) super(AuthServer, self).validate_requested_scope(extended_scope, state) - def handle_response(self, status_code=None, payload=None, headers=None, newSession=None): #, **actions): + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None): + """ Handle response + + :param int status_code: http status code + :param payload: response payload + :param list headers: headers + :param dict newSession: session data to store + + :return: TornadoResponse() + """ self.log.debug('Handle authorization response with %s status code:' % status_code, payload) resp = TornadoResponse(payload) if status_code: @@ -340,14 +350,11 @@ def handle_response(self, status_code=None, payload=None, headers=None, newSessi if 'error' in payload: resp.clear_cookie('auth_session') return resp - # return S_OK([[status_code, headers, payload, newSession, 'error' in payload], actions]) def create_authorization_response(self, response, username): - result = super(AuthServer, self).create_authorization_response(response, username) - if result['OK']: - # Remove auth session - result['Value'][0][4] = True - return result + response = super(AuthServer, self).create_authorization_response(response, username) + response.clear_cookie('auth_session') + return response def validate_consent_request(self, request, provider=None): """ Validate current HTTP request for authorization page. This page diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index 80da66d535f..aa2871ed3c5 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -118,12 +118,12 @@ def query_user_grant(self, user_code): :param str user_code: user code - :return: str, bool -- user dict and user auth status + :return: (str, bool) or None -- user dict and user auth status """ result = self.server.db.getSessionByUserCode(user_code) if not result['OK']: raise OAuth2Error('Cannot found authorization session', result['Message']) - return (result['Value']['user_id'], True) if result['Value'].get('username', "None") != "None" else ('', False) + return (result['Value']['user_id'], True) if result['Value'].get('username', "None") != "None" else None def should_slow_down(self, *args): """ The authorization request is still pending and polling should continue, diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index e6402edb864..d78baed25c6 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -119,6 +119,8 @@ class OAuth2Token(_OAuth2Token): def __init__(self, params=None, **kwargs): """ Constructor """ + if six.PY3 and isinstance(params, bytes): + params = params.decode() if isinstance(params, six.string_types): # Is params a JWT? params = params.strip() diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index a8dbcc010df..526783423a4 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -10,7 +10,6 @@ from __future__ import division from __future__ import print_function -import six import jwt from DIRAC import S_OK, S_ERROR, gLogger From 72e09c54d39ba5fa0fb36817b8a1d5f671baa200 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 5 Jul 2021 15:28:29 +0200 Subject: [PATCH 107/178] fix log --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index e52db1cd426..46beaabe017 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -472,8 +472,8 @@ def __executeMethod(self, targetMethod, args): :return: Future """ - sLog.notice("Incoming request %s /%s: %s(%s)" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, self.method, ', '.join(list(args)))) + sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, self.method)) # Execute try: self.initializeRequest() From 3b690971cd6dc0139263af2c18ce55b26fb14c8b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 6 Jul 2021 19:47:41 +0200 Subject: [PATCH 108/178] fix refresher with asyncio issue --- .../private/TornadoRefresher.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py index 7443a13ba43..c28a3119306 100644 --- a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py +++ b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py @@ -4,6 +4,7 @@ __RCSID__ = "$Id$" +import six import time from tornado import gen @@ -76,9 +77,13 @@ def __refreshLoop(self): yield gen.sleep(gConfigurationData.getPropagationTime()) # Publish step is blocking so we have to run it in executor # If we are not doing it, when master try to ping we block the IOLoop - yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh) - @gen.coroutine + # When switching from python 2 to python 3, the following error occurs: + # RuntimeError: There is no current event loop in thread.. + # The reason seems to be that asyncio.get_event_loop() is called in some thread other than the main thread, + # asyncio only generates an event loop for the main thread. + yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh if six.PY3 else self.__AutoRefreshPy2) + def __AutoRefresh(self): """ Auto refresh the configuration @@ -89,6 +94,17 @@ def __AutoRefresh(self): if not self._refreshAndPublish(): # pylint: disable=no-member gLogger.error("Can't refresh configuration from any source") + @gen.coroutine + def __AutoRefreshPy2(self): + """ + Auto refresh the configuration + We disable pylint error because this class must be instanciated + by a mixin to define the methods. for python 2 + """ + if self._refreshEnabled: # pylint: disable=no-member + if not self._refreshAndPublish(): # pylint: disable=no-member + gLogger.error("Can't refresh configuration from any source") + def daemonize(self): """ daemonize is probably not the best name because there is no daemon behind but we must keep it to the same interface of the DISET refresher """ From 509d74cbceba0a7520e3b45b34cf1ec8a3d816c3 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 6 Jul 2021 19:50:10 +0200 Subject: [PATCH 109/178] fix BaseRequestHandler for WebApp case --- .../Server/private/BaseRequestHandler.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 46beaabe017..ce0fba00799 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -606,20 +606,24 @@ def _authzSSL(self): :return: S_OK(dict)/S_ERROR() """ - peerChain = X509Chain() - derCert = self.request.get_ssl_certificate() - # Get client certificate pem + try: + derCert = self.request.get_ssl_certificate() + except Exception: + # If 'IOStream' object has no attribute 'get_ssl_certificate' + derCert = None + + # Get client certificate as pem if derCert: chainAsText = derCert.as_pem() - # Here we read all certificate chain - chainAsText += '\n'.join([cert.as_pem() for cert in self.request.get_ssl_certificate_chain()]) - elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS': - chainAsTextEncoded = self.request.headers.get('X-SSL-CERT') - chainAsText = unquote(chainAsTextEncoded) + # Read all certificate chain + chainAsText += ''.join([cert.as_pem() for cert in self.request.get_ssl_certificate_chain()]) + elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS' and self.request.headers.get('X-SSL-CERT'): + chainAsText = unquote(self.request.headers.get('X-SSL-CERT')) else: return S_ERROR(DErrno.ECERTFIND, 'Valid certificate not found.') # Load full certificate chain + peerChain = X509Chain() peerChain.loadChainFromString(chainAsText) # Retrieve the credentials From dd1494ccf68b122c10f95876f4e6553e1dd363b0 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 6 Jul 2021 19:50:56 +0200 Subject: [PATCH 110/178] comment sslDebug --- src/DIRAC/Core/Tornado/Server/TornadoServer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 7be28d08933..7a9601eccdf 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -186,7 +186,8 @@ def startTornado(self): 'keyfile': certs[1], 'cert_reqs': M2Crypto.SSL.verify_peer, 'ca_certs': ca, - 'sslDebug': False, # Set to true if you want to see the TLS debug messages + # Failed in tornado '5.1.1', 'sslDebug' not in m2netutil._SSL_CONTEXT_KEYWORDS + #'sslDebug': False, # Set to true if you want to see the TLS debug messages } # Init monitoring From 97d9b485b7a15183b117cfc46ae1bd1c7f936b1d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 6 Jul 2021 19:53:56 +0200 Subject: [PATCH 111/178] add utilities, change LOCATION --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 18 ++++---- .../private/authorization/AuthServer.py | 46 ++++--------------- .../private/authorization/utils/Utilities.py | 37 +++++++++++++++ .../Resources/IdProvider/IdProviderFactory.py | 11 +++-- tests/Integration/Framework/Test_AuthDB.py | 12 ++--- .../IdProvider/Test_IdProviderFactory.py | 7 +-- 6 files changed, 69 insertions(+), 62 deletions(-) create mode 100644 src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 693b740f7b8..cbf3825ce26 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -37,7 +37,7 @@ class AuthHandler(TornadoREST): USE_AUTHZ_GRANTS = ['JWT', 'VISITOR'] SYSTEM = 'Framework' AUTH_PROPS = 'all' - LOCATION = "/DIRAC/auth" + LOCATION = "/auth" css_align_center = 'display:block;justify-content:center;align-items:center;' css_center_div = 'height:700px;width:100%;position:absolute;top:50%;left:0;margin-top:-350px;' css_big_text = 'font-size:28px;' @@ -137,9 +137,9 @@ def web_index(self, instance): Content-Type: application/json { - "registration_endpoint": "https://domain.com/DIRAC/auth/register", - "userinfo_endpoint": "https://domain.com/DIRAC/auth/userinfo", - "jwks_uri": "https://domain.com/DIRAC/auth/jwk", + "registration_endpoint": "https://domain.com/auth/register", + "userinfo_endpoint": "https://domain.com/auth/userinfo", + "jwks_uri": "https://domain.com/auth/jwk", "code_challenge_methods_supported": [ "S256" ], @@ -148,7 +148,7 @@ def web_index(self, instance): "code", "refresh_token" ], - "token_endpoint": "https://domain.com/DIRAC/auth/token", + "token_endpoint": "https://domain.com/auth/token", "response_types_supported": [ "code", "device", @@ -156,8 +156,8 @@ def web_index(self, instance): "id_token", "token" ], - "authorization_endpoint": "https://domain.com/DIRAC/auth/authorization", - "issuer": "https://domain.com/DIRAC/auth" + "authorization_endpoint": "https://domain.com/auth/authorization", + "issuer": "https://domain.com/auth" } """ if self.request.method == "GET": @@ -282,10 +282,10 @@ def web_device(self, provider=None): { "device_code": "TglwLiow0HUwowjB9aHH5HqH3bZKP9d420LkNhCEuR", - "verification_uri": "https://marosvn32.in2p3.fr/DIRAC/auth/device", + "verification_uri": "https://marosvn32.in2p3.fr/auth/device", "interval": 5, "expires_in": 1800, - "verification_uri_complete": "https://marosvn32.in2p3.fr/DIRAC/auth/device/WSRL-HJMR", + "verification_uri_complete": "https://marosvn32.in2p3.fr/auth/device/WSRL-HJMR", "user_code": "WSRL-HJMR" } diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 3aa224eedf0..b5ee8645166 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -17,27 +17,26 @@ from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc7636 import CodeChallenge -from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope -from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint -from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant -from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, - DeviceCodeGrant) -from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant -from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients, Client -from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request - from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata, isDownloadablePersonalProxy +from DIRAC.ConfigurationSystem.Client.Utilities import isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, wrapIDAsDN, getDNForUsername, getIdPForGroup) from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient from DIRAC.FrameworkSystem.Client.TokenManagerClient import TokenManagerClient from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import TornadoResponse +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients, Client +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata +from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, + DeviceCodeGrant) +from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import AuthorizationCodeGrant log = logging.getLogger('authlib') log.addHandler(logging.StreamHandler(sys.stdout)) @@ -45,33 +44,6 @@ log = gLogger.getSubLogger(__name__) -def collectMetadata(issuer=None): - """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: - https://datatracker.ietf.org/doc/html/rfc8414#section-2 - - :param str issuer: issuer to set - - :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time - """ - result = getAuthorizationServerMetadata(issuer) - if not result['OK']: - raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) - metadata = result['Value'] - for name, endpoint in [('jwks_uri', 'jwk'), - ('token_endpoint', 'token'), - ('userinfo_endpoint', 'userinfo'), - ('revocation_endpoint', 'revoke'), - ('authorization_endpoint', 'authorization'), - ('device_authorization_endpoint', 'device')]: - metadata[name] = metadata['issuer'].strip('/') + '/' + endpoint - metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] - metadata['grant_types_supported'] = ['code', 'authorization_code', 'refresh_token', - 'urn:ietf:params:oauth:grant-type:device_code'] - metadata['response_types_supported'] = ['code', 'device', 'token'] - metadata['code_challenge_methods_supported'] = ['S256'] - return AuthorizationServerMetadata(metadata) - - class AuthServer(_AuthorizationServer): """ Implementation of the :class:`authlib.oauth2.rfc6749.AuthorizationServer`. diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py new file mode 100644 index 00000000000..93ab40f3f5c --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata + +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata + + +def collectMetadata(issuer=None): + """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: + https://datatracker.ietf.org/doc/html/rfc8414#section-2 + + :param str issuer: issuer to set + + :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time + """ + result = getAuthorizationServerMetadata(issuer) + if not result['OK']: + raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) + metadata = result['Value'] + for name, endpoint in [('jwks_uri', 'jwk'), + ('token_endpoint', 'token'), + ('userinfo_endpoint', 'userinfo'), + ('revocation_endpoint', 'revoke'), + ('redirect_uri', 'redirect'), + ('authorization_endpoint', 'authorization'), + ('device_authorization_endpoint', 'device')]: + metadata[name] = metadata['issuer'].strip('/') + '/' + endpoint + metadata['scopes_supported'] = ['g:', 'proxy', 'lifetime:'] + metadata['grant_types_supported'] = ['code', 'authorization_code', 'refresh_token', + 'urn:ietf:params:oauth:grant-type:device_code'] + metadata['response_types_supported'] = ['code', 'device', 'token'] + metadata['code_challenge_methods_supported'] = ['S256'] + return AuthorizationServerMetadata(metadata) \ No newline at end of file diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 526783423a4..75783e0bdf5 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -16,8 +16,8 @@ from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe from DIRAC.Core.Utilities.DictCache import DictCache from DIRAC.Resources.IdProvider.Utilities import getProviderInfo, getSettingsNamesForIdPIssuer -from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata __RCSID__ = "$Id$" @@ -71,15 +71,14 @@ def getIdProvider(self, name, **kwargs): :return: S_OK(IdProvider)/S_ERROR() """ + # Get Authorization Server metadata + asMetaDict = collectMetadata() self.log.debug('Search %s identity provider client configuration..' % name) clients = getDIRACClients() if name in clients: # If it is a DIRAC default pre-registred client pDict = clients[name] - result = getAuthorizationServerMetadata() - if not result['OK']: - return result - pDict.update(result['Value']) + pDict.update(asMetaDict) else: # if it is external identity provider client result = getProviderInfo(name) @@ -87,6 +86,8 @@ def getIdProvider(self, name, **kwargs): self.log.error('Failed to read configuration', '%s: %s' % (name, result['Message'])) return result pDict = result['Value'] + # Set default redirect_uri + pDict['redirect_uri'] = pDict.get('redirect_uri', asMetaDict['redirect_uri']) pDict.update(kwargs) pDict['ProviderName'] = name diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py index a1e6d30b4a4..a52cd2094b0 100644 --- a/tests/Integration/Framework/Test_AuthDB.py +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -174,10 +174,10 @@ def test_Sessions(): 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', 'interval': 5, 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'uri': 'https://domain.com/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF'} + 'verification_uri': 'https://domain.com/auth/device', + 'verification_uri_complete': u'https://domain.com/auth/device?user_code=MDKP-MXMF'} # Example of the updated session sData2 = {'client_id': 'DIRAC_CLI', @@ -186,10 +186,10 @@ def test_Sessions(): 'id': 'SsoGTDglu6LThpx0CigM9i9J72B5atZ24ULr6R1awm', 'interval': 5, 'scope': 'g:my_group', - 'uri': 'https://domain.com/DIRAC/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', + 'uri': 'https://domain.com/auth/device?&response_type=device&client_id=DIRAC_CLI&scope=g:my_group', 'user_code': 'MDKP-MXMF', - 'verification_uri': 'https://domain.com/DIRAC/auth/device', - 'verification_uri_complete': u'https://domain.com/DIRAC/auth/device?user_code=MDKP-MXMF', + 'verification_uri': 'https://domain.com/auth/device', + 'verification_uri_complete': u'https://domain.com/auth/device?user_code=MDKP-MXMF', 'user_id': 'username'} # Remove old session diff --git a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py index 733a882fa3b..005c383925a 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProviderFactory.py @@ -58,7 +58,6 @@ # DIRACOS not contain required packages from authlib.jose import jwt from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory - from DIRAC.FrameworkSystem.private.authorization.AuthServer import collectMetadata from DIRAC.FrameworkSystem.private.authorization.utils.Clients import DEFAULT_CLIENTS idps = IdProviderFactory() @@ -67,17 +66,15 @@ def test_getDIRACClients(): """ Try to load default DIRAC authorization client """ - params = collectMetadata() - # Try to get DIRAC client authorization settings - result = idps.getIdProvider('DIRACCLI', **params) + result = idps.getIdProvider('DIRACCLI') assert result['OK'], result['Message'] assert result['Value'].issuer == 'https://issuer.url/' assert result['Value'].client_id == DEFAULT_CLIENTS['DIRACCLI']['client_id'] assert result['Value'].get_metadata('jwks_uri') == 'https://issuer.url/jwk' # Try to get DIRAC client authorization settings for Web portal - result = idps.getIdProvider('DIRACWeb', **params) + result = idps.getIdProvider('DIRACWeb') assert result['OK'], result['Message'] assert result['Value'].issuer == 'https://issuer.url/' assert result['Value'].client_id == 'client_identificator' From c167bf5a2629dc6f353533a98d1826521b8ccc97 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 6 Jul 2021 20:45:58 +0200 Subject: [PATCH 112/178] fix pylint --- src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py | 2 +- src/DIRAC/Core/Tornado/Server/TornadoServer.py | 2 +- .../FrameworkSystem/private/authorization/AuthServer.py | 7 ++++--- .../private/authorization/utils/Utilities.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py index c28a3119306..32dc1ed733f 100644 --- a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py +++ b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py @@ -96,7 +96,7 @@ def __AutoRefresh(self): @gen.coroutine def __AutoRefreshPy2(self): - """ + """ Auto refresh the configuration We disable pylint error because this class must be instanciated by a mixin to define the methods. for python 2 diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 7a9601eccdf..337a0b7173e 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -187,7 +187,7 @@ def startTornado(self): 'cert_reqs': M2Crypto.SSL.verify_peer, 'ca_certs': ca, # Failed in tornado '5.1.1', 'sslDebug' not in m2netutil._SSL_CONTEXT_KEYWORDS - #'sslDebug': False, # Set to true if you want to see the TLS debug messages + # 'sslDebug': False, # Set to true if you want to see the TLS debug messages } # Init monitoring diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index b5ee8645166..dd47b37ece9 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -311,16 +311,17 @@ def handle_response(self, status_code=None, payload=None, headers=None, newSessi self.log.debug('Handle authorization response with %s status code:' % status_code, payload) resp = TornadoResponse(payload) if status_code: - resp.set_status(status_code) + resp.set_status(status_code) # pylint: disable=no-member if headers: self.log.debug('Headers:', headers) for key, value in headers: - resp.set_header(key, value) + resp.set_header(key, value) # pylint: disable=no-member if newSession: self.log.debug('newSession:', newSession) + # pylint: disable=no-member resp.set_secure_cookie('auth_session', json.dumps(newSession), secure=True, httponly=True) if 'error' in payload: - resp.clear_cookie('auth_session') + resp.clear_cookie('auth_session') # pylint: disable=no-member return resp def create_authorization_response(self, response, username): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 93ab40f3f5c..bf6e64e390b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -34,4 +34,4 @@ def collectMetadata(issuer=None): 'urn:ietf:params:oauth:grant-type:device_code'] metadata['response_types_supported'] = ['code', 'device', 'token'] metadata['code_challenge_methods_supported'] = ['S256'] - return AuthorizationServerMetadata(metadata) \ No newline at end of file + return AuthorizationServerMetadata(metadata) From e9e89bcbfcdf16d7b879d27194d31732e76c6d3b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 7 Jul 2021 10:28:54 +0200 Subject: [PATCH 113/178] fix --- src/DIRAC/Resources/IdProvider/IdProviderFactory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 75783e0bdf5..d07da05e9af 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -73,12 +73,12 @@ def getIdProvider(self, name, **kwargs): """ # Get Authorization Server metadata asMetaDict = collectMetadata() - self.log.debug('Search %s identity provider client configuration..' % name) + self.log.debug('Search configuration for', name) clients = getDIRACClients() if name in clients: # If it is a DIRAC default pre-registred client - pDict = clients[name] - pDict.update(asMetaDict) + pDict = asMetaDict + pDict.update(clients[name]) else: # if it is external identity provider client result = getProviderInfo(name) From d6eefb4fe1ecf1c1b1a34c7738ec290d360d396c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 12 Jul 2021 14:32:17 +0200 Subject: [PATCH 114/178] fix bugs --- .../Server/private/BaseRequestHandler.py | 26 +++++++++++++++---- .../Service/TokenManagerHandler.py | 2 +- .../private/authorization/AuthServer.py | 3 ++- .../private/authorization/utils/Requests.py | 2 +- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index ce0fba00799..40264b98289 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -46,20 +46,36 @@ class TornadoResponse(object): - """ This class describe result object """ + """ This class registers tornadoes with arguments in the order they are called + from TornadoResponse to call them later. + Use:: + + def web_myEndpotin(self): + resp = TornadoResponse('data') + resp.set_status(400) + return resp + """ def __init__(self, data=None): - """ C'or """ + """ C'or + + :param data: response body + """ self.data = data self.actions = [] for mName, mObj in inspect.getmembers(RequestHandler): - if inspect.isroutine(mObj) and not mName.startswith('_') and mName is not 'finish': + if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startwith('get'): setattr(self, mName, partial(self.__setAction, mName)) def __setAction(self, mName, *args, **kwargs): + """ Register new action """ self.actions.append((mName, args, kwargs)) - def finish(self, reqObj): + def _runActions(self, reqObj): + """ Calling methods in the order of their registration + + :param reqObj: RequestHandler instance + """ for mName, args, kwargs in self.actions: getattr(reqObj, mName)(*args, **kwargs) if not reqObj._finished: @@ -525,7 +541,7 @@ def _finishFuture(self, retVal): finishFunc = getattr(self, 'finish_%s' % self.method, None) if isinstance(self.result, TornadoResponse): - self.result.finish(self) + self.result._runActions(self) elif callable(finishFunc): finishFunc() diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index d8187324344..94938a8cdf8 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -87,7 +87,7 @@ def export_updateToken(self, token, userID, provider, rt_expired_in=24 * 3600): :return: S_OK(dict)/S_ERROR() -- dict contain uploaded tokens info """ - self.log.verbose('Update %s user token:\n', pprint.pformat(token)) + self.log.verbose('Update %s user token for %s:\n' % (userID, provider), pprint.pformat(token)) result = self.idps.getIdProvider(provider) if not result['OK']: return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index dd47b37ece9..b1f40964f83 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -320,7 +320,7 @@ def handle_response(self, status_code=None, payload=None, headers=None, newSessi self.log.debug('newSession:', newSession) # pylint: disable=no-member resp.set_secure_cookie('auth_session', json.dumps(newSession), secure=True, httponly=True) - if 'error' in payload: + if isinstance(payload, dict) and 'error' in payload: resp.clear_cookie('auth_session') # pylint: disable=no-member return resp @@ -367,6 +367,7 @@ def validateIdentityProvider(self, request, provider): :return: str, S_OK()/S_ERROR() -- provider name and html page to choose it """ + self.log.debug("Check if %s identity provider registred in DIRAC.." % provider) # Research supported IdPs result = getProvidersForInstance('Id') if not result['OK']: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index aa53d4bdab4..375ac67af48 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -33,7 +33,7 @@ def groups(self): :return: list """ - return [s.split(':')[1] for s in scope_to_list(self.scope) if s.startswith('g:')] + return [s.split(':')[1] for s in scope_to_list(self.scope) if s.startswith('g:') and s.split(':')[1]] def toDict(self): """ Convert class to dictionary From dee115714d9a6520b7f1f80673b112ff18c6b993 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:42:49 +0200 Subject: [PATCH 115/178] update getAuthorizationServerMetadata to ignore CS errors --- .../ConfigurationSystem/Client/Utilities.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index 4bf1d0f1259..f1138f6e386 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -579,15 +579,25 @@ def getAuthAPI(): return gConfig.getValue("/Systems/Framework/%s/URLs/AuthAPI" % getSystemInstance("Framework")) -def getAuthorizationServerMetadata(issuer=None): +def getAuthorizationServerMetadata(issuer=None, ignoreErrors=False): """ Get authorization server metadata + :param str issuer: issuer + :param bool ignoreErrors: igrnore configuration errors + :return: S_OK(dict)/S_ERROR() """ - result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization') - if not result['OK']: - return {'issuer': issuer} if issuer else result - data = result['Value'] + data = {} + try: + result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization') + if not result['OK']: + return S_OK({'issuer': issuer}) if issuer else result + data = result['Value'] + except Exception as e: + if ignoreErrors: + gLogger.warn(repr(e)) + else: + raise e # Search DIRAC Authorization Server issuer data['issuer'] = data.get('issuer', issuer) @@ -597,7 +607,7 @@ def getAuthorizationServerMetadata(issuer=None): except Exception as e: return S_ERROR('No issuer found in DIRAC authorization server: %s' % repr(e)) - return S_OK(data) + return S_OK(data) if data['issuer'] else S_ERROR('Cannot find DIRAC Authorization Server issuer.') def isDownloadablePersonalProxy(): From 9f1c5fc0817dddc97c8cd163b7865fea3290ea7d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:44:14 +0200 Subject: [PATCH 116/178] fixes --- src/DIRAC/Core/Tornado/Server/TornadoServer.py | 7 +++++++ .../Tornado/Server/private/BaseRequestHandler.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 337a0b7173e..5735509d81c 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -198,6 +198,13 @@ def startTornado(self): # Starting monitoring, IOLoop waiting time in ms, __monitoringLoopDelay is defined in seconds tornado.ioloop.PeriodicCallback(self.__reportToMonitoring, self.__monitoringLoopDelay * 1000).start() + if six.PY3: + # If we are running with python3, Tornado will use asyncio, + # and we have to convince it to let us run in a different thread + # Doing this ensures a consistent behavior between py2 and py3 + import asyncio # pylint: disable=import-error + asyncio.set_event_loop_policy(tornado.platform.asyncio.AnyThreadEventLoopPolicy()) + for port, app in self.__appsSettings.items(): sLog.debug(" - %s" % "\n - ".join(["%s = %s" % (k, ssl_options[k]) for k in ssl_options])) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 40264b98289..134f52b967d 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -64,7 +64,7 @@ def __init__(self, data=None): self.data = data self.actions = [] for mName, mObj in inspect.getmembers(RequestHandler): - if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startwith('get'): + if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startswith('get'): setattr(self, mName, partial(self.__setAction, mName)) def __setAction(self, mName, *args, **kwargs): @@ -79,7 +79,7 @@ def _runActions(self, reqObj): for mName, args, kwargs in self.actions: getattr(reqObj, mName)(*args, **kwargs) if not reqObj._finished: - reqObj.finish() if self.data is None else reqObj.finish(self.data) + reqObj.finish(self.data if self.data is None else encode(self.data)) class BaseRequestHandler(RequestHandler): @@ -291,10 +291,6 @@ def __initializeService(cls, request): if cls.__init_done: return S_OK() - # Load all registred identity providers - if six.PY3: - cls.__loadIdPs() - # absoluteUrl: full URL e.g. ``https://://`` absoluteUrl = request.path serviceName = cls._getServiceName(request) @@ -318,6 +314,10 @@ def __initializeService(cls, request): cls.initializeHandler(serviceInfo) + # Load all registred identity providers + if six.PY3: + cls.__loadIdPs() + cls.__init_done = True return S_OK() @@ -630,9 +630,9 @@ def _authzSSL(self): # Get client certificate as pem if derCert: - chainAsText = derCert.as_pem() + chainAsText = derCert.as_pem().decode() # Read all certificate chain - chainAsText += ''.join([cert.as_pem() for cert in self.request.get_ssl_certificate_chain()]) + chainAsText += ''.join([cert.as_pem().decode() for cert in self.request.get_ssl_certificate_chain()]) elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS' and self.request.headers.get('X-SSL-CERT'): chainAsText = unquote(self.request.headers.get('X-SSL-CERT')) else: From 64910e5e8e5bf0368081b5376a9750baaabe9fca Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:44:49 +0200 Subject: [PATCH 117/178] add bytes decoding to JEncode --- src/DIRAC/Core/Utilities/JEncode.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/DIRAC/Core/Utilities/JEncode.py b/src/DIRAC/Core/Utilities/JEncode.py index c2283a41061..0369f401def 100644 --- a/src/DIRAC/Core/Utilities/JEncode.py +++ b/src/DIRAC/Core/Utilities/JEncode.py @@ -105,6 +105,9 @@ def default(self, obj): # pylint: disable=method-hidden # if the object inherits from JSJerializable, try to serialize it elif isinstance(obj, JSerializable): return obj._toJSON() # pylint: disable=protected-access + # if the object a bytes, decode it + elif isinstance(obj, bytes): + return obj.decode() # otherwise, let the parent do return super(DJSONEncoder, self).default(obj) From 7d2150de7aacf207e4413de282adf4675f1d90bb Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:45:59 +0200 Subject: [PATCH 118/178] add tokens to dirac_configure --- src/DIRAC/Core/scripts/dirac_configure.py | 86 +++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/src/DIRAC/Core/scripts/dirac_configure.py b/src/DIRAC/Core/scripts/dirac_configure.py index 48bf58f5c57..5c6db36cbab 100755 --- a/src/DIRAC/Core/scripts/dirac_configure.py +++ b/src/DIRAC/Core/scripts/dirac_configure.py @@ -63,6 +63,7 @@ class Params(object): def __init__(self): + self.issuer = None self.logLevel = None self.setup = None self.configurationServer = None @@ -172,6 +173,11 @@ def setExtensions(self, optionValue): DIRAC.gConfig.setOptionValue(cfgInstallPath('Extensions'), self.extensions) return DIRAC.S_OK() + def setIssuer(self, optionValue): + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'True' + self.issuer = optionValue + DIRAC.gConfig.setOptionValue('/DIRAC/Security/Authorization/issuer', self.issuer) + return DIRAC.S_OK() def _runConfigurationWizard(setups, defaultSetup): """The implementation of the configuration wizard""" @@ -289,7 +295,75 @@ def main(): runDiracConfigure(params) +def login(params): + from prompt_toolkit import prompt, print_formatted_text, HTML + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import (writeTokenDictToTokenFile, + getTokenFileLocation, + readTokenFromFile) + # Init authorization client + result = IdProviderFactory().getIdProvider('DIRACCLI', issuer=params.issuer, scope=' ') + if not result['OK']: + return result + idpObj = result['Value'] + + # Get token file path + tokenFile = getTokenFileLocation() + + # Submit Device authorisation flow + result = idpObj.deviceAuthorization() + if not result['OK']: + return result + + # Revoke old tokens from token file + if os.path.isfile(tokenFile): + result = readTokenFromFile(tokenFile) + if not result['OK']: + DIRAC.gLogger.warn(result['Message']) + elif result['Value']: + oldToken = result['Value'] + for tokenType in ['access_token', 'refresh_token']: + result = idpObj.revokeToken(oldToken[tokenType], tokenType) + if result['OK']: + DIRAC.gLogger.debug('%s is revoked from' % tokenType, tokenFile) + else: + DIRAC.gLogger.warn(result['Message']) + + # Save new tokens to token file + result = writeTokenDictToTokenFile(idpObj.token, tokenFile) + if not result['OK']: + return result + DIRAC.gLogger.debug('New token is saved to %s.' % result['Value']) + + # Get server setups and master CS server URL + csURL = idpObj.get_metadata("configuration_server") + setups = idpObj.get_metadata("setups") + + if len(setups) == 1 and csURL: + # If setup only one dont ask user + params.setSetup(setups[0]) + params.setServer(csURL) + elif setups and csURL: + # Ask the user for the appropriate configuration settings + while True: + result = _runConfigurationWizard({setup: csURL for setup in setups}, setups[0]) + if result: + break + print_formatted_text(HTML( + "Wizard failed, retrying... (press Control + C to exit)\n" + )) + # Apply the arguments to the params object + setup, csURL = result + params.setSetup(setup) + params.setServer(csURL) + + confirm = prompt(HTML("Do you want to use tokens instead of certificates by default? "), default="no") + return DIRAC.S_OK('yes') if confirm.lower() in ["y", "yes"] else DIRAC.S_OK(None) + + def runDiracConfigure(params): + if six.PY3: + Script.registerSwitch("", "login=", "Set DIRAC authorization endpoint", params.setIssuer) Script.registerSwitch("S:", "Setup=", "Set as DIRAC setup", params.setSetup) Script.registerSwitch("e:", "Extensions=", "Set as DIRAC extensions", params.setExtensions) Script.registerSwitch("C:", "ConfigurationServer=", "Set as DIRAC configuration server", params.setServer) @@ -320,6 +394,18 @@ def runDiracConfigure(params): Script.parseCommandLine(ignoreErrors=True) + # Use token auth + if params.issuer: + result = login(params) + if not result['OK']: + DIRAC.gLogger.error('Authorization failed: %s' % result['Message']) + DIRAC.exit(1) + useTokens = result['Value'] + if useTokens: + DIRAC.gConfig.setOptionValue('/DIRAC/Security/UseTokens', useTokens) + else: + DIRAC.gLogger.notice('To use tokens, please, set "/DIRAC/Security/UseTokens=yes".') + if not params.logLevel: params.logLevel = DIRAC.gConfig.getValue(cfgInstallPath('LogLevel'), '') if params.logLevel: From 8825d86252f4b70278a7c6491f681b101302a739 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:46:27 +0200 Subject: [PATCH 119/178] add additional info to well-known --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index cbf3825ce26..61f011e82da 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -21,7 +21,7 @@ from authlib.oauth2.base import OAuth2Error from authlib.oauth2.rfc6749.util import scope_to_list -from DIRAC import S_ERROR +from DIRAC import S_ERROR, gConfig from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST from DIRAC.ConfigurationSystem.Client.Helpers import Registry from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer @@ -161,7 +161,10 @@ def web_index(self, instance): } """ if self.request.method == "GET": - return self.server.metadata + resDict = dict(setups=gConfig.getSections('DIRAC/Setups').get('Value', []), + configuration_server=gConfig.getValue("/DIRAC/Configuration/MasterServer", "")) + resDict.update(self.server.metadata) + return resDict def web_jwk(self): """ JWKs endpoint @@ -465,7 +468,7 @@ def __researchDIRACGroup(self, extSession): # Base DIRAC client auth session firstRequest = createOAuth2Request(extSession['mainSession']) # Read requested groups by DIRAC client or user - firstRequest.addScopes(self.get_arguments('chooseScope', [])) + firstRequest.addScopes(self.get_arguments('chooseScope')) # Read already authed user username = extSession['authed']['username'] self.log.debug('Next groups has been found for %s:' % username, ', '.join(firstRequest.groups)) @@ -479,7 +482,6 @@ def __researchDIRACGroup(self, extSession): return None, S_ERROR('No groups found for %s.' % username) self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) - if not firstRequest.groups: if len(validGroups) == 1: firstRequest.addScopes(['g:%s' % validGroups[0]]) From 67ff2274afc0322901d64a679a7cc8cc5db4b2aa Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 15 Jul 2021 23:47:24 +0200 Subject: [PATCH 120/178] other fixes --- .../Service/TokenManagerHandler.py | 2 ++ .../private/authorization/AuthServer.py | 35 ++++++++----------- .../private/authorization/utils/Clients.py | 7 ++-- .../private/authorization/utils/Tokens.py | 8 ++++- .../private/authorization/utils/Utilities.py | 4 +-- .../FrameworkSystem/scripts/dirac_login.py | 33 ++++------------- .../Resources/IdProvider/IdProviderFactory.py | 4 ++- .../Resources/IdProvider/OAuth2IdProvider.py | 14 ++++---- 8 files changed, 47 insertions(+), 60 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index 94938a8cdf8..0d2618f2207 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -133,6 +133,8 @@ def export_getToken(self, username, userGroup): """ userID = [] provider = Registry.getIdPForGroup(userGroup) + if not provider: + return S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.') result = self.idps.getIdProvider(provider) if not result['OK']: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index b1f40964f83..09fd265c883 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -108,7 +108,7 @@ def _getScope(self, scope, param): :return: str or None """ try: - return [s.split(':')[1] for s in scope_to_list(scope) if s.startswith('%s:' % param)][0] + return [s.split(':')[1] for s in scope_to_list(scope) if s.startswith('%s:' % param) and s.split(':')[1]][0] except Exception: return None @@ -123,17 +123,17 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, # Search DIRAC username result = getUsernameForDN(wrapIDAsDN(user)) if not result['OK']: - raise Exception(result['Message']) + raise OAuth2Error(result['Message']) userName = result['Value'] if 'proxy' in scope_to_list(scope): # Try to return user proxy if proxy scope present in the authorization request if not isDownloadablePersonalProxy(): - raise Exception("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") + raise OAuth2Error("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") self.log.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) result = getDNForUsername(userName) if not result['OK']: - raise Exception(result['Message']) + raise OAuth2Error(result['Message']) userDNs = result['Value'] err = [] for dn in userDNs: @@ -148,9 +148,9 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, self.log.info('Proxy was created.') result = result['Value'].dumpAllToString() if not result['OK']: - raise Exception(result['Message']) + raise OAuth2Error(result['Message']) return {'proxy': result['Value']} - raise Exception('; '.join(err)) + raise OAuth2Error('; '.join(err)) else: # Ask TokenManager to generate new tokens for user @@ -264,14 +264,14 @@ def parseIdPAuthorizationResponse(self, response, session): # FINISHING with IdP # As a result of authentication we will receive user credential dictionary - credDict = result['Value'] + credDict, payload = result['Value'] self.log.debug("Read profile:", pprint.pformat(credDict)) # Is ID registred? result = getUsernameForDN(credDict['DN']) if not result['OK']: comment = '%s ID is not registred in the DIRAC.' % credDict['ID'] - result = self.__registerNewUser(providerName, credDict) + result = self.__registerNewUser(providerName, payload) if result['OK']: comment += ' Administrators have been notified about you.' else: @@ -409,29 +409,24 @@ def validateIdentityProvider(self, request, provider): return provider, None - def __registerNewUser(self, provider, userProfile): + def __registerNewUser(self, provider, payload): """ Register new user :param str provider: provider - :param dict userProfile: user information dictionary + :param dict payload: user information dictionary :return: S_OK()/S_ERROR() """ from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient - username = userProfile['ID'] + username = payload['sub'] mail = {} - mail['subject'] = "[SessionManager] User %s to be added." % username - mail['body'] = 'User %s was authenticated by ' % username - mail['body'] += provider - mail['body'] += "\n\nAuto updating of the user database is not allowed." - mail['body'] += " New user %s to be added," % username - mail['body'] += "with the following information:\n" - mail['body'] += "\nUser ID: %s\n" % username - mail['body'] += "\nUser profile:\n%s" % pprint.pformat(userProfile) + mail['subject'] = "[DIRAC AS] User %s to be added." % username + mail['body'] = 'User %s was authenticated by %s.' % (username, provider) + mail['body'] += "\n\nNew user to be added with the following information:\n%s" % pprint.pformat(payload) mail['body'] += "\n\n------" - mail['body'] += "\n This is a notification from the DIRAC AuthManager service, please do not reply.\n" + mail['body'] += "\n This is a notification from the DIRAC authorization service, please do not reply.\n" result = S_OK() for addresses in getEmailsForGroup('dirac_admin'): result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py index 66f298d81b1..d4daf0165a7 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -9,6 +9,7 @@ from DIRAC import gConfig, gLogger from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata __RCSID__ = "$Id$" @@ -29,10 +30,8 @@ def getDIRACClients(): :return: S_OK(dict)/S_ERROR() """ clients = DEFAULT_CLIENTS.copy() - result = gConfig.getOptionsDictRecursively('/DIRAC/Security/Authorization/Clients') - if not result['OK']: - gLogger.error(result['Message']) - confClients = result.get('Value', {}) + result = getAuthorizationServerMetadata(ignoreErrors=True) + confClients = result.get('Value', {}).get('Clients', {}) for cli in confClients: if cli not in clients: clients[cli] = confClients[cli] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index d78baed25c6..e52bf472e61 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -165,7 +165,13 @@ def is_expired(self): :return: bool """ - return int(self.get('expires_at', self.get('issued_at') + self.get('expires_in'))) < time.time() + if self.get('expires_at'): + return int(self.get('expires_at')) < time.time() + elif self.get('issued_at') and self.get('expires_in'): + return int(self.get('issued_at')) + int(self.get('expires_in')) < time.time() + else: + exp = self.get_payload().get('exp') + return int(exp) < time.time() if exp else True @property def scopes(self): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index bf6e64e390b..102351c3872 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -9,7 +9,7 @@ from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata -def collectMetadata(issuer=None): +def collectMetadata(issuer=None, ignoreErrors=False): """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: https://datatracker.ietf.org/doc/html/rfc8414#section-2 @@ -17,7 +17,7 @@ def collectMetadata(issuer=None): :return: dict -- dictionary is the AuthorizationServerMetadata object in the same time """ - result = getAuthorizationServerMetadata(issuer) + result = getAuthorizationServerMetadata(issuer, ignoreErrors=ignoreErrors) if not result['OK']: raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) metadata = result['Value'] diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index e9a1ec836ee..802af2701be 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -90,31 +90,11 @@ def setLivetime(self, arg): def registerCLISwitches(self): """ Register CLI switches """ - Script.registerSwitch( - "P", - "proxy", - "return with an access token also a proxy certificate with DIRAC group extension", - self.returnProxy) - Script.registerSwitch( - "g:", - "group=", - "set DIRAC group", - self.setGroup) - Script.registerSwitch( - "I:", - "issuer=", - "set issuer", - self.setIssuer) - Script.registerSwitch( - "T:", - "lifetime=", - "set proxy lifetime in a hours", - self.setLivetime) - Script.registerSwitch( - "F:", - "file=", - "set token file location", - self.setTokenFile) + Script.registerSwitch("P", "proxy", "request a proxy certificate with DIRAC group extension", self.returnProxy) + Script.registerSwitch("g:", "group=", "set DIRAC group", self.setGroup) + Script.registerSwitch("I:", "issuer=", "set issuer", self.setIssuer) + Script.registerSwitch("T:", "lifetime=", "set proxy lifetime in a hours", self.setLivetime) + Script.registerSwitch("F:", "file=", "set token file location", self.setTokenFile) def doOAuthMagic(self): """ Magic method with tokens @@ -135,7 +115,7 @@ def doOAuthMagic(self): scope.append('proxy') if self.lifetime: scope.append('lifetime:%s' % (int(self.lifetime) * 3600)) - idpObj.scope = '+'.join(scope) if scope else None + idpObj.scope = '+'.join(scope) if scope else '' tokenFile = getTokenFileLocation(self.tokenLoc) @@ -151,6 +131,7 @@ def doOAuthMagic(self): return result gLogger.notice('Proxy is saved to %s.' % self.proxyLoc) else: + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'True' # Revoke old tokens from token file if os.path.isfile(tokenFile): result = readTokenFromFile(tokenFile) diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index d07da05e9af..ca6fb604fce 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -71,8 +71,10 @@ def getIdProvider(self, name, **kwargs): :return: S_OK(IdProvider)/S_ERROR() """ + if not name: + return S_ERROR('Identity Provider client name must be not None.') # Get Authorization Server metadata - asMetaDict = collectMetadata() + asMetaDict = collectMetadata(kwargs.get('issuer'), ignoreErrors=True) self.log.debug('Search configuration for', name) clients = getDIRACClients() if name in clients: diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 97d7749ff3e..7b104e597a9 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -256,7 +256,7 @@ def researchGroup(self, payload=None, token=None): token = self.token if not payload and token: - payload = OAuth2Token(dict(token)).get_payload() + payload = OAuth2Token(token).get_payload() credDict = self.parseBasic(payload) if not credDict.get('DIRACGroups'): @@ -363,7 +363,7 @@ def parseAuthResponse(self, response, session=None): :param dict response: response on request to get user profile :param object session: session - :return: S_OK(dict)/S_ERROR() + :return: S_OK((dict, dict))/S_ERROR() """ response = createOAuth2Request(response) @@ -377,8 +377,10 @@ def parseAuthResponse(self, response, session=None): self.fetchToken(authorization_response=response.uri, code_verifier=session.get('code_verifier')) result = self.verifyToken(self.token['access_token']) - if result['OK']: - result = self.researchGroup(result['Value']) + if not result['OK']: + return result + payload = result['Value'] + result = self.researchGroup(payload) if not result['OK']: return result credDict = result['Value'] @@ -387,7 +389,7 @@ def parseAuthResponse(self, response, session=None): # Store token self.token['user_id'] = credDict['ID'] - return S_OK(credDict) + return S_OK((credDict, payload)) def submitDeviceCodeAuthorizationFlow(self, group=None): """ Submit authorization flow @@ -448,7 +450,7 @@ def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, t self.token = token return S_OK(token) if token['error'] != 'authorization_pending': - return S_ERROR(token['error'] + ' : ' + token.get('description', '')) + return S_ERROR((token.get('error') or 'unknown') + ' : ' + (token.get('error_description') or '')) def getGroupScopes(self, group): """ Get group scopes From 3c3858cdb9be1d2049559d6cd4baa913d0061719 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 16 Jul 2021 00:07:12 +0200 Subject: [PATCH 121/178] other fixes --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 134f52b967d..f8ff1cab543 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -79,7 +79,7 @@ def _runActions(self, reqObj): for mName, args, kwargs in self.actions: getattr(reqObj, mName)(*args, **kwargs) if not reqObj._finished: - reqObj.finish(self.data if self.data is None else encode(self.data)) + reqObj.finish(self.data) class BaseRequestHandler(RequestHandler): From c481d4e63baf1cefdc57852d09b779eed1752d7a Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:13:01 +0200 Subject: [PATCH 122/178] fix bug --- src/DIRAC/Core/Tornado/Server/TornadoServer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 5735509d81c..b58d2ab32b6 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -186,8 +186,7 @@ def startTornado(self): 'keyfile': certs[1], 'cert_reqs': M2Crypto.SSL.verify_peer, 'ca_certs': ca, - # Failed in tornado '5.1.1', 'sslDebug' not in m2netutil._SSL_CONTEXT_KEYWORDS - # 'sslDebug': False, # Set to true if you want to see the TLS debug messages + 'sslDebug': False, # Set to true if you want to see the TLS debug messages } # Init monitoring From 9b31c33081731dd1c216220147ef6a2e5ba0ba55 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:15:48 +0200 Subject: [PATCH 123/178] optimize --- .../private/authorization/AuthServer.py | 82 +++++++++++-------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 09fd265c883..7fdbd233551 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -149,7 +149,7 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, result = result['Value'].dumpAllToString() if not result['OK']: raise OAuth2Error(result['Message']) - return {'proxy': result['Value']} + return {'proxy': result['Value'].decode()} raise OAuth2Error('; '.join(err)) else: @@ -221,22 +221,22 @@ def registerRefreshToken(self, payload, token): token['refresh_token'] = result['Value'] return S_OK(token) - def getIdPAuthorization(self, providerName, request): + def getIdPAuthorization(self, provider, request): """ Submit subsession and return dict with authorization url and session number - :param str providerName: provider name + :param str provider: provider name :param object request: main session request :return: S_OK(response)/S_ERROR() -- dictionary contain response generated by `handle_response` """ - result = self.idps.getIdProvider(providerName) + result = self.idps.getIdProvider(provider) if not result['OK']: return result idpObj = result['Value'] authURL, state, session = idpObj.submitNewSession() session['state'] = state - session['Provider'] = providerName - session['mainSession'] = request if isinstance(request, dict) else request.toDict() + session['Provider'] = provider + session['firstRequest'] = request if isinstance(request, dict) else request.toDict() self.log.verbose('Redirect to', authURL) return self.handle_response(302, {}, [("Location", authURL)], session) @@ -341,7 +341,11 @@ def validate_consent_request(self, request, provider=None): if request.method != 'GET': return 'Use GET method to access this endpoint.' try: - req = self.create_oauth2_request(request) + # Check Identity Provider + req, result = self.validateIdentityProvider(self.create_oauth2_request(request), provider) + if not req: + return result + self.log.info('Validate consent request for', req.state) grant = self.get_authorization_grant(req) self.log.debug('Use grant:', grant) @@ -349,13 +353,8 @@ def validate_consent_request(self, request, provider=None): if not hasattr(grant, 'prompt'): grant.prompt = None - # Check Identity Provider - provider, providerChooser = self.validateIdentityProvider(req, provider) - if not provider: - return providerChooser - # Submit second auth flow through IdP - return self.getIdPAuthorization(provider, req) + return self.getIdPAuthorization(req.provider, req) except OAuth2Error as error: return self.handle_error_response(None, error) @@ -367,27 +366,51 @@ def validateIdentityProvider(self, request, provider): :return: str, S_OK()/S_ERROR() -- provider name and html page to choose it """ - self.log.debug("Check if %s identity provider registred in DIRAC.." % provider) + if provider: + request.provider = provider + + # Find identity provider for group + groupProvider = getIdPForGroup(request.group) if request.groups else None + + # If requested access token for group that is not registred in any identity provider + # or the requested provider does not match the group return error + if request.group and not groupProvider and 'proxy' not in request.scope: + self.db.removeSession(request.sessionID) + return None, S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.' % request.group) + # if provider and provider != groupProvider: + self.db.removeSession(request.sessionID) + return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (group, groupProvider, provider)) + # provider = groupProvider + + self.log.debug("Check if %s identity provider registred in DIRAC.." % request.provider) # Research supported IdPs result = getProvidersForInstance('Id') if not result['OK']: + self.db.removeSession(request.sessionID) return None, result - idPs = result['Value'] - - # Remove settings of the DIRAC AS - result = getProvidersForInstance('Id', 'DIRAC') - if not result['OK']: - return None, result - for dCli in result['Value']: - if dCli in idPs: - idPs.remove(dCli) + idPs = result['Value'] if not idPs: + self.db.removeSession(request.sessionID) return None, S_ERROR('No identity providers found.') - if not provider: - if len(idPs) == 1: - return idPs[0], None + if request.provider: + if request.provider not in idPs: + self.db.removeSession(request.sessionID) + return None, S_ERROR('%s identity provider is not registered.' % request.provider) + elif groupProvider and request.provider != groupProvider: + self.db.removeSession(request.sessionID) + return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (group, groupProvider, + request.provider)) + return request, None + + # If no identity provider is specified, it must be assigned + if groupProvider: + request.provider = groupProvider + elif len(idPs) == 1: + # If only one identity provider is registered, then choose it + request.provider = idPs[0] + else: # Choose IdP interface doc = document('DIRAC authentication') with doc.head: @@ -402,12 +425,7 @@ def validateIdentityProvider(self, request, provider): dom.button(dom.a(idP, href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query)), cls='button') return None, self.handle_response(payload=Template(doc.render()).generate()) - - # Check IdP - if provider not in idPs: - return None, S_ERROR('%s is not registered in DIRAC.' % provider) - - return provider, None + return request, None def __registerNewUser(self, provider, payload): """ Register new user From 69b3b089428c62a33faf60103bbe008d49c774db Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:16:26 +0200 Subject: [PATCH 124/178] optimize --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 42 +++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 61f011e82da..90399372a1b 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -12,7 +12,6 @@ import json import pprint -from io import open from dominate import document, tags as dom from tornado.template import Template @@ -23,7 +22,7 @@ from DIRAC import S_ERROR, gConfig from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST -from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getIdPForGroup, getGroupsForUser from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint @@ -113,10 +112,9 @@ def _finishFuture(self, retVal): if isinstance(self.result, dict) and self.result.get('OK') is False and 'Message' in self.result: # S_ERROR is interpreted in the OAuth2 error format. self.set_status(400) - self.write({'error': 'server_error', 'description': retVal['Message']}) self.clear_cookie('auth_session') self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) - self.finish() + self.finish({'error': 'server_error', 'description': retVal['Message']}) else: super(AuthHandler, self)._finishFuture(retVal) @@ -164,6 +162,7 @@ def web_index(self, instance): resDict = dict(setups=gConfig.getSections('DIRAC/Setups').get('Value', []), configuration_server=gConfig.getValue("/DIRAC/Configuration/MasterServer", "")) resDict.update(self.server.metadata) + resDict.pop('Clients', None) return resDict def web_jwk(self): @@ -315,17 +314,10 @@ def web_device(self, provider=None): session = result['Value'] # Get original request from session req = createOAuth2Request(dict(method='GET', uri=session['uri'])) + req.setQueryArguments(id=session['id'], user_code=userCode) - groups = [s.split(':')[1] for s in scope_to_list(req.scope) if s.startswith('g:')] # pylint: disable=no-member - group = groups[0] if groups else None - - if group and not provider: - provider = Registry.getIdPForGroup(group) - - self.log.debug('Use provider:', provider) - # pylint: disable=no-member - authURL = '%s/authorization/%s?%s&user_code=%s' % (self.LOCATION, provider, req.query, userCode) - # Save session to cookie + # Save session to cookie and redirect to authorization endpoint + authURL = '%s?%s' % (req.path.replace('device', 'authorization'), req.query) return self.server.handle_response(302, {}, [("Location", authURL)], session) # If received a request without a user code, then send a form to enter the user code @@ -398,13 +390,12 @@ def web_redirect(self): state, OAuth2Error(error=error, description=self.get_argument('error_description', ''))) # Check current auth session that was initiated for the selected external identity provider - try: - session = json.loads(self.get_secure_cookie('auth_session')) - except Exception: - session = {} + session = self.get_secure_cookie('auth_session') + if not session: + return S_ERROR("%s session is expired." % state) - sessionWithExtIdP = session if state and (session.get('state') == state) else None - if not sessionWithExtIdP: + sessionWithExtIdP = json.loads(session) + if state and not sessionWithExtIdP.get('state') == state: return S_ERROR("%s session is expired." % state) if not sessionWithExtIdP.get('authed'): @@ -466,20 +457,23 @@ def __researchDIRACGroup(self, extSession): :return: response """ # Base DIRAC client auth session - firstRequest = createOAuth2Request(extSession['mainSession']) + firstRequest = createOAuth2Request(extSession['firstRequest']) # Read requested groups by DIRAC client or user firstRequest.addScopes(self.get_arguments('chooseScope')) # Read already authed user username = extSession['authed']['username'] + # Requested arguments in first request + provider = firstRequest.provider self.log.debug('Next groups has been found for %s:' % username, ', '.join(firstRequest.groups)) # Researche Group - result = Registry.getGroupsForUser(username) + result = getGroupsForUser(username) if not result['OK']: return None, result - validGroups = result['Value'] + groups = result['Value'] + validGroups = [group for group in groups if (getIdPForGroup(group) == provider) or ('proxy' in firstRequest.scope)] if not validGroups: - return None, S_ERROR('No groups found for %s.' % username) + return None, S_ERROR('No groups found for %s and for %s Identity Provider.' % (username, provider)) self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) if not firstRequest.groups: From e2225889380887127f36f00c6ca46c8a093aefc7 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:17:57 +0200 Subject: [PATCH 125/178] add getUserInfo --- .../Service/TokenManagerHandler.py | 61 ++++++++++++++++--- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index 0d2618f2207..77c87c18bca 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -39,6 +39,25 @@ def initializeHandler(cls, serviceInfoDict): cls.idps = IdProviderFactory() return S_OK() + def __generateUsersTokensInfo(self, users): + """ Generate information dict about user tokens + + :return: dict + """ + tokensInfo = [] + credDict = self.getRemoteCredentials() + result = Registry.getDNForUsername(credDict['username']) + if not result['OK']: + return result + for dn in result['Value']: + result = Registry.getIDFromDN(dn) + if result['OK']: + result = self.__tokenDB.getTokensByUserID(result['Value']) + if not result['OK']: + gLogger.error(result['Message']) + tokensInfo += result['Value'] + return tokensInfo + def __generateUserTokensInfo(self): """ Generate information dict about user tokens @@ -74,6 +93,33 @@ def export_getUserTokensInfo(self): :return: S_OK(dict) """ return S_OK(self.__generateUserTokensInfo()) + + auth_getUsersTokensInfo = [Properties.PROXY_MANAGEMENT] + + def export_getUsersTokensInfo(self, users): + """ Get the info about the user tokens in the system + + :param list users: user names + + :return: S_OK(dict) + """ + tokensInfo = [] + for user in users: + result = Registry.getDNForUsername(user) + if not result['OK']: + return result + for dn in result['Value']: + uid = Registry.getIDFromDN(dn).get('Value') + if uid: + result = self.__tokenDB.getTokensByUserID(uid) + if not result['OK']: + gLogger.error(result['Message']) + else: + for tokenDict in result['Value']: + if tokenDict not in tokensInfo: + tokenDict['username'] = user + tokensInfo.append(tokenDict) + return S_OK(tokensInfo) auth_uploadToken = ['authenticated'] @@ -85,7 +131,7 @@ def export_updateToken(self, token, userID, provider, rt_expired_in=24 * 3600): :param str provider: provider name :param int rt_expired_in: refresh token expires time - :return: S_OK(dict)/S_ERROR() -- dict contain uploaded tokens info + :return: S_OK(list)/S_ERROR() -- list contain uploaded tokens info as dictionaries """ self.log.verbose('Update %s user token for %s:\n' % (userID, provider), pprint.pformat(token)) result = self.idps.getIdProvider(provider) @@ -107,7 +153,7 @@ def __checkProperties(self, requestedUserDN, requestedUserGroup): :param str requestedUserDN: user DN :param str requestedUserGroup: DIRAC group - :return: S_OK(boolean)/S_ERROR() + :return: S_OK(bool)/S_ERROR() """ credDict = self.getRemoteCredentials() if Properties.FULL_DELEGATION in credDict['properties']: @@ -134,7 +180,7 @@ def export_getToken(self, username, userGroup): userID = [] provider = Registry.getIdPForGroup(userGroup) if not provider: - return S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.') + return S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.' % userGroup) result = self.idps.getIdProvider(provider) if not result['OK']: @@ -150,18 +196,15 @@ def export_getToken(self, username, userGroup): result = Registry.getIDFromDN(dn) if result['OK']: result = self.__tokenDB.getTokenForUserProvider(result['Value'], provider) - if not result['OK']: - err.append(result['Message']) - elif result['Value']: + if result['OK'] and result['Value']: idpObj.token = result['Value'] result = self.__checkProperties(dn, userGroup) if result['OK']: result = idpObj.exchangeGroup(userGroup) if result['OK']: return result - if not err: - return S_ERROR('No user ID found for %s' % username) - return S_ERROR('; '.join(err)) + err.append(result.get('Message', 'No token found for %s.' % dn)) + return S_ERROR('; '.join(err or ['No user ID found for %s' % username])) def export_deleteToken(self, userDN): """ Delete a token from the DB From 99c88f4da96bb62f5c25892fd8da98994fb67261 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:18:53 +0200 Subject: [PATCH 126/178] add properties to OAuth2Request --- .../private/authorization/utils/Requests.py | 62 ++++++++++++++++--- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index 375ac67af48..76ead9b67e1 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -20,12 +20,27 @@ def addScopes(self, scopes): :param list scopes: scopes """ - # Remove "scope" argument from uri - self.uri = re.sub(r"&scope(=[^&]*)?|^scope(=[^&]*)?&?", "", self.uri) - # Add "scope" argument to uri with new scopes - self.uri += "&scope=%s" % '+'.join(list(set(scope_to_list(self.scope or '') + scopes))) or '' - # Reinit all attributes with new uri - self.__init__(self.method, to_unicode(self.uri)) + self.setQueryArguments(scope=list(set(scope_to_list(self.scope or '') + scopes))) + + def setQueryArguments(self, **kwargs): + """ Set query arguments """ + for k in kwargs: + # Remove argument from uri + query = re.sub(r"&{argument}(=[^&]*)?|^{argument}(=[^&]*)?&?".format(argument=k), "", self.query) + # Add new one + if query: + query += '&' + query += "%s=%s" % (k, '+'.join(kwargs[k]) if isinstance(kwargs[k], list) else kwargs[k]) + # Re-init class + self.__init__(self.method, to_unicode(self.path + '?' + query)) + + @property + def path(self): + """ URL path + + :return: str + """ + return self.uri.replace('?%s' % (self.query or ''), '') @property def groups(self): @@ -33,7 +48,40 @@ def groups(self): :return: list """ - return [s.split(':')[1] for s in scope_to_list(self.scope) if s.startswith('g:') and s.split(':')[1]] + return [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] + + @property + def group(self): + """ Serarch DIRAC group in scopes + + :return: str + """ + groups = [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] + return groups[0] if groups else None + + @property + def provider(self): + """ Serarch IdP in scopes + + :return: str + """ + return self.data.get('provider') + + @provider.setter + def provider(self, provider): + self.setQueryArguments(provider=provider) + + @property + def sessionID(self): + """ Serarch IdP in scopes + + :return: str + """ + return self.data.get('id') + + @provider.setter + def sessionID(self, sessionID): + self.setQueryArguments(id=sessionID) def toDict(self): """ Convert class to dictionary From 1cdae99b3561ca4e938f0072eca8458f2389a7f1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:20:00 +0200 Subject: [PATCH 127/178] fix bug --- src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index e52bf472e61..6e85fad590b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -187,7 +187,7 @@ def groups(self): :return: list """ - return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] + return [s.split(':')[1] for s in self.scopes if s.startswith('g:') and s.split(':')[1]] def get_payload(self, token_type='access_token'): """ Decode token From 94442d7d8629f2f2068073e296f4f073af844884 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:23:19 +0200 Subject: [PATCH 128/178] fix DIRAC_USE_ACCESS_TOKEN --- src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 1 + src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 802af2701be..153767a407d 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -125,6 +125,7 @@ def doOAuthMagic(self): return result if self.proxy: + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'False' # Save new proxy certificate result = writeToProxyFile(idpObj.token['proxy'].encode("UTF-8"), self.proxyLoc) if not result['OK']: diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py index 8868db2fe1a..9885fa97dbe 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py @@ -254,6 +254,8 @@ def main(): global piParams, pI piParams = Params() piParams.registerCLISwitches() + # Take off tokens + os.environ['DIRAC_USE_ACCESS_TOKEN'] = 'False' Script.disableCS() Script.parseCommandLine(ignoreErrors=True) From 8b6a8f7296be0bc648a51f2e9db0560dba3e2f6c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 19 Jul 2021 22:25:06 +0200 Subject: [PATCH 129/178] modify getGroupScopes --- .../Resources/IdProvider/CheckInIdProvider.py | 18 +++++++++++++++++- .../Resources/IdProvider/IAMIdProvider.py | 16 +++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py index dcf19c8bb3d..e16c5cb124a 100644 --- a/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/CheckInIdProvider.py @@ -4,11 +4,27 @@ from __future__ import division from __future__ import print_function +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC import S_OK from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup, getGroupOption + __RCSID__ = "$Id$" class CheckInIdProvider(OAuth2IdProvider): - pass + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPRole') + if not idPScope: + idPScope = 'eduperson_entitlement?value=urn:mace:egi.eu:group:%s:role=%s#aai.egi.eu' % (getVOForGroup(group), + group.split('_')[1]) + return S_OK(scope_to_list(idPScope)) diff --git a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py index c3f51e04a7a..66a4f201abd 100644 --- a/src/DIRAC/Resources/IdProvider/IAMIdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IAMIdProvider.py @@ -4,11 +4,25 @@ from __future__ import division from __future__ import print_function +from authlib.oauth2.rfc6749.util import scope_to_list + +from DIRAC import S_OK from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup, getGroupOption __RCSID__ = "$Id$" class IAMIdProvider(OAuth2IdProvider): - pass + def getGroupScopes(self, group): + """ Get group scopes + + :param str group: DIRAC group + + :return: list + """ + idPScope = getGroupOption(group, 'IdPRole') + if not idPScope: + idPScope = 'wlcg.groups:/%s/%s' % (getVOForGroup(group), group.split('_')[1]) + return S_OK(scope_to_list(idPScope)) From 2bb50b2ed3738e2ae1a6897c4df71bc3b49790ff Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 20 Jul 2021 16:10:54 +0200 Subject: [PATCH 130/178] fix pylint --- .../Core/Tornado/Server/private/BaseRequestHandler.py | 2 +- src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py | 2 +- .../FrameworkSystem/private/authorization/AuthServer.py | 7 ++++--- .../private/authorization/utils/Requests.py | 6 +++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index f8ff1cab543..d9b2d7614b4 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -58,7 +58,7 @@ def web_myEndpotin(self): """ def __init__(self, data=None): """ C'or - + :param data: response body """ self.data = data diff --git a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py index 77c87c18bca..8ffa1f20aae 100644 --- a/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TokenManagerHandler.py @@ -93,7 +93,7 @@ def export_getUserTokensInfo(self): :return: S_OK(dict) """ return S_OK(self.__generateUserTokensInfo()) - + auth_getUsersTokensInfo = [Properties.PROXY_MANAGEMENT] def export_getUsersTokensInfo(self, users): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 7fdbd233551..f0249825309 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -129,7 +129,7 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, if 'proxy' in scope_to_list(scope): # Try to return user proxy if proxy scope present in the authorization request if not isDownloadablePersonalProxy(): - raise OAuth2Error("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") + raise OAuth2Error("You can't get proxy, configuration(downloadablePersonalProxy) not allow to do that.") self.log.debug('Try to query %s@%s proxy%s' % (user, group, ('with lifetime:%s' % lifetime) if lifetime else '')) result = getDNForUsername(userName) if not result['OK']: @@ -379,7 +379,8 @@ def validateIdentityProvider(self, request, provider): return None, S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.' % request.group) # if provider and provider != groupProvider: self.db.removeSession(request.sessionID) - return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (group, groupProvider, provider)) + return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, + request.provider)) # provider = groupProvider self.log.debug("Check if %s identity provider registred in DIRAC.." % request.provider) @@ -400,7 +401,7 @@ def validateIdentityProvider(self, request, provider): return None, S_ERROR('%s identity provider is not registered.' % request.provider) elif groupProvider and request.provider != groupProvider: self.db.removeSession(request.sessionID) - return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (group, groupProvider, + return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, request.provider)) return request, None diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index 76ead9b67e1..c0c28e2cb52 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -49,7 +49,7 @@ def groups(self): :return: list """ return [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] - + @property def group(self): """ Serarch DIRAC group in scopes @@ -58,7 +58,7 @@ def group(self): """ groups = [s.split(':')[1] for s in scope_to_list(self.scope or '') if s.startswith('g:') and s.split(':')[1]] return groups[0] if groups else None - + @property def provider(self): """ Serarch IdP in scopes @@ -79,7 +79,7 @@ def sessionID(self): """ return self.data.get('id') - @provider.setter + @sessionID.setter def sessionID(self, sessionID): self.setQueryArguments(id=sessionID) From ca4f4978ca793d4ab641b572a4908db75db36a26 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 29 Jul 2021 20:43:02 +0200 Subject: [PATCH 131/178] remove unused --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 90399372a1b..f8e97095504 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -18,7 +18,6 @@ from tornado.concurrent import Future from authlib.oauth2.base import OAuth2Error -from authlib.oauth2.rfc6749.util import scope_to_list from DIRAC import S_ERROR, gConfig from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST From 8d2394d3a462695bea12730225707df831e20cda Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 29 Jul 2021 21:29:17 +0200 Subject: [PATCH 132/178] fix rebase --- .../ConfigurationSystem/Client/PathFinder.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index 9a38fe9c745..e21addcd5f7 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -95,6 +95,17 @@ def getComponentSection(system, component=False, setup=False, componentCategory= ) +def getAPISection(system, endpointName=False, setup=False): + """ Get service section in a system + + :param str system: system name + :param str endpointName: endpoint name + + :return: str + """ + return getComponentSection(system, component=endpointName, setup=setup, "APIs") + + def getServiceSection(system, serviceName=False, setup=False): """ Get service section in a system @@ -131,12 +142,8 @@ def getExecutorSection(system, executorName=None, component=False, setup=False): return getComponentSection(system, component=executorName, setup=setup, componentCategory="Executors") -def getAPISection(APIName, APITuple=False, setup=False): - return getComponentSection(APIName, APITuple, setup, "APIs") - - -def getServiceSection(serviceName, serviceTuple=False, setup=False): - return getComponentSection(serviceName, serviceTuple, setup, "Services") +def getDatabaseSection(dbName, dbTuple=False, setup=False): + """ Get DB section in a system :param str system: system name :param str dbName: DB name From 8afdbde7182c52ff6cf1e90e379df09076a78ea0 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 29 Jul 2021 21:42:59 +0200 Subject: [PATCH 133/178] fix rebase --- src/DIRAC/ConfigurationSystem/Client/PathFinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index e21addcd5f7..e629a24c0ea 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -103,7 +103,7 @@ def getAPISection(system, endpointName=False, setup=False): :return: str """ - return getComponentSection(system, component=endpointName, setup=setup, "APIs") + return getComponentSection(system, component=endpointName, setup=setup, componentCategory="APIs") def getServiceSection(system, serviceName=False, setup=False): From ec406a47c249260695904badb624a39d499a4f43 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 29 Jul 2021 21:47:10 +0200 Subject: [PATCH 134/178] fix rebase --- src/DIRAC/ConfigurationSystem/Client/PathFinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index e629a24c0ea..ecbb4cd303a 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -142,7 +142,7 @@ def getExecutorSection(system, executorName=None, component=False, setup=False): return getComponentSection(system, component=executorName, setup=setup, componentCategory="Executors") -def getDatabaseSection(dbName, dbTuple=False, setup=False): +def getDatabaseSection(system, dbName=False, setup=False): """ Get DB section in a system :param str system: system name From 5d18b5b01f0dcce9ea118d6edbb70d7061094c98 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 1 Aug 2021 17:29:55 +0200 Subject: [PATCH 135/178] pass args_kwargs to target method --- .../Server/private/BaseRequestHandler.py | 79 +++++++++++-------- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 36 ++++----- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index d9b2d7614b4..faa1c0ddd44 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -56,6 +56,8 @@ def web_myEndpotin(self): resp.set_status(400) return resp """ + __attrs = inspect.getmembers(RequestHandler) + def __init__(self, data=None): """ C'or @@ -63,13 +65,19 @@ def __init__(self, data=None): """ self.data = data self.actions = [] - for mName, mObj in inspect.getmembers(RequestHandler): + for mName, mObj in self.__attrs: if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startswith('get'): setattr(self, mName, partial(self.__setAction, mName)) def __setAction(self, mName, *args, **kwargs): - """ Register new action """ + """ Register new action + + :param str mName: RequestHandler method name + + :return: TornadoResponse instance + """ self.actions.append((mName, args, kwargs)) + return self def _runActions(self, reqObj): """ Calling methods in the order of their registration @@ -186,6 +194,7 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ # Which grant type to use USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + @classmethod def _initMonitoring(cls, serviceName, fullUrl): """ @@ -395,9 +404,24 @@ def _getMethodName(self): def _getMethodArgs(self, args): """ Decode args. - :return: list - """ - return args + :return: tuple + """ + # Read information about method + spec = inspect.getargspec(self.mehtodObj) + # Pass all arguments + kwargs = self.request.arguments.copy() if spec.keywords else {} + # Get all defaults from method + defaults = {a: spec.defaults[args.index(a)] for a in args[-len(spec.defaults):]} if spec.defaults else {} + # Calcule kwargs + for arg in spec.args[len(args):]: + if not defaults.get(arg): + kwargs[arg] = self.get_argument(arg) + elif isinstance(defaults[arg], six.string_types): + kwargs[arg] = self.get_arguments(arg, defaults[arg]) + else: + kwargs[arg] = self.get_argument(arg, defaults[arg]) + + return (args, kwargs) def _getMethodAuthProps(self): """ Resolves the hard coded authorization requirements for method. @@ -406,18 +430,18 @@ def _getMethodAuthProps(self): """ if self.AUTH_PROPS and not isinstance(self.AUTH_PROPS, (list, tuple)): self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] - return getattr(self, 'auth_' + self.method, self.AUTH_PROPS) + return getattr(self, 'auth_' + self.mehtodName, self.AUTH_PROPS) def _getMethod(self): """ Get method function to call. :return: function """ - method = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method), None) - if not callable(method): - sLog.error("Invalid method", self.method) + methodObj = getattr(self, '%s%s' % (self.METHOD_PREFIX, self.mehtodName), None) + if not callable(methodObj): + sLog.error("Invalid method", self.mehtodName) raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) - return method + return methodObj def prepare(self): """ @@ -429,7 +453,8 @@ def prepare(self): # on the handler side # If the argument is not available, the method exists # and an error 400 ``Bad Request`` is returned to the client - self.method = self._getMethodName() + self.mehtodName = self._getMethodName() + self.methodObj = self._getMethod() self._monitorRequest() @@ -456,7 +481,7 @@ def _prepare(self): # Check whether we are authorized to perform the query # Note that performing the authQuery modifies the credDict... - authorized = self._authManager.authQuery(self.method, self.credDict, + authorized = self._authManager.authQuery(self.mehtodName, self.credDict, self._getMethodAuthProps()) if not authorized: extraInfo = '' @@ -470,7 +495,7 @@ def _prepare(self): self.request.path, extraInfo)) raise HTTPError(http_client.UNAUTHORIZED) - def __executeMethod(self, targetMethod, args): + def __executeMethod(self, targetMethod, args, kwargs): """ Execute the method called, this method is ran in an executor We have several try except to catch the different problem which can occur @@ -484,33 +509,17 @@ def __executeMethod(self, targetMethod, args): :param str targetMethod: name of the method to call :param list args: target method arguments + :param dict kwargs: target method arguments :return: Future """ sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, self.method)) - # Execute - try: - self.initializeRequest() - return targetMethod(*args) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) - raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) - - @gen.coroutine - def __executeMethodPy2(self, targetMethod, args): - """ The only difference from __executeMethod is the presence of a coroutine decorator - - :return: Future - """ - - sLog.notice("Incoming request %s /%s: %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, self.method)) + self._serviceName, self.mehtodName)) # Execute try: self.initializeRequest() - return targetMethod(*args) + return targetMethod(*args, **kwargs) except Exception as e: # pylint: disable=broad-except sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) raise e if isinstance(e, HTTPError) else HTTPError(http_client.INTERNAL_SERVER_ERROR, str(e)) @@ -522,8 +531,8 @@ def _prepareExecutor(self, args): :return: executor, target method with arguments """ - return None, partial(self.__executeMethodPy2 if six.PY2 else self.__executeMethod, - self._getMethod(), self._getMethodArgs(args)) + return None, partial(gen.coroutine(self.__executeMethod) if six.PY2 else self.__executeMethod, + self.methodObj, *self._getMethodArgs(args)) def _finishFuture(self, retVal): """ Handler Future result @@ -538,7 +547,7 @@ def _finishFuture(self, retVal): # If you need to end the method using tornado methods, outside the thread, # you need to define the finish_ method. # This method will be started after __executeMethod is completed. - finishFunc = getattr(self, 'finish_%s' % self.method, None) + finishFunc = getattr(self, 'finish_%s' % self.mehtodName, None) if isinstance(self.result, TornadoResponse): self.result._runActions(self) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index f8e97095504..40d2dbab2c3 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -119,7 +119,7 @@ def _finishFuture(self, retVal): path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] - def web_index(self, instance): + def web_index(self, *instance): """ Well known endpoint, specified by `RFC8414 `_ @@ -243,7 +243,7 @@ def web_userinfo(self): path_device = ['([A-z0-9-_]*)'] - def web_device(self, provider=None): + def web_device(self, provider=None, user_code=None): """ The device authorization endpoint can be used to request device and user codes. This endpoint is used to start the device flow authorization process and user code verification. @@ -303,17 +303,17 @@ def web_device(self, provider=None): return self.server.create_endpoint_response(DeviceAuthorizationEndpoint.ENDPOINT_NAME, self.request) elif self.request.method == 'GET': - userCode = self.get_argument('user_code', None) - if userCode: + # userCode = self.get_argument('user_code', None) + if user_code: # If received a request with a user code, then prepare a request to authorization endpoint self.log.verbose('User code verification.') - result = self.server.db.getSessionByUserCode(userCode) + result = self.server.db.getSessionByUserCode(user_code) if not result['OK']: - return 'Device code flow authorization session %s expired.' % userCode + return 'Device code flow authorization session %s expired.' % user_code session = result['Value'] # Get original request from session req = createOAuth2Request(dict(method='GET', uri=session['uri'])) - req.setQueryArguments(id=session['id'], user_code=userCode) + req.setQueryArguments(id=session['id'], user_code=user_code) # Save session to cookie and redirect to authorization endpoint authURL = '%s?%s' % (req.path.replace('device', 'authorization'), req.query) @@ -364,7 +364,7 @@ def web_authorization(self, provider=None): """ return self.server.validate_consent_request(self.request, provider) - def web_redirect(self): + def web_redirect(self, state, error=None, error_description='', chooseScope=[]): """ Redirect endpoint. After a user successfully authorizes an application, the authorization server will redirect the user back to the application with either an authorization code or access token in the URL. @@ -379,14 +379,14 @@ def web_redirect(self): &chooseScope=.. to specify new scope(group in our case) (optional) """ - # Current IdP session state - state = self.get_argument('state') + # # Current IdP session state + # state = self.get_argument('state') - # Try to catch errors - error = self.get_argument('error', None) + # # Try to catch errors + # error = self.get_argument('error', None) if error: return self.server.handle_error_response( - state, OAuth2Error(error=error, description=self.get_argument('error_description', ''))) + state, OAuth2Error(error=error, description=error_description)) # self.get_argument('error_description', ''))) # Check current auth session that was initiated for the selected external identity provider session = self.get_secure_cookie('auth_session') @@ -408,7 +408,7 @@ def web_redirect(self): sessionWithExtIdP['authed'] = result['Value'] # Research group - grant_user, response = self.__researchDIRACGroup(sessionWithExtIdP) + grant_user, response = self.__researchDIRACGroup(sessionWithExtIdP, chooseScope, state) if not grant_user: return response @@ -448,7 +448,7 @@ def web_token(self): """ return self.server.create_token_response(self.request) - def __researchDIRACGroup(self, extSession): + def __researchDIRACGroup(self, extSession, chooseScope, state): """ Research DIRAC groups for authorized user :param dict extSession: ended authorized external IdP session @@ -458,7 +458,7 @@ def __researchDIRACGroup(self, extSession): # Base DIRAC client auth session firstRequest = createOAuth2Request(extSession['firstRequest']) # Read requested groups by DIRAC client or user - firstRequest.addScopes(self.get_arguments('chooseScope')) + firstRequest.addScopes(chooseScope) # self.get_arguments('chooseScope')) # Read already authed user username = extSession['authed']['username'] # Requested arguments in first request @@ -484,8 +484,8 @@ def __researchDIRACGroup(self, extSession): with dom.div(style=self.css_main): with dom.div('Choose group', style=self.css_align_center): for group in validGroups: - dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, - self.get_argument('state'), group)), + dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group)), + # self.get_argument('state'), group)), cls='button') return None, self.server.handle_response(payload=Template(self.doc.render()).generate(), newSession=extSession) From aa04255305d8724ac0ccc3f2e3f98abb7aa51c77 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 7 Aug 2021 18:43:02 +0200 Subject: [PATCH 136/178] fix rebase --- requirements.txt | 73 ------------------- .../private/TornadoRefresher.py | 15 +--- 2 files changed, 2 insertions(+), 86 deletions(-) delete mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index c8018f8ec45..00000000000 --- a/requirements.txt +++ /dev/null @@ -1,73 +0,0 @@ -# From repo -fts3-rest - -#Patch for tornado -git+https://github.com/DIRACGrid/tornado.git@iostreamConfigurable -git+https://github.com/DIRACGrid/tornado_m2crypto.git - -# From pypi -apache-libcloud -boto3 -#asn1 -M2Crypto>=0.36 -autopep8==1.3.3 -cachetools<4 -certifi -coverage -docutils -diraccfg -elasticsearch-dsl~=6.3.1 -CMRESHandler>=1.0.0b4 -funcsigs -future -futures>=3.0.5 -GitPython>=2.1.0 -# newer versions of matplotlib require python 3 -matplotlib>=2.1.0,<3.0 -mock>=1.0.1 -MySQL-python>=1.2.5 -importlib_resources -jinja2 -ipython==5.3.0 -numpy>=1.10.1 -pexpect>=4.0.1 -pillow -psutil>=4.2.0 -pyasn1>0.4.1 -pyasn1_modules -Pygments>=1.5 -parameterized -pylint>=1.6.5 -pyparsing>=2.0.6 -pytest>=3.6 -pytest-cov>=2.2.0 -pytest-mock -pytz -readline>=6.2.4 -recommonmark -requests>=2.9.1 -rucio-clients >=1.25.6 -simplejson>=3.8.1 -six>=1.10 -# Freeze until all problems with 1.4 are solved -sqlalchemy==1.3.* -xmltodict -# more recent version are python 3 only -stomp.py==4.1.23 -suds-jurko>=0.6 -sphinx -# typing comes in via m2crypto. newer versions of typing caused an error in hypothesis -typing==3.6.6 -hypothesis -python-json-logger>=0.1.8 -multi-mechanize>=1.2.0 -caniusepython3 -subprocess32 -flaky -ldap3 -# setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 -setuptools_scm<6.0 -# OAuth2 -Authlib -pyjwt -dominate \ No newline at end of file diff --git a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py index 32dc1ed733f..841f2c23186 100644 --- a/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py +++ b/src/DIRAC/ConfigurationSystem/private/TornadoRefresher.py @@ -4,7 +4,7 @@ __RCSID__ = "$Id$" -import six +from six import PY3 import time from tornado import gen @@ -82,7 +82,7 @@ def __refreshLoop(self): # RuntimeError: There is no current event loop in thread.. # The reason seems to be that asyncio.get_event_loop() is called in some thread other than the main thread, # asyncio only generates an event loop for the main thread. - yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh if six.PY3 else self.__AutoRefreshPy2) + yield _IOLoop.current().run_in_executor(None, self.__AutoRefresh if PY3 else gen.coroutine(self.__AutoRefresh)) def __AutoRefresh(self): """ @@ -94,17 +94,6 @@ def __AutoRefresh(self): if not self._refreshAndPublish(): # pylint: disable=no-member gLogger.error("Can't refresh configuration from any source") - @gen.coroutine - def __AutoRefreshPy2(self): - """ - Auto refresh the configuration - We disable pylint error because this class must be instanciated - by a mixin to define the methods. for python 2 - """ - if self._refreshEnabled: # pylint: disable=no-member - if not self._refreshAndPublish(): # pylint: disable=no-member - gLogger.error("Can't refresh configuration from any source") - def daemonize(self): """ daemonize is probably not the best name because there is no daemon behind but we must keep it to the same interface of the DISET refresher """ From bb5e65ddf2fc143ca008d5c2c8ab8a70e29fc66f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 8 Aug 2021 14:15:27 +0200 Subject: [PATCH 137/178] fix target method args --- src/DIRAC/Core/Tornado/Server/TornadoService.py | 4 ++-- .../Tornado/Server/private/BaseRequestHandler.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index 10568c763cc..f343c2fd44c 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -147,7 +147,7 @@ def _getMethodName(self): def _getMethodArgs(self, args): """ Decode args. - :return: list + :return: tuple """ # "method" argument of the POST call. @@ -332,7 +332,7 @@ def __executeMethod(self): # Decode args args_encoded = self.get_body_argument('args', default=encode([])) - return decode(args_encoded)[0] + return (decode(args_encoded)[0], {}) # Make post a coroutine. # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index faa1c0ddd44..1d8dd038f3d 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -404,14 +404,16 @@ def _getMethodName(self): def _getMethodArgs(self, args): """ Decode args. - :return: tuple + :return: tuple -- contain args and kwargs """ # Read information about method - spec = inspect.getargspec(self.mehtodObj) + fArgs, fvArgs, fvKwargs, fDef = inspect.getargspec(self.methodObj) + # Remove self argument from method arguments + fArgs = [a for a in fArgs if a != "self"] # Pass all arguments - kwargs = self.request.arguments.copy() if spec.keywords else {} + kwargs = self.request.arguments.copy() if fvKwargs else {} # Get all defaults from method - defaults = {a: spec.defaults[args.index(a)] for a in args[-len(spec.defaults):]} if spec.defaults else {} + defaults = {a: fDef[args.index(a)] for a in args[-len(fDef):]} if fDef else {} # Calcule kwargs for arg in spec.args[len(args):]: if not defaults.get(arg): @@ -531,8 +533,9 @@ def _prepareExecutor(self, args): :return: executor, target method with arguments """ + args, kwargs = self._getMethodArgs(args) return None, partial(gen.coroutine(self.__executeMethod) if six.PY2 else self.__executeMethod, - self.methodObj, *self._getMethodArgs(args)) + self.methodObj, args, kwargs) def _finishFuture(self, retVal): """ Handler Future result From 55206c46a6f0dfc9e971573fc693f8878b4915d8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 8 Aug 2021 16:14:59 +0200 Subject: [PATCH 138/178] fix target method args --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 1d8dd038f3d..47db71ef90b 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -413,10 +413,10 @@ def _getMethodArgs(self, args): # Pass all arguments kwargs = self.request.arguments.copy() if fvKwargs else {} # Get all defaults from method - defaults = {a: fDef[args.index(a)] for a in args[-len(fDef):]} if fDef else {} + defaults = {a: fDef[args[-len(default):].index(a)] for a in args[-len(fDef):]} if fDef else {} # Calcule kwargs - for arg in spec.args[len(args):]: - if not defaults.get(arg): + for arg in fArgs[len(args):]: + if arg not in defaults: kwargs[arg] = self.get_argument(arg) elif isinstance(defaults[arg], six.string_types): kwargs[arg] = self.get_arguments(arg, defaults[arg]) From e38b5bb51d80de3f505f4d39f68807748005d1a8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 8 Aug 2021 16:44:50 +0200 Subject: [PATCH 139/178] fix target method args --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 47db71ef90b..85e8aaab00b 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -413,13 +413,13 @@ def _getMethodArgs(self, args): # Pass all arguments kwargs = self.request.arguments.copy() if fvKwargs else {} # Get all defaults from method - defaults = {a: fDef[args[-len(default):].index(a)] for a in args[-len(fDef):]} if fDef else {} + defaults = {a: fDef[fArgs[-len(default):].index(a)] for a in fArgs[-len(fDef):]} if fDef else {} # Calcule kwargs for arg in fArgs[len(args):]: if arg not in defaults: kwargs[arg] = self.get_argument(arg) - elif isinstance(defaults[arg], six.string_types): - kwargs[arg] = self.get_arguments(arg, defaults[arg]) + elif isinstance(defaults[arg], list): + kwargs[arg] = self.get_arguments(arg) or defaults[arg] else: kwargs[arg] = self.get_argument(arg, defaults[arg]) From 0a4ca001c8e936e8cb7ef4c48158b1712f6d3af2 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 11 Aug 2021 21:32:25 +0200 Subject: [PATCH 140/178] fix bug --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 85e8aaab00b..27cbd7814bc 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -413,7 +413,7 @@ def _getMethodArgs(self, args): # Pass all arguments kwargs = self.request.arguments.copy() if fvKwargs else {} # Get all defaults from method - defaults = {a: fDef[fArgs[-len(default):].index(a)] for a in fArgs[-len(fDef):]} if fDef else {} + defaults = {a: fDef[fArgs[-len(fDef):].index(a)] for a in fArgs[-len(fDef):]} if fDef else {} # Calcule kwargs for arg in fArgs[len(args):]: if arg not in defaults: From 2aa6a7ce6619314fe55299d2a388ef9eba427efd Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 11 Aug 2021 21:56:38 +0200 Subject: [PATCH 141/178] fix tests --- .../Integration/Framework/Test_AuthServer.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/Integration/Framework/Test_AuthServer.py b/tests/Integration/Framework/Test_AuthServer.py index 84e351ad576..7358ffe0ee4 100644 --- a/tests/Integration/Framework/Test_AuthServer.py +++ b/tests/Integration/Framework/Test_AuthServer.py @@ -20,7 +20,7 @@ if six.PY3: # DIRACOS not contain required packages - from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer + from DIRAC.FrameworkSystem.private.authorization import AuthServer class Proxy(object): @@ -45,7 +45,7 @@ def getToken(self, *args, **kwargs): @pytest.fixture -def auth_server(mocker): +def auth_server(monkeypatch): cfg = CFG() cfg.loadFromBuffer(""" DIRAC @@ -60,20 +60,14 @@ def auth_server(mocker): } """) gConfig.loadCFG(cfg) - if AuthServer: - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getIdPForGroup", - side_effect=mockgetIdPForGroup) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getDNForUsername", - side_effect=mockgetDNForUsername) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.getUsernameForDN", - side_effect=mockgetUsernameForDN) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.ProxyManagerClient", - side_effect=ProxyManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.TokenManagerClient", - side_effect=TokenManagerClient) - mocker.patch("DIRAC.FrameworkSystem.private.authorization.AuthServer.isDownloadablePersonalProxy", - side_effect=mockisDownloadablePersonalProxy) - return DIRAC.FrameworkSystem.private.authorization.AuthServer.AuthServer() + if six.PY3: + monkeypatch.setattr(AuthServer, "getIdPForGroup", mockgetIdPForGroup) + monkeypatch.setattr(AuthServer, "getDNForUsername", mockgetDNForUsername) + monkeypatch.setattr(AuthServer, "getUsernameForDN", mockgetUsernameForDN) + monkeypatch.setattr(AuthServer, "ProxyManagerClient", ProxyManagerClient) + monkeypatch.setattr(AuthServer, "TokenManagerClient", TokenManagerClient) + monkeypatch.setattr(AuthServer, "isDownloadablePersonalProxy", mockisDownloadablePersonalProxy) + return AuthServer.AuthServer() return None From 43ab502ea8871f20b6cbea692b77f10292eea5b1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 12 Aug 2021 00:30:29 +0200 Subject: [PATCH 142/178] fix pylint --- .../Server/private/BaseRequestHandler.py | 2 +- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 27cbd7814bc..69a7b7edb56 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -535,7 +535,7 @@ def _prepareExecutor(self, args): """ args, kwargs = self._getMethodArgs(args) return None, partial(gen.coroutine(self.__executeMethod) if six.PY2 else self.__executeMethod, - self.methodObj, args, kwargs) + self.methodObj, args, kwargs) def _finishFuture(self, retVal): """ Handler Future result diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 40d2dbab2c3..38d3639d380 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -375,18 +375,16 @@ def web_redirect(self, state, error=None, error_description='', chooseScope=[]): GET LOCATION/redirect - Parameters:: + :param str state: Current IdP session state + :param str error: IdP error response + :param str error_description: error description + :param list chooseScope: to specify new scope(group in our case) (optional) - &chooseScope=.. to specify new scope(group in our case) (optional) + :return: S_OK()/S_ERROR() """ - # # Current IdP session state - # state = self.get_argument('state') - - # # Try to catch errors - # error = self.get_argument('error', None) + # Try to catch errors if error: - return self.server.handle_error_response( - state, OAuth2Error(error=error, description=error_description)) # self.get_argument('error_description', ''))) + return self.server.handle_error_response(state, OAuth2Error(error=error, description=error_description)) # Check current auth session that was initiated for the selected external identity provider session = self.get_secure_cookie('auth_session') @@ -485,7 +483,6 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): with dom.div('Choose group', style=self.css_align_center): for group in validGroups: dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group)), - # self.get_argument('state'), group)), cls='button') return None, self.server.handle_response(payload=Template(self.doc.render()).generate(), newSession=extSession) From 7ed460cb59ff2df84de800d06c44c8cdbfd7849b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 12 Aug 2021 02:06:48 +0200 Subject: [PATCH 143/178] fix pylint --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 69a7b7edb56..3acbf686b90 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -194,7 +194,6 @@ def post(self, *args, **kwargs): # pylint: disable=arguments-differ # Which grant type to use USE_AUTHZ_GRANTS = ['SSL', 'JWT'] - @classmethod def _initMonitoring(cls, serviceName, fullUrl): """ From b6ee4e4676d0dcf1fc6210a121b268fe87020a31 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 12 Aug 2021 22:52:44 +0200 Subject: [PATCH 144/178] fix decode --- src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index f0249825309..0126486127a 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -149,7 +149,7 @@ def generateProxyOrToken(self, client, grant_type, user=None, scope=None, result = result['Value'].dumpAllToString() if not result['OK']: raise OAuth2Error(result['Message']) - return {'proxy': result['Value'].decode()} + return {'proxy': result['Value'].decode() if isinstance(result['Value'], bytes) else result['Value']} raise OAuth2Error('; '.join(err)) else: From 5461b7abb4575528f16b1bfa4d72e885c75f4a97 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 17:26:52 +0200 Subject: [PATCH 145/178] make https bundle delivery --- .../Client/BundleDeliveryClient.py | 18 +++++++-- .../Service/BundleDeliveryHandler.py | 39 ++++++++++++++++++- .../Resources/IdProvider/OAuth2IdProvider.py | 2 +- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py index 239e39d27c7..2588875800f 100644 --- a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py +++ b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py @@ -11,7 +11,8 @@ from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Base.Client import Client, createClient -from DIRAC.Core.DISET.TransferClient import TransferClient +from DIRAC.Core.Tornado.Client.TornadoClient import TornadoClient +from DIRAC.Core.Tornado.Client.ClientSelector import TransferClientSelector as TransferClient from DIRAC.Core.Security import Locations, Utilities from DIRAC.Core.Utilities.File import mkDir from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import skipCACheck @@ -20,6 +21,16 @@ __RCSID__ = "$Id$" +class BundleDeliveryJSONClient(TornadoClient): + + def receiveFile(self, buff, fileId): + retVal = self.executeRPC('streamToClient', fileId) + if retVal['OK']: + retVal['Value'] = b64decode(retVal['Value'].encode()) + buff.write(retVal['Value']) + return retVal + + @createClient('Framework/BundleDelivery') class BundleDeliveryClient(Client): @@ -37,7 +48,8 @@ def __getTransferClient(self): if self.transferClient: return self.transferClient return TransferClient("Framework/BundleDelivery", - skipCACheck=skipCACheck()) + skipCACheck=skipCACheck(), + httpsClient=BundleDeliveryJSONClient) def __getHash(self, bundleID, dirToSyncTo): """ Get hash for bundle in directory @@ -63,7 +75,7 @@ def __setHash(self, bundleID, dirToSyncTo, bdHash): """ try: fileName = os.path.join(dirToSyncTo, ".dab.%s" % bundleID) - with open(fileName, "wt") as fd: + with open(fileName, "wb") as fd: fd.write(bdHash) except Exception as e: self.log.error("Could not save hash after synchronization", "%s: %s" % (fileName, str(e))) diff --git a/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py b/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py index 2e57294f5a4..ecb390886d9 100644 --- a/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py @@ -9,8 +9,10 @@ import six import tarfile import os -from DIRAC.Core.DISET.RequestHandler import RequestHandler +from base64 import b64encode, b64decode + from DIRAC import gLogger, S_OK, S_ERROR, gConfig +from DIRAC.Core.Tornado.Server.TornadoService import TornadoService from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler from DIRAC.Core.Utilities import File, List from DIRAC.Core.Security import Locations, Utilities @@ -88,7 +90,7 @@ def updateBundles(self): self.__bundles[bId] = (None, None) -class BundleDeliveryHandler(RequestHandler): +class BundleDeliveryHandler(TornadoService): @classmethod def initializeHandler(cls, serviceInfoDict): @@ -99,6 +101,39 @@ def initializeHandler(cls, serviceInfoDict): gThreadScheduler.addPeriodicTask(updateBundleTime, cls.bundleManager.updateBundles) return S_OK() + + types_streamToClient = [] + + def export_streamToClient(self, fileId): + version = "" + if isinstance(fileId, six.string_types): + if fileId in ['CAs', 'CRLs']: + retVal = Utilities.generateCAFile() if fileId == 'CAs' else Utilities.generateRevokedCertsFile() + if not retVal['OK']: + return retVal + with open(retVal['Value'], 'r') as fd: + return S_OK(b64encode(fd.read()).decode()) + bId = fileId + + elif isinstance(fileId, (list, tuple)): + if len(fileId) == 0: + return S_ERROR("No bundle specified!") + bId = fileId[0] + if len(fileId) != 1: + version = fileId[1] + + if not self.bundleManager.bundleExists(bId): + return S_ERROR("Unknown bundle %s" % bId) + + bundleVersion = self.bundleManager.getBundleVersion(bId) + if bundleVersion is None: + return S_ERROR("Empty bundle %s" % bId) + + if version == bundleVersion: + return S_OK(bundleVersion) + + return S_OK(b64encode(self.bundleManager.getBundleData(bId)).decode()) + types_getListOfBundles = [] @classmethod diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 7b104e597a9..08c065f4ed8 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -330,7 +330,7 @@ def deviceAuthorization(self, group=None): # Notify user to go to authorization endpoint if response.get('verification_uri_complete'): - showURL = 'Use next link to continue"\n%s' % response['verification_uri_complete'] + showURL = 'Use next link to continue\n%s' % response['verification_uri_complete'] else: showURL = 'Use next link to continue, your user code is "%s"\n%s' % (response['user_code'], response['verification_uri']) From a96772e108b9c2a46938ec85ce47c791d3ed7db2 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 19:34:28 +0200 Subject: [PATCH 146/178] add info in dirac-info --- src/DIRAC/Core/scripts/dirac_info.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/DIRAC/Core/scripts/dirac_info.py b/src/DIRAC/Core/scripts/dirac_info.py index ebb8bd8834e..3268739c467 100755 --- a/src/DIRAC/Core/scripts/dirac_info.py +++ b/src/DIRAC/Core/scripts/dirac_info.py @@ -60,6 +60,8 @@ def platform(arg): records = [] records.append(('Setup', gConfig.getValue('/DIRAC/Setup', 'Unknown'))) + records.append(('AuthorizationServer', gConfig.getValue('/DIRAC/Security/Authorization/issuer', + '/DIRAC/Security/Authorization/issuer option is absent'))) records.append(('ConfigurationServer', gConfig.getValue('/DIRAC/Configuration/Servers', []))) records.append(('Installation path', DIRAC.rootPath)) @@ -88,6 +90,10 @@ def platform(arg): records.append(('Use Server Certificate', 'Yes')) else: records.append(('Use Server Certificate', 'No')) + useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") + if not useTokens: + useTokens = gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true") + records.append(('Use tokens', 'Yes' if useTokens else 'No')) if gConfig.getValue('/DIRAC/Security/SkipCAChecks', False): records.append(('Skip CA Checks', 'Yes')) else: From 66d8feb22ed36a03fef5a2607254493f90b0d071 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 19:40:02 +0200 Subject: [PATCH 147/178] set issuer if present --- src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 153767a407d..5b5c70e3c1d 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -154,6 +154,12 @@ def doOAuthMagic(self): tokenFile = result['Value'] gLogger.notice('New token is saved to %s.' % tokenFile) + if not DIRAC.gConfig.getValue('/DIRAC/Security/Authorization/issuer'): + gLogger.notice('To continue use token you need to add /DIRAC/Security/Authorization/issuer option.') + if not self.issuer: + DIRAC.exit(1) + DIRAC.gConfig.setOptionValue('/DIRAC/Security/Authorization/issuer', self.issuer) + # Try to get user information result = Script.enableCS() if not result['OK']: From 4db9fe6a0c6df026b0da8a646406fd772793af51 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 19:40:19 +0200 Subject: [PATCH 148/178] fix --- src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py | 1 + src/DIRAC/Resources/IdProvider/IdProviderFactory.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py index 2588875800f..ab765cf5d48 100644 --- a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py +++ b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py @@ -8,6 +8,7 @@ import getpass import tarfile from six import BytesIO +from base64 import b64decode from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Base.Client import Client, createClient diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index ca6fb604fce..f1ced5f37ae 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -74,7 +74,10 @@ def getIdProvider(self, name, **kwargs): if not name: return S_ERROR('Identity Provider client name must be not None.') # Get Authorization Server metadata - asMetaDict = collectMetadata(kwargs.get('issuer'), ignoreErrors=True) + try: + asMetaDict = collectMetadata(kwargs.get('issuer'), ignoreErrors=True) + except Exception as e: + return S_ERROR(str(e)) self.log.debug('Search configuration for', name) clients = getDIRACClients() if name in clients: From f69eafad10c3d3644ceed0b2d307bfdf5a27f08c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 20:02:38 +0200 Subject: [PATCH 149/178] add information about user auth --- .../FrameworkSystem/scripts/dirac_login.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 5b5c70e3c1d..4281411e5fd 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -18,8 +18,7 @@ import DIRAC from DIRAC import gLogger, S_OK, S_ERROR -from DIRAC.Core.Base import Script -from DIRAC.Core.Utilities.DIRACScript import DIRACScript +from DIRAC.Core.Utilities.DIRACScript import DIRACScript as Script from DIRAC.Core.Security.ProxyFile import writeToProxyFile from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory @@ -32,6 +31,7 @@ class Params(object): def __init__(self): + self.info = False self.provider = 'DIRACCLI' self.proxy = False self.group = None @@ -40,6 +40,14 @@ def __init__(self): self.proxyLoc = '/tmp/x509up_u%s' % os.getuid() self.tokenLoc = None + def getInfo(self, _arg): + """ To return user info + + :return: S_OK() + """ + self.info = True + return S_OK() + def returnProxy(self, _arg): """ To return proxy @@ -90,6 +98,7 @@ def setLivetime(self, arg): def registerCLISwitches(self): """ Register CLI switches """ + Script.registerSwitch("", "info", "output current user authorization status", self.getInfo) Script.registerSwitch("P", "proxy", "request a proxy certificate with DIRAC group extension", self.returnProxy) Script.registerSwitch("g:", "group=", "set DIRAC group", self.setGroup) Script.registerSwitch("I:", "issuer=", "set issuer", self.setIssuer) @@ -101,6 +110,26 @@ def doOAuthMagic(self): :return: S_OK()/S_ERROR() """ + tokenFile = getTokenFileLocation(self.tokenLoc) + + if self.info: + # Try to get user information + result = Script.enableCS() + if not result['OK']: + return S_ERROR("Cannot contact CS.") + useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") + if not useTokens and not DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', + 'false').lower() in ("y", "yes", "true"): + result = getProxyInfo(self.proxyLoc) + if not result['OK']: + return result['Message'] + gLogger.notice(formatProxyInfoAsString(result['Value'])) + else: + result = readTokenFromFile(tokenFile) + if not result['OK']: + return result + gLogger.notice(result['Value'].getInfoAsString()) + return S_OK() params = {} if self.issuer: params['issuer'] = self.issuer @@ -117,8 +146,6 @@ def doOAuthMagic(self): scope.append('lifetime:%s' % (int(self.lifetime) * 3600)) idpObj.scope = '+'.join(scope) if scope else '' - tokenFile = getTokenFileLocation(self.tokenLoc) - # Submit Device authorisation flow result = idpObj.deviceAuthorization() if not result['OK']: @@ -163,7 +190,7 @@ def doOAuthMagic(self): # Try to get user information result = Script.enableCS() if not result['OK']: - return S_ERROR("Cannot contact CS to get user list") + return S_ERROR("Cannot contact CS.") DIRAC.gConfig.forceRefresh() if self.proxy: @@ -180,7 +207,7 @@ def doOAuthMagic(self): return S_OK() -@DIRACScript() +@Script() def main(): piParams = Params() piParams.registerCLISwitches() From 1150ffe77de337cb9c0b7520244690b5a17739ed Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 20:08:59 +0200 Subject: [PATCH 150/178] add notification --- src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 4281411e5fd..bbddbe61e7b 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -120,11 +120,13 @@ def doOAuthMagic(self): useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") if not useTokens and not DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true"): + gLogger.notice('You use proxy, to use access token set "DIRAC_USE_ACCESS_TOKEN=True" env.\n') result = getProxyInfo(self.proxyLoc) if not result['OK']: return result['Message'] gLogger.notice(formatProxyInfoAsString(result['Value'])) else: + gLogger.notice('You use access token, to use proxy set "DIRAC_USE_ACCESS_TOKEN=False" env.\n') result = readTokenFromFile(tokenFile) if not result['OK']: return result From 9f83748cd9acd2fc5820abae3fd00072ad1c55a3 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 21:08:02 +0200 Subject: [PATCH 151/178] fix BundleDelivery --- .../Service/BundleDeliveryHandler.py | 39 +--- .../Service/TornadoBundleDeliveryHandler.py | 207 ++++++++++++++++++ .../private/authorization/AuthServer.py | 1 + 3 files changed, 210 insertions(+), 37 deletions(-) create mode 100644 src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py diff --git a/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py b/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py index ecb390886d9..2e57294f5a4 100644 --- a/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/BundleDeliveryHandler.py @@ -9,10 +9,8 @@ import six import tarfile import os -from base64 import b64encode, b64decode - +from DIRAC.Core.DISET.RequestHandler import RequestHandler from DIRAC import gLogger, S_OK, S_ERROR, gConfig -from DIRAC.Core.Tornado.Server.TornadoService import TornadoService from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler from DIRAC.Core.Utilities import File, List from DIRAC.Core.Security import Locations, Utilities @@ -90,7 +88,7 @@ def updateBundles(self): self.__bundles[bId] = (None, None) -class BundleDeliveryHandler(TornadoService): +class BundleDeliveryHandler(RequestHandler): @classmethod def initializeHandler(cls, serviceInfoDict): @@ -101,39 +99,6 @@ def initializeHandler(cls, serviceInfoDict): gThreadScheduler.addPeriodicTask(updateBundleTime, cls.bundleManager.updateBundles) return S_OK() - - types_streamToClient = [] - - def export_streamToClient(self, fileId): - version = "" - if isinstance(fileId, six.string_types): - if fileId in ['CAs', 'CRLs']: - retVal = Utilities.generateCAFile() if fileId == 'CAs' else Utilities.generateRevokedCertsFile() - if not retVal['OK']: - return retVal - with open(retVal['Value'], 'r') as fd: - return S_OK(b64encode(fd.read()).decode()) - bId = fileId - - elif isinstance(fileId, (list, tuple)): - if len(fileId) == 0: - return S_ERROR("No bundle specified!") - bId = fileId[0] - if len(fileId) != 1: - version = fileId[1] - - if not self.bundleManager.bundleExists(bId): - return S_ERROR("Unknown bundle %s" % bId) - - bundleVersion = self.bundleManager.getBundleVersion(bId) - if bundleVersion is None: - return S_ERROR("Empty bundle %s" % bId) - - if version == bundleVersion: - return S_OK(bundleVersion) - - return S_OK(b64encode(self.bundleManager.getBundleData(bId)).decode()) - types_getListOfBundles = [] @classmethod diff --git a/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py new file mode 100644 index 00000000000..fc498374c32 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py @@ -0,0 +1,207 @@ +""" Handler for CAs + CRLs bundles +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import six +import tarfile +import os +from base64 import b64encode, b64decode + +from DIRAC import gLogger, S_OK, S_ERROR, gConfig +from DIRAC.Core.Tornado.Server.TornadoService import TornadoService +from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler +from DIRAC.Core.Utilities import File, List +from DIRAC.Core.Security import Locations, Utilities + + +class BundleManager(object): + + def __init__(self, baseCSPath): + self.__csPath = baseCSPath + self.__bundles = {} + self.updateBundles() + + def __getDirsToBundle(self): + dirsToBundle = {} + result = gConfig.getOptionsDict("%s/DirsToBundle" % self.__csPath) + if result['OK']: + dB = result['Value'] + for bId in dB: + dirsToBundle[bId] = List.fromChar(dB[bId]) + if gConfig.getValue("%s/BundleCAs" % self.__csPath, True): + dirsToBundle['CAs'] = [ + "%s/*.0" % + Locations.getCAsLocation(), + "%s/*.signing_policy" % + Locations.getCAsLocation(), + "%s/*.pem" % + Locations.getCAsLocation()] + if gConfig.getValue("%s/BundleCRLs" % self.__csPath, True): + dirsToBundle['CRLs'] = ["%s/*.r0" % Locations.getCAsLocation()] + return dirsToBundle + + def getBundles(self): + return dict([(bId, self.__bundles[bId]) for bId in self.__bundles]) + + def bundleExists(self, bId): + return bId in self.__bundles + + def getBundleVersion(self, bId): + try: + return self.__bundles[bId][0] + except Exception: + return "" + + def getBundleData(self, bId): + try: + return self.__bundles[bId][1] + except Exception: + return "" + + def updateBundles(self): + dirsToBundle = self.__getDirsToBundle() + # Delete bundles that don't have to be updated + for bId in self.__bundles: + if bId not in dirsToBundle: + gLogger.info("Deleting old bundle %s" % bId) + del(self.__bundles[bId]) + for bId in dirsToBundle: + bundlePaths = dirsToBundle[bId] + gLogger.info("Updating %s bundle %s" % (bId, bundlePaths)) + buffer_ = six.BytesIO() + filesToBundle = sorted(File.getGlobbedFiles(bundlePaths)) + if filesToBundle: + commonPath = os.path.commonprefix(filesToBundle) + commonEnd = len(commonPath) + gLogger.info("Bundle will have %s files with common path %s" % (len(filesToBundle), commonPath)) + with tarfile.open('dummy', "w:gz", buffer_) as tarBuffer: + for filePath in filesToBundle: + tarBuffer.add(filePath, filePath[commonEnd:]) + zippedData = buffer_.getvalue() + buffer_.close() + hash_ = File.getMD5ForFiles(filesToBundle) + gLogger.info("Bundled %s : %s bytes (%s)" % (bId, len(zippedData), hash_)) + self.__bundles[bId] = (hash_, zippedData) + else: + self.__bundles[bId] = (None, None) + + +class TornadoBundleDeliveryHandler(TornadoService): + + @classmethod + def initializeHandler(cls, serviceInfoDict): + csPath = serviceInfoDict['serviceSectionPath'] + cls.bundleManager = BundleManager(csPath) + updateBundleTime = gConfig.getValue("%s/BundlesLifeTime" % csPath, 3600 * 6) + gLogger.info("Bundles will be updated each %s secs" % updateBundleTime) + gThreadScheduler.addPeriodicTask(updateBundleTime, cls.bundleManager.updateBundles) + return S_OK() + + + types_streamToClient = [] + + def export_streamToClient(self, fileId): + version = "" + if isinstance(fileId, six.string_types): + if fileId in ['CAs', 'CRLs']: + retVal = Utilities.generateCAFile() if fileId == 'CAs' else Utilities.generateRevokedCertsFile() + if not retVal['OK']: + return retVal + with open(retVal['Value'], 'r') as fd: + return S_OK(b64encode(fd.read()).decode()) + bId = fileId + + elif isinstance(fileId, (list, tuple)): + if len(fileId) == 0: + return S_ERROR("No bundle specified!") + bId = fileId[0] + if len(fileId) != 1: + version = fileId[1] + + if not self.bundleManager.bundleExists(bId): + return S_ERROR("Unknown bundle %s" % bId) + + bundleVersion = self.bundleManager.getBundleVersion(bId) + if bundleVersion is None: + return S_ERROR("Empty bundle %s" % bId) + + if version == bundleVersion: + return S_OK(bundleVersion) + + return S_OK(b64encode(self.bundleManager.getBundleData(bId)).decode()) + + types_getListOfBundles = [] + + @classmethod + def export_getListOfBundles(cls): + return S_OK(cls.bundleManager.getBundles()) + + def transfer_toClient(self, fileId, token, fileHelper): + version = "" + if isinstance(fileId, six.string_types): + if fileId in ['CAs', 'CRLs']: + return self.__transferFile(fileId, fileHelper) + else: + bId = fileId + elif isinstance(fileId, (list, tuple)): + if len(fileId) == 0: + fileHelper.markAsTransferred() + return S_ERROR("No bundle specified!") + elif len(fileId) == 1: + bId = fileId[0] + else: + bId = fileId[0] + version = fileId[1] + if not self.bundleManager.bundleExists(bId): + fileHelper.markAsTransferred() + return S_ERROR("Unknown bundle %s" % bId) + + bundleVersion = self.bundleManager.getBundleVersion(bId) + if bundleVersion is None: + fileHelper.markAsTransferred() + return S_ERROR("Empty bundle %s" % bId) + + if version == bundleVersion: + fileHelper.markAsTransferred() + return S_OK(bundleVersion) + + buffer_ = six.BytesIO(self.bundleManager.getBundleData(bId)) + result = fileHelper.DataSourceToNetwork(buffer_) + buffer_.close() + if not result['OK']: + return result + return S_OK(bundleVersion) + + def __transferFile(self, filetype, fileHelper): + """ + This file is creates and transfers the CAs or CRLs file to the client. + :param str filetype: we can define which file will be transfered to the client + :param object fileHelper: + :return: S_OK or S_ERROR + """ + if filetype == 'CAs': + retVal = Utilities.generateCAFile() + elif filetype == 'CRLs': + retVal = Utilities.generateRevokedCertsFile() + else: + return S_ERROR("Not supported file type %s" % filetype) + + if not retVal['OK']: + return retVal + else: + result = fileHelper.getFileDescriptor(retVal['Value'], 'r') + if not result['OK']: + result = fileHelper.sendEOF() + # better to check again the existence of the file + if not os.path.exists(retVal['Value']): + return S_ERROR('File %s does not exist' % os.path.basename(retVal['Value'])) + else: + return S_ERROR('Failed to get file descriptor') + fileDescriptor = result['Value'] + result = fileHelper.FDToNetwork(fileDescriptor) + fileHelper.oFile.close() # close the file and return + return result diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 0126486127a..f5345a27f42 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -444,6 +444,7 @@ def __registerNewUser(self, provider, payload): mail['subject'] = "[DIRAC AS] User %s to be added." % username mail['body'] = 'User %s was authenticated by %s.' % (username, provider) mail['body'] += "\n\nNew user to be added with the following information:\n%s" % pprint.pformat(payload) + mail['body'] += "\n\nPlease, add '%s' to /Register/Users//DN option.\n" % wrapIDAsDN(username) mail['body'] += "\n\n------" mail['body'] += "\n This is a notification from the DIRAC authorization service, please do not reply.\n" result = S_OK() From 708632b5c8e48346aeb9a8428800c2e04bfea12f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 21:57:14 +0200 Subject: [PATCH 152/178] add BundleDelivery to template --- src/DIRAC/FrameworkSystem/ConfigTemplate.cfg | 4 ++++ src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg index ec28555ee15..d239cd6c5d7 100644 --- a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg +++ b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg @@ -25,6 +25,10 @@ Services storeHostInfo = Operator } } + TornadoBundleDelivery + { + Protocol = https + } ##BEGIN TokenManager: # Section to describe TokenManager system TokenManager diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index f5345a27f42..8cbc81ebc78 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -271,6 +271,7 @@ def parseIdPAuthorizationResponse(self, response, session): result = getUsernameForDN(credDict['DN']) if not result['OK']: comment = '%s ID is not registred in the DIRAC.' % credDict['ID'] + payload.update(idpObj.getUserProfile().get('Value', {})) result = self.__registerNewUser(providerName, payload) if result['OK']: comment += ' Administrators have been notified about you.' From d73b499f4ab7601ef5cf83232a4019f4679f2d78 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 15 Aug 2021 23:25:25 +0200 Subject: [PATCH 153/178] fix pylint --- .../FrameworkSystem/Service/TornadoBundleDeliveryHandler.py | 1 - src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py index fc498374c32..6d7a8ad40fc 100644 --- a/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TornadoBundleDeliveryHandler.py @@ -101,7 +101,6 @@ def initializeHandler(cls, serviceInfoDict): gThreadScheduler.addPeriodicTask(updateBundleTime, cls.bundleManager.updateBundles) return S_OK() - types_streamToClient = [] def export_streamToClient(self, fileId): diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index bbddbe61e7b..cc412d23325 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -124,7 +124,7 @@ def doOAuthMagic(self): result = getProxyInfo(self.proxyLoc) if not result['OK']: return result['Message'] - gLogger.notice(formatProxyInfoAsString(result['Value'])) + gLogger.notice(formatProxyInfoAsString(result['Value'])) else: gLogger.notice('You use access token, to use proxy set "DIRAC_USE_ACCESS_TOKEN=False" env.\n') result = readTokenFromFile(tokenFile) From ccbc60624e3595a86426487b98c0530ebb1e658f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Tue, 17 Aug 2021 22:43:29 +0200 Subject: [PATCH 154/178] fix --- src/DIRAC/Core/scripts/dirac_info.py | 8 ++++---- src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/DIRAC/Core/scripts/dirac_info.py b/src/DIRAC/Core/scripts/dirac_info.py index 3268739c467..f9b38c330a9 100755 --- a/src/DIRAC/Core/scripts/dirac_info.py +++ b/src/DIRAC/Core/scripts/dirac_info.py @@ -90,10 +90,10 @@ def platform(arg): records.append(('Use Server Certificate', 'Yes')) else: records.append(('Use Server Certificate', 'No')) - useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") - if not useTokens: - useTokens = gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true") - records.append(('Use tokens', 'Yes' if useTokens else 'No')) + if gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true"): + records.append(('Use tokens', 'Yes')) + else: + records.append(('Use tokens', 'No')) if gConfig.getValue('/DIRAC/Security/SkipCAChecks', False): records.append(('Skip CA Checks', 'Yes')) else: diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index cc412d23325..69b5276f5e3 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -117,9 +117,10 @@ def doOAuthMagic(self): result = Script.enableCS() if not result['OK']: return S_ERROR("Cannot contact CS.") - useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") - if not useTokens and not DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', - 'false').lower() in ("y", "yes", "true"): + useTokens = DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true") + if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: + useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") + if useTokens: gLogger.notice('You use proxy, to use access token set "DIRAC_USE_ACCESS_TOKEN=True" env.\n') result = getProxyInfo(self.proxyLoc) if not result['OK']: From 3aceaebfec41daa398adf0f10c85d7a5cca13485 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 19 Aug 2021 01:57:22 +0200 Subject: [PATCH 155/178] fix rebase --- environment.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/environment.yml b/environment.yml index 6c0255a83d3..fcb11ffd10e 100644 --- a/environment.yml +++ b/environment.yml @@ -72,10 +72,6 @@ dependencies: # Pin OpenSSL to avoid: https://github.com/DIRACGrid/DIRAC/issues/4489 - openssl <1.1 - selectors2 - # OAuth2 - - authlib - - pyjwt - - dominate - pip: - diraccfg # OAuth2 From cbad2379d9d17a95405b172c09fb5a6a1ee6bf2c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 19 Aug 2021 05:47:23 +0200 Subject: [PATCH 156/178] add UI --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 133 ++++++++---------- .../private/authorization/AuthServer.py | 87 +++++++++--- .../authorization/grants/DeviceFlow.py | 30 +++- .../private/authorization/utils/Utilities.py | 24 ++++ 4 files changed, 178 insertions(+), 96 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 38d3639d380..d4313f75eaf 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -13,7 +13,7 @@ import json import pprint -from dominate import document, tags as dom +from dominate import tags as dom from tornado.template import Template from tornado.concurrent import Future @@ -26,6 +26,7 @@ from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML __RCSID__ = "$Id$" @@ -36,48 +37,6 @@ class AuthHandler(TornadoREST): SYSTEM = 'Framework' AUTH_PROPS = 'all' LOCATION = "/auth" - css_align_center = 'display:block;justify-content:center;align-items:center;' - css_center_div = 'height:700px;width:100%;position:absolute;top:50%;left:0;margin-top:-350px;' - css_big_text = 'font-size:28px;' - css_main = ' '.join([css_align_center, css_center_div, css_big_text]) - CSS = """ -.button { - border-radius: 4px; - background-color: #ffffff00; - border: none; - color: black; - text-align: center; - font-size: 14px; - padding: 12px; - width: 100%; - transition: all 0.5s; - cursor: pointer; - margin: 5px; - display: block; /* Make the links appear below each other */ -} -.button a { - color: black; - cursor: pointer; - display: inline-block; - position: relative; - transition: 0.5s; - text-decoration: none; /* Remove underline from links */ -} -.button a:after { - content: '\\00bb'; - position: absolute; - opacity: 0; - top: 0; - right: -20px; - transition: 0.5s; -} -.button:hover a { - padding-right: 25px; -} -.button:hover a:after { - opacity: 1; - right: 0; -}""" @classmethod def initializeHandler(cls, serviceInfo): @@ -86,18 +45,11 @@ def initializeHandler(cls, serviceInfo): :param dict ServiceInfoDict: infos about services """ cls.server = AuthServer() - cls.server.css = dict(CSS=cls.CSS, css_align_center=cls.css_align_center, css_main=cls.css_main) cls.server.LOCATION = cls.LOCATION def initializeRequest(self): """ Called at every request """ self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path - # Template for a html UI - self.doc = document('DIRAC authentication') - with self.doc.head: - dom.link(rel='stylesheet', - href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") - dom.style(self.CSS) def _finishFuture(self, retVal): """ Handler Future result @@ -112,8 +64,11 @@ def _finishFuture(self, retVal): # S_ERROR is interpreted in the OAuth2 error format. self.set_status(400) self.clear_cookie('auth_session') - self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) - self.finish({'error': 'server_error', 'description': retVal['Message']}) + if retVal['Message'].startswith(""): + self.finish(retVal['Message']) + else: + self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) + self.finish({'error': 'server_error', 'description': retVal['Message']}) else: super(AuthHandler, self)._finishFuture(retVal) @@ -320,11 +275,23 @@ def web_device(self, provider=None, user_code=None): return self.server.handle_response(302, {}, [("Location", authURL)], session) # If received a request without a user code, then send a form to enter the user code - with self.doc: - dom.div(dom.form(dom._input(type="text", name="user_code", style=self.css_big_text), - dom.button('Submit', type="submit", style=self.css_big_text), - action=self.currentPath, method="GET"), style=self.css_main) - return Template(self.doc.render()).generate() + html = getHTML('device flow') + with html: + with dom.div(cls="container"): + with dom.div(cls="row m-5 justify-content-md-center align-items-center"): + dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") + with dom.div(cls="col-md-4"): + dom.small(dom.i(cls="fa fa-ticket-alt")) + dom.small('user code verification..', cls="p-3 h6") + dom.br() + dom.br() + dom.small(dom.i(cls="fa fa-info")) + dom.small('Device flow required user code. You will need to type user code to continue.', cls="p-2") + with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")): + dom.div(dom.form(dom._input(type="text", name="user_code"), + dom.button('Submit', type="submit", cls="btn btn-submit"), + action=self.currentPath, method="GET"), cls='card') + return html.render() path_authorization = ['([A-z0-9-_]*)'] @@ -473,18 +440,40 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): return None, S_ERROR('No groups found for %s and for %s Identity Provider.' % (username, provider)) self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) - if not firstRequest.groups: - if len(validGroups) == 1: - firstRequest.addScopes(['g:%s' % validGroups[0]]) - else: - # Choose group interface - with self.doc: - with dom.div(style=self.css_main): - with dom.div('Choose group', style=self.css_align_center): - for group in validGroups: - dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group)), - cls='button') - return None, self.server.handle_response(payload=Template(self.doc.render()).generate(), newSession=extSession) - - # Return grant user - return extSession['authed'], firstRequest + + # If group already defined in first request, just return it + if firstRequest.groups: + return extSession['authed'], firstRequest + + # If not and we found only one valid group, apply this group + if len(validGroups) == 1: + firstRequest.addScopes(['g:%s' % validGroups[0]]) + return extSession['authed'], firstRequest + + # Else give user chanse to choose group in browser + # Return choose group HTML interface + html = getHTML('group selection') + with html: + with dom.div(cls="container"): + with dom.div(cls="row m-5 justify-content-md-center align-items-center"): + dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") + with dom.div(cls="col-md-4"): + dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) + dom.small('user code verified.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) + dom.small('Identity Provider selected.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-users")) + dom.small('DIRAC group selection..', cls="p-3 h6") + dom.br() + dom.br() + dom.small(dom.i(cls="fa fa-info")) + dom.small('Dirac use groups to describe permissions.', cls="p-2") + dom.small('You will need to select one of the groups to continue.', cls="p-2") + with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")): + for group in validGroups: + with dom.div(cls="card shadow-sm border-0 text-center m-5 p-2"): + dom.h4(group, cls="p-2") + dom.a(href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group), cls="stretched-link") + return None, self.server.handle_response(payload=html.render(), newSession=extSession) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 8cbc81ebc78..36cb0b1735b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -21,7 +21,7 @@ from DIRAC import gLogger, S_OK, S_ERROR from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB -from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance +from DIRAC.Resources.IdProvider.Utilities import getProvidersForInstance, getProviderInfo from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory from DIRAC.ConfigurationSystem.Client.Utilities import isDownloadablePersonalProxy from DIRAC.ConfigurationSystem.Client.Helpers.Registry import (getUsernameForDN, getEmailsForGroup, wrapIDAsDN, @@ -31,7 +31,7 @@ from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import TornadoResponse from DIRAC.FrameworkSystem.private.authorization.utils.Clients import getDIRACClients, Client from DIRAC.FrameworkSystem.private.authorization.utils.Requests import OAuth2Request, createOAuth2Request -from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import collectMetadata, getHTML from DIRAC.FrameworkSystem.private.authorization.grants.RevokeToken import RevocationEndpoint from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, @@ -51,7 +51,6 @@ class AuthServer(_AuthorizationServer): server = AuthServer() """ - css = {} LOCATION = None REFRESH_TOKEN_EXPIRES_IN = 24 * 3600 @@ -270,14 +269,37 @@ def parseIdPAuthorizationResponse(self, response, session): # Is ID registred? result = getUsernameForDN(credDict['DN']) if not result['OK']: - comment = '%s ID is not registred in the DIRAC.' % credDict['ID'] + comment = 'Your ID is not registred in the DIRAC: %s.' % credDict['ID'] payload.update(idpObj.getUserProfile().get('Value', {})) result = self.__registerNewUser(providerName, payload) + if result['OK']: comment += ' Administrators have been notified about you.' else: comment += ' Please, contact the DIRAC administrators.' - return S_ERROR(comment) + + # Notify user about problem + html = getHTML("unregister") + # Create HTML page + with html: + with dom.div(cls="container"): + with dom.div(cls="row m-5 justify-content-md-center align-items-center"): + dom.div(dom.img(src=self.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") + with dom.div(cls="col-md-4"): + dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) + dom.small('user code verified.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) + dom.small('Identity Provider selected.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-exclamation-circle", style="color:red;")) + dom.small('authorization failed.', cls="p-3 h6") + dom.br() + dom.br() + dom.small(dom.i(cls="fa fa-info")) + dom.small(comment, cls="p-2") + + return S_ERROR(html.render()) credDict['username'] = result['Value'] # Update token for user. This token will be stored separately in the database and @@ -409,25 +431,44 @@ def validateIdentityProvider(self, request, provider): # If no identity provider is specified, it must be assigned if groupProvider: request.provider = groupProvider - elif len(idPs) == 1: - # If only one identity provider is registered, then choose it + return request, None + + # If only one identity provider is registered, then choose it + if len(idPs) == 1: request.provider = idPs[0] - else: - # Choose IdP interface - doc = document('DIRAC authentication') - with doc.head: - dom.link(rel='stylesheet', - href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") - dom.style(self.css['CSS']) - with doc: - with dom.div(style=self.css['css_main']): - with dom.div('Choose identity provider', style=self.css['css_align_center']): - for idP in idPs: - # data: Status, Comment, Action - dom.button(dom.a(idP, href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query)), - cls='button') - return None, self.handle_response(payload=Template(doc.render()).generate()) - return request, None + return request, None + + # Choose IdP HTML interface + html = getHTML("IdP selector", style=".card{transition:.3s;}.card:hover{transform:scale(1.05);}") + # Create HTML page + with html: + with dom.div(cls="container"): + with dom.div(cls="row m-5 justify-content-md-center align-items-center"): + dom.div(dom.img(src=self.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") + with dom.div(cls="col-md-4"): + dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) + dom.small('user code verified.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-user-check")) + dom.small('Identity Provider selection..', cls="p-3 h6") + dom.br() + dom.br() + dom.small(dom.i(cls="fa fa-info")) + dom.small('Dirac itself is not an Identity Provider. You will need to select one to continue.', cls="p-2") + with dom.div(cls="row m-5 justify-content-md-center"): + for idP in idPs: + result = getProviderInfo(idP) + if result['OK']: + logo = result['Value'].get('logoURL') + with dom.div(cls="col-lg-4").add(dom.div(cls="card shadow-lg h-100 border-0")): + with dom.div(cls="row m-2 align-items-center h-100"): + with dom.div(cls="col-lg-8"): + dom.h2(idP) + dom.a(href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query), cls="stretched-link") + if logo: + dom.div(dom.img(src=logo, cls="card-img"), cls="col-lg-4") + # Render into HTML + return None, html.render() def __registerNewUser(self, provider, payload): """ Register new user diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index aa2871ed3c5..f975e591493 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -3,12 +3,14 @@ from __future__ import print_function import time +from dominate import document, tags as dom from authlib.oauth2 import OAuth2Error from authlib.oauth2.rfc6749.grants import AuthorizationEndpointMixin from authlib.oauth2.rfc6749.errors import InvalidClientError, UnauthorizedClientError from authlib.oauth2.rfc8628 import (DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, DeviceCodeGrant as _DeviceCodeGrant, DeviceCredentialDict) +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): @@ -93,7 +95,33 @@ def create_authorization_response(self, redirect_uri, user): result = self.server.db.updateSession(data, data['id']) if not result['OK']: raise OAuth2Error('Cannot save authorization result', result['Message']) - return 200, 'Authorization complite.' + + # Notify user that authorization complite. + html = getHTML("authorization complite") + # Create HTML page + with html: + with dom.div(cls="container"): + with dom.div(cls="row m-5 justify-content-md-center align-items-center"): + dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") + with dom.div(cls="col-md-4"): + dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) + dom.small('user code verified.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) + dom.small('Identity Provider selected.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-users", style="color:green;")) + dom.small('DIRAC group selected.', cls="p-3 h6") + dom.br() + dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) + dom.small('authorization complite!', cls="p-3 h6") + dom.br() + dom.br() + dom.small(dom.i(cls="fa fa-info")) + dom.small('Authorization has been completed, you can now return to the terminal and close this window.', + cls="p-2") + + return 200, html.render() def query_device_credential(self, device_code): """ Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 102351c3872..3fdb135e9fe 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -4,6 +4,7 @@ __RCSID__ = "$Id$" +from dominate import document, tags as dom from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorizationServerMetadata @@ -35,3 +36,26 @@ def collectMetadata(issuer=None, ignoreErrors=False): metadata['response_types_supported'] = ['code', 'device', 'token'] metadata['code_challenge_methods_supported'] = ['S256'] return AuthorizationServerMetadata(metadata) + + +def getHTML(title, style=None): + """ Provide HTML object + + :param str title: browser tab title + :param str style: css as string + + :return: HTML object + """ + html = document("DIRAC - %s" % title) + with html.head: + dom.script(src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/js/all.min.js") + # Enable bootstrap 5 + dom.link(rel='stylesheet', integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC", + href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css", crossorigin="anonymous") + dom.script(src='https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/js/bootstrap.bundle.min.js', + integrity="sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM", + crossorigin="anonymous") + # Provide additional css + if style: + dom.style(style) + return html From cce156afdcfdfa048fe65c5fb303214600ef66bb Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 19 Aug 2021 09:36:16 +0200 Subject: [PATCH 157/178] fix pylint --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 2 +- src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index d4313f75eaf..cf386e06a2c 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -449,7 +449,7 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): if len(validGroups) == 1: firstRequest.addScopes(['g:%s' % validGroups[0]]) return extSession['authed'], firstRequest - + # Else give user chanse to choose group in browser # Return choose group HTML interface html = getHTML('group selection') diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 36cb0b1735b..6a6b05ff06b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -272,12 +272,12 @@ def parseIdPAuthorizationResponse(self, response, session): comment = 'Your ID is not registred in the DIRAC: %s.' % credDict['ID'] payload.update(idpObj.getUserProfile().get('Value', {})) result = self.__registerNewUser(providerName, payload) - + if result['OK']: comment += ' Administrators have been notified about you.' else: comment += ' Please, contact the DIRAC administrators.' - + # Notify user about problem html = getHTML("unregister") # Create HTML page From ffa1a2f6a20dcf638594ece016b14116ab894419 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 21 Aug 2021 21:21:30 +0200 Subject: [PATCH 158/178] unquote args --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 3acbf686b90..7d728693d4a 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -422,7 +422,7 @@ def _getMethodArgs(self, args): else: kwargs[arg] = self.get_argument(arg, defaults[arg]) - return (args, kwargs) + return ([unquote(a) for a in args], kwargs) def _getMethodAuthProps(self): """ Resolves the hard coded authorization requirements for method. From aa99c8d899ec136029c1be5fffb8c648095e9a1c Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 21 Aug 2021 21:22:00 +0200 Subject: [PATCH 159/178] quote args --- .../FrameworkSystem/private/authorization/utils/Requests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index c0c28e2cb52..ee5e1c66547 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -8,6 +8,7 @@ from authlib.common.encoding import to_unicode from authlib.oauth2 import OAuth2Request as _OAuth2Request from authlib.oauth2.rfc6749.util import scope_to_list +from six.moves.urllib.parse import quote __RCSID__ = "$Id$" @@ -25,12 +26,14 @@ def addScopes(self, scopes): def setQueryArguments(self, **kwargs): """ Set query arguments """ for k in kwargs: + # Quote value before add it to request query + value = '+'.join([quote(str(v)) for v in kwargs[k]]) if isinstance(kwargs[k], list) else quote(str(kwargs[k])) # Remove argument from uri query = re.sub(r"&{argument}(=[^&]*)?|^{argument}(=[^&]*)?&?".format(argument=k), "", self.query) # Add new one if query: query += '&' - query += "%s=%s" % (k, '+'.join(kwargs[k]) if isinstance(kwargs[k], list) else kwargs[k]) + query += "%s=%s" % (k, value) # Re-init class self.__init__(self.method, to_unicode(self.path + '?' + query)) From 08696f31ef79a8150a9dfe3803b6577b69e6ddee Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 21 Aug 2021 21:23:37 +0200 Subject: [PATCH 160/178] optimize --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 66 ++++++------------ .../private/authorization/AuthServer.py | 67 +++++------------- .../authorization/grants/DeviceFlow.py | 26 +------ .../private/authorization/utils/Utilities.py | 69 +++++++++++++++++-- 4 files changed, 103 insertions(+), 125 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index cf386e06a2c..599f0af96ba 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -196,7 +196,7 @@ def web_userinfo(self): """ return self.getRemoteCredentials() - path_device = ['([A-z0-9-_]*)'] + path_device = ['([A-z%0-9-_]*)'] def web_device(self, provider=None, user_code=None): """ The device authorization endpoint can be used to request device and user codes. @@ -258,7 +258,6 @@ def web_device(self, provider=None, user_code=None): return self.server.create_endpoint_response(DeviceAuthorizationEndpoint.ENDPOINT_NAME, self.request) elif self.request.method == 'GET': - # userCode = self.get_argument('user_code', None) if user_code: # If received a request with a user code, then prepare a request to authorization endpoint self.log.verbose('User code verification.') @@ -275,25 +274,14 @@ def web_device(self, provider=None, user_code=None): return self.server.handle_response(302, {}, [("Location", authURL)], session) # If received a request without a user code, then send a form to enter the user code - html = getHTML('device flow') - with html: - with dom.div(cls="container"): - with dom.div(cls="row m-5 justify-content-md-center align-items-center"): - dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") - with dom.div(cls="col-md-4"): - dom.small(dom.i(cls="fa fa-ticket-alt")) - dom.small('user code verification..', cls="p-3 h6") - dom.br() - dom.br() - dom.small(dom.i(cls="fa fa-info")) - dom.small('Device flow required user code. You will need to type user code to continue.', cls="p-2") - with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")): - dom.div(dom.form(dom._input(type="text", name="user_code"), - dom.button('Submit', type="submit", cls="btn btn-submit"), - action=self.currentPath, method="GET"), cls='card') - return html.render() - - path_authorization = ['([A-z0-9-_]*)'] + with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")) as tag: + dom.div(dom.form(dom._input(type="text", name="user_code"), + dom.button('Submit', type="submit", cls="btn btn-submit"), + action=self.currentPath, method="GET"), cls='card') + return getHTML('user code verification..', body=tag, icon='ticket-alt', + info='Device flow required user code. You will need to type user code to continue.',).render() + + path_authorization = ['([A-z%0-9-_]*)'] def web_authorization(self, provider=None): """ Authorization endpoint @@ -423,7 +411,7 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): # Base DIRAC client auth session firstRequest = createOAuth2Request(extSession['firstRequest']) # Read requested groups by DIRAC client or user - firstRequest.addScopes(chooseScope) # self.get_arguments('chooseScope')) + firstRequest.addScopes(chooseScope) # Read already authed user username = extSession['authed']['username'] # Requested arguments in first request @@ -452,28 +440,14 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): # Else give user chanse to choose group in browser # Return choose group HTML interface - html = getHTML('group selection') - with html: - with dom.div(cls="container"): - with dom.div(cls="row m-5 justify-content-md-center align-items-center"): - dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") - with dom.div(cls="col-md-4"): - dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) - dom.small('user code verified.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) - dom.small('Identity Provider selected.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-users")) - dom.small('DIRAC group selection..', cls="p-3 h6") - dom.br() - dom.br() - dom.small(dom.i(cls="fa fa-info")) - dom.small('Dirac use groups to describe permissions.', cls="p-2") - dom.small('You will need to select one of the groups to continue.', cls="p-2") - with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")): - for group in validGroups: - with dom.div(cls="card shadow-sm border-0 text-center m-5 p-2"): - dom.h4(group, cls="p-2") - dom.a(href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group), cls="stretched-link") + with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")) as tag: + for group in validGroups: + with dom.div(cls="card shadow-sm border-0 text-center m-5 p-2"): + dom.h4(group, cls="p-2") + dom.a(href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group), cls="stretched-link") + + html = getHTML('group selection..', body=tag, icon='users', + info='Dirac use groups to describe permissions.' + 'You will need to select one of the groups to continue.') + return None, self.server.handle_response(payload=html.render(), newSession=extSession) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 6a6b05ff06b..8afd7cb0c51 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -279,27 +279,9 @@ def parseIdPAuthorizationResponse(self, response, session): comment += ' Please, contact the DIRAC administrators.' # Notify user about problem - html = getHTML("unregister") - # Create HTML page - with html: - with dom.div(cls="container"): - with dom.div(cls="row m-5 justify-content-md-center align-items-center"): - dom.div(dom.img(src=self.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") - with dom.div(cls="col-md-4"): - dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) - dom.small('user code verified.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) - dom.small('Identity Provider selected.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-exclamation-circle", style="color:red;")) - dom.small('authorization failed.', cls="p-3 h6") - dom.br() - dom.br() - dom.small(dom.i(cls="fa fa-info")) - dom.small(comment, cls="p-2") - + html = getHTML("unregistered user!", info=comment, theme='warning') return S_ERROR(html.render()) + credDict['username'] = result['Value'] # Update token for user. This token will be stored separately in the database and @@ -439,36 +421,23 @@ def validateIdentityProvider(self, request, provider): return request, None # Choose IdP HTML interface - html = getHTML("IdP selector", style=".card{transition:.3s;}.card:hover{transform:scale(1.05);}") - # Create HTML page - with html: - with dom.div(cls="container"): - with dom.div(cls="row m-5 justify-content-md-center align-items-center"): - dom.div(dom.img(src=self.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") - with dom.div(cls="col-md-4"): - dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) - dom.small('user code verified.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-user-check")) - dom.small('Identity Provider selection..', cls="p-3 h6") - dom.br() - dom.br() - dom.small(dom.i(cls="fa fa-info")) - dom.small('Dirac itself is not an Identity Provider. You will need to select one to continue.', cls="p-2") - with dom.div(cls="row m-5 justify-content-md-center"): - for idP in idPs: - result = getProviderInfo(idP) - if result['OK']: - logo = result['Value'].get('logoURL') - with dom.div(cls="col-lg-4").add(dom.div(cls="card shadow-lg h-100 border-0")): - with dom.div(cls="row m-2 align-items-center h-100"): - with dom.div(cls="col-lg-8"): - dom.h2(idP) - dom.a(href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query), cls="stretched-link") - if logo: - dom.div(dom.img(src=logo, cls="card-img"), cls="col-lg-4") + with dom.div(cls="row m-5 justify-content-md-center") as tag: + for idP in idPs: + result = getProviderInfo(idP) + if result['OK']: + logo = result['Value'].get('logoURL') + with dom.div(cls="col-md-6 p-2").add(dom.div(cls="card shadow-lg h-100 border-0")): + with dom.div(cls="row m-2 justify-content-md-center align-items-center h-100"): + with dom.div(cls="col-auto"): + dom.h2(idP) + dom.a(href='%s/authorization/%s?%s' % (self.LOCATION, idP, request.query), cls="stretched-link") + if logo: + dom.div(dom.img(src=logo, cls="card-img"), cls="col-auto") + # Render into HTML - return None, html.render() + return None, getHTML("Identity Provider selection..", body=tag, icon='fingerprint', + info="Dirac itself is not an Identity Provider." + "You will need to select one to continue.").render() def __registerNewUser(self, provider, payload): """ Register new user diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index f975e591493..43a252506d4 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -97,30 +97,8 @@ def create_authorization_response(self, redirect_uri, user): raise OAuth2Error('Cannot save authorization result', result['Message']) # Notify user that authorization complite. - html = getHTML("authorization complite") - # Create HTML page - with html: - with dom.div(cls="container"): - with dom.div(cls="row m-5 justify-content-md-center align-items-center"): - dom.div(dom.img(src=self.server.metadata.get('logoURL', ''), cls="card-img p-5"), cls="col-md-4") - with dom.div(cls="col-md-4"): - dom.small(dom.i(cls="fa fa-ticket-alt", style="color:green;")) - dom.small('user code verified.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) - dom.small('Identity Provider selected.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-users", style="color:green;")) - dom.small('DIRAC group selected.', cls="p-3 h6") - dom.br() - dom.small(dom.i(cls="fa fa-user-check", style="color:green;")) - dom.small('authorization complite!', cls="p-3 h6") - dom.br() - dom.br() - dom.small(dom.i(cls="fa fa-info")) - dom.small('Authorization has been completed, you can now return to the terminal and close this window.', - cls="p-2") - + html = getHTML("authorization complite!", theme='success', + info='Authorization has been completed, you can now return to the terminal and close this window.') return 200, html.render() def query_device_credential(self, device_code): diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 3fdb135e9fe..5d8a0db73ba 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -38,16 +38,39 @@ def collectMetadata(issuer=None, ignoreErrors=False): return AuthorizationServerMetadata(metadata) -def getHTML(title, style=None): +def getHTML(title, info=None, body=None, style=None, state=None, theme=None, icon=None): """ Provide HTML object - :param str title: browser tab title - :param str style: css as string + :param str title: short name of the notification, e.g.: server error + :param str info: some short description if needed, e.g.: Seems it because server down. + :param body: dominate tag object, main content, e.g.: dom.pre(dom.code(result['Message'])) + :param str style: additional css style if needed, e.g.: '.card{color:red;}' + :param str state: response state, if needed, e.g.: 404 + :param str theme: message color theme, the same that in bootstrap 5. + :param str icon: awesome icon name - :return: HTML object + :return: HTML document object """ html = document("DIRAC - %s" % title) + + icon = icon or 'flask' + theme = theme or 'secondary' + if theme == 'warning': + icon = icon or 'exclamation-triangle' + elif theme == 'info': + icon = icon or 'info' + elif theme == 'success': + icon = icon or 'check' + elif theme == 'error': + theme = 'danger' + icon = icon or 'times' + + diracLogo = collectMetadata(ignoreErrors=True).get('logoURL', '') + + # Create head with html.head: + # Provide icons + dom.link(rel="icon", href="/static/core/img/icons/system/favicon.ico", type="image/x-icon") dom.script(src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/js/all.min.js") # Enable bootstrap 5 dom.link(rel='stylesheet', integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC", @@ -56,6 +79,40 @@ def getHTML(title, style=None): integrity="sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM", crossorigin="anonymous") # Provide additional css - if style: - dom.style(style) + style = ".card{transition:.3s;}.card:hover{transform:scale(1.03);}" + (style or '') + dom.style(style) + + # Create body + with html: + # Background image + dom.i(cls='position-absolute bottom-0 start-0 translate-middle-x fa fa-{icon} m-5 text-{theme}', + style="font-size:40vw;z-index:-1;") + + # A4 page with align center + with dom.div(cls="row vh-100 vw-100 justify-content-md-center align-items-center m-0"): + with dom.div(cls='container', style="max-width: 600px;").add(dom.div(cls="row align-items-center")): + + # Logo + dom.div(dom.img(src=diracLogo, cls="card-img px-2"), cls="col-md-6 my-3") + # Information card + with dom.div(cls="col-md-6 my-3"): + + # Show response state number + if state and state != 200: + dom.div(dom.h1(state, cls='text-center badge bg-{theme} text-wrap'), cls="row py-2") + + # Message title + with dom.div(cls="row"): + dom.div(dom.i(cls="fa fa-{icon} text-{theme}"), cls="col-auto") + dom.div(title, cls='col-auto ps-0 pb-2 fw-bold') + + # Description + if description: + dom.small(dom.i(cls="fa fa-info text-info")) + dom.small(info, cls="ps-1") + + # Add content + if body: + body + return html From fe292b0e9d7ab8f8f7a6ff5450f0aa2eee31a72f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 20:44:01 +0200 Subject: [PATCH 161/178] fix dirac-login --- .../FrameworkSystem/scripts/dirac_login.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 69b5276f5e3..905245562ab 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -114,25 +114,28 @@ def doOAuthMagic(self): if self.info: # Try to get user information - result = Script.enableCS() - if not result['OK']: - return S_ERROR("Cannot contact CS.") + Script.enableCS() + useTokens = DIRAC.gConfig.getValue('/DIRAC/Security/UseTokens', 'false').lower() in ("y", "yes", "true") if 'DIRAC_USE_ACCESS_TOKEN' in os.environ: useTokens = os.environ.get('DIRAC_USE_ACCESS_TOKEN', 'false').lower() in ("y", "yes", "true") if useTokens: - gLogger.notice('You use proxy, to use access token set "DIRAC_USE_ACCESS_TOKEN=True" env.\n') - result = getProxyInfo(self.proxyLoc) - if not result['OK']: - return result['Message'] - gLogger.notice(formatProxyInfoAsString(result['Value'])) - else: - gLogger.notice('You use access token, to use proxy set "DIRAC_USE_ACCESS_TOKEN=False" env.\n') + gLogger.notice('You are currently using access token to access new HTTP DIRAC services.' + ' To use a proxy instead, do the following:\n', + 'export DIRAC_USE_ACCESS_TOKEN=False\n') result = readTokenFromFile(tokenFile) - if not result['OK']: - return result - gLogger.notice(result['Value'].getInfoAsString()) - return S_OK() + if result['OK']: + gLogger.notice(result['Value'].getInfoAsString()) + else: + gLogger.notice('You are currently using proxy to access new HTTP DIRAC services.' + ' To use a access token instead, do the following:\n', + 'export DIRAC_USE_ACCESS_TOKEN=True\n') + result = getProxyInfo(self.proxyLoc) + if result['OK']: + gLogger.notice(formatProxyInfoAsString(result['Value'])) + + return result + params = {} if self.issuer: params['issuer'] = self.issuer From abb751b6a2bec9c238cf45c6cc3b801ec2e0102f Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 20:46:16 +0200 Subject: [PATCH 162/178] handle 404 error --- .../Tornado/Client/private/TornadoBaseClient.py | 2 ++ src/DIRAC/Core/Tornado/Server/TornadoServer.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 78dfc09b9c1..0e7696d825b 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -604,6 +604,8 @@ def _request(self, retry=0, outputFile=None, **kwargs): return S_ERROR(errno.ENOSYS, "%s is not implemented" % kwargs.get('method')) elif status_code in (http_client.FORBIDDEN, http_client.UNAUTHORIZED): return S_ERROR(errno.EACCES, "No access to %s" % url) + elif status_code == http_client.NOT_FOUND: + rawText = "%s is not found" % url # if it is something else, retry raise diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index b58d2ab32b6..40e37a09280 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -21,9 +21,10 @@ 'tornado_m2crypto.m2iostream.M2IOStream') # pylint: disable=wrong-import-position from tornado.httpserver import HTTPServer -from tornado.web import Application, url +from tornado.web import Application, url, RequestHandler from tornado.ioloop import IOLoop import tornado.ioloop +from six.moves import http_client import DIRAC from DIRAC import gConfig, gLogger, S_OK @@ -32,10 +33,18 @@ from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML sLog = gLogger.getSubLogger(__name__) +class NotFoundHandler(RequestHandler): + """ Handle 404 errors """ + def prepare(self): + self.set_status(http_client.NOT_FOUND.value) + self.finish(getHTML(http_client.NOT_FOUND)) + + class TornadoServer(object): """ Tornado webserver @@ -213,7 +222,7 @@ def startTornado(self): # Merge appllication settings settings.update(app['settings']) # Start server - router = Application(app['routes'], **settings) + router = Application(app['routes'], default_handler_class=NotFoundHandler, **settings) server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) try: server.listen(int(port)) From a2dc51c7926a8a1da7f1008dca7e20238cef11bf Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 20:46:53 +0200 Subject: [PATCH 163/178] add status code to TornadoResponse --- .../Tornado/Server/private/BaseRequestHandler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index 7d728693d4a..fe2c1799ea7 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -58,12 +58,14 @@ def web_myEndpotin(self): """ __attrs = inspect.getmembers(RequestHandler) - def __init__(self, data=None): + def __init__(self, payload=None, status_code=None): """ C'or - :param data: response body + :param payload: response body + :param int status_code: response status code """ - self.data = data + self.payload = payload + self.status_code = status_code self.actions = [] for mName, mObj in self.__attrs: if inspect.isroutine(mObj) and not mName.startswith('_') and not mName.startswith('get'): @@ -84,10 +86,12 @@ def _runActions(self, reqObj): :param reqObj: RequestHandler instance """ + if self.status_code: + reqObj.set_status(self.status_code) for mName, args, kwargs in self.actions: getattr(reqObj, mName)(*args, **kwargs) if not reqObj._finished: - reqObj.finish(self.data) + reqObj.finish(self.payload) class BaseRequestHandler(RequestHandler): From 69350fb24058fecfc7a5326d8ef7cf96ad750d4a Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 20:48:29 +0200 Subject: [PATCH 164/178] optimize, catch errors --- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 95 +++++++++---------- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 4 + .../private/authorization/AuthServer.py | 92 ++++++++++-------- .../authorization/grants/DeviceFlow.py | 12 ++- .../private/authorization/utils/Utilities.py | 90 ++++++++++++------ 5 files changed, 171 insertions(+), 122 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index 599f0af96ba..fb74512cf8b 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -51,30 +51,9 @@ def initializeRequest(self): """ Called at every request """ self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path - def _finishFuture(self, retVal): - """ Handler Future result - - :param object retVal: tornado.concurrent.Future - """ - # Wait result only if it's a Future object - self.result = retVal.result() if isinstance(retVal, Future) else retVal - - # Is it S_ERROR? - if isinstance(self.result, dict) and self.result.get('OK') is False and 'Message' in self.result: - # S_ERROR is interpreted in the OAuth2 error format. - self.set_status(400) - self.clear_cookie('auth_session') - if retVal['Message'].startswith(""): - self.finish(retVal['Message']) - else: - self.log.error('%s\n' % retVal['Message'], ''.join(retVal['CallStack'])) - self.finish({'error': 'server_error', 'description': retVal['Message']}) - else: - super(AuthHandler, self)._finishFuture(retVal) - path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] - def web_index(self, *instance): + def web_index(self, well_known, instance=None): """ Well known endpoint, specified by `RFC8414 `_ @@ -259,11 +238,12 @@ def web_device(self, provider=None, user_code=None): elif self.request.method == 'GET': if user_code: - # If received a request with a user code, then prepare a request to authorization endpoint + # If received a request with a user code, then prepare a request to authorization endpoint. self.log.verbose('User code verification.') result = self.server.db.getSessionByUserCode(user_code) - if not result['OK']: - return 'Device code flow authorization session %s expired.' % user_code + if not result['OK'] or not result['Value']: + return getHTML('session is expired.', theme='warning', body=result.get('Message'), + info='Seems device code flow authorization session %s expired.' % user_code) session = result['Value'] # Get original request from session req = createOAuth2Request(dict(method='GET', uri=session['uri'])) @@ -274,12 +254,13 @@ def web_device(self, provider=None, user_code=None): return self.server.handle_response(302, {}, [("Location", authURL)], session) # If received a request without a user code, then send a form to enter the user code - with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")) as tag: - dom.div(dom.form(dom._input(type="text", name="user_code"), - dom.button('Submit', type="submit", cls="btn btn-submit"), - action=self.currentPath, method="GET"), cls='card') + with dom.div(cls='row mt-5 justify-content-md-center') as tag: + with dom.div(cls="col-auto"): + dom.div(dom.form(dom._input(type="text", name="user_code"), + dom.button('Submit', type="submit", cls="btn btn-submit"), + action=self.currentPath, method="GET"), cls='card') return getHTML('user code verification..', body=tag, icon='ticket-alt', - info='Device flow required user code. You will need to type user code to continue.',).render() + info='Device flow required user code. You will need to type user code to continue.') path_authorization = ['([A-z%0-9-_]*)'] @@ -337,18 +318,25 @@ def web_redirect(self, state, error=None, error_description='', chooseScope=[]): :return: S_OK()/S_ERROR() """ - # Try to catch errors - if error: - return self.server.handle_error_response(state, OAuth2Error(error=error, description=error_description)) - # Check current auth session that was initiated for the selected external identity provider session = self.get_secure_cookie('auth_session') if not session: - return S_ERROR("%s session is expired." % state) + return self.server.handle_response( + payload=getHTML("session is expired.", theme="warning", state=400, + info=f"Seems {state} session is expired, please, try again."), delSession=True) sessionWithExtIdP = json.loads(session) if state and not sessionWithExtIdP.get('state') == state: - return S_ERROR("%s session is expired." % state) + return self.server.handle_response( + payload=getHTML("session is expired.", theme="warning", state=400, + info=f"Seems {state} session is expired, please, try again."), delSession=True) + + # Try to catch errors if the authorization on the selected identity provider was unsuccessful + if error: + provider = sessionWithExtIdP.get('Provider') + return self.server.handle_response( + payload=getHTML(error, theme="error", body=error_description, + info=f"Seems {state} session is failed on the {provider}'s' side."), delSession=True) if not sessionWithExtIdP.get('authed'): # Parse result of the second authentication flow @@ -356,7 +344,10 @@ def web_redirect(self, state, error=None, error_description='', chooseScope=[]): result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) if not result['OK']: - return result + if result['Message'].startswith(""): + return self.server.handle_response(payload=result['Message'], delSession=True) + return self.server.handle_response( + payload=getHTML("server error", state=500, info=result['Message']), delSession=True) # Return main session flow sessionWithExtIdP['authed'] = result['Value'] @@ -366,7 +357,10 @@ def web_redirect(self, state, error=None, error_description='', chooseScope=[]): return response # RESPONSE to basic DIRAC client request - return self.server.create_authorization_response(response, grant_user) + resp = self.server.create_authorization_response(response, grant_user) + if not resp.payload.startswith(""): + resp.payload = getHTML('authorization response', state=resp.status_code, body=resp.payload) + return resp def web_token(self): """ The token endpoint, the description of the parameters will differ depending on the selected grant_type @@ -406,7 +400,8 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): :param dict extSession: ended authorized external IdP session - :return: response + :return: -- will return (None, response) to provide error or group selector + will return (grant_user, request) to contionue authorization with choosed group """ # Base DIRAC client auth session firstRequest = createOAuth2Request(extSession['firstRequest']) @@ -421,11 +416,15 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): # Researche Group result = getGroupsForUser(username) if not result['OK']: - return None, result + return None, self.server.handle_response( + getHTML("server error", theme="error", info=result['Message']), delSession=True) groups = result['Value'] + validGroups = [group for group in groups if (getIdPForGroup(group) == provider) or ('proxy' in firstRequest.scope)] if not validGroups: - return None, S_ERROR('No groups found for %s and for %s Identity Provider.' % (username, provider)) + return None, self.server.handle_response(getHTML( + "groups not found.", theme="error", + info=f'No groups found for {username} and for {provider} Identity Provider.'), delSession=True) self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) @@ -439,15 +438,15 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): return extSession['authed'], firstRequest # Else give user chanse to choose group in browser - # Return choose group HTML interface - with dom.div(cls='row mt-5 justify-content-md-center').add(dom.div(cls="col-auto")) as tag: - for group in validGroups: - with dom.div(cls="card shadow-sm border-0 text-center m-5 p-2"): - dom.h4(group, cls="p-2") + with dom.div(cls='row mt-5 justify-content-md-center align-items-center') as tag: + for group in sorted(validGroups): + vo, gr = group.split('_') + with dom.div(cls="col-auto p-2").add(dom.div(cls="card shadow-lg border-0 text-center p-2")): + dom.h4(vo.upper() + ' ' + gr, cls="p-2") dom.a(href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, state, group), cls="stretched-link") html = getHTML('group selection..', body=tag, icon='users', - info='Dirac use groups to describe permissions.' + info='Dirac use groups to describe permissions. ' 'You will need to select one of the groups to continue.') - return None, self.server.handle_response(payload=html.render(), newSession=extSession) + return None, self.server.handle_response(payload=html, newSession=extSession) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 0b0d4063e5a..2ef49920042 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -308,6 +308,8 @@ def getSession(self, sessionID): session = self.session() try: resData = session.query(AuthSession).filter(AuthSession.id == sessionID).first() + if not resData: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) except MultipleResultsFound: return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) except NoResultFound: @@ -326,6 +328,8 @@ def getSessionByUserCode(self, userCode): session = self.session() try: resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() + if not resData: + return self.__result(session, S_ERROR("Session for %s user code is expired." % userCode)) except MultipleResultsFound: return self.__result(session, S_ERROR("%s is not unique ID." % userCode)) except NoResultFound: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 8afd7cb0c51..b8c619867f4 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -4,6 +4,7 @@ from __future__ import print_function import re +import six import sys import time import json @@ -230,7 +231,7 @@ def getIdPAuthorization(self, provider, request): """ result = self.idps.getIdProvider(provider) if not result['OK']: - return result + raise Exception(result['Message']) idpObj = result['Value'] authURL, state, session = idpObj.submitNewSession() session['state'] = state @@ -280,7 +281,7 @@ def parseIdPAuthorizationResponse(self, response, session): # Notify user about problem html = getHTML("unregistered user!", info=comment, theme='warning') - return S_ERROR(html.render()) + return S_ERROR(html) credDict['username'] = result['Value'] @@ -303,7 +304,7 @@ def validate_requested_scope(self, scope, state=None): extended_scope = list_to_scope([re.sub(r':.*$', ':', s) for s in scope_to_list(scope or '')]) super(AuthServer, self).validate_requested_scope(extended_scope, state) - def handle_response(self, status_code=None, payload=None, headers=None, newSession=None): + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, delSession=None): """ Handle response :param int status_code: http status code @@ -314,25 +315,34 @@ def handle_response(self, status_code=None, payload=None, headers=None, newSessi :return: TornadoResponse() """ self.log.debug('Handle authorization response with %s status code:' % status_code, payload) - resp = TornadoResponse(payload) - if status_code: - resp.set_status(status_code) # pylint: disable=no-member + resp = TornadoResponse(payload, status_code) if headers: self.log.debug('Headers:', headers) for key, value in headers: resp.set_header(key, value) # pylint: disable=no-member if newSession: - self.log.debug('newSession:', newSession) + self.log.debug('Initialize new session:', newSession) # pylint: disable=no-member resp.set_secure_cookie('auth_session', json.dumps(newSession), secure=True, httponly=True) - if isinstance(payload, dict) and 'error' in payload: + if delSession or isinstance(payload, dict) and 'error' in payload: resp.clear_cookie('auth_session') # pylint: disable=no-member return resp def create_authorization_response(self, response, username): - response = super(AuthServer, self).create_authorization_response(response, username) - response.clear_cookie('auth_session') - return response + """ Rewrite original Authlib method + `authlib.authlib.oauth2.rfc6749.authorization_server.create_authorization_response` + to catch errors and remove authorization session. + + :return: TornadoResponse object + """ + try: + response = super(AuthServer, self).create_authorization_response(response, username) + response.clear_cookie('auth_session') + return response + except Exception as e: + self.log.exception(e) + return self.handle_response( + payload=getHTML('server error', theme='error', body='traceback', info=repr(e)), delSession=True) def validate_consent_request(self, request, provider=None): """ Validate current HTTP request for authorization page. This page @@ -344,13 +354,18 @@ def validate_consent_request(self, request, provider=None): :return: response generated by `handle_response` or S_ERROR or html """ if request.method != 'GET': - return 'Use GET method to access this endpoint.' + return self.handle_response( + payload=getHTML("use GET method", theme="error", info='Use GET method to access this endpoint.'), + delSession=True) try: + request = self.create_oauth2_request(request) # Check Identity Provider - req, result = self.validateIdentityProvider(self.create_oauth2_request(request), provider) - if not req: - return result + req = self.validateIdentityProvider(request, provider) + # If return IdP selector + if isinstance(req, six.string_types): + return req + self.log.info('Validate consent request for', req.state) grant = self.get_authorization_grant(req) self.log.debug('Use grant:', grant) @@ -360,8 +375,17 @@ def validate_consent_request(self, request, provider=None): # Submit second auth flow through IdP return self.getIdPAuthorization(req.provider, req) + except OAuth2Error as error: - return self.handle_error_response(None, error) + self.db.removeSession(request.sessionID) + code, body, _ = error(None) + return self.handle_response( + payload=getHTML(repr(e), state=code, body=body, info='OAuth2 error.'), delSession=True) + except Exception as e: + self.db.removeSession(request.sessionID) + self.log.exception(e) + return self.handle_response( + payload=getHTML('server error', theme='error', body='traceback', info=repr(e)), delSession=True) def validateIdentityProvider(self, request, provider): """ Check if identity provider registred in DIRAC @@ -369,7 +393,7 @@ def validateIdentityProvider(self, request, provider): :param object request: request :param str provider: provider name - :return: str, S_OK()/S_ERROR() -- provider name and html page to choose it + :return: OAuth2Request object or HTML -- new request with provider name or provider selector """ if provider: request.provider = provider @@ -380,45 +404,35 @@ def validateIdentityProvider(self, request, provider): # If requested access token for group that is not registred in any identity provider # or the requested provider does not match the group return error if request.group and not groupProvider and 'proxy' not in request.scope: - self.db.removeSession(request.sessionID) - return None, S_ERROR('The %s group belongs to the VO that is not tied to any Identity Provider.' % request.group) - # if provider and provider != groupProvider: - self.db.removeSession(request.sessionID) - return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, - request.provider)) - # provider = groupProvider + raise Exception('The %s group belongs to the VO that is not tied to any Identity Provider.' % request.group) self.log.debug("Check if %s identity provider registred in DIRAC.." % request.provider) # Research supported IdPs result = getProvidersForInstance('Id') if not result['OK']: - self.db.removeSession(request.sessionID) - return None, result + raise Exception(result['Message']) idPs = result['Value'] if not idPs: - self.db.removeSession(request.sessionID) - return None, S_ERROR('No identity providers found.') + raise Exception('No identity providers found.') if request.provider: if request.provider not in idPs: - self.db.removeSession(request.sessionID) - return None, S_ERROR('%s identity provider is not registered.' % request.provider) + raise Exception('%s identity provider is not registered.' % request.provider) elif groupProvider and request.provider != groupProvider: - self.db.removeSession(request.sessionID) - return None, S_ERROR('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, - request.provider)) - return request, None + raise Exception('The %s group Identity Provider is "%s" and not "%s".' % (request.group, groupProvider, + request.provider)) + return request # If no identity provider is specified, it must be assigned if groupProvider: request.provider = groupProvider - return request, None + return request # If only one identity provider is registered, then choose it if len(idPs) == 1: request.provider = idPs[0] - return request, None + return request # Choose IdP HTML interface with dom.div(cls="row m-5 justify-content-md-center") as tag: @@ -435,9 +449,9 @@ def validateIdentityProvider(self, request, provider): dom.div(dom.img(src=logo, cls="card-img"), cls="col-auto") # Render into HTML - return None, getHTML("Identity Provider selection..", body=tag, icon='fingerprint', - info="Dirac itself is not an Identity Provider." - "You will need to select one to continue.").render() + return getHTML("Identity Provider selection..", body=tag, icon='fingerprint', + info="Dirac itself is not an Identity Provider. " + "You will need to select one to continue.") def __registerNewUser(self, provider, payload): """ Register new user diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py index 43a252506d4..34a4b144e7b 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -87,19 +87,21 @@ def create_authorization_response(self, redirect_uri, user): """ result = self.server.db.getSessionByUserCode(self.request.data['user_code']) if not result['OK']: - raise OAuth2Error(result['Message']) + return 500, getHTML("server error", theme='error', body=result['Message'], + info='Failed to read %s authorization session.' % self.request.data['user_code']) data = result['Value'] data.update(dict(user_id=user['ID'], uri=self.request.uri, username=user['username'], scope=self.request.scope)) # Save session with user result = self.server.db.updateSession(data, data['id']) if not result['OK']: - raise OAuth2Error('Cannot save authorization result', result['Message']) + return 500, getHTML("server error", theme='error', body=result['Message'], + info='Failed to save %s authorization session status.' % self.request.data['user_code']) # Notify user that authorization complite. - html = getHTML("authorization complite!", theme='success', - info='Authorization has been completed, you can now return to the terminal and close this window.') - return 200, html.render() + return 200, getHTML( + "authorization complite!", theme='success', + info='Authorization has been completed, now you can close this window and return to the terminal.') def query_device_credential(self, device_code): """ Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 5d8a0db73ba..62e1376287a 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -4,6 +4,9 @@ __RCSID__ = "$Id$" +import six +import traceback + from dominate import document, tags as dom from authlib.oauth2.rfc8414 import AuthorizationServerMetadata @@ -38,41 +41,66 @@ def collectMetadata(issuer=None, ignoreErrors=False): return AuthorizationServerMetadata(metadata) -def getHTML(title, info=None, body=None, style=None, state=None, theme=None, icon=None): +def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None, icon=None): """ Provide HTML object :param str title: short name of the notification, e.g.: server error :param str info: some short description if needed, e.g.: Seems it because server down. :param body: dominate tag object, main content, e.g.: dom.pre(dom.code(result['Message'])) :param str style: additional css style if needed, e.g.: '.card{color:red;}' - :param str state: response state, if needed, e.g.: 404 + :param int state: response state, if needed, e.g.: 404 :param str theme: message color theme, the same that in bootstrap 5. :param str icon: awesome icon name :return: HTML document object """ + if title and not isinstance(title, six.string_types): + # Expected HTTPStatus + state = title.value + info = title.description + title = title.phrase + html = document("DIRAC - %s" % title) - icon = icon or 'flask' - theme = theme or 'secondary' - if theme == 'warning': + if state in [400, 401, 403, 404]: + theme = theme or 'warning' + elif state in [500]: + theme = theme or 'danger' + elif state in [200]: + theme = theme or 'success' + + if theme in ['warning', 'warn']: + theme = 'warning' icon = icon or 'exclamation-triangle' elif theme == 'info': icon = icon or 'info' elif theme == 'success': icon = icon or 'check' - elif theme == 'error': + elif theme in ['error', 'danger']: theme = 'danger' icon = icon or 'times' + else: + theme = theme or 'secondary' + icon = icon or 'flask' + + if body and isinstance(body, six.string_types): + body = dom.pre(dom.code(traceback.format_exc() if body == 'traceback' else body), cls="mt-5") diracLogo = collectMetadata(ignoreErrors=True).get('logoURL', '') # Create head with html.head: - # Provide icons - dom.link(rel="icon", href="/static/core/img/icons/system/favicon.ico", type="image/x-icon") + # Meta tags + dom.meta(charset="utf-8") + dom.meta(name="viewport", content="width=device-width, initial-scale=1") + # Favicon + dom.link(rel="shortcut icon", href="/static/core/img/icons/system/favicon.ico", type="image/x-icon") + # Provide awesome icons + # https://fontawesome.com/v4.7/license/ dom.script(src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/js/all.min.js") # Enable bootstrap 5 + # https://getbootstrap.com/docs/5.0/getting-started/introduction/ + # https://getbootstrap.com/docs/5.0/about/license/ dom.link(rel='stylesheet', integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC", href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css", crossorigin="anonymous") dom.script(src='https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/js/bootstrap.bundle.min.js', @@ -85,34 +113,36 @@ def getHTML(title, info=None, body=None, style=None, state=None, theme=None, ico # Create body with html: # Background image - dom.i(cls='position-absolute bottom-0 start-0 translate-middle-x fa fa-{icon} m-5 text-{theme}', + dom.i(cls='position-absolute bottom-0 start-0 translate-middle-x m-5 fa ' + 'fa-%s text-%s' % (icon, theme), style="font-size:40vw;z-index:-1;") # A4 page with align center with dom.div(cls="row vh-100 vw-100 justify-content-md-center align-items-center m-0"): - with dom.div(cls='container', style="max-width: 600px;").add(dom.div(cls="row align-items-center")): + with dom.div(cls='container', style="max-width:600px;") as page: + # Main panel + with dom.div(cls="row align-items-center"): + # Logo + dom.div(dom.img(src=diracLogo, cls="card-img px-2"), cls="col-md-6 my-3") + # Information card + with dom.div(cls="col-md-6 my-3"): + + # Show response state number + if state and state != 200: + dom.div(dom.h1(state, cls='text-center badge bg-%s text-wrap' % theme), cls="row py-2") + + # Message title + with dom.div(cls="row"): + dom.div(dom.i(cls="fa fa-%s text-%s" % (icon, theme)), cls="col-auto") + dom.div(title, cls='col-auto ps-0 pb-2 fw-bold') - # Logo - dom.div(dom.img(src=diracLogo, cls="card-img px-2"), cls="col-md-6 my-3") - # Information card - with dom.div(cls="col-md-6 my-3"): - - # Show response state number - if state and state != 200: - dom.div(dom.h1(state, cls='text-center badge bg-{theme} text-wrap'), cls="row py-2") - - # Message title - with dom.div(cls="row"): - dom.div(dom.i(cls="fa fa-{icon} text-{theme}"), cls="col-auto") - dom.div(title, cls='col-auto ps-0 pb-2 fw-bold') - - # Description - if description: - dom.small(dom.i(cls="fa fa-info text-info")) - dom.small(info, cls="ps-1") + # Description + if info: + dom.small(dom.i(cls="fa fa-info text-info")) + dom.small(info, cls="ps-1") # Add content if body: - body + page.add(body) - return html + return html.render() From 5aead5f18dcff6339da618f267ff688c4134994b Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 21:39:22 +0200 Subject: [PATCH 165/178] fix tests --- .../Core/Tornado/Server/TornadoServer.py | 8 ++--- src/DIRAC/FrameworkSystem/API/AuthHandler.py | 8 ++--- src/DIRAC/FrameworkSystem/DB/AuthDB.py | 4 --- .../private/authorization/AuthServer.py | 2 +- .../private/authorization/utils/Utilities.py | 33 ++++++++++--------- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 40e37a09280..0d7675ad921 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -24,7 +24,6 @@ from tornado.web import Application, url, RequestHandler from tornado.ioloop import IOLoop import tornado.ioloop -from six.moves import http_client import DIRAC from DIRAC import gConfig, gLogger, S_OK @@ -33,7 +32,6 @@ from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient -from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML sLog = gLogger.getSubLogger(__name__) @@ -41,8 +39,10 @@ class NotFoundHandler(RequestHandler): """ Handle 404 errors """ def prepare(self): - self.set_status(http_client.NOT_FOUND.value) - self.finish(getHTML(http_client.NOT_FOUND)) + self.set_status(404) + if six.PY3: + from DIRAC.FrameworkSystem.private.authorization.utils.Utilities import getHTML + self.finish(getHTML('Not found.', state=404, info='Nothing matches the given URI.')) class TornadoServer(object): diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py index fb74512cf8b..e297e9073a3 100644 --- a/src/DIRAC/FrameworkSystem/API/AuthHandler.py +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -323,20 +323,20 @@ def web_redirect(self, state, error=None, error_description='', chooseScope=[]): if not session: return self.server.handle_response( payload=getHTML("session is expired.", theme="warning", state=400, - info=f"Seems {state} session is expired, please, try again."), delSession=True) + info="Seems %s session is expired, please, try again." % state), delSession=True) sessionWithExtIdP = json.loads(session) if state and not sessionWithExtIdP.get('state') == state: return self.server.handle_response( payload=getHTML("session is expired.", theme="warning", state=400, - info=f"Seems {state} session is expired, please, try again."), delSession=True) + info="Seems %s session is expired, please, try again." % state), delSession=True) # Try to catch errors if the authorization on the selected identity provider was unsuccessful if error: provider = sessionWithExtIdP.get('Provider') return self.server.handle_response( payload=getHTML(error, theme="error", body=error_description, - info=f"Seems {state} session is failed on the {provider}'s' side."), delSession=True) + info="Seems %s session is failed on the %s's' side." % (state, provider)), delSession=True) if not sessionWithExtIdP.get('authed'): # Parse result of the second authentication flow @@ -424,7 +424,7 @@ def __researchDIRACGroup(self, extSession, chooseScope, state): if not validGroups: return None, self.server.handle_response(getHTML( "groups not found.", theme="error", - info=f'No groups found for {username} and for {provider} Identity Provider.'), delSession=True) + info='No groups found for %s and for %s Identity Provider.' % (username, provider)), delSession=True) self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(validGroups)) diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py index 2ef49920042..0b0d4063e5a 100644 --- a/src/DIRAC/FrameworkSystem/DB/AuthDB.py +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -308,8 +308,6 @@ def getSession(self, sessionID): session = self.session() try: resData = session.query(AuthSession).filter(AuthSession.id == sessionID).first() - if not resData: - return self.__result(session, S_ERROR("%s session is expired." % sessionID)) except MultipleResultsFound: return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) except NoResultFound: @@ -328,8 +326,6 @@ def getSessionByUserCode(self, userCode): session = self.session() try: resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() - if not resData: - return self.__result(session, S_ERROR("Session for %s user code is expired." % userCode)) except MultipleResultsFound: return self.__result(session, S_ERROR("%s is not unique ID." % userCode)) except NoResultFound: diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index b8c619867f4..7d2aab0df45 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -365,7 +365,7 @@ def validate_consent_request(self, request, provider=None): # If return IdP selector if isinstance(req, six.string_types): return req - + self.log.info('Validate consent request for', req.state) grant = self.get_authorization_grant(req) self.log.debug('Use grant:', grant) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 62e1376287a..4f0fdf5465f 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -41,27 +41,26 @@ def collectMetadata(issuer=None, ignoreErrors=False): return AuthorizationServerMetadata(metadata) -def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None, icon=None): +def getHTML(title, info=None, body=None, style=None, state=None, theme=None, icon=None): """ Provide HTML object :param str title: short name of the notification, e.g.: server error - :param str info: some short description if needed, e.g.: Seems it because server down. - :param body: dominate tag object, main content, e.g.: dom.pre(dom.code(result['Message'])) + :param str info: some short description if needed, e.g.: It looks like the server is not responding + :param body: it can be string or dominate tag object, main content, e.g.:: + + from dominate import tags as dom + return getHTML('server error', dom.pre(dom.code(result['Message'])) + :param str style: additional css style if needed, e.g.: '.card{color:red;}' - :param int state: response state, if needed, e.g.: 404 - :param str theme: message color theme, the same that in bootstrap 5. - :param str icon: awesome icon name + :param int state: response state code, if needed, e.g.: 404 + :param str theme: message color theme, the same that in bootstrap 5, e.g.: 'warning' + :param str icon: awesome icon name, e.g.: 'users' - :return: HTML document object + :return: str -- HTML document """ - if title and not isinstance(title, six.string_types): - # Expected HTTPStatus - state = title.value - info = title.description - title = title.phrase - html = document("DIRAC - %s" % title) + # select the color to the state code if state in [400, 401, 403, 404]: theme = theme or 'warning' elif state in [500]: @@ -69,6 +68,7 @@ def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None elif state in [200]: theme = theme or 'success' + # select the icon to the theme if theme in ['warning', 'warn']: theme = 'warning' icon = icon or 'exclamation-triangle' @@ -83,6 +83,7 @@ def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None theme = theme or 'secondary' icon = icon or 'flask' + # If body is text wrap it with tags if body and isinstance(body, six.string_types): body = dom.pre(dom.code(traceback.format_exc() if body == 'traceback' else body), cls="mt-5") @@ -126,7 +127,7 @@ def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None dom.div(dom.img(src=diracLogo, cls="card-img px-2"), cls="col-md-6 my-3") # Information card with dom.div(cls="col-md-6 my-3"): - + # Show response state number if state and state != 200: dom.div(dom.h1(state, cls='text-center badge bg-%s text-wrap' % theme), cls="row py-2") @@ -135,12 +136,12 @@ def getHTML(title=None, info=None, body=None, style=None, state=None, theme=None with dom.div(cls="row"): dom.div(dom.i(cls="fa fa-%s text-%s" % (icon, theme)), cls="col-auto") dom.div(title, cls='col-auto ps-0 pb-2 fw-bold') - + # Description if info: dom.small(dom.i(cls="fa fa-info text-info")) dom.small(info, cls="ps-1") - + # Add content if body: page.add(body) From 68ea6496ef4fd93d7163cbf88712970ad26779e8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sun, 22 Aug 2021 23:43:32 +0200 Subject: [PATCH 166/178] fix docs --- .../private/authorization/utils/Utilities.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 4f0fdf5465f..87715bc8f58 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -46,11 +46,9 @@ def getHTML(title, info=None, body=None, style=None, state=None, theme=None, ico :param str title: short name of the notification, e.g.: server error :param str info: some short description if needed, e.g.: It looks like the server is not responding - :param body: it can be string or dominate tag object, main content, e.g.:: - - from dominate import tags as dom - return getHTML('server error', dom.pre(dom.code(result['Message'])) - + :param body: it can be string or dominate tag object, e.g.: + from dominate import tags as dom + return getHTML('server error', body=dom.pre(dom.code(result['Message'])) :param str style: additional css style if needed, e.g.: '.card{color:red;}' :param int state: response state code, if needed, e.g.: 404 :param str theme: message color theme, the same that in bootstrap 5, e.g.: 'warning' From f4d3ee77d4de9af136e394fe5394fff7c22170f8 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Mon, 23 Aug 2021 20:13:46 +0200 Subject: [PATCH 167/178] remove unuse --- src/DIRAC/FrameworkSystem/scripts/dirac_login.py | 3 +-- src/DIRAC/FrameworkSystem/scripts/dirac_logout.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 905245562ab..89caf8fa57a 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -32,7 +32,6 @@ class Params(object): def __init__(self): self.info = False - self.provider = 'DIRACCLI' self.proxy = False self.group = None self.lifetime = None @@ -139,7 +138,7 @@ def doOAuthMagic(self): params = {} if self.issuer: params['issuer'] = self.issuer - result = IdProviderFactory().getIdProvider(self.provider, **params) + result = IdProviderFactory().getIdProvider('DIRACCLI', **params) if not result['OK']: return result idpObj = result['Value'] diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py index 1f6f8eb1d81..620954535a4 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_logout.py @@ -30,7 +30,6 @@ class Params(object): def __init__(self): - self.provider = 'DIRACCLI' self.issuer = None self.tokenFileLoc = None @@ -76,7 +75,7 @@ def doOAuthMagic(self): params = {} if self.issuer: params['issuer'] = self.issuer - result = IdProviderFactory().getIdProvider(self.provider, **params) + result = IdProviderFactory().getIdProvider('DIRACCLI', **params) if not result['OK']: return result idpObj = result['Value'] From 60ce2064d5ee2f85d2d4564e43709cecabfd3722 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 25 Aug 2021 20:58:48 +0200 Subject: [PATCH 168/178] fix py3 --- src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py index ab765cf5d48..a20d4b0ac60 100644 --- a/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py +++ b/src/DIRAC/FrameworkSystem/Client/BundleDeliveryClient.py @@ -7,6 +7,7 @@ import os import getpass import tarfile +import six from six import BytesIO from base64 import b64decode @@ -61,9 +62,9 @@ def __getHash(self, bundleID, dirToSyncTo): :return: str """ try: - with open(os.path.join(dirToSyncTo, ".dab.%s" % bundleID), "r") as fd: + with open(os.path.join(dirToSyncTo, ".dab.%s" % bundleID), "rb") as fd: bdHash = fd.read().strip() - return bdHash + return bdHash.decode() if six.PY3 else bdHash except Exception: return "" @@ -77,7 +78,7 @@ def __setHash(self, bundleID, dirToSyncTo, bdHash): try: fileName = os.path.join(dirToSyncTo, ".dab.%s" % bundleID) with open(fileName, "wb") as fd: - fd.write(bdHash) + fd.write(bdHash.encode() if six.PY3 and not isinstance(bdHash, bytes) else bdHash) except Exception as e: self.log.error("Could not save hash after synchronization", "%s: %s" % (fileName, str(e))) From f3dffab20eda3b243c64fcab202c55e47c6b5e6d Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Wed, 25 Aug 2021 22:56:39 +0200 Subject: [PATCH 169/178] fix --- .../FrameworkSystem/private/authorization/utils/Utilities.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 87715bc8f58..1f9f8d397c0 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -85,7 +85,10 @@ def getHTML(title, info=None, body=None, style=None, state=None, theme=None, ico if body and isinstance(body, six.string_types): body = dom.pre(dom.code(traceback.format_exc() if body == 'traceback' else body), cls="mt-5") - diracLogo = collectMetadata(ignoreErrors=True).get('logoURL', '') + try: + diracLogo = collectMetadata(ignoreErrors=True).get('logoURL', '') + except Exception: + diracLogo = '' # Create head with html.head: From 4b6ace7ee50478647f9c02669ffd3e0237b033c7 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 26 Aug 2021 00:48:12 +0200 Subject: [PATCH 170/178] fix initialize --- .../WorkloadManagementSystem/Service/JobMonitoringHandler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py index 12ad6ade42b..073bb079774 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py @@ -666,4 +666,6 @@ def export_getInputData(cls, jobID): class JobMonitoringHandler(JobMonitoringHandlerMixin, RequestHandler): - pass + + def initialize(self): + self.initializeRequest() From dfbb53232dba40aff4bf0f04802877c51d078ab1 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 26 Aug 2021 13:54:05 +0200 Subject: [PATCH 171/178] fix int to str --- src/DIRAC/Interfaces/API/Dirac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Interfaces/API/Dirac.py b/src/DIRAC/Interfaces/API/Dirac.py index e1cb29400c1..8c9b344c874 100755 --- a/src/DIRAC/Interfaces/API/Dirac.py +++ b/src/DIRAC/Interfaces/API/Dirac.py @@ -1772,7 +1772,7 @@ def getJobStatus(self, jobID): if self.jobRepo: self.jobRepo.updateJobs(repoDict) for job, vals in siteDict.items(): # can be an iterator - result[job].update(vals) + result[str(job)].update(vals) return S_OK(result) From 5fe2e1bdec5a9acb1e1a986056e1d8ef27500b18 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 28 Aug 2021 21:39:11 +0200 Subject: [PATCH 172/178] fix JobManagerHandler --- src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py index 7d2aab0df45..3ae5784695d 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -10,8 +10,7 @@ import json import pprint import logging -from dominate import document, tags as dom -from tornado.template import Template +from dominate import tags as dom import authlib from authlib.jose import JsonWebKey, jwt From 3801f8901aaac542f06859881a36e7c56b7c262e Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 2 Sep 2021 13:48:50 +0200 Subject: [PATCH 173/178] fix issues --- src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py | 7 ++++--- src/DIRAC/ConfigurationSystem/Client/PathFinder.py | 2 +- src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index 09860fdc65a..1986f11ccdb 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -13,6 +13,8 @@ __RCSID__ = "$Id$" +ID_DN_PREFIX = "/O=DIRAC/CN='' + # pylint: disable=missing-docstring gBaseRegistrySection = "/Registry" @@ -726,7 +728,6 @@ def getIDFromDN(userDN): :return: S_OK(str)/S_ERROR() """ - prefix = '/O=DIRAC/CN=' - if not userDN.startswith(prefix): + if not userDN.startswith(ID_DN_PREFIX): return S_ERROR("%s DN does not contain user ID." % userDN) - return S_OK(userDN[len(prefix):]) + return S_OK(userDN[len(ID_DN_PREFIX):]) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index ecbb4cd303a..097f8b28eb4 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -96,7 +96,7 @@ def getComponentSection(system, component=False, setup=False, componentCategory= def getAPISection(system, endpointName=False, setup=False): - """ Get service section in a system + """ Get API section in a system :param str system: system name :param str endpointName: endpoint name diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index 0e7696d825b..a501726214c 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -210,7 +210,7 @@ def __discoverCredentialsToUse(self): -> if KW_SKIP_CA_CHECK is not in kwargs and we are using the certificates, set KW_SKIP_CA_CHECK to false in kwargs -> if KW_SKIP_CA_CHECK is not in kwargs and we are not using the certificate, check the skipCACheck - * Baerer token: + * Bearer token: -> If KW_USE_ACCESS_TOKEN in kwargs, sets it in self.__useAccessToken -> If not, check "/DIRAC/Security/UseTokens", and sets it in self.__useAccessToken and kwargs[KW_USE_ACCESS_TOKEN] From d2a1941a99aa9632bbff0b83dce044fb30f1b9d9 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 2 Sep 2021 14:22:25 +0200 Subject: [PATCH 174/178] afterrebase fix --- .../Client/Helpers/Resources.py | 119 +++++++++++++++--- .../private/authorization/utils/Requests.py | 8 +- .../private/authorization/utils/Tokens.py | 2 +- .../private/authorization/utils/Utilities.py | 2 +- 4 files changed, 106 insertions(+), 25 deletions(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index ac8cdf47bf8..631c9b6f49e 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -474,36 +474,117 @@ def getInfoAboutProviders(of=None, providerName=None, option='', section=''): section, option))) -def getProvidersForInstance(instance, providerType=None): - """ Get providers for instance +def findGenericCloudCredentials(vo=False, group=False): + """ Get the cloud credentials to use for a specific VO and/or group. """ + if not group and not vo: + return S_ERROR("Need a group or a VO to determine the Generic cloud credentials") + if not vo: + vo = Registry.getVOForGroup(group) + if not vo: + return S_ERROR("Group %s does not have a VO associated" % group) + opsHelper = Operations.Operations(vo=vo) + cloudGroup = opsHelper.getValue("Cloud/GenericCloudGroup", "") + cloudDN = opsHelper.getValue("Cloud/GenericCloudDN", "") + if not cloudDN: + cloudUser = opsHelper.getValue("Cloud/GenericCloudUser", "") + if cloudUser: + result = Registry.getDNForUsername(cloudUser) + if result['OK']: + cloudDN = result['Value'][0] + else: + return S_ERROR("Failed to find suitable CloudDN") + if cloudDN and cloudGroup: + gLogger.verbose("Cloud credentials from CS: %s@%s" % (cloudDN, cloudGroup)) + result = gProxyManager.userHasProxy(cloudDN, cloudGroup, 86400) + if not result['OK']: + return result + return S_OK((cloudDN, cloudGroup)) + return S_ERROR("Cloud credentials not found") - :param str instance: instance of what this providers - :param str providerType: provider type - :return: S_OK(list)/S_ERROR() +def getVMTypes(siteList=None, ceList=None, vmTypeList=None, vo=None): + """ Get CE/vmType options filtered by the provided parameters. """ - providers = [] - instance = "%sProviders" % instance - result = gConfig.getSections('%s/%s' % (gBaseResourcesSection, instance)) - # Return an empty list if the section does not exist - if not result['OK'] or not result['Value'] or not providerType: + result = gConfig.getSections('/Resources/Sites') + if not result['OK']: return result - for prov in result['Value']: - if providerType == gConfig.getValue('%s/%s/%s/ProviderType' % (gBaseResourcesSection, instance, prov)): - providers.append(prov) - return S_OK(providers) + resultDict = {} + grids = result['Value'] + for grid in grids: + result = gConfig.getSections('/Resources/Sites/%s' % grid) + if not result['OK']: + continue + sites = result['Value'] + for site in sites: + if siteList is not None and site not in siteList: + continue + if vo: + voList = gConfig.getValue('/Resources/Sites/%s/%s/VO' % (grid, site), []) + if voList and vo not in voList: + continue + result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud' % (grid, site)) + if not result['OK']: + continue + ces = result['Value'] + for ce in ces: + if ceList is not None and ce not in ceList: + continue + if vo: + voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/VO' % (grid, site, ce), []) + if voList and vo not in voList: + continue + result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s' % (grid, site, ce)) + if not result['OK']: + continue + ceOptionsDict = result['Value'] + result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud/%s/VMTypes' % (grid, site, ce)) + if not result['OK']: + result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud/%s/Images' % (grid, site, ce)) + if not result['OK']: + return result + vmTypes = result['Value'] + for vmType in vmTypes: + if vmTypeList is not None and vmType not in vmTypeList: + continue + if vo: + voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/VMTypes/%s/VO' % (grid, site, ce, vmType), []) + if not voList: + voList = gConfig.getValue('/Resources/Sites/%s/%s/Cloud/%s/Images/%s/VO' % (grid, site, ce, vmType), []) + if voList and vo not in voList: + continue + resultDict.setdefault(site, {}) + resultDict[site].setdefault(ce, ceOptionsDict) + resultDict[site][ce].setdefault('VMTypes', {}) + result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s/VMTypes/%s' % (grid, site, ce, vmType)) + if not result['OK']: + result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s/Images/%s' % (grid, site, ce, vmType)) + if not result['OK']: + continue + vmTypeOptionsDict = result['Value'] + resultDict[site][ce]['VMTypes'][vmType] = vmTypeOptionsDict -def getProviderInfo(provider): - """ Get provider info + return S_OK(resultDict) - :param str provider: provider - :return: S_OK(dict)/S_ERROR() +def getVMTypeConfig(site, ce='', vmtype=''): + """ Get the VM image type parameters of the specified queue """ - result = gConfig.getSections(gBaseResourcesSection) + tags = [] + grid = site.split('.')[0] + if not ce: + result = gConfig.getSections('/Resources/Sites/%s/%s/Cloud' % (grid, site)) + if not result['OK']: + return result + ceList = result['Value'] + if len(ceList) == 1: + ce = ceList[0] + else: + return S_ERROR('No cloud endpoint specified') + + result = gConfig.getOptionsDict('/Resources/Sites/%s/%s/Cloud/%s' % (grid, site, ce)) if not result['OK']: return result resultDict = result['Value'] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py index ee5e1c66547..ab4c7953008 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -47,7 +47,7 @@ def path(self): @property def groups(self): - """ Serarch DIRAC groups in scopes + """ Search DIRAC groups in scopes :return: list """ @@ -55,7 +55,7 @@ def groups(self): @property def group(self): - """ Serarch DIRAC group in scopes + """ Search DIRAC group in scopes :return: str """ @@ -64,7 +64,7 @@ def group(self): @property def provider(self): - """ Serarch IdP in scopes + """ Search IdP in scopes :return: str """ @@ -76,7 +76,7 @@ def provider(self, provider): @property def sessionID(self): - """ Serarch IdP in scopes + """ Search IdP in scopes :return: str """ diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py index 6e85fad590b..68c0c5feed6 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -114,7 +114,7 @@ def writeTokenDictToTokenFile(tokenDict, fileName=None): class OAuth2Token(_OAuth2Token): - """ Implementation a Token object """ + """ Implementation of a Token object """ def __init__(self, params=None, **kwargs): """ Constructor diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py index 1f9f8d397c0..674e0472987 100644 --- a/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Utilities.py @@ -14,7 +14,7 @@ def collectMetadata(issuer=None, ignoreErrors=False): - """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defines by IETF specification: + """ Collect metadata for DIRAC Authorization Server(DAS), a metadata format defined by the IETF specification: https://datatracker.ietf.org/doc/html/rfc8414#section-2 :param str issuer: issuer to set From 38e6132adfd9a80f9f7f7c5e2f75f30e43d7e006 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Thu, 2 Sep 2021 14:32:02 +0200 Subject: [PATCH 175/178] fix --- src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index 1986f11ccdb..750b7287628 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -13,7 +13,7 @@ __RCSID__ = "$Id$" -ID_DN_PREFIX = "/O=DIRAC/CN='' +ID_DN_PREFIX = "/O=DIRAC/CN=" # pylint: disable=missing-docstring From 42e4f6d396ed9efff031c3b6a62b67130f8069d6 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Fri, 3 Sep 2021 14:09:25 +0200 Subject: [PATCH 176/178] getDN --> getUserDN --- src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py index fe2c1799ea7..9a0d2aa6710 100644 --- a/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py +++ b/src/DIRAC/Core/Tornado/Server/private/BaseRequestHandler.py @@ -708,7 +708,7 @@ def _authzVISITOR(self): def log(self): return sLog - def getDN(self): + def getUserDN(self): return self.credDict.get('DN', '') def getUserName(self): From 9876263a5c05ae06734fa1714b0380a7b4917497 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 4 Sep 2021 20:24:31 +0200 Subject: [PATCH 177/178] after rebase --- .../Core/Tornado/Server/TornadoService.py | 187 +----------------- 1 file changed, 1 insertion(+), 186 deletions(-) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index f343c2fd44c..18ad47262b3 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -20,10 +20,7 @@ import DIRAC -from DIRAC import gConfig, gLogger, S_OK -from DIRAC.ConfigurationSystem.Client import PathFinder -from DIRAC.Core.DISET.AuthManager import AuthManager -from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC import gLogger, S_OK from DIRAC.Core.Utilities.JEncode import decode, encode from DIRAC.Core.Tornado.Server.private.BaseRequestHandler import BaseRequestHandler from DIRAC.ConfigurationSystem.Client import PathFinder @@ -149,188 +146,6 @@ def _getMethodArgs(self, args): :return: tuple """ - - # "method" argument of the POST call. - # This resolves into the ``export_`` method - # on the handler side - # If the argument is not available, the method exists - # and an error 400 ``Bad Request`` is returned to the client - self.method = self.get_argument("method") - - self._stats['requests'] += 1 - self._monitor.setComponentExtraParam('queries', self._stats['requests']) - self._monitor.addMark("Queries") - - try: - self.credDict = self._gatherPeerCredentials() - except Exception: # pylint: disable=broad-except - # If an error occur when reading certificates we close connection - # It can be strange but the RFC, for HTTP, say's that when error happend - # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN - sLog.error( - "Error gathering credentials", "%s; path %s" % - (self.getRemoteAddress(), self.request.path)) - raise HTTPError(status_code=http_client.UNAUTHORIZED) - - # Resolves the hard coded authorization requirements - try: - hardcodedAuth = getattr(self, 'auth_' + self.method) - except AttributeError: - hardcodedAuth = None - - # Check whether we are authorized to perform the query - # Note that performing the authQuery modifies the credDict... - authorized = self._authManager.authQuery(self.method, self.credDict, hardcodedAuth) - if not authorized: - sLog.error( - "Unauthorized access", "Identity %s; path %s; DN %s" % - (self.srv_getFormattedRemoteCredentials, - self.request.path, - self.credDict['DN'], - )) - raise HTTPError(status_code=http_client.UNAUTHORIZED) - - # Make post a coroutine. - # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines - # for details - @gen.coroutine - def post(self): # pylint: disable=arguments-differ - """ - Method to handle incoming ``POST`` requests. - Note that all the arguments are already prepared in the :py:meth:`.prepare` - method. - - The ``POST`` arguments expected are: - - * ``method``: name of the method to call - * ``args``: JSON encoded arguments for the method - * ``extraCredentials``: (optional) Extra informations to authenticate client - * ``rawContent``: (optionnal, default False) If set to True, return the raw output - of the method called. - - If ``rawContent`` was requested by the client, the ``Content-Type`` - is ``application/octet-stream``, otherwise we set it to ``application/json`` - and JEncode retVal. - - If ``retVal`` is a dictionary that contains a ``Callstack`` item, - it is removed, not to leak internal information. - - - Example of call using ``requests``:: - - In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' - ...: cert = '/tmp/x509up_u1000' - ...: kwargs = {'method':'whoami'} - ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' - ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: - ...: print r.json() - ...: - {u'OK': True, - u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'group': u'dirac_user', - u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'isLimitedProxy': False, - u'isProxy': True, - u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'properties': [u'NormalUser'], - u'secondsLeft': 85441, - u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/CN=2409820262', - u'username': u'adminusername', - u'validDN': False, - u'validGroup': False}} - """ - - sLog.notice( - "Incoming request", "%s /%s: %s" % - (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - self.method)) - - # Execute the method in an executor (basically a separate thread) - # Because of that, we cannot calls certain methods like `self.write` - # in __executeMethod. This is because these methods are not threadsafe - # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # However, we can still rely on instance attributes to store what should - # be sent back (reminder: there is an instance - # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self.__executeMethod) - - # retVal is :py:class:`tornado.concurrent.Future` - self.result = retVal.result() - - # Here it is safe to write back to the client, because we are not - # in a thread anymore - - # If set to true, do not JEncode the return of the RPC call - # This is basically only used for file download through - # the 'streamToClient' method. - rawContent = self.get_argument('rawContent', default=False) - - if rawContent: - # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - self.set_header("Content-Type", "application/octet-stream") - result = self.result - else: - self.set_header("Content-Type", "application/json") - result = encode(self.result) - - self.write(result) - self.finish() - - # This nice idea of streaming to the client cannot work because we are ran in an executor - # and we should not write back to the client in a different thread. - # See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # def export_streamToClient(self, filename): - # # https://bhch.github.io/posts/2017/12/serving-large-files-with-tornado-safely-without-blocking/ - # #import ipdb; ipdb.set_trace() - # # chunk size to read - # chunk_size = 1024 * 1024 * 1 # 1 MiB - - # with open(filename, 'rb') as f: - # while True: - # chunk = f.read(chunk_size) - # if not chunk: - # break - # try: - # self.write(chunk) # write the chunk to response - # self.flush() # send the chunk to client - # except StreamClosedError: - # # this means the client has closed the connection - # # so break the loop - # break - # finally: - # # deleting the chunk is very important because - # # if many clients are downloading files at the - # # same time, the chunks in memory will keep - # # increasing and will eat up the RAM - # del chunk - # # pause the coroutine so other handlers can run - # yield gen.sleep(0.000000001) # 1 nanosecond - - # return S_OK() - - @gen.coroutine - def __executeMethod(self): - """ - Execute the method called, this method is ran in an executor - We have several try except to catch the different problem which can occur - - - First, the method does not exist => Attribute error, return an error to client - - second, anything happend during execution => General Exception, send error to client - - .. warning:: - This method is called in an executor, and so cannot use methods like self.write - See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - """ - # getting method - try: - # For compatibility reasons with DISET, the methods are still called ``export_*`` - method = getattr(self, 'export_%s' % self.method) - except AttributeError: - sLog.error("Invalid method", self.method) - raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) - - # Decode args args_encoded = self.get_body_argument('args', default=encode([])) return (decode(args_encoded)[0], {}) From 1dc322d79ccacf63d787329d46e2539a14183506 Mon Sep 17 00:00:00 2001 From: TaykYoku Date: Sat, 4 Sep 2021 20:41:43 +0200 Subject: [PATCH 178/178] after rebase --- .../WorkloadManagementSystem/Service/JobMonitoringHandler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py index 073bb079774..12ad6ade42b 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py @@ -666,6 +666,4 @@ def export_getInputData(cls, jobID): class JobMonitoringHandler(JobMonitoringHandlerMixin, RequestHandler): - - def initialize(self): - self.initializeRequest() + pass