diff --git a/dirac.cfg b/dirac.cfg index 8acaceeab14..f71aba70abd 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -419,6 +419,19 @@ Systems } Resources { + IdProviders + { + CheckIn + { + # What supported type of provider does it belong to + ProviderType = OAuth2 + # Description of the client parameters registered on the identity provider side. + # Look here for information about client parameters description https://tools.ietf.org/html/rfc8414#section-2 + issuer = https://aai-dev.egi.eu/oidc + client_id = type_client_id_here_receved_after_client_registration + scope = openid, profile, offline_access, eduperson_entitlement, cert_entitlement + } + } # Section for proxy providers, subsections is the names of the proxy providers # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Resources/proxyprovider.html diff --git a/docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac-my-great-script.py b/docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac_my_great_script.py similarity index 100% rename from docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac-my-great-script.py rename to docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac_my_great_script.py diff --git a/docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac-ping-info.py b/docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac_ping_info.py similarity index 100% rename from docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac-ping-info.py rename to docs/source/DeveloperGuide/AddingNewComponents/DevelopingCommands/dirac_ping_info.py diff --git a/environment.yml b/environment.yml index 131b8bf0a1a..d4add15bf88 100644 --- a/environment.yml +++ b/environment.yml @@ -69,6 +69,9 @@ dependencies: - typing >=3.6.6 # Pin OpenSSL to avoid: https://github.com/DIRACGrid/DIRAC/issues/4489 - openssl <1.1 + # OAuth + - authlib ==0.15.3 + - termcolor - pip: - diraccfg # This is a fork of tornado with a patch to allow for configurable iostream diff --git a/requirements.txt b/requirements.txt index c910c22a06a..dc8b0630797 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,5 +61,3 @@ caniusepython3 subprocess32 flaky ldap3 -# setuptools_scm comes via tornado. newer versions of setuptools_scm do not support py2 -setuptools_scm<6.0 diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py index e9377b1915a..ebd376b96cd 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Registry.py @@ -1,4 +1,14 @@ -""" Helper for /Registry section +""" Helper for **/Registry** section that contains information about DIRAC users, groups and communities (VOs). + + Currently, user registration is done by writing a user name with some metadata to the DIRAC configuration + in the Registry section. However, if an external resource, such as a VOMS server or an OAuth2 Identity Provider, + is used to obtain a user profile, information from these resources will also be considered only with + a lower priority than DIRAC configuration. + + Thare are present two important imports, that provide caching data:: + + * :mod:`ProxyManagerData ` caches information from VOMS + * :mod:`AuthManagerData ` caches information from IdPs """ from __future__ import absolute_import from __future__ import division @@ -6,14 +16,18 @@ import six import errno +from pprint import pprint from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Utilities.Decorators import deprecated from DIRAC.ConfigurationSystem.Client.Config import gConfig from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getVO -__RCSID__ = "$Id$" +# Registry use cached data from AuthManager and ProxyManager services +from DIRAC.FrameworkSystem.Client.ProxyManagerData import gProxyManagerData +from DIRAC.FrameworkSystem.Client.AuthManagerData import gAuthManagerData -# pylint: disable=missing-docstring +__RCSID__ = "$Id$" gBaseRegistrySection = "/Registry" @@ -35,18 +49,32 @@ def getUsernameForDN(dn, usersList=None): for username in usersList: if dn in gConfig.getValue("%s/Users/%s/DN" % (gBaseRegistrySection, username), []): return S_OK(username) + + # Get users profiles from session manager cache + result = gAuthManagerData.getIDsForDN(dn) + if result['OK']: + for uid in result['Value']: + result = getUsernameForID(uid) + if result['OK']: + return result + return S_ERROR("No username found for dn %s" % dn) +@deprecated("Use getDNsForUsername or getDNsForUsernameFromCS instead") def getDNForUsername(username): - """ Get user DN for user + dnList = getDNsForUsernameFromCS(username) + return S_OK(dnList) if dnList else S_ERROR("No DN found for user %s" % username) - :param str username: user name - :return: S_OK(str)/S_ERROR() +def getDNsForUsernameFromCS(username): + """ Find DNs for DIRAC user only from CS + + :param str username: DIRAC user + + :return: list -- contain DNs """ - dnList = gConfig.getValue("%s/Users/%s/DN" % (gBaseRegistrySection, username), []) - return S_OK(dnList) if dnList else S_ERROR("No DN found for user %s" % username) + return gConfig.getValue("%s/Users/%s/DN" % (gBaseRegistrySection, username), []) def getDNForHost(host): @@ -60,18 +88,52 @@ def getDNForHost(host): return S_OK(dnList) if dnList else S_ERROR("No DN found for host %s" % host) -def getGroupsForDN(dn): +def getGroupsForDN(dn, groupsList=None): """ Get all possible groups for DN :param str dn: user DN + :param list groupsList: group list where need to search :return: S_OK(list)/S_ERROR() -- contain list of groups """ dn = dn.strip() + groups = [] + if not groupsList: + result = gConfig.getSections("%s/Groups" % gBaseRegistrySection) + if not result['OK']: + return result + groupsList = result['Value'] + result = getUsernameForDN(dn) if not result['OK']: return result - return getGroupsForUser(result['Value']) + user = result['Value'] + + # Get VOMS information cache + result = gProxyManagerData.getActualVOMSesDNs(dnList=[dn]) + if not result['OK']: + return result + vomsData = result['Value'] + + result = getVOsWithVOMS() + if not result['OK']: + return result + vomsVOs = result['Value'] + + for group in groupsList: + if user in getGroupOption(group, 'Users', []): + vo = getGroupOption(group, 'VO') + # Is VOMS VO? + if vo in vomsVOs and vomsData.get(vo) and vomsData[vo]['OK'] and vomsData[vo]['Value']: + voData = vomsData[vo]['Value'] + role = getGroupOption(group, 'VOMSRole') + if not role or role in voData[dn]['VOMSRoles']: + groups.append(group) + else: + # If it's not VOMS VO or cannot get information from VOMS + groups.append(group) + groups = sorted(set(groups)) + return S_OK(groups) if groups else S_ERROR('No groups found for %s' % dn) def __getGroupsWithAttr(attrName, value): @@ -94,14 +156,26 @@ def __getGroupsWithAttr(attrName, value): return S_OK(groups) if groups else S_ERROR("No groups found for %s=%s" % (attrName, value)) -def getGroupsForUser(username): - """ Find groups for user +def getGroupsForUser(username, groupsList=None): + """ Find groups for user or if set reseachedGroup check it for user :param str username: user name + :param list groupsList: groups - :return: S_OK(list)/S_ERROR() -- contain list of groups + :return: S_OK(list or bool)/S_ERROR() -- contain list of groups or status group for user """ - return __getGroupsWithAttr('Users', username) + if not groupsList: + retVal = gConfig.getSections("%s/Groups" % gBaseRegistrySection) + if not retVal['OK']: + return retVal + groupsList = retVal['Value'] + + groups = [] + for group in groupsList: + if username in getGroupOption(group, 'Users', []): + groups.append(group) + groups = sorted(set(groups)) + return S_OK(groups) if groups else S_ERROR('No groups found for %s user' % username) def getGroupsForVO(vo): @@ -205,7 +279,7 @@ def getAllGroups(): return result['Value'] if result['OK'] else [] -def getUsersInGroup(groupName, defaultValue=None): +def getUsersInGroup(group, defaultValue=None): """ Find all users for group :param str group: group name @@ -213,8 +287,8 @@ def getUsersInGroup(groupName, defaultValue=None): :return: list """ - option = "%s/Groups/%s/Users" % (gBaseRegistrySection, groupName) - return gConfig.getValue(option, [] if defaultValue is None else defaultValue) + users = sorted(set(getGroupOption(group, 'Users', []))) + return users or [] if defaultValue is None else defaultValue def getUsersInVO(vo, defaultValue=None): @@ -225,44 +299,56 @@ def getUsersInVO(vo, defaultValue=None): :return: list """ + users = [] result = getGroupsForVO(vo) - if not result['OK'] or not result['Value']: - return [] if defaultValue is None else defaultValue - groups = result['Value'] + if result['OK'] and result['Value']: + for group in result['Value']: + users += getUsersInGroup(group) + users = sorted(set(users)) + return users or [] if defaultValue is None else defaultValue - userList = [] - for group in groups: - userList += getUsersInGroup(group) - return userList +def getDNsInGroup(group, checkStatus=False): + """ Find user DNs for DIRAC group -def getDNsInVO(vo): - """ Get all DNs that have a VO users - - :param str vo: VO name + :param str group: group name + :param bool checkStatus: don't add suspended DNs :return: list """ - DNs = [] - for user in getUsersInVO(vo): - result = getDNForUsername(user) - if result['OK']: - DNs.extend(result['Value']) - return DNs - - -def getDNsInGroup(groupName): - """ Find all DNs for DIRAC group + vomsData = {} + vo = getGroupOption(group, 'VO') - :param str groupName: group name + # Get VOMS information for VO, if it's VOMS VO + result = getVOsWithVOMS([vo]) + if not result['OK']: + return result + if result['Value']: + result = gProxyManagerData.getActualVOMSesDNs(voList=vo) + if not result['OK']: + return result + vomsData = result['Value'] - :return: list - """ DNs = [] - for user in getUsersInGroup(groupName): - result = getDNForUsername(user) - if result['OK']: - DNs.extend(result['Value']) + for username in getGroupOption(group, 'Users', []): + if checkStatus and vo in getUserOption(username, 'Suspended', []): + continue + result = getDNsForUsername(username) + if not result['OK']: + return result + userDNs = result['Value'] + if vomsData.get(vo) and vomsData[vo]['OK']: + voData = vomsData[vo]['Value'] + role = getGroupOption(group, 'VOMSRole') + for dn in userDNs: + if dn in voData and (not checkStatus or not voData[dn]['Suspended']): + if not role or role in voData[dn]['ActiveRoles' if checkStatus else 'VOMSRoles']: + DNs.append(dn) + else: + for dn in userDNs: + if dn and dn not in DNs: + DNs.append(dn) + return DNs @@ -300,8 +386,6 @@ def getPropertiesForEntity(group, name="", dn="", defaultValue=None): :return: defaultValue or list """ - if defaultValue is None: - defaultValue = [] if group == 'hosts': if not name: result = getHostnameForDN(dn) @@ -428,6 +512,16 @@ def getVOForGroup(group): return getVO() or gConfig.getValue("%s/Groups/%s/VO" % (gBaseRegistrySection, group), "") +def getIdPForGroup(group): + """ Get identity provider for group VO + + :param str group: group name + + :return: str + """ + return getVOOption(getVOForGroup(group), 'IdP') + + def getDefaultVOMSAttribute(): """ Get default VOMS attribute @@ -468,15 +562,16 @@ def getVOMSVOForGroup(group): return vomsVO -def getGroupsWithVOMSAttribute(vomsAttr): +def getGroupsWithVOMSAttribute(vomsAttr, groupsList=None): """ Search groups with VOMS attribute :param str vomsAttr: VOMS attribute + :param list groupsList: groups where need to search :return: list """ groups = [] - for group in gConfig.getSections("%s/Groups" % (gBaseRegistrySection)).get('Value', []): + for group in groupsList or getAllGroups(): if vomsAttr == gConfig.getValue("%s/Groups/%s/VOMSRole" % (gBaseRegistrySection, group), ""): groups.append(group) return groups @@ -587,72 +682,30 @@ def getUsernameForID(ID, usersList=None): return S_ERROR("No username found for ID %s" % ID) -def getCAForUsername(username): - """ Get CA option by user name +def getDNProperty(dn, prop, defaultValue=None, username=None): + """ Get user DN property - :param str username: user name - - :return: S_OK(str)/S_ERROR() - """ - dnList = gConfig.getValue("%s/Users/%s/CA" % (gBaseRegistrySection, username), []) - return S_OK(dnList) if dnList else S_ERROR("No CA found for user %s" % username) - - -def getDNProperty(userDN, value, defaultValue=None): - """ Get property from DNProperties section by user DN - - :param str userDN: user DN - :param str value: option that need to get + :param str dn: user DN + :param str prop: property name :param defaultValue: default value + :param str username: username - :return: S_OK()/S_ERROR() -- str or list that contain option value - """ - result = getUsernameForDN(userDN) - if not result['OK']: - return result - pathDNProperties = "%s/Users/%s/DNProperties" % (gBaseRegistrySection, result['Value']) - result = gConfig.getSections(pathDNProperties) - if result['OK']: - for section in result['Value']: - if userDN == gConfig.getValue("%s/%s/DN" % (pathDNProperties, section)): - return S_OK(gConfig.getValue("%s/%s/%s" % (pathDNProperties, section, value), defaultValue)) - return S_OK(defaultValue) - - -def getProxyProvidersForDN(userDN): - """ Get proxy providers by user DN - - :param str userDN: user DN - - :return: S_OK(list)/S_ERROR() + :return: S_OK()/S_ERROR() """ - return getDNProperty(userDN, 'ProxyProviders', []) - - -def getDNFromProxyProviderForUserID(proxyProvider, userID): - """ Get groups by user DN in DNProperties - - :param str proxyProvider: proxy provider name - :param str userID: user identificator + if not username: + result = getUsernameForDN(dn) + if not result['OK']: + return result + username = result['Value'] - :return: S_OK(str)/S_ERROR() - """ - # Get user name - result = getUsernameForID(userID) + root = "%s/Users/%s/DNProperties" % (gBaseRegistrySection, username) + result = gConfig.getSections(root) if not result['OK']: return result - # Get DNs from user - result = getDNForUsername(result['Value']) - if not result['OK']: - return result - for DN in result['Value']: - result = getProxyProvidersForDN(DN) - if not result['OK']: - return result - if proxyProvider in result['Value']: - return S_OK(DN) - return S_ERROR(errno.ENODATA, - "No DN found for %s proxy provider for user ID %s" % (proxyProvider, userID)) + for section in result['Value']: + if dn == gConfig.getValue("%s/%s/DN" % (root, section)): + return S_OK(gConfig.getValue("%s/%s/%s" % (root, section, prop), defaultValue)) + return S_OK(defaultValue) def isDownloadableGroup(groupName): @@ -667,24 +720,6 @@ def isDownloadableGroup(groupName): return True -def getUserDict(username): - """ Get full information from user section - - :param str username: DIRAC user name - - :return: S_OK()/S_ERROR() - """ - resDict = {} - relPath = '%s/Users/%s/' % (gBaseRegistrySection, username) - result = gConfig.getConfigurationTree(relPath) - if not result['OK']: - return result - for key, value in result['Value'].items(): - if value: - resDict[key.replace(relPath, '')] = value - return S_OK(resDict) - - def getEmailsForGroup(groupName): """ Get email list of users in group @@ -697,3 +732,112 @@ def getEmailsForGroup(groupName): email = getUserOption(username, 'Email', []) emails.append(email) return emails + + +def getIDsForUsername(username): + """ Return IDs for DIRAC user + + :param str username: DIRAC user + + :return: list -- contain IDs + """ + return gConfig.getValue("%s/Users/%s/ID" % (gBaseRegistrySection, username), []) + + +def getVOsWithVOMS(voList=None): + """ Get all the configured VOMS VOs + + :param list voList: VOs where to look + + :return: S_OK(list)/S_ERROR() + """ + vos = [] + if not voList: + result = getVOs() + if result['OK']: + voList = result['Value'] + for vo in voList or []: + if getVOOption(vo, 'VOMSName'): + vos.append(vo) + return S_OK(vos) + + +def getDNsForUsername(username): + """ Find all DNs for DIRAC user + + :param str username: DIRAC user + + :return: S_OK(list)/S_ERROR() -- contain DNs + """ + userDNs = getDNsForUsernameFromCS(username) + for uid in getIDsForUsername(username): + result = gAuthManagerData.getDNsForID(uid) + if result['OK']: + for dn in result['Value']: + if dn not in userDNs: + userDNs.append(dn) + return S_OK(userDNs) + + +def getDNForUsernameInGroup(username, group, checkStatus=False): + """ Get user DN for user in group + + :param str username: user name + :param str group: group name + :param bool checkStatus: don't add suspended DNs + + :return: S_OK(str)/S_ERROR() + """ + result = getDNsForUsernameInGroup(username, group, checkStatus) + return S_OK(result['Value'][0]) if result['OK'] else result + + +def getDNsForUsernameInGroup(username, group, checkStatus=False): + """ Get user DN for user in group + + :param str username: user name + :param str group: group name + :param bool checkStatus: don't add suspended DNs + + :return: S_OK(str)/S_ERROR() + """ + if username not in getGroupOption(group, 'Users', []): + return S_ERROR('%s group not have %s user.' % (group, username)) + result = getDNsForUsername(username) + if not result['OK']: + return result + userDNs = result['Value'] + print('== getDNsForUsernameInGroup ==') + pprint(userDNs) + + DNs = [] + vo = getGroupOption(group, 'VO') + if checkStatus and vo in getUserOption(username, 'Suspended', []): + return S_ERROR('%s marked as suspended for %s VO.' % (username, vo)) + result = getVOsWithVOMS([vo]) + if not result['OK']: + return result + if result['Value']: + result = gProxyManagerData.getActualVOMSesDNs(voList=[vo]) + if not result['OK']: + return result + vomsData = result['Value'] + if vomsData.get(vo) and vomsData[vo]['OK']: + voData = vomsData[vo]['Value'] + role = getGroupOption(group, 'VOMSRole') + for dn in userDNs: + if dn in voData and (not checkStatus or not voData[dn]['Suspended']): + if not role or role in voData[dn]['ActiveRoles' if checkStatus else 'VOMSRoles']: + DNs.append(dn) + else: + DNs = userDNs + else: + DNs = userDNs + print('-------------------------') + pprint(DNs) + dns = [e for e in DNs if e] + print('-------------------------') + pprint(dns) + if dns: + return S_OK(dns) + return S_ERROR('For %s@%s not found DN%s.' % (username, group, ' or it suspended' if checkStatus else '')) diff --git a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py index 84455a045f4..8b49c4136c7 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py +++ b/src/DIRAC/ConfigurationSystem/Client/Helpers/Resources.py @@ -428,45 +428,91 @@ def getFilterConfig(filterID): return gConfig.getOptionsDict('Resources/LogFilters/%s' % filterID) -def getInfoAboutProviders(of=None, providerName=None, option='', section=''): - """ Get the information about providers - - :param str of: provider of what(Id, Proxy or etc.) need to look, - None, "all" to get list of instance of what this providers - :param str providerName: provider name, - None, "all" to get list of providers names - :param str option: option name that need to get, - None, "all" to get all options in a section - :param str section: section path in root section of provider, - "all" to get options in all sections - - :return: S_OK()/S_ERROR() +def getIdProviderForIssuer(issuer): + """ Get identity provider for issuer + + :param str issuer: issuer + + :return: S_OK(dict)/S_ERROR() + """ + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for prov in result['Value']: + if issuer.strip('/') == gConfig.getValue('%s/IdProviders/%s/issuer' % (gBaseResourcesSection, prov)).strip('/'): + return S_OK(prov) + return S_ERROR('Not found provider wwith %s issuer.' % issuer) + + +def getProvidersForInstance(instance, providerType=None): + """ Get providers for instance + + :param str instance: instance of what this providers + :param str providerType: provider type + + :return: S_OK(list)/S_ERROR() + """ + data = [] + instance = "%sProviders" % instance + result = gConfig.getSections(gBaseResourcesSection) + if result['OK']: + if instance not in result['Value']: + return S_OK(data) + result = gConfig.getSections('%s/%s' % (gBaseResourcesSection, instance)) + + # Return an empty list if the section does not exist + if not result['OK'] or not result['Value'] or not providerType: + return result + + for prov in result['Value']: + if providerType == gConfig.getValue('%s/%s/%s/ProviderType' % (gBaseResourcesSection, instance, prov)): + data.append(prov) + return S_OK(data) + + +def getProviderByAlias(alias, instance=None): + """ Find provider name by alias + + :param str alias: other registered provider name + :param str instance: provider of what + + :return: S_OK(str)/S_ERROR() """ - if not of or of == "all": + instances = [instance] or [] + if not instances: result = gConfig.getSections(gBaseResourcesSection) if not result['OK']: return result - return S_OK([i.replace('Providers', '') for i in result['Value']]) - if not providerName or providerName == "all": - return gConfig.getSections('%s/%sProviders' % (gBaseResourcesSection, of)) - if not option or option == 'all': - if not section: - return gConfig.getOptionsDict( - "%s/%sProviders/%s" % (gBaseResourcesSection, of, providerName)) - elif section == "all": - resDict = {} - relPath = "%s/%sProviders/%s/" % (gBaseResourcesSection, of, providerName) - result = gConfig.getConfigurationTree(relPath) + for section in result['Value']: + if section.endswith('Providers'): + instances.append(section.rsplit('Providers', 1)[0]) + for instance in instances: + result = getProvidersForInstance(instance) + if not result['OK']: + return result + for provider in result['Value']: + if alias in gConfig.getValue("%s/%sProviders/%s/Aliases" % (gBaseResourcesSection, + instance, provider), []): + return S_OK(provider) + return S_ERROR('Did not find any provider for %s' % alias) + + +def getProviderInfo(provider): + """ Get provider info + + :param str provider: provider + + :return: S_OK(dict)/S_ERROR() + """ + result = gConfig.getSections(gBaseResourcesSection) + if not result['OK']: + return result + for section in result['Value']: + if section.endswith('Providers'): + result = getProvidersForInstance(section[:-9]) if not result['OK']: return result - for key, value in result['Value'].items(): # can be an iterator - if value: - resDict[key.replace(relPath, '')] = value - return S_OK(resDict) - else: - return gConfig.getSections( - '%s/%sProviders/%s/%s/' % (gBaseResourcesSection, of, providerName, section)) - else: - return S_OK(gConfig.getValue( - '%s/%sProviders/%s/%s/%s' % (gBaseResourcesSection, of, providerName, - section, option))) + if provider in result['Value']: + return gConfig.getOptionsDictRecursively("%s/%s/%s/" % (gBaseResourcesSection, + section, provider)) + return S_ERROR('%s provider not found.' % provider) diff --git a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py index 4a748aad28e..5a28cbb0fb0 100755 --- a/src/DIRAC/ConfigurationSystem/Client/PathFinder.py +++ b/src/DIRAC/ConfigurationSystem/Client/PathFinder.py @@ -67,6 +67,10 @@ def getComponentSection(componentName, componentTuple=False, setup=False, compon return "%s/%s/%s" % (systemSection, componentCategory, componentTuple[1]) +def getAPISection(APIName, APITuple=False, setup=False): + return getComponentSection(APIName, APITuple, setup, "APIs") + + def getServiceSection(serviceName, serviceTuple=False, setup=False): return getComponentSection(serviceName, serviceTuple, setup, "Services") diff --git a/src/DIRAC/ConfigurationSystem/Client/Utilities.py b/src/DIRAC/ConfigurationSystem/Client/Utilities.py index d4b9cef6bab..79518f47979 100644 --- a/src/DIRAC/ConfigurationSystem/Client/Utilities.py +++ b/src/DIRAC/ConfigurationSystem/Client/Utilities.py @@ -21,6 +21,7 @@ from DIRAC.Core.Utilities.Glue2 import getGlue2CEInfo from DIRAC.Core.Utilities.SiteSEMapping import getSEHosts from DIRAC.DataManagementSystem.Utilities.DMSHelpers import DMSHelpers +from DIRAC.ConfigurationSystem.Client.PathFinder import getSystemInstance def getGridVOs(): @@ -536,16 +537,6 @@ def getElasticDBParameters(fullname): return S_OK(parameters) -def getOAuthAPI(instance='Production'): - """ Get OAuth API url - - :param str instance: instance - - :return: str - """ - return gConfig.getValue("/Systems/Framework/%s/URLs/OAuthAPI" % instance) - - def getDIRACGOCDictionary(): """ Create a dictionary containing DIRAC site names and GOCDB site names @@ -578,3 +569,133 @@ def getDIRACGOCDictionary(): log.debug('End function.') return S_OK(dictionary) + + +def getAuthAPI(): + """ Get Auth REST API url + + :return: str + """ + return gConfig.getValue("/Systems/Framework/%s/URLs/Auth" % getSystemInstance("Framework")) + + +def getProxyAPI(): + """ Get Proxy REST API url + + :return: str + """ + return gConfig.getValue("/Systems/Framework/%s/URLs/Proxy" % getSystemInstance("Framework")) + + +def getDIRACClientID(): + """ Get DIRAC client public ID + + :return: str + """ + return gConfig.getValue("/DIRAC/ClientID") + + +def getWebClient(): + """ Get registred in the configuration Web authentication client + + :return: S_OK(dict)/S_ERROR() + """ + return getAuthClients(clientName='WEBAPPDIRACCLI') + + +def getDIRACClient(): + """ Get registred in the configuration DIRAC authentication client + + :return: S_OK(dict)/S_ERROR() + """ + return getAuthClients(clientName='DIRACCLI') + + +def getAuthClients(clientID=None, clientName=None): + """ Get all registred in the configuration authentication clients + + :param str clientID: client ID + :param str clientName: client name + + :return: S_OK(dict)/S_ERROR() -- dictionary contain all registred clients in the configuration + """ + clients = {} + path = '/Systems/Framework/%s/APIs/Auth' % getSystemInstance("Framework") + result = gConfig.getSections(path) + if not result['OK']: + return result + + if 'Clients' in result['Value']: + result = gConfig.getOptionsDictRecursively('%s/Clients' % path) + if not result['OK']: + return result + clients = result['Value'] + + for cliName, cliDict in clients.items(): + cliDict['issuer'] = cliDict.get('issuer', getAuthAPI()) + cliDict['authority'] = cliDict.get('authority', getAuthAPI()) + if cliName == 'DIRACCLI': + if not cliDict.get('client_metadata'): + cliDict['client_metadata'] = {'response_types': ['device'], + 'grant_types': ['urn:ietf:params:oauth:grant-type:device_code']} + elif cliName == 'WEBAPPDIRACCLI': + cliDict['token_endpoint_auth_method'] = cliDict.get('token_endpoint_auth_method', 'client_secret_basic') + if not cliDict.get('client_metadata'): + cliDict['client_metadata'] = {'response_types': ['code', 'id_token token', 'token'], + 'redirect_uris': [cliDict['redirect_uri']], + 'token_endpoint_auth_method': cliDict['token_endpoint_auth_method'], + 'grant_types': ['device', 'authorization_code', 'refresh_token', + 'urn:ietf:params:oauth:grant-type:token-exchange']} + if clientName and clientName == cliName: + return S_OK(cliDict) + + if clientID and clientID == cliDict['client_id']: + return S_OK(cliDict) + return S_OK({} if clientID else clients) + + +def getAuthorisationServerMetadata(): + """ Get authoraisation server metadata + + :return: S_OK(dict)/S_ERROR() + """ + path = '/Systems/Framework/%s/APIs/Auth' % getSystemInstance("Framework") + result = gConfig.getSections(path) + if not result['OK']: + return result + + data = {} + if 'AuthorizationServer' in result['Value']: + result = gConfig.getOptionsDictRecursively('%s/AuthorizationServer' % path) + if not result['OK']: + return result + data = result['Value'] + + data['issuer'] = data.get('issuer', getAuthAPI()) + if not data['issuer']: + return S_ERROR('Cannot found the Auth RESTful API base URL in the configuration.') + data['jwks_uri'] = data.get('jwks_uri', data['issuer'] + '/jwk') + data['token_endpoint'] = data.get('token_endpoint', data['issuer'] + '/token') + data['userinfo_endpoint'] = data.get('userinfo_endpoint', data['issuer'] + '/userinfo') + data['registration_endpoint'] = data.get('registration_endpoint', data['issuer'] + '/register') + data['authorization_endpoint'] = data.get('authorization_endpoint', data['issuer'] + '/authorization') + data['grant_types_supported'] = data.get('grant_types_supported', [ + 'code', 'authorization_code', 'urn:ietf:params:oauth:grant-type:device_code', 'refresh_token' + ]) + data['response_types_supported'] = data.get('response_types_supported', [ + 'code', 'device', 'id_token token', 'id_token', 'token' + ]) + data['code_challenge_methods_supported'] = data.get('code_challenge_methods_supported', ['S256']) + # Search values with type list + for key, v in data.items(): + data[key] = [e for e in v.replace(', ', ',').split(',') if e] if ',' in v else v + return S_OK(data) + + +def isDownloadablePersonalProxy(): + """ Get downloadablePersonalProxy flag + + :return: S_OK(bool)/S_ERROR() + """ + cs_path = '/Systems/Framework/%s/APIs/Proxy' % getSystemInstance("Framework") + return gConfig.getOption(cs_path + '/downloadablePersonalProxy') diff --git a/src/DIRAC/ConfigurationSystem/test/Test_agentOptions.py b/src/DIRAC/ConfigurationSystem/test/Test_agentOptions.py index 20c551e7725..c7a5ce34598 100644 --- a/src/DIRAC/ConfigurationSystem/test/Test_agentOptions.py +++ b/src/DIRAC/ConfigurationSystem/test/Test_agentOptions.py @@ -68,7 +68,7 @@ 'Enable']}), ('DIRAC.WorkloadManagementSystem.Agent.StatesAccountingAgent', {}), ('DIRAC.WorkloadManagementSystem.Agent.SiteDirector', - {'SpecialMocks': {'findGenericPilotCredentials': S_OK(('a', 'b'))}}), + {'SpecialMocks': {'findGenericPilotCredentials': S_OK(('a', 'b', 'c'))}}), ] diff --git a/src/DIRAC/Core/Base/API.py b/src/DIRAC/Core/Base/API.py index c01f56fd82a..fc0f0cbc817 100644 --- a/src/DIRAC/Core/Base/API.py +++ b/src/DIRAC/Core/Base/API.py @@ -3,14 +3,15 @@ from __future__ import print_function from __future__ import absolute_import from __future__ import division + import six -import pprint import sys +import pprint from DIRAC import gLogger, gConfig, S_OK, S_ERROR from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNsForUsername from DIRAC.Core.Utilities.Version import getCurrentVersion -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getSites __RCSID__ = '$Id$' @@ -159,9 +160,11 @@ def _getCurrentUser(self): gLogger.debug(formatProxyInfoAsString(proxyInfo)) if 'group' not in proxyInfo: return self._errorReport('Proxy information does not contain the group', res['Message']) - res = getDNForUsername(proxyInfo['username']) - if not res['OK']: - return self._errorReport('Failed to get proxies for user', res['Message']) + result = getDNsForUsername(proxyInfo['username']) + if not result['OK']: + return self._errorReport('Failed to get proxies for user', result['Message']) + if not result['Value']: + return self._errorReport('Failed to get proxies for user', "No DNs found for %s" % proxyInfo['username']) return S_OK(proxyInfo['username']) ############################################################################# diff --git a/src/DIRAC/Core/Base/DB.py b/src/DIRAC/Core/Base/DB.py index af4b9af4fc2..a5770a5ccd6 100755 --- a/src/DIRAC/Core/Base/DB.py +++ b/src/DIRAC/Core/Base/DB.py @@ -5,6 +5,7 @@ from __future__ import division from __future__ import print_function +from DIRAC import S_OK, S_ERROR from DIRAC.Core.Base.DIRACDB import DIRACDB from DIRAC.Core.Utilities.MySQL import MySQL from DIRAC.ConfigurationSystem.Client.Utilities import getDBParameters @@ -17,8 +18,15 @@ class DB(DIRACDB, MySQL): """ def __init__(self, dbname, fullname, debug=False): + """ C'or + :param str dbname: database name + :param str fullname: full name + :param bool debug: debug mode + """ + self.versionDB = 0 self.fullname = fullname + self.versionTable = '%s_Version' % dbname result = getDBParameters(fullname) if not result['OK']: @@ -41,10 +49,42 @@ def __init__(self, dbname, fullname, debug=False): if not self._connected: raise RuntimeError("Can not connect to DB '%s', exiting..." % self.dbName) + # Initialize version + result = self._query("show tables") + if result['OK']: + if self.versionTable not in [t[0] for t in result['Value']]: + result = self._createTables({self.versionTable: {'Fields': {'Version': 'INTEGER NOT NULL'}, + 'PrimaryKey': 'Version'}}) + if not result['OK']: + raise RuntimeError("Can not initialize %s version: %s" % (self.dbName, result['Message'])) + result = self._query("SELECT Version FROM `%s`" % self.versionTable) + if result['OK']: + if len(result['Value']) > 0: + self.versionDB = result['Value'][0][0] + else: + result = self._update("INSERT INTO `%s` (Version) VALUES (%s)" % (self.versionTable, self.versionDB)) + if not result['OK']: + raise RuntimeError("Can not initialize %s version: %s" % (self.dbName, result['Message'])) + self.log.info("===================== MySQL ======================") self.log.info("User: " + self.dbUser) self.log.info("Host: " + self.dbHost) self.log.info("Port: " + str(self.dbPort)) - # self.log.info("Password: "+ self.dbPass) + # self.log.info("Password: "+self.dbPass) self.log.info("DBName: " + self.dbName) self.log.info("==================================================") + + def updateDBVersion(self, version): + """ Update DB version + + :param int version: version number + + :return: S_OK()/S_ERROR() + """ + result = self._query('DELETE FROM `%s`' % self.versionTable) + if result['OK']: + result = self._update("INSERT INTO `%s` (Version) VALUES (%s)" % (self.versionTable, version)) + if not result['OK']: + return S_ERROR("Can not initialize %s version: %s" % (self.dbName, result['Message'])) + self.versionDB = version + return S_OK() diff --git a/src/DIRAC/Core/DISET/AuthManager.py b/src/DIRAC/Core/DISET/AuthManager.py index 847043cc987..e2d6cefb680 100755 --- a/src/DIRAC/Core/DISET/AuthManager.py +++ b/src/DIRAC/Core/DISET/AuthManager.py @@ -5,180 +5,255 @@ from __future__ import print_function import six + +from DIRAC.Core.Utilities import List +from DIRAC.Core.Security import Properties +from DIRAC.FrameworkSystem.Client.Logger import gLogger from DIRAC.ConfigurationSystem.Client.Config import gConfig from DIRAC.ConfigurationSystem.Client.Helpers import Registry -from DIRAC.FrameworkSystem.Client.Logger import gLogger -from DIRAC.Core.Security import Properties -from DIRAC.Core.Utilities import List __RCSID__ = "$Id$" +KW_ID = 'ID' +KW_DN = 'DN' +KW_GROUP = 'group' +KW_USERNAME = 'username' +KW_HOSTS_GROUP = 'hosts' +KW_PROPERTIES = 'properties' +KW_EXTRA_CREDENTIALS = 'extraCredentials' + + +def forwardingCredentials(credDict, logObj=gLogger): + """ Check whether the credentials are being forwarded by a valid source and extract it + + :param dict credDict: Credentials to ckeck + :param object logObj: logger + + :return: bool + """ + if isinstance(credDict.get(KW_EXTRA_CREDENTIALS), (tuple, list)): + retVal = Registry.getHostnameForDN(credDict.get(KW_DN)) + if not retVal['OK']: + logObj.debug("The credentials forwarded not by a host:", credDict.get(KW_DN)) + return False + hostname = retVal['Value'] + if Properties.TRUSTED_HOST not in Registry.getPropertiesForHost(hostname, []): + logObj.debug("The credentials forwarded by a %s host, but it is not a trusted one" % hostname) + return False + if credDict[KW_EXTRA_CREDENTIALS][0][0] == '/': + credDict[KW_DN] = credDict[KW_EXTRA_CREDENTIALS][0] + credDict[KW_ID] = None + else: + credDict[KW_ID] = credDict[KW_EXTRA_CREDENTIALS][0] + credDict[KW_DN] = None + credDict[KW_GROUP] = credDict[KW_EXTRA_CREDENTIALS][1] + del credDict[KW_EXTRA_CREDENTIALS] + return True + return False + + +def authorizeBySession(credDict, logObj=gLogger): + """ Discover the username associated to the authentication session. It will check if the selected group is valid. + The username will be included in the credentials dictionary. And will discover DN for group if last not set. + + :param dict credDict: Credentials to check + :param object logObj: logger + + :return: bool -- specifying whether the username was found + """ + # Find user + result = Registry.getUsernameForID(credDict[KW_ID]) + if not result['OK']: + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + return False + credDict[KW_USERNAME] = result['Value'] + return __checkGroup(credDict, logObj=gLogger) + + +def authorizeByCertificate(credDict, logObj=gLogger): + """ Discover the username associated to the certificate DN. It will check if the selected group is valid. + The username will be included in the credentials dictionary. + + :param dict credDict: Credentials to check + :param object logObj: logger + + :return: bool -- specifying whether the username was found + """ + # Search host + result = Registry.getHostnameForDN(credDict[KW_DN]) + if result['OK'] and result['Value']: + credDict[KW_USERNAME] = result['Value'] + credDict[KW_GROUP] = KW_HOSTS_GROUP + return __checkGroup(credDict, logObj=gLogger) + elif credDict.get(KW_GROUP) == KW_HOSTS_GROUP: + logObj.warn("Cannot find hostname for DN %s: %s" % (credDict[KW_DN], result['Message'])) + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + return False + + # Search user + result = Registry.getUsernameForDN(credDict[KW_DN]) + if not result['OK']: + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + return False + credDict[KW_USERNAME] = result['Value'] + return __checkGroup(credDict, logObj=gLogger) + + +def __checkGroup(credDict, logObj=gLogger): + """ Check/get default group + + :param dict credDict: Credentials to check + :param object logObj: logger + + :return: bool -- specifying whether the username was found + """ + # Find/check group + credDict[KW_PROPERTIES] = [] + if not credDict.get(KW_GROUP) or credDict[KW_GROUP] == 'visitor': + result = Registry.findDefaultGroupForUser(credDict[KW_USERNAME]) + if not result['OK']: + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + return False + credDict[KW_GROUP] = result['Value'] + if credDict[KW_GROUP] == KW_HOSTS_GROUP: + credDict[KW_PROPERTIES] = Registry.getPropertiesForHost(credDict[KW_USERNAME], []) + return True + if not Registry.getGroupsForUser(credDict[KW_USERNAME], groupsList=[credDict[KW_GROUP]]).get('Value'): + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + return False + # Get DN for user/group + result = Registry.getDNsForUsernameInGroup(credDict[KW_USERNAME], credDict[KW_GROUP], checkStatus=True) + if not result['OK']: + logObj.error(result['Message']) + credDict[KW_GROUP] = "visitor" + return False + # Set DN if authorization not througth certificate + if not credDict.get(KW_DN): + credDict[KW_DN] = result['Value'][0] + # Check if DN match for group + if credDict[KW_DN] not in result["Value"]: + logObj.error('%s DN is not match for %s group.' % (credDict[KW_DN], credDict[KW_GROUP])) + credDict[KW_GROUP] = "visitor" + return False + + # Fill group properties + credDict[KW_PROPERTIES] = Registry.getPropertiesForGroup(credDict[KW_GROUP], []) + return True + class AuthManager(object): """ Handle Service Authorization """ - __authLogger = gLogger.getSubLogger("Authorization") - KW_HOSTS_GROUP = 'hosts' - KW_DN = 'DN' - KW_GROUP = 'group' - KW_EXTRA_CREDENTIALS = 'extraCredentials' - KW_PROPERTIES = 'properties' - KW_USERNAME = 'username' def __init__(self, authSection): - """ - Constructor + """ Constructor - :type authSection: string - :param authSection: Section containing the authorization rules + :param str authSection: Section containing the authorization rules """ self.authSection = authSection def authQuery(self, methodQuery, credDict, defaultProperties=False): - """ - Check if the query is authorized for a credentials dictionary - - :type methodQuery: string - :param methodQuery: Method to test - :type credDict: dictionary - :param credDict: dictionary containing credentials for test. The dictionary can contain the DN - and selected group. - :return: Boolean result of test + """ Check if the query is authorized for a credentials dictionary + + :param str methodQuery: Method to test + :param dict credDict: dictionary containing credentials for test. The dictionary can contain the DN + and selected group. + :param defaultProperties: default properties + :type defaultProperties: list or tuple + + :return: bool -- result of test """ userString = "" - if self.KW_DN in credDict: - userString += "DN=%s" % credDict[self.KW_DN] - if self.KW_GROUP in credDict: - userString += " group=%s" % credDict[self.KW_GROUP] - if self.KW_EXTRA_CREDENTIALS in credDict: - userString += " extraCredentials=%s" % str(credDict[self.KW_EXTRA_CREDENTIALS]) + if KW_ID in credDict: + userString += " ID=%s" % credDict[KW_ID] + if KW_DN in credDict: + userString += " DN=%s" % credDict[KW_DN] + if credDict.get(KW_GROUP): + userString += " group=%s" % credDict[KW_GROUP] + if KW_EXTRA_CREDENTIALS in credDict: + userString += " extraCredentials=%s" % str(credDict[KW_EXTRA_CREDENTIALS]) self.__authLogger.debug("Trying to authenticate %s" % userString) + # Get properties requiredProperties = self.getValidPropertiesForMethod(methodQuery, defaultProperties) + lowerCaseProperties = [prop.lower() for prop in requiredProperties] or ['any'] + allowAll = "any" in lowerCaseProperties or "all" in lowerCaseProperties + # Extract valid groups validGroups = self.getValidGroups(requiredProperties) - lowerCaseProperties = [prop.lower() for prop in requiredProperties] - if not lowerCaseProperties: - lowerCaseProperties = ['any'] + self.__authLogger.info("validGroups: ", validGroups) - allowAll = "any" in lowerCaseProperties or "all" in lowerCaseProperties - # Set no properties by default - credDict[self.KW_PROPERTIES] = [] - # Check non secure backends - if self.KW_DN not in credDict or not credDict[self.KW_DN]: - if allowAll and not validGroups: - self.__authLogger.debug("Accepted request from unsecure transport") - return True + # Read extra credentials + if KW_EXTRA_CREDENTIALS in credDict: + # Is it a host? and HACK TO MAINTAIN COMPATIBILITY + if credDict.get(KW_EXTRA_CREDENTIALS) == KW_HOSTS_GROUP: + credDict[KW_GROUP] = credDict[KW_EXTRA_CREDENTIALS] + del credDict[KW_EXTRA_CREDENTIALS] + # Check if query comes though a gateway/web server + elif forwardingCredentials(credDict, logObj=self.__authLogger): + self.__authLogger.debug("Query comes from a gateway") + return self.authQuery(methodQuery, credDict, requiredProperties) else: - self.__authLogger.debug( - "Explicit property required and query seems to be coming through an unsecure transport") return False - # Check if query comes though a gateway/web server - if self.forwardedCredentials(credDict): - self.__authLogger.debug("Query comes from a gateway") - self.unpackForwardedCredentials(credDict) - return self.authQuery(methodQuery, credDict, requiredProperties) - # Get the properties - # Check for invalid forwarding - if self.KW_EXTRA_CREDENTIALS in credDict: - # Invalid forwarding? - if not isinstance(credDict[self.KW_EXTRA_CREDENTIALS], six.string_types): - self.__authLogger.debug("The credentials seem to be forwarded by a host, but it is not a trusted one") - return False - # Is it a host? - if self.KW_EXTRA_CREDENTIALS in credDict and credDict[self.KW_EXTRA_CREDENTIALS] == self.KW_HOSTS_GROUP: - # Get the nickname of the host - credDict[self.KW_GROUP] = credDict[self.KW_EXTRA_CREDENTIALS] - # HACK TO MAINTAIN COMPATIBILITY - else: - if self.KW_EXTRA_CREDENTIALS in credDict and self.KW_GROUP not in credDict: - credDict[self.KW_GROUP] = credDict[self.KW_EXTRA_CREDENTIALS] - # END OF HACK - # Get the username - if self.KW_DN in credDict and credDict[self.KW_DN]: - if self.KW_GROUP not in credDict: - result = Registry.findDefaultGroupForDN(credDict[self.KW_DN]) - if not result['OK']: - credDict[self.KW_USERNAME] = "anonymous" - credDict[self.KW_GROUP] = "visitor" - else: - credDict[self.KW_GROUP] = result['Value'] - if credDict[self.KW_GROUP] == self.KW_HOSTS_GROUP: - # For host - if not self.getHostNickName(credDict): - self.__authLogger.warn("Host is invalid") - if not allowAll: - return False - # If all, then set anon credentials - credDict[self.KW_USERNAME] = "anonymous" - credDict[self.KW_GROUP] = "visitor" + + # User/group authorization + if not credDict.get(KW_USERNAME): + if credDict.get(KW_DN): + # With certificate + authorized = authorizeByCertificate(credDict, logObj=self.__authLogger) + elif credDict.get(KW_ID): + # With IdP session + authorized = authorizeBySession(credDict, logObj=self.__authLogger) else: - # For users - username = self.getUsername(credDict) - suspended = self.isUserSuspended(credDict) - if not username: - self.__authLogger.warn("User is invalid or does not belong to the group it's saying") - if suspended: - self.__authLogger.warn("User is Suspended") - - if not username or suspended: - if not allowAll: - return False - # If all, then set anon credentials - credDict[self.KW_USERNAME] = "anonymous" - credDict[self.KW_GROUP] = "visitor" + # As visitor + credDict[KW_USERNAME] = "anonymous" + credDict[KW_GROUP] = "visitor" + authorized = False + # Marked as visitor + elif credDict[KW_USERNAME].lower() == 'anonymous' or credDict[KW_GROUP].lower() == "visitor": + authorized = False + # User/group already checked else: - if not allowAll: - return False - credDict[self.KW_USERNAME] = "anonymous" - credDict[self.KW_GROUP] = "visitor" + authorized = True - # If any or all in the props, allow - allowGroup = not validGroups or credDict[self.KW_GROUP] in validGroups - if allowAll and allowGroup: - return True - # Check authorized groups - if "authenticated" in lowerCaseProperties and allowGroup: - return True - if not self.matchProperties(credDict, requiredProperties): + # Access to the service + + # Check free access + if not allowAll and not authorized: + self.__authLogger.debug("User is invalid or does not belong to the group it's saying") + return False + # Match properties + if not self.matchProperties(credDict, list(set(requiredProperties) - set(['Any', 'any', + 'All', 'all', + 'authenticated', + 'Authenticated']))): self.__authLogger.warn("Client is not authorized\nValid properties: %s\nClient: %s" % (requiredProperties, credDict)) return False - elif not allowGroup: + # Match allowed groups + if validGroups and credDict[KW_GROUP] not in validGroups: self.__authLogger.warn("Client is not authorized\nValid groups: %s\nClient: %s" % (validGroups, credDict)) return False - return True - - def getHostNickName(self, credDict): - """ - Discover the host nickname associated to the DN. - The nickname will be included in the credentials dictionary. - - :type credDict: dictionary - :param credDict: Credentials to ckeck - :return: Boolean specifying whether the nickname was found - """ - if self.KW_DN not in credDict: - return True - if self.KW_GROUP not in credDict: - return False - retVal = Registry.getHostnameForDN(credDict[self.KW_DN]) - if not retVal['OK']: - gLogger.warn("Cannot find hostname for DN %s: %s" % (credDict[self.KW_DN], retVal['Message'])) - return False - credDict[self.KW_USERNAME] = retVal['Value'] - credDict[self.KW_PROPERTIES] = Registry.getPropertiesForHost(credDict[self.KW_USERNAME], []) + # Access allowed + if not authorized: + self.__authLogger.debug("Accepted request from unsecure transport") return True def getValidPropertiesForMethod(self, method, defaultProperties=False): - """ - Get all authorized groups for calling a method + """ Get all authorized groups for calling a method + + :param str method: Method to test + :param defaultProperties: default properties + :type defaultProperties: list or tuple - :type method: string - :param method: Method to test - :return: List containing the allowed groups + :return: list -- List containing the allowed groups """ authProps = gConfig.getValue("%s/%s" % (self.authSection, method), []) if authProps: @@ -197,11 +272,11 @@ def getValidPropertiesForMethod(self, method, defaultProperties=False): return [] def getValidGroups(self, rawProperties): - """ Get valid groups as specified in the method authorization rules + """ Get valid groups as specified in the method authorization rules + + :param list rawProperties: all method properties - :param rawProperties: all method properties - :type rawProperties: python:list - :return: list of allowed groups or [] + :return: list -- list of allowed groups """ validGroups = [] for prop in list(rawProperties): @@ -219,103 +294,26 @@ def getValidGroups(self, rawProperties): validGroups = list(set(validGroups)) return validGroups - def forwardedCredentials(self, credDict): - """ - Check whether the credentials are being forwarded by a valid source - - :type credDict: dictionary - :param credDict: Credentials to ckeck - :return: Boolean with the result - """ - if self.KW_EXTRA_CREDENTIALS in credDict and isinstance(credDict[self.KW_EXTRA_CREDENTIALS], (tuple, list)): - if self.KW_DN in credDict: - retVal = Registry.getHostnameForDN(credDict[self.KW_DN]) - if retVal['OK']: - hostname = retVal['Value'] - if Properties.TRUSTED_HOST in Registry.getPropertiesForHost(hostname, []): - return True - return False - - def unpackForwardedCredentials(self, credDict): - """ - Extract the forwarded credentials - - :type credDict: dictionary - :param credDict: Credentials to unpack - """ - credDict[self.KW_DN] = credDict[self.KW_EXTRA_CREDENTIALS][0] - credDict[self.KW_GROUP] = credDict[self.KW_EXTRA_CREDENTIALS][1] - del(credDict[self.KW_EXTRA_CREDENTIALS]) - - def getUsername(self, credDict): - """ - Discover the username associated to the DN. It will check if the selected group is valid. - The username will be included in the credentials dictionary. - - :type credDict: dictionary - :param credDict: Credentials to check - :return: Boolean specifying whether the username was found - """ - if self.KW_DN not in credDict: - return True - if self.KW_GROUP not in credDict: - result = Registry.findDefaultGroupForDN(credDict[self.KW_DN]) - if not result['OK']: - return False - credDict[self.KW_GROUP] = result['Value'] - credDict[self.KW_PROPERTIES] = Registry.getPropertiesForGroup(credDict[self.KW_GROUP], []) - usersInGroup = Registry.getUsersInGroup(credDict[self.KW_GROUP], []) - if not usersInGroup: - return False - retVal = Registry.getUsernameForDN(credDict[self.KW_DN], usersInGroup) - if retVal['OK']: - credDict[self.KW_USERNAME] = retVal['Value'] - return True - return False + def matchProperties(self, credDict, validProps, caseSensitive=False): + """ Return True if one or more properties are in the valid list of properties - def isUserSuspended(self, credDict): - """ Discover if the user is in Suspended status + :param dict credDict: credentials to match + :param list validProps: List of valid properties + :param bool caseSensitive: Map lower case properties to properties to make the check in + lowercase but return the proper case - :param dict credDict: Credentials to check - :return: Boolean True if user is Suspended + :return: bool -- specifying whether any property has matched the valid ones """ - # Update credDict if the username is not there - if self.KW_USERNAME not in credDict: - self.getUsername(credDict) - # If username or group is not known we can not judge if the user is suspended - # These cases are treated elsewhere anyway - if self.KW_USERNAME not in credDict or self.KW_GROUP not in credDict: - return False - suspendedVOList = Registry.getUserOption(credDict[self.KW_USERNAME], 'Suspended', []) - if not suspendedVOList: - return False - vo = Registry.getVOForGroup(credDict[self.KW_GROUP]) - if vo in suspendedVOList: + if not validProps: return True - return False - - def matchProperties(self, credDict, validProps, caseSensitive=False): - """ - Return True if one or more properties are in the valid list of properties - - :type props: list - :param props: List of properties to match - :type validProps: list - :param validProps: List of valid properties - :return: Boolean specifying whether any property has matched the valid ones - """ - - # HACK: Map lower case properties to properties to make the check in lowercase but return the proper case if not caseSensitive: validProps = dict((prop.lower(), prop) for prop in validProps) else: validProps = dict((prop, prop) for prop in validProps) - groupProperties = credDict[self.KW_PROPERTIES] foundProps = [] - for prop in groupProperties: + for prop in credDict[KW_PROPERTIES]: if not caseSensitive: prop = prop.lower() if prop in validProps: foundProps.append(validProps[prop]) - credDict[self.KW_PROPERTIES] = foundProps - return foundProps + return bool(foundProps) diff --git a/src/DIRAC/Core/DISET/RequestHandler.py b/src/DIRAC/Core/DISET/RequestHandler.py index 8715f323162..7c47c867105 100755 --- a/src/DIRAC/Core/DISET/RequestHandler.py +++ b/src/DIRAC/Core/DISET/RequestHandler.py @@ -10,7 +10,6 @@ import six import time import psutil -import six import DIRAC @@ -20,6 +19,8 @@ from DIRAC.ConfigurationSystem.Client.Config import gConfig from DIRAC.FrameworkSystem.Client.Logger import gLogger from DIRAC.Core.Security.Properties import CS_ADMINISTRATOR +from DIRAC.Core.DISET.AuthManager import authorizeByCertificate,\ + forwardingCredentials, authorizeBySession def getServiceOption(serviceInfo, optionName, defaultValue): @@ -100,7 +101,13 @@ def getRemoteCredentials(self): :return: Credentials dictionary of remote peer. """ - return self.__trPool.get(self.__trid).getConnectingCredentials() + credDict = self.__trPool.get(self.__trid).getConnectingCredentials() + forwardingCredentials(credDict, logObj=self.log) + if credDict.get('ID'): + authorizeBySession(credDict, logObj=self.log) + else: + authorizeByCertificate(credDict, logObj=self.log) + return credDict @classmethod def getCSOption(cls, optionName, defaultValue=False): diff --git a/src/DIRAC/Core/DISET/ThreadConfig.py b/src/DIRAC/Core/DISET/ThreadConfig.py index 77d3d19e89d..dfff5e4ccc0 100644 --- a/src/DIRAC/Core/DISET/ThreadConfig.py +++ b/src/DIRAC/Core/DISET/ThreadConfig.py @@ -27,6 +27,7 @@ def reset(self): """ Reset extra information """ self.__DN = False + self.__ID = False self.__group = False self.__deco = False self.__setup = False @@ -73,21 +74,19 @@ def getGroup(self): """ return self.__group - def setID(self, DN, group): + def setID(self, ID): """ Set user ID - :param str DN: user DN - :param str group: user group + :param str ID: user ID """ - self.__DN = DN - self.__group = group + self.__ID = ID def getID(self): """ Return user ID - :return: tuple + :return: str """ - return (self.__DN, self.__group) + return self.__ID def setSetup(self, setup): """ Set setup name @@ -108,19 +107,17 @@ def dump(self): :return: tuple """ - return (self.__DN, self.__group, self.__setup) + return (self.__DN, self.__group, self.__setup, self.__ID) def load(self, tp): """ Save extra information - :param tuple tp: contain DN, group name, setup name + :param tuple tp: contains DN, group name, setup name, userID """ - if tp[0]: - self.__DN = tp[0] - if tp[1]: - self.__group = tp[1] - if tp[2]: - self.__setup = tp[2] + self.__ID = tp[3] or self.__ID + self.__DN = tp[0] or self.__DN + self.__group = tp[1] or self.__group + self.__setup = tp[2] or self.__setup def threadDeco(method): diff --git a/src/DIRAC/Core/DISET/private/BaseClient.py b/src/DIRAC/Core/DISET/private/BaseClient.py index 7df81ad7bba..50a48b75239 100755 --- a/src/DIRAC/Core/DISET/private/BaseClient.py +++ b/src/DIRAC/Core/DISET/private/BaseClient.py @@ -15,7 +15,7 @@ import DIRAC from DIRAC.Core.DISET.private.Protocols import gProtocolDict from DIRAC.FrameworkSystem.Client.Logger import gLogger -from DIRAC.Core.Utilities import List, Network +from DIRAC.Core.Utilities import List, Network, DErrno from DIRAC.Core.Utilities.ReturnValues import S_OK, S_ERROR from DIRAC.ConfigurationSystem.Client.Config import gConfig from DIRAC.ConfigurationSystem.Client.PathFinder import getServiceURL, getServiceFailoverURL @@ -37,6 +37,7 @@ class BaseClient(object): KW_TIMEOUT = "timeout" KW_SETUP = "setup" KW_VO = "VO" + KW_DELEGATED_ID = "delegatedID" KW_DELEGATED_DN = "delegatedDN" KW_DELEGATED_GROUP = "delegatedGroup" KW_IGNORE_GATEWAYS = "ignoreGateways" @@ -259,17 +260,24 @@ def __discoverExtraCredentials(self): self.__extraCredentials = self.kwargs[self.KW_EXTRA_CREDENTIALS] # Are we delegating something? + delegatedID = self.kwargs.get(self.KW_DELEGATED_ID) or self.__threadConfig.getID() delegatedDN = self.kwargs.get(self.KW_DELEGATED_DN) or self.__threadConfig.getDN() delegatedGroup = self.kwargs.get(self.KW_DELEGATED_GROUP) or self.__threadConfig.getGroup() + + if delegatedID: + self.kwargs[self.KW_DELEGATED_ID] = delegatedID if delegatedDN: self.kwargs[self.KW_DELEGATED_DN] = delegatedDN + if delegatedGroup: if not delegatedGroup: result = Registry.findDefaultGroupForDN(delegatedDN) if not result['OK']: return result delegatedGroup = result['Value'] - self.kwargs[self.KW_DELEGATED_GROUP] = delegatedGroup - self.__extraCredentials = (delegatedDN, delegatedGroup) + self.kwargs[self.KW_DELEGATED_GROUP] = delegatedGroup + + if delegatedID or delegatedDN: + self.__extraCredentials = (delegatedDN or delegatedID, delegatedGroup) return S_OK() def __findServiceURL(self): @@ -506,10 +514,10 @@ def _connect(self): # try to reconnect return self._connect() else: - return retVal + return S_ERROR(DErrno.ECONNECT, retVal['Message']) except Exception as e: gLogger.exception(lException=True, lExcInfo=True) - return S_ERROR("Can't connect to %s: %s" % (self.serviceURL, repr(e))) + return S_ERROR(DErrno.ECONNECT, "Can't connect to %s: %s" % (self.serviceURL, repr(e))) # We add the connection to the transport pool gLogger.debug("Connected to: %s" % self.serviceURL) trid = getGlobalTransportPool().add(transport) diff --git a/src/DIRAC/Core/DISET/test/Test_AuthManager.py b/src/DIRAC/Core/DISET/test/Test_AuthManager.py index 56645994f00..0a833d26f8a 100644 --- a/src/DIRAC/Core/DISET/test/Test_AuthManager.py +++ b/src/DIRAC/Core/DISET/test/Test_AuthManager.py @@ -4,14 +4,44 @@ from __future__ import division from __future__ import print_function +import os +import mock +import pickle import unittest from diraccfg import CFG -from DIRAC import gConfig +from DIRAC import gConfig, rootPath, S_OK, S_ERROR from DIRAC.Core.DISET.AuthManager import AuthManager __RCSID__ = "$Id$" +workDir = os.path.join(gConfig.getValue('/LocalSite/InstancePath', rootPath), 'work/ProxyManager') + +voDict = { + 'testVO': S_OK({ + '/User/test/DN/CN=userS': { + 'Suspended': True, + 'VOMSRoles': [u'/testVO'], + 'ActiveRoles': [], + 'SuspendedRoles': [u'/testVO'] + }, + '/User/test/DN/CN=userA': { + 'Suspended': False, + 'VOMSRoles': [u'/testVO'], + 'ActiveRoles': [u'/testVO'], + 'SuspendedRoles': [] + } + }), + 'testVOOther': S_OK({ + '/User/test/DN/CN=userS': { + 'Suspended': False, + 'VOMSRoles': [u'/testVOOther'], + 'ActiveRoles': [u'/testVOOther'], + 'SuspendedRoles': [] + } + }) +} + testSystemsCFG = """ Systems { @@ -46,6 +76,7 @@ testVO { VOAdmin = userA + VOMSName = testVO } testVOBad { @@ -54,6 +85,7 @@ testVOOther { VOAdmin = userA + VOMSName = testVOOther } } Users @@ -91,6 +123,7 @@ { Users = userA, userS VO = testVO + VOMSRole = /testVO Properties = NormalUser } group_test_other @@ -110,9 +143,28 @@ """ +def sf_getVOMSInfo(*args, **kwargs): + return S_OK(voDict) + + +@mock.patch('DIRAC.ConfigurationSystem.Client.Helpers.Registry.gProxyManagerData.getActualVOMSesDNs', + new=sf_getVOMSInfo) class AuthManagerTest(unittest.TestCase): """ Base class for the Modules test cases """ + @classmethod + def setUpClass(cls): + if not os.path.exists(workDir): + os.makedirs(workDir) + for vo, infoDict in voDict.items(): + with open(os.path.join(workDir, vo + '.pkl'), 'wb+') as f: + pickle.dump(infoDict, f, pickle.HIGHEST_PROTOCOL) + + @classmethod + def tearDownClass(cls): + for vo in voDict.keys(): + if os.path.exists(os.path.join(workDir, vo + '.pkl')): + os.remove(os.path.join(workDir, vo + '.pkl')) def setUp(self): self.authMgr = AuthManager('/Systems/Service/Authorization') diff --git a/src/DIRAC/Core/Security/Locations.py b/src/DIRAC/Core/Security/Locations.py index fe9e575aca6..0cd8b7808c6 100644 --- a/src/DIRAC/Core/Security/Locations.py +++ b/src/DIRAC/Core/Security/Locations.py @@ -29,7 +29,19 @@ def getProxyLocation(): # No gridproxy found return False -# Retrieve CA's location + +def getPrivateKeyLocation(): + """ Get the path of the currently active private key(for auth) + """ + # Grid-Security + retVal = gConfig.getOption('%s/Grid-Security' % g_SecurityConfPath) + if retVal['OK']: + keyPath = "%s/private.pem" % retVal['Value'] + if os.path.isfile(keyPath): + return keyPath + + # No private key found + return False def getCAsLocation(): diff --git a/src/DIRAC/Core/Security/VOMSService.py b/src/DIRAC/Core/Security/VOMSService.py index f949b5853aa..667319c9b6c 100644 --- a/src/DIRAC/Core/Security/VOMSService.py +++ b/src/DIRAC/Core/Security/VOMSService.py @@ -49,7 +49,7 @@ def attGetUserNickname(self, dn, _ca=None): :param str dn: user DN :param str _ca: CA, kept for backward compatibility - :return: S_OK with Value: nickname + :return: S_OK with Value: nickname """ if self.userDict is None: @@ -65,16 +65,18 @@ def attGetUserNickname(self, dn, _ca=None): return S_ERROR(DErrno.EVOMS, "No nickname defined") return S_OK(nickname) - def getUsers(self): + def getUsers(self, proxyPath=None): """ Get all the users of the VOMS VO with their detailed information + :param str proxyPath: proxy path + :return: user dictionary keyed by the user DN """ if not self.urls: return S_ERROR(DErrno.ENOAUTH, "No VOMS server defined") - userProxy = getProxyLocation() + userProxy = proxyPath or getProxyLocation() caPath = getCAsLocation() rawUserList = [] result = None @@ -125,7 +127,6 @@ def getUsers(self): resultDict[dn] = user resultDict[dn]['CA'] = cert['issuerString'] resultDict[dn]['certSuspended'] = cert.get('suspended') - resultDict[dn]['suspended'] = user.get('suspended') resultDict[dn]['mail'] = user.get('emailAddress') resultDict[dn]['Roles'] = user.get('fqans') attributes = user.get('attributes') diff --git a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py index a8cfa822328..f7670438082 100644 --- a/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py +++ b/src/DIRAC/Core/Tornado/Client/private/TornadoBaseClient.py @@ -69,6 +69,7 @@ class TornadoBaseClient(object): KW_TIMEOUT = "timeout" KW_SETUP = "setup" KW_VO = "VO" + KW_DELEGATED_ID = "delegatedID" KW_DELEGATED_DN = "delegatedDN" KW_DELEGATED_GROUP = "delegatedGroup" KW_IGNORE_GATEWAYS = "ignoreGateways" @@ -100,6 +101,8 @@ def __init__(self, serviceName, **kwargs): raise TypeError("Service name expected to be a string. Received %s type %s" % (str(serviceName), type(serviceName))) + self.client = None + self._destinationSrv = serviceName self._serviceName = serviceName self.__ca_location = False @@ -220,6 +223,11 @@ def __discoverCredentialsToUse(self): self.kwargs[self.KW_SKIP_CA_CHECK] = False else: self.kwargs[self.KW_SKIP_CA_CHECK] = skipCACheck() + + if not self.__useCertificates: + if os.environ.get('DIRAC_TOKEN') and os.environ.get('DIRAC_TRY_USE_TOKEN'): + self.client = IdProviderFactory().getIdProviderForToken(os.environ['DIRAC_TOKEN']) + self.client.token = os.environ['DIRAC_TOKEN'] # Rewrite a little bit from here: don't need the proxy string, we use the file if self.KW_PROXY_CHAIN in self.kwargs: @@ -249,28 +257,28 @@ def __discoverExtraCredentials(self): WARNING: COPY/PASTE FROM Core/Diset/private/BaseClient """ # which extra credentials to use? - if self.__useCertificates: - self.__extraCredentials = self.VAL_EXTRA_CREDENTIALS_HOST - else: - self.__extraCredentials = "" + self.__extraCredentials = self.VAL_EXTRA_CREDENTIALS_HOST if self.__useCertificates else "" if self.KW_EXTRA_CREDENTIALS in self.kwargs: self.__extraCredentials = self.kwargs[self.KW_EXTRA_CREDENTIALS] + # Are we delegating something? - delegatedDN, delegatedGroup = self.__threadConfig.getID() - if self.KW_DELEGATED_DN in self.kwargs and self.kwargs[self.KW_DELEGATED_DN]: - delegatedDN = self.kwargs[self.KW_DELEGATED_DN] - elif delegatedDN: + delegatedID = self.kwargs.get(self.KW_DELEGATED_ID) or self.__threadConfig.getID() + delegatedDN = self.kwargs.get(self.KW_DELEGATED_DN) or self.__threadConfig.getDN() + delegatedGroup = self.kwargs.get(self.KW_DELEGATED_GROUP) or self.__threadConfig.getGroup() + + if delegatedID: + self.kwargs[self.KW_DELEGATED_ID] = delegatedID + if delegatedDN: self.kwargs[self.KW_DELEGATED_DN] = delegatedDN - if self.KW_DELEGATED_GROUP in self.kwargs and self.kwargs[self.KW_DELEGATED_GROUP]: - delegatedGroup = self.kwargs[self.KW_DELEGATED_GROUP] - elif delegatedGroup: + if delegatedGroup: self.kwargs[self.KW_DELEGATED_GROUP] = delegatedGroup - if delegatedDN: + + if delegatedID or delegatedDN: if not delegatedGroup: result = findDefaultGroupForDN(self.kwargs[self.KW_DELEGATED_DN]) if not result['OK']: return result - self.__extraCredentials = (delegatedDN, delegatedGroup) + self.__extraCredentials = (delegatedID or delegatedDN, delegatedGroup) return S_OK() def __discoverTimeout(self): @@ -500,16 +508,23 @@ def _request(self, retry=0, outputFile=None, **kwargs): gLogger.error("No CAs found!") return S_ERROR("No CAs found!") verify = self.__ca_location - + # getting certificate # Do we use the server certificate ? if self.kwargs[self.KW_USE_CERTIFICATES]: - cert = Locations.getHostCertificateAndKeyLocation() + auth = {'cert': Locations.getHostCertificateAndKeyLocation()} + + # Use access token? + elif os.environ.get('DIRAC_TOKEN') and os.environ.get('DIRAC_TRY_USE_TOKEN'): + # TODO: idp check and refresh tokens + self.client.fetch_access_token() + auth = {'headers': {"Authorization": "Bearer %s" % self.client.token['access_token']}} + # CHRIS 04.02.21 # TODO: add proxyLocation check ? else: - cert = Locations.getProxyLocation() - if not cert: + auth = {'cert': Locations.getProxyLocation()} + if not auth['cert']: gLogger.error("No proxy found") return S_ERROR("No proxy found") @@ -526,7 +541,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): if not outputFile: call = requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, - cert=cert) + **auth) # raising the exception for status here # means essentialy that we are losing here the information of what is returned by the server # as error message, since it is not passed to the exception @@ -546,7 +561,7 @@ def _request(self, retry=0, outputFile=None, **kwargs): # Stream download # https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow with requests.post(url, data=kwargs, timeout=self.timeout, verify=verify, - cert=cert, stream=True) as r: + stream=True, **auth) as r: rawText = r.text r.raise_for_status() diff --git a/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py new file mode 100644 index 00000000000..2af957930d8 --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/BaseRequestHandler.py @@ -0,0 +1,827 @@ +""" BaseRequestHandler is the base class for tornados services and etc handlers. + It directly inherits from :py:class:`tornado.web.RequestHandler` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +from io import open + +import jwt +# from jwt import PyJWKClient + +import os +import time +import threading +from six import string_types +from datetime import datetime +from six.moves import http_client +from six.moves.urllib.parse import unquote + +import tornado +from tornado import gen +from tornado.web import RequestHandler, HTTPError +from tornado.ioloop import IOLoop +from tornado.httpclient import HTTPResponse +from tornado.concurrent import Future + +import DIRAC + +from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.DISET.AuthManager import AuthManager +from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +sLog = gLogger.getSubLogger(__name__.split('.')[-1]) + + +class BaseRequestHandler(RequestHandler): + # Because we initialize at first request, we use a flag to know if it's already done + __init_done = False + # Lock to make sure that two threads are not initializing at the same time + __init_lock = threading.RLock() + + # MonitoringClient, we don't use gMonitor which is not thread-safe + # We also need to add specific attributes for each service + _monitor = None + + # System name with which this component is associated + SYSTEM = None + + # If need to return HTTP error instead DIRAC error, e.g.: in REST endpoints + RAISE_DIRAC_ERROR = False + + # Auth requirements + AUTH_PROPS = None + + # Type of component + MONITORING_COMPONENT = MonitoringClient.COMPONENT_WEB + + # Prefix of methods names + METHOD_PREFIX = "export_" + + # Which grant type to use + USE_AUTHZ_GRANTS = ['SSL', 'JWT'] + + @classmethod + def _initMonitoring(cls, serviceName, fullUrl): + """ + Initialize the monitoring specific to this handler + This has to be called only by :py:meth:`.__initializeService` + to ensure thread safety and unicity of the call. + + :param serviceName: relative URL ``//`` + :param fullUrl: full URl like ``https://://`` + """ + + # Init extra bits of monitoring + + cls._monitor = MonitoringClient() + cls._monitor.setComponentType(cls.MONITORING_COMPONENT) + + cls._monitor.initialize() + + if tornado.process.task_id() is None: # Single process mode + cls._monitor.setComponentName('Tornado/%s' % serviceName) + else: + cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) + + cls._monitor.setComponentLocation(fullUrl) + + cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) + + cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) + cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) + cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) + + cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} + + return S_OK() + + @classmethod + def _getServiceName(cls, request): + """ Search service name in request. + + :param object request: tornado Request + + :return: str + """ + raise NotImplementedError('Please, create the _getServiceName class method') + + @classmethod + def _getServiceAuthSection(cls, serviceName): + """ Search service auth section. + + :param str serviceName: service name + + :return: str + """ + return "%s/Authorization" % PathFinder.getServiceSection(serviceName) + + @classmethod + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. + + :param str serviceName: service name + :param object request: tornado Request + + :return: dict + """ + return {} + + @classmethod + def __initializeService(cls, request): + """ + Initialize a service. + The work is only perform once at the first request. + + :param object request: tornado Request + + :returns: S_OK + """ + # If the initialization was already done successfuly, + # we can just return + if cls.__init_done: + return S_OK() + + # Otherwise, do the work but with a lock + with cls.__init_lock: + + # Check again that the initialization was not done by another thread + # while we were waiting for the lock + if cls.__init_done: + return S_OK() + + cls._idps = {} + + # Set Identity Providers + idps = IdProviderFactory() + result = getProvidersForInstance('Id') + if result['OK']: + for providerName in result['Value']: + result = idps.getIdProvider(providerName) + if not result['OK']: + break + idpObj = result['Value'] + cls._idps[idpObj.metadata['issuer'].strip('/')] = idpObj + if not result['OK']: + raise Exception("There was a problem loading Identity Providers: %s" % result['Message']) + + + # absoluteUrl: full URL e.g. ``https://://`` + absoluteUrl = request.path + serviceName = cls._getServiceName(request) + + cls._startTime = datetime.utcnow() + sLog.info("First use of %s, initializing service..." % serviceName) + cls._authManager = AuthManager(cls._getServiceAuthSection(serviceName)) + + cls._initMonitoring(serviceName, absoluteUrl) + + cls._serviceName = serviceName + cls._validNames = [serviceName] + serviceInfo = cls._getServiceInfo(serviceName, request) + + cls._serviceInfoDict = serviceInfo + + cls.__monitorLastStatsUpdate = time.time() + + # Some pre-initialization + cls._initializeHandler() + + cls.initializeHandler(serviceInfo) + + cls.__init_done = True + + return S_OK() + + @classmethod + def _initializeHandler(cls): + """ + If you are writing your own framework that follows this class + and you need to add something before initializing the service, + such as initializing the OAuth client, then you need to change this method. + """ + pass + + @classmethod + def initializeHandler(cls, serviceInfo): + """ + This may be overwritten when you write a DIRAC service handler + And it must be a class method. This method is called only one time, + at the first request + + :param dict ServiceInfoDict: infos about services, it contains + 'serviceName', 'serviceSectionPath', + 'csPaths' and 'URL' + """ + pass + + def initializeRequest(self): + """ + Called at every request, may be overwritten in your handler. + """ + pass + + # This is a Tornado magic method + def initialize(self): # pylint: disable=arguments-differ + """ + Initialize the handler, called at every request. + + It just calls :py:meth:`.__initializeService` + + If anything goes wrong, the client will get ``Connection aborted`` + error. See details inside the method. + + ..warning:: + DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER + ==> initialize in DISET became initializeRequest in HTTPS ! + """ + # Only initialized once + if not self.__init_done: + # Ideally, if something goes wrong, we would like to return a Server Error 500 + # but this method cannot write back to the client as per the + # `tornado doc `_. + # So the client will get a ``Connection aborted``` + try: + res = self.__initializeService(self.request) + if not res['OK']: + raise Exception(res['Message']) + except Exception as e: + sLog.error("Error in initialization", repr(e)) + raise + + def _monitorRequest(self): + """ Monitor action for each request + """ + self._stats['requests'] += 1 + self._monitor.setComponentExtraParam('queries', self._stats['requests']) + self._monitor.addMark("Queries") + + def _getMethodName(self): + """ Parse method name. + + :return: str + """ + raise NotImplementedError('Please, create the _getMethodName method') + + def _getMethodArgs(self, args): + """ Decode args. + + :return: list + """ + return args + + def _getMethodAuthProps(self): + """ Resolves the hard coded authorization requirements for method. + + :return: object + """ + try: + return getattr(self, 'auth_' + self.method) + except AttributeError: + if not isinstance(self.AUTH_PROPS, (list, tuple)): + self.AUTH_PROPS = [p.strip() for p in self.AUTH_PROPS.split(",") if p.strip()] + return self.AUTH_PROPS + + def _getMethod(self): + """ Get method object. + + :return: object + """ + try: + return getattr(self, '%s%s' % (self.METHOD_PREFIX, self.method)) + except AttributeError as e: + sLog.error("Invalid method", self.method) + raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) + + def prepare(self): + """ + Tornados prepare method that called before request + """ + + # "method" argument of the POST call. + # This resolves into the ``export_`` method + # on the handler side + # If the argument is not available, the method exists + # and an error 400 ``Bad Request`` is returned to the client + self.method = self._getMethodName() + + self._monitorRequest() + + def _prepare(self): + """ + Prepare the request. It reads certificates and check authorizations. + We make the assumption that there is always going to be a ``method`` argument + regardless of the HTTP method used + + """ + + try: + self.credDict = self._gatherPeerCredentials() + except Exception as e: # pylint: disable=broad-except + # If an error occur when reading certificates we close connection + # It can be strange but the RFC, for HTTP, say's that when error happend + # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN + sLog.debug(str(e)) + sLog.error( + "Error gathering credentials ", "%s; path %s" % + (self.getRemoteAddress(), self.request.path)) + raise HTTPError(status_code=http_client.UNAUTHORIZED) + + # Check whether we are authorized to perform the query + # Note that performing the authQuery modifies the credDict... + authorized = self._authManager.authQuery(self.method, self.credDict, + self._getMethodAuthProps()) + if not authorized: + extraInfo = '' + if self.credDict.get('DN'): + extraInfo += 'DN: %s' % self.credDict['DN'] + if self.credDict.get('ID'): + extraInfo += 'ID: %s' % self.credDict['ID'] + sLog.error( + "Unauthorized access", "Identity %s; path %s; %s" % + (self.srv_getFormattedRemoteCredentials(), + self.request.path, extraInfo)) + raise HTTPError(status_code=http_client.UNAUTHORIZED) + + # Make post a coroutine. + # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines + # for details + @gen.coroutine + def post(self, *args, **kwargs): # pylint: disable=arguments-differ + """ + Method to handle incoming ``POST`` requests. + Note that all the arguments are already prepared in the :py:meth:`.prepare` + method. + + The ``POST`` arguments expected are: + + * ``method``: name of the method to call + * ``args``: JSON encoded arguments for the method + * ``extraCredentials``: (optional) Extra informations to authenticate client + * ``rawContent``: (optionnal, default False) If set to True, return the raw output + of the method called. + + If ``rawContent`` was requested by the client, the ``Content-Type`` + is ``application/octet-stream``, otherwise we set it to ``application/json`` + and JEncode retVal. + + If ``retVal`` is a dictionary that contains a ``Callstack`` item, + it is removed, not to leak internal information. + + + Example of call using ``requests``:: + + In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' + ...: cert = '/tmp/x509up_u1000' + ...: kwargs = {'method':'whoami'} + ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' + ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: + ...: print r.json() + ...: + {u'OK': True, + u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'group': u'dirac_user', + u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'isLimitedProxy': False, + u'isProxy': True, + u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch', + u'properties': [u'NormalUser'], + u'secondsLeft': 85441, + u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch/CN=2409820262', + u'username': u'adminusername', + u'validDN': False, + u'validGroup': False}} + """ + # Execute the method in an executor (basically a separate thread) + # Because of that, we cannot calls certain methods like `self.write` + # in _executeMethod. This is because these methods are not threadsafe + # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + # However, we can still rely on instance attributes to store what should + # be sent back (reminder: there is an instance + # of this class created for each request) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) + + @gen.coroutine + def _executeMethod(self, args): + """ + Execute the method called, this method is ran in an executor + We have several try except to catch the different problem which can occur + + - First, the method does not exist => Attribute error, return an error to client + - second, anything happend during execution => General Exception, send error to client + + .. warning:: + This method is called in an executor, and so cannot use methods like self.write + See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes + """ + self._prepare() + + sLog.notice( + "Incoming request %s /%s: %s" % + (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + self.method)) + + # getting method + method = self._getMethod() + methodArgs = self._getMethodArgs(args) + + # Execute + try: + self.initializeRequest() + retVal = method(*methodArgs) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) + raise HTTPError(http_client.INTERNAL_SERVER_ERROR) + + return retVal + + def _finishFuture(self, retVal): + """ Handler Future result + + :param object retVal: tornado.concurrent.Future + """ + + # Wait result only if it's a Future object + self.result = retVal.result() if isinstance(retVal, Future) else retVal + + # Here it is safe to write back to the client, because we are not + # in a thread anymore + + # Is it S_OK or S_ERROR + if isinstance(self.result, dict) and isinstance(self.result.get('OK'), bool) and ('Value' if self.result['OK'] else 'Message') in self.result: + self._parseDIRACResult(self.result) + + # If set to true, do not JEncode the return of the RPC call + # This is basically only used for file download through + # the 'streamToClient' method. + elif self.get_argument('rawContent', default=False): + # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt + self.set_header("Content-Type", "application/octet-stream") + self.write(self.result) + + # Return simple text or html + elif isinstance(self.result, string_types): + self.write(self.result) + + # JSON + else: + # from authlib.consts import default_json_headers + self.set_header("Content-Type", "application/json") + self.write(encode(self.result)) + + self.finish() + + def _parseDIRACResult(self, result): + """ Processing of a standard DIRAC result, + but in a separate method so that it can be modified for another class if necessary + """ + self.set_header("Content-Type", "application/json") + self.write(encode(result)) + + # # Parse HTTPResponse object, e.g.: in AuthHandler endpoint + # if isinstance(self.result, HTTPResponse): + # self.set_status(self.result.code) + # for key in self.result.headers: + # self.set_header(key, self.result.headers[key]) + # if self.result.body: + # self.write(self.result.body) + + # # If set to true, do not JEncode the return of the RPC call + # # This is basically only used for file download through + # # the 'streamToClient' method. + # elif self.get_argument('rawContent', default=False): + # # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt + # self.set_header("Content-Type", "application/octet-stream") + # self.write(self.result) + + # # Return simple text or html + # elif isinstance(self.result, str): + # self.write(self.result) + + # # DIRAC JSON + # else: + # # Convert DIRAC error to HTTP error and return only result 'Value' + # if self.RAISE_DIRAC_ERROR: + # if not self.result['OK']: + # sLog.error(self.result['Message']) + # raise HTTPError(http_client.INTERNAL_SERVER_ERROR, self.result['Message']) + # self.result = self.result['Value'] + # if isinstance(self.result, str): + # self.write(self.result) + # else: + # self.set_header("Content-Type", "application/json") + # self.write(encode(self.result)) + + # self.finish() + + def on_finish(self): + """ + Called after the end of HTTP request. + Log the request duration + """ + elapsedTime = 1000.0 * self.request.request_time() + + argsString = "OK" + try: + if not self.result['OK']: + argsString = "ERROR: %s" % self.result['Message'] + except (AttributeError, KeyError, TypeError): # In case it is not a DIRAC structure + if self._reason != 'OK': + argsString = 'ERROR %s' % self._reason + + sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), + self._serviceName, + elapsedTime, argsString)) + + def _gatherPeerCredentials(self, grants=None): + """ Returne a dictionary designed to work with the AuthManager, + already written for DISET and re-used for HTTPS. + + :param list grants: grants to use + + :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` + (not a DIRAC structure !) + """ + err = [] + result = None + + grants = grants or self.USE_AUTHZ_GRANTS + + if not grants: + raise Exception('USE_AUTHZ_GRANTS is not defined.') + + for a in grants: + grant = a.upper() + try: + result = eval('self._authz%s' % grant)() + except AttributeError: + raise Exception('%s authentication type is not supported.' % grant) + + if result['OK']: + for e in err: + sLog.debug(e) + sLog.debug('%s authentication success.' % grant) + return result['Value'] + err.append('%s authentication: %s' % (grant, result['Message'])) + + # Report on failed authentication attempts + raise Exception('; '.join(err)) + + def _authzSSL(self): + """ Load client certchain in DIRAC and extract informations. + + :return: S_OK(dict)/S_ERROR() + """ + peerChain = X509Chain() + derCert = self.request.get_ssl_certificate() + + # Get client certificate pem + if derCert: + chainAsText = derCert.as_pem() + # Here we read all certificate chain + cert_chain = self.request.get_ssl_certificate_chain() + for cert in cert_chain: + chainAsText = cert.as_pem() + elif self.request.headers.get('X-Ssl_client_verify') == 'SUCCESS': + chainAsTextEncoded = self.request.headers.get('X-SSL-CERT') + chainAsText = unquote(chainAsTextEncoded) + else: + return S_ERROR('Not found a valide client certificate.') + + peerChain.loadChainFromString(chainAsText) + + # Retrieve the credentials + res = peerChain.getCredentials(withRegistryInfo=False) + if not res['OK']: + return res + + credDict = res['Value'] + + # We check if client sends extra credentials... + if "extraCredentials" in self.request.arguments: + extraCred = self.get_argument("extraCredentials") + if extraCred: + credDict['extraCredentials'] = decode(extraCred)[0] + return S_OK(credDict) + + def _authzJWT(self): + """ Load token claims in DIRAC and extract informations. + + :return: S_OK(dict)/S_ERROR() + """ + # Export token + token = self.request.headers.get('Authorization') + if not token or len(token.split()) != 2: + return S_ERROR('Not found a bearer access token.') + tokenType, accessToken = token.split() + if tokenType.lower() != 'bearer': + return S_ERROR('Found a not bearer access token.') + + # Read token without verification to get issuer + issuer = jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') + + if not self._idps.get(issuer): + return S_ERROR('%s issuer not registred in DIRAC.' % issuer) + + payload = self._idps[issuer].verifyToken(accessToken) + + # {'ID':.., 'group':.., 'provider':..} + credDict = self._idps[issuer].researchGroup(payload, accessToken) + credDict['token'] = accessToken + return S_OK(credDict) + + def _authzVISITOR(self): + """ Visitor access + + :return: S_OK(dict) + """ + return S_OK({}) + + @property + def log(self): + return sLog + + def getDN(self): + return self.credDict.get('DN', '') + + def getID(self): + return self.credDict.get('ID', '') + + def getUserName(self): + return self.credDict.get('username', '') + + def getUserGroup(self): + return self.credDict.get('group', '') + + def getProperties(self): + return self.credDict.get('properties', []) + + def isRegisteredUser(self): + return self.credDict.get('username', 'anonymous') != 'anonymous' and self.credDict.get('group') + + auth_ping = ['all'] + + def export_ping(self): + """ + Default ping method, returns some info about server. + + It returns the exact same information as DISET, for transparency purpose. + """ + # COPY FROM DIRAC.Core.DISET.RequestHandler + dInfo = {} + dInfo['version'] = DIRAC.version + dInfo['time'] = datetime.utcnow() + # Uptime + try: + with open("/proc/uptime", 'rt') as oFD: + iUptime = int(float(oFD.readline().split()[0].strip())) + dInfo['host uptime'] = iUptime + except Exception: # pylint: disable=broad-except + pass + startTime = self._startTime + dInfo['service start time'] = self._startTime + serviceUptime = datetime.utcnow() - startTime + dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds + # Load average + try: + with open("/proc/loadavg", 'rt') as oFD: + dInfo['load'] = " ".join(oFD.read().split()[:3]) + except Exception: # pylint: disable=broad-except + pass + dInfo['name'] = self._serviceInfoDict['serviceName'] + stTimes = os.times() + dInfo['cpu times'] = {'user time': stTimes[0], + 'system time': stTimes[1], + 'children user time': stTimes[2], + 'children system time': stTimes[3], + 'elapsed real time': stTimes[4] + } + + return S_OK(dInfo) + + auth_echo = ['all'] + + @staticmethod + def export_echo(data): + """ + This method used for testing the performance of a service + """ + return S_OK(data) + + auth_whoami = ['authenticated'] + + def export_whoami(self): + """ + A simple whoami, returns all credential dictionary, except certificate chain object. + """ + credDict = self.srv_getRemoteCredentials() + if 'x509Chain' in credDict: + # Not serializable + del credDict['x509Chain'] + return S_OK(credDict) + + @classmethod + def srv_getCSOption(cls, optionName, defaultValue=False): + """ + Get an option from the CS section of the services + + :return: Value for serviceSection/optionName in the CS being defaultValue the default + """ + if optionName[0] == "/": + return gConfig.getValue(optionName, defaultValue) + for csPath in cls._serviceInfoDict['csPaths']: + result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) + if result['OK']: + return result['Value'] + return defaultValue + + def getCSOption(self, optionName, defaultValue=False): + """ + Just for keeping same public interface + """ + return self.srv_getCSOption(optionName, defaultValue) + + def srv_getRemoteAddress(self): + """ + Get the address of the remote peer. + + :return: Address of remote peer. + """ + + remote_ip = self.request.remote_ip + # Although it would be trivial to add this attribute in _HTTPRequestContext, + # Tornado won't release anymore 5.1 series, so go the hacky way + try: + remote_port = self.request.connection.stream.socket.getpeername()[1] + except Exception: # pylint: disable=broad-except + remote_port = 0 + + return (remote_ip, remote_port) + + def getRemoteAddress(self): + """ + Just for keeping same public interface + """ + return self.srv_getRemoteAddress() + + def srv_getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def getRemoteCredentials(self): + """ + Get the credentials of the remote peer. + + :return: Credentials dictionary of remote peer. + """ + return self.credDict + + def srv_getFormattedRemoteCredentials(self): + """ + Return the DN of user + + Mostly copy paste from + :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` + + Note that the information will be complete only once the AuthManager was called + """ + address = self.getRemoteAddress() + peerId = "" + # Depending on where this is call, it may be that credDict is not yet filled. + # (reminder: AuthQuery fills part of it..) + try: + peerId = "[%s:%s]" % (self.credDict['group'], self.credDict['username']) + except AttributeError: + pass + + if address[0].find(":") > -1: + return "([%s]:%s)%s" % (address[0], address[1], peerId) + return "(%s:%s)%s" % (address[0], address[1], peerId) + + def srv_getServiceName(self): + """ + Return the service name + """ + return self._serviceInfoDict['serviceName'] + + def srv_getURL(self): + """ + Return the URL + """ + return self.request.path diff --git a/src/DIRAC/Core/Tornado/Server/HandlerManager.py b/src/DIRAC/Core/Tornado/Server/HandlerManager.py index 1f4397f63e9..b1b9cca2be0 100644 --- a/src/DIRAC/Core/Tornado/Server/HandlerManager.py +++ b/src/DIRAC/Core/Tornado/Server/HandlerManager.py @@ -9,6 +9,8 @@ __RCSID__ = "$Id$" +from six import string_types +import inspect from tornado.web import url as TornadoURL, RequestHandler from DIRAC import gConfig, gLogger, S_ERROR, S_OK @@ -50,109 +52,219 @@ class HandlerManager(object): ``System/Component`` (e.g. ``DataManagement/FileCatalog``) """ - def __init__(self, autoDiscovery=True): + def __init__(self, services, endpoints): """ - Initialization function, you can set autoDiscovery=False to prevent automatic - discovery of handler. If disabled you can use loadHandlersByServiceName() to - load your handlers or loadHandlerInHandlerManager() - - :param autoDiscovery: (default True) Disable the automatic discovery, - can be used to choose service we want to load. + Initialization function, you can set False for both arguments to prevent automatic + discovery of handlers and use `loadServicesHandlers()` to + load your handlers or `loadEndpointsHandlers()` + + :param services: List of service handlers to load. + If ``True``, loads all services from CS + :type services: bool or list + :param endpoints: List of endpoint handlers to load. + If ``True``, loads all endpoints from CS + :type endpoints: bool or list """ + self.loader = None self.__handlers = {} + self.__services = services + self.__endpoints = endpoints self.__objectLoader = ObjectLoader() - self.__autoDiscovery = autoDiscovery - self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") - def __addHandler(self, handlerTuple, url=None): + def __addHandler(self, handlerPath, handler, urls=None, port=None): """ Function which add handler to list of known handlers + :param str handlerPath: module name, e.g.: `Framework/Auth` + :param object handler: handler class + :param list urls: request path + :param int port: port - :param handlerTuple: (path, class) + :return: S_OK()/S_ERROR() """ - # Check if handler not already loaded - if not url or url not in self.__handlers: - gLogger.debug("Find new handler %s" % (handlerTuple[0])) - - # If url is not given, try to discover it - if url is None: - # FIRST TRY: Url is hardcoded - try: - url = handlerTuple[1].LOCATION - # SECOND TRY: URL can be deduced from path - except AttributeError: - gLogger.debug("No location defined for %s try to get it from path" % handlerTuple[0]) - url = urlFinder(handlerTuple[0]) - + # First of all check if we can find route + # If urls is not given, try to discover it + if urls is None: + # FIRST TRY: Url is hardcoded + try: + urls = handler.LOCATION + # SECOND TRY: URL can be deduced from path + except AttributeError: + gLogger.debug("No location defined for %s try to get it from path" % handlerPath) + urls = urlFinder(handlerPath) + + if not urls: + gLogger.warn("URL not found for %s" % (handlerPath)) + return S_ERROR("URL not found for %s" % (handlerPath)) + + for url in urls if isinstance(urls, (list, tuple)) else [urls]: # We add "/" if missing at begin, e.g. we found "Framework/Service" # URL can't be relative in Tornado if url and not url.startswith('/'): url = "/%s" % url - elif not url: - gLogger.warn("URL not found for %s" % (handlerTuple[0])) - return S_ERROR("URL not found for %s" % (handlerTuple[0])) + + # Some new handler + if handlerPath not in self.__handlers: + gLogger.debug("Add new handler %s with port %s" % (handlerPath, port)) + self.__handlers[handlerPath] = {'URLs': [], 'Port': port} + + # Check if URL already loaded + if (url, handler) in self.__handlers[handlerPath]['URLs']: + gLogger.debug("URL: %s already loaded for %s " % (url, handlerPath)) + continue # Finally add the URL to handlers - if url not in self.__handlers: - self.__handlers[url] = handlerTuple[1] - gLogger.info("New handler: %s with URL %s" % (handlerTuple[0], url)) - else: - gLogger.debug("Handler already loaded %s" % (handlerTuple[0])) + gLogger.info("Add new URL %s to %s handler" % (url, handlerPath)) + self.__handlers[handlerPath]['URLs'].append((url, handler)) + return S_OK() - def discoverHandlers(self): + def discoverHandlers(self, handlerInstance): """ Force the discovery of URL, automatic call when we try to get handlers for the first time. You can disable the automatic call with autoDiscovery=False at initialization + + :param str handlerInstance: handler instance, the name of the section in some system section e.g.:: Services, APIs + + :return: list """ - gLogger.debug("Trying to auto-discover the handlers for Tornado") + urls = [] + gLogger.debug("Trying to auto-discover the %s handlers for Tornado" % handlerInstance) # Look in config diracSystems = gConfig.getSections('/Systems') - serviceList = [] if diracSystems['OK']: for system in diracSystems['Value']: try: - instance = PathFinder.getSystemInstance(system) - services = gConfig.getSections('/Systems/%s/%s/Services' % (system, instance)) - if services['OK']: - for service in services['Value']: - newservice = ("%s/%s" % (system, service)) - - # We search in the CS all handlers which used HTTPS as protocol - isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, instance, service)) - if isHTTPS and isHTTPS.lower() == 'https': - serviceList.append(newservice) + sysInstance = PathFinder.getSystemInstance(system) + result = gConfig.getSections('/Systems/%s/%s/%s' % (system, sysInstance, handlerInstance)) + if result['OK']: + for inst in result['Value']: + newInst = ("%s/%s" % (system, inst)) + + if handlerInstance == 'Services': + # We search in the CS all handlers which used HTTPS as protocol + isHTTPS = gConfig.getValue('/Systems/%s/%s/Services/%s/Protocol' % (system, sysInstance, inst)) + if isHTTPS and isHTTPS.lower() == 'https': + urls.append(newInst) + else: + port = gConfig.getValue('/Systems/%s/%s/Services/%s/Port' % (system, sysInstance, inst)) + if port: + newInst += ':%s' % port + urls.append(newInst) # On systems sometime you have things not related to services... except RuntimeError: pass - return self.loadHandlersByServiceName(serviceList) + return urls - def loadHandlersByServiceName(self, servicesNames): + def loadServicesHandlers(self, services=None): """ Load a list of handler from list of service using DIRAC moduleLoader Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - :param servicesNames: list of service, e.g. ['Framework/Hello', 'Configuration/Server'] + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() """ + # list of services, e.g. ['Framework/Hello', 'Configuration/Server'] + if isinstance(services, string_types): + services = [services] + # list of services + self.__services = self.__services if services is None else services if services else [] + + if self.__services is True: + self.__services = self.discoverHandlers('Services') + + if self.__services: + self.loader = ModuleLoader("Service", PathFinder.getServiceSection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) + load = self.loader.loadModules(self.__services) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + url = module['loadName'] + + # URL can be like https://domain:port/service/name or just service/name + # Here we just want the service name, for tornado + serviceTuple = url.replace('https://', '').split('/')[-2:] + url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) + self.__addHandler(module['loadName'], module['classObj'], url) + return S_OK() + + def __extractPorts(self, urls): + """ Extract ports from urls - # Use DIRAC system to load: search in CS if path is given and if not defined - # it search in place it should be (e.g. in DIRAC/FrameworkSystem/Service) - if not isinstance(servicesNames, list): - servicesNames = [servicesNames] + :param list urls: urls that can contain port, .e.g:: System/Service:port - load = self.loader.loadModules(servicesNames) - if not load['OK']: - return load - for module in self.loader.getModules().values(): - url = module['loadName'] + :return: (dict, list) + """ + portMapping = {} + newURLs = [] + for _url in urls: + if ':' in _url: + urlTuple = _url.split(':') + if urlTuple[0] not in portMapping: + portMapping[urlTuple[0]] = urlTuple[1] + newURLs.append(urlTuple[0]) + else: + newURLs.append(_url) + return (portMapping, newURLs) + + def loadEndpointsHandlers(self, endpoints=None): + """ + Load a list of handler from list of endpoints using DIRAC moduleLoader + Use :py:class:`DIRAC.Core.Base.private.ModuleLoader` - # URL can be like https://domain:port/service/name or just service/name - # Here we just want the service name, for tornado - serviceTuple = url.replace('https://', '').split('/')[-2:] - url = "%s/%s" % (serviceTuple[0], serviceTuple[1]) - self.__addHandler((module['loadName'], module['classObj']), url) + :param endpoints: List of endpoint handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + # list of endpoints, e.g. ['Framework/Hello', 'Configuration/Conf'] + if isinstance(endpoints, string_types): + endpoints = [endpoints] + # list of endpoints. If __endpoints is ``True`` then list of endpoints will dicover from CS + self.__endpoints = self.__endpoints if endpoints is None else endpoints if endpoints else [] + + if self.__endpoints is True: + self.__endpoints = self.discoverHandlers('APIs') + + if self.__endpoints: + # Extract ports + ports, self.__endpoints = self.__extractPorts(self.__endpoints) + + self.loader = ModuleLoader("API", PathFinder.getAPISection, RequestHandler, moduleSuffix="Handler") + + # Use DIRAC system to load: search in CS if path is given and if not defined + # it search in place it should be (e.g. in DIRAC/FrameworkSystem/API) + load = self.loader.loadModules(self.__endpoints) + if not load['OK']: + return load + for module in self.loader.getModules().values(): + handler = module['classObj'] + if not handler.LOCATION: + handler.LOCATION = urlFinder(module['loadName']) + urls = [] + # Look for methods that are exported + for mName, mObj in inspect.getmembers(handler): + if inspect.ismethod(mObj) and mName.find(handler.METHOD_PREFIX) == 0: + methodName = mName[len(handler.METHOD_PREFIX):] + args = getattr(handler, 'path_%s' % methodName, []) + # argObj = inspect.getargspec(mObj) + # args = '/'.join(['([A-z0-9_.-]+)'] * (len(argObj.args) - 1 - len(argObj.defaults or []))) + # defs = '/'.join(['([A-z0-9_.-]*)'] * len(argObj.defaults or [])) + gLogger.debug(" - Route %s/%s -> %s %s" % (handler.LOCATION, methodName, module['loadName'], mName)) + url = "%s%s" % (handler.LOCATION, '' if methodName == 'index' else ('/%s' % methodName)) + if args: + url += r'[\/]?%s' % '/'.join(args) + urls.append(url) + gLogger.debug(" * %s" % url) + self.__addHandler(module['loadName'], handler, urls, ports.get(module['modName'])) return S_OK() def getHandlersURLs(self): @@ -163,24 +275,32 @@ def getHandlersURLs(self): :returns: a list of URL (not the string with "https://..." but the tornado object) see http://www.tornadoweb.org/en/stable/web.html#tornado.web.URLSpec """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - self.discoverHandlers() urls = [] - for key in self.__handlers: - urls.append(TornadoURL(key, self.__handlers[key])) + for handlerData in self.__handlers.values(): + for url in handlerData['URLs']: + urls.append(TornadoURL(*url)) return urls def getHandlersDict(self): """ Return all handler dictionary - :returns: dictionary with absolute url as key ("/System/Service") - and tornado.web.url object as value + :returns: dictionary with service name as key ("System/Service") + and tornado.web.url objects as value for 'URLs' key + and port as value for 'Port' key """ - if not self.__handlers and self.__autoDiscovery: - self.__autoDiscovery = False - res = self.discoverHandlers() - if not res['OK']: - gLogger.error("Could not load handlers", res) return self.__handlers + + def getRoutes(self, defaultPort): + """ Get routes for each of the registred port + + :return: list -- contain tuples with port and list of tornado urls + """ + routes = {} + for data in self.__handlers.values(): + port = data.get('Port') + settings = data.get('Settings') + if port not in routes: + routes[port] = [] + routes[port] = list(set(routes[port]) + set(data['URLs'])) + return list(routes.items()) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoREST.py b/src/DIRAC/Core/Tornado/Server/TornadoREST.py new file mode 100644 index 00000000000..1475d2446bd --- /dev/null +++ b/src/DIRAC/Core/Tornado/Server/TornadoREST.py @@ -0,0 +1,77 @@ +""" +TornadoService is the base class for your handlers. +It directly inherits from :py:class:`tornado.web.RequestHandler` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import tornado.ioloop +from tornado import gen +from tornado.web import HTTPError +from tornado.ioloop import IOLoop +from six.moves import http_client + +import DIRAC + +from DIRAC import gLogger +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler + +sLog = gLogger.getSubLogger(__name__) + + +class TornadoREST(BaseRequestHandler): # pylint: disable=abstract-method + USE_AUTHZ_GRANTS = ['SSL', 'JWT', 'VISITOR'] + METHOD_PREFIX = 'web_' + LOCATION = '/' + + @classmethod + def _getServiceName(cls, request): + """ Define endpoint full name + + :param object request: tornado Request + + :return: str + """ + if not cls.SYSTEM: + raise Exception("System name must be defined.") + return "/".join([cls.SYSTEM, cls.__name__]) + + @classmethod + def _getServiceAuthSection(cls, endpointName): + """ Search endpoint auth section. + + :param str endpointName: endpoint name + + :return: str + """ + return "%s/Authorization" % PathFinder.getAPISection(endpointName) + + def _getMethodName(self): + """ Parse method name. + + :return: str + """ + method = self.request.path.replace(self.LOCATION, '').strip('/').split('/')[0] + if method and hasattr(self, ''.join([self.METHOD_PREFIX, method])): + return method + elif hasattr(self, '%sindex' % self.METHOD_PREFIX): + return 'index' + else: + raise NotImplementedError('%s method not implemented. \ + You can use the index method to handle this.' % method) + + @gen.coroutine + def get(self, *args, **kwargs): # pylint: disable=arguments-differ + """ Method to handle incoming ``GET`` requests. + Logic copied from :py:func:`~DIRAC.Core.Tornado.Server.BaseRequestHandler.post`. + """ + # Execute the method in an executor (basically a separate thread) + retVal = yield IOLoop.current().run_in_executor(None, self._executeMethod, args) + + # retVal is :py:class:`tornado.concurrent.Future` + self._finishFuture(retVal) diff --git a/src/DIRAC/Core/Tornado/Server/TornadoServer.py b/src/DIRAC/Core/Tornado/Server/TornadoServer.py index 3707cea6673..d8f35ff5c37 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoServer.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoServer.py @@ -9,11 +9,12 @@ __RCSID__ = "$Id$" +import os import time import datetime -import os - +import tempfile import M2Crypto +from io import open import tornado.iostream tornado.iostream.SSLIOStream.configure( @@ -25,13 +26,22 @@ import tornado.ioloop import DIRAC -from DIRAC import gConfig, gLogger -from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC import gConfig, gLogger, S_OK, S_ERROR from DIRAC.Core.Security import Locations -from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager from DIRAC.Core.Utilities import MemStat +from DIRAC.Core.Tornado.Server.HandlerManager import HandlerManager +from DIRAC.ConfigurationSystem.Client import PathFinder +from DIRAC.Core.Security import Locations, X509Chain, X509CRL from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient +# FROM WEB +import sys +import signal +import tornado.process +import tornado.autoreload + +from DIRAC.FrameworkSystem.private.authorization.utils.Sessions import SessionManager + sLog = gLogger.getSubLogger(__name__) @@ -60,35 +70,47 @@ class TornadoServer(object): Example 2:We want to debug service1 and service2 only, and use another port for that :: - services = ['component/service1', 'component/service2'] - serverToLaunch = TornadoServer(services=services, port=1234) + services = ['component/service1:port1', 'component/service2'] + endpoints = ['component/endpoint1:port1', 'component/endpoint2'] + serverToLaunch = TornadoServer(services=services, endpoints=endpoints, port=1234) serverToLaunch.startTornado() """ - def __init__(self, services=None, port=None): - """ - - :param list services: (default None) List of service handlers to load. If ``None``, loads all - :param int port: Port to listen to. If None, the port is resolved following the logic - described in the class documentation + def __init__(self, services=True, endpoints=False, port=None, debug=False, balancer=None, processes=None): + """ C'r + + :param list services: (default True) List of service handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load services + :param list endpoints: (default False) List of endpoint handlers to load. + If ``True``, loads all described in the CS + If ``False``, do not load endpoints + :param int port: Port to listen to. + If ``None``, the port is resolved following the logic described in the class documentation + :param bool debug: debug + :param str balancer: if need to use balancer, e.g.:: `nginx` + :param int processes: number of processes or if it's True just use all server CPUs """ - + # Debug + self.debug = debug + # Balancer, like as nginx + self.balancer = balancer + # Multiprocessor mode settings + self.processes = 1 if processes is None else 0 if processes is True else processes + if processes: + raise ImportError('Multiprocessor mode is not supported.') + # Application metadata, routes and settings mapping on the ports + self.__appsSettings = {} + # Default port, if enother is not discover if port is None: port = gConfig.getValue("/Systems/Tornado/%s/Port" % PathFinder.getSystemInstance('Tornado'), 8443) - - if services and not isinstance(services, list): - services = [services] - - # URLs for services. - # Contains Tornado :py:class:`tornado.web.url` object - self.urls = [] - # Other infos self.port = port - self.handlerManager = HandlerManager() - # Monitoring attributes + # Handler manager initialization with default settings + self.handlerManager = HandlerManager(services, endpoints) + # Monitoring attributes self._monitor = MonitoringClient() # temp value for computation, used by the monitoring self.__report = None @@ -97,20 +119,82 @@ def __init__(self, services=None, port=None): self.__monitoringLoopDelay = 60 # In secs # If services are defined, load only these ones (useful for debug purpose or specific services) - if services: - retVal = self.handlerManager.loadHandlersByServiceName(services) - if not retVal['OK']: - sLog.error(retVal['Message']) - raise ImportError("Some services can't be loaded, check the service names and configuration.") - + retVal = self.handlerManager.loadServicesHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some services can't be loaded, check the service names and configuration.") + + retVal = self.handlerManager.loadEndpointsHandlers() + if not retVal['OK']: + sLog.error(retVal['Message']) + raise ImportError("Some endpoints can't be loaded, check the endpoint names and configuration.") + + def __calculateAppSettings(self): + """ Calculate application information mapping on the ports + """ # if no service list is given, load services from configuration handlerDict = self.handlerManager.getHandlersDict() - for item in handlerDict.items(): - # handlerDict[key].initializeService(key) - self.urls.append(url(item[0], item[1])) - # If there is no services loaded: - if not self.urls: - raise ImportError("There is no services loaded, please check your configuration") + for data in handlerDict.values(): + port = data.get('Port') or self.port + for hURL in data['URLs']: + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if hURL not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(hURL) + return bool(self.__appsSettings) + + def loadServices(self, services): + """ Load a services + + :param services: List of service handlers to load. Default value set at initialization + If ``True``, loads all services from CS + :type services: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadServicesHandlers(services) + + def loadEndpoints(self, endpoints): + """ Load a endpoints + + :param endpoints: List of service handlers to load. Default value set at initialization + If ``True``, loads all endpoints from CS + :type endpoints: bool or list + + :return: S_OK()/S_ERROR() + """ + return self.handlerManager.loadEndpointsHandlers(endpoints) + + def addHandlers(self, routes, settings=None, port=None): + """ Add new routes + + :param list routes: routes + :param dict settings: application settings + :param int port: port + """ + port = port or self.port + if port not in self.__appsSettings: + self.__appsSettings[port] = {'routes': [], 'settings': {}} + if settings: + self.__appsSettings[port]['settings'].update(settings) + for route in routes: + if route not in self.__appsSettings[port]['routes']: + self.__appsSettings[port]['routes'].append(route) + + return S_OK() + + # def stopChildProcesses(self, sig, frame): + # """ + # It is used to properly stop tornado when more than one process is used. + # In principle this is doing the job of runsv.... + + # :param int sig: the signal sent to the process + # :param object frame: execution frame which contains the child processes + # """ + # for child in frame.f_locals.get('children', []): + # gLogger.info("Stopping child processes: %d" % child) + # os.kill(child, signal.SIGTERM) + # sys.exit(0) def startTornado(self): """ @@ -118,17 +202,17 @@ def startTornado(self): This method never returns. """ - sLog.debug("Starting Tornado") - self._initMonitoring() + # If there is no services loaded: + if not self.__calculateAppSettings(): + raise Exception("There is no services loaded, please check your configuration") - router = Application(self.urls, - debug=False, - compress_response=True) + sLog.debug("Starting Tornado") + # Prepare SSL settings certs = Locations.getHostCertificateAndKeyLocation() if certs is False: sLog.fatal("Host certificates not found ! Can't start the Server") - raise ImportError("Unable to load certificates") + raise Exception("Unable to load certificates") ca = Locations.getCAsLocation() ssl_options = { 'certfile': certs[0], @@ -138,23 +222,69 @@ def startTornado(self): 'sslDebug': False, # Set to true if you want to see the TLS debug messages } + if self.balancer: + # Create CAs for balancer + generateRevokedCertsFile() # it is used by nginx.... + # when NGINX is used then the Conf.HTTPS return False, it means tornado + # does not have to be configured using 443 port + generateCAFile() # if we use Nginx we have to generate the cas as well... + + # ############ + # # please do no move this lines. The lines must be before the fork_processes + # signal.signal(signal.SIGTERM, self.stopChildProcesses) + # signal.signal(signal.SIGINT, self.stopChildProcesses) + + # # Check processes if we're under a load balancert and have only one port + # if self.processes != 1: + # if not self.balancer: + # raise Exception("For multi processor mode, please, use balacer.") + # if len(self.__appsSettings) != 1: + # raise Exception("For multi processor mode, please, use one server port.") + # tornado.process.fork_processes(self.processes, max_restarts=0) + # ############# + + # Init monitoring + self._initMonitoring() self.__monitorLastStatsUpdate = time.time() self.__report = self.__startReportToMonitoringLoop() # Starting monitoring, IOLoop waiting time in ms, __monitoringLoopDelay is defined in seconds tornado.ioloop.PeriodicCallback(self.__reportToMonitoring, self.__monitoringLoopDelay * 1000).start() - # Start server - server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True) - try: - server.listen(self.port) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception starting HTTPServer", e) - raise - sLog.always("Listening on port %s" % self.port) - for service in self.urls: - sLog.debug("Available service: %s" % service) - + for port, app in self.__appsSettings.items(): + sLog.debug(" - %s" % "\n - ".join(["%s = %s" % (k, ssl_options[k]) for k in ssl_options])) + + # Default server configuration + settings = dict(debug=self.debug, + compress_response=True, + # Use gLogger instead tornado log + log_function=_logRequest) + + # Merge appllication settings + settings.update(app['settings']) + + # # Don't use autoreload in debug mode for multiprocess + # if self.processes != 1: + # if self.balancer: + # port = 8000 + # port += tornado.process.task_id() or 0 + # settings['debug'] = False + + # Start server + router = Application(app['routes'], **settings) + server = HTTPServer(router, ssl_options=ssl_options, decompress_request=True, xheaders=True) + try: + server.listen(port) + except Exception as e: # pylint: disable=broad-except + sLog.exception("Exception starting HTTPServer", e) + raise + if settings['debug']: + sLog.info("Configuring in developer mode...") + sLog.always("Listening on 127.0.0.1:%s" % port) + for service in app['routes']: + sLog.debug("Available service: %s" % service if isinstance(service, url) else service[0]) + + tornado.autoreload.add_reload_hook(lambda: sLog.verbose("\n == Reloading server...\n")) IOLoop.current().start() def _initMonitoring(self): @@ -224,3 +354,92 @@ def __endReportToMonitoringLoop(self, initialWallTime, initialCPUTime): percentage = cpuTime / wallTime * 100. if percentage > 0: self._monitor.addMark('CPU', percentage) + + +def _logRequest(handler): + """ This function will be called at the end of every request to log the result + + :param object handler: RequestHandler object + """ + print('=== _logRequest') + print('=== %s' % handler) + print(sLog) + print('===============') + status = handler.get_status() + if status < 400: + logm = sLog.notice + elif status < 500: + logm = sLog.warn + else: + logm = sLog.error + request_time = 1000.0 * handler.request.request_time() + logm("%d %s %.2fms" % (status, handler._request_summary(), request_time)) + + +def generateCAFile(): + """ Generate a single CA file with all the PEMs + + :return: str or bool + """ + cert = Locations.getHostCertificateAndKeyLocation() + if cert: + cert = cert[0] + else: + cert = "/opt/dirac/etc/grid-security/hostcert.pem" + + caDir = Locations.getCAsLocation() + for fn in (os.path.join(os.path.dirname(caDir), "cas.pem"), + os.path.join(os.path.dirname(cert), "cas.pem"), + False): + if not fn: + fn = tempfile.mkstemp(prefix="cas.", suffix=".pem")[1] + try: + fd = open(fn, "w") + except IOError: + continue + for caFile in os.listdir(caDir): + caFile = os.path.join(caDir, caFile) + chain = X509Chain.X509Chain() + result = chain.loadChainFromFile(caFile) + if not result['OK']: + continue + expired = chain.hasExpired() + if not expired['OK'] or expired['Value']: + continue + fd.write(chain.dumpAllToString()['Value'].decode('utf-8')) + fd.close() + return fn + return False + + +def generateRevokedCertsFile(): + """ Generate a single CA file with all the PEMs + + :return: str or bool + """ + cert = Locations.getHostCertificateAndKeyLocation() + if cert: + cert = cert[0] + else: + cert = "/opt/dirac/etc/grid-security/hostcert.pem" + + caDir = Locations.getCAsLocation() + for fn in (os.path.join(os.path.dirname(caDir), "allRevokedCerts.pem"), + os.path.join(os.path.dirname(cert), "allRevokedCerts.pem"), + False): + if not fn: + fn = tempfile.mkstemp(prefix="allRevokedCerts", suffix=".pem")[1] + try: + fd = open(fn, "w") + except IOError: + continue + for caFile in os.listdir(caDir): + caFile = os.path.join(caDir, caFile) + chain = X509CRL.X509CRL() + result = chain.loadCRLFromFile(caFile) + if not result['OK']: + continue + fd.write(chain.dumpAllToString()['Value'].decode('utf-8')) + fd.close() + return fn + return False diff --git a/src/DIRAC/Core/Tornado/Server/TornadoService.py b/src/DIRAC/Core/Tornado/Server/TornadoService.py index dbc88d67b28..d7a84c11e81 100644 --- a/src/DIRAC/Core/Tornado/Server/TornadoService.py +++ b/src/DIRAC/Core/Tornado/Server/TornadoService.py @@ -16,24 +16,21 @@ import threading from datetime import datetime from six.moves import http_client -from tornado.web import RequestHandler, HTTPError -from tornado import gen -import tornado.ioloop -from tornado.ioloop import IOLoop import DIRAC from DIRAC import gConfig, gLogger, S_OK, S_ERROR -from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.Core.DISET.AuthManager import AuthManager -from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC.Core.Tornado.Server.BaseRequestHandler import BaseRequestHandler +from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.FrameworkSystem.Client.MonitoringClient import MonitoringClient sLog = gLogger.getSubLogger(__name__) -class TornadoService(RequestHandler): # pylint: disable=abstract-method +class TornadoService(BaseRequestHandler): # pylint: disable=abstract-method """ Base class for all the Handlers. It directly inherits from :py:class:`tornado.web.RequestHandler` @@ -111,625 +108,45 @@ def export_streamToClient(self, myDataToSend, token): These are initialized in the :py:meth:`.initialize` method. """ - - # Because we initialize at first request, we use a flag to know if it's already done - __init_done = False - # Lock to make sure that two threads are not initializing at the same time - __init_lock = threading.RLock() - - # MonitoringClient, we don't use gMonitor which is not thread-safe - # We also need to add specific attributes for each service - _monitor = None + # Prefix of methods names + METHOD_PREFIX = "export_" @classmethod - def _initMonitoring(cls, serviceName, fullUrl): - """ - Initialize the monitoring specific to this handler - This has to be called only by :py:meth:`.__initializeService` - to ensure thread safety and unicity of the call. - - :param serviceName: relative URL ``//`` - :param fullUrl: full URl like ``https://://`` - """ - - # Init extra bits of monitoring - - cls._monitor = MonitoringClient() - cls._monitor.setComponentType(MonitoringClient.COMPONENT_WEB) - - cls._monitor.initialize() - - if tornado.process.task_id() is None: # Single process mode - cls._monitor.setComponentName('Tornado/%s' % serviceName) - else: - cls._monitor.setComponentName('Tornado/CPU%d/%s' % (tornado.process.task_id(), serviceName)) - - cls._monitor.setComponentLocation(fullUrl) + def _getServiceName(cls, request): + """ Search service name in request. - cls._monitor.registerActivity("Queries", "Queries served", "Framework", "queries", MonitoringClient.OP_RATE) + :param object request: tornado Request - cls._monitor.setComponentExtraParam('DIRACVersion', DIRAC.version) - cls._monitor.setComponentExtraParam('platform', DIRAC.getPlatform()) - cls._monitor.setComponentExtraParam('startTime', datetime.utcnow()) - - cls._stats = {'requests': 0, 'monitorLastStatsUpdate': time.time()} - - return S_OK() - - @classmethod - def __initializeService(cls, relativeUrl, absoluteUrl): - """ - Initialize a service. - The work is only perform once at the first request. - - :param relativeUrl: relative URL, e.g. ``//`` - :param absoluteUrl: full URL e.g. ``https://://`` - - :returns: S_OK + :return: str """ - # If the initialization was already done successfuly, - # we can just return - if cls.__init_done: - return S_OK() - - # Otherwise, do the work but with a lock - with cls.__init_lock: - - # Check again that the initialization was not done by another thread - # while we were waiting for the lock - if cls.__init_done: - return S_OK() - - # Url starts with a "/", we just remove it - serviceName = relativeUrl[1:] - - cls._startTime = datetime.utcnow() - sLog.info("First use, initializing service...", "%s" % relativeUrl) - cls._authManager = AuthManager("%s/Authorization" % PathFinder.getServiceSection(serviceName)) - - cls._initMonitoring(serviceName, absoluteUrl) - - cls._serviceName = serviceName - cls._validNames = [serviceName] - serviceInfo = {'serviceName': serviceName, - 'serviceSectionPath': PathFinder.getServiceSection(serviceName), - 'csPaths': [PathFinder.getServiceSection(serviceName)], - 'URL': absoluteUrl - } - cls._serviceInfoDict = serviceInfo - - cls.__monitorLastStatsUpdate = time.time() - - cls.initializeHandler(serviceInfo) - - cls.__init_done = True - - return S_OK() + # Expected path: ``//`` + return request.path[1:] @classmethod - def initializeHandler(cls, serviceInfoDict): - """ - This may be overwritten when you write a DIRAC service handler - And it must be a class method. This method is called only one time, - at the first request - - :param dict ServiceInfoDict: infos about services, it contains - 'serviceName', 'serviceSectionPath', - 'csPaths' and 'URL' - """ - pass + def _getServiceInfo(cls, serviceName, request): + """ Fill service information. - def initializeRequest(self): - """ - Called at every request, may be overwritten in your handler. - """ - pass + :param str serviceName: service name + :param object request: tornado Request - # This is a Tornado magic method - def initialize(self): # pylint: disable=arguments-differ + :return: dict """ - Initialize the handler, called at every request. + return {'serviceName': serviceName, + 'serviceSectionPath': PathFinder.getServiceSection(serviceName), + 'csPaths': [PathFinder.getServiceSection(serviceName)], + 'URL': request.full_url()} - It just calls :py:meth:`.__initializeService` + def _getMethodName(self): + """ Parse method name. - If anything goes wrong, the client will get ``Connection aborted`` - error. See details inside the method. - - ..warning:: - DO NOT REWRITE THIS FUNCTION IN YOUR HANDLER - ==> initialize in DISET became initializeRequest in HTTPS ! + :return: str """ + return self.get_argument("method") - # Only initialized once - if not self.__init_done: - # Ideally, if something goes wrong, we would like to return a Server Error 500 - # but this method cannot write back to the client as per the - # `tornado doc `_. - # So the client will get a ``Connection aborted``` - try: - res = self.__initializeService(self.srv_getURL(), self.request.full_url()) - if not res['OK']: - raise Exception(res['Message']) - except Exception as e: - sLog.error("Error in initialization", repr(e)) - raise - - def prepare(self): - """ - Prepare the request. It reads certificates and check authorizations. - We make the assumption that there is always going to be a ``method`` argument - regardless of the HTTP method used + def _getMethodArgs(self, args): + """ Decode args. + :return: list """ - - # "method" argument of the POST call. - # This resolves into the ``export_`` method - # on the handler side - # If the argument is not available, the method exists - # and an error 400 ``Bad Request`` is returned to the client - self.method = self.get_argument("method") - - self._stats['requests'] += 1 - self._monitor.setComponentExtraParam('queries', self._stats['requests']) - self._monitor.addMark("Queries") - - try: - self.credDict = self._gatherPeerCredentials() - except Exception: # pylint: disable=broad-except - # If an error occur when reading certificates we close connection - # It can be strange but the RFC, for HTTP, say's that when error happend - # before authentication we return 401 UNAUTHORIZED instead of 403 FORBIDDEN - sLog.error( - "Error gathering credentials", "%s; path %s" % - (self.getRemoteAddress(), self.request.path)) - raise HTTPError(status_code=http_client.UNAUTHORIZED) - - # Resolves the hard coded authorization requirements - try: - hardcodedAuth = getattr(self, 'auth_' + self.method) - except AttributeError: - hardcodedAuth = None - - # Check whether we are authorized to perform the query - # Note that performing the authQuery modifies the credDict... - authorized = self._authManager.authQuery(self.method, self.credDict, hardcodedAuth) - if not authorized: - sLog.error( - "Unauthorized access", "Identity %s; path %s; DN %s" % - (self.srv_getFormattedRemoteCredentials, - self.request.path, - self.credDict['DN'], - )) - raise HTTPError(status_code=http_client.UNAUTHORIZED) - - # Make post a coroutine. - # See https://www.tornadoweb.org/en/branch5.1/guide/coroutines.html#coroutines - # for details - @gen.coroutine - def post(self): # pylint: disable=arguments-differ - """ - Method to handle incoming ``POST`` requests. - Note that all the arguments are already prepared in the :py:meth:`.prepare` - method. - - The ``POST`` arguments expected are: - - * ``method``: name of the method to call - * ``args``: JSON encoded arguments for the method - * ``extraCredentials``: (optional) Extra informations to authenticate client - * ``rawContent``: (optionnal, default False) If set to True, return the raw output - of the method called. - - If ``rawContent`` was requested by the client, the ``Content-Type`` - is ``application/octet-stream``, otherwise we set it to ``application/json`` - and JEncode retVal. - - If ``retVal`` is a dictionary that contains a ``Callstack`` item, - it is removed, not to leak internal information. - - - Example of call using ``requests``:: - - In [20]: url = 'https://server:8443/DataManagement/TornadoFileCatalog' - ...: cert = '/tmp/x509up_u1000' - ...: kwargs = {'method':'whoami'} - ...: caPath = '/home/dirac/ClientInstallDIR/etc/grid-security/certificates/' - ...: with requests.post(url, data=kwargs, cert=cert, verify=caPath) as r: - ...: print r.json() - ...: - {u'OK': True, - u'Value': {u'DN': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'group': u'dirac_user', - u'identity': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'isLimitedProxy': False, - u'isProxy': True, - u'issuer': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser', - u'properties': [u'NormalUser'], - u'secondsLeft': 85441, - u'subject': u'/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/CN=2409820262', - u'username': u'adminusername', - u'validDN': False, - u'validGroup': False}} - """ - - sLog.notice( - "Incoming request", "%s /%s: %s" % - (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - self.method)) - - # Execute the method in an executor (basically a separate thread) - # Because of that, we cannot calls certain methods like `self.write` - # in __executeMethod. This is because these methods are not threadsafe - # https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # However, we can still rely on instance attributes to store what should - # be sent back (reminder: there is an instance - # of this class created for each request) - retVal = yield IOLoop.current().run_in_executor(None, self.__executeMethod) - - # retVal is :py:class:`tornado.concurrent.Future` - self.result = retVal.result() - - # Here it is safe to write back to the client, because we are not - # in a thread anymore - - # If set to true, do not JEncode the return of the RPC call - # This is basically only used for file download through - # the 'streamToClient' method. - rawContent = self.get_argument('rawContent', default=False) - - if rawContent: - # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - self.set_header("Content-Type", "application/octet-stream") - result = self.result - else: - self.set_header("Content-Type", "application/json") - result = encode(self.result) - - self.write(result) - self.finish() - - # This nice idea of streaming to the client cannot work because we are ran in an executor - # and we should not write back to the client in a different thread. - # See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - # def export_streamToClient(self, filename): - # # https://bhch.github.io/posts/2017/12/serving-large-files-with-tornado-safely-without-blocking/ - # #import ipdb; ipdb.set_trace() - # # chunk size to read - # chunk_size = 1024 * 1024 * 1 # 1 MiB - - # with open(filename, 'rb') as f: - # while True: - # chunk = f.read(chunk_size) - # if not chunk: - # break - # try: - # self.write(chunk) # write the chunk to response - # self.flush() # send the chunk to client - # except StreamClosedError: - # # this means the client has closed the connection - # # so break the loop - # break - # finally: - # # deleting the chunk is very important because - # # if many clients are downloading files at the - # # same time, the chunks in memory will keep - # # increasing and will eat up the RAM - # del chunk - # # pause the coroutine so other handlers can run - # yield gen.sleep(0.000000001) # 1 nanosecond - - # return S_OK() - - @gen.coroutine - def __executeMethod(self): - """ - Execute the method called, this method is ran in an executor - We have several try except to catch the different problem which can occur - - - First, the method does not exist => Attribute error, return an error to client - - second, anything happend during execution => General Exception, send error to client - - .. warning:: - This method is called in an executor, and so cannot use methods like self.write - See https://www.tornadoweb.org/en/branch5.1/web.html#thread-safety-notes - """ - - # getting method - try: - # For compatibility reasons with DISET, the methods are still called ``export_*`` - method = getattr(self, 'export_%s' % self.method) - except AttributeError as e: - sLog.error("Invalid method", self.method) - raise HTTPError(status_code=http_client.NOT_IMPLEMENTED) - - # Decode args args_encoded = self.get_body_argument('args', default=encode([])) - - args = decode(args_encoded)[0] - # Execute - try: - self.initializeRequest() - retVal = method(*args) - except Exception as e: # pylint: disable=broad-except - sLog.exception("Exception serving request", "%s:%s" % (str(e), repr(e))) - raise HTTPError(http_client.INTERNAL_SERVER_ERROR) - - return retVal - - # def __write_return(self, retVal): - # """ - # Write back to the client and return. - # It sets some headers (status code, ``Content-Type``). - # If raw content was requested by the client, the ``Content-Type`` - # is ``application/octet-stream``, otherwise we set it to ``application/json`` - # and JEncode retVal. - - # If ``retVal`` is a dictionary that contains a ``Callstack`` item, - # it is removed, not to leak internal information. - - # :param retVal: anything that can be serialized in json. - # """ - - # # In case of error in server side we hide server CallStack to client - # try: - # if 'CallStack' in retVal: - # del retVal['CallStack'] - # except TypeError: - # pass - - # # Set the status - # self.set_status(self._httpStatus) - - # # This is basically only used for file download through - # # the 'streamToClient' method. - # if self.rawContent: - # # See 4.5.1 http://www.rfc-editor.org/rfc/rfc2046.txt - # self.set_header("Content-Type", "application/octet-stream") - # returnedData = retVal - # else: - # self.set_header("Content-Type", "application/json") - # returnedData = encode(retVal) - - # self.write(returnedData) - - def on_finish(self): - """ - Called after the end of HTTP request. - Log the request duration - """ - elapsedTime = 1000.0 * self.request.request_time() - - try: - if self.result['OK']: - argsString = "OK" - else: - argsString = "ERROR: %s" % self.result['Message'] - except (AttributeError, KeyError): # In case it is not a DIRAC structure - if self._reason == 'OK': - argsString = 'OK' - else: - argsString = 'ERROR %s' % self._reason - - argsString = "ERROR: %s" % self._reason - sLog.notice("Returning response", "%s %s (%.2f ms) %s" % (self.srv_getFormattedRemoteCredentials(), - self._serviceName, - elapsedTime, argsString)) - - def _gatherPeerCredentials(self): - """ - Load client certchain in DIRAC and extract informations. - - The dictionary returned is designed to work with the AuthManager, - already written for DISET and re-used for HTTPS. - - :returns: a dict containing the return of :py:meth:`DIRAC.Core.Security.X509Chain.X509Chain.getCredentials` - (not a DIRAC structure !) - """ - - chainAsText = self.request.get_ssl_certificate().as_pem() - peerChain = X509Chain() - - # Here we read all certificate chain - cert_chain = self.request.get_ssl_certificate_chain() - for cert in cert_chain: - chainAsText += cert.as_pem() - - peerChain.loadChainFromString(chainAsText) - - # Retrieve the credentials - res = peerChain.getCredentials(withRegistryInfo=False) - if not res['OK']: - raise Exception(res['Message']) - - credDict = res['Value'] - - # We check if client sends extra credentials... - if "extraCredentials" in self.request.arguments: - extraCred = self.get_argument("extraCredentials") - if extraCred: - credDict['extraCredentials'] = decode(extraCred)[0] - return credDict - - -#### -# -# Default method -# -#### - - auth_ping = ['all'] - - def export_ping(self): - """ - Default ping method, returns some info about server. - - It returns the exact same information as DISET, for transparency purpose. - """ - # COPY FROM DIRAC.Core.DISET.RequestHandler - dInfo = {} - dInfo['version'] = DIRAC.version - dInfo['time'] = datetime.utcnow() - # Uptime - try: - with open("/proc/uptime", 'rt') as oFD: - iUptime = int(float(oFD.readline().split()[0].strip())) - dInfo['host uptime'] = iUptime - except Exception: # pylint: disable=broad-except - pass - startTime = self._startTime - dInfo['service start time'] = self._startTime - serviceUptime = datetime.utcnow() - startTime - dInfo['service uptime'] = serviceUptime.days * 3600 + serviceUptime.seconds - # Load average - try: - with open("/proc/loadavg", 'rt') as oFD: - dInfo['load'] = " ".join(oFD.read().split()[:3]) - except Exception: # pylint: disable=broad-except - pass - dInfo['name'] = self._serviceInfoDict['serviceName'] - stTimes = os.times() - dInfo['cpu times'] = {'user time': stTimes[0], - 'system time': stTimes[1], - 'children user time': stTimes[2], - 'children system time': stTimes[3], - 'elapsed real time': stTimes[4] - } - - return S_OK(dInfo) - - auth_echo = ['all'] - - @staticmethod - def export_echo(data): - """ - This method used for testing the performance of a service - """ - return S_OK(data) - - auth_whoami = ['authenticated'] - - def export_whoami(self): - """ - A simple whoami, returns all credential dictionary, except certificate chain object. - """ - credDict = self.srv_getRemoteCredentials() - if 'x509Chain' in credDict: - # Not serializable - del credDict['x509Chain'] - return S_OK(credDict) - -#### -# -# Utilities methods, some getters. -# From DIRAC.Core.DISET.requestHandler to get same interface in the handlers. -# Adapted for Tornado. -# These method are copied from DISET RequestHandler, they are not all used when i'm writing -# these lines. I rewrite them for Tornado to get them ready when a new HTTPS service need them -# -#### - - @classmethod - def srv_getCSOption(cls, optionName, defaultValue=False): - """ - Get an option from the CS section of the services - - :return: Value for serviceSection/optionName in the CS being defaultValue the default - """ - if optionName[0] == "/": - return gConfig.getValue(optionName, defaultValue) - for csPath in cls._serviceInfoDict['csPaths']: - result = gConfig.getOption("%s/%s" % (csPath, optionName, ), defaultValue) - if result['OK']: - return result['Value'] - return defaultValue - - def getCSOption(self, optionName, defaultValue=False): - """ - Just for keeping same public interface - """ - return self.srv_getCSOption(optionName, defaultValue) - - def srv_getRemoteAddress(self): - """ - Get the address of the remote peer. - - :return: Address of remote peer. - """ - - remote_ip = self.request.remote_ip - # Although it would be trivial to add this attribute in _HTTPRequestContext, - # Tornado won't release anymore 5.1 series, so go the hacky way - try: - remote_port = self.request.connection.stream.socket.getpeername()[1] - except Exception: # pylint: disable=broad-except - remote_port = 0 - - return (remote_ip, remote_port) - - def getRemoteAddress(self): - """ - Just for keeping same public interface - """ - return self.srv_getRemoteAddress() - - def srv_getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def getRemoteCredentials(self): - """ - Get the credentials of the remote peer. - - :return: Credentials dictionary of remote peer. - """ - return self.credDict - - def srv_getFormattedRemoteCredentials(self): - """ - Return the DN of user - - Mostly copy paste from - :py:meth:`DIRAC.Core.DISET.private.Transports.BaseTransport.BaseTransport.getFormattedCredentials` - - Note that the information will be complete only once the AuthManager was called - """ - address = self.getRemoteAddress() - peerId = "" - # Depending on where this is call, it may be that credDict is not yet filled. - # (reminder: AuthQuery fills part of it..) - try: - peerId = "[%s:%s]" % (self.credDict['group'], self.credDict['username']) - except AttributeError: - pass - - if address[0].find(":") > -1: - return "([%s]:%s)%s" % (address[0], address[1], peerId) - return "(%s:%s)%s" % (address[0], address[1], peerId) - -# def getFormattedCredentials(self): -# peerCreds = self.getConnectingCredentials() -# address = self.getRemoteAddress() -# if 'username' in peerCreds: -# peerId = "[%s:%s]" % (peerCreds['group'], peerCreds['username']) -# else: -# peerId = "" -# if address[0].find(":") > -1: -# return "([%s]:%s)%s" % (address[0], address[1], peerId) -# return "(%s:%s)%s" % (address[0], address[1], peerId) - - def srv_getServiceName(self): - """ - Return the service name - """ - return self._serviceInfoDict['serviceName'] - - def srv_getURL(self): - """ - Return the URL - """ - return self.request.path + return decode(args_encoded)[0] diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_endpoints.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_endpoints.py new file mode 100644 index 00000000000..a62c67ef9c3 --- /dev/null +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_endpoints.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import os +import sys + +from DIRAC.Core.Utilities.DIRACScript import DIRACScript + + +@DIRACScript() +def main(): + # Must be define BEFORE any dirac import + os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" + + from DIRAC import gConfig + from DIRAC.ConfigurationSystem.Client import PathFinder + from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData + from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration + from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer + from DIRAC.Core.Utilities.DErrno import includeExtensionErrors + from DIRAC.FrameworkSystem.Client.Logger import gLogger + + # We check if there is no configuration server started as master + # If you want to start a master CS you should use Configuration_Server.cfg and + # use tornado-start-CS.py + if gConfigurationData.isMaster() and gConfig.getValue( + '/Systems/Configuration/%s/Services/Server/Protocol' % + PathFinder.getSystemInstance('Configuration'), + 'dips').lower() == 'https': + gLogger.fatal("You can't run the CS and services in the same server!") + sys.exit(0) + + localCfg = LocalConfiguration() + localCfg.addMandatoryEntry("/DIRAC/Setup") + localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") + localCfg.addDefaultEntry("LogLevel", "INFO") + localCfg.addDefaultEntry("LogColor", True) + resultDict = localCfg.loadUserData() + if not resultDict['OK']: + gLogger.initialize("Tornado", "/") + gLogger.error("There were errors when loading configuration", resultDict['Message']) + sys.exit(1) + + includeExtensionErrors() + + gLogger.initialize('Tornado', "/") + + endpoints = ['Configuration/Configuration', 'Framework/Auth', 'Framework/Proxy'] + serverToLaunch = TornadoServer(False, endpoints, 8000) + serverToLaunch.startTornado() + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py new file mode 100644 index 00000000000..abf50960808 --- /dev/null +++ b/src/DIRAC/Core/Tornado/scripts/tornado_start_web.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import os +import sys +import tornado + +from DIRAC.Core.Utilities.DIRACScript import DIRACScript + + +@DIRACScript() +def main(): + # Must be define BEFORE any dirac import + os.environ['DIRAC_USE_TORNADO_IOLOOP'] = "True" + + from DIRAC import gConfig + from DIRAC.ConfigurationSystem.Client import PathFinder + from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData + from DIRAC.ConfigurationSystem.Client.LocalConfiguration import LocalConfiguration + from DIRAC.Core.Tornado.Server.TornadoServer import TornadoServer + from DIRAC.Core.Utilities.DErrno import includeExtensionErrors + from DIRAC.FrameworkSystem.Client.Logger import gLogger + + # We check if there is no configuration server started as master + # If you want to start a master CS you should use Configuration_Server.cfg and + # use tornado-start-CS.py + if gConfigurationData.isMaster() and gConfig.getValue( + '/Systems/Configuration/%s/Services/Server/Protocol' % + PathFinder.getSystemInstance('Configuration'), + 'dips').lower() == 'https': + gLogger.fatal("You can't run the CS and services in the same server!") + sys.exit(0) + + localCfg = LocalConfiguration() + localCfg.addMandatoryEntry("/DIRAC/Setup") + localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "yes") + localCfg.addDefaultEntry("LogLevel", "INFO") + localCfg.addDefaultEntry("LogColor", True) + resultDict = localCfg.loadUserData() + if not resultDict['OK']: + gLogger.initialize("Tornado", "/") + gLogger.error("There were errors when loading configuration", resultDict['Message']) + sys.exit(1) + + includeExtensionErrors() + + gLogger.initialize('Tornado', "/") + + services = ['DataManagement/TornadoFileCatalog'] + endpoints = ['Configuration/Configuration', 'Framework/Auth', 'Framework/Proxy'] + + try: + from WebAppDIRAC.Core.App import App + except ImportError: + gLogger.fatal('Web portal is not installed.') + sys.exit(1) + + # Get routes and settings for a portal + result = App().getAppToDict(8000) + if not result['OK']: + gLogger.fatal(result['Message']) + sys.exit(1) + app = result['Value'] + + serverToLaunch = TornadoServer(services, endpoints, port=8000, balancer='nginx') + serverToLaunch.addHandlers(app['routes'], app['settings']) + serverToLaunch.startTornado() + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/Core/Utilities/DErrno.py b/src/DIRAC/Core/Utilities/DErrno.py index ea0cdf23cc4..1a07a92f730 100644 --- a/src/DIRAC/Core/Utilities/DErrno.py +++ b/src/DIRAC/Core/Utilities/DErrno.py @@ -92,6 +92,7 @@ # DISET: 1X EDISET = 1110 ENOAUTH = 1111 +ECONNECT = 1112 # 3rd party security: 2X E3RDPARTY = 1120 EVOMS = 1121 @@ -169,6 +170,7 @@ # 111X: DISET 1110: 'EDISET', 1111: 'ENOAUTH', + 1112: 'ECONNECT', # 112X: 3rd party security 1120: 'E3RDPARTY', 1121: 'EVOMS', @@ -242,6 +244,7 @@ # 111X: DISET EDISET: "DISET Error", ENOAUTH: "Unauthorized query", + ECONNECT: "Connection error", # 112X: 3rd party security E3RDPARTY: "3rd party security service error", EVOMS: "VOMS Error", diff --git a/src/DIRAC/Core/Utilities/DictCache.py b/src/DIRAC/Core/Utilities/DictCache.py index 5eddb9770ee..9aa630b2336 100644 --- a/src/DIRAC/Core/Utilities/DictCache.py +++ b/src/DIRAC/Core/Utilities/DictCache.py @@ -207,6 +207,24 @@ def getKeys(self, validSeconds=0): finally: self.lock.release() + def getDict(self, validSeconds=0): + """ Get dictionary for all contents + + :param int validSeconds: valid time in seconds + + :return: dict + """ + self.lock.acquire() + try: + resDict = {} + limitTime = datetime.datetime.now() + datetime.timedelta(seconds=validSeconds) + for cKey in self.__cache: + if self.__cache[cKey]['expirationTime'] > limitTime: + resDict[cKey] = self.__cache[cKey]['value'] + return resDict + finally: + self.lock.release() + def purgeExpired(self, expiredInSeconds=0): """ Purge all entries that are expired or will be expired in diff --git a/src/DIRAC/Core/Utilities/Proxy.py b/src/DIRAC/Core/Utilities/Proxy.py index 655a5a99651..b3f8309dd27 100644 --- a/src/DIRAC/Core/Utilities/Proxy.py +++ b/src/DIRAC/Core/Utilities/Proxy.py @@ -41,10 +41,9 @@ def undecoratedFunction(foo='bar'): import os from DIRAC import gConfig, gLogger, S_ERROR, S_OK +from DIRAC.Core.Utilities.LockRing import LockRing from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.ConfigurationSystem.Client.ConfigurationData import gConfigurationData -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup, getDNForUsername -from DIRAC.Core.Utilities.LockRing import LockRing __RCSID__ = "$Id$" @@ -104,30 +103,32 @@ def wrapped_fcn(*args, **kwargs): return wrapped_fcn -def getProxy(userDNs, userGroup, vomsAttr, proxyFilePath): - """ do the actual download of the proxy, trying the different DNs +def getProxy(user, userGroup, vomsAttr, proxyFilePath): + """ Do the actual download of the proxy, trying the different DNs + + :param str user: user name or DN + :param str userGroup: group name + :param bool vomsAttr: if need VOMSproxy + :param str proxyPathFile: path to proxy file + + :return: S_OK(object)/S_ERROR() -- return proxy as chain """ - for userDN in userDNs: - if vomsAttr: - result = gProxyManager.downloadVOMSProxyToFile(userDN, userGroup, - requiredVOMSAttribute=vomsAttr, - filePath=proxyFilePath, - requiredTimeLeft=3600, - cacheTime=3600) - else: - result = gProxyManager.downloadProxyToFile(userDN, userGroup, - filePath=proxyFilePath, - requiredTimeLeft=3600, - cacheTime=3600) - - if not result['OK']: - gLogger.error("Can't download %sproxy " % ('VOMS' if vomsAttr else ''), - "of '%s', group %s to file: " % (userDN, userGroup) + result['Message']) - else: - return result + if vomsAttr: + result = gProxyManager.downloadVOMSProxyToFile(user, userGroup, + filePath=proxyFilePath, + requiredTimeLeft=3600, + cacheTime=3600) + else: + result = gProxyManager.downloadProxyToFile(user, userGroup, + filePath=proxyFilePath, + requiredTimeLeft=3600, + cacheTime=3600) - # If proxy not found for any DN, return an error - return S_ERROR("Can't download proxy") + if not result['OK']: + gLogger.error("Can't download %sproxy " % ('VOMS' if vomsAttr else ''), + "of '%s', group %s to file: " % (user, userGroup) + result['Message']) + return S_ERROR("Can't download proxy") + return result def executeWithoutServerCertificate(fcn): @@ -151,7 +152,8 @@ def executeWithoutServerCertificate(fcn): """ def wrapped_fcn(*args, **kwargs): - + """ Wrapped fuction + """ # Get the lock and acquire it executionLock = LockRing().getLock('_UseUserProxy_', recursive=True) executionLock.acquire() @@ -225,20 +227,7 @@ def _putProxy(userDN=None, userName=None, userGroup=None, vomsFlag=None, proxyFi :returns: Tuple of originalUserProxy, useServerCertificate, executionLock """ # Setup user proxy - if userDN: - userDNs = [userDN] - else: - result = getDNForUsername(userName) - if not result['OK']: - return result - userDNs = result['Value'] # a same user may have more than one DN - - vomsAttr = '' - if vomsFlag: - vomsAttr = getVOMSAttributeForGroup(userGroup) - - result = getProxy(userDNs, userGroup, vomsAttr, proxyFilePath) - + result = getProxy(userDN or userName, userGroup, vomsFlag, proxyFilePath) if not result['OK']: return result diff --git a/src/DIRAC/Core/Utilities/Shifter.py b/src/DIRAC/Core/Utilities/Shifter.py index 02928783e4c..5161ee348c9 100644 --- a/src/DIRAC/Core/Utilities/Shifter.py +++ b/src/DIRAC/Core/Utilities/Shifter.py @@ -14,7 +14,8 @@ from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations from DIRAC.ConfigurationSystem.Client.Helpers import cfgPath -from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import findDefaultGroupForUser,\ + getDNForUsernameInGroup, getVOMSAttributeForGroup def getShifterProxy(shifterType, fileName=False): @@ -31,26 +32,28 @@ def getShifterProxy(shifterType, fileName=False): userName = opsHelper.getValue(cfgPath('Shifter', shifterType, 'User'), '') if not userName: return S_ERROR("No shifter User defined for %s" % shifterType) - result = Registry.getDNForUsername(userName) - if not result['OK']: - return result - userDN = result['Value'][0] - result = Registry.findDefaultGroupForDN(userDN) + result = findDefaultGroupForUser(userName) if not result['OK']: return result defaultGroup = result['Value'] userGroup = opsHelper.getValue(cfgPath('Shifter', shifterType, 'Group'), defaultGroup) - vomsAttr = Registry.getVOMSAttributeForGroup(userGroup) + result = getDNForUsernameInGroup(userName, userGroup) + if not result['OK']: + return result + userDN = result['Value'] + if not userDN: + return S_ERROR('No user DN found for shifter %s@%s' % (userName, userGroup)) + vomsAttr = getVOMSAttributeForGroup(userGroup) if vomsAttr: gLogger.info("Getting VOMS [%s] proxy for shifter %s@%s (%s)" % (vomsAttr, userName, userGroup, userDN)) - result = gProxyManager.downloadVOMSProxyToFile(userDN, userGroup, + result = gProxyManager.downloadVOMSProxyToFile(userName, userGroup, filePath=fileName, requiredTimeLeft=86400, cacheTime=86400) else: gLogger.info("Getting proxy for shifter %s@%s (%s)" % (userName, userGroup, userDN)) - result = gProxyManager.downloadProxyToFile(userDN, userGroup, + result = gProxyManager.downloadProxyToFile(userName, userGroup, filePath=fileName, requiredTimeLeft=86400, cacheTime=86400) diff --git a/src/DIRAC/DataManagementSystem/Agent/FTS3Agent.py b/src/DIRAC/DataManagementSystem/Agent/FTS3Agent.py index 3b6b137fb3c..3f5baf1332d 100644 --- a/src/DIRAC/DataManagementSystem/Agent/FTS3Agent.py +++ b/src/DIRAC/DataManagementSystem/Agent/FTS3Agent.py @@ -37,7 +37,6 @@ from DIRAC.Core.Utilities.Time import fromString from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getFTS3ServerDict from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations as opHelper -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername from DIRAC.FrameworkSystem.Client.Logger import gLogger from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.DataManagementSystem.private import FTS3Utilities @@ -158,13 +157,7 @@ def getFTS3Context(self, username, group, ftsServer, threadID): # We keep a context in the cache for 45 minutes # (so it needs to be valid at least 15 since we add it for one hour) if not contextes.exists(idTuple, 15 * 60): - res = getDNForUsername(username) - if not res['OK']: - return res - # We take the first DN returned - userDN = res['Value'][0] - - log.debug("UserDN %s" % userDN) + log.debug("User: %s, group: %s" % (username, group)) # We dump the proxy to a file. # It has to have a lifetime of self.proxyLifetime @@ -172,7 +165,7 @@ def getFTS3Context(self, username, group, ftsServer, threadID): # we should make our cache a bit less than 2/3rd of the lifetime cacheTime = int(2 * self.proxyLifetime / 3) - 600 res = gProxyManager.downloadVOMSProxyToFile( - userDN, group, requiredTimeLeft=self.proxyLifetime, cacheTime=cacheTime) + username, group, requiredTimeLeft=self.proxyLifetime, cacheTime=cacheTime) if not res['OK']: return res diff --git a/src/DIRAC/DataManagementSystem/Service/FTS3ManagerHandler.py b/src/DIRAC/DataManagementSystem/Service/FTS3ManagerHandler.py index e561ab99d85..59839d095f0 100644 --- a/src/DIRAC/DataManagementSystem/Service/FTS3ManagerHandler.py +++ b/src/DIRAC/DataManagementSystem/Service/FTS3ManagerHandler.py @@ -18,7 +18,7 @@ # from DIRAC from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNsForUsernameInGroup from DIRAC.Core.DISET.RequestHandler import RequestHandler, getServiceOption from DIRAC.Core.Security.Properties import FULL_DELEGATION, LIMITED_DELEGATION, TRUSTED_HOST from DIRAC.Core.Utilities import DErrno @@ -70,10 +70,10 @@ def _isAllowed(opObj, remoteCredentials): credProperties = remoteCredentials['properties'] # First, get the DN matching the username - res = getDNForUsername(opObj.username) + res = getDNsForUsernameInGroup(opObj.username, opObj.userGroup) # if we have an error, do not allow if not res['OK']: - gLogger.error("Error retrieving DN for username", res) + gLogger.error("Error retrieving DN for username/group", res) return False # List of DN matching the username diff --git a/src/DIRAC/DataManagementSystem/Service/FileCatalogProxyHandler.py b/src/DIRAC/DataManagementSystem/Service/FileCatalogProxyHandler.py index a1ea67000a7..4bd52c7d529 100644 --- a/src/DIRAC/DataManagementSystem/Service/FileCatalogProxyHandler.py +++ b/src/DIRAC/DataManagementSystem/Service/FileCatalogProxyHandler.py @@ -88,9 +88,9 @@ def __prepareSecurityDetails(self, vomsFlag=True): clientGroup = credDict['group'] gLogger.debug("Getting proxy for %s@%s (%s)" % (clientUsername, clientGroup, clientDN)) if vomsFlag: - result = gProxyManager.downloadVOMSProxyToFile(clientDN, clientGroup) + result = gProxyManager.downloadVOMSProxyToFile(clientDN or clientUsername, clientGroup) else: - result = gProxyManager.downloadProxyToFile(clientDN, clientGroup) + result = gProxyManager.downloadProxyToFile(clientDN or clientUsername, clientGroup) if not result['OK']: return result gLogger.debug("Updating environment.") diff --git a/src/DIRAC/DataManagementSystem/Service/StorageElementProxyHandler.py b/src/DIRAC/DataManagementSystem/Service/StorageElementProxyHandler.py index aa53d34555e..30df7ef9fe7 100644 --- a/src/DIRAC/DataManagementSystem/Service/StorageElementProxyHandler.py +++ b/src/DIRAC/DataManagementSystem/Service/StorageElementProxyHandler.py @@ -270,7 +270,7 @@ def __prepareSecurityDetails(self): clientUsername = credDict['username'] clientGroup = credDict['group'] gLogger.debug("Getting proxy for %s@%s (%s)" % (clientUsername, clientGroup, clientDN)) - res = gProxyManager.downloadVOMSProxy(clientDN, clientGroup) + res = gProxyManager.downloadVOMSProxy(clientDN or clientUsername, clientGroup) if not res['OK']: return res chain = res['Value'] diff --git a/src/DIRAC/FrameworkSystem/API/AuthHandler.py b/src/DIRAC/FrameworkSystem/API/AuthHandler.py new file mode 100644 index 00000000000..5375198f911 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/API/AuthHandler.py @@ -0,0 +1,581 @@ +""" This handler basically provides a REST interface to interact with the OAuth 2 authentication server +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import pprint +import requests +from io import open + +from dominate import document, tags as dom +from tornado.template import Template +from tornado.httputil import HTTPHeaders +from tornado.httpclient import HTTPResponse, HTTPRequest + +from authlib.jose import jwk, jwt +# from authlib.jose import JsonWebKey +from authlib.oauth2.base import OAuth2Error + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager +from DIRAC.FrameworkSystem.private.authorization.AuthServer import AuthServer +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import ResourceProtector +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import ClientRegistrationEndpoint +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import DeviceAuthorizationEndpoint +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +__RCSID__ = "$Id$" + + +class AuthHandler(TornadoREST): + # TODO: docs + # Authorization access to all methods handled by AuthServer instance + USE_AUTHZ_GRANTS = ['VISITOR'] + SYSTEM = 'Framework' + AUTH_PROPS = 'all' + LOCATION = "/DIRAC/auth" + CSS = """ +.button { + border-radius: 4px; + background-color: #ffffff00; + border: none; + color: black; + text-align: center; + font-size: 14px; + padding: 12px; + width: 100%; + transition: all 0.5s; + cursor: pointer; + margin: 5px; + display: block; /* Make the links appear below each other */ +} +.button a { + color: black; + cursor: pointer; + display: inline-block; + position: relative; + transition: 0.5s; + text-decoration: none; /* Remove underline from links */ +} +.button a:after { + content: '\\00bb'; + position: absolute; + opacity: 0; + top: 0; + right: -20px; + transition: 0.5s; +} +.button:hover a { + padding-right: 25px; +} +.button:hover a:after { + opacity: 1; + right: 0; +}""" + + @classmethod + def initializeHandler(cls, serviceInfo): + """ This method is called only one time, at the first request + + :param dict ServiceInfoDict: infos about services + """ + cls.server = AuthServer() + cls.idps = IdProviderFactory() + cls.css = {} + cls.css_align_center = 'display:block;justify-content:center;align-items:center;' + cls.css_center_div = 'height:700px;width:100%;position:absolute;top:50%;left:0;margin-top:-350px;' + cls.css_big_text = 'font-size:28px;' + cls.css_main = ' '.join([cls.css_align_center, cls.css_center_div, cls.css_big_text]) + + def initializeRequest(self): + """ Called at every request """ + self.currentPath = self.request.protocol + "://" + self.request.host + self.request.path + # Template for a html UI + self.doc = document('DIRAC authentication') + with self.doc.head: + dom.link(rel='stylesheet', + href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css") + dom.style(self.CSS) + + def _parseDIRACResult(self, result): + """ Here the result which returns handle_response is processed + """ + if not result['OK']: + # If response error is DIRAC server error, not OAuth2 flow error + self.removeSession() + self.set_status = 400 + self.write({'error': 'server_error', + 'description': '%s:\n%s' % (result['Message'], '\n'.join(result['CallStack']))}) + else: + # Successful responses and OAuth2 errors are processed here + status_code, headers, payload, new_session, error = result['Value'][0] + if status_code: + self.set_status(status_code) + if headers: + for key, value in headers: + self.set_header(key, value) + if payload: + self.write(payload) + if new_session: + self.saveSession(new_session) + if error: + self.removeSession() + for method, args_kwargs in result['Value'][1].items(): + eval('self.%s' % method)(*args_kwargs[0], **args_kwargs[1]) + + def saveSession(self, session): + """ Save session to cookie + + :param dict session: session + """ + self.set_secure_cookie('auth_session', json.dumps(session), secure=True, httponly=True) + + def removeSession(self): + """ Remove session from cookie """ + self.clear_cookie('auth_session') + + def getSession(self, state=None, **kw): + """ Get session from cookie + + :param str state: state + + :return: dict + """ + try: + session = json.loads(self.get_secure_cookie('auth_session')) + checkState = (session['state'] == state) if state else None + checkOption = (session[kw.items()[0][0]] == kw.items()[0][0]) if kw else None + except Exception as e: + return None + return session if (checkState or checkOption) else None + + path_index = ['.well-known/(oauth-authorization-server|openid-configuration)'] + + def web_index(self, instance): + """ Well known endpoint, specified by + `RFC8414 `_ + + Request examples:: + + GET: LOCATION/.well-known/openid-configuration + GET: LOCATION/.well-known/oauth-authorization-server + + Responce:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "registration_endpoint": "https://domain.com/DIRAC/auth/register", + "userinfo_endpoint": "https://domain.com/DIRAC/auth/userinfo", + "jwks_uri": "https://domain.com/DIRAC/auth/jwk", + "code_challenge_methods_supported": [ + "S256" + ], + "grant_types_supported": [ + "authorization_code", + "code", + "urn:ietf:params:oauth:grant-type:device_code", + "implicit", + "refresh_token" + ], + "token_endpoint": "https://domain.com/DIRAC/auth/token", + "response_types_supported": [ + "code", + "device", + "id_token token", + "id_token", + "token" + ], + "authorization_endpoint": "https://domain.com/DIRAC/auth/authorization", + "issuer": "https://domain.com/DIRAC/auth" + } + """ + if self.request.method == "GET": + return dict(self.server.metadata) + + def web_jwk(self): + """ JWKs endpoint + + Request example:: + + GET LOCATION/jwk + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "keys": [ + { + "e": "AQAB", + "kty": "RSA", + "n": "3Vv5h5...X3Y7k" + } + ] + } + """ + if self.request.method == "GET": + with open('/opt/dirac/etc/grid-security/jwtRS256.key.pub', 'rb') as f: + key = f.read() + # # # For newer version + # # # key = JsonWebKey.import_key(key, {'kty': 'RSA'}) + # # # self.finish(key.as_dict()) + return {'keys': [jwk.dumps(key, kty='RSA', alg='RS256')]} + + # auth_userinfo = ["authenticated"] + def web_userinfo(self): + """ The UserInfo endpoint can be used to retrieve identity information about a user, + see `spec `_ + + GET LOCATION/userinfo + + Parameters: + +---------------+--------+---------------------------------+--------------------------------------------------+ + | **name** | **in** | **description** | **example** | + +---------------+--------+---------------------------------+--------------------------------------------------+ + | Authorization | header | Provide access token | Bearer jkagfbfd3r4ubf887gqduyqwogasd87 | + +---------------+--------+---------------------------------+--------------------------------------------------+ + + Request example:: + + GET LOCATION/userinfo + Authorization: Bearer + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "248289761001", + "name": "Bob Smith", + "given_name": "Bob", + "family_name": "Smith", + "group": [ + "dirac_user", + "dirac_admin" + ] + } + """ + # Token verification + token = ResourceProtector().acquire_token(self.request, '') + return {'sub': token.sub, 'issuer': token.issuer, 'group': token.groups[0]} + # return {'username': credDict['username'], + # 'group': credDict['group']} + # return self.__validateToken() + + def web_register(self): + """ The Client Registration Endpoint, specified by + `RFC7591 `_ + + POST LOCATION/register?data.. + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | Authorization | header | Provide access token | Bearer jkagfbfd3r4ubf887gqduyqwogasd8 | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | grant_types | data | list of grant types, more supported | ["authorization_code","refresh_token"]| + | | | more supported grant types in *grants | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | data | list of scoupes separated by a space | something | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | response_types | data | list of returned responses | ["token","id_token token","code"] | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | redirect_uris | data | Redirection URI to which the response will| ['https://dirac.egi.eu/redirect'] | + | | | be sent. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + *:mod:`grants ` + + https://wlcg.cloud.cnaf.infn.it/register + + requests.post('https://marosvn32.in2p3.fr/DIRAC/auth/register', + json={'grant_types': ['implicit'], + 'response_types': ['token'], + 'redirect_uris': ['https://dirac.egi.eu'], + 'token_endpoint_auth_method': 'none'}, verify=False).text + requests.post('https://marosvn32.in2p3.fr/DIRAC/auth/register', + json={"scope":"changeGroup", + "token_endpoint_auth_method":"client_secret_basic", + "grant_types":["authorization_code","refresh_token"], + "redirect_uris":["https://marosvn32.in2p3.fr/DIRAC","https://marosvn32.in2p3.fr/DIRAC/loginComplete"], + "response_types":["token","id_token token","code"]}, verify=False).text + """ + return self.server.create_endpoint_response(ClientRegistrationEndpoint.ENDPOINT_NAME, self.request) + + path_device = ['([A-z0-9-_]*)'] + + def web_device(self, provider=None): + """ The device authorization endpoint can be used to request device and user codes. + This endpoint is used to start the device flow authorization process and user code verification. + + POST LOCATION/device/? + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | user code | query | in the last step to confirm recived user | WE8R-WEN9 | + | | | code put it as query parameter (optional) | | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The public client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + + User code confirmation:: + + GET LOCATION/device/?user_code= + + Request example, to initialize a Device authentication flow:: + + POST LOCATION/device/CheckIn_dev?client_id=3f1DAj8z6eNw0E6JGq1Vu6efZwyV&scope=g:dirac_admin + + Response:: + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_code": "TglwLiow0HUwowjB9aHH5HqH3bZKP9d420LkNhCEuR", + "verification_uri": "https://marosvn32.in2p3.fr/DIRAC/auth/device", + "interval": 5, + "expires_in": 1800, + "verification_uri_complete": "https://marosvn32.in2p3.fr/DIRAC/auth/device/WSRL-HJMR", + "user_code": "WSRL-HJMR" + } + + Request example, to confirm the user code:: + + POST LOCATION/device/CheckIn_dev/WSRL-HJMR + + Response:: + + HTTP/1.1 200 OK + """ + if self.request.method == 'POST': + group = self.get_argument('group', None) + if group: + provider = Registry.getIdPForGroup(group) + if not provider: + return S_ERROR('No provider found for %s' % group) + result = self.idps.getIdProvider(provider) + if result['OK']: + idPObj = result['Value'] + result = idPObj.submitDeviceCodeAuthorizationFlow(group) + if not result['OK']: + return result + return result['Value'] + + self.log.verbose('Initialize a Device authentication flow.') + return self.server.create_endpoint_response(DeviceAuthorizationEndpoint.ENDPOINT_NAME, self.request) + + elif self.request.method == 'GET': + userCode = self.get_argument('user_code', None) + if userCode: + # If received a request with a user code, then prepare a request to authorization endpoint + self.log.verbose('User code verification.') + session, data = self.server.db.getSessionByOption('user_code', userCode) + if not session: + return 'Device flow authorization session %s expired.' % session + # Get original request from session + req = createOAuth2Request(dict(method='GET', uri=data['uri'])) + authURL = '/authorization/%s?%s&user_code=%s' % (provider, req.query, userCode) + # Save session to cookie + return self.server.handle_response(302, {}, [("Location", authURL)], data) + + # If received a request without a user code, then send a form to enter the user code + with self.doc: + dom.div(dom.form(dom._input(type="text", name="user_code", style=self.css_big_text), + dom.button('Submit', type="submit", style=self.css_big_text), + action=self.currentPath, method="GET"), style=self.css_main) + return Template(self.doc.render()).generate() + + path_authorization = ['([A-z0-9]*)'] + + def web_authorization(self, provider=None): + """ Authorization endpoint + + GET: LOCATION/authorization/ + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | response_type | query | informs of the desired grant type | code | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | client_id | query | The client ID | 3f6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | scope | query | list of scoupes separated by a space, to | g:dirac_user | + | | | add a group you must add "g:" before the | | + | | | group name | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | provider | path | identity provider to autorize (optional) | CheckIn | + | | | It's possible to add it interactively. | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + General options: + provider -- identity provider to autorize + + Device flow: + &user_code=.. (required) + + Authentication code flow: + &scope=.. (optional) + &redirect_uri=.. (optional) + &state=.. (main session id, optional) + &code_challenge=.. (PKCE, optional) + &code_challenge_method=(pain|S256) ('pain' by default, optional) + """ + return self.server.validate_consent_request(self.request, provider) + + def web_redirect(self): + """ Redirect endpoint. + After a user successfully authorizes an application, the authorization server will redirect + the user back to the application with either an authorization code or access token in the URL. + The full URL of this endpoint must be registered in the identity provider. + + Read more in `oauth.com `_. + Specified by `RFC6749 `_. + + GET LOCATION/redirect + + Parameters:: + + &chooseScope=.. to specify new scope(group in our case) (optional) + """ + # Current IdP session state + state = self.get_argument('state') + + # Try to catch errors + if self.get_argument('error', None): + error = OAuth2Error(error=self.get_argument('error'), description=self.get_argument('error_description', '')) + return self.server.handle_error_response(state, error) + + # Check current auth session that was initiated for the selected external identity provider + sessionWithExtIdP = self.getSession(state) + if not sessionWithExtIdP: + return S_ERROR("%s session is expired." % state) + + if not sessionWithExtIdP.get('authed'): + # Parse result of the second authentication flow + self.log.info('%s session, parsing authorization response:\n' % state, + '\n'.join([self.request.uri, self.request.query, self.request.body, str(self.request.headers)])) + + result = self.server.parseIdPAuthorizationResponse(self.request, sessionWithExtIdP) + if not result['OK']: + return result + # Return main session flow + sessionWithExtIdP['authed'] = result['Value'] + + # Research group + grant_user, response = self.__researchDIRACGroup(sessionWithExtIdP) + if not grant_user: + return response + + # RESPONSE to basic DIRAC client request + return self.server.create_authorization_response(response, grant_user) + + def web_token(self): + """ The token endpoint, the description of the parameters will differ depending on the selected grant_type + + POST LOCATION/token + + Parameters: + +----------------+--------+-------------------------------------+---------------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | grant_type | query | what grant type to use, more | urn:ietf:params:oauth:grant-type:device_code| + | | | supported grant types in *grants | | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | client_id | query | The public client ID | 3f1DAj8z6eNw0E6JGq1VuzRkpWUL9XTxhL86efZw | + +----------------+--------+-------------------------------------+---------------------------------------------+ + | device_code | query | device code | uW5xL4hr2tqwBPKL5d0JO9Fcc67gLqhJsNqYTSp | + +----------------+--------+-------------------------------------+---------------------------------------------+ + + *:mod:`grants ` + + Request example:: + + POST LOCATION/token?client_id=L86..yV&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code=uW5 + + Response:: + + HTTP/1.1 400 OK + Content-Type: application/json + + { + "error": "authorization_pending" + } + """ + return self.server.create_token_response(self.request) + + def __researchDIRACGroup(self, extSession): + """ Research DIRAC groups for authorized user + + :param dict extSession: ended authorized external IdP session + + :return: response + """ + # Base DIRAC client auth session + firstRequest = createOAuth2Request(extSession['mainSession']) + # Read requested groups by DIRAC client or user + firstRequest.addScopes(self.get_arguments('chooseScope', [])) + # Read already authed user + username, userID = extSession['authed'] + self.log.debug('Next groups has been found for %s:' % username, ', '.join(firstRequest.groups)) + + # Researche Group + result = gProxyManager.getGroupsStatusByUsername(username, firstRequest.groups) + if not result['OK']: + return None, result + groupStatuses = result['Value'] + if not groupStatuses: + return None, S_ERROR('No groups found.') + self.log.debug('The state of %s user groups has been checked:' % username, pprint.pformat(groupStatuses)) + + if not firstRequest.groups: + if len(groupStatuses) == 1: + firstRequest.addScopes(['g:%s' % groupStatuses[0]]) + else: + # Choose group interface + with self.doc: + with dom.div(style=self.css_main): + with dom.div('Choose group', style=self.css_align_center): + for group, data in groupStatuses.items(): + # data: Status, Comment, Action + dom.button(dom.a(group, href='%s?state=%s&chooseScope=g:%s' % (self.currentPath, + self.get_argument('state'), group)), + cls='button') + return None, self.server.handle_response(payload=Template(self.doc.render()).generate(), newSession=extSession) + + for group in firstRequest.groups: + status = groupStatuses[group]['Status'] + action = groupStatuses[group].get('Action') + comment = groupStatuses[group].get('Comment') + + if status == 'needToAuth': + # Submit second auth flow through IdP + idP = action[1][0] + return None, self.server.getIdPAuthorization(idP, firstRequest) + + if status not in ['ready', 'unknown']: + self.log.verbose('%s group has bad status: %s; %s' % (group, status, comment)) + + # Return grant user + return {'username': username, 'user_id': userID}, firstRequest diff --git a/src/DIRAC/FrameworkSystem/API/ProxyHandler.py b/src/DIRAC/FrameworkSystem/API/ProxyHandler.py new file mode 100644 index 00000000000..2966b0e8839 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/API/ProxyHandler.py @@ -0,0 +1,116 @@ +""" Handler to serve the DIRAC proxy data +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Tornado.Server.TornadoREST import TornadoREST +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient +from DIRAC.ConfigurationSystem.Client.Utilities import isDownloadablePersonalProxy +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsernameInGroup + +__RCSID__ = "$Id$" + + +class ProxyHandler(TornadoREST): + USE_AUTHZ_GRANTS = ['JWT'] + RAISE_DIRAC_ERROR = True + SYSTEM = 'Framework' + AUTH_PROPS = "authenticated" + LOCATION = "/DIRAC" + + @classmethod + def initializeHandler(cls, serviceInfo): + """ Request initialization """ + cls.proxyCli = ProxyManagerClient() + + # path_proxy = [r'([a-z]*)[\/]?([a-z]*)'] + + # def web_proxy(self, user=None, group=None): + def web_proxy(self, user=None, group=None): + """ RESTful endpoints to user proxy management to retrieve personal proxy. + + GET LOCATION/proxy + + Parameters: + +----------------+--------+-------------------------------------------+---------------------------------------+ + | **name** | **in** | **description** | **example** | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | Authorization | header | Provide access token | Bearer jkagfbfd3r4ubf887gqduyqwogasd8 | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | lifetime | query | requested proxy live time in seconds | 3600 | + | | | by default 6 hours (optional) | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | voms | query | to get user proxy with VOMS extension, | true | + | | | by defaul false (optional) | | + +----------------+--------+-------------------------------------------+---------------------------------------+ + | refresh_token | query | to get user proxy need provide all tokens | jkagfbfd3r4ubf887gqduyqwogasd89823hf | + +----------------+--------+-------------------------------------------+---------------------------------------+ + + Request example:: + + GET LOCATION/proxy + Authorization: Bearer + + Response:: + + HTTP/1.1 200 OK + + + """ + voms = self.get_argument('voms', None) + try: + proxyLifeTime = int(self.get_argument('lifetime', 3600 * 6)) + except Exception as e: + return S_ERROR('Cannot read "lifetime" argument. %s' % repr(e)) + + # GET + if self.request.method == 'GET': + # # Return content of Proxy DB + # if 'metadata' in optns: + # pass + + # Return personal proxy + # if not user and not group: + from pprint import pprint + print('=============== PROXY =================') + pprint(self.getRemoteCredentials()) + result = self.__getProxy(self.getUserName(), self.getUserGroup(), voms, proxyLifeTime) + if result['OK']: + return result['Value'] + return result + + # elif user and group: + # return self.__getProxy(user, group, voms, proxyLifeTime) + + else: + return S_ERROR("Wrone request.") + + def __getProxy(self, user, group, voms, lifetime): + """ Get proxy + + :param str user: user name + :param str group: group name + :param bool voms: add voms ext + :param int lifetime: proxy lifetime + + :return: S_OK(str)/S_ERROR() + """ + lifetime = min(lifetime, 3600 * 6) + + # Allowe to take only personal proxy + if self.getUserName() != user or self.getUserGroup() != group: + return S_ERROR('Sorry, only personal proxy is allowed to download') + + if not isDownloadablePersonalProxy(): + return S_ERROR("You can't get proxy, configuration settings(downloadablePersonalProxy) not allow to do that.") + + if voms: + result = self.proxyCli.downloadVOMSProxy(user, group, requiredTimeLeft=lifetime) + else: + result = self.proxyCli.downloadProxy(user, group, requiredTimeLeft=lifetime) + if result['OK']: + self.log.notice('Proxy was created.') + return result['Value'].dumpAllToString() + return result diff --git a/src/DIRAC/FrameworkSystem/API/__init__.py b/src/DIRAC/FrameworkSystem/API/__init__.py new file mode 100644 index 00000000000..2492dd6ec85 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/API/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# $HeadURL$ +__RCSID__ = "$Id$" diff --git a/src/DIRAC/FrameworkSystem/Client/AuthManagerClient.py b/src/DIRAC/FrameworkSystem/Client/AuthManagerClient.py new file mode 100644 index 00000000000..e95abb92628 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Client/AuthManagerClient.py @@ -0,0 +1,143 @@ +""" AuthManagerClient has the function to "talk" to the AuthManager service. Also, when requesting information + about users, this information is cached in a separate class + :mod:`AuthManagerData `, and is used, in the Registry for example, + to reduce the number of requests to the server part +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import six +import json +import requests +from authlib.common.security import generate_token + +from diraccfg import CFG + +from DIRAC import rootPath, S_OK, S_ERROR +from DIRAC.Core.Base.Client import Client, createClient +from DIRAC.Core.Utilities import DIRACSingleton +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.Client.AuthManagerData import gAuthManagerData +from DIRAC.FrameworkSystem.private.authorization.utils.Sessions import Session + +__RCSID__ = "$Id$" + + +@createClient('Framework/AuthManager') +@six.add_metaclass(DIRACSingleton.DIRACSingleton) +class AuthManagerClient(Client): + """ Authentication manager + """ + + def __init__(self, *args, **kwargs): + """ Constructor + """ + super(AuthManagerClient, self).__init__(*args, **kwargs) + self.localCfg = CFG() + self.cfgFile = os.path.join(rootPath, 'etc', 'dirac.cfg') + self.localCfg.loadFromFile(self.cfgFile) + self.setServer('Framework/AuthManager') + self.idps = IdProviderFactory() + + def prepareClientCredentials(self): + """ To interact with the server part through OAuth, you must at least be a registered client, + prepare authentication client credentials + + :return: S_OK(dict)/S_ERROR() + """ + clientMetadata = self.localCfg.getAsDict('/LocalInstallation/AuthorizationClient') + + if clientMetadata: + return S_OK(clientMetadata) + + self.log.info('Register new authorization client..') + + try: + # TODO: Fix hardcore url + r = requests.post('https://marosvn32.in2p3.fr/DIRAC/auth/register', {'redirect_uri': ''}, verify=False) + r.raise_for_status() + clientMetadata = r.json() + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer.') + except requests.exceptions.RequestException as ex: + return S_ERROR(r.content or ex) + except Exception as ex: + return S_ERROR('Cannot read response: %s' % ex) + + if not clientMetadata: + return S_ERROR('Cannot get authorization client credentials') + + self.log.debug('Store %s client to local configuration..' % clientMetadata['client_id']) + + data = CFG() + data.loadFromDict(clientMetadata) + comment = "Write fresh client credentials to /LocalInstallation section" + self.localCfg.createNewSection('LocalConfiguration/AuthorizationClient', comment=comment, contents=data) + self.localCfg.writeToFile(self.cfgFile) + + return S_OK(clientMetadata) + + def submitUserAuthorizationFlow(self, client=None, idP=None, group=None, grant='device'): + """ Submit authorization flow + """ + if not client: + # Prepare client + result = self.prepareClientCredentials() + if not result['OK']: + return result + client = result['Value'] + + # TODO: Fix hardcore url + url = 'https://marosvn32.in2p3.fr/DIRAC/auth/device?client_id=%s' % client['client_id'] + if group: + url += '&scope=g:%s' % group + if idP: + url += '&provider=%s' % idP + try: + r = requests.post(url, verify=False) + r.raise_for_status() + return S_OK(r.json()) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer.') + except requests.exceptions.RequestException as ex: + return S_ERROR(r.content or ex) + except Exception as ex: + return S_ERROR('Cannot read response: %s' % ex) + + def getIdPAuthorization(self, providerName, session): + """ Submit subsession and return dict with authorization url and session number + + :param str providerName: provider name + :param str mainSession: main session identificator + + :return: S_OK(dict)/S_ERROR() -- dictionary contain next keys: + Status -- session status + UserName -- user name, returned if status is 'ready' + Session -- session id, returned if status is 'needToAuth' + """ + session = session or generate_token(10) + result = self.idps.getIdProvider(providerName) # , sessionManager=self.__db) + return result['Value'].submitNewSession(session) if result['OK'] else result + + def parseAuthResponse(self, providerName, response, session): # , username, userProfile): + """ Fill session by user profile, tokens, comment, OIDC authorize status, etc. + Prepare dict with user parameters, if DN is absent there try to get it. + Create new or modify existing DIRAC user and store the session + + :param str providerName: identity provider name + :param dict response: authorization response + :param object session: session data dictionary + + :return: S_OK(dict)/S_ERROR() + """ + result = self._getRPC().parseAuthResponse(providerName, response, dict(session)) # , username, userProfile) + if result['OK']: + username, userID, profile = result['Value'] + if username and profile: + gAuthManagerData.updateProfiles(userID, profile) + return S_OK((username, userID, profile)) if result['OK'] else result + + +gSessionManager = AuthManagerClient() diff --git a/src/DIRAC/FrameworkSystem/Client/AuthManagerData.py b/src/DIRAC/FrameworkSystem/Client/AuthManagerData.py new file mode 100644 index 00000000000..df91e5eae0a --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Client/AuthManagerData.py @@ -0,0 +1,179 @@ +""" This class is located between the client and server part and designed to cache user information requested + from the server part. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +from pprint import pprint + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.Utilities import ThreadSafe, DIRACSingleton +from DIRAC.Core.Utilities.DictCache import DictCache + +__RCSID__ = "$Id$" + + +gCacheProfiles = ThreadSafe.Synchronizer() + + +@six.add_metaclass(DIRACSingleton.DIRACSingleton) +class AuthManagerData(object): + """ Authentication manager + """ + __cacheIdPToIDs = DictCache() + # # { + # # : [ ], + # # : ... + # # } + + __cacheProfiles = DictCache() + # # { + # # : { + # # DNs: { + # # : { + # # ProxyProvider: [ ], + # # VOMSRoles: [ ], + # # ... + # # }, + # # : { ... }, + # # } + # # }, + # # : { ... } + # # } + + __service = DictCache() + # # { + # # crash: bool + # # } + + def __init__(self): + self.rpc = None + + @gCacheProfiles + def getProfiles(self, userID=None): + """ Get cache information + + :param str userID: user ID + + :return: dict + """ + if userID: + return self.__cacheProfiles.get(userID) or {} + return self.__cacheProfiles.getDict() + + @gCacheProfiles + def updateProfiles(self, userID, data, time=3600 * 24): + """ Add cache information + + :param str userID: user ID + :param dict data: ID information data + :param int time: lifetime + """ + profileDict = self.__cacheProfiles.get(userID) or {} + print('================== CLI DATA updateProfiles ==================') + print('User ID: %s' % userID) + pprint(profileDict) + pprint(data) + for k, v in data.items(): + if v is not None: + profileDict[k] = v + self.__cacheProfiles.add(userID, time, value=profileDict) + ids = self.__cacheIdPToIDs.get(profileDict['Provider']) + if isinstance(ids, list) and userID not in ids: + self.__cacheIdPToIDs.add(profileDict['Provider'], time, userID) + + def __getRPC(self): + """ Get RPC + """ + if not self.rpc: + from DIRAC.Core.Base.Client import Client + self.rpc = Client()._getRPC(url="Framework/AuthManager", timeout=10) + return self.rpc + + def resfreshProfiles(self, userID=None): + """ Refresh profiles cache from service + + :param str userID: userID to update + + :return: S_OK()/S_ERROR() + """ + print('==== resfreshProfiles ====') + servCrash = self.__service.get('crash') + if servCrash and servCrash[1] > 2: + return servCrash[0] + result = self.__getRPC().getIdProfiles(userID) + # If the AuthManager service is down client will ignore it 1 minute + if result.get('Errno', 0) == 1112: + crash = self.__service.get('crash') + self.__service.add('crash', 60, value=(result, (crash[1] + 1) if crash else 1)) + if result['OK'] and result['Value']: + for uid, data in result['Value'].items(): + if data: + self.updateProfiles(uid, data) + return result + + def getIDsForDN(self, dn, provider=None): + """ Find ID for DN + + :param str dn: user DN + + :return: S_OK(list) + """ + userIDs = [] + profile = self.getProfiles() + for resfreshed in [0, 1]: + for uid, data in profile.items(): + if dn not in data.get('DNs', []) or (provider and data['DNs'][dn]['ProxyProvider'] != provider): + continue + userIDs.append(uid) + if userIDs or resfreshed: + break + result = self.resfreshProfiles() + if not result['OK']: + return result + profile = result['Value'] + + return S_OK(userIDs) + + def getDNsForID(self, uid): + """ Find DNs for ID + + :param str uid: user ID + + :return: S_OK(list)/S_ERROR() + """ + print('==== getDNsForID ====') + profile = self.getProfiles(userID=uid) + if not profile: + result = self.resfreshProfiles(userID=uid) + if not result['OK']: + return result + profile = result['Value'].get(uid, {}) + pprint(profile) + print('=====================') + return S_OK(profile.get('DNs', [])) + + def getDNOptionForID(self, uid, dn, option): + """ Find option for DN + + :param str uid: user ID + :param str dn: user DN + :param str option: option to find + + :return: S_OK()/S_ERROR() + """ + profile = self.getProfiles(userID=uid) + if not profile: + result = self.resfreshProfiles(userID=uid) + if not result['OK']: + return result + profile = result['Value'].get(uid, {}) + + if dn in profile.get('DNs', []): + return S_OK(profile['DNs'][dn].get(option)) + return S_OK(None) + + +gAuthManagerData = AuthManagerData() diff --git a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py index e1436379f51..e2ed4e4c1ff 100644 --- a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py +++ b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py @@ -1130,6 +1130,9 @@ def getStartupComponentStatus(self, componentTupleList): DIRAC.exit(-1) return S_ERROR(error) + if not cList: + return S_ERROR('No components found.') + result = self.execCommand(0, ['runsvstat'] + cList) if not result['OK']: return result diff --git a/src/DIRAC/FrameworkSystem/Client/ProxyGeneration.py b/src/DIRAC/FrameworkSystem/Client/ProxyGeneration.py index 5717f3acfd7..4bf0f78204a 100644 --- a/src/DIRAC/FrameworkSystem/Client/ProxyGeneration.py +++ b/src/DIRAC/FrameworkSystem/Client/ProxyGeneration.py @@ -342,13 +342,6 @@ def generateProxy(params): return S_ERROR("Can't contact DIRAC CS: %s" % retVal['Message']) userDN = chain.getCertInChain(-1)['Value'].getSubjectDN()['Value'] - if not params.diracGroup: - result = Registry.findDefaultGroupForDN(userDN) - if not result['OK']: - gLogger.warn("Could not get a default group for DN %s: %s" % (userDN, result['Message'])) - else: - params.diracGroup = result['Value'] - gLogger.info("Default discovered group is %s" % params.diracGroup) gLogger.info("Checking DN %s" % userDN) retVal = Registry.getUsernameForDN(userDN) if not retVal['OK']: @@ -356,13 +349,21 @@ def generateProxy(params): return S_ERROR("DN %s is not registered" % userDN) username = retVal['Value'] gLogger.info("Username is %s" % username) + + if not params.diracGroup: + result = Registry.findDefaultGroupForUser(username) + if not result['OK']: + gLogger.warn(retVal['Message']) + return S_ERROR("Cannot found group for %s user. %s" % (username, result['Message'])) + params.diracGroup = result['Value'] + retVal = Registry.getGroupsForUser(username) if not retVal['OK']: gLogger.warn(retVal['Message']) return S_ERROR("User %s has no groups defined" % username) groups = retVal['Value'] if params.diracGroup not in groups: - return S_ERROR("Requested group %s is not valid for DN %s" % (params.diracGroup, userDN)) + return S_ERROR("Requested group %s is not valid for %s user" % (params.diracGroup, username)) gLogger.info("Creating proxy for %s@%s (%s)" % (username, params.diracGroup, userDN)) if params.summary: h = int(params.proxyLifeTime / 3600) diff --git a/src/DIRAC/FrameworkSystem/Client/ProxyManagerClient.py b/src/DIRAC/FrameworkSystem/Client/ProxyManagerClient.py index ecb0185617a..414d1c1e058 100755 --- a/src/DIRAC/FrameworkSystem/Client/ProxyManagerClient.py +++ b/src/DIRAC/FrameworkSystem/Client/ProxyManagerClient.py @@ -1,15 +1,20 @@ -""" ProxyManagemerClient has the function to "talk" to the ProxyManager service +""" ProxyManagerClient has the function to "talk" to the ProxyManager service. Also, when requesting information + about users, this information is cached in a separate class + :mod:`ProxyManagerData `, and is used, in the Registry for example, + to reduce the number of requests to the server part """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import os +import six import datetime from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client.ProxyManagerData import gProxyManagerData +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup,\ + getUsernameForDN, getDNsForUsernameInGroup, getDNForUsernameInGroup from DIRAC.Core.Utilities import ThreadSafe, DIRACSingleton from DIRAC.Core.Utilities.DictCache import DictCache from DIRAC.Core.Security.ProxyFile import multiProxyArgument, deleteMultiProxy @@ -17,22 +22,25 @@ from DIRAC.Core.Security.X509Request import X509Request # pylint: disable=import-error from DIRAC.Core.Security.VOMS import VOMS from DIRAC.Core.Security import Locations -from DIRAC.Core.DISET.RPCClient import RPCClient +from DIRAC.Core.Base.Client import Client, createClient __RCSID__ = "$Id$" -gUsersSync = ThreadSafe.Synchronizer() gProxiesSync = ThreadSafe.Synchronizer() -gVOMSProxiesSync = ThreadSafe.Synchronizer() +@createClient('Framework/ProxyManager') @six.add_metaclass(DIRACSingleton.DIRACSingleton) -class ProxyManagerClient(object): - def __init__(self): - self.__usersCache = DictCache() +class ProxyManagerClient(Client): + """ Proxy manager client + """ + + def __init__(self, **kwargs): + """ Constructor + """ + super(ProxyManagerClient, self).__init__(**kwargs) + self.setServer('Framework/ProxyManager') self.__proxiesCache = DictCache() - self.__vomsProxiesCache = DictCache() - self.__pilotProxiesCache = DictCache() self.__filesCache = DictCache(self.__deleteTemporalFile) def __deleteTemporalFile(self, filename): @@ -48,153 +56,50 @@ def __deleteTemporalFile(self, filename): def clearCaches(self): """ Clear caches """ - self.__usersCache.purgeAll() self.__proxiesCache.purgeAll() - self.__vomsProxiesCache.purgeAll() - self.__pilotProxiesCache.purgeAll() - - def __getSecondsLeftToExpiration(self, expiration, utc=True): - """ Get time left to expiration in a seconds - - :param datetime expiration: - :param boolean utc: time in utc - - :return: datetime - """ - if utc: - td = expiration - datetime.datetime.utcnow() - else: - td = expiration - datetime.datetime.now() - return td.days * 86400 + td.seconds - - def __refreshUserCache(self, validSeconds=0): - """ Refresh user cache - - :param int validSeconds: required seconds the proxy is valid for - - :return: S_OK()/S_ERROR() - """ - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - retVal = rpcClient.getRegisteredUsers(validSeconds) - if not retVal['OK']: - return retVal - data = retVal['Value'] - # Update the cache - for record in data: - cacheKey = (record['DN'], record['group']) - self.__usersCache.add(cacheKey, - self.__getSecondsLeftToExpiration(record['expirationtime']), - record) - return S_OK() - - @gUsersSync - def userHasProxy(self, userDN, userGroup, validSeconds=0): - """ Check if a user(DN-group) has a proxy in the proxy management - Updates internal cache if needed to minimize queries to the service - - :param str userDN: user DN - :param str userGroup: user group - :param int validSeconds: proxy valid time in a seconds - - :return: S_OK()/S_ERROR() - """ - - # For backward compatibility reasons with versions prior to v7r1 - # we need to check for proxy with a group - # AND for groupless proxy even if not specified - - cacheKeys = ((userDN, userGroup), (userDN, '')) - for cacheKey in cacheKeys: - if self.__usersCache.exists(cacheKey, validSeconds): - return S_OK(True) - - # Get list of users from the DB with proxys at least 300 seconds - gLogger.verbose("Updating list of users in proxy management") - retVal = self.__refreshUserCache(validSeconds) - if not retVal['OK']: - return retVal - - for cacheKey in cacheKeys: - if self.__usersCache.exists(cacheKey, validSeconds): - return S_OK(True) - return S_OK(False) - - @gUsersSync - def getUserPersistence(self, userDN, userGroup, validSeconds=0): - """ Check if a user(DN-group) has a proxy in the proxy management + def userHasProxy(self, user, group, validSeconds=0): + """ Check if a user-group has a proxy in the proxy management Updates internal cache if needed to minimize queries to the service - :param str userDN: user DN - :param str userGroup: user group + :param str user: user name + :param str group: user group :param int validSeconds: proxy valid time in a seconds :return: S_OK()/S_ERROR() """ - cacheKey = (userDN, userGroup) - userData = self.__usersCache.get(cacheKey, validSeconds) - if userData: - if userData['persistent']: - return S_OK(True) - # Get list of users from the DB with proxys at least 300 seconds - gLogger.verbose("Updating list of users in proxy management") - retVal = self.__refreshUserCache(validSeconds) - if not retVal['OK']: - return retVal - userData = self.__usersCache.get(cacheKey, validSeconds) - if userData: - return S_OK(userData['persistent']) - return S_OK(False) - - def setPersistency(self, userDN, userGroup, persistent): - """ Set the persistency for user/group + dn = user + if not user.startswith('/'): + result = getDNForUsernameInGroup(user, group) + if not result['OK']: + return result + dn = result['Value'] - :param str userDN: user DN - :param str userGroup: user group - :param boolean persistent: presistent flag + result = gProxyManagerData.userHasProxy(dn, group, validSeconds) + if not result['OK'] or result['Value'] or user.startswith('/'): + return result - :return: S_OK()/S_ERROR() - """ - # Hack to ensure bool in the rpc call - persistentFlag = True - if not persistent: - persistentFlag = False - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - retVal = rpcClient.setPersistency(userDN, userGroup, persistentFlag) - if not retVal['OK']: - return retVal - # Update internal persistency cache - cacheKey = (userDN, userGroup) - record = self.__usersCache.get(cacheKey, 0) - if record: - record['persistent'] = persistentFlag - self.__usersCache.add(cacheKey, - self.__getSecondsLeftToExpiration(record['expirationtime']), - record) - return retVal + result = self.getGroupsStatusByUsername(user, [group]) + return S_OK(True if result['Value'][group]['Status'] == "ready" else False) if result['OK'] else result def uploadProxy(self, proxy=None, restrictLifeTime=0, rfcIfPossible=False): """ Upload a proxy to the proxy management service using delegation :param X509Chain proxy: proxy as a chain :param int restrictLifeTime: proxy live time in a seconds - :param boolean rfcIfPossible: make rfc proxy if possible + :param bool rfcIfPossible: make rfc proxy if possible :return: S_OK(dict)/S_ERROR() -- dict contain proxies """ # Discover proxy location + proxyLocation = proxy if isinstance(proxy, six.string_types) else "" if isinstance(proxy, X509Chain): chain = proxy - proxyLocation = "" else: - if not proxy: + if not proxyLocation: proxyLocation = Locations.getProxyLocation() if not proxyLocation: return S_ERROR("Can't find a valid proxy") - elif isinstance(proxy, six.string_types): - proxyLocation = proxy - else: - return S_ERROR("Can't find a valid proxy") chain = X509Chain() result = chain.loadProxyFromFile(proxyLocation) if not result['OK']: @@ -206,59 +111,71 @@ def uploadProxy(self, proxy=None, restrictLifeTime=0, rfcIfPossible=False): if chain.getDIRACGroup().get('Value') or chain.isVOMS().get('Value'): return S_ERROR("Cannot upload proxy with DIRAC group or VOMS extensions") - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) # Get a delegation request - result = rpcClient.requestDelegationUpload(chain.getRemainingSecs()['Value']) + result = self._getRPC().requestDelegationUpload() if not result['OK']: return result reqDict = result['Value'] # Generate delegated chain chainLifeTime = chain.getRemainingSecs()['Value'] - 60 - if restrictLifeTime and restrictLifeTime < chainLifeTime: - chainLifeTime = restrictLifeTime - retVal = chain.generateChainFromRequestString(reqDict['request'], - lifetime=chainLifeTime, + chainLifeTime = min(restrictLifeTime, chainLifeTime) if restrictLifeTime else chainLifeTime + retVal = chain.generateChainFromRequestString(reqDict['request'], lifetime=chainLifeTime, rfc=rfcIfPossible) if not retVal['OK']: return retVal # Upload! - result = rpcClient.completeDelegationUpload(reqDict['id'], retVal['Value']) - if not result['OK']: - return result - return S_OK(result.get('proxies') or result['Value']) + return self._getRPC().completeDelegationUpload(reqDict['id'], retVal['Value']) @gProxiesSync - def downloadProxy(self, userDN, userGroup, limited=False, requiredTimeLeft=1200, - cacheTime=14400, proxyToConnect=None, token=None): - """ Get a proxy Chain from the proxy management + def __getProxy(self, user, userGroup, limited=False, requiredTimeLeft=1200, cacheTime=14400, + proxyToConnect=None, token=None, voms=None): + """ Get a proxy Chain from the proxy manager - :param str userDN: user DN - :param str userGroup: user group - :param boolean limited: if need limited proxy + :param str user: user name or DN + :param str group: user group + :param bool limited: if need limited proxy :param int requiredTimeLeft: required proxy live time in a seconds :param int cacheTime: store in a cache time in a seconds :param X509Chain proxyToConnect: proxy as a chain :param str token: valid token to get a proxy + :param bool voms: for VOMS proxy :return: S_OK(X509Chain)/S_ERROR() """ - cacheKey = (userDN, userGroup) + if voms and not getVOMSAttributeForGroup(userGroup): + return S_ERROR("No mapping defined for group %s in the CS" % userGroup) + + dn = None + if user.startswith('/'): + dn = user + result = getUsernameForDN(dn) + if result['OK']: + user = result['Value'] + result = getDNsForUsernameInGroup(user, userGroup) + if not result['OK']: + return result + if dn not in result['Value']: + return S_ERROR('"%s" DN not match with %s user, %s group.' % (dn, user, userGroup)) + + cacheKey = (dn or user, userGroup, voms, limited) if self.__proxiesCache.exists(cacheKey, requiredTimeLeft): return S_OK(self.__proxiesCache.get(cacheKey)) + req = X509Request() req.generateProxyRequest(limited=limited) - if proxyToConnect: - rpcClient = RPCClient("Framework/ProxyManager", proxyChain=proxyToConnect, timeout=120) - else: - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - if token: - retVal = rpcClient.getProxyWithToken(userDN, userGroup, req.dumpRequest()['Value'], - int(cacheTime + requiredTimeLeft), token) - else: - retVal = rpcClient.getProxy(userDN, userGroup, req.dumpRequest()['Value'], - int(cacheTime + requiredTimeLeft)) + # TODO: self._getRPC cannot add new kwargs, so proxyChain need set to ProxyManagerClient() + rpcClient = self._getRPC() # proxyChain=proxyToConnect) if proxyToConnect else self._getRPC() + + result = req.dumpRequest() + if not result['OK']: + return result + requestPem = result['Value'] + requestTime = int(cacheTime + requiredTimeLeft) + + retVal = rpcClient.getProxy(dn or user, userGroup, requestPem, requestTime, token, voms) if not retVal['OK']: return retVal + chain = X509Chain(keyObj=req.getPKey()) retVal = chain.loadChainFromString(retVal['Value']) if not retVal['OK']: @@ -266,190 +183,104 @@ def downloadProxy(self, userDN, userGroup, limited=False, requiredTimeLeft=1200, self.__proxiesCache.add(cacheKey, chain.getRemainingSecs()['Value'], chain) return S_OK(chain) - def downloadProxyToFile(self, userDN, userGroup, limited=False, requiredTimeLeft=1200, - cacheTime=14400, filePath=None, proxyToConnect=None, token=None): - """ Get a proxy Chain from the proxy management and write it to file + def downloadProxy(self, user, group, limited=False, requiredTimeLeft=1200, cacheTime=14400, + proxyToConnect=None, token=None): + """ Get a proxy Chain from the proxy management - :param str userDN: user DN - :param str userGroup: user group - :param boolean limited: if need limited proxy + :param str user: user name or DN + :param str group: user group + :param bool limited: if need limited proxy :param int requiredTimeLeft: required proxy live time in a seconds :param int cacheTime: store in a cache time in a seconds - :param str filePath: path to save proxy :param X509Chain proxyToConnect: proxy as a chain :param str token: valid token to get a proxy :return: S_OK(X509Chain)/S_ERROR() """ - retVal = self.downloadProxy(userDN, userGroup, limited, requiredTimeLeft, cacheTime, proxyToConnect, token) - if not retVal['OK']: - return retVal - chain = retVal['Value'] - retVal = self.dumpProxyToFile(chain, filePath) - if not retVal['OK']: - return retVal - retVal['chain'] = chain - return retVal + return self.__getProxy(user, group, limited=limited, requiredTimeLeft=requiredTimeLeft, + cacheTime=cacheTime, proxyToConnect=proxyToConnect, token=token) - @gVOMSProxiesSync - def downloadVOMSProxy(self, userDN, userGroup, limited=False, requiredTimeLeft=1200, - cacheTime=14400, requiredVOMSAttribute=None, - proxyToConnect=None, token=None): - """ Download a proxy if needed and transform it into a VOMS one + def downloadProxyToFile(self, user, group, limited=False, requiredTimeLeft=1200, cacheTime=14400, + filePath=None, proxyToConnect=None, token=None): + """ Get a proxy Chain from the proxy management and write it to file - :param str userDN: user DN - :param str userGroup: user group - :param boolean limited: if need limited proxy + :param str user: user name or DN + :param str group: user group + :param bool limited: if need limited proxy :param int requiredTimeLeft: required proxy live time in a seconds :param int cacheTime: store in a cache time in a seconds - :param str requiredVOMSAttribute: VOMS attr to add to the proxy + :param str filePath: path to save proxy :param X509Chain proxyToConnect: proxy as a chain :param str token: valid token to get a proxy :return: S_OK(X509Chain)/S_ERROR() """ - cacheKey = (userDN, userGroup, requiredVOMSAttribute, limited) - if self.__vomsProxiesCache.exists(cacheKey, requiredTimeLeft): - return S_OK(self.__vomsProxiesCache.get(cacheKey)) - req = X509Request() - req.generateProxyRequest(limited=limited) - if proxyToConnect: - rpcClient = RPCClient("Framework/ProxyManager", proxyChain=proxyToConnect, timeout=120) - else: - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - if token: - retVal = rpcClient.getVOMSProxyWithToken(userDN, userGroup, req.dumpRequest()['Value'], - int(cacheTime + requiredTimeLeft), token, requiredVOMSAttribute) - - else: - retVal = rpcClient.getVOMSProxy(userDN, userGroup, req.dumpRequest()['Value'], - int(cacheTime + requiredTimeLeft), requiredVOMSAttribute) - if not retVal['OK']: - return retVal - chain = X509Chain(keyObj=req.getPKey()) - retVal = chain.loadChainFromString(retVal['Value']) - if not retVal['OK']: - return retVal - self.__vomsProxiesCache.add(cacheKey, chain.getRemainingSecs()['Value'], chain) - return S_OK(chain) + retVal = self.downloadProxy(user, group, limited, requiredTimeLeft, cacheTime, proxyToConnect, token) + if retVal['OK']: + chain = retVal['Value'] + retVal = self.dumpProxyToFile(chain, filePath) + if retVal['OK']: + retVal['chain'] = chain + return retVal - def downloadVOMSProxyToFile(self, userDN, userGroup, limited=False, requiredTimeLeft=1200, - cacheTime=14400, requiredVOMSAttribute=None, filePath=None, - proxyToConnect=None, token=None): - """ Download a proxy if needed, transform it into a VOMS one and write it to file + def downloadVOMSProxy(self, user, group, limited=False, requiredTimeLeft=1200, + cacheTime=14400, proxyToConnect=None, token=None): + """ Download a proxy if needed and transform it into a VOMS one - :param str userDN: user DN - :param str userGroup: user group - :param boolean limited: if need limited proxy + :param str user: user name or DN + :param str group: user group + :param bool limited: if need limited proxy :param int requiredTimeLeft: required proxy live time in a seconds :param int cacheTime: store in a cache time in a seconds - :param str requiredVOMSAttribute: VOMS attr to add to the proxy - :param str filePath: path to save proxy :param X509Chain proxyToConnect: proxy as a chain :param str token: valid token to get a proxy :return: S_OK(X509Chain)/S_ERROR() """ - retVal = self.downloadVOMSProxy(userDN, userGroup, limited, requiredTimeLeft, cacheTime, - requiredVOMSAttribute, proxyToConnect, token) - if not retVal['OK']: - return retVal - chain = retVal['Value'] - retVal = self.dumpProxyToFile(chain, filePath) - if not retVal['OK']: - return retVal - retVal['chain'] = chain - return retVal - - def getPilotProxyFromDIRACGroup(self, userDN, userGroup, requiredTimeLeft=43200, proxyToConnect=None): - """ Download a pilot proxy with VOMS extensions depending on the group - - :param str userDN: user DN - :param str userGroup: user group - :param int requiredTimeLeft: required proxy live time in seconds - :param X509Chain proxyToConnect: proxy as a chain - - :return: S_OK(X509Chain)/S_ERROR() - """ - # Assign VOMS attribute - vomsAttr = Registry.getVOMSAttributeForGroup(userGroup) - if not vomsAttr: - gLogger.warn("No voms attribute assigned to group %s when requested pilot proxy" % userGroup) - return self.downloadProxy(userDN, userGroup, limited=False, requiredTimeLeft=requiredTimeLeft, - proxyToConnect=proxyToConnect) - else: - return self.downloadVOMSProxy(userDN, userGroup, limited=False, requiredTimeLeft=requiredTimeLeft, - requiredVOMSAttribute=vomsAttr, proxyToConnect=proxyToConnect) + return self.__getProxy(user, group, limited=limited, requiredTimeLeft=requiredTimeLeft, voms=True, + cacheTime=cacheTime, proxyToConnect=proxyToConnect, token=token) - def getPilotProxyFromVOMSGroup(self, userDN, vomsAttr, requiredTimeLeft=43200, proxyToConnect=None): - """ Download a pilot proxy with VOMS extensions depending on the group + def downloadVOMSProxyToFile(self, user, group, limited=False, requiredTimeLeft=1200, + cacheTime=14400, filePath=None, proxyToConnect=None, token=None): + """ Download a proxy if needed, transform it into a VOMS one and write it to file - :param str userDN: user DN - :param str vomsAttr: VOMS attribute + :param str user: user name or DN + :param str group: user group + :param bool limited: if need limited proxy :param int requiredTimeLeft: required proxy live time in a seconds + :param int cacheTime: store in a cache time in a seconds + :param str filePath: path to save proxy :param X509Chain proxyToConnect: proxy as a chain - - :return: S_OK(X509Chain)/S_ERROR() - """ - groups = Registry.getGroupsWithVOMSAttribute(vomsAttr) - if not groups: - return S_ERROR("No group found that has %s as voms attrs" % vomsAttr) - - for userGroup in groups: - result = self.downloadVOMSProxy(userDN, userGroup, - limited=False, - requiredTimeLeft=requiredTimeLeft, - requiredVOMSAttribute=vomsAttr, - proxyToConnect=proxyToConnect) - if result['OK']: - return result - return result - - def getPayloadProxyFromDIRACGroup(self, userDN, userGroup, requiredTimeLeft, token=None, proxyToConnect=None): - """ Download a payload proxy with VOMS extensions depending on the group - - :param str userDN: user DN - :param str userGroup: user group - :param int requiredTimeLeft: required proxy live time in a seconds :param str token: valid token to get a proxy - :param X509Chain proxyToConnect: proxy as a chain :return: S_OK(X509Chain)/S_ERROR() """ - # Assign VOMS attribute - vomsAttr = Registry.getVOMSAttributeForGroup(userGroup) - if not vomsAttr: - gLogger.verbose("No voms attribute assigned to group %s when requested payload proxy" % userGroup) - return self.downloadProxy(userDN, userGroup, limited=True, requiredTimeLeft=requiredTimeLeft, - proxyToConnect=proxyToConnect, token=token) - else: - return self.downloadVOMSProxy(userDN, userGroup, limited=True, requiredTimeLeft=requiredTimeLeft, - requiredVOMSAttribute=vomsAttr, proxyToConnect=proxyToConnect, - token=token) + retVal = self.downloadVOMSProxy(user, group, limited, requiredTimeLeft, cacheTime, proxyToConnect, token) + if retVal['OK']: + chain = retVal['Value'] + retVal = self.dumpProxyToFile(chain, filePath) + if retVal['OK']: + retVal['chain'] = chain + return retVal - def getPayloadProxyFromVOMSGroup(self, userDN, vomsAttr, token, requiredTimeLeft, proxyToConnect=None): - """ Download a payload proxy with VOMS extensions depending on the VOMS attr + def downloadCorrectProxy(self, user, group, requiredTimeLeft=43200, proxyToConnect=None, token=None): + """ Download a proxy with VOMS extensions depending on the group or simple proxy + if group without VOMS extensions - :param str userDN: user DN - :param str vomsAttr: VOMS attribute - :param str token: valid token to get a proxy + :param str user: user name or DN + :param str group: user group :param int requiredTimeLeft: required proxy live time in a seconds :param X509Chain proxyToConnect: proxy as a chain + :param str token: valid token to get a proxy :return: S_OK(X509Chain)/S_ERROR() """ - groups = Registry.getGroupsWithVOMSAttribute(vomsAttr) - if not groups: - return S_ERROR("No group found that has %s as voms attrs" % vomsAttr) - userGroup = groups[0] - - return self.downloadVOMSProxy(userDN, - userGroup, - limited=True, - requiredTimeLeft=requiredTimeLeft, - requiredVOMSAttribute=vomsAttr, - proxyToConnect=proxyToConnect, - token=token) + if not getVOMSAttributeForGroup(group): + gLogger.verbose("No voms attribute assigned to group %s when requested proxy" % group) + return self.downloadProxy(user, group, limited=False, requiredTimeLeft=requiredTimeLeft, + proxyToConnect=proxyToConnect) + return self.downloadVOMSProxy(user, group, limited=False, requiredTimeLeft=requiredTimeLeft, + proxyToConnect=proxyToConnect) def dumpProxyToFile(self, chain, destinationFile=None, requiredTimeLeft=600): """ Dump a proxy to a file. It's cached so multiple calls won't generate extra files @@ -486,28 +317,17 @@ def deleteGeneratedProxyFile(self, chain): self.__filesCache.delete(chain) return S_OK() - def deleteProxyBundle(self, idList): - """ delete a list of id's - - :param list,tuple idList: list of identity numbers - - :return: S_OK(int)/S_ERROR() - """ - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - return rpcClient.deleteProxyBundle(idList) - - def requestToken(self, requesterDN, requesterGroup, numUses=1): + def requestToken(self, requester, requesterGroup, numUses=1): """ Request a number of tokens. usesList must be a list of integers and each integer is the number of uses a token must have - :param str requesterDN: user DN + :param str requester: user name :param str requesterGroup: user group :param int numUses: number of uses :return: S_OK(tuple)/S_ERROR() -- tuple contain token, number uses """ - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - return rpcClient.generateToken(requesterDN, requesterGroup, numUses) + return self._getRPC().generateToken(requester, requesterGroup, numUses) def renewProxy(self, proxyToBeRenewed=None, minLifeTime=3600, newProxyLifeTime=43200, proxyToConnect=None): """ Renew a proxy using the ProxyManager @@ -545,6 +365,10 @@ def renewProxy(self, proxyToBeRenewed=None, minLifeTime=3600, newProxyLifeTime=4 deleteMultiProxy(proxyToConnectDict) return retVal userGroup = retVal['Value'] + result = getUsernameForDN(userDN) + if not result['OK']: + return result + userName = result['Value'] limited = proxyToRenewDict['chain'].isLimitedProxy()['Value'] voms = VOMS() @@ -553,19 +377,12 @@ def renewProxy(self, proxyToBeRenewed=None, minLifeTime=3600, newProxyLifeTime=4 deleteMultiProxy(proxyToRenewDict) deleteMultiProxy(proxyToConnectDict) return retVal - vomsAttrs = retVal['Value'] - if vomsAttrs: - retVal = self.downloadVOMSProxy(userDN, - userGroup, - limited=limited, - requiredTimeLeft=newProxyLifeTime, - requiredVOMSAttribute=vomsAttrs[0], + + if retVal['Value']: + retVal = self.downloadVOMSProxy(userName, userGroup, limited=limited, requiredTimeLeft=newProxyLifeTime, proxyToConnect=proxyToConnectDict['chain']) else: - retVal = self.downloadProxy(userDN, - userGroup, - limited=limited, - requiredTimeLeft=newProxyLifeTime, + retVal = self.downloadProxy(userName, userGroup, limited=limited, requiredTimeLeft=newProxyLifeTime, proxyToConnect=proxyToConnectDict['chain']) deleteMultiProxy(proxyToRenewDict) @@ -573,7 +390,6 @@ def renewProxy(self, proxyToBeRenewed=None, minLifeTime=3600, newProxyLifeTime=4 if not retVal['OK']: return retVal - chain = retVal['Value'] if not proxyToRenewDict['tempFile']: @@ -581,46 +397,50 @@ def renewProxy(self, proxyToBeRenewed=None, minLifeTime=3600, newProxyLifeTime=4 return S_OK(chain) - def getDBContents(self, condDict={}, sorting=[['UserDN', 'DESC']], start=0, limit=0): + def getVOMSAttributes(self, chain): + """ Get the voms attributes for a chain + + :param X509Chain chain: proxy as a chain + + :return: S_OK(str)/S_ERROR() + """ + return VOMS().getVOMSAttributes(chain) + + def getDBContents(self, condDict={}, start=0, limit=0): """ Get the contents of the db :param dict condDict: search condition + :param int start: search limit start + :param int start: search limit amount :return: S_OK(dict)/S_ERROR() -- dict contain fields, record list, total records """ - rpcClient = RPCClient("Framework/ProxyManager", timeout=120) - return rpcClient.getContents(condDict, sorting, start, limit) + return self._getRPC().getContents(condDict, [['UserDN', 'DESC']], 0, 0) - def getVOMSAttributes(self, chain): - """ Get the voms attributes for a chain + def getUploadedProxiesDetails(self, user=None, group=None): + """ Get the details about an uploaded proxy - :param X509Chain chain: proxy as a chain + :param str user: user name + :param str group: group name - :return: S_OK(str)/S_ERROR() + :return: S_OK(dict)/S_ERROR() -- dict contain fields, record list, total records """ - return VOMS().getVOMSAttributes(chain) + return self.getDBContents({'UserName': user, 'UserGroup': group}) - def getUploadedProxyLifeTime(self, DN, group): + def getUploadedProxyLifeTime(self, user, group): """ Get the remaining seconds for an uploaded proxy - :param str DN: user DN - :param str group: group + :param str user: user name + :param str group: group name :return: S_OK(int)/S_ERROR() """ - result = self.getDBContents({'UserDN': [DN], 'UserGroup': [group]}) + result = self.getUploadedProxiesDetails(user, group) if not result['OK']: return result - data = result['Value'] - if len(data['Records']) == 0: - return S_OK(0) - pNames = list(data['ParameterNames']) - dnPos = pNames.index('UserDN') - groupPos = pNames.index('UserGroup') - expiryPos = pNames.index('ExpirationTime') - for row in data['Records']: - if DN == row[dnPos] and group == row[groupPos]: - td = row[expiryPos] - datetime.datetime.utcnow() + for proxyDict in result['Value']['Dictionaries']: + if user == proxyDict['user'] and group == proxyDict['group']: + td = proxyDict['expirationtime'] - datetime.datetime.utcnow() secondsLeft = td.days * 86400 + td.seconds return S_OK(max(0, secondsLeft)) return S_OK(0) @@ -630,7 +450,7 @@ def getUserProxiesInfo(self): :return: S_OK(dict)/S_ERROR() """ - result = RPCClient("Framework/ProxyManager", timeout=120).getUserProxiesInfo() + result = self._getRPC().getUserProxiesInfo() if 'rpcStub' in result: result.pop('rpcStub') return result diff --git a/src/DIRAC/FrameworkSystem/Client/ProxyManagerData.py b/src/DIRAC/FrameworkSystem/Client/ProxyManagerData.py new file mode 100644 index 00000000000..83faa3829f4 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Client/ProxyManagerData.py @@ -0,0 +1,214 @@ +""" This class is located between the client and server part and designed to cache user information requested + from the server part. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import six + +from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC.Core.Utilities import ThreadSafe, DIRACSingleton +from DIRAC.Core.Utilities.DictCache import DictCache + +__RCSID__ = "$Id$" + +gUsersSync = ThreadSafe.Synchronizer() +gVOMSUsersSync = ThreadSafe.Synchronizer() + + +@six.add_metaclass(DIRACSingleton.DIRACSingleton) +class ProxyManagerData(object): + """ Proxy manager client + """ + # # __usersCache cache, with following: + # # Key: (userDN, group) + # # Value: dict + # # { + # # 'DN': , + # # 'user': , + # # 'groups': [], <-- TODOAL: current group + # # 'expirationtime': , + # # 'provider': + # # } + + # # __VOMSesUsersCache cache, with next structure: + # # Key: VOMS VO name + # # Value: S_OK(dict)/S_ERROR() -- request VOMS information result that contain + # # dictionary with following: + # # { : { + # # Suspended: bool, + # # VOMSRoles: [], + # # ActiveRoles: [], + # # SuspendedRoles: [] + # # } + # # } + + def __init__(self): + self.rpc = None + self.__usersCache = DictCache() + self.__VOMSesUsersCache = DictCache() + + @gUsersSync + @gVOMSUsersSync + def clearCaches(self): + """ Clear caches + """ + self.__usersCache.purgeAll() + self.__VOMSesUsersCache.purgeAll() + + @gUsersSync + def __getUsersCache(self, mask=None, time=None): + """ Get cache information + + :param str mask: user ID + :param int time: lifetime + + :return: dict + """ + if mask: + return self.__usersCache.get(mask, time) or {} + return self.__usersCache.getDict() + + @gUsersSync + def __addUsersCache(self, data, time=3600 * 24): + """ Add cache information + + :param dict data: ID information data + :param int time: lifetime + """ + for oid, info in data.items(): + self.__usersCache.add(oid, time, value=info) + + def __getSecondsLeftToExpiration(self, expiration, utc=True): + """ Get time left to expiration in a seconds + + :param datetime expiration: + :param bool utc: time in utc + + :return: datetime + """ + if utc: + td = expiration - datetime.datetime.utcnow() + else: + td = expiration - datetime.datetime.now() + return td.days * 86400 + td.seconds + + def __getRPC(self): + """ Get RPC + """ + if not self.rpc: + from DIRAC.Core.Base.Client import Client + self.rpc = Client()._getRPC(url="Framework/ProxyManager", timeout=120) + return self.rpc + + def __refreshUserCache(self, validSeconds=0): + """ Refresh user cache + + :param int validSeconds: required seconds the proxy is valid for + + :return: S_OK()/S_ERROR() + """ + retVal = self.__getRPC().getRegisteredUsers(validSeconds) + if not retVal['OK']: + return retVal + # Update the cache + resDict = {} + for record in retVal['Value']: + for group in record['groups']: + cacheKey = (record['DN'], group) + resDict[cacheKey] = record + self.__addUsersCache({cacheKey: record}, self.__getSecondsLeftToExpiration(record['expirationtime'])) + return S_OK(resDict) + + @gVOMSUsersSync + def __getVOMSUsersDict(self): + """ Get users dictionary from cache + + :return: dict + """ + return self.__VOMSesUsersCache.getDict() + + @gVOMSUsersSync + def __setVOMSUsersDict(self, usersDict): + """ Set dictionary to cache + + :param dict usersDict: dictionary with VOMS users + """ + for vo, userInfo in usersDict.items(): + self.__VOMSesUsersCache.add(vo, 3600 * 24, value=userInfo) + self.__VOMSesUsersCache.add('Fresh', 3600 * 12, value=True) + + def __refreshVOMSesCache(self): + """ Get fresh info from service about VOMSes + + :return: S_OK()/S_ERROR() + """ + result = self.__getRPC().getVOMSesUsers() + if result['OK']: + self.__setVOMSUsersDict(result['Value']) + return result + + def getActualVOMSesDNs(self, voList=None, dnList=None): + """ Return actual/not suspended DNs from VOMSes + + :param list voList: VOs to get + :param list dnList: DNs to get + + :return: S_OK(dict)/S_ERROR() + """ + vomsUsers = self.__getVOMSUsersDict() + if not vomsUsers.get('Fresh'): + result = self.__refreshVOMSesCache() + if not result['OK']: + return result + vomsUsers = result['Value'] + vomsUsers.pop('Fresh', None) + res = {} + if not vomsUsers: + return S_ERROR('VOMSes has not been updated.') + for vo, voInfo in vomsUsers.items(): + if voList and vo not in voList: + continue + if not voInfo['OK']: + res[vo] = voInfo + continue + res[vo] = S_OK({}) + for dn, data in voInfo['Value'].items(): + if dnList and dn not in dnList: + continue + if dn not in res[vo]['Value']: + res[vo]['Value'][dn] = {'Suspended': data.get('suspended'), + 'VOMSRoles': [], + 'ActiveRoles': [], + 'SuspendedRoles': []} + res[vo]['Value'][dn]['VOMSRoles'] = list(set(res[vo]['Value'][dn]['VOMSRoles'] + data['Roles'])) + if data['certSuspended'] or data.get('suspended'): + res[vo]['Value'][dn]['SuspendedRoles'] = list(set(res[vo]['Value'][dn]['SuspendedRoles'] + data['Roles'])) + else: + res[vo]['Value'][dn]['ActiveRoles'] = list(set(res[vo]['Value'][dn]['ActiveRoles'] + data['Roles'])) + return S_OK(res) + + def userHasProxy(self, userDN, group, validSeconds=0): + """ Check if a user-group has a proxy in the proxy management + Updates internal cache if needed to minimize queries to the service + + :param str userDN: user DN + :param str group: user group + :param int validSeconds: proxy valid time in a seconds + + :return: S_OK(bool)/S_ERROR() + """ + cacheKey = (userDN, group) + if self.__getUsersCache(cacheKey, validSeconds): + return S_OK(True) + # Get list of users from the DB with proxys at least 300 seconds + gLogger.verbose("Updating list of users in proxy management") + result = self.__refreshUserCache(validSeconds) + if result.get('Value', {}).get(cacheKey): + return S_OK(bool(result['Value'].get(cacheKey))) + return result + + +gProxyManagerData = ProxyManagerData() diff --git a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg index 51c7ff8dc4a..678216a0f3e 100644 --- a/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg +++ b/src/DIRAC/FrameworkSystem/ConfigTemplate.cfg @@ -1,3 +1,40 @@ +APIs +{ + Auth + { + Port = 8000 + # Authorization server parameters description, these settings are available by default + # https://tools.ietf.org/html/rfc8414#section-2 + # AuthorizationServer + # { + # issuer = https://marosvn32.in2p3.fr/DIRAC/auth + # } + # Describe DIRAC authorization clients + # https://tools.ietf.org/html/rfc6749#section-2.1 + Clients + { + # DIRAC client instalation described as the public client + DIRACCLI + { + client_id = type here client id + redirect_uri = type here url for redirect + } + # WebAppDIRAC server instalation described as the confidential client + WEBAPPDIRACCLI + { + client_id = type here client id + client_secret = type here client secret + redirect_uri = https://yourportaldomain/DIRAC/loginComplete + } + } + } + Proxy + { + Port = 8000 + # Allow download personal proxy + downloadablePersonalProxy = True + } +} Services { Gateway @@ -14,7 +51,7 @@ Services } } ##BEGIN ProxyManager: - # Section to describe ProxyManager system + # Section to describe ProxyManager service # https://dirac.readthedocs.org/en/latest/AdministratorGuide/Systems/Framework/ProxyManager/index.html ProxyManager { @@ -43,6 +80,23 @@ Services } } ##END + ##BEGIN AuthManager: + # Section to describe AuthManager service + TokenManager + { + CheckIn + { + client_id = type client id here receved after client registration + client_secret = type client secret here receved after client registration + redirect_uri = https://yourdomain/DIRAC/auth/redirect + } + Port = 9151 + Authorization + { + Default = authenticated + } + } + ##END SecurityLogging { Port = 9153 diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.py b/src/DIRAC/FrameworkSystem/DB/AuthDB.py new file mode 100644 index 00000000000..90dce4c6b3c --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.py @@ -0,0 +1,293 @@ +""" Auth class is a front-end to the Auth Database +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +from time import time +from pprint import pprint +from authlib.oauth2.rfc6749.wrappers import OAuth2Token +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin, OAuth2TokenMixin +from sqlalchemy import Column, Integer, Text, BigInteger, String +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.types import TypeDecorator, VARCHAR +from sqlalchemy.ext.mutable import Mutable +from sqlalchemy.ext.declarative import declarative_base + +from DIRAC import S_OK, S_ERROR, gLogger, gConfig +from DIRAC.Core.Base.SQLAlchemyDB import SQLAlchemyDB +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthClients + +__RCSID__ = "$Id$" + + +Model = declarative_base() + + +class Token(Model, OAuth2TokenMixin): + __tablename__ = 'Tokens' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + # access_token too large for varchar(255) + # 767 bytes is the stated prefix limitation for InnoDB tables in MySQL version 5.6 + # https://stackoverflow.com/questions/1827063/mysql-error-key-specification-without-a-key-length + access_token = Column(Text, nullable=False) + # client_id too large + client_id = Column(String(255)) + provider = Column(Text) + user_id = Column(String(255), nullable=False, unique=True, primary_key=True) + expires_at = Column(Integer, nullable=False, default=0) + id_token = Column(Text, nullable=False) + + +class AuthSession(Model): + __tablename__ = 'Sessions' + __table_args__ = {'mysql_engine': 'InnoDB', + 'mysql_charset': 'utf8'} + id = Column(String(255), unique=True, primary_key=True, nullable=False) + state = Column(String(255)) + uri = Column(String(255)) + client_id = Column(String(255)) + user_id = Column(String(255)) + username = Column(String(255)) + expires_at = Column(Integer, nullable=False, default=0) + expires_in = Column(Integer, nullable=False, default=0) + interval = Column(Integer, nullable=False, default=5) + verification_uri = Column(String(255)) + verification_uri_complete = Column(String(255)) + user_code = Column(String(255)) + device_code = Column(String(255)) + scope = Column(String(255)) + + +class AuthDB(SQLAlchemyDB): + """ AuthDB class is a front-end to the OAuth Database + """ + # TODO: provide logging instead of print + def __init__(self): + """ Constructor + """ + super(AuthDB, self).__init__() + self._initializeConnection('Framework/AuthDB') + result = self.__initializeDB() + if not result['OK']: + raise Exception("Can't create tables: %s" % result['Message']) + self.session = scoped_session(self.sessionMaker_o) + + def __initializeDB(self): + """ Create the tables + """ + tablesInDB = self.inspector.get_table_names() + + # Tokens + if 'Tokens' not in tablesInDB: + try: + Token.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + # Sessions + if 'Sessions' not in tablesInDB: + try: + AuthSession.__table__.create(self.engine) # pylint: disable=no-member + except Exception as e: + return S_ERROR(e) + + return S_OK() + + def storeToken(self, metadata): + """ Save token + + :param dict metadata: token info + + :return: S_OK(str)/S_ERROR() + """ + attrts = {} + print('========= STORE TOKEN') + pprint(metadata) + print('---------------------') + for k, v in metadata.items(): + if k not in Token.__dict__.keys(): + self.log.warn('%s is not expected as token attribute.' % k) + else: + attrts[k] = v + session = self.session() + try: + session.add(Token(**attrts)) + except Exception as e: + return self.__result(session, S_ERROR('Could not add Token: %s' % e)) + return self.__result(session, S_OK('Token successfully added')) + + def updateToken(self, token, refreshToken): + """ Update token + + :param dict token: token info + :param str refreshToken: refresh token + + :return: S_OK(object)/S_ERROR() + """ + self.removeToken(refresh_token=refreshToken) + return self.storeToken(token) + + def removeToken(self, access_token=None, refresh_token=None): + """ Remove token + + :param str access_token: access token + :param str refresh_token: refresh token + + :return: S_OK(object)/S_ERROR() + """ + session = self.session() + try: + if access_token: + session.query(Token).filter(Token.access_token == access_token).delete() + elif refresh_token: + session.query(Token).filter(Token.refresh_token == refresh_token).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK('Token successfully removed')) + + def getTokenByRefreshToken(self, refresh_token): + session = self.session() + try: + token = session.query(Token).filter(Token.refresh_token == refresh_token).first() + except NoResultFound: + return self.__result(session, S_ERROR("Token not found.")) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(token))) + + def getTokenByUserIDAndProvider(self, userID, provider): + session = self.session() + try: + token = session.query(Token).filter(Token.user_id == userID, Token.provider == provider).first() + except NoResultFound: + return self.__result(session, S_ERROR("Token not found.")) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(token))) + + def getIdPTokens(self, IdP, userIDs=None): + session = self.session() + try: + if userIDs: + tokens = session.query(Token).filter(Token.provider == IdP).filter(Token.user_id.in_(set(userIDs))).all() + else: + tokens = session.query(Token).filter(Token.provider == IdP).all() + except NoResultFound: + return self.__result(session, S_ERROR("Tokens not found.")) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK([OAuth2Token(self.__rowToDict(t)) for t in tokens])) + + def _addSession(self, data): + """ Add new session + + :param dict data: session metadata + + :return: S_OK(dict)/S_ERROR() + """ + print('============ addSession ============') + pprint(data) + session = self.session() + # newAuthSession = Session(**data) + try: + session.add(Session(**data)) + except Exception as e: + return self.__result(session, S_ERROR('Could not add Session: %s' % e)) + return self.__result(session, S_OK()) + + def addSession(self, data): + result = self._addSession(data) + if not result['OK']: + result = self.updateSession(data) + return result + + def updateSession(self, data): + """ Update session + + :param dict data: session data with 'id' key + + :return: S_OK(object)/S_ERROR() + """ + session = self.session() + try: + session.update(AuthSession(**data)).where(AuthSession.id == data['id']) + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def removeSession(self, sessionID): + """ Remove session + + :param str sessionID: session id + + :return: S_OK()/S_ERROR() + """ + session = self.session() + try: + session.query(AuthSession).filter(AuthSession.id == sessionID).delete() + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK()) + + def getSession(self, sessionID): + """ Get client + + :param str sessionID: session id + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.id == sessionID).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + except NoResultFound: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def getSessionByUserCode(self, userCode): + """ Get client + + :param str userCode: user code + + :return: S_OK(dict)/S_ERROR() + """ + session = self.session() + try: + resData = session.query(AuthSession).filter(AuthSession.user_code == userCode).first() + except MultipleResultsFound: + return self.__result(session, S_ERROR("%s is not unique ID." % sessionID)) + except NoResultFound: + return self.__result(session, S_ERROR("%s session is expired." % sessionID)) + except Exception as e: + return self.__result(session, S_ERROR(str(e))) + return self.__result(session, S_OK(self.__rowToDict(resData))) + + def __result(self, session, result=None): + try: + if not result['OK']: + session.rollback() + else: + session.commit() + except Exception as e: + session.rollback() + result = S_ERROR('Could not commit: %s' % (e)) + session.close() + return result + + def __rowToDict(self, row): + """ Convert sqlalchemy row to dictionary + + :param object row: sqlalchemy row + + :return: dict + """ + return {c.name: str(getattr(row, c.name)) for c in row.__table__.columns} if row else {} diff --git a/src/DIRAC/FrameworkSystem/DB/AuthDB.sql b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql new file mode 100644 index 00000000000..bfd422307e6 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/DB/AuthDB.sql @@ -0,0 +1,2 @@ +# Everything is created by the DB object upon instantiation if it does not exists. +use AuthDB; \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/DB/ProxyDB.py b/src/DIRAC/FrameworkSystem/DB/ProxyDB.py index 94d613ab347..c9f000ec66c 100755 --- a/src/DIRAC/FrameworkSystem/DB/ProxyDB.py +++ b/src/DIRAC/FrameworkSystem/DB/ProxyDB.py @@ -27,6 +27,7 @@ from DIRAC import gConfig, gLogger, S_OK, S_ERROR from DIRAC.Core.Base.DB import DB from DIRAC.Core.Utilities import DErrno +from DIRAC.Core.Utilities.Decorators import deprecated from DIRAC.Core.Security import Properties from DIRAC.Core.Security.VOMS import VOMS from DIRAC.Core.Security.MyProxy import MyProxy @@ -36,15 +37,19 @@ from DIRAC.ConfigurationSystem.Client.PathFinder import getDatabaseSection from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient from DIRAC.Resources.ProxyProvider.ProxyProviderFactory import ProxyProviderFactory +from DIRAC.FrameworkSystem.Client.AuthManagerData import gAuthManagerData class ProxyDB(DB): + """ Proxy database + """ NOTIFICATION_TIMES = [2592000, 1296000] def __init__(self, useMyProxy=False): DB.__init__(self, 'ProxyDB', 'Framework/ProxyDB') + self.__version = 2 self.__defaultRequestLifetime = 300 # 5min self.__defaultTokenLifetime = 86400 * 7 # 1 week self.__defaultTokenMaxUses = 50 @@ -102,13 +107,11 @@ def __initializeDB(self): } if 'ProxyDB_CleanProxies' not in tablesInDB: - tablesD['ProxyDB_CleanProxies'] = {'Fields': {'UserName': 'VARCHAR(64) NOT NULL', - 'UserDN': 'VARCHAR(255) NOT NULL', - 'ProxyProvider': 'VARCHAR(64) DEFAULT "Certificate"', + tablesD['ProxyDB_CleanProxies'] = {'Fields': {'UserDN': 'VARCHAR(255) NOT NULL', 'Pem': 'BLOB', 'ExpirationTime': 'DATETIME', }, - 'PrimaryKey': ['UserDN', 'ProxyProvider'] + 'PrimaryKey': 'UserDN' } # WARN: Now proxies upload only in ProxyDB_CleanProxies, so this table will not be needed in some future if 'ProxyDB_Proxies' not in tablesInDB: @@ -135,9 +138,9 @@ def __initializeDB(self): if 'ProxyDB_Log' not in tablesInDB: tablesD['ProxyDB_Log'] = {'Fields': {'ID': 'BIGINT NOT NULL AUTO_INCREMENT', - 'IssuerDN': 'VARCHAR(255) NOT NULL', + 'IssuerUsername': 'VARCHAR(255) NOT NULL', 'IssuerGroup': 'VARCHAR(255) NOT NULL', - 'TargetDN': 'VARCHAR(255) NOT NULL', + 'TargetUsername': 'VARCHAR(255) NOT NULL', 'TargetGroup': 'VARCHAR(255) NOT NULL', 'Action': 'VARCHAR(128) NOT NULL', 'Timestamp': 'DATETIME', @@ -148,7 +151,7 @@ def __initializeDB(self): if 'ProxyDB_Tokens' not in tablesInDB: tablesD['ProxyDB_Tokens'] = {'Fields': {'Token': 'VARCHAR(64) NOT NULL', - 'RequesterDN': 'VARCHAR(255) NOT NULL', + 'RequesterUsername': 'VARCHAR(255) NOT NULL', 'RequesterGroup': 'VARCHAR(255) NOT NULL', 'ExpirationTime': 'DATETIME NOT NULL', 'UsesLeft': 'SMALLINT UNSIGNED DEFAULT 1', @@ -207,7 +210,12 @@ def __checkDBVersion(self): :return: S_OK()/S_ERROR() """ - for tableName in ("ProxyDB_CleanProxies", "ProxyDB_Proxies", "ProxyDB_VOMSProxies"): + if self.versionDB == self.__version: # pylint: disable=no-member + return S_OK() + if self.versionDB > self.__version: # pylint: disable=no-member + return S_ERROR('Already installed newer DB version "%s".' % self.versionDB) # pylint: disable=no-member + + for tableName in ("ProxyDB_Proxies", "ProxyDB_VOMSProxies"): result = self._query("describe `%s`" % tableName) if not result['OK']: return result @@ -217,11 +225,36 @@ def __checkDBVersion(self): if not result['OK']: return result - def generateDelegationRequest(self, proxyChain, userDN): + if self.versionDB == 0 and self.versionDB < self.__version: # pylint: disable=no-member + for tb, oldColumn, newColumn in [('ProxyDB_Log', 'IssuerDN', 'IssuerUsername'), + ('ProxyDB_Log', 'TargetDN', 'TargetUsername'), + ('ProxyDB_Tokens', 'RequesterDN', 'RequesterUsername')]: + result = self._query("SHOW COLUMNS FROM `%s` LIKE '%s'" % (tb, oldColumn)) + if result['OK'] and len(result['Value']) > 0: + result = self._query('ALTER TABLE %s CHANGE COLUMN %s %s VARCHAR(255) NOT NULL' % (tb, oldColumn, newColumn)) + if not result['OK']: + return result + result = self.updateDBVersion(1) # pylint: disable=no-member + if not result['OK']: + return result + + if self.versionDB == 1 and self.versionDB < self.__version: # pylint: disable=no-member + for column in ['UserName', 'ProxyProvider']: + result = self._query("SHOW COLUMNS FROM `ProxyDB_CleanProxies` LIKE '%s'" % column) + if result['OK'] and len(result['Value']) > 0: + result = self._query('ALTER TABLE ProxyDB_CleanProxies DROP COLUMN %s' % column) + if not result['OK']: + return result + result = self.updateDBVersion(2) # pylint: disable=no-member + if not result['OK']: + return result + + return S_OK() + + def generateDelegationRequest(self, credDict): """ Generate a request and store it for a given proxy Chain - :param X509Chain() proxyChain: proxy as chain - :param str userDN: user DN + :param dict credDict: dictionary that contain proxy as chain :return: S_OK(dict)/S_ERROR() -- dict contain id and proxy as string of the request """ @@ -229,7 +262,7 @@ def generateDelegationRequest(self, proxyChain, userDN): if not retVal['OK']: return retVal connObj = retVal['Value'] - retVal = proxyChain.generateProxyRequest() + retVal = credDict['x509Chain'].generateProxyRequest() if not retVal['OK']: return retVal request = retVal['Value'] @@ -242,14 +275,13 @@ def generateDelegationRequest(self, proxyChain, userDN): return retVal allStr = reqStr + retVal['Value'] try: - sUserDN = self._escapeString(userDN)['Value'] + sDN = self._escapeString(credDict['DN'])['Value'] sAllStr = self._escapeString(allStr)['Value'] except KeyError: return S_ERROR("Cannot escape DN") - cmd = "INSERT INTO `ProxyDB_Requests` ( Id, UserDN, Pem, ExpirationTime )" - cmd += " VALUES ( 0, %s, %s, TIMESTAMPADD( SECOND, %d, UTC_TIMESTAMP() ) )" % (sUserDN, - sAllStr, - int(self.__defaultRequestLifetime)) + cmd = "INSERT INTO `ProxyDB_Requests` (UserDN, Pem, ExpirationTime) VALUES " + cmd += "(%s, %s, TIMESTAMPADD(SECOND, %d, UTC_TIMESTAMP()))" % (sDN, sAllStr, + int(self.__defaultRequestLifetime)) retVal = self._update(cmd, conn=connObj) if not retVal['OK']: return retVal @@ -263,8 +295,7 @@ def generateDelegationRequest(self, proxyChain, userDN): data = retVal['Value'] if not data: return S_ERROR("Insertion of the request in the db didn't work as expected") - userGroup = proxyChain.getDIRACGroup().get('Value') or "unset" - self.logAction("request upload", userDN, userGroup, userDN, "any") + self.logAction("request upload", credDict['username'], credDict['group'], credDict['username'], "any") # Here we go! return S_OK({'id': data[0][0], 'request': reqStr}) @@ -276,22 +307,21 @@ def __retrieveDelegationRequest(self, requestId, userDN): :return: S_OK(str)/S_ERROR() """ - try: - sUserDN = self._escapeString(userDN)['Value'] - except KeyError: + result = self._escapeString(userDN) + if not result['OK']: return S_ERROR("Cannot escape DN") + sUserDN = result['Value'] + cmd = "SELECT Pem FROM `ProxyDB_Requests` WHERE Id = %s AND UserDN = %s" % (requestId, sUserDN) - retVal = self._query(cmd) - if not retVal['OK']: - return retVal - data = retVal['Value'] - if len(data) == 0: - return S_ERROR("No requests with id %s" % requestId) - request = X509Request() - retVal = request.loadAllFromString(data[0][0]) - if not retVal['OK']: - return retVal - return S_OK(request) + result = self._query(cmd) + if result['OK']: + data = result['Value'] + if len(data) == 0: + return S_ERROR("No requests with id %s" % requestId) + request = X509Request() + result = request.loadAllFromString(data[0][0]) + + return S_OK(request) if result['OK'] else result def purgeExpiredRequests(self): """ Purge expired requests from the db @@ -315,7 +345,7 @@ def completeDelegation(self, requestId, userDN, delegatedPem): """ Complete a delegation and store it in the db :param int requestId: id of the request - :param str userDN: user DN + :param str userDN: DN :param str delegatedPem: delegated proxy as string :return: S_OK()/S_ERROR() @@ -340,40 +370,47 @@ def completeDelegation(self, requestId, userDN, delegatedPem): if not retVal['OK']: return retVal if retVal['Value']: - return S_ERROR("Proxies with DIRAC group extensions not allowed to be uploaded") - retVal = self.__storeProxy(userDN, chain) - return self.deleteRequest(requestId) if retVal['OK'] else retVal + return S_ERROR("Proxies with DIRAC group extensions not allowed to be uploaded.") + + result = self._storeProxy(userDN, chain) + + return self.deleteRequest(requestId) if result['OK'] else result + + def _isProxyExist(self, userDN, timeleft=None): + """ Check if proxy present in DB + + :param str userDN: user DN + :param int timeleft: requirement time - def __storeProxy(self, userDN, chain, proxyProvider=None): + :return: S_OK()/S_ERROR() + """ + result = self._escapeString(userDN) + if not result['OK']: + return result + sUserDN = result['Value'] + cmd = 'SELECT * FROM ProxyDB_CleanProxies WHERE UserDN = %s' % sUserDN + if timeleft: + cmd += ' AND TIMESTAMPDIFF(SECOND, UTC_TIMESTAMP(), ExpirationTime) > %s' % timeleft + result = self._query(cmd) + return S_OK(True if result['Value'] else False) if result['OK'] else result + + def _storeProxy(self, userDN, chain): """ Store user proxy into the Proxy repository for a user specified by his DN and group or proxy provider. :param str userDN: user DN from proxy :param X509Chain() chain: proxy chain - :param str proxyProvider: proxy provider name :return: S_OK()/S_ERROR() """ - retVal = Registry.getUsernameForDN(userDN) - if not retVal['OK']: - return retVal - userName = retVal['Value'] - - if not proxyProvider: - result = Registry.getProxyProvidersForDN(userDN) - if not result['OK']: - return result - proxyProvider = result.get('Value') and result['Value'][0] or 'Certificate' - # Get remaining secs retVal = chain.getRemainingSecs() if not retVal['OK']: return retVal remainingSecs = retVal['Value'] if remainingSecs < self._minSecsToAllowStore: - return S_ERROR( - "Cannot store proxy, remaining secs %s is less than %s" % - (remainingSecs, self._minSecsToAllowStore)) + return S_ERROR("Cannot store proxy, remaining secs %s is less than %s" % + (remainingSecs, self._minSecsToAllowStore)) # Compare the DNs retVal = chain.getIssuerCert() @@ -398,58 +435,110 @@ def __storeProxy(self, userDN, chain, proxyProvider=None): try: sUserDN = self._escapeString(userDN)['Value'] - sTable = 'ProxyDB_CleanProxies' except KeyError: return S_ERROR("Cannot escape DN") # Check what we have already got in the repository cmd = "SELECT TIMESTAMPDIFF( SECOND, UTC_TIMESTAMP(), ExpirationTime ), Pem " - cmd += "FROM `%s` WHERE UserDN=%s " % (sTable, sUserDN) + cmd += "FROM `ProxyDB_CleanProxies` WHERE UserDN=%s " % sUserDN result = self._query(cmd) if not result['OK']: return result + data = result['Value'] # Check if there is a previous ticket for the DN - data = result['Value'] - sqlInsert = True + pemChain = chain.dumpAllToString()['Value'] + dValues = {'UserDN': sUserDN, 'Pem': self._escapeString(pemChain)['Value'], + 'ExpirationTime': 'TIMESTAMPADD( SECOND, %d, UTC_TIMESTAMP() )' % int(remainingSecs)} + cmd = "INSERT INTO `ProxyDB_CleanProxies` (%s) VALUES (%s)" % (", ".join(dValues.keys()), + ", ".join(dValues.values())) if data: - sqlInsert = False + cmd = 'UPDATE `ProxyDB_CleanProxies` SET %s WHERE UserDN = %s' % ( + ", ".join(["%s = %s" % (k, v) for k, v in dValues.items()]), sUserDN) pem = data[0][1] if pem: remainingSecsInDB = data[0][0] if remainingSecs <= remainingSecsInDB: - self.log.info( - "Proxy stored is longer than uploaded, omitting.", - "%s in uploaded, %s in db" % - (remainingSecs, - remainingSecsInDB)) + self.log.info("Proxy stored is longer than uploaded, omitting.", + "%s in uploaded, %s in db" % (remainingSecs, remainingSecsInDB)) return S_OK() - pemChain = chain.dumpAllToString()['Value'] - dValues = {'UserName': self._escapeString(userName)['Value'], - 'UserDN': sUserDN, - 'Pem': self._escapeString(pemChain)['Value'], - 'ExpirationTime': 'TIMESTAMPADD( SECOND, %d, UTC_TIMESTAMP() )' % int(remainingSecs)} - dValues['ProxyProvider'] = "'%s'" % proxyProvider - if sqlInsert: - sqlFields = [] - sqlValues = [] - for key in dValues: - sqlFields.append(key) - sqlValues.append(dValues[key]) - cmd = "INSERT INTO `%s` ( %s ) VALUES ( %s )" % (sTable, ", ".join(sqlFields), ", ".join(sqlValues)) - else: - sqlSet = [] - sqlWhere = [] - for k in dValues: - if k in ('UserDN', 'ProxyProvider'): - sqlWhere.append("%s = %s" % (k, dValues[k])) - else: - sqlSet.append("%s = %s" % (k, dValues[k])) - cmd = "UPDATE `%s` SET %s WHERE %s" % (sTable, ", ".join(sqlSet), " AND ".join(sqlWhere)) + retVal = Registry.getUsernameForDN(userDN) + if not retVal['OK']: + return retVal + userName = retVal['Value'] - self.logAction("store proxy", userDN, proxyProvider, userDN, proxyProvider) + self.logAction("store proxy", userName, 'any', userName, 'any') return self._update(cmd) + def __getPemAndTimeLeft(self, userDN, userGroup, requiredLifeTime=None, vomsAttr=None): + """ Get proxy from DB and add group + + :param str userDN: user DN + :param str userGroup: required DIRAC group + :param int requiredLifeTime: required proxy live time in a seconds + :param str vomsAttr: if need search VOMS proxy first + + :return: S_OK(tuple)/S_ERROR() -- tuple with proxy as chain and proxy live time in a seconds + """ + cmd = 'SELECT Pem, TIMESTAMPDIFF(SECOND, UTC_TIMESTAMP(), ExpirationTime) FROM ' + cmd += '`%%s` WHERE UserDN="%s" AND TIMESTAMPDIFF(SECOND, UTC_TIMESTAMP(), ExpirationTime) > 0' % userDN + if vomsAttr: + # Search VOMS proxy first + result = self._query(cmd % 'ProxyDB_VOMSProxies' + " AND VOMSAttr=%s AND UserGroup=%s" % (vomsAttr, userGroup)) + if not result['OK']: + result = self._query(cmd % 'ProxyDB_CleanProxies') + else: + result = self._query(cmd % 'ProxyDB_CleanProxies') + err = "%s@%s proxy" % (userDN, userGroup) + if not result['OK']: + return S_ERROR("%s getting error: %s" % (err, result['Message'])) + data = result['Value'] + if data and data[0][0]: + if requiredLifeTime and data[0][1] <= requiredLifeTime: + return S_ERROR("%s stored in DB, but with less live time that required " % err) + chain = X509Chain() + result = chain.loadProxyFromString(data[0][0]) + if result['OK']: + result = chain.generateProxyToString(requiredLifeTime or min( + 3600 * 12, data[0][1]), diracGroup=userGroup, rfc=True) + if not result['OK']: + return S_ERROR("%s exist in DB, but %s" % (err, result['Message'])) + return S_OK((result['Value'], requiredLifeTime)) + return S_ERROR("%s with %s group is absent in DB" % (userDN, userGroup)) + + def __generateProxyForDNGroup(self, userDN, userGroup, requiredLifeTime): + """ Generate proxy from proxy provider and store it to DB + + :param str userDN: user DN + :param str userGroup: required DIRAC group + :param int requiredLifeTime: required proxy live time in a seconds + + :return: S_OK(tuple)/S_ERROR() -- tuple with proxy as chain and proxy live time in a seconds + """ + # Try to get proxy + result = self.getProxyProviderForDN(userDN) + if not result['OK']: + return result + if result['Value'] == 'Certificate': + return S_ERROR('No proxy provider found for this DN, need to upload proxy') + + result = ProxyProviderFactory().getProxyProvider(result['Value'], proxyManager=self) + if not result['OK']: + return result + providerObj = result['Value'] + + # Generate the proxy and store in the DB + result = providerObj.getProxy(userDN) + if not result['OK']: + return result + chain = result['Value'] + + # Add group + result = chain.generateProxyToString(requiredLifeTime, diracGroup=userGroup, rfc=True) + if not result['OK']: + return S_ERROR("Cannot generate proxy: %s" % result['Message']) + return S_OK((result['Value'], requiredLifeTime)) + def purgeExpiredProxies(self, sendNotifications=True): """ Purge expired requests from the db @@ -471,47 +560,37 @@ def purgeExpiredProxies(self, sendNotifications=True): return result return S_OK(purged) - def deleteProxy(self, userDN, userGroup=None, proxyProvider=None): + def deleteProxy(self, userDNs): """ Remove proxy of the given user from the repository - :param str userDN: user DN - :param str userGroup: DIRAC group - :param str proxyProvider: proxy provider name + :param list userDNs: user DN list :return: S_OK()/S_ERROR() """ - try: - userDN = self._escapeString(userDN)['Value'] - if userGroup: - userGroup = self._escapeString(userGroup)['Value'] - if proxyProvider: - proxyProvider = self._escapeString(proxyProvider)['Value'] - except KeyError: - return S_ERROR("Invalid DN or group or proxy provider") - errMsgs = [] - req = "DELETE FROM `%%s` WHERE UserDN=%s" % userDN - if proxyProvider or not userGroup: - result = self._update('%s %s' % (req % 'ProxyDB_CleanProxies', - proxyProvider and 'AND ProxyProvider=%s' % proxyProvider or '')) + tables = ['ProxyDB_Proxies', 'ProxyDB_VOMSProxies', 'ProxyDB_CleanProxies'] + escapeUserDNs = [] + for dn in userDNs: + result = self._escapeString(dn) if not result['OK']: - errMsgs.append(result['Message']) - for table in ['ProxyDB_Proxies', 'ProxyDB_VOMSProxies']: - result = self._update('%s %s' % (req % table, - userGroup and 'AND UserGroup=%s' % userGroup or '')) + return S_ERROR("Invalid DN: %s" % result['Message']) + escapeUserDNs.append(result['Value']) + errMsgs = [] + req = "DELETE FROM `%%s` WHERE UserDN IN (%s)" % ', '.join(escapeUserDNs) + for table in tables: + result = self._update(req % table) if not result['OK']: if result['Message'] not in errMsgs: errMsgs.append(result['Message']) - if errMsgs: - return S_ERROR(', '.join(errMsgs)) - return result - def __getPemAndTimeLeft(self, userDN, userGroup=None, vomsAttr=None, proxyProvider=None): + return S_ERROR(', '.join(errMsgs)) if errMsgs else result + + @deprecated("Old method for compatibility with older versions v7r0-") + def __getPemAndTimeLeftOld(self, userDN, userGroup, vomsAttr=None): """ Get proxy from database :param str userDN: user DN :param str userGroup: requested DIRAC group :param str vomsAttr: VOMS name - :param str proxyProvider: proxy provider name :return: S_OK(tuple)/S_ERROR() -- tuple contain proxy as string and remaining seconds """ @@ -523,52 +602,36 @@ def __getPemAndTimeLeft(self, userDN, userGroup=None, vomsAttr=None, proxyProvid sVomsAttr = self._escapeString(vomsAttr)['Value'] except KeyError: return S_ERROR("Invalid DN or Group") - if proxyProvider: - sTable = "`ProxyDB_CleanProxies`" - elif not vomsAttr: + if not vomsAttr: sTable = "`ProxyDB_Proxies`" else: sTable = "`ProxyDB_VOMSProxies`" cmd = "SELECT Pem, TIMESTAMPDIFF( SECOND, UTC_TIMESTAMP(), ExpirationTime ) from %s " % sTable cmd += "WHERE UserDN=%s AND TIMESTAMPDIFF( SECOND, UTC_TIMESTAMP(), ExpirationTime ) > 0" % (sUserDN) - if proxyProvider: - cmd += ' AND ProxyProvider="%s"' % proxyProvider - else: - if userGroup: - cmd += " AND UserGroup=%s" % sUserGroup - if vomsAttr: - cmd += " AND VOMSAttr=%s" % sVomsAttr + if userGroup: + cmd += " AND UserGroup=%s" % sUserGroup + if vomsAttr: + cmd += " AND VOMSAttr=%s" % sVomsAttr retVal = self._query(cmd) if not retVal['OK']: return retVal data = retVal['Value'] for record in data: if record[0]: - if proxyProvider: - chain = X509Chain() - result = chain.loadProxyFromString(record[0]) - if not result['OK']: - return result - result = chain.generateProxyToString(record[1], diracGroup=userGroup, rfc=True) - if not result['OK']: - return result - return S_OK((result['Value'], record[1])) return S_OK((record[0], record[1])) - if userGroup: - userMask = "%s@%s" % (userDN, userGroup) - else: - userMask = userDN + userMask = "%s@%s" % (userDN, userGroup) return S_ERROR("%s has no proxy registered" % userMask) + # WARN: This proxy manager work as myproxy, it seems we no use an external myproxy anymore def renewFromMyProxy(self, userDN, userGroup, lifeTime=None, chain=None): """ Renew proxy from MyProxy :param str userDN: user DN :param str userGroup: user group :param int lifeTime: needed proxy live time in a seconds - :param X509Chain chain: proxy as chain + :param object chain: proxy as X509Chain - :return: S_OK(X509Chain/S_ERROR() + :return: S_OK(X509Chain)/S_ERROR() """ if not lifeTime: lifeTime = 43200 @@ -621,7 +684,7 @@ def renewFromMyProxy(self, userDN, userGroup, lifeTime=None, chain=None): chainGroup = retVal['Value'] if chainGroup != userGroup: return S_ERROR("Mismatch between renewed proxy group and expected: %s vs %s" % (userGroup, chainGroup)) - retVal = self.__storeProxy(userDN, userGroup, mpChain) + retVal = self._storeProxy(userDN, mpChain) if not retVal['OK']: self.log.error("Cannot store proxy after renewal", retVal['Message']) retVal = myProxy.getServiceDN() @@ -632,274 +695,67 @@ def renewFromMyProxy(self, userDN, userGroup, lifeTime=None, chain=None): self.logAction("myproxy renewal", hostDN, "host", userDN, userGroup) return S_OK(mpChain) - # WARN: this method will not be needed if CS section Users//DNProperties will be for every user - # in this case will be used proxy providers that described there - def __getPUSProxy(self, userDN, userGroup, requiredLifetime, requestedVOMSAttr=False): - result = Registry.getGroupsForDN(userDN) - if not result['OK']: - return result - - validGroups = result['Value'] - if userGroup not in validGroups: - return S_ERROR('Invalid group %s for user' % userGroup) - - voName = Registry.getVOForGroup(userGroup) - if not voName: - return S_ERROR('Can not determine VO for group %s' % userGroup) - - retVal = self.__getVOMSAttribute(userGroup, requestedVOMSAttr) - if not retVal['OK']: - return retVal - vomsAttribute = retVal['Value']['attribute'] - vomsVO = retVal['Value']['VOMSVO'] - - puspServiceURL = Registry.getVOOption(voName, 'PUSPServiceURL') - if not puspServiceURL: - return S_ERROR('Can not determine PUSP service URL for VO %s' % voName) - - user = userDN.split(":")[-1] - - puspURL = "%s?voms=%s:%s&proxy-renewal=false&disable-voms-proxy=false" \ - "&rfc-proxy=true&cn-label=user:%s" % (puspServiceURL, vomsVO, vomsAttribute, user) - - try: - proxy = urlopen(puspURL).read() - except Exception: - return S_ERROR('Failed to get proxy from the PUSP server') - - chain = X509Chain() - chain.loadChainFromString(proxy) - chain.loadKeyFromString(proxy) - - result = chain.getCredentials() - if not result['OK']: - return S_ERROR('Failed to get a valid PUSP proxy') - credDict = result['Value'] - if credDict['identity'] != userDN: - return S_ERROR('Requested DN does not match the obtained one in the PUSP proxy') - timeLeft = credDict['secondsLeft'] - - result = chain.generateProxyToString(timeLeft, diracGroup=userGroup) - if not result['OK']: - return result - proxyString = result['Value'] - return S_OK((proxyString, timeLeft)) - - def __generateProxyFromProxyProvider(self, userDN, proxyProvider): - """ Get proxy from proxy provider - - :param str userDN: user DN for what need to create proxy - :param str proxyProvider: proxy provider name that will ganarete proxy - - :return: S_OK(dict)/S_ERROR() -- dict with remaining seconds, proxy as a string and as a chain - """ - gLogger.info('Getting proxy from proxyProvider', '(for "%s" DN by "%s")' % (userDN, proxyProvider)) - result = ProxyProviderFactory().getProxyProvider(proxyProvider) - if not result['OK']: - return result - pp = result['Value'] - result = pp.getProxy(userDN) - if not result['OK']: - return result - proxyStr = result['Value'] - chain = X509Chain() - result = chain.loadProxyFromString(proxyStr) - if not result['OK']: - return result - result = chain.getRemainingSecs() - if not result['OK']: - return result - remainingSecs = result['Value'] - result = self.__storeProxy(userDN, chain, proxyProvider) - if result['OK']: - return S_OK({'proxy': proxyStr, 'chain': chain, 'remainingSecs': remainingSecs}) - return result - - def __getProxyFromProxyProviders(self, userDN, userGroup, requiredLifeTime): - """ Generate new proxy from exist clean proxy or from proxy provider - for use with userDN in the userGroup + def getProxy(self, userDN, userGroup, requiredLifeTime=None, voms=False): + """ Get proxy string from the Proxy Repository for use with userDN in the userGroup :param str userDN: user DN - :param str userGroup: required group name + :param str userGroup: required DIRAC group :param int requiredLifeTime: required proxy live time in a seconds + :param bool voms: if need VOMS attribute - :return: S_OK(tuple)/S_ERROR() -- tuple contain proxy as string and remainig seconds + :return: S_OK(tuple)/S_ERROR() -- tuple with proxy as chain and proxy live time in a seconds """ - result = Registry.getGroupsForDN(userDN) - if not result['OK']: - return S_ERROR('Cannot generate proxy: %s' % result['Message']) - if userGroup not in result['Value']: - return S_ERROR('Cannot generate proxy: Invalid group %s for user' % userGroup) - result = Registry.getProxyProvidersForDN(userDN) - - errMsgs = [] - if result['OK']: - providers = result['Value'] - providers.append('Certificate') - for proxyProvider in providers: - self.log.verbose('Try to get proxy from ProxyDB_CleanProxies') - result = self.__getPemAndTimeLeft(userDN, userGroup, proxyProvider=proxyProvider) - if result['OK'] and (not requiredLifeTime or result['Value'][1] > requiredLifeTime): - return result - if len(providers) == 1: - return S_ERROR('Cannot generate proxy: No proxy providers found for "%s"' % userDN) - self.log.verbose('Try to generate proxy from %s proxy provider' % proxyProvider) - result = self.__generateProxyFromProxyProvider(userDN, proxyProvider) - if result['OK']: - chain = result['Value']['chain'] - remainingSecs = result['Value']['remainingSecs'] - result = chain.generateProxyToString(remainingSecs, diracGroup=userGroup, rfc=True) - if result['OK']: - return S_OK((result['Value'], remainingSecs)) - errMsgs.append('"%s": %s' % (proxyProvider, result['Message'])) + vomsAttr = Registry.getVOMSAttributeForGroup(userGroup) + if not vomsAttr and voms: + return S_ERROR("No mapping defined for group %s in the CS" % userGroup) - return S_ERROR('Cannot generate proxy%s' % (errMsgs and ': ' + ', '.join(errMsgs) or '')) + # Standard proxy is requested + self.log.verbose('Try to get proxy from ProxyDB_CleanProxies') + result = self.__getPemAndTimeLeft(userDN, userGroup, requiredLifeTime, voms and vomsAttr) + if not result['OK']: - def getProxy(self, userDN, userGroup, requiredLifeTime=None): - """ Get proxy string from the Proxy Repository for use with userDN - in the userGroup + # WARN: for compatibility + result = self.__getPemAndTimeLeftOld(userDN, userGroup, voms and vomsAttr) + if not result['OK'] or requiredLifeTime and result['Value'][1] < requiredLifeTime: - :param str userDN: user DN - :param str userGroup: required DIRAC group - :param int requiredLifeTime: required proxy live time in a seconds + errMsg = result.get('Message') or 'Stored proxy have not enough lifetime' + result = self.__generateProxyForDNGroup(userDN, userGroup, requiredLifeTime) + if not result['OK']: + return S_ERROR('%s; %s' % (errMsg, result['Message'])) - :return: S_OK(tuple)/S_ERROR() -- tuple with proxy as chain and proxy live time in a seconds - """ - # Test that group enable to download - if not Registry.isDownloadableGroup(userGroup): - return S_ERROR('"%s" group is disable to download.' % userGroup) - - # WARN: this block will not be needed if CS section Users//DNProperties will be for every user - # in this case will be used proxy providers that described there - # Get the Per User SubProxy if one is requested - if isPUSPdn(userDN): - result = self.__getPUSProxy(userDN, userGroup, requiredLifeTime) - if not result['OK']: - return result - pemData = result['Value'][0] - timeLeft = result['Value'][1] - chain = X509Chain() - result = chain.loadProxyFromString(pemData) - if not result['OK']: - return result - return S_OK((chain, timeLeft)) + pemData, timeLeft = result['Value'] - # Standard proxy is requested - self.log.verbose('Try to get proxy from ProxyDB_Proxies') - retVal = self.__getPemAndTimeLeft(userDN, userGroup) - errMsg = "Can't get proxy%s: " % (requiredLifeTime and ' for %s seconds' % requiredLifeTime or '') - if not retVal['OK']: - errMsg += '%s, try to generate new' % retVal['Message'] - retVal = self.__getProxyFromProxyProviders(userDN, userGroup, requiredLifeTime=requiredLifeTime) - elif requiredLifeTime: - if retVal['Value'][1] < requiredLifeTime and not self.__useMyProxy: - errMsg += 'Stored proxy is not long lived enough, try to generate new' - retVal = self.__getProxyFromProxyProviders(userDN, userGroup, requiredLifeTime=requiredLifeTime) - if not retVal['OK']: - return S_ERROR("%s; %s" % (errMsg, retVal['Message'])) - pemData = retVal['Value'][0] - timeLeft = retVal['Value'][1] chain = X509Chain() result = chain.loadProxyFromString(pemData) - if not retVal['OK']: - return S_ERROR("%s; %s" % (errMsg, retVal['Message'])) - if self.__useMyProxy: - if requiredLifeTime: - if timeLeft < requiredLifeTime: - retVal = self.renewFromMyProxy(userDN, userGroup, lifeTime=requiredLifeTime, chain=chain) - if not retVal['OK']: - return S_ERROR("%s; the proxy lifetime from MyProxy is less than required." % errMsg) - chain = retVal['Value'] + if not result['OK']: + self.deleteProxy([userDN]) + return S_ERROR("Checking %s@%s proxy failed: %s" % (userDN, userGroup, result['Message'])) # Proxy is invalid for some reason, let's delete it if not chain.isValidProxy()['OK']: - self.deleteProxy(userDN, userGroup) + self.deleteProxy([userDN]) return S_ERROR("%s@%s has no proxy registered" % (userDN, userGroup)) - return S_OK((chain, timeLeft)) - - def __getVOMSAttribute(self, userGroup, requiredVOMSAttribute=False): - """ Get VOMS attribute for DIRAC group - - :param str userGroup: DIRAC group - :param boolean requiredVOMSAttribute: VOMS attribute - - :return: S_OK(dict)/S_ERROR() -- dict contain attribute and VOMS VO - """ - if requiredVOMSAttribute: - return S_OK({'attribute': requiredVOMSAttribute, 'VOMSVO': Registry.getVOMSVOForGroup(userGroup)}) - - csVOMSMapping = Registry.getVOMSAttributeForGroup(userGroup) - if not csVOMSMapping: - return S_ERROR("No mapping defined for group %s in the CS" % userGroup) - - return S_OK({'attribute': csVOMSMapping, 'VOMSVO': Registry.getVOMSVOForGroup(userGroup)}) - - def getVOMSProxy(self, userDN, userGroup, requiredLifeTime=None, requestedVOMSAttr=None): - """ Get proxy string from the Proxy Repository for use with userDN - in the userGroup - - :param str userDN: user DN - :param str userGroup: required DIRAC group - :param int requiredLifeTime: required proxy live time in a seconds - :param str requestedVOMSAttr: VOMS attribute - - :return: S_OK(tuple)/S_ERROR() -- tuple with proxy as chain and proxy live time in a seconds - """ - retVal = self.__getVOMSAttribute(userGroup, requestedVOMSAttr) - if not retVal['OK']: - return retVal - vomsAttr = retVal['Value']['attribute'] - vomsVO = retVal['Value']['VOMSVO'] - - # Look in the cache - retVal = self.__getPemAndTimeLeft(userDN, userGroup, vomsAttr) - if retVal['OK']: - pemData = retVal['Value'][0] - vomsTime = retVal['Value'][1] - chain = X509Chain() - retVal = chain.loadProxyFromString(pemData) - if retVal['OK']: - retVal = chain.getRemainingSecs() - if retVal['OK']: - remainingSecs = retVal['Value'] - if requiredLifeTime and requiredLifeTime <= vomsTime and requiredLifeTime <= remainingSecs: - return S_OK((chain, min(vomsTime, remainingSecs))) - - if isPUSPdn(userDN): - # Get the Per User SubProxy if one is requested - result = self.__getPUSProxy(userDN, userGroup, requiredLifeTime, requestedVOMSAttr) - if not result['OK']: - return result - pemData = result['Value'][0] - chain = X509Chain() - result = chain.loadProxyFromString(pemData) - if not result['OK']: - return result - - else: - # Get the stored proxy and dress it with the VOMS extension - retVal = self.getProxy(userDN, userGroup, requiredLifeTime) - if not retVal['OK']: - return retVal - chain, _secsLeft = retVal['Value'] + if voms: + # If VOMS proxy requested vomsMgr = VOMS() - attrs = vomsMgr.getVOMSAttributes(chain).get('Value') or [''] - if attrs[0]: + attrs = vomsMgr.getVOMSAttributes(chain).get('Value') + if attrs and attrs[0]: if vomsAttr != attrs[0]: return S_ERROR("Stored proxy has already a different VOMS attribute %s than requested %s" % (attrs[0], vomsAttr)) else: - retVal = vomsMgr.setVOMSAttributes(chain, vomsAttr, vo=vomsVO) + retVal = vomsMgr.setVOMSAttributes(chain, vomsAttr, vo=Registry.getVOMSVOForGroup(userGroup)) if not retVal['OK']: return S_ERROR("Cannot append voms extension: %s" % retVal['Message']) chain = retVal['Value'] + # We have got the VOMS proxy, store it into the cache + result = self.__storeVOMSProxy(userDN, userGroup, vomsAttr, chain) + if not result['OK']: + return result + timeLeft = result['Value'] - # We have got the VOMS proxy, store it into the cache - result = self.__storeVOMSProxy(userDN, userGroup, vomsAttr, chain) - if not result['OK']: - return result - return S_OK((chain, result['Value'])) + return S_OK((chain, timeLeft)) def __storeVOMSProxy(self, userDN, userGroup, vomsAttr, chain): """ Store VOMS proxy @@ -907,7 +763,7 @@ def __storeVOMSProxy(self, userDN, userGroup, vomsAttr, chain): :param str userDN: user DN :param str userGroup: DIRAC group :param str vomsAttr: VOMS attribute - :param X509Chain() chain: proxy as chain + :param object chain: proxy as X509Chain :return: S_OK(str)/S_ERROR() """ @@ -957,45 +813,24 @@ def getUsers(self, validSecondsLeft=0, userMask=None): :param int validSecondsLeft: validity period expressed in seconds :param str userMask: user name that need to add to search filter - :return: S_OK(list)/S_ERROR() -- list contain dicts with user name, DN, group - expiration time, persistent flag + :return: S_OK(list)/S_ERROR() -- list contain dicts with DN, group, expiration time """ - data = [] + selDict = {} sqlCond = [] if validSecondsLeft: try: validSecondsLeft = int(validSecondsLeft) except ValueError: return S_ERROR("Seconds left has to be an integer") - sqlCond.append("TIMESTAMPDIFF( SECOND, UTC_TIMESTAMP(), ExpirationTime ) > %d" % validSecondsLeft) + sqlCond.append("TIMESTAMPDIFF(SECOND, UTC_TIMESTAMP(), ExpirationTime) > %d" % validSecondsLeft) if userMask: - try: - sUserName = self._escapeString(userMask)['Value'] - except KeyError: - return S_ERROR("Can't escape user name") - sqlCond.append('UserName = %s' % sUserName) - - for table, fields in [('ProxyDB_CleanProxies', ("UserName", "UserDN", "ExpirationTime")), - ('ProxyDB_Proxies', ("UserName", "UserDN", "UserGroup", - "ExpirationTime", "PersistentFlag"))]: - cmd = "SELECT %s FROM `%s`" % (", ".join(fields), table) - if sqlCond: - cmd += " WHERE %s" % " AND ".join(sqlCond) - retVal = self._query(cmd) - if not retVal['OK']: - return retVal - for record in retVal['Value']: - record = list(record) - if table == 'ProxyDB_CleanProxies': - record.insert(2, '') - record.insert(4, False) - data.append({'Name': record[0], - 'DN': record[1], - 'group': record[2], - 'expirationtime': record[3], - 'persistent': record[4] == 'True'}) - return S_OK(data) + result = Registry.getDNsForUsername(userMask) + if not result['OK'] or not result.get('Value'): + return S_OK([]) + selDict['UserDN'] = result['Value'] + result = self.getProxiesContent(selDict, cond=sqlCond) + return S_OK(result['Value']['Dictionaries']) if result['OK'] else result def getCredentialsAboutToExpire(self, requiredSecondsLeft, onlyPersistent=True): """ Get credentials about to expire for MyProxy @@ -1037,7 +872,7 @@ def setPersistencyFlag(self, userDN, userGroup, persistent=True): sqlInsert = True if retVal['OK']: data = retVal['Value'] - if len(data) > 0: + if data: sqlInsert = False if data[0][0] == sqlFlag: return S_OK() @@ -1049,12 +884,9 @@ def setPersistencyFlag(self, userDN, userGroup, persistent=True): if not result['OK']: self.log.error("setPersistencyFlag: Can not retrieve username for DN", userDN) return result - try: - sUserName = self._escapeString(result['Value'])['Value'] - except KeyError: - return S_ERROR("Can't escape user name") - cmd = "INSERT INTO `ProxyDB_Proxies` ( UserName, UserDN, UserGroup, Pem, ExpirationTime, PersistentFlag ) " - cmd += " VALUES( %s, %s, %s, '', UTC_TIMESTAMP(), 'True' )" % (sUserName, sUserDN, sUserGroup) + userName = result['Value'] + cmd = "INSERT INTO `ProxyDB_Proxies` (UserName, UserDN, UserGroup, Pem, ExpirationTime, PersistentFlag)" + cmd += " VALUES ('%s', %s, %s, '', UTC_TIMESTAMP(), 'True' )" % (userName, sUserDN, sUserGroup) else: cmd = "UPDATE `ProxyDB_Proxies` SET PersistentFlag='%s' WHERE UserDN=%s AND UserGroup=%s" % (sqlFlag, sUserDN, @@ -1064,49 +896,99 @@ def setPersistencyFlag(self, userDN, userGroup, persistent=True): return retVal return S_OK() - def getProxiesContent(self, selDict, sortList, start=0, limit=0): - """ Get the contents of the db, parameters are a filter to the db + def getProxiesContent(self, selDict, sortList=None, start=0, limit=0, cond=None): + """ Get the contents of the db, parameters are a filter to the db. - :param dict selDict: selection dict that contain fields and their posible values - :param dict sortList: dict with sorting fields + :param dict selDict: selection dict that contain fields and their possible values + :param sortList: dict with sorting fields + :type sortList: str or list :param int start: search limit start - :param int start: search limit amount + :param int limit: search limit amount + :param dict cond: filters :return: S_OK(dict)/S_ERROR() -- dict contain fields, record list, total records """ - data = [] + paramNames = ("UserName", "UserDN", "UserGroup", "ExpirationTime", "ProxyProvider") + + users = [] + groups = [] + if selDict.get('UserName'): + users = selDict['UserName'] + if not isinstance(users, (list, tuple)): + users = [users] + del selDict["UserName"] + if selDict.get('UserGroup'): + groups = selDict['UserGroup'] + if not isinstance(groups, (list, tuple)): + groups = [groups] + del selDict["UserGroup"] + + if groups or users: + DNs = [] + if groups and users: + for user in users: + for group in groups: + result = Registry.getDNsForUsernameInGroup(user, group) + if result['OK']: + DNs += result['Value'] + elif users: + for user in users: + result = Registry.getDNsForUsername(user) + if result['OK']: + DNs += result['Value'] + elif groups: + for group in groups: + for user in Registry.getUsersInGroup(group): + result = Registry.getDNsForUsername(user) + if result['OK']: + DNs += result['Value'] + + if not DNs: + return S_OK({'ParameterNames': paramNames, 'Records': [], 'TotalRecords': 0, 'Dictionaries': []}) + if selDict.get("UserDN"): + selDNs = selDict["UserDN"] if isinstance(selDict["UserDN"], (list, tuple)) else [selDict["UserDN"]] + selDict["UserDN"] = [] + for dn in selDNs: + if dn in DNs: + selDict["UserDN"].append(dn) + if not selDict["UserDN"]: + return S_OK({'ParameterNames': paramNames, 'Records': [], 'TotalRecords': 0, 'Dictionaries': []}) + else: + selDict["UserDN"] = DNs + + listData = [] + dataRecords = [] sqlWhere = ["Pem is not NULL"] - for table, fields in [('ProxyDB_CleanProxies', ("UserName", "UserDN", "ExpirationTime")), - ('ProxyDB_Proxies', ("UserName", "UserDN", "UserGroup", - "ExpirationTime", "PersistentFlag"))]: + if cond: + sqlWhere += (list(cond) if isinstance(cond, (list, tuple)) else [cond]) + + sqlOrder = [] + if sortList: + for sort in sortList: + if len(sort) == 1: + sort = (sort, "DESC") + elif len(sort) > 2: + return S_ERROR("Invalid sort %s" % sort) + if sort[0] in ['UserDN', 'ExpirationTime']: + if sort[1].upper() not in ("ASC", "DESC"): + return S_ERROR("Invalid sorting order %s" % sort[1]) + sqlOrder.append("%s %s" % (sort[0], sort[1])) + + for table, fields in [('ProxyDB_CleanProxies', ("UserDN", "ExpirationTime")), + ('ProxyDB_Proxies', ("UserDN", "UserGroup", "ExpirationTime"))]: cmd = "SELECT %s FROM `%s`" % (", ".join(fields), table) + for field in selDict: if field not in fields: continue fVal = selDict[field] if isinstance(fVal, (dict, tuple, list)): - sqlWhere.append("%s in (%s)" % - (field, ", ".join([self._escapeString(str(value))['Value'] for value in fVal]))) + if fVal: + sqlWhere.append("%s in (%s)" % + (field, ", ".join([self._escapeString(str(value))['Value'] for value in fVal]))) else: sqlWhere.append("%s = %s" % (field, self._escapeString(str(fVal))['Value'])) - sqlOrder = [] - if sortList: - for sort in sortList: - if len(sort) == 1: - sort = (sort, "DESC") - elif len(sort) > 2: - return S_ERROR("Invalid sort %s" % sort) - if sort[0] not in fields: - if table == 'ProxyDB_CleanProxies' and sort[0] in ['UserGroup', 'PersistentFlag']: - continue - return S_ERROR("Invalid sorting field %s" % sort[0]) - if sort[1].upper() not in ("ASC", "DESC"): - return S_ERROR("Invalid sorting order %s" % sort[1]) - sqlOrder.append("%s %s" % (sort[0], sort[1])) - if sqlWhere: - cmd = "%s WHERE %s" % (cmd, " AND ".join(sqlWhere)) - if sqlOrder: - cmd = "%s ORDER BY %s" % (cmd, ", ".join(sqlOrder)) + cmd += " WHERE %s " % " AND ".join(sqlWhere) if limit: try: start = int(start) @@ -1114,40 +996,68 @@ def getProxiesContent(self, selDict, sortList, start=0, limit=0): except ValueError: return S_ERROR("start and limit have to be integers") cmd += " LIMIT %d,%d" % (start, limit) + cmd += " ORDER BY %s" % (", ".join(sqlOrder) if sqlOrder else "UserDN DESC") retVal = self._query(cmd) if not retVal['OK']: + retVal['Message'] += "\n" + cmd return retVal + for record in retVal['Value']: record = list(record) - if table == 'ProxyDB_CleanProxies': - record.insert(2, '') - record.insert(4, False) - record[4] = record[4] == 'True' - data.append(record) - totalRecords = len(data) - return S_OK({'ParameterNames': fields, 'Records': data, 'TotalRecords': totalRecords}) + result = Registry.getUsernameForDN(record[0]) + if not result['OK']: + gLogger.error(result['Message']) + continue + user = result['Value'] + result = Registry.getGroupsForDN(record[0]) + if not result['OK']: + gLogger.error(result['Message']) + continue + groups = result['Value'] + result = self.getProxyProviderForDN(record[0]) + if not result['OK']: + gLogger.error(result['Message']) + continue + provider = result['Value'] - def logAction(self, action, issuerDN, issuerGroup, targetDN, targetGroup): + if table == 'ProxyDB_CleanProxies': + record.insert(1, groups) + record.append(provider) + + listData.append({'DN': record[0], + 'user': user, + 'groups': [record[1]] if table == 'ProxyDB_Proxies' else groups, + 'expirationtime': record[2], + 'provider': provider}) + + record.insert(0, user) + record.append(provider) + dataRecords.append(record) + return S_OK({'ParameterNames': paramNames, 'Records': dataRecords, 'TotalRecords': len(dataRecords), + 'Dictionaries': listData}) + + def logAction(self, action, issuerUsername, issuerGroup, targetUsername, targetGroup): """ Add an action to the log :param str action: proxy action - :param str issuerDN: user DN of issuer + :param str issuerUsername: user DN of issuer :param str issuerGroup: DIRAC group of issuer - :param str targetDN: user DN of target + :param str targetUsername: user DN of target :param str targetGroup: DIRAC group of target :return: S_ERROR() """ try: sAction = self._escapeString(action)['Value'] - sIssuerDN = self._escapeString(issuerDN)['Value'] + sIssuerUsername = self._escapeString(issuerUsername)['Value'] sIssuerGroup = self._escapeString(issuerGroup)['Value'] - sTargetDN = self._escapeString(targetDN)['Value'] + sTargetUsername = self._escapeString(targetUsername)['Value'] sTargetGroup = self._escapeString(targetGroup)['Value'] except KeyError: return S_ERROR("Can't escape from death") - cmd = "INSERT INTO `ProxyDB_Log` ( Action, IssuerDN, IssuerGroup, TargetDN, TargetGroup, Timestamp ) VALUES " - cmd += "( %s, %s, %s, %s, %s, UTC_TIMESTAMP() )" % (sAction, sIssuerDN, sIssuerGroup, sTargetDN, sTargetGroup) + cmd = "INSERT INTO `ProxyDB_Log` (Action, IssuerUsername, IssuerGroup, TargetUsername, TargetGroup, Timestamp)" + cmd += " VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP())" % (sAction, sIssuerUsername, sIssuerGroup, + sTargetUsername, sTargetGroup) retVal = self._update(cmd) if not retVal['OK']: self.log.error("Can't add a proxy action log: ", retVal['Message']) @@ -1161,11 +1071,16 @@ def purgeLogs(self): return self._update(cmd) def getLogsContent(self, selDict, sortList, start=0, limit=0): + """ Function to get the contents of the logs table parameters are a filter to the db + + :param dict selDict: filters + :param list sortList: sort list + :param int start: start number + :param int limit: limit + + :return: S_OK()/S_ERROR() """ - Function to get the contents of the logs table - parameters are a filter to the db - """ - fields = ("Action", "IssuerDN", "IssuerGroup", "TargetDN", "TargetGroup", "Timestamp") + fields = ("Action", "IssuerUsername", "IssuerGroup", "TargetUsername", "TargetGroup", "Timestamp") cmd = "SELECT %s FROM `ProxyDB_Log`" % ", ".join(fields) if selDict: qr = [] @@ -1199,10 +1114,10 @@ def getLogsContent(self, selDict, sortList, start=0, limit=0): totalRecords = retVal['Value'][0][0] return S_OK({'ParameterNames': fields, 'Records': data, 'TotalRecords': totalRecords}) - def generateToken(self, requesterDN, requesterGroup, numUses=1, lifeTime=0, retries=10): + def generateToken(self, requesterUsername, requesterGroup, numUses=1, lifeTime=0, retries=10): """ Generate and return a token and the number of uses for the token - :param str requesterDN: DN of requester + :param str requesterUsername: DN of requester :param str requesterGroup: DIRAC group of requester :param int numUses: number of uses :param int lifeTime: proxy live time in a seconds @@ -1218,9 +1133,9 @@ def generateToken(self, requesterDN, requesterGroup, numUses=1, lifeTime=0, retr rndData = "%s.%s.%s.%s" % (time.time(), random.random(), numUses, lifeTime) m.update(rndData.encode()) token = m.hexdigest() - fieldsSQL = ", ".join(("Token", "RequesterDN", "RequesterGroup", "ExpirationTime", "UsesLeft")) + fieldsSQL = ", ".join(("Token", "RequesterUsername", "RequesterGroup", "ExpirationTime", "UsesLeft")) valuesSQL = ", ".join((self._escapeString(token)['Value'], - self._escapeString(requesterDN)['Value'], + self._escapeString(requesterUsername)['Value'], self._escapeString(requesterGroup)['Value'], "TIMESTAMPADD( SECOND, %d, UTC_TIMESTAMP() )" % int(lifeTime), str(numUses))) @@ -1243,18 +1158,18 @@ def purgeExpiredTokens(self): delSQL = "DELETE FROM `ProxyDB_Tokens` WHERE ExpirationTime < UTC_TIMESTAMP() OR UsesLeft < 1" return self._update(delSQL) - def useToken(self, token, requesterDN, requesterGroup): + def useToken(self, token, requesterUsername, requesterGroup): """ Uses of token count :param str token: token - :param str requesterDN: DN of requester + :param str requesterUsername: user name of requester :param str requesterGroup: DIRAC group of requester :return: S_OK(boolean)/S_ERROR() """ sqlCond = " AND ".join(("UsesLeft > 0", "Token=%s" % self._escapeString(token)['Value'], - "RequesterDN=%s" % self._escapeString(requesterDN)['Value'], + "RequesterUsername=%s" % self._escapeString(requesterUsername)['Value'], "RequesterGroup=%s" % self._escapeString(requesterGroup)['Value'], "ExpirationTime >= UTC_TIMESTAMP()")) updateSQL = "UPDATE `ProxyDB_Tokens` SET UsesLeft = UsesLeft - 1 WHERE %s" % sqlCond @@ -1271,7 +1186,6 @@ def __cleanExpNotifs(self): cmd = "DELETE FROM `ProxyDB_ExpNotifs` WHERE ExpirationTime < UTC_TIMESTAMP()" return self._update(cmd) - # FIXME: Add clean proxy def sendExpirationNotifications(self): """ Send notification about expiration @@ -1311,7 +1225,7 @@ def sendExpirationNotifications(self): if notKey in notifDone and notifDone[notKey] <= notifLimit: # Already notified for this notification limit break - if not self._notifyProxyAboutToExpire(userDN, lTime): + if not self._notifyProxyAboutToExpire(userDN, group, lTime): # Cannot send notification, retry later break try: @@ -1335,10 +1249,11 @@ def sendExpirationNotifications(self): notifDone[notKey] = notifLimit return S_OK(sent) - def _notifyProxyAboutToExpire(self, userDN, lTime): + def _notifyProxyAboutToExpire(self, userDN, userGroup, lTime): """ Send notification mail about to expire :param str userDN: user DN + :param str userGroup: DIRAC group :param int lTime: left proxy live time in a seconds :return: boolean @@ -1360,16 +1275,77 @@ def _notifyProxyAboutToExpire(self, userDN, lTime): information is: DN: %s + Group: %s + + If you plan on keep using this credentials please upload a newer proxy to + DIRAC by executing: + + $ dirac-proxy-init -P -g %s --rfc If you have been issued different certificate, please make sure you have a proxy uploaded with that certificate. Cheers, DIRAC's Proxy Manager -""" % (userName, daysLeft, userDN) +""" % (userName, daysLeft, userDN, userGroup, userGroup) fromAddr = self.getFromAddr() result = self.__notifClient.sendMail(userEMail, msgSubject, msgBody, fromAddress=fromAddr) if not result['OK']: gLogger.error("Could not send email", result['Message']) return False return True + + def getProxyProviderForDN(self, userDN, username=None): + """ Get proxy providers by user DN + + :param str userDN: user DN + :param str username: user name + + :return: S_OK(str)/S_ERROR() + """ + if not username: + result = Registry.getUsernameForDN(userDN) + if not result['OK']: + return result + username = result['Value'] + + result = Registry.getDNProperty(userDN, 'ProxyProviders', username=username, defaultValue=[]) + if result['OK'] and result['Value']: + return S_OK(result['Value'][0]) + + for userID in Registry.getIDsForUsername(username): + result = gAuthManagerData.getDNOptionForID(userID, userDN, 'ProxyProvider') + if not result['OK']: + return result + provider = result['Value'] + if provider: + return S_OK(provider) + return S_OK('Certificate') + + def getValidDNs(self, listDNs, sqlCond=None): + """ Get valid DNs + + :param list listDNs: list DNs + + :return: S_OK()/S_ERROR() + """ + dns = [] + dataRecords = [] + sqlWhere = ["Pem is not NULL"] + if sqlCond: + sqlWhere += (list(sqlCond) if isinstance(sqlCond, (list, tuple)) else [sqlCond]) + sqlWhere.append("UserDN in (%s)" % ", ".join([self._escapeString(str(v))['Value'] for v in listDNs])) + for table, exfield in [('ProxyDB_CleanProxies', ''), ('ProxyDB_Proxies', ', UserGroup')]: + cmd = "SELECT UserDN, ExpirationTime%s FROM `%s`" % (exfield, table) + result = self._query("%s WHERE %s ORDER BY UserDN DESC" % (cmd, " AND ".join(sqlWhere))) + if not result['OK']: + return result + for record in result['Value']: + record = list(record) + if len(record) == 2: + record.append(None) + if record[0] in dns: + continue + dns.append(record[0]) + dataRecords.append(record) + return S_OK(dataRecords) diff --git a/src/DIRAC/FrameworkSystem/Service/AuthManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/AuthManagerHandler.py new file mode 100644 index 00000000000..2cb2fc7ae75 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Service/AuthManagerHandler.py @@ -0,0 +1,443 @@ +""" The AuthManager service provides a toolkit to authenticate through an OIDC session. + + .. literalinclude:: ../ConfigTemplate.cfg + :start-after: ##BEGIN AuthManager: + :end-before: ##END + :dedent: 2 + :caption: AuthManager options + + The main mission is to manage user tokens and clients, namely store/get it from the + :mod:`AuthDB ` database. + The service also collects and caches real-time information about the status of users in registered + identity providers servers. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import six +import time +import pprint +import threading +from authlib.jose import jwt # TODO: need to add authlib to DIRACOS + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.Core.DISET.RequestHandler import RequestHandler +from DIRAC.Core.Utilities import ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo, getProvidersForInstance +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForID, getIDsForUsername, getEmailsForGroup +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthAPI +from DIRAC.FrameworkSystem.private.authorization.utils.Sessions import Session + +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + +__RCSID__ = "$Id$" + + +gCacheProfiles = ThreadSafe.Synchronizer() + + +class AuthManagerHandler(RequestHandler): + """ Authentication manager + """ + __cahceIdPIDs = DictCache() + # # { + # # : [ ], + # # : ... + # # } + + __cacheProfiles = DictCache() + # # { + # # : { + # # DNs: { + # # : { + # # ProxyProvider: [ ], + # # VOMSRoles: [ ], + # # ... + # # }, + # # : { ... }, + # # } + # # }, + # # : { ... } + # # } + + __db = None + + @classmethod + @gCacheProfiles + def __getProfiles(cls, userID=None): + """ Get cache information + + :param str userID: user ID + + :return: dict + """ + if userID: + return {userID: cls.__cacheProfiles.get(userID) or {}} + return cls.__cacheProfiles.getDict() + + @classmethod + @gCacheProfiles + def __addProfiles(cls, data, time=3600 * 24): + """ Caching information + + :param dict data: ID information data + :param int time: lifetime + """ + if data: + for oid, info in data.items(): + cls.__cacheProfiles.add(oid, time, value=info) + + @classmethod + def __cleanAuthDB(cls): + """ Check AuthDB for zombie sessions and clean + + :return: S_OK()/S_ERROR() + """ + # cls.log.info("Kill zombie sessions") + # result = cls.__db.getZombieSessions() + # if not result['OK']: + # gLogger.error('Cannot clean zombies: %s' % result['Message']) + # return result + # for idP, sessions in result['Value'].items(): + # result = cls.__idps.getIdProvider(idP, sessionManager=cls.__db) + # if not result['OK']: + # for session in sessions: + # cls.log.error('%s session, with %s IdP, cannot log out:' % (sessions, idP), result['Message']) + # cls.__db.killSession(session) + # continue + # provObj = result['Value'] + # for session in sessions: + # result = provObj.logOut(session) + # if not result['OK']: + # cls.log.error('%s session, with %s IdP, cannot log out:' % (session, idP), result['Message']) + # cls.__db.killSession(session) + + # cls.log.notice("Cleaning is done!") + return S_OK() + + @classmethod + def initializeHandler(cls, serviceInfo): + """ Handler initialization + """ + cls.__db = AuthDB() + cls.__idps = IdProviderFactory() + # gThreadScheduler.addPeriodicTask(3600, cls.__cleanAuthDB) + # result = cls.__cleanAuthDB() + return cls.__refreshProfiles() # if result['OK'] else result + + @classmethod + def __refreshProfiles(cls): + """ Refresh users profiles + + :return: S_OK()/S_ERROR() + """ + def refreshIdP(idP): + """ Process to get information from IdPs using access tokens + + :param str idP: identity provider name + """ + result = cls.__idps.getIdProvider(idP, sessionManager=cls.__db) + if result['OK']: + provObj = result['Value'] + result = provObj.getIDsMetadata() + if result['OK']: + cls.__addProfiles(result['Value']) + if not result['OK']: + return result + + result = getProvidersForInstance('Id') + if not result['OK']: + return result + for idP in result['Value']: + processThread = threading.Thread(target=refreshIdP, args=[idP]) + processThread.start() + + return S_OK() + + def __checkAuth(self): + """ Check authorization rules + + :return: S_OK(tuple)/S_ERROR() -- tuple contain username and IDs + """ + credDict = self.getRemoteCredentials() + if credDict['group'] == 'hosts': + return S_OK((None, 'all')) + + user = credDict["username"] + userIDs = getIDsForUsername(user) + if not userIDs: + return S_ERROR('No registred IDs for %s user.' % user) + + return S_OK((user, userIDs)) + + # types_updateProfile = [] + # auth_updateProfile = ["authenticated", "TrustedHost"] + + # def export_updateProfile(self, userID=None): + # """ Return fresh info from identity providers about users with actual sessions + + # :params: str userID: user ID + + # :return: S_OK(dict)/S_ERROR() + # """ + # result = self.__checkAuth() + # if not result['OK']: + # return result + # user, ids = result["Value"] + + # # For host + # if ids == 'all': + # return S_OK(self.__getProfiles(userID=userID)) + + # # For user + # if userID: + # if userID not in ids: + # return S_ERROR('%s user not have access to %s ID information.' % (user, userID)) + # return S_OK(self.__getProfiles(userID=userID)) + + # data = {} + # for uid in ids: + # idDict = self.__getProfiles(userID=uid) + # if idDict: + # data[uid] = idDict + + # return S_OK(data) + + types_getIdProfiles = [] + auth_getIdProfiles = ["authenticated", "TrustedHost"] + + def export_getIdProfiles(self, userID=None): + """ Return fresh info from identity providers about users with actual sessions + + :params: str userID: user ID + + :return: S_OK(dict)/S_ERROR() + """ + result = self.__checkAuth() + if not result['OK']: + return result + user, ids = result["Value"] + + print('================== export_getIdProfiles ==================') + print('CREDS:') + pprint.pprint(self.getRemoteCredentials()) + print('userID: %s' % userID) + p = self.__getProfiles() + pprint.pprint(p) + + # For host + if ids == 'all': + print('all') + pprint.pprint(self.__getProfiles(userID=userID)) + return S_OK(self.__getProfiles(userID=userID)) + + # For user + if userID: + if userID not in ids: + return S_ERROR('%s user not have access to %s ID information.' % (user, userID)) + print('For user') + pprint.pprint(self.__getProfiles(userID=userID)) + return S_OK(self.__getProfiles(userID=userID)) + + data = {} + for uid in ids: + idDict = self.__getProfiles(userID=uid) + if idDict.get(uid): + data.update(idDict) + print('Else') + pprint.pprint(data) + return S_OK(data) + + types_parseAuthResponse = [six.string_types, dict, dict] # , six.string_types, dict] + + def export_parseAuthResponse(self, providerName, response, sessionDict): # , username, userProfile): + """ Fill session by user profile, tokens, comment, OIDC authorize status, etc. + Prepare dict with user parameters, if DN is absent there try to get it. + Create new or modify existing DIRAC user and store the session + + :param str providerName: provider name + :param dict response: authorization response + :param dict sessionDict: session number + + :return: S_OK(tuple)/S_ERROR() -- tuple contain username, profile and session + """ + self.log.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) + # Parse response + result = self.__idps.getIdProvider(providerName, sessionManager=self.__db) + if result['OK']: + provObj = result['Value'] + result = provObj.parseAuthResponse(response, sessionDict) + if not result['OK']: + return result + # FINISHING with IdP auth result + username, userID, profile = result['Value'] + self.log.debug("Read %s's profile:" % username, pprint.pformat(profile)) + userProfile = profile[providerName][userID] + + # self.log.debug('The next session is identified for %s:\n' % username, pprint.pformat(session)) + # Is ID registred? + result = getUsernameForID(userID) + if not result['OK']: + comment = '%s ID is not registred in the DIRAC.' % userID + result = self.__registerNewUser(providerName, username, userProfile) + if result['OK']: + comment += ' Administrators have been notified about you.' + else: + comment += ' Please, contact the DIRAC administrators.' + return S_ERROR(comment) + + self.log.debug("Add %s's profile to cache:" % username, pprint.pformat(userProfile)) + self.__addProfiles({userID: userProfile}) + + # print('================== export_parseAuthResponse ==================') + # print('userID: %s' % userProfile['ID']) + # print('profile: %s' % userProfile) + # pprint.pprint(self.__getProfiles()) + # print('================== ==================') + return S_OK((result['Value'], userID, userProfile)) + + def __registerNewUser(self, provider, username, userProfile): + """ Register new user + + :param str provider: provider + :param str username: user name + :param dict userProfile: user information dictionary + + :return: S_OK()/S_ERROR() + """ + from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient + + mail = {} + mail['subject'] = "[SessionManager] User %s to be added." % username + mail['body'] = 'User %s was authenticated by ' % userProfile['FullName'] + mail['body'] += provider + mail['body'] += "\n\nAuto updating of the user database is not allowed." + mail['body'] += " New user %s to be added," % username + mail['body'] += "with the following information:\n" + mail['body'] += "\nUser name: %s\n" % username + mail['body'] += "\nUser profile:\n%s" % pprint.pformat(userProfile) + mail['body'] += "\n\n------" + mail['body'] += "\n This is a notification from the DIRAC AuthManager service, please do not reply.\n" + result = S_OK() + for addresses in getEmailsForGroup('dirac_admin'): + result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) + if not result['OK']: + self.log.error(result['Message']) + if result['OK']: + self.log.info(result['Value'], "administrators have been notified about a new user.") + return result + + types_createClient = [dict] + auth_createClient = [] # "authenticated", "TrustedHost"] + + def export_createClient(self, kwargs): + """ Generates a state string to be used in authorizations + + :param str provider: provider + :param str session: session number + + :return: S_OK(str)/S_ERROR() + """ + return self.__db.addClient(**kwargs) + + # types_getClientByID = [six.string_types] + # auth_getClientByID = [] # "authenticated", "TrustedHost"] + + # def export_getClientByID(self, clientID, metadata): + # """ Generates a state string to be used in authorizations + + # :param str provider: provider + # :param str session: session number + + # :return: S_OK(str)/S_ERROR() + # """ + # return self.__db.getClientByID(clientID, **metadata) + + types_storeToken = [dict] + auth_storeToken = ["authenticated"] + + def export_storeToken(self, kwargs): + """ Generates a state string to be used in authorizations + + :param str provider: provider + :param str session: session number + + :return: S_OK(str)/S_ERROR() + """ + return self.__db.storeToken(kwargs) + + types_updateToken = [] + auth_updateToken = ["authenticated"] + + def export_updateToken(self, token, refreshToken): + """ Generates a state string to be used in authorizations + + :param str provider: provider + :param str session: session number + + :return: S_OK(str)/S_ERROR() + """ + result = self.__db.updateToken(token, refreshToken) + return S_OK(dict(result['Value'])) if result['OK'] else result + + types_getTokenByUserIDAndProvider = [six.string_types, six.string_types] + auth_getTokenByUserIDAndProvider = ["authenticated"] + + def export_getTokenByUserIDAndProvider(self, uid, provider): + """ Generates a state string to be used in authorizations + + :param str provider: provider + :param str session: session number + + :return: S_OK(str)/S_ERROR() + """ + # if provider: + # result = self.__idps.getIdProvider(provider, sessionManager=cls.__db) + # if result['OK']: + # provObj = result['Value'] + # result = provObj.getTokenByUserID(uid) + # else: + result = self.__db.getTokenByUserIDAndProvider(uid, provider) + return result + + types_addSession = [dict] + + def export_addSession(self, session): + """ Add session to cache + + :param session: session + :type session: str, dict or Session object + :param int exp: expired time + """ + return self.__db.addSession(dict(session)) + + types_getSession = [six.string_types] + + def export_getSession(self, session): + """ Get session from cache + + :param session: session + :type session: str, Session object + + :return: Session object + """ + print('-- getSession') + pprint(session) + return self.__db.getSession(session) + + types_removeSession = [six.string_types] + + def export_removeSession(self, session): + """ Remove session from cache + + :param session: session + :type session: str, Session object + """ + print('-- removeSession') + pprint(session) + return self.__db.removeSession(session) \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py index 3ec724fe31f..9403a131afc 100644 --- a/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py @@ -1,10 +1,15 @@ -""" ProxyManager is the implementation of the ProxyManagement service in the DISET framework +""" ProxyManager is the implementation of the ProxyManagement service in the DISET framework. .. literalinclude:: ../ConfigTemplate.cfg :start-after: ##BEGIN ProxyManager: :end-before: ##END :dedent: 2 :caption: ProxyManager options + + The main mission is to manage user proxies, namely upload/download, add DISET-specific extensions, such as group, + or VOMS extensions, and the like things. A :mod:`ProxyDB ` database is used to store + proxies. The service also collects and caches real-time information about the status of users in registered VOMS + servers. """ from __future__ import absolute_import from __future__ import division @@ -12,22 +17,166 @@ __RCSID__ = "$Id$" +import os import six -from DIRAC import gLogger, S_OK, S_ERROR -from DIRAC.Core.DISET.RequestHandler import RequestHandler +import pickle +import pprint +import threading + +from DIRAC import gLogger, S_OK, S_ERROR, rootPath, gConfig from DIRAC.Core.Security import Properties +from DIRAC.Core.Security.ProxyFile import writeChainToProxyFile +from DIRAC.Core.Security.VOMSService import VOMSService +from DIRAC.Core.DISET.RequestHandler import RequestHandler +from DIRAC.Core.Utilities import ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.Core.Utilities.Decorators import deprecated from DIRAC.Core.Utilities.ThreadScheduler import gThreadScheduler from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader +from DIRAC.ConfigurationSystem.Client import PathFinder from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient +from DIRAC.Resources.ProxyProvider.ProxyProviderFactory import ProxyProviderFactory +gVOMSCacheSync = ThreadSafe.Synchronizer() +gVOMSFileSync = ThreadSafe.Synchronizer() -class ProxyManagerHandler(RequestHandler): +class ProxyManagerHandler(RequestHandler): + """ Proxy manager service + + Contain __VOMSesUsersCache cache, with next structure: + Key: VOMS VO name + Value: S_OK(dict)/S_ERROR() -- dictionary formed by :func:`DIRAC.Core.Security.VOMSService.getUsers` + """ + # # { : { + # # Roles: [ ], + # # suspended: bool, + # # certSuspended: bool, + # # ... + # # } + # # } + + __notify = NotificationClient() + __VOMSesUsersCache = DictCache() __maxExtraLifeFactor = 1.5 __proxyDB = None + @classmethod + @gVOMSCacheSync + def saveVOMSInfoToCache(cls, vo, infoDict): + """ Cache VOMS VO information + + :param str vo: VO name + :param dict infoDict: dictionary with information about users + """ + cls.__VOMSesUsersCache.add(vo, 3600 * 24, infoDict) + + @classmethod + @gVOMSFileSync + def saveVOMSInfoToFile(cls, vo, infoDict): + """ Save cache to file + + :param str vo: VO name + :param dict infoDict: dictionary with information about users + """ + if not os.path.exists(cls.__workDir): + os.makedirs(cls.__workDir) + with open(os.path.join(cls.__workDir, vo + '.pkl'), 'wb+') as f: + pickle.dump(infoDict, f, pickle.HIGHEST_PROTOCOL) + + @classmethod + @gVOMSFileSync + def getVOMSInfoFromFile(cls, vo): + """ Load VO cache from file + + :param str vo: VO name + + :return: S_OK(dict)/S_ERROR() -- dictionary with information about users + """ + try: + with open(os.path.join(cls.__workDir, vo + '.pkl'), 'rb') as f: + return S_OK(pickle.load(f)) + except Exception as e: + return S_ERROR(str(e)) + + @classmethod + @gVOMSFileSync + def getVOMSInfoFromCache(cls, vo=None): + """ Load VO cache from file + + :param str vo: VO name + + :return: S_OK(dict)/S_ERROR() -- dictionary with information about users + """ + return cls.__VOMSesUsersCache.get(vo) if vo else cls.__VOMSesUsersCache.getDict() + + @classmethod + def __refreshVOMSesUsersCache(cls, voList=None): + """ Update cache with information about active users from supported VOs + + :param list voList: VOs to update + + :return: S_OK()/S_ERROR() + """ + def getVOInfo(vo): + """ Process to get information from VOMS API + + :param str vo: VO name + """ + usersDict = {} + result = S_ERROR('Cannot find administrators for %s VOMS VO' % vo) + voAdmins = Registry.getVOOption(vo, "VOAdmin", []) + + for group in Registry.getGroupsForVO(vo).get('Value') or []: + for user in voAdmins: + result = Registry.getDNForUsernameInGroup(user, group) + if result['OK']: + # Try to get proxy for any VO admin + result = cls.__proxyDB.getProxy(result['Value'], group, 1800) + if result['OK']: + # Now we have a proxy, lets dump it to file + result = writeChainToProxyFile(result['Value'][0], '/tmp/x509_syncTmp') + if result['OK']: + # Get users from VOMS + result = VOMSService(vo=vo).getUsers(result['Value']) + if result['OK']: + cls.saveVOMSInfoToCache(vo, result) + cls.saveVOMSInfoToFile(vo, result) + return + + gLogger.error(result['Message']) + if not cls.getVOMSInfoFromCache(vo) or not cls.getVOMSInfoFromCache(vo)['OK']: + cls.saveVOMSInfoToCache(vo, result) + # ################# getVOInfo ###################### # + + gLogger.info('Update VOMSes information..') + if not voList: + result = Registry.getVOsWithVOMS() + if not result['OK']: + return result + voList = result['Value'] + + for vo in voList: + processThread = threading.Thread(target=getVOInfo, args=[vo]) + processThread.start() + + # if diracAdminsNotifyDict: + # subject = '[ProxyManager] Cannot update users from %s VOMS VOs.' % ', '.join(diracAdminsNotifyDict.keys()) + # body = pprint.pformat(diracAdminsNotifyDict) + # body += "\n------\n This is a notification from the DIRAC ProxyManager service, please do not reply." + # #cls.__notify.sendMail(getEmailsForGroup('dirac_admin'), subject, body) + return S_OK() + @classmethod def initializeHandler(cls, serviceInfoDict): + """ Initialization + + :param dict serviceInfoDict: service information dictionary + + :return: S_OK()/S_ERROR() + """ + cls.__workDir = os.path.join(gConfig.getValue('/LocalSite/InstancePath', rootPath), 'work/ProxyManager') useMyProxy = cls.srv_getCSOption("UseMyProxy", False) try: result = ObjectLoader().loadObject('FrameworkSystem.DB.ProxyDB') @@ -35,18 +184,37 @@ def initializeHandler(cls, serviceInfoDict): gLogger.error('Failed to load ProxyDB class: %s' % result['Message']) return result dbClass = result['Value'] - cls.__proxyDB = dbClass(useMyProxy=useMyProxy) - except RuntimeError as excp: return S_ERROR("Can't connect to ProxyDB: %s" % excp) gThreadScheduler.addPeriodicTask(900, cls.__proxyDB.purgeExpiredTokens, elapsedTime=900) gThreadScheduler.addPeriodicTask(900, cls.__proxyDB.purgeExpiredRequests, elapsedTime=900) gThreadScheduler.addPeriodicTask(21600, cls.__proxyDB.purgeLogs) gThreadScheduler.addPeriodicTask(3600, cls.__proxyDB.purgeExpiredProxies) + gThreadScheduler.addPeriodicTask(3600 * 24, cls.__refreshVOMSesUsersCache) if useMyProxy: gLogger.info("MyProxy: %s\n MyProxy Server: %s" % (useMyProxy, cls.__proxyDB.getMyProxyServer())) - return S_OK() + return cls.__refreshVOMSesUsersCache() + + types_getVOMSesUsers = [] + + def export_getVOMSesUsers(self): + """ Return fresh info from service about VOMSes + + :return: S_OK(dict)/S_ERROR() + """ + VOMSesUsers = self.getVOMSInfoFromCache() + result = Registry.getVOs() + if not result['OK']: + return result + for vo in result['Value']: + if vo not in VOMSesUsers: + result = self.getVOMSInfoFromFile(vo) + if result['OK']: + VOMSesUsers[vo] = result['Value'] + continue + VOMSesUsers[vo] = S_ERROR('No information from "%s" VOMS VO' % vo) + return S_OK(VOMSesUsers) def __generateUserProxiesInfo(self): """ Generate information dict about user proxies @@ -55,32 +223,23 @@ def __generateUserProxiesInfo(self): """ proxiesInfo = {} credDict = self.getRemoteCredentials() - result = Registry.getDNForUsername(credDict['username']) + result = Registry.getDNsForUsername(credDict['username']) if not result['OK']: return result - selDict = {'UserDN': result['Value']} - result = self.__proxyDB.getProxiesContent(selDict, {}) + result = self.__proxyDB.getProxiesContent({'UserDN': result['Value']}) if not result['OK']: return result - contents = result['Value'] - userDNIndex = contents['ParameterNames'].index("UserDN") - userGroupIndex = contents['ParameterNames'].index("UserGroup") - expirationIndex = contents['ParameterNames'].index("ExpirationTime") - for record in contents['Records']: - userDN = record[userDNIndex] - if userDN not in proxiesInfo: - proxiesInfo[userDN] = {} - userGroup = record[userGroupIndex] - proxiesInfo[userDN][userGroup] = record[expirationIndex] - return proxiesInfo - - def __addKnownUserProxiesInfo(self, retDict): - """ Given a S_OK/S_ERR add a proxies entry with info of all the proxies a user has uploaded + for data in result['Value']['Dictionaries']: + if data['DN'] not in proxiesInfo: + proxiesInfo[data['DN']] = {} + for k, v in data.items(): + proxiesInfo[data['DN']][k] = v - :return: S_OK(dict)/S_ERROR() - """ - retDict['proxies'] = self.__generateUserProxiesInfo() - return retDict + for dn, data in proxiesInfo.items(): + pprint.pprint(data) + data['groups'] = sorted(set(data['groups'])) + + return S_OK(proxiesInfo) auth_getUserProxiesInfo = ['authenticated'] types_getUserProxiesInfo = [] @@ -90,28 +249,31 @@ def export_getUserProxiesInfo(self): :return: S_OK(dict) """ - return S_OK(self.__generateUserProxiesInfo()) + return self.__generateUserProxiesInfo() - # WARN: Since v7r1 requestDelegationUpload method use only first argument! - # WARN: Second argument for compatibility with older versions - types_requestDelegationUpload = [six.integer_types] + # WARN: Since v7r3 requestDelegationUpload method not use arguments! + auth_requestDelegationUpload = ['authenticated'] + types_requestDelegationUpload = [] - def export_requestDelegationUpload(self, requestedUploadTime, diracGroup=None): + def export_requestDelegationUpload(self, requestedUploadTime=None, diracGroup=None): """ Request a delegation. Send a delegation request to client - :param int requestedUploadTime: requested live time - :return: S_OK(dict)/S_ERROR() -- dict contain id and proxy as string of the request """ + if requestedUploadTime: + self.log.warn("Since v7r3 requestDelegationUpload method without arguments!") + if diracGroup: - self.log.warn("Since v7r1 requestDelegationUpload method use only first argument!") + return S_ERROR("Proxy with DIRAC group or VOMS extensions not allowed to be uploaded.") + credDict = self.getRemoteCredentials() - user = '%s:%s' % (credDict['username'], credDict['group']) - result = self.__proxyDB.generateDelegationRequest(credDict['x509Chain'], credDict['DN']) + result = self.__proxyDB.generateDelegationRequest(credDict) if result['OK']: - gLogger.info("Upload request by %s given id %s" % (user, result['Value']['id'])) + gLogger.info("Upload request by %s:%s given id %s" % + (credDict['username'], credDict['group'], result['Value']['id'])) else: - gLogger.error("Upload request failed", "by %s : %s" % (user, result['Message'])) + gLogger.error("Upload request failed", "by %s:%s : %s" % + (credDict['username'], credDict['group'], result['Message'])) return result types_completeDelegationUpload = [six.integer_types, six.string_types] @@ -124,14 +286,15 @@ def export_completeDelegationUpload(self, requestId, pemChain): :return: S_OK(dict)/S_ERROR() -- dict contain proxies """ + credDict = self.getRemoteCredentials() userId = "%s:%s" % (credDict['username'], credDict['group']) retVal = self.__proxyDB.completeDelegation(requestId, credDict['DN'], pemChain) if not retVal['OK']: gLogger.error("Upload proxy failed", "id: %s user: %s message: %s" % (requestId, userId, retVal['Message'])) - return self.__addKnownUserProxiesInfo(retVal) + return retVal gLogger.info("Upload %s by %s completed" % (requestId, userId)) - return self.__addKnownUserProxiesInfo(S_OK()) + return self.__generateUserProxiesInfo() types_getRegisteredUsers = [] @@ -148,21 +311,21 @@ def export_getRegisteredUsers(self, validSecondsRequired=0): return self.__proxyDB.getUsers(validSecondsRequired, userMask=credDict['username']) return self.__proxyDB.getUsers(validSecondsRequired) - def __checkProperties(self, requestedUserDN, requestedUserGroup): + def __checkProperties(self, requestedUsername, requestedUserGroup, credDict): """ Check the properties and return if they can only download limited proxies if authorized - :param str requestedUserDN: user DN + :param str requestedUsername: user name :param str requestedUserGroup: DIRAC group + :param dict credDict: remote credentials - :return: S_OK(boolean)/S_ERROR() + :return: S_OK(bool)/S_ERROR() -- bool indicates whether there are limitation """ - credDict = self.getRemoteCredentials() if Properties.FULL_DELEGATION in credDict['properties']: return S_OK(False) if Properties.LIMITED_DELEGATION in credDict['properties']: return S_OK(True) if Properties.PRIVATE_LIMITED_DELEGATION in credDict['properties']: - if credDict['DN'] != requestedUserDN: + if credDict['username'] != requestedUsername: return S_ERROR("You are not allowed to download any proxy") if Properties.PRIVATE_LIMITED_DELEGATION not in Registry.getPropertiesForGroup(requestedUserGroup): return S_ERROR("You can't download proxies for that group") @@ -172,159 +335,132 @@ def __checkProperties(self, requestedUserDN, requestedUserGroup): types_getProxy = [six.string_types, six.string_types, six.string_types, six.integer_types] - def export_getProxy(self, userDN, userGroup, requestPem, requiredLifetime): - """ Get a proxy for a userDN/userGroup + def export_getProxy(self, instance, group, requestPem, requiredLifetime, token=None, vomsAttribute=None): + """ Get a proxy for a user/group - :param requestPem: PEM encoded request object for delegation - :param requiredLifetime: Argument for length of proxy + :param str instance: user name or DN + :param str group: DIRAC group + :param str requestPem: PEM encoded request object for delegation + :param int requiredLifetime: Argument for length of proxy + :param str token: token that need to use + :param bool vomsAttribute: make proxy with VOMS extension * Properties: * FullDelegation <- permits full delegation of proxies * LimitedDelegation <- permits downloading only limited proxies * PrivateLimitedDelegation <- permits downloading only limited proxies for one self - """ - credDict = self.getRemoteCredentials() - - result = self.__checkProperties(userDN, userGroup) - if not result['OK']: - return result - forceLimited = result['Value'] - - self.__proxyDB.logAction("download proxy", credDict['DN'], credDict['group'], userDN, userGroup) - return self.__getProxy(userDN, userGroup, requestPem, requiredLifetime, forceLimited) - - def __getProxy(self, userDN, userGroup, requestPem, requiredLifetime, forceLimited): - """ Internal to get a proxy - - :param str userDN: user DN - :param str userGroup: DIRAC group - :param str requestPem: dump of request certificate - :param int requiredLifetime: requested live time of proxy - :param boolean forceLimited: limited proxy :return: S_OK(str)/S_ERROR() """ - retVal = self.__proxyDB.getProxy(userDN, userGroup, requiredLifeTime=requiredLifetime) - if not retVal['OK']: - return retVal - chain, secsLeft = retVal['Value'] - # If possible we return a proxy 1.5 longer than requested - requiredLifetime = int(min(secsLeft, requiredLifetime * self.__maxExtraLifeFactor)) - retVal = chain.generateChainFromRequestString(requestPem, - lifetime=requiredLifetime, - requireLimited=forceLimited) - if not retVal['OK']: - return retVal - return S_OK(retVal['Value']) + # Test that group enable to download + if not Registry.isDownloadableGroup(group): + return S_ERROR('"%s" group is disable to download.' % group) - types_getVOMSProxy = [six.string_types, six.string_types, - six.string_types, six.integer_types, - [six.string_types, type(None), bool]] - - def export_getVOMSProxy(self, userDN, userGroup, requestPem, requiredLifetime, vomsAttribute=None): - """ Get a proxy for a userDN/userGroup - - :param requestPem: PEM encoded request object for delegation - :param requiredLifetime: Argument for length of proxy - :param vomsAttribute: VOMS attr to add to the proxy + # Read arguments + result = self.__getDNAndUsername(instance, group) + if not result['OK']: + return result + user, userDNs = result['Value'] - * Properties : - * FullDelegation <- permits full delegation of proxies - * LimitedDelegation <- permits downloading only limited proxies - * PrivateLimitedDelegation <- permits downloading only limited proxies for one self - """ credDict = self.getRemoteCredentials() - result = self.__checkProperties(userDN, userGroup) + if token: + result = self.__proxyDB.useToken(token, credDict['username'], credDict['group']) + if not result['OK']: + return result + if not result['Value']: + return S_ERROR("Proxy token is invalid") + + result = self.__checkProperties(user, group, credDict) if not result['OK']: return result - forceLimited = result['Value'] - self.__proxyDB.logAction("download voms proxy", credDict['DN'], credDict['group'], userDN, userGroup) - return self.__getVOMSProxy(userDN, userGroup, requestPem, requiredLifetime, vomsAttribute, forceLimited) + # Set limitation if tokens are used + forceLimited = True if token else result['Value'] + + log = "download %sproxy%s" % ('VOMS ' if vomsAttribute else '', 'with token' if token else '') + self.__proxyDB.logAction(log, credDict['username'], credDict['group'], user, group) + + errors = [] + # Use loop to fix a possible case of having several DNs for one user/group + for userDN in userDNs: + retVal = self.__proxyDB.getProxy(userDN, group, requiredLifeTime=requiredLifetime, voms=vomsAttribute) + if retVal['OK']: + break + errors.append(retVal['Message']) - def __getVOMSProxy(self, userDN, userGroup, requestPem, requiredLifetime, vomsAttribute, forceLimited): - retVal = self.__proxyDB.getVOMSProxy(userDN, userGroup, - requiredLifeTime=requiredLifetime, - requestedVOMSAttr=vomsAttribute) if not retVal['OK']: - return retVal + return S_ERROR('; '.join(errors)) chain, secsLeft = retVal['Value'] # If possible we return a proxy 1.5 longer than requested requiredLifetime = int(min(secsLeft, requiredLifetime * self.__maxExtraLifeFactor)) - return chain.generateChainFromRequestString(requestPem, - lifetime=requiredLifetime, + return chain.generateChainFromRequestString(requestPem, lifetime=requiredLifetime, requireLimited=forceLimited) - types_setPersistency = [six.string_types, six.string_types, bool] - - def export_setPersistency(self, userDN, userGroup, persistentFlag): - """ Set the persistency for a given dn/group - - :param str userDN: user DN - :param str userGroup: DIRAC group - :param boolean persistentFlag: if proxy persistent - - :return: S_OK()/S_ERROR() - """ - retVal = self.__proxyDB.setPersistencyFlag(userDN, userGroup, persistentFlag) - if not retVal['OK']: - return retVal - credDict = self.getRemoteCredentials() - self.__proxyDB.logAction("set persistency to %s" % bool(persistentFlag), - credDict['DN'], credDict['group'], userDN, userGroup) - return S_OK() - types_deleteProxyBundle = [(list, tuple)] def export_deleteProxyBundle(self, idList): - """ delete a list of id's + """ Delete a list of id's - :param list,tuple idList: list of identity numbers + :param idList: list of identity numbers + :type idList: list or tuple :return: S_OK(int)/S_ERROR() """ errorInDelete = [] deleted = 0 for _id in idList: - if len(_id) != 2: - errorInDelete.append("%s doesn't have two fields" % str(_id)) - retVal = self.export_deleteProxy(_id[0], _id[1]) - if not retVal['OK']: - errorInDelete.append("%s : %s" % (str(_id), retVal['Message'])) + if isinstance(_id, (tuple, list)): + if len(_id) != 2: + errorInDelete.append("%s doesn't have two fields" % str(_id)) + if _id[0]: + retVal = self.export_deleteProxy(_id[0], None if isinstance(_id[1], list) else _id[1]) + if not retVal['OK']: + errorInDelete.append("%s : %s" % (str(_id), retVal['Message'])) + else: + deleted += 1 else: - deleted += 1 + retVal = self.export_deleteProxy(_id) + if not retVal['OK']: + errorInDelete.append("%s : %s" % (str(_id), retVal['Message'])) + else: + deleted += 1 if errorInDelete: return S_ERROR("Could not delete some proxies: %s" % ",".join(errorInDelete)) return S_OK(deleted) - types_deleteProxy = [(list, tuple)] + types_deleteProxy = [six.string_types] - def export_deleteProxy(self, userDN, userGroup): + def export_deleteProxy(self, instance, userGroup=None): """ Delete a proxy from the DB - :param str userDN: user DN + :param str instance: user name or DN :param str userGroup: DIRAC group :return: S_OK()/S_ERROR() """ + result = self.__getDNAndUsername(instance, userGroup) + if not result['OK']: + return result + username, userDNs = result['Value'] + credDict = self.getRemoteCredentials() if Properties.PROXY_MANAGEMENT not in credDict['properties']: - if userDN != credDict['DN']: + if username != credDict['username']: return S_ERROR("You aren't allowed!") - retVal = self.__proxyDB.deleteProxy(userDN, userGroup) + retVal = self.__proxyDB.deleteProxy(userDNs) if not retVal['OK']: return retVal - self.__proxyDB.logAction("delete proxy", credDict['DN'], credDict['group'], userDN, userGroup) + self.__proxyDB.logAction("delete proxy", credDict['username'], credDict['group'], username, userGroup) return S_OK() types_getContents = [dict, (list, tuple), six.integer_types, six.integer_types] - def export_getContents(self, selDict, sortDict, start, limit): + def export_getContents(self, selDict, conn, start=0, limit=0): """ Retrieve the contents of the DB :param dict selDict: selection fields - :param list,tuple sortDict: sorting fields + :param list conn: filters :param int start: search limit start :param int start: search limit amount @@ -333,7 +469,8 @@ def export_getContents(self, selDict, sortDict, start, limit): credDict = self.getRemoteCredentials() if Properties.PROXY_MANAGEMENT not in credDict['properties']: selDict['UserName'] = credDict['username'] - return self.__proxyDB.getProxiesContent(selDict, sortDict, start, limit) + + return self.__proxyDB.getProxiesContent(selDict, conn, start=start, limit=limit) types_getLogContents = [dict, (list, tuple), six.integer_types, six.integer_types] @@ -341,9 +478,10 @@ def export_getLogContents(self, selDict, sortDict, start, limit): """ Retrieve the contents of the DB :param dict selDict: selection fields - :param list,tuple sortDict: search filter + :param sortDict: search filter + :type sortDict: list or tuple :param int start: search limit start - :param int start: search limit amount + :param int limit: search limit amount :return: S_OK(dict)/S_ERROR() -- dict contain fields, record list, total records """ @@ -351,74 +489,311 @@ def export_getLogContents(self, selDict, sortDict, start, limit): types_generateToken = [six.string_types, six.string_types, six.integer_types] - def export_generateToken(self, requesterDN, requesterGroup, tokenUses): + def export_generateToken(self, requester, requesterGroup, tokenUses): """ Generate tokens for proxy retrieval - :param str requesterDN: user DN + :param str requester: user name or DN :param str requesterGroup: DIRAC group :param int tokenUses: number of uses :return: S_OK(tuple)/S_ERROR() -- tuple contain token, number uses """ + # Is instance user DN? + if requester.startswith('/'): + result = Registry.getUsernameForDN(requester) + if not result['OK']: + return result + requester = result['Value'] + credDict = self.getRemoteCredentials() - self.__proxyDB.logAction("generate tokens", credDict['DN'], credDict['group'], requesterDN, requesterGroup) - return self.__proxyDB.generateToken(requesterDN, requesterGroup, numUses=tokenUses) + self.__proxyDB.logAction("generate tokens", credDict['username'], credDict['group'], + requester, requesterGroup) + return self.__proxyDB.generateToken(requester, requesterGroup, numUses=tokenUses) + + types_getVOMSProxyWithToken = [six.string_types, six.string_types, + six.string_types, six.integer_types, [six.string_types, type(None)]] + + @deprecated("This method is deprecated, you can use export_getProxy with token and vomsAttribute parameter") + def export_getVOMSProxyWithToken(self, user, userGroup, requestPem, requiredLifetime, token, vomsAttribute=None): + """ Get a proxy with VOMS extension for a user/userGroup by using token + + :param str user: user name + :param str userGroup: DIRAC group + :param str requestPem: PEM encoded request object for delegation + :param int requiredLifetime: Argument for length of proxy + :param str token: Valid token to get a proxy + + :return: S_OK(str)/S_ERROR() + """ + return self.export_getProxy(user, userGroup, requestPem, requiredLifetime, token=token, vomsAttribute=vomsAttribute) types_getProxyWithToken = [six.string_types, six.string_types, six.string_types, six.integer_types, six.string_types] - def export_getProxyWithToken(self, userDN, userGroup, requestPem, requiredLifetime, token): - """ Get a proxy for a userDN/userGroup + @deprecated("This method is deprecated, you can use export_getProxy with token parameter") + def export_getProxyWithToken(self, user, userGroup, requestPem, requiredLifetime, token): + """ Get a proxy for a user/userGroup by using token - :param requestPem: PEM encoded request object for delegation - :param requiredLifetime: Argument for length of proxy - :param token: Valid token to get a proxy + :param str user: user name + :param str userGroup: DIRAC group + :param str requestPem: PEM encoded request object for delegation + :param int requiredLifetime: Argument for length of proxy + :param str token: Valid token to get a proxy - * Properties: - * FullDelegation <- permits full delegation of proxies - * LimitedDelegation <- permits downloading only limited proxies - * PrivateLimitedDelegation <- permits downloading only limited proxies for one self + :return: S_OK(str)/S_ERROR() """ - credDict = self.getRemoteCredentials() - result = self.__proxyDB.useToken(token, credDict['DN'], credDict['group']) - gLogger.info("Trying to use token %s by %s:%s" % (token, credDict['DN'], credDict['group'])) - if not result['OK']: - return result - if not result['Value']: - return S_ERROR("Proxy token is invalid") - self.__proxyDB.logAction("used token", credDict['DN'], credDict['group'], userDN, userGroup) + return self.export_getProxy(user, userGroup, requestPem, requiredLifetime, token=token) - result = self.__checkProperties(userDN, userGroup) - if not result['OK']: - return result - self.__proxyDB.logAction("download proxy with token", credDict['DN'], credDict['group'], userDN, userGroup) - return self.__getProxy(userDN, userGroup, requestPem, requiredLifetime, True) + types_getVOMSProxy = [six.string_types, six.string_types, + six.string_types, six.integer_types, [six.string_types, type(None)]] - types_getVOMSProxyWithToken = [six.string_types, six.string_types, - six.string_types, six.integer_types, - [six.string_types, type(None)]] + @deprecated("This method is deprecated, you can use export_getProxy with vomsAttribute parameter") + def export_getVOMSProxy(self, user, userGroup, requestPem, requiredLifetime, vomsAttribute=None): + """ Get a proxy with VOMS extension for a user/userGroup - def export_getVOMSProxyWithToken(self, userDN, userGroup, requestPem, requiredLifetime, token, vomsAttribute=None): - """ Get a proxy for a userDN/userGroup + :param str user: user name + :param str userGroup: DIRAC group + :param str requestPem: PEM encoded request object for delegation + :param int requiredLifetime: Argument for length of proxy + :param str token: Valid token to get a proxy - :param requestPem: PEM encoded request object for delegation - :param requiredLifetime: Argument for length of proxy - :param vomsAttribute: VOMS attr to add to the proxy + :return: S_OK(str)/S_ERROR() + """ + return self.export_getProxy(user, userGroup, requestPem, requiredLifetime, vomsAttribute=vomsAttribute) - * Properties : - * FullDelegation <- permits full delegation of proxies - * LimitedDelegation <- permits downloading only limited proxies - * PrivateLimitedDelegation <- permits downloading only limited proxies for one self + types_getGroupsStatusByUsername = [six.string_types] + + def export_getGroupsStatusByUsername(self, username, groups=None): + """ Get status of every group for DIRAC user: + + :param str username: user name + + :return: S_OK(dict)/S_ERROR() """ - credDict = self.getRemoteCredentials() - result = self.__proxyDB.useToken(token, credDict['DN'], credDict['group']) + # # { + # # : { + # # Status: (unknown|failed|suspended|ready|not ready), + # # Comment: .., + # # DN: .., + # # Action: [ , [ ] ] + # # }, + # # { ... } + # # : {} ... } + # # } + print('GET STATUS FOR %s user, %s groups' % (username, groups)) + statusDict = {} + result = Registry.getGroupsForUser(username) if not result['OK']: return result - if not result['Value']: - return S_ERROR("Proxy token is invalid") - self.__proxyDB.logAction("used token", credDict['DN'], credDict['group'], userDN, userGroup) + userGroups = result['Value'] + groups = groups or userGroups + for group in groups: + if group not in userGroups: + return S_ERROR('%s group is not %s user group' % (group, username)) + + provDict = {} + groupDict = {} + # Sort user DNs for groups and proxy providers + for group in groups: + result = Registry.getDNsForUsernameInGroup(username, group) + print('group: %s' % group) + pprint.pprint(result) + if not result['OK']: + if group not in statusDict: + statusDict[group] = [{'Status': 'failed', 'Comment': result['Message']}] + continue + # we get only fist DN for now + for dn in [result['Value'][0]]: + result = self.__proxyDB.getProxyProviderForDN(dn, username=username) + if not result['OK']: + return result + pProvider = result['Value'] + if group not in groupDict: + groupDict[group] = [] + if pProvider not in provDict: + provDict[pProvider] = [] + provDict[pProvider] = list(set(provDict[pProvider] + [dn])) + groupDict[group] = list(set(groupDict[group] + [dn])) + + # Check VOMS VO + for group, dns in groupDict.items(): + if group not in statusDict: + statusDict[group] = [] + + vo = Registry.getGroupOption(group, 'VO') + + result = Registry.getVOsWithVOMS(voList=[vo]) + if not result['OK']: + return result + if not result['Value']: + continue - result = self.__checkProperties(userDN, userGroup) - if not result['OK']: - return result - self.__proxyDB.logAction("download voms proxy with token", credDict['DN'], credDict['group'], userDN, userGroup) - return self.__getVOMSProxy(userDN, userGroup, requestPem, requiredLifetime, vomsAttribute, True) + result = Registry.getVOMSServerInfo(vo) + if not result['OK']: + return result + vomsServers = result['Value'][vo]["Servers"] if result['Value'] else {} + vomsServerURL = 'https://%s:8443/voms/register/start.action' % vomsServers.keys()[0] + + result = self.getVOMSInfoFromCache(vo) + if not result: + result = self.getVOMSInfoFromFile(vo) + if result['OK']: + result = result['Value'] + result = S_ERROR('No information from "%s" VOMS VO' % vo) + + if not result or not result['OK']: + err = '' if not result else result.get('Message', '') + st = {'Status': 'unknown', + "Comment": 'Fail to get %s VOMS VO information depended for this group: %s' % (vo, err)} + for dn in dns: + if dn in groupDict[group]: + groupDict[group].remove(dn) + st['DN'] = dn + statusDict[group].append(st) + continue + + voData = result['Value'] + for dn in dns: + if dn not in voData: + if dn in groupDict[group]: + groupDict[group].remove(dn) + st = {'Status': 'failed', 'DN': dn, 'Action': ['openURL', [vomsServerURL]], + 'Comment': 'Make sure you(%s) are a member of the %s VOMS VO depended for this group. ' % (dn, vo)} + statusDict[group].append(st) + continue + + role = Registry.getGroupOption(group, 'VOMSRole') + print('===================== voData[dn] --->>') + pprint.pprint(voData[dn]) + if not role: + if voData[dn].get('Suspended'): # TODO: voData[dn]['Suspended'] + if dn in groupDict[group]: + groupDict[group].remove(dn) + st = {'Status': 'suspended', 'DN': dn, 'Action': ['openURL', [vomsServerURL]], + 'Comment': 'It seems you(%s) are suspended in the %s VOMS VO depended for this group. ' % + (dn, vo)} + statusDict[group].append(st) + continue + else: + if role not in voData[dn]['VOMSRoles']: + if dn in groupDict[group]: + groupDict[group].remove(dn) + st = {'Status': 'failed', 'DN': dn, 'Action': ['openURL', [vomsServerURL]], + 'Comment': 'It seems you(%s) have no %s role in %s VOMS VO depended for this group. ' % + (dn, role, vo)} + statusDict[group].append(st) + continue + if role in voData[dn]['SuspendedRoles']: + if dn in groupDict[group]: + groupDict[group].remove(dn) + st = {'Status': 'suspended', 'DN': dn, 'Action': ['openURL', [vomsServerURL]], + 'Comment': 'It seems you(%s) are suspended for %s role in the %s VOMS VO depended for this group. ' % + (dn, role, vo)} + statusDict[group].append(st) + continue + + # Check DNs by proxy providers + for prov, dns in provDict.items(): + dns = list(set(dns)) + + # Cut off existed DNs in DB + result = self.__proxyDB.getValidDNs(dns) + if not result['OK']: + return result + for dn, time, group in result['Value']: + if dn in dns: + dns.remove(dn) + st = {'Status': 'ready', 'DN': dn, + "Comment": 'proxy uploaded end valid to %s' % time} + for _group, _dns in groupDict.items(): + if not group or group == _group: + if _group not in statusDict: + statusDict[_group] = [] + if dn in _dns: + statusDict[_group].append(st) + + # If for DN not found proxy provider + if not prov or prov == 'Certificate': + for dn in dns: + st = {'Status': 'not ready', 'DN': dn, "Action": ['upload proxy', ['/']], + "Comment": 'You have no proxy with(%s) uploaded to DIRAC.' % dn} + for group, dns in groupDict.items(): + if group not in statusDict: + statusDict[group] = [] + if dn in dns: + statusDict[group].append(st) + continue + + # If proxy provider exist for DN + result = ProxyProviderFactory().getProxyProvider(prov, proxyManager=self.__proxyDB) + if not result['OK']: + return result + pProvObj = result['Value'] + for dn in dns: + result = pProvObj.checkStatus(dn) + st = result['Value'] if result['OK'] else {'Status': 'unknown', "Comment": result['Message']} + st['DN'] = dn + for group, dns in groupDict.items(): + if group not in statusDict: + statusDict[group] = [] + if dn in dns: + statusDict[group].append(st) + + resD = {} + for group, statuses in statusDict.items(): + for stat in statuses: + if stat['Status'] not in ["ready", "unknown"]: + resD[group] = stat + break + if group not in resD: + for stat in statuses: + if stat['Status'] == "ready": + resD[group] = stat + break + if group not in resD: + resD[group] = statuses[0] + + return S_OK(resD) + + def __getDNAndUsername(self, instance, group=None): + """ Parse instance to understand if it's DN and find username + + :param str instance: user name or DN + :param str group: user group + + :return S_OK(tuple)/S_ERROR() -- tuple contain username and userDN + """ + userDN = None + userName = instance + + # Is instance user DN? + if instance.startswith('/'): + userDN = instance + result = Registry.getUsernameForDN(instance) + if not result['OK']: + return result + userName = result['Value'] + + dns = [userDN] if userDN else [] + if group: + result = Registry.getDNsForUsernameInGroup(instance, group) + if not result['OK']: + return result + dns = result['Value'] + if dns: + if len(dns) > 1: + gLogger.warn('For %s@%s found more than one DN:' % (userName, group), dns) + if userDN: + if userDN not in dns: + return S_ERROR('Requested %s DN not match with %s user, %s group' % (userDN, userName, group)) + dns = [userDN] + + if not dns: + return S_ERROR('No DN were found for %s user' % userName, (', %s group' % group) if group else '') + return S_OK((userName, dns)) + + types_setPersistency = [] + + @deprecated("Unuse") + def export_setPersistency(self, user, userGroup, persistentFlag): + """ Set the persistency for a given DN/group """ + return S_OK(True) diff --git a/src/DIRAC/FrameworkSystem/Utilities/halo.py b/src/DIRAC/FrameworkSystem/Utilities/halo.py new file mode 100644 index 00000000000..3e401baba4b --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Utilities/halo.py @@ -0,0 +1,720 @@ +# -*- coding: utf-8 -*- +# pylint: disable=unsubscriptable-object +""" Beautiful terminal spinners in Python. + Source: https://github.com/manrajgrover/halo + + MIT License + + Copyright (c) 2017 Manraj Singh + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + -- + Source: https://github.com/tartley/colorama + + Copyright (c) 2010 Jonathan Hartley + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name of the copyright holders, nor those of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import, unicode_literals + +import os +import re +import sys +import six +import time +import ctypes +import atexit +import signal +import codecs +import platform +import functools +import threading + +from termcolor import colored +try: + from shutil import get_terminal_size +except ImportError: + from backports.shutil_get_terminal_size import get_terminal_size + +from DIRAC import S_OK, S_ERROR + + +def qrterminal(url): + """ Show QR code + + :param str url: URL to convert to QRCode + + :return: S_OK(str)/S_ERROR() + """ + try: + import pyqrcode # pylint: disable=import-error + except Exception as ex: + return S_ERROR('pyqrcode library is not installed.') + __qr = '\n' + qrA = pyqrcode.create(url).code + qrA.insert(0, [0 for i in range(0, len(qrA[0]))]) + qrA.append([0 for i in range(0, len(qrA[0]))]) + if not (len(qrA) % 2) == 0: + qrA.append([0 for i in range(0, len(qrA[0]))]) + for i in range(0, len(qrA)): + if not (i % 2) == 0: + continue + __qr += '\033[0;30;47m ' + for j in range(0, len(qrA[0])): + p = str(qrA[i][j]) + str(qrA[i + 1][j]) + if p == '11': # black bg + __qr += '\033[0;30;40m \033[0;30;47m' + if p == '10': # upblock + __qr += u'\u2580' + if p == '01': # downblock + __qr += u'\u2584' + if p == '00': # white bg + __qr += ' ' + __qr += ' \033[0m\n' + return S_OK(__qr) + + +class StreamWrapper(object): + """ Wraps a stream (such as stdout), acting as a transparent proxy for all + attribute access apart from method 'write()', which is delegated to our + Converter instance. + Source: https://github.com/tartley/colorama + """ + + def __init__(self, wrapped, converter): + # double-underscore everything to prevent clashes with names of + # attributes on the wrapped stream object. + self.__wrapped = wrapped + self.__convertor = converter + + def __getattr__(self, name): + return getattr(self.__wrapped, name) + + def __enter__(self, *args, **kwargs): + # special method lookup bypasses __getattr__/__getattribute__, see + # https://stackoverflow.com/questions/12632894/why-doesnt-getattr-work-with-exit + # thus, contextlib magic methods are not proxied via __getattr__ + return self.__wrapped.__enter__(*args, **kwargs) + + def __exit__(self, *args, **kwargs): + return self.__wrapped.__exit__(*args, **kwargs) + + def write(self, text): + self.__convertor.write(text) + + def isatty(self): + stream = self.__wrapped + if 'PYCHARM_HOSTED' in os.environ: + if stream is not None and (stream is sys.__stdout__ or stream is sys.__stderr__): + return True + try: + streamIsATTY = stream.isatty + except AttributeError: + return False + else: + return streamIsATTY() + + @property + def closed(self): + stream = self.__wrapped + try: + return stream.closed + except AttributeError: + return True + + +class PreWrapp(object): + """ Source: https://github.com/tartley/colorama + """ + + def __init__(self, wrapped): + if os.name == 'nt': + raise BaseException('Not support') + # The wrapped stream (normally sys.stdout or sys.stderr) + self.wrapped = wrapped + # create the proxy wrapping our output stream + self.stream = StreamWrapper(wrapped, self) + + def write(self, text): + self.wrapped.write(text) + self.wrapped.flush() + self.resetAll() + + def resetAll(self): + if not self.stream.closed: + self.wrapped.write('\033[0m') + + +def resetAll(): + if PreWrapp is not None: # Issue #74: objects might become None at exit + PreWrapp(sys.stdout).resetAll() + + +sys.stdout = PreWrapp(sys.stdout).stream +sys.stderr = PreWrapp(sys.stderr).stream +atexit.register(resetAll) + + +def getEnvironment(): + """ Get the environment in which halo is running + + :return: str -- Environment name + """ + try: + from IPython import get_ipython + except ImportError: + return 'terminal' + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': # Jupyter notebook or qtconsole + return 'jupyter' + elif shell == 'TerminalInteractiveShell': # Terminal running IPython + return 'ipython' + else: + return 'terminal' # Other type (?) + except NameError: + return 'terminal' + + +def decodeUTF8Text(text): + """ Decode the text from utf-8 format + + :param str text: String to be decoded + + :return: str -- Decoded string + """ + try: + return codecs.decode(text, 'utf-8') + except (TypeError, ValueError): + return text + + +def encodeUTF8Text(text): + """ Encodes the text to utf-8 format + + :param str text: String to be encoded + + :return: str -- Encoded string + """ + try: + return codecs.encode(text, 'utf-8', 'ignore') + except (TypeError, ValueError): + return text + + +class Halo(object): + """ Halo library. + + CLEAR_LINE -- Code to clear the line + """ + class Done(Exception): + """ Done exception """ + pass + + class CursorInfo(ctypes.Structure): + # Need for cursor + if os.name == 'nt': + _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)] + + CLEAR_LINE = '\033[K' + SPINNER_PLACEMENTS = ('left', 'right',) + + def __init__(self, text='', color='green', textColor=None, spinner=None, + animation=None, placement='left', interval=-1, enabled=True, stream=sys.stdout, result='succeed'): + """ Constructs the Halo object. + + :param str text: Text to display. + :param str color: Color of the text. + :param str textColor: Color of the text to display. + :param str,dict spinner: String or dictionary representing spinner. + :param basesrting animation: Animation to apply if text is too large. Can be one of `bounce`, `marquee`. + Defaults to ellipses. + :param str placement: Side of the text to place the spinner on. Can be `left` or `right`. + Defaults to `left`. + :param int interval: Interval between each frame of the spinner in milliseconds. + :param boolean enabled: Spinner enabled or not. + :param io stream: IO output. + """ + self._newline = None + self._result = result + self._color = color + self._animation = animation + self.spinner = spinner + self.text = text + self._textColor = textColor + self._interval = int(interval) if int(interval) > 0 else self._spinner['interval'] + self._stream = stream + self.placement = placement + self._frameIndex = 0 + self._textIndex = 0 + self._spinnerThread = None + self._stopSpinner = None + self._spinnerId = None + self.enabled = enabled + environment = getEnvironment() + + def cleanUp(): + """ Handle cell execution""" + self.__stop() + + if environment in ('ipython', 'jupyter'): + from IPython import get_ipython + ip = get_ipython() + ip.events.register('post_run_cell', cleanUp) + else: # default terminal + atexit.register(cleanUp) + + def __enter__(self): + """ Starts the spinner on a separate thread. For use in context managers. + """ + return self.start() + + def __exit__(self, eType, eValue, traceback): + """ Stops the spinner. For use in context managers.""" + if eType: + self._newline = False + self._text['original'] = '' + if isinstance(eValue, SystemExit) and eValue.code in [None, 0]: + self.succeed() + else: + self.fail() + elif self._result == 'succeed': + self.succeed() + elif self._result == 'warn': + self.warn() + elif self._result == 'info': + self.info() + else: + self.stop() + + def __call__(self, f): + """ Allow the Halo object to be used as a regular function decorator. + """ + @functools.wraps(f) + def wrapped(*args, **kwargs): + with self: + return f(*args, **kwargs) + return wrapped + + def coloredFrame(self, text, color=None): + """ Colorize text, while stripping nested ANSI color sequences. + + :param str text: text + :param str color: text colors -> red, green, yellow, blue, magenta, cyan, white. + + :return: str + """ + return colored(text, color, attrs=['bold']) + + @property + def spinner(self): + """ Getter for spinner property. + + :return: dict -- spinner value + """ + return self._spinner + + @spinner.setter + def spinner(self, spinner=None): + """ Setter for spinner property. + + :param dict,str spinner: Defines the spinner value with frame and interval + """ + self._spinner = {"interval": 80, "frames": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]} + self._frameIndex = 0 + self._textIndex = 0 + + @property + def text(self): + """ Getter for text property. + + :return: str -- text value + """ + return self._text['original'] + + @text.setter + def text(self, text): + """ Setter for text property. + + :param str text: Defines the text value for spinner + """ + self._text = self._getText(text) + + @property + def result(self): + """ Getter for result property. + + :return: str -- result value + """ + return self._result + + # pylint: disable=function-redefined + @text.setter + def result(self, result): + """ Setter for result property. + + :param str result: Defines the result of with + """ + self._result = result + + @property + def textColor(self): + """ Getter for text color property. + + :return: str -- text color value + """ + return self._textColor + + @textColor.setter + def textColor(self, textColor): + """ Setter for text color property. + + :param str textColor: Defines the text color value for spinner + """ + self._textColor = textColor + + @property + def color(self): + """ Getter for color property. + + :return: str -- color value + """ + return self._color + + @color.setter + def color(self, color): + """ Setter for color property. + + :param str color: Defines the color value for spinner + """ + self._color = color + + @property + def placement(self): + """ Getter for placement property. + + :return: str -- spinner placement + """ + return self._placement + + @placement.setter + def placement(self, placement): + """ Setter for placement property. + + :param str placement: Defines the placement of the spinner + """ + if placement not in self.SPINNER_PLACEMENTS: + raise ValueError("Unknown spinner placement '{0}', available are {1}".format(placement, self.SPINNER_PLACEMENTS)) + self._placement = placement + + @property + def spinner_id(self): + """ Getter for spinner id + + :return: str -- Spinner id value + """ + return self._spinnerId + + @property + def animation(self): + """ Getter for animation property. + + :return: str -- Spinner animation + """ + return self._animation + + @animation.setter + def animation(self, animation): + """ Setter for animation property. + + :param str animation: Defines the animation of the spinner + """ + self._animation = animation + self._text = self._getText(self._text['original']) + + def _checkStream(self): + """ Returns whether the stream is open, and if applicable, writable + + :return: bool -- Whether the stream is open + """ + if self._stream.closed: + return False + try: + # Attribute access kept separate from invocation, to avoid + # swallowing AttributeErrors from the call which should bubble up. + checkStreamWritable = self._stream.writable + except AttributeError: + pass + else: + return checkStreamWritable() + return True + + def _write(self, s): + """ Write to the stream, if writable + + :params str s: Characters to write to the stream + """ + if self._checkStream(): + self._stream.write(s) + + def _hideCursor(self): + """ Disable the user's blinking cursor + """ + if self._checkStream() and self._stream.isatty(): + for sid in [signal.SIGINT, signal.SIGTSTP]: + signal.signal(sid, self._showCursor) + if os.name == 'nt': + ci = CursorInfo() # pylint: disable=undefined-variable + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = False + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == 'posix': + sys.stdout.write("\033[?25l") + sys.stdout.flush() + + def _showCursor(self, *args): + """ Re-enable the user's blinking cursor + """ + if self._checkStream() and self._stream.isatty(): + if os.name == 'nt': + ci = CursorInfo() # pylint: disable=undefined-variable + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = True + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == 'posix': + sys.stdout.write("\033[?25h") + sys.stdout.flush() + if args: + raise SystemExit(args[0]) + + def _getText(self, text): + """ Creates frames based on the selected animation + + :params str text: text + """ + animation = self._animation + strippedText = text.strip() + + # Check which frame of the animation is the widest + maxSpinnerLength = max([len(i) for i in self._spinner['frames']]) + + # If column size is 0 either we are not connected + # to a terminal or something else went wrong. Fallback to 80. + terminalColumns = 80 if get_terminal_size().columns == 0 else get_terminal_size().columns + + # Subtract to the current terminal size the max spinner length + # (-1 to leave room for the extra space between spinner and text) + terminalWidth = terminalColumns - maxSpinnerLength - 1 + textLength = len(strippedText) + frames = [] + if terminalWidth < textLength and animation: + + if animation == 'bounce': + # Make the text bounce back and forth + for x in range(0, textLength - terminalWidth + 1): + frames.append(strippedText[x:terminalWidth + x]) + frames.extend(list(reversed(frames))) + + elif 'marquee': + # Make the text scroll like a marquee + strippedText = strippedText + ' ' + strippedText[:terminalWidth] + for x in range(0, textLength + 1): + frames.append(strippedText[x:terminalWidth + x]) + + elif terminalWidth < textLength and not animation: + # Add ellipsis if text is larger than terminal width and no animation was specified + frames = [strippedText[:terminalWidth - 4] + '... '] + else: + frames = [strippedText] + return {'original': text, 'frames': frames} + + def clear(self): + """ Clears the line and returns cursor to the start. + """ + self._write('\r') + self._write(self.CLEAR_LINE) + return self + + def _renderFrame(self): + """ Renders the frame on the line after clearing it. + """ + if not self.enabled: + # in case we're disabled or stream is closed while still rendering, + # we render the frame and increment the frame index, so the proper + # frame is rendered if we're reenabled or the stream opens again. + return + self.clear() + frame = self.frame() + output = '\r{}'.format(frame) + try: + self._write(output) + except UnicodeEncodeError: + self._write(encodeUTF8Text(output)) + + def render(self): + """ Runs the render until thread flag is set. + """ + while not self._stopSpinner.is_set(): + self._renderFrame() + time.sleep(0.001 * self._interval) + return self + + def frame(self): + """ Builds and returns the frame to be rendered + """ + frames = self._spinner['frames'] + frame = frames[self._frameIndex] + if self._color: + frame = self.coloredFrame(frame, self._color) + self._frameIndex += 1 + self._frameIndex = self._frameIndex % len(frames) + textFrame = self.textFrame() + return u'{0} {1}'.format(*[(textFrame, frame) if self._placement == 'right' else (frame, textFrame)][0]) + + def textFrame(self): + """ Builds and returns the text frame to be rendered + """ + if len(self._text['frames']) == 1: + if self._textColor: + return self.coloredFrame(self._text['frames'][0], self._textColor) + # Return first frame (can't return original text because at this point it might be ellipsed) + return self._text['frames'][0] + frames = self._text['frames'] + frame = frames[self._textIndex] + self._textIndex += 1 + self._textIndex = self._textIndex % len(frames) + return self.coloredFrame(frame, self._textColor) if self._textColor else frame + + def start(self, text=None): + """ Starts the spinner on a separate thread. + + :param str text: Text to be used alongside spinner + """ + if text is not None: + self.text = text + if self._spinnerId is not None: + return self + if not (self.enabled and self._checkStream()): + return self + self._hideCursor() + self._stopSpinner = threading.Event() + self._spinnerThread = threading.Thread(target=self.render) + self._spinnerThread.setDaemon(True) + self._renderFrame() + self._spinnerId = self._spinnerThread.name + self._spinnerThread.start() + return self + + def __stop(self): + if self._spinnerThread and self._spinnerThread.is_alive(): + self._stopSpinner.set() + self._spinnerThread.join() + + if self.enabled: + self.clear() + + self._frameIndex = 0 + self._spinnerId = None + self._showCursor() + return self + + def succeed(self, text=None): + """ Shows and persists success symbol and text and exits. + + :param str text: Text to be shown alongside success symbol. + """ + self._color = 'green' + return self.stop(symbol='✔', text=text) + + def fail(self, text=None): + """ Shows and persists fail symbol and text and exits. + + :param str text: Text to be shown alongside fail symbol. + """ + self._color = 'red' + return self.stop(symbol='✖', text=text) + + def warn(self, text=None): + """ Shows and persists warn symbol and text and exits. + + :param str text: Text to be shown alongside warn symbol. + """ + self._color = 'yellow' + return self.stop(symbol='⚠', text=text) + + def info(self, text=None): + """ Shows and persists info symbol and text and exits. + + :param str text: Text to be shown alongside info symbol. + """ + self._color = 'blue' + return self.stop(symbol='ℹ', text=text) + + def stop(self, text=None, symbol=None): + """ Stops the spinner and persists the final frame to be shown. + + :param str text: Text to be shown in final frame + :param str symbol: Symbol to be shown in final frame + """ + if not (symbol and text): + self.__stop() + if not self.enabled: + return self + self.__stop() + symbol = decodeUTF8Text(symbol) if symbol is not None else '' + text = decodeUTF8Text(text) if text is not None else self._text['original'] + symbol = self.coloredFrame(symbol, self._color) if self._color and symbol else symbol + text = self.coloredFrame(text, self._textColor) if self._textColor and text else text.strip() + output = u'{0} {1}'.format(*[(text, symbol) if self._placement == 'right' else (symbol, text)][0]) + output += '' if self._newline is False else '\n' + try: + self._write(output) + except UnicodeEncodeError: + self._write(encodeUTF8Text(output)) + return self diff --git a/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py new file mode 100644 index 00000000000..12c15cac3e5 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/AuthServer.py @@ -0,0 +1,358 @@ +""" This class provides authorization server activity. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import json +from time import time +import pprint +import urlparse +from tornado.httpclient import HTTPResponse +from tornado.httputil import HTTPHeaders +from tornado.template import Template + +from authlib.deprecate import deprecate +from authlib.jose import jwt +from authlib.oauth2 import HttpRequest, AuthorizationServer as _AuthorizationServer +from authlib.oauth2.rfc6749.grants import ImplicitGrant +from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import (DeviceAuthorizationEndpoint, + DeviceCodeGrant, + SaveSessionToDB) +from DIRAC.FrameworkSystem.private.authorization.grants.AuthorizationCode import (OpenIDCode, + AuthorizationCodeGrant) +from DIRAC.FrameworkSystem.private.authorization.grants.RefreshToken import RefreshTokenGrant +from DIRAC.FrameworkSystem.private.authorization.grants.TokenExchange import TokenExchangeGrant +from DIRAC.FrameworkSystem.private.authorization.grants.ImplicitFlow import (OpenIDImplicitGrant, + NotebookImplicitGrant) +from DIRAC.FrameworkSystem.private.authorization.utils.Clients import (ClientRegistrationEndpoint, + ClientManager) +from DIRAC.FrameworkSystem.private.authorization.utils.Sessions import SessionManager +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import (OAuth2Request, + createOAuth2Request) +# from authlib.oidc.core import UserInfo + +from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.common.security import generate_token +from authlib.common.encoding import to_unicode, json_dumps +from authlib.oauth2.base import OAuth2Error + +from DIRAC import gLogger, gConfig, S_OK, S_ERROR +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance +from DIRAC.ConfigurationSystem.Client.Helpers.CSGlobals import getSetup +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory +from DIRAC.FrameworkSystem.Client.AuthManagerClient import gSessionManager +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthorisationServerMetadata +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForID, getEmailsForGroup +# from DIRAC.Core.Web.SessionData import SessionStorage + +import logging +import sys +log = logging.getLogger('authlib') +log.addHandler(logging.StreamHandler(sys.stdout)) +log.setLevel(logging.DEBUG) +log = gLogger.getSubLogger(__name__) + + +class AuthServer(_AuthorizationServer, ClientManager): #SessionManager + """ Implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + + Initialize:: + + server = AuthServer() + """ + metadata_class = AuthorizationServerMetadata + + def __init__(self): + self.db = AuthDB() + self.idps = IdProviderFactory() + ClientManager.__init__(self, self.db) + # Privide two authlib methods query_client and save_token + _AuthorizationServer.__init__(self, query_client=self.getClient, save_token=self.saveToken) + self.generate_token = BearerToken(self.access_token_generator, self.refresh_token_generator) + self.config = {} + self.collectMetadata() + + # self.register_grant(NotebookImplicitGrant) # OpenIDImplicitGrant) + self.register_grant(TokenExchangeGrant) + self.register_grant(RefreshTokenGrant) + self.register_grant(DeviceCodeGrant, [SaveSessionToDB(db=self.db)]) + self.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True), OpenIDCode(require_nonce=False)]) + self.register_endpoint(ClientRegistrationEndpoint) + self.register_endpoint(DeviceAuthorizationEndpoint) + + def collectMetadata(self): + """ Collect metadata """ + self.metadata = {} + result = getAuthorisationServerMetadata() + if not result['OK']: + raise Exception('Cannot prepare authorization server metadata. %s' % result['Message']) + # Verify metadata + metadata = self.metadata_class(result['Value']) + metadata.validate() + self.metadata = metadata + + def addSession(self, session): + self.db.addSession(session) + + def getSession(self, session): + self.db.getSession(session) + + def saveToken(self, token, request): + """ Store tokens + + :param dict token: tokens + :param object request: http Request object, implemented for compatibility with authlib library (unuse) + """ + if 'refresh_token' in token: + return self.db.storeToken(token) + return S_OK(None) + + def getIdPAuthorization(self, providerName, request): + """ Submit subsession and return dict with authorization url and session number + + :param str providerName: provider name + :param object request: main session request + + :return: S_OK(response)/S_ERROR() -- dictionary contain response generated by `handle_response` + """ + result = self.idps.getIdProvider(providerName) + if not result['OK']: + return result + idpObj = result['Value'] + result = idpObj.submitNewSession() + if not result['OK']: + return result + authURL, state, session = result['Value'] + session['state'] = state + session['Provider'] = providerName + session['mainSession'] = request if isinstance(request, dict) else request.toDict() + + gLogger.verbose('Redirect to', authURL) + return self.handle_response(302, {}, [("Location", authURL)], session) + + def parseIdPAuthorizationResponse(self, response, session): + """ Fill session by user profile, tokens, comment, OIDC authorize status, etc. + Prepare dict with user parameters, if DN is absent there try to get it. + Create new or modify existing DIRAC user and store the session + + :param dict response: authorization response + :param str session: session + + :return: S_OK(dict)/S_ERROR() + """ + providerName = session.pop('Provider') + gLogger.debug('Try to parse authentification response from %s:\n' % providerName, pprint.pformat(response)) + # Parse response + result = self.idps.getIdProvider(providerName, sessionManager=self.db) + if not result['OK']: + return result + provObj = result['Value'] + result = provObj.parseAuthResponse(response, session) + if not result['OK']: + return result + # FINISHING with IdP auth result + username, userID, profile = result['Value'] + gLogger.debug("Read %s's profile:" % username, pprint.pformat(profile)) + userProfile = profile[providerName][userID] + # Is ID registred? + result = getUsernameForID(userID) + if not result['OK']: + # if sync with extVO is turn on: + # return autogenerated username and userID + # else: + comment = '%s ID is not registred in the DIRAC.' % userID + result = self.__registerNewUser(providerName, username, userProfile) + if result['OK']: + comment += ' Administrators have been notified about you.' + else: + comment += ' Please, contact the DIRAC administrators.' + return S_ERROR(comment) + username = result['Value'] + return S_OK((username, userID)) + # return gSessionManager.parseAuthResponse(session.pop('Provider'), createOAuth2Request(response).toDict(), + # session) + + def access_token_generator(self, client, grant_type, user, scope): + """ A function to generate ``access_token`` + + :param object client: Client object + :param str grant_type: grant type + :param str user: user unique id + :param str scope: scope + + :return: jwt object + """ + gLogger.debug('GENERATE DIRAC ACCESS TOKEN for "%s" with "%s" scopes.' % (user, scope)) + header = {'alg': 'RS256'} + payload = {'sub': user, + 'iss': self.metadata['issuer'], + 'iat': int(time()), + 'exp': int(time()) + (12 * 3600), + 'scope': scope, + 'setup': getSetup()} + # # + # Return proxy with token in one response? + # # + + # Read private key of DIRAC auth service + with open('/opt/dirac/etc/grid-security/jwtRS256.key', 'r') as f: + key = f.read() + # Need to use enum==0.3.1 for python 2.7 + return jwt.encode(header, payload, key) + + def refresh_token_generator(self, client, grant_type, user, scope): + """ A function to generate ``refresh_token`` + + :param object client: Client object + :param str grant_type: grant type + :param str user: user unique id + :param str scope: scope + + :return: jwt object + """ + gLogger.debug('GENERATE DIRAC REFRESH TOKEN for "%s" with "%s" scopes.' % (user, scope)) + header = {'alg': 'RS256'} + payload = {'sub': user, + 'iss': self.metadata['issuer'], + 'iat': int(time()), + 'exp': int(time()) + (30 * 24 * 3600), + 'scope': scope, + 'setup': getSetup(), + 'client_id': client.client_id} + # Read private key of DIRAC auth service + with open('/opt/dirac/etc/grid-security/jwtRS256.key', 'r') as f: + key = f.read() + # Need to use enum==0.3.1 for python 2.7 + return jwt.encode(header, payload, key) + + def get_error_uris(self, request): + error_uris = self.config.get('error_uris') + if error_uris: + return dict(error_uris) + + def create_oauth2_request(self, request, method_cls=OAuth2Request, use_json=False): + gLogger.debug('Create OAuth2 request', 'with json' if use_json else '') + return createOAuth2Request(request, method_cls, use_json) + + def create_json_request(self, request): + return self.create_oauth2_request(request, HttpRequest, True) + + def handle_error_response(self, request, error): + return self.handle_response(*error(translations=self.get_translations(request), + error_uris=self.get_error_uris(request)), error=True) + + def handle_response(self, status_code=None, payload=None, headers=None, newSession=None, error=None, **actions): + gLogger.debug('Handle authorization response with %s status code:' % status_code, payload) + gLogger.debug('Headers:', headers) + if newSession: + gLogger.debug('newSession:', newSession) + return S_OK([[status_code, headers, payload, newSession, error], actions]) + # return HTTPResponse(self.request, status_code, headers=header, buffer=io.StringIO(payload)) + + def create_authorization_response(self, response, username): + result = super(AuthServer, self).create_authorization_response(response, username) + if result['OK']: + # Remove auth session + result['Value'][0][4] = True + return result + + def validate_consent_request(self, request, provider=None): + """ Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization:: + + :param object request: tornado request + :param provider: provider + + :return: response generated by `handle_response` or S_ERROR or html + """ + if request.method != 'GET': + return 'Use GET method to access this endpoint.' + try: + req = self.create_oauth2_request(request) + # req.data['state'] = req.state or generate_token(10) + gLogger.info('Validate consent request for', req.state) + grant = self.get_authorization_grant(req) + gLogger.debug('Use grant:', grant) + grant.validate_consent_request() + if not hasattr(grant, 'prompt'): + grant.prompt = None + + # Check Identity Provider + provider, providerChooser = self.validateIdentityProvider(req, provider) + if not provider: + return providerChooser + + # Submit second auth flow through IdP + return self.getIdPAuthorization(provider, req) + except OAuth2Error as error: + return self.handle_error_response(None, error) + + def validateIdentityProvider(self, request, provider): + """ Check if identity provider registred in DIRAC + + :param object request: request + :param str provider: provider name + + :return: str, S_OK()/S_ERROR() -- provider name and html page to choose it + """ + # Research supported IdPs + result = getProvidersForInstance('Id') + if not result['OK']: + return None, result + idPs = result['Value'] + if not idPs: + return None, S_ERROR('No identity providers found.') + + if not provider: + if len(idPs) == 1: + return idPs[0], None + # Choose IdP interface + with self.doc: + with dom.div(style=self.css_main): + with dom.div('Choose identity provider', style=self.css_align_center): + for idP in idPs: + # data: Status, Comment, Action + dom.button(dom.a(idP, href='/authorization/%s?%s' % (idP, request.query)), + cls='button') + return None, self.handle_response(payload=Template(self.doc.render()).generate()) + + # Check IdP + if provider not in idPs: + return None, S_ERROR('%s is not registered in DIRAC.' % provider) + + return provider, None + + def __registerNewUser(self, provider, username, userProfile): + """ Register new user + + :param str provider: provider + :param str username: user name + :param dict userProfile: user information dictionary + + :return: S_OK()/S_ERROR() + """ + from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient + + mail = {} + mail['subject'] = "[SessionManager] User %s to be added." % username + mail['body'] = 'User %s was authenticated by ' % userProfile['FullName'] + mail['body'] += provider + mail['body'] += "\n\nAuto updating of the user database is not allowed." + mail['body'] += " New user %s to be added," % username + mail['body'] += "with the following information:\n" + mail['body'] += "\nUser name: %s\n" % username + mail['body'] += "\nUser profile:\n%s" % pprint.pformat(userProfile) + mail['body'] += "\n\n------" + mail['body'] += "\n This is a notification from the DIRAC AuthManager service, please do not reply.\n" + result = S_OK() + for addresses in getEmailsForGroup('dirac_admin'): + result = NotificationClient().sendMail(addresses, mail['subject'], mail['body'], localAttempt=False) + if not result['OK']: + self.log.error(result['Message']) + if result['OK']: + self.log.info(result['Value'], "administrators have been notified about a new user.") + return result \ No newline at end of file diff --git a/src/DIRAC/FrameworkSystem/private/authorization/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/__init__.py new file mode 100644 index 00000000000..dacc66057b4 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/__init__.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# $HeadURL$ +__RCSID__ = "$Id$" + +# from .AuthServer import AuthServer + +# __all__ = ['AuthServer'] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py new file mode 100644 index 00000000000..6c6f0faf905 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/AuthorizationCode.py @@ -0,0 +1,141 @@ +""" This class describe Authorization Code grant type +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +from pprint import pprint +from authlib.jose import JsonWebSignature +from authlib.oidc.core import UserInfo +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oauth2.rfc6749.grants import AuthorizationCodeGrant as _AuthorizationCodeGrant +from authlib.oauth2.rfc7636 import CodeChallenge +from authlib.common.encoding import to_unicode, json_dumps, json_b64encode, urlsafe_b64decode, json_loads + +from DIRAC import gLogger, S_OK, S_ERROR + + +class OAuth2Code(dict): + def __init__(self, params): + params['auth_time'] = params.get('auth_time', int(time())) + super(OAuth2Code, self).__init__(params) + + @property + def user(self): + return self.get('user_id') + + @property + def code_challenge(self): + return self.get('code_challenge') + + @property + def code_challenge_method(self): + return self.get('code_challenge_method', 'pain') + + def is_expired(self): + return self.get('auth_time') + 300 < time() + + def get_redirect_uri(self): + return self.get('redirect_uri') + + def get_scope(self): + return self.get('scope', '') + + def get_auth_time(self): + return self.get('auth_time') + + def get_nonce(self): + return self.get('nonce') + + +class OpenIDCode(_OpenIDCode): + def exists_nonce(self, nonce, request): + return False + # try: + # AuthorizationCode.objects.get(client_id=request.client_id, nonce=nonce) + # return True + # except AuthorizationCode.DoesNotExist: + # return False + + def get_jwt_config(self, grant): + with open('/opt/dirac/etc/grid-security/jwtRS256.key', 'rb') as f: + key = f.read() + issuer = grant.server.metadata['issuer'] + return {'key': key, 'alg': 'RS512', 'iss': issuer, 'exp': 3600} + + def generate_user_info(self, user, scope): + print('== generate_user_info ==') + # pprint(self.__dict__) + print(user) + print(scope) + # data = self.server.getSession(self.request.state) + # return UserInfo(sub=user[0], profile=data['profile'], grp=user[1]) + return UserInfo(sub=user[0], grp=user[1]) + + +class AuthorizationCodeGrant(_AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + pass + + def delete_authorization_code(self, authorization_code): + pass + + def query_authorization_code(self, code, client): + """ Parse authorization code + + :param code: authorization code as JWS + :param client: client + + :return: OAuth2Code or None + """ + print('== query_authorization_code ==') + pprint(code) + jws = JsonWebSignature(algorithms=['RS256']) + with open('/opt/dirac/etc/grid-security/jwtRS256.key.pub', 'rb') as f: + key = f.read() + data = jws.deserialize_compact(code, key) + try: + item = OAuth2Code(json_loads(urlsafe_b64decode(data['payload']))) + pprint(dict(item)) + print('get_scope: %s' % item.get_scope()) + except Exception as e: + return None + if not item.is_expired(): + return item + + def authenticate_user(self, authorization_code): + return authorization_code.user + + def generate_authorization_code(self): + """ return code """ + print('========= generate_authorization_code =========') + print('DICT:') + pprint(self.__dict__) + print('Reuest:') + pprint(self.request.user) + pprint(self.request.data) + print('Session:') + # sessionDict = self.server.getSession(self.request.state) + # pprint(sessionDict) + print('-----------------------------------------------') + jws = JsonWebSignature(algorithms=['RS256']) + protected = {'alg': 'RS256'} + code = OAuth2Code({'user_id': self.request.user['user_id'], + # These scopes already contain DIRAC groups + 'scope': self.request.data['scope'], + 'redirect_uri': self.request.args['redirect_uri'], + 'client_id': self.request.args['client_id'], + 'code_challenge': self.request.args.get('code_challenge'), + 'code_challenge_method': self.request.args.get('code_challenge_method')}) + print('--= Payload =--') + pprint(dict(code)) + # payload = json_dumps(dict(code)) # + payload = json_b64encode(dict(code)) + pprint(payload) + print('--= =--') + with open('/opt/dirac/etc/grid-security/jwtRS256.key', 'rb') as f: + key = f.read() + return jws.serialize_compact(protected, payload, key) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py new file mode 100644 index 00000000000..502b9132507 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/DeviceFlow.py @@ -0,0 +1,208 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import requests +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationEndpointMixin +from authlib.oauth2.rfc6749.errors import InvalidClientError, UnauthorizedClientError +from authlib.oauth2.rfc8628 import ( + DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, + DeviceCodeGrant as _DeviceCodeGrant, + DeviceCredentialDict, + DEVICE_CODE_GRANT_TYPE +) + +from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthAPI, getDIRACClientID + +log = gLogger.getSubLogger(__name__) + + +def submitUserAuthorizationFlow(idP=None, group=None): + """ Submit authorization flow + + :param str idP: identity provider + :param str group: requested group + + :return: S_OK(dict)/S_ERROR() -- dictionary with device code flow response + """ + try: + r = requests.post('{api}/device{provider}?client_id={client_id}{group}'.format( + api=getAuthAPI(), client_id=getDIRACClientID(), + provider=('/%s' % idP) if idP else '', + group = ('&scope=g:%s' % group) if group else '' + ), verify=False) + r.raise_for_status() + deviceResponse = r.json() + + # Check if all main keys are present here + for k in ['user_code', 'device_code', 'verification_uri']: + if not deviceResponse.get(k): + return S_ERROR('Mandatory %s key is absent in authentication response.' % k) + + return S_OK(deviceResponse) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer, timeout.') + except requests.exceptions.RequestException as ex: + return S_ERROR(r.content or repr(ex)) + except Exception as ex: + return S_ERROR('Cannot read authentication response: %s' % repr(ex)) + + +def waitFinalStatusOfUserAuthorizationFlow(deviceCode, interval=5, timeout=300): + """ Submit waiting loop process, that will monitor current authorization session status + + :param str deviceCode: received device code + :param int interval: waiting interval + :param int timeout: max time of waiting + + :return: S_OK(dict)/S_ERROR() - dictionary contain access/refresh token and some metadata + """ + __start = time.time() + + while True: + time.sleep(int(interval)) + if time.time() - __start > timeout: + return S_ERROR('Time out.') + r = requests.post('{api}/token?client_id={client_id}&grant_type={grant}&device_code={device_code}'.format( + api=getAuthAPI(), client_id = getDIRACClientID(), grant=DEVICE_CODE_GRANT_TYPE, device_code=deviceCode + ), verify=False) + token = r.json() + if not token: + return S_ERROR('Resived token is empty!') + if 'error' not in token: + os.environ['DIRAC_TOKEN'] = r.text + return S_OK(token) + if token['error'] != 'authorization_pending': + return S_ERROR(token['error'] + ' : ' + token.get('description', '')) + + +class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): + URL = '%s/device' % getAuthAPI() + + def create_endpoint_response(self, req): + """ See :func:`authlib.oauth2.rfc8628.DeviceAuthorizationEndpoint.create_endpoint_response` """ + # Share original request object to endpoint class before create_endpoint_response + self.req = req + return super(DeviceAuthorizationEndpoint, self).create_endpoint_response(req) + + def get_verification_uri(self): + """ Create verification uri when `DeviceCode` flow initialized + + :return: str + """ + return self.req.protocol + "://" + self.req.host + self.req.path + + def save_device_credential(self, client_id, scope, data): + """ Save device credentials + + :param str client_id: client id + :param str scope: request scopes + :param dict data: device credentials + """ + data.update(dict(uri='{api}?{query}&response_type=device&client_id={client_id}&scope={scope}'.format( + api=data['verification_uri'], query=self.req.query, client_id=client_id, scope=scope, + ), id=data['device_code'])) + result = self.server.db.addSession(data) + if not result['OK']: + raise OAuth2Error('Cannot save device credentials', result['Message']) + + +class DeviceCodeGrant(_DeviceCodeGrant, AuthorizationEndpointMixin): + RESPONSE_TYPES = {'device'} + + def validate_authorization_request(self): + """ Validate authorization request + + :return: None + """ + # Validate client for this request + client_id = self.request.client_id + log.debug('Validate authorization request of', client_id) + if client_id is None: + raise InvalidClientError(state=self.request.state) + client = self.server.query_client(client_id) + if not client: + raise InvalidClientError(state=self.request.state) + response_type = self.request.response_type + if not client.check_response_type(response_type): + raise UnauthorizedClientError('The client is not authorized to use "response_type={}"'.format(response_type)) + self.request.client = client + self.validate_requested_scope() + + # Check user_code, when user go to authorization endpoint + userCode = self.request.args.get('user_code') + if not userCode: + raise OAuth2Error('user_code is absent.') + + # Get session from cookie + if not self.getSession(user_code=userCode): + raise OAuth2Error('Session with %s user code is expired.' % userCode) + # self.execute_hook('after_validate_authorization_request') + return None + + def create_authorization_response(self, redirect_uri, user): + """ Mark session as authed with received user + + :param str redirect_uri: redirect uri + :param dict user: dictionary with username and userID + + :return: result of `handle_response` + """ + # Save session with user + result = self.server.db.addSession(dict(id=self.request.state, user_id=user['userID'], uri=self.request.uri, + username=user['username'], scope=self.request.scope)) + if not result['OK']: + raise OAuth2Error('Cannot save authorization result', result['Message']) + return 200, 'Authorization complite.' + + def query_device_credential(self, device_code): + result = self.server.db.getSession(device_code) + if not result['OK']: + raise OAuth2Error(result['Message']) + data = result['Value'] + if not data: + return None + data['expires_at'] = int(data['expires_in']) + int(time.time()) + data['interval'] = DeviceAuthorizationEndpoint.INTERVAL + data['verification_uri'] = DeviceAuthorizationEndpoint.URL + return DeviceCredentialDict(data) + + def query_user_grant(self, user_code): + """ Check if user alredy authed and return it to token generator + + :param str user_code: user code + + :return: str, bool -- user dict and user auth status + """ + result = self.server.db.getSessionByUserCode(user_code) + if not result['OK']: + raise OAuth2Error('Cannot found authorization session', result['Message']) + data = result['Value'] + return (data['user_id'], True) if data.get('username') != "None" else None + + def should_slow_down(self, credential, now): + """ If need to slow down requests """ + return False + + +class SaveSessionToDB(object): + """ SaveSessionToDB extension to Device Code Grant. It is used to + seve authorization session of Device Code flow for public clients in MySQL database. + + Then register this extension via:: + + server.register_grant(DeviceCodeGrant, [SaveSessionToDB(db=self.db)]) + """ + def __init__(self, db): + self.db = db + + def __call__(self, grant): + grant.register_hook('after_validate_consent_request', self.save_session) + + def save_session(self, *args, **kwargs): + print('SAVE-SESSION') + print(args) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/ImplicitFlow.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/ImplicitFlow.py new file mode 100644 index 00000000000..160b41c4725 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/ImplicitFlow.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749.grants import AuthorizationEndpointMixin, ImplicitGrant as _ImplicitGrant +from authlib.oauth2.rfc6749.errors import AccessDeniedError +from authlib.oidc.core.grants import OpenIDImplicitGrant as _OpenIDImplicitGrant +from authlib.common.security import generate_token + +from DIRAC import gLogger + +log = gLogger.getSubLogger(__name__) + + +class NotebookImplicitGrant(_ImplicitGrant): + def create_authorization_response(self, redirect_uri, grant_user): + print('== NotebookImplicitGrant: user: %s' % grant_user) + state = self.request.state + if grant_user: + self.request.user = grant_user + # from pprint import pprint + # import inspect + # print("args:") + # pprint(inspect.getargspec(self.generate_token)) + # inspect.getfullargspec(a_method) + token = self.generate_token( # client=self.request.client, + grant_type=self.GRANT_TYPE, + user=grant_user, scope=self.request.scope, include_refresh_token=False) + return 200, token, [] + else: + raise AccessDeniedError(state=state, redirect_uri=redirect_uri, redirect_fragment=True) + # c, p, h = super(NotebookImplicitGrant, self).create_authorization_response(redirect_uri, grant_user) + # return 200, h[0][1], [] + + +class OpenIDImplicitGrant(_OpenIDImplicitGrant): + def validate_authorization_request(self): + redirect_uri = super(OpenIDImplicitGrant, self).validate_authorization_request() + session = self.request.state or generate_token(10) + self.server.updateSession(session, request=self.request, group=self.request.args.get('group')) + return redirect_uri + + def get_jwt_config(self): + with open('/opt/dirac/etc/grid-security/jwtRS256.key', 'rb') as f: + key = f.read() + issuer = self.server.metadata['issuer'] + return dict(key=key, alg='RS256', iss=issuer, exp=3600) + + def generate_user_info(self, user, scopes): + print('=== generate_user_info ===') + print(user) + print(scopes) + data = self.server.getSession(self.request.state) + # return UserInfo(sub=data['userID'], profile=data['profile'], grp=data['group']) + return dict(sub=data['userID'], profile=data['profile'], grp=data['group']) + + def exists_nonce(self, nonce, request): + # TODO: need to implement + return False diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py new file mode 100644 index 00000000000..20ed723628e --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/RefreshToken.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oauth2.rfc6749.grants import RefreshTokenGrant as _RefreshTokenGrant +from authlib.oauth2.base import OAuth2Error + +# from DIRAC.ConfigurationSystem.Client.Helpers import Registry +# from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager +# from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import BearerTokenValidator + + +class RefreshTokenGrant(_RefreshTokenGrant): + # def __init__(self, *args, **kwargs): + # super(RefreshTokenGrant, self).__init__(*args, **kwargs) + # self.validator = BearerTokenValidator() + + def authenticate_refresh_token(self, refresh_token): + """ Get credential for token + + :param str refresh_token: refresh token + + :return: object + """ + # Check auth session + result = self.server.db.getTokenByRefreshToken(refresh_token) + if not result['OK']: + raise OAuth2Error('Cannot get token', result['Message']) + return result['Value'] + # if not session: + # return None + + # # Check token + # token = self.validator(refresh_token, self.request.scope, self.request, 'OR') + + # # # To special flow to change group + # # if self.request.scope and 'changeGroup' in self.request.scope: + # # scopes = scope_to_list(self.request.scope) + # # reqGroups = [s.split(':')[1] for s in scopes if s.startswith('g:')] + # # if len(reqGroups) != 1 or not reqGroups[0]: + # # return None + # # group = reqGroups[0] + # # result = Registry.getUsernameForID(token['sub']) + # # if not result['OK']: + # # return None + # # result = gProxyManager.getGroupsStatusByUsername(result['Value'], group) + # # if not result['OK']: + # # return None + # # if result['Value'][group]['Status'] not in ['ready', 'unknown']: + # # return None + # return self.validator(refresh_token, self.request.scope, self.request, 'OR') + + def authenticate_user(self, credential): + """ Authorize user + + :param object credential: credential + + :return: str + """ + return credential.sub + + def revoke_old_credential(self, credential): + """ Remove old credential + + :param object credential: credential + """ + self.server.db.removeToken(credential['refresh_token']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/TokenExchange.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/TokenExchange.py new file mode 100644 index 00000000000..ea11cd11810 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/TokenExchange.py @@ -0,0 +1,346 @@ +""" + authlib.oauth2.rfc6749.grants.refresh_token + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + A special grant endpoint for refresh_token grant_type. Refreshing an + Access Token per `Section 6`_. + + .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 +""" + +import logging +from authlib.oauth2.rfc6749.grants.base import BaseGrant, TokenEndpointMixin +from authlib.oauth2.rfc6749.util import scope_to_list +from authlib.oauth2.rfc6749.errors import ( + InvalidRequestError, + InvalidScopeError, + InvalidGrantError, + UnauthorizedClientError, +) +log = logging.getLogger(__name__) + +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import BearerTokenValidator + + +class _TokenExchangeGrant(BaseGrant, TokenEndpointMixin): + """ A special grant endpoint for urn:ietf:params:oauth:grant-type:token-exchange grant_type. + Exchanging an Access Token per `Section 6`_. + + .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 + """ + GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange' + + #: The authorization server MAY issue a new refresh token + INCLUDE_NEW_REFRESH_TOKEN = False + + def _validate_request_client(self): + # require client authentication for confidential clients or for any + # client that was issued client credentials (or with other + # authentication requirements) + client = self.authenticate_token_endpoint_client() + log.debug('Validate token request of %r', client) + + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + return client + + def _validate_request_token(self, client): + subject_token = self.request.form.get('subject_token') + if subject_token is None: + raise InvalidRequestError('Missing "subject_token" in request.') + + subject_token_type = self.request.form.get('subject_token_type') + if subject_token_type is None: + raise InvalidRequestError('Missing "subject_token_type" in request.') + + actor_token = self.request.form.get('actor_token') + actor_token_type = self.request.form.get('actor_token_type') + if actor_token and actor_token_type is None: + raise InvalidRequestError('Missing "actor_token_type" in request.') + + token = self.authenticate_subject_token(subject_token, subject_token_type) + if not token or token.get_client_id() != client.get_client_id(): + raise InvalidGrantError() + return token + + def _validate_token_scope(self, token): + scope = self.request.scope + if not scope: + return + + original_scope = token.get_scope() + if not original_scope: + raise InvalidScopeError() + + def validate_token_request(self): + """ A client requests a security token by making a token request to the + authorization server's token endpoint using the extension grant type + mechanism defined in Section 4.5 of [RFC6749]. + + Client authentication to the authorization server is done using the + normal mechanisms provided by OAuth 2.0. Section 2.3.1 of [RFC6749] + defines password-based authentication of the client, however, client + authentication is extensible and other mechanisms are possible. For + example, [RFC7523] defines client authentication using bearer JSON + Web Tokens (JWTs) [JWT]. The supported methods of client + authentication and whether or not to allow unauthenticated or + unidentified clients are deployment decisions that are at the + discretion of the authorization server. Note that omitting client + authentication allows for a compromised token to be leveraged via an + STS into other tokens by anyone possessing the compromised token. + Thus, client authentication allows for additional authorization + checks by the STS as to which entities are permitted to impersonate + or receive delegations from other entities. + + The client makes a token exchange request to the token endpoint with + an extension grant type using the HTTP "POST" method. The following + parameters are included in the HTTP request entity-body using the + "application/x-www-form-urlencoded" format with a character encoding + of UTF-8 as described in Appendix B of [RFC6749], per Section 6: + + grant_type + REQUIRED. The value "urn:ietf:params:oauth:grant-type:token- + exchange" indicates that a token exchange is being performed. + + resource + OPTIONAL. A URI that indicates the target service or resource + where the client intends to use the requested security token. + This enables the authorization server to apply policy as + appropriate for the target, such as determining the type and + content of the token to be issued or if and how the token is to be + encrypted. In many cases, a client will not have knowledge of the + logical organization of the systems with which it interacts and + will only know a URI of the service where it intends to use the + token. The "resource" parameter allows the client to indicate to + the authorization server where it intends to use the issued token + by providing the location, typically as an https URL, in the token + exchange request in the same form that will be used to access that + resource. The authorization server will typically have the + capability to map from a resource URI value to an appropriate + policy. The value of the "resource" parameter MUST be an absolute + URI, as specified by Section 4.3 of [RFC3986], that MAY include a + query component and MUST NOT include a fragment component. + Multiple "resource" parameters may be used to indicate that the + issued token is intended to be used at the multiple resources + listed. See [OAUTH-RESOURCE] for additional background and uses + of the "resource" parameter. + + audience + OPTIONAL. The logical name of the target service where the client + intends to use the requested security token. This serves a + purpose similar to the "resource" parameter but with the client + providing a logical name for the target service. Interpretation + of the name requires that the value be something that both the + client and the authorization server understand. An OAuth client + identifier, a SAML entity identifier [OASIS.saml-core-2.0-os], and + an OpenID Connect Issuer Identifier [OpenID.Core] are examples of + things that might be used as "audience" parameter values. + However, "audience" values used with a given authorization server + must be unique within that server to ensure that they are properly + interpreted as the intended type of value. Multiple "audience" + parameters may be used to indicate that the issued token is + intended to be used at the multiple audiences listed. The + "audience" and "resource" parameters may be used together to + indicate multiple target services with a mix of logical names and + resource URIs. + + scope + OPTIONAL. A list of space-delimited, case-sensitive strings, as + defined in Section 3.3 of [RFC6749], that allow the client to + specify the desired scope of the requested security token in the + context of the service or resource where the token will be used. + The values and associated semantics of scope are service specific + and expected to be described in the relevant service + documentation. + + requested_token_type + OPTIONAL. An identifier, as described in Section 3, for the type + of the requested security token. If the requested type is + unspecified, the issued token type is at the discretion of the + authorization server and may be dictated by knowledge of the + requirements of the service or resource indicated by the + "resource" or "audience" parameter. + + subject_token + REQUIRED. A security token that represents the identity of the + party on behalf of whom the request is being made. Typically, the + subject of this token will be the subject of the security token + issued in response to the request. + + subject_token_type + REQUIRED. An identifier, as described in Section 3, that + indicates the type of the security token in the "subject_token" + parameter. + + actor_token + OPTIONAL. A security token that represents the identity of the + acting party. Typically, this will be the party that is + authorized to use the requested security token and act on behalf + of the subject. + + actor_token_type + An identifier, as described in Section 3, that indicates the type + of the security token in the "actor_token" parameter. This is + REQUIRED when the "actor_token" parameter is present in the + request but MUST NOT be included otherwise. + """ + client = self._validate_request_client() + self.request.client = client + token = self._validate_request_token(client) + self._validate_token_scope(token) + self.request.credential = token + + def create_token_response(self): + """If valid and authorized, the authorization server issues an access + token as described in Section 5.1. If the request failed + verification or is invalid, the authorization server returns an error + response as described in Section 5.2. + """ + credential = self.request.credential + user = self.authenticate_user(credential) + if not user: + raise InvalidRequestError('There is no "user" for this token.') + + client = self.request.client + token = self.issue_token(user, credential) + log.debug('Issue token %r to %r', token, client) + + self.request.user = user + self.save_token(token) + self.execute_hook('process_token', token=token) + self.revoke_old_credential(credential) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def issue_token(self, user, credential): + expires_in = credential.get_expires_in() + scope = self.request.scope + if not scope: + scope = credential.get_scope() + + token = self.generate_token(user=user, expires_in=expires_in, scope=scope, + include_refresh_token=self.INCLUDE_NEW_REFRESH_TOKEN) + return token + + def authenticate_subject_token(self, subject_token, subject_token_type): + """Get token information with subject_token string. Developers MUST + implement this method in subclass:: + + def authenticate_subject_token(self, subject_token, subject_token_type): + item = Token.get(**{subject_token_type: subject_token) + if item and item.is_refresh_token_active(): + return item + + :param subject_token: The token issued to the client + :param subject_token_type: The type of the token issued to the client + :return: token + """ + raise NotImplementedError() + + def authenticate_user(self, credential): + """Authenticate the user related to this credential. Developers MUST + implement this method in subclass:: + + def authenticate_user(self, credential): + return User.query.get(credential.user_id) + + :param credential: Token object + :return: user + """ + raise NotImplementedError() + + def revoke_old_credential(self, credential): + """The authorization server MAY revoke the old refresh token after + issuing a new refresh token to the client. Developers MUST implement + this method in subclass:: + + def revoke_old_credential(self, credential): + credential.revoked = True + credential.save() + + :param credential: Token object + """ + raise NotImplementedError() + + +TOKEN_TYPE_IDENTIFIERS = [ + # Indicates that the token is an OAuth 2.0 access token issued by + # the given authorization server. + 'urn:ietf:params:oauth:token-type:access_token', + # Indicates that the token is an OAuth 2.0 refresh token issued by + # the given authorization server. + 'urn:ietf:params:oauth:token-type:refresh_token', + # Indicates that the token is an ID Token as defined in Section 2 of + # [OpenID.Core]. + 'urn:ietf:params:oauth:token-type:id_token', + # Indicates that the token is a base64url-encoded SAML 1.1 + # [OASIS.saml-core-1.1] assertion. + 'urn:ietf:params:oauth:token-type:saml1', + # Indicates that the token is a base64url-encoded SAML 2.0 + # [OASIS.saml-core-2.0-os] assertion. + 'urn:ietf:params:oauth:token-type:saml2' +] + + +class TokenExchangeGrant(_TokenExchangeGrant): + def __init__(self, *args, **kwargs): + super(TokenExchangeGrant, self).__init__(*args, **kwargs) + self.validator = BearerTokenValidator() + + def authenticate_subject_token(self, subject_token, subject_token_type): + """ Get credential for token + + :param str subject_token: subject_token + :param str subject_token_type: token type https://tools.ietf.org/html/rfc8693#section-3 + + :return: object + """ + if subject_token_type.split(':')[-1] != 'refresh_token': + raise InvalidRequestError('Please set refresh_token to "subject_token" in request.') + + ######################## TODO ################## + # Check token in DB + token = self.server.db.getToken(subject_token) + if not token: + return None + + # Check token + token = self.validator(subject_token, self.request.scope, self.request, 'OR') + # token = session.token + + # To special flow to change group + if not self.request.scope: + return token + + scopes = scope_to_list(self.request.scope) + reqGroups = [s.split(':')[1] for s in scopes if s.startswith('g:')] + if len(reqGroups) != 1 or not reqGroups[0]: + return None + group = reqGroups[0] + result = Registry.getUsernameForID(token['sub']) + if not result['OK']: + return None + result = gProxyManager.getGroupsStatusByUsername(result['Value'], [group]) + if not result['OK']: + return None + if result['Value'][group]['Status'] not in ['ready', 'unknown']: + return None + return token + + def authenticate_user(self, credential): + """ Authorize user + + :param object credential: credential + + :return: str + """ + return credential.sub + + def revoke_old_credential(self, credential): + """ Remove old credential + + :param object credential: credential + """ + self.server.removeSession(credential['refresh_token']) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py new file mode 100644 index 00000000000..f23b7a0db73 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/grants/__init__.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from .AuthorizationCode import AuthorizationCodeGrant, OpenIDCode +# from .RefreshToken import RefreshTokenGrant +# from .DeviceFlow import DeviceCodeGrant, DeviceAuthorizationEndpoint +# from .ImplicitFlow import OpenIDImplicitGrant, NotebookImplicitGrant + +# __all__ = [ +# 'AuthorizationCodeGrant', 'RefreshTokenGrant', 'DeviceCodeGrant', +# 'OpenIDCode', 'DeviceAuthorizationEndpoint', 'OpenIDImplicitGrant', +# 'NotebookImplicitGrant' +# ] diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py new file mode 100644 index 00000000000..0672b44336b --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Clients.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import time + +from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin +from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint as _ClientRegistrationEndpoint +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.common.security import generate_token + +from DIRAC.Core.Utilities import ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthClients + +__RCSID__ = "$Id$" + +gCacheClient = ThreadSafe.Synchronizer() + + +class Client(OAuth2ClientMixin): + def __init__(self, params): + super(Client, self).__init__() + self.client_id = params['client_id'] + self.client_secret = params.get('client_secret', '') + self.client_id_issued_at = params['client_id_issued_at'] + self.client_secret_expires_at = params['client_secret_expires_at'] + if isinstance(params['client_metadata'], dict): + self._client_metadata = json.dumps(params['client_metadata']) + else: + self._client_metadata = params['client_metadata'] + + def get_allowed_scope(self, scope): + if not scope: + return '' + allowed = set(self.scope.split()) + scopes = scope_to_list(scope) + return list_to_scope([s for s in scopes if s in allowed or s.startswith('g:')]) + + +class ClientManager(object): + def __init__(self, database): + self.__db = database + self.__clients = DictCache() + + @gCacheClient + def addClient(self, data): + result = self.__db.addClient(data) + if result['OK']: + data = result['Value'] + self.__clients.add(data['client_id'], 24 * 3600, Client(data)) + return result + + @gCacheClient + def getClient(self, clientID): + print('getClient: %s ' % clientID) + client = self.__clients.get(clientID) + print(client) + if not client: + result = getAuthClients(clientID) + if not result['OK'] or not result['Value']: + result = self.__db.getClient(clientID) + print('getClient result: %s' % result) + if result['OK']: + cliDict = result['Value'] + cliDict['client_id_issued_at'] = cliDict.get('client_id_issued_at', int(time.time())) + cliDict['client_secret_expires_at'] = cliDict.get('client_secret_expires_at', 0) + client = Client(cliDict) + print('getClient client: %s' % str(client)) + self.__clients.add(clientID, 24 * 3600, client) + print('getClient client added') + print('finish: client: %s' % client) + return client + + +class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): + """ The client registration endpoint is an OAuth 2.0 endpoint designed to + allow a client to be registered with the authorization server. See authlib + :mod:`ClientRegistrationEndpoint ` class. + """ + # TODO: align with last version authlib + + def authenticate_user(self, request): + return True + + def authenticate_token(self, request): + # TODO: Provider token verification to allow regster clients only for reg users + return False + + def save_client(self, client_info, client_metadata, request): + print("Save client:") + print(client_info) + print(client_metadata) + for k, v in [('grant_types', + ['authorization_code', 'urn:ietf:params:oauth:grant-type:device_code']), + ('response_types', ['code', 'device']), + ('token_endpoint_auth_method', 'none')]: + if k not in client_metadata: + client_metadata[k] = v + + if client_metadata['token_endpoint_auth_method'] == 'none': + client_info['client_secret'] = '' + # else: + # client_info['client_secret'] = generate_token(48) + + client_info['client_metadata'] = client_metadata + + print(client_info) + result = self.server.addClient(client_info) + return Client(result['Value']) if result['OK'] else None diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/ProfileParser.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/ProfileParser.py new file mode 100644 index 00000000000..5b12c6342ab --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/ProfileParser.py @@ -0,0 +1,227 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__RCSID__ = "$Id$" + +import re +import six + +from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderByAlias +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForID, getVOMSRoleGroupMapping + + +""" +{ + : + { + ... + : { + FullName: ... + Provider: ... + DNs: { + : { + ProxyProvider: ..., + VOMSRoles: ..., + } + } + VOs: { + : { + ... + : { + ... + } + } + } + } + } +} +""" + + +def claimParser(claimDict, attributes): + """ Parse claims to write it as DIRAC profile + + :param dict claimDict: claims + :param dict attributes: contain claim and regex to parse it + :param dict profile: to fill parsed data + + :return: dict + """ + profile = {} + result = None + for claim, reg in attributes.items(): + if claim not in claimDict: + continue + profile[claim] = {} + if isinstance(claimDict[claim], dict): + result = claimParser(claimDict[claim], reg) + if result: + profile[claim] = result + elif isinstance(claimDict[claim], six.string_types): + result = re.compile(reg).match(claimDict[claim]) + if result: + for k, v in result.groupdict().items(): + profile[claim][k] = v + else: + profile[claim] = [] + for claimItem in claimDict[claim]: + if isinstance(reg, dict): + result = claimParser(claimItem, reg) + if result: + profile[claim].append(result) + else: + result = re.compile(reg).match(claimItem) + if result: + profile[claim].append(result.groupdict()) + + return profile + + +def parseBasic(claimDict): + """ Parse basic claims + + :param dict claimDict: claims + + :return: S_OK(dict)/S_ERROR() + """ + credDict = {} + credDict['ID'] = claimDict['sub'] + return credDict + + +def parseEduperson(claimDict): + """ Parse eduperson claims + + :param dict claimDict: claims + + :return: dict + """ + credDict = {} + attributes = { + 'eduperson_unique_id': '^(?P.*)', + 'eduperson_entitlement': '^(?P[A-z,.,_,-,:]+):(group:registry|group):\ + (?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*' + } + resDict = claimParser(claimDict, attributes) + if not resDict: + return credDict + credDict['ID'] = resDict['eduperson_unique_id']['ID'] + credDict['VOs'] = {} + for voDict in resDict['eduperson_entitlement']: + if voDict['VO'] not in credDict['VOs']: + credDict['VOs'][voDict['VO']] = {'VORoles': []} + if voDict['VORole'] not in credDict['VOs'][voDict['VO']]['VORoles']: + credDict['VOs'][voDict['VO']]['VORoles'].append(voDict['VORole']) + return credDict + + +def userDiscover(credDict): + result = getUsernameForID(credDict['ID']) + credDict['DIRACUsername'] = result['Value'] if result['OK'] else 'anonymous' + credDict['DIRACGroups'] = [] + for vo, voData in credDict.get('VOs', {}).items(): + result = getVOMSRoleGroupMapping(vo) + if result['OK']: + avilGroups = result['Value']['VOMSDIRAC'] + for role in voData['VORoles']: + groups = result['Value']['VOMSDIRAC'].get('/%s' % role) + if groups: + credDict['DIRACGroups'] = list(set(credDict['DIRACGroups'] + groups)) + return credDict + + +class ProfileParser(object): + def __init__(self, **parameters): + self.provider = parameters['ProviderName'] + self.user_id = None + self.username = None + self.profile = {self.provider: {}} + + def __call__(self, claimDict): + """ Parse claims + """ + self.parseBasic(claimDict) + self.parseEduperson(claimDict) + self.parseCertEntitlement(claimDict) + return S_OK((self.username, self.user_id, self.profile)) + + def parseBasic(self, claimDict): + """ Parse basic claims + + :param dict claimDict: claims + + :return: S_OK(dict)/S_ERROR() + """ + self.user_id = claimDict['sub'] + self.profile[self.provider][self.user_id] = {'DNs': {}, 'VOs': {}} + if claimDict.get('email'): + self.profile[self.provider][self.user_id]['Email'] = claimDict['email'] + gname = claimDict.get('given_name') + fname = claimDict.get('family_name') + pname = claimDict.get('preferred_username') + name = claimDict.get('name') and claimDict['name'].split(' ') + username = pname or gname and fname and gname[0] + fname + username = username or name and len(name) > 1 and name[0][0] + name[1] or '' + self.username = re.sub('[^A-Za-z0-9]+', '', username.lower())[:13] + fullname = gname and fname and ' '.join([gname, fname]) or name and ' '.join(name) or '' + self.profile[self.provider][self.user_id]['FullName'] = fullname + self.profile[self.provider][self.user_id]['Provider'] = self.provider + return self.profile + + def parseCertEntitlement(self, claimDict): + """ Parse cert_entitlement claim + + :param dict claimDict: claims + + :return: dict + """ + r = '^(?P[A-z,.,_,-,:]+):(group:registry|group):(?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*' + attributes = { + 'cert_entitlement': { + 'cert_iss': '(?P.*)', + 'cert_subject_dn': '(?P.*)', + 'eduperson_entitlement': r + } + } + result = claimParser(claimDict, attributes) + if result: + for data in result['cert_entitlement']: + dn = data['cert_subject_dn']['DN'] + # DIRAC understand only DNs with slashes + if not dn.startswith('/'): + gLogger.debug('Convert %s to view with slashes.' % dn) + items = dn.split(',') + items.reverse() + dn = '/' + '/'.join(items) + provider = data['cert_iss']['PROXYPROVIDER'] + if provider: + result = getProviderByAlias(provider, instance='Proxy') + provider = result['Value'] if result['OK'] else 'Certificate' + self.profile[self.provider][self.user_id]['DNs'][dn] = { + 'ProxyProvider': provider, + 'VOMSRoles': [data['eduperson_entitlement']['GROUP']] + } + return self.profile + + def parseEduperson(self, claimDict): + """ Parse eduperson claims + + :param dict claimDict: claims + + :return: dict + """ + attributes = { + 'eduperson_unique_id': '^(?P.*)', + 'eduperson_entitlement': '^(?P[A-z,.,_,-,:]+):(group:registry|group):\ + (?P[A-z,.,_,-]+):role=(?P[A-z,.,_,-]+)[:#].*' + } + resDict = claimParser(claimDict, attributes) + if not resDict: + return self.profile + + self.user_id = resDict['eduperson_unique_id']['ID'] + for voDict in resDict['eduperson_entitlement']: + self.profile[self.provider][self.user_id]['VOs'][voDict['VO']] = {voDict['VOMSRole']: {}} + return self.profile diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py new file mode 100644 index 00000000000..2f6e7aa3ccc --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Requests.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tornado.escape import json_decode +from authlib.oauth2 import OAuth2Request as _OAuth2Request +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + +__RCSID__ = "$Id$" + + +class OAuth2Request(_OAuth2Request): + """ OAuth request object """ + + def addScopes(self, scopes): + """ Add new scopes to query + + :param list scopes: scopes + """ + # Remove "scope" argument from uri + self.uri = re.sub(r"&scope(=[^&]*)?|^scope(=[^&]*)?&?", "", self.uri) + # Add "scope" argument to uri with new scopes + self.uri += "&scope=%s" % list_to_scope(list(set(scope_to_list(self.scope) + scopes))) or '' + # Reinit all attributes with new uri + self.__init__(self.method, self.uri) + + @property + def groups(self): + """ Serarch DIRAC groups in scopes + + :return: list + """ + return [s.split(':')[1] for s in scope_to_list(self.scope) if s.startswith('g:')] + + def toDict(self): + """ Convert class to dictionary + + :return: dict + """ + return {'method': self.method, 'uri': self.uri} + + +def createOAuth2Request(request, method_cls=OAuth2Request, use_json=False): + """ Create request object + + :param request: request + :type request: object, dict + :param object method_cls: returned class + :param str use_json: if data is json + + :return: object -- `OAuth2Request` + """ + if isinstance(request, method_cls): + return request + if isinstance(request, dict): + return method_cls(request['method'], request['uri'], request.get('body'), request.get('headers')) + if use_json: + body = json_decode(request.body) + else: + body = {} + for k, v in request.body_arguments.items(): + body[k] = ' '.join(v) + return method_cls(request.method, request.full_url(), body, request.headers) diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Sessions.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Sessions.py new file mode 100644 index 00000000000..6d36bfaf716 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Sessions.py @@ -0,0 +1,186 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +from pprint import pprint + +from DIRAC import gLogger +from DIRAC.Core.Utilities import ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token + +__RCSID__ = "$Id$" + +gCacheSession = ThreadSafe.Synchronizer() + + +class Session(dict): + """ A dict instance to represent a authentication session object. + + :param session: + :type session: str or dict + """ + + def __init__(self, session, data=None, exp=300): + if isinstance(session, Session): + session = dict(session) + + data = data or {} + if isinstance(session, dict): + session.update(data) + data = session + else: + data['id'] = session + if not data.get('id'): + raise KeyError('Missing "id" for a session.') + if not data.get('expires_at'): + data['expires_at'] = int(time()) + exp + if not data.get('created'): + data['created'] = int(time()) + super(Session, self).__init__(**data) + self.id = data['id'] + self.created = self['created'] + + @property + def status(self): + """ Session status + + :return: int + """ + return self.get('Status', 'submited') + + @property + def age(self): + """ Session age + + :return: int + """ + return int(time()) - self.created + + @property + def token(self): + """ Tokens + + :return: object + """ + return self.get('token') and OAuth2Token(self['token']) + + def update(self, data=None, **kwargs): + """ Update session + + :param dict data: dictionary with new values + + :return: object + """ + kwargs.update(data or {}) + super(Session, self).update(kwargs) + print('updated done') + return self + + +class SessionManager(object): + """ Authentication sessions cache manager """ + + def __init__(self, database, addTime=300, maxAge=3600 * 12): + """ Con'r + + :param int addTime: additional time added to session life + :param int maxAge: max session age + """ + # self.__sessions = DictCache() + self.__db = database + self.__addTime = addTime + self.__maxAge = maxAge + + # @gCacheSession + def addSession(self, session, exp=None, **kwargs): + """ Add session to cache + + :param session: session + :type session: str, dict or Session object + :param int exp: expired time + """ + # print('-- addSession') + # pprint(session) + # exp = min(exp or self.__addTime, self.__maxAge) + # session = Session(session, data=kwargs, exp=exp) + # pprint(session) + + # if session.age > self.__maxAge: + # return self.__sessions.delete(session.id) + # print('ADD SESSION: %s' % session.id) + return self.__db.addSession(dict(session)) + # return self.__sessions.add(session.id, exp, session) + + # @gCacheSession + def getSession(self, session): + """ Get session from cache + + :param session: session + :type session: str, Session object + + :return: Session object + """ + print('-- getSession') + pprint(session) + return self.__db.getSession(session) + # return self.__sessions.get(session.id if isinstance(session, Session) else session) + + # # @gCacheSession + # def getSessions(self): + # """ Get all sessions from cache + + # :return: dict + # """ + # return self.__sessions.getDict() + + # @gCacheSession + def removeSession(self, session): + """ Remove session from cache + + :param session: session + :type session: str, Session object + """ + print('-- removeSession') + pprint(session) + return self.__db.removeSession(session) + # self.__sessions.delete(session.id if isinstance(session, Session) else session) + + def updateSession(self, session, exp=None, createIfNotExist=None, **kwargs): + """ Update session in cache + + :param session: session + :type session: str, Session object + :param int exp: expiration time + """ + print('-- updateSession') + pprint(session) + sessionID = session.id if isinstance(session, Session) else session + session = self.getSession(sessionID) + pprint(session) + exp = exp or self.__addTime + if session and session.age < self.__maxAge: + if (session.age + exp) > self.__maxAge: + exp = self.__maxAge - session.age + if exp: + print('UPDATE SESSION: %s' % session.id) + self.addSession(session.update(kwargs), exp) + elif createIfNotExist: + print('UPDATE hard SESSION: %s' % sessionID) + self.addSession(sessionID, exp, **kwargs) + + def getSessionByOption(self, key, value): + """ Search session by the option + + :param str key: option name + :param str value: option value + + :return: str, Session + """ + if key and value: + sessions = self.getSessions() + for session, data in sessions.items(): + if data.get(key) == value: + return session, data + return None, None diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py new file mode 100644 index 00000000000..3e1e289739a --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/Tokens.py @@ -0,0 +1,128 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from time import time +import functools +from contextlib import contextmanager + +from authlib.jose import jwt +from authlib.oauth2 import OAuth2Error, ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import MissingAuthorizationError, HttpRequest +from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator +from authlib.oauth2.rfc6749.wrappers import OAuth2Token as _OAuth2Token +from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + + +class OAuth2Token(_OAuth2Token, OAuth2TokenMixin): + """ Implementation a Token object """ + + def __init__(self, params=None, **kwargs): + kwargs.update(params or {}) + self.sub = kwargs.get('sub') + self.issuer = kwargs.get('iss') + self.client_id = kwargs.get('client_id', kwargs.get('aud')) + self.token_type = kwargs.get('token_type') + self.access_token = kwargs.get('access_token') + self.refresh_token = kwargs.get('refresh_token') + self.scope = kwargs.get('scope') + self.revoked = kwargs.get('revoked') + self.issued_at = int(kwargs.get('issued_at', kwargs.get('iat', time()))) + self.expires_in = int(kwargs.get('expires_in', 0)) + self.expires_at = int(kwargs.get('expires_at', kwargs.get('exp', 0))) + if not self.issued_at: + raise Exception('Missing "iat" in token.') + if not self.expires_at: + if not self.expires_in: + raise Exception('Cannot calculate token "expires_at".') + self.expires_at = self.issued_at + self.expires_in + if not self.expires_in: + self.expires_in = self.expires_at - self.issued_at + kwargs.update({'client_id': self.client_id, + 'token_type': self.token_type, + 'access_token': self.access_token, + 'refresh_token': self.refresh_token, + 'scope': self.scope, + 'revoked': self.revoked, + 'issued_at': self.issued_at, + 'expires_in': self.expires_in, + 'expires_at': self.expires_at}) + super(OAuth2Token, self).__init__(kwargs) + + @property + def scopes(self): + """ Get tokens scopes + + :return: list + """ + return scope_to_list(self.scope) or [] + + @property + def groups(self): + """ Get tokens groups + + :return: list + """ + return [s.split(':')[1] for s in self.scopes if s.startswith('g:')] + + +class ResourceProtector(_ResourceProtector): + """ A protecting method for resource servers. """ + + def __init__(self): + self.validator = BearerTokenValidator() + self._token_validators = {self.validator.TOKEN_TYPE: self.validator} + + def acquire_token(self, request, scope=None, operator='AND'): + """ A method to acquire current valid token with the given scope. + + :param request: Tornado HTTP request instance + :param scope: string or list of scope values + :param operator: value of "AND" or "OR" + + :return: token object + """ + req = HttpRequest(request.method, request.uri, request.body, request.headers) + return self.validate_request(scope, req, operator if callable(operator) else operator.upper()) + + +class BearerTokenValidator(_BearerTokenValidator): + """ Token validator """ + + def authenticate_token(self, token): + """ A method to query token from database with the given token string. + + :param str token: A string to represent the access_token. + + :return: token + """ + # Read public key of DIRAC auth service + with open('/opt/dirac/etc/grid-security/jwtRS256.key.pub', 'rb') as f: + key = f.read() + + # Get claims and verify signature + claims = jwt.decode(token, key) + + # Verify token + claims.validate() + + return OAuth2Token(claims, access_token=token) + + def request_invalid(self, request): + """ Request validation + + :param object request: request + + :return: bool + """ + return False + + def token_revoked(self, token): + """ If token can be revoked + + :param object token: token + + :return: bool + """ + return token.revoked diff --git a/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py new file mode 100644 index 00000000000..02d96d1b436 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/authorization/utils/__init__.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from .Clients import Client, ClientRegistrationEndpoint, ClientManager +# from .Sessions import SessionManager, Session +# from .Requests import OAuth2Request, createOAuth2Request +# from .Tokens import ResourceProtector, OAuth2Token, BearerTokenValidator + +# __all__ = [ +# 'Client', 'ClientRegistrationEndpoint', 'SessionManager', 'ClientManager', +# 'Session', 'OAuth2Request', 'createOAuth2Request', 'ResourceProtector', 'OAuth2Token', +# 'BearerTokenValidator' +# ] diff --git a/src/DIRAC/FrameworkSystem/private/testNotebookAuth.py b/src/DIRAC/FrameworkSystem/private/testNotebookAuth.py new file mode 100644 index 00000000000..2424896688e --- /dev/null +++ b/src/DIRAC/FrameworkSystem/private/testNotebookAuth.py @@ -0,0 +1,132 @@ +import os +import stat +import requests +import urllib3 + +from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.Utilities.JEncode import decode, encode +from DIRAC.ConfigurationSystem.Client.Helpers import Registry + + +class notebookAuth(object): + """ The main goal of this class provide authentication with access token + """ + + def __init__(self, group, lifetime=3600 * 12, voms=False, aToken=None, proxyPath=None): + """ C'r + + :param str group: requested group + :param int lifetime: requested proxy lifetime + :param bool voms: requested voms extension + :param str aToken: access token or path + :param str proxyPath: proxy path + """ + self.log = gLogger.getSubLogger(__name__) + # Defaulf location for proxy is /tmp/x509up_uXXXX + self.pPath = proxyPath or '/tmp/x509up_u%s' % os.getuid() + self.group = group + self.lifetime = lifetime + self.voms = voms + # Default access token path for notebook: /var/run/secrets/egi.eu/access_token + self.accessToken = aToken or '/var/run/secrets/egi.eu/access_token' + # Load client metadata + result = gConfig.getOptionsDict("/LocalInstallation/AuthorizationClient") + if not result['OK']: + raise Exception("Can't load client settings.") + self.metadata = result['Value'] + # For this open client we don't verify ssl certs + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + def getToken(self): + """ Get access_token + + :return: S_OK(dict)/S_ERROR() + """ + # Read noteboot access token + if self.accessToken.startswith('/'): + with open(self.accessToken, 'rb') as f: + self.accessToken = f.read() + + # Fill authorization URL + url = '%s/authorization' % self.metadata['issuer'] + url += '?client_id=%s' % self.metadata['client_id'] + url += '&redirect_uri=%s' % self.metadata['redirect_uri'] + url += '&response_type=%s' % self.metadata['response_type'] + if self.group: + url += '&scope=g:%s' % self.group + # For this version of code we use only CheckIn provider + url += '&provider=CheckIn&access_token=%s' % self.accessToken + try: + r = requests.get(url, verify=False) + r.raise_for_status() + return S_OK(r.json()) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer.') + except requests.exceptions.RequestException as ex: + return S_ERROR(r.content or ex) + except Exception as ex: + return S_ERROR('Cannot read response: %s' % ex) + + def getProxyWithToken(self, token): + """ Get proxy with token + + :param str token: access token + + :return: S_OK()/S_ERROR() + """ + # Get REST endpoints from local CS + confUrl = gConfig.getValue("/LocalInstallation/ConfigurationServerAPI") + if not confUrl: + return S_ERROR('Could not get configuration server API URL.') + setup = gConfig.getValue("/DIRAC/Setup") + if not setup: + return S_ERROR('Could not get setup name.') + + # Get REST endpoints from ConfigurationService + try: + r = requests.get('%s/option?path=/Systems/Framework/Production/URLs/ProxyAPI' % confUrl, verify=False) + r.raise_for_status() + proxyAPI = r.text + # proxyAPI = decode(r.text)[0] + except requests.exceptions.Timeout: + return S_ERROR('Time out') + except requests.exceptions.RequestException as e: + return S_ERROR(str(e)) + except Exception as e: + return S_ERROR('Cannot read response: %s' % e) + + # Fill the proxy request URL + # url = '%ss:%s/g:%s/proxy?lifetime=%s' % (proxyAPI, setup, self.group, self.lifetime) + url = '%sproxy?lifetime=%s' % (proxyAPI, self.lifetime) + voms = self.voms or Registry.getGroupOption(self.group, "AutoAddVOMS", False) + if voms: + url += '&voms=%s' % voms + + # Get proxy from REST API + try: + r = requests.get(url, headers={'Authorization': 'Bearer ' + token}, verify=False) + r.raise_for_status() + proxy = r.text + # proxy = decode(r.text)[0] + except requests.exceptions.Timeout: + return S_ERROR('Time out') + except requests.exceptions.RequestException as e: + return S_ERROR(str(e)) + except Exception as e: + return S_ERROR('Cannot read response: %s' % e) + + if not proxy: + return S_ERROR("Result is empty.") + + self.log.notice('Saving proxy.. to %s..' % self.pPath) + + # Save proxy to file + try: + with open(self.pPath, 'w+') as fd: + fd.write(proxy.encode("UTF-8")) + os.chmod(self.pPath, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + return S_ERROR("%s :%s" % (self.pPath, repr(e).replace(',)', ')'))) + + self.log.notice('Proxy is saved to %s.' % self.pPath) + return S_OK() diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_admin_get_proxy.py b/src/DIRAC/FrameworkSystem/scripts/dirac_admin_get_proxy.py index fe26f14d58a..10531c9b1ca 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_admin_get_proxy.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_admin_get_proxy.py @@ -42,7 +42,6 @@ class Params(object): proxyPath = False proxyLifeTime = 86400 enableVOMS = False - vomsAttr = None def setLimited(self, args): """ Set limited @@ -89,15 +88,6 @@ def automaticVOMS(self, arg): self.enableVOMS = True return S_OK() - def setVOMSAttr(self, arg): - """ Register CLI switches - - :param str arg: VOMS attribute - """ - self.enableVOMS = True - self.vomsAttr = arg - return S_OK() - def registerCLISwitches(self): """ Register CLI switches """ @@ -105,7 +95,6 @@ def registerCLISwitches(self): Script.registerSwitch("l", "limited", "Get a limited proxy", self.setLimited) Script.registerSwitch("u:", "out=", "File to write as proxy", self.setProxyLocation) Script.registerSwitch("a", "voms", "Get proxy with VOMS extension mapped to the DIRAC group", self.automaticVOMS) - Script.registerSwitch("m:", "vomsAttr=", "VOMS attribute to require", self.setVOMSAttr) @DIRACScript() @@ -120,45 +109,27 @@ def main(): Script.showHelp() userGroup = str(args[1]) - userDN = str(args[0]) - userName = False - if userDN.find("/") != 0: - userName = userDN - retVal = Registry.getDNForUsername(userName) - if not retVal['OK']: - gLogger.notice("Cannot discover DN for username %s\n\t%s" % (userName, retVal['Message'])) + + # First argument is user name + if not str(args[0]).startswith("/"): + userName = str(args[0]) + userDN = None + else: + userDN = str(args[0]) + result = Registry.getUsernameForDN(userDN) + if not result['OK']: + gLogger.notice("DN '%s' is not registered in DIRAC" % userDN) DIRAC.exit(2) - DNList = retVal['Value'] - if len(DNList) > 1: - gLogger.notice("Username %s has more than one DN registered" % userName) - ind = 0 - for dn in DNList: - gLogger.notice("%d %s" % (ind, dn)) - ind += 1 - inp = six.moves.input("Which DN do you want to download? [default 0] ") - if not inp: - inp = 0 - else: - inp = int(inp) - userDN = DNList[inp] - else: - userDN = DNList[0] + userName = result['Value'] if not params.proxyPath: - if not userName: - result = Registry.getUsernameForDN(userDN) - if not result['OK']: - gLogger.notice("DN '%s' is not registered in DIRAC" % userDN) - DIRAC.exit(2) - userName = result['Value'] params.proxyPath = "%s/proxy.%s.%s" % (os.getcwd(), userName, userGroup) if params.enableVOMS: - result = gProxyManager.downloadVOMSProxy(userDN, userGroup, limited=params.limited, - requiredTimeLeft=params.proxyLifeTime, - requiredVOMSAttribute=params.vomsAttr) + result = gProxyManager.downloadVOMSProxy(userDN or userName, userGroup, limited=params.limited, + requiredTimeLeft=params.proxyLifeTime) else: - result = gProxyManager.downloadProxy(userDN, userGroup, limited=params.limited, + result = gProxyManager.downloadProxy(userDN or userName, userGroup, limited=params.limited, requiredTimeLeft=params.proxyLifeTime) if not result['OK']: gLogger.notice('Proxy file cannot be retrieved: %s' % result['Message']) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_admin_users_with_proxy.py b/src/DIRAC/FrameworkSystem/scripts/dirac_admin_users_with_proxy.py index 45c1ed6f10c..e32f10d530b 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_admin_users_with_proxy.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_admin_users_with_proxy.py @@ -58,35 +58,32 @@ def main(): params.registerCLISwitches() Script.parseCommandLine(ignoreErrors=True) args = Script.getPositionalArgs() - result = gProxyManager.getDBContents() + result = gProxyManager.getUploadedProxiesDetails() if not result['OK']: print("Can't retrieve list of users: %s" % result['Message']) DIRAC.exit(1) - keys = result['Value']['ParameterNames'] - records = result['Value']['Records'] dataDict = {} - now = Time.dateTime() - for record in records: - expirationDate = record[3] - dt = expirationDate - now + for infoDict in result['Value']['Dictionaries']: + user = infoDict['user'] + del infoDict['user'] + dt = infoDict['expirationtime'] - Time.dateTime() secsLeft = dt.days * 86400 + dt.seconds if secsLeft > params.proxyLifeTime: - userName, userDN, userGroup, _, persistent = record - if userName not in dataDict: - dataDict[userName] = [] - dataDict[userName].append((userDN, userGroup, expirationDate, persistent)) - - for userName in dataDict: - print("* %s" % userName) - for iP in range(len(dataDict[userName])): - data = dataDict[userName][iP] - print(" DN : %s" % data[0]) - print(" group : %s" % data[1]) - print(" not after : %s" % Time.toString(data[2])) - print(" persistent : %s" % data[3]) - if iP < len(dataDict[userName]) - 1: - print(" -") + infoDict['expirationtime'] = Time.toString(infoDict['expirationtime']) + if user not in dataDict: + dataDict[user] = [] + dataDict[user].append(infoDict) + + keys = result['Value']['Dictionaries'][0].keys() if result['Value']['Dictionaries'] else [''] + strFormat = "{{:<{}}}".format(max(len(i) for i in keys)) + + for user, userDicts in dataDict.items(): + print("* %s" % user) + for userDict in userDicts: + for k, v in userDict.items(): + print(" %s : %s" % (strFormat.format(k), ','.join(v) if isinstance(v, (list, tuple)) else v)) + print(" -") DIRAC.exit(0) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_notebook_proxy_init.py b/src/DIRAC/FrameworkSystem/scripts/dirac_notebook_proxy_init.py new file mode 100644 index 00000000000..15f9d8639b1 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_notebook_proxy_init.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys +import glob +import time +import threading + +import DIRAC +from DIRAC import gConfig, gLogger, S_OK, S_ERROR +from DIRAC.Core.Base import Script +from DIRAC.Core.Security import ProxyInfo # pylint: disable=import-error +from DIRAC.Core.Utilities.DIRACScript import DIRACScript +from DIRAC.FrameworkSystem.Client import ProxyGeneration +from DIRAC.FrameworkSystem.Client.BundleDeliveryClient import BundleDeliveryClient +from DIRAC.FrameworkSystem.private.testNotebookAuth import notebookAuth + +__RCSID__ = "$Id$" + + +class Params(ProxyGeneration.CLIParams): + + addVOMSExt = False + + def setVOMSExt(self, _arg): + """ Set VOMS extention + + :param _arg: unuse + + :return: S_OK() + """ + self.addVOMSExt = True + return S_OK() + + def registerCLISwitches(self): + """ Register CLI switches """ + ProxyGeneration.CLIParams.registerCLISwitches(self) + Script.registerSwitch("M", "VOMS", "Add voms extension", self.setVOMSExt) + + +class ProxyInit(object): + + def __init__(self, piParams): + """ Constructor """ + self.__piParams = piParams + self.__issuerCert = False + self.__proxyGenerated = False + self.__uploadedInfo = {} + + def printInfo(self): + """ Printing utilities + """ + resultProxyInfoAsAString = ProxyInfo.getProxyInfoAsString(self.__proxyGenerated) + if not resultProxyInfoAsAString['OK']: + gLogger.error('Failed to get the new proxy info: %s' % resultProxyInfoAsAString['Message']) + else: + gLogger.notice("Proxy generated:") + gLogger.notice(resultProxyInfoAsAString['Value']) + if self.__uploadedInfo: + gLogger.notice("\nProxies uploaded:") + maxDNLen = 0 + maxProviderLen = len('ProxyProvider') + for userDN, data in self.__uploadedInfo.items(): + maxDNLen = max(maxDNLen, len(userDN)) + maxProviderLen = max(maxProviderLen, len(data['provider'])) + gLogger.notice(" %s | %s | %s | SupportedGroups" % ("DN".ljust(maxDNLen), "ProxyProvider".ljust(maxProviderLen), + "Until (GMT)".ljust(16))) + for userDN, data in self.__uploadedInfo.items(): + gLogger.notice(" %s | %s | %s | " % (userDN.ljust(maxDNLen), data['provider'].ljust(maxProviderLen), + data['expirationtime'].strftime("%Y/%m/%d %H:%M").ljust(16)), + ",".join(data['groups'])) + + def checkCAs(self): + """ Check CAs + + :return: S_OK() + """ + if "X509_CERT_DIR" not in os.environ: + gLogger.warn("X509_CERT_DIR is unset. Abort check of CAs") + return + caDir = os.environ["X509_CERT_DIR"] + # In globus standards .r0 files are CRLs. They have the same names of the CAs but diffent file extension + searchExp = os.path.join(caDir, "*.r0") + crlList = glob.glob(searchExp) + if not crlList: + gLogger.warn("No CRL files found for %s. Abort check of CAs" % searchExp) + return + newestFPath = max(crlList, key=os.path.getmtime) + newestFTime = os.path.getmtime(newestFPath) + if newestFTime > (time.time() - (2 * 24 * 3600)): + # At least one of the files has been updated in the last 2 days + return S_OK() + if not os.access(caDir, os.W_OK): + gLogger.error("Your CRLs appear to be outdated, but you have no access to update them.") + # Try to continue anyway... + return S_OK() + # Update the CAs & CRLs + gLogger.notice("Your CRLs appear to be outdated; attempting to update them...") + bdc = BundleDeliveryClient() + res = bdc.syncCAs() + if not res['OK']: + gLogger.error("Failed to update CAs", res['Message']) + res = bdc.syncCRLs() + if not res['OK']: + gLogger.error("Failed to update CRLs", res['Message']) + # Continue even if the update failed... + return S_OK() + + def doOAuthMagic(self): + """ Magic method + + :return: S_OK()/S_ERROR() + """ + if not self.__piParams.diracGroup: + return S_ERROR('Need to set user group.') + nAuth = notebookAuth( + self.__piParams.diracGroup, + voms=self.__piParams.addVOMSExt, + proxyPath=self.__piParams.proxyLoc) + result = nAuth.getToken() + if not result['OK']: + return result + aToken = result['Value'].get('access_token') + if not aToken: + return S_ERROR('Access token is absent in resporse.') + result = nAuth.getProxyWithToken(aToken) + if not result['OK']: + return result + + result = Script.enableCS() + if not result['OK']: + return S_ERROR("Cannot contact CS to get user list") + threading.Thread(target=self.checkCAs).start() + gConfig.forceRefresh(fromMaster=True) + return S_OK(self.__piParams.proxyLoc) + + +@DIRACScript() +def main(): + piParams = Params() + piParams.registerCLISwitches() + + Script.disableCS() + Script.parseCommandLine(ignoreErrors=True) + DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") + + pI = ProxyInit(piParams) + gLogger.info(gConfig.getConfigurationTree()) + resultDoMagic = pI.doOAuthMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) + sys.exit(1) + + pI.printInfo() + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_destroy.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_destroy.py index 54e1710b354..2e04627fcc1 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_destroy.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_destroy.py @@ -14,14 +14,13 @@ import os import DIRAC + from DIRAC import gLogger, S_OK from DIRAC.Core.Base import Script -from DIRAC.Core.Utilities.DIRACScript import DIRACScript - from DIRAC.Core.Security import Locations, ProxyInfo -from DIRAC.Core.DISET.RPCClient import RPCClient -from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager +from DIRAC.Core.Utilities.DIRACScript import DIRACScript from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager class Params(object): @@ -79,7 +78,7 @@ def getProxyGroups(): user_groups = set() for dn in proxies['Value']: - dn_groups = set(proxies['Value'][dn].keys()) + dn_groups = set(proxies['Value'][dn]['groups']) user_groups.update(dn_groups) return user_groups @@ -97,18 +96,15 @@ def mapVoToGroups(voname): return set(vo_dict['Value']) -def deleteRemoteProxy(userdn, vogroup): +def deleteRemoteProxy(userdn): """ - Deletes proxy for a vogroup for the user envoking this function. + Deletes proxy for all groups for the user envoking this function. Returns a list of all deleted proxies (if any). """ - rpcClient = RPCClient("Framework/ProxyManager") - retVal = rpcClient.deleteProxyBundle([(userdn, vogroup)]) - - if retVal['OK']: - gLogger.notice('Deleted proxy for %s.' % vogroup) + if gProxyManager.deleteProxy(userdn)['OK']: + gLogger.notice('Deleted proxy %s.' % userdn) else: - gLogger.error('Failed to delete proxy for %s.' % vogroup) + gLogger.error('Failed to delete proxy %s.' % userdn) def deleteLocalProxy(proxyLoc): @@ -161,7 +157,7 @@ def run(): if not remote_groups: gLogger.notice('No remote proxies found.') for vo_group in remote_groups: - deleteRemoteProxy(userDN, vo_group) + deleteRemoteProxy(userDN) # delete local proxy deleteLocalProxy(proxyLoc) elif options.vos: @@ -174,7 +170,7 @@ def run(): if not vo_groups: gLogger.notice('You have no proxies registered for any of the specified VOs.') for group in vo_groups: - deleteRemoteProxy(userDN, group) + deleteRemoteProxy(userDN) else: deleteLocalProxy(proxyLoc) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_get_uploaded_info.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_get_uploaded_info.py index e035b4574ee..f435409ac92 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_get_uploaded_info.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_get_uploaded_info.py @@ -23,6 +23,8 @@ import sys +import DIRAC + from DIRAC import gLogger, S_OK from DIRAC.Core.Base import Script from DIRAC.Core.Utilities.DIRACScript import DIRACScript @@ -64,25 +66,18 @@ def main(): gLogger.notice("Your proxy don`t have username extension") sys.exit(1) - if userName in Registry.getAllUsers(): - if Properties.PROXY_MANAGEMENT not in proxyProps['groupProperties']: - if userName != proxyProps['username'] and userName != proxyProps['issuer']: - gLogger.notice("You can only query info about yourself!") - sys.exit(1) - result = Registry.getDNForUsername(userName) - if not result['OK']: - gLogger.notice("Oops %s" % result['Message']) - dnList = result['Value'] - if not dnList: - gLogger.notice("User %s has no DN defined!" % userName) + if userName not in Registry.getAllUsers(): + gLogger.notice("%s user is not found.") + sys.exit(1) + + if Properties.PROXY_MANAGEMENT not in proxyProps['groupProperties']: + if userName != proxyProps['username'] and userName != proxyProps['issuer']: + gLogger.notice("You can only query info about yourself!") sys.exit(1) - userDNs = dnList - else: - userDNs = [userName] - gLogger.notice("Checking for DNs %s" % " | ".join(userDNs)) + gLogger.notice("Checking for user", userName) pmc = ProxyManagerClient() - result = pmc.getDBContents({'UserDN': userDNs}) + result = pmc.getUploadedProxiesDetails(userName) if not result['OK']: gLogger.notice("Could not retrieve the proxy list: %s" % result['Value']) sys.exit(1) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_info.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_info.py index 701f56cd4d6..166712212d5 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_info.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_info.py @@ -146,17 +146,16 @@ def invalidProxy(msg): if uploadedInfo: gLogger.notice("== Proxies uploaded ==") maxDNLen = 0 - maxGroupLen = 0 - for userDN in uploadedInfo: + maxProviderLen = len('ProxyProvider') + for userDN, data in uploadedInfo.items(): maxDNLen = max(maxDNLen, len(userDN)) - for group in uploadedInfo[userDN]: - maxGroupLen = max(maxGroupLen, len(group)) - gLogger.notice(" %s | %s | Until (GMT)" % ("DN".ljust(maxDNLen), "Group".ljust(maxGroupLen))) - for userDN in uploadedInfo: - for group in uploadedInfo[userDN]: - gLogger.notice(" %s | %s | %s" % (userDN.ljust(maxDNLen), - group.ljust(maxGroupLen), - uploadedInfo[userDN][group].strftime("%Y/%m/%d %H:%M"))) + maxProviderLen = max(maxProviderLen, len(data['provider'])) + gLogger.notice(" %s | %s | %s | SupportedGroups" % ("DN".ljust(maxDNLen), "ProxyProvider".ljust(maxProviderLen), + "Until (GMT)".ljust(16))) + for userDN, data in uploadedInfo.items(): + gLogger.notice(" %s | %s | %s | " % (userDN.ljust(maxDNLen), data['provider'].ljust(maxProviderLen), + data['expirationtime'].strftime("%Y/%m/%d %H:%M").ljust(16)), + ",".join(data['groups'])) if params.checkValid: if infoDict['secondsLeft'] == 0: diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py index 6bf4b726294..7265f4b8ba1 100755 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_proxy_init.py @@ -16,39 +16,89 @@ import os import sys +import stat import glob import time +import pickle import datetime import DIRAC - -from DIRAC import gLogger, S_OK, S_ERROR +from DIRAC import gConfig, gLogger, S_OK, S_ERROR from DIRAC.Core.Base import Script +from DIRAC.Core.Security import X509Chain, ProxyInfo, Properties, VOMS # pylint: disable=import-error from DIRAC.Core.Utilities.DIRACScript import DIRACScript -from DIRAC.FrameworkSystem.Client import ProxyGeneration, ProxyUpload -from DIRAC.Core.Security import X509Chain, ProxyInfo, Properties, VOMS from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.FrameworkSystem.Client import ProxyGeneration, ProxyUpload from DIRAC.FrameworkSystem.Client.BundleDeliveryClient import BundleDeliveryClient __RCSID__ = "$Id$" - class Params(ProxyGeneration.CLIParams): + session = None + provider = '' + addEmail = False + addQRcode = False addVOMSExt = False + addProvider = False uploadProxy = True uploadPilot = False + def setEmail(self, arg): + """ Set email + + :param str arg: email + + :return: S_OK() + """ + self.Email = arg + self.addEmail = True + return S_OK() + + def setQRcode(self, _arg): + """ Use QRcode + + :param _arg: unuse + + :return: S_OK() + """ + self.addQRcode = True + return S_OK() + + def setProvider(self, arg): + """ Set provider + + :param str arg: provider + + :return: S_OK() + """ + self.provider = arg + self.addProvider = True + return S_OK() + def setVOMSExt(self, _arg): + """ Set VOMS extention + + :param _arg: unuse + + :return: S_OK() + """ self.addVOMSExt = True return S_OK() def disableProxyUpload(self, _arg): + """ Do not upload proxy + + :param _arg: unuse + + :return: S_OK() + """ self.uploadProxy = False return S_OK() def registerCLISwitches(self): + """ Register CLI switches """ ProxyGeneration.CLIParams.registerCLISwitches(self) Script.registerSwitch( "U", @@ -59,18 +109,26 @@ def registerCLISwitches(self): "no-upload", "Do not upload a long lived proxy to the ProxyManager", self.disableProxyUpload) + Script.registerSwitch("e:", "email=", "Send oauth authentification url on email", self.setEmail) + Script.registerSwitch("P:", "provider=", "Set provider name for authentification", self.setProvider) + Script.registerSwitch("Q", "qrcode", "Print link as QR code", self.setQRcode) Script.registerSwitch("M", "VOMS", "Add voms extension", self.setVOMSExt) class ProxyInit(object): def __init__(self, piParams): + """ Constructor """ self.__piParams = piParams self.__issuerCert = False self.__proxyGenerated = False self.__uploadedInfo = {} def getIssuerCert(self): + """ Get certificate issuer + + :return: str + """ if self.__issuerCert: return self.__issuerCert proxyChain = X509Chain.X509Chain() @@ -86,6 +144,8 @@ def getIssuerCert(self): return self.__issuerCert def certLifeTimeCheck(self): + """ Check certificate live time + """ minLife = Registry.getGroupOption(self.__piParams.diracGroup, "SafeCertificateLifeTime", 2592000) resultIssuerCert = self.getIssuerCert() resultRemainingSecs = resultIssuerCert.getRemainingSecs() # pylint: disable=no-member @@ -97,10 +157,13 @@ def certLifeTimeCheck(self): daysLeft = int(lifeLeft / 86400) msg = "Your certificate will expire in less than %d days. Please renew it!" % daysLeft sep = "=" * (len(msg) + 4) - msg = "%s\n %s \n%s" % (sep, msg, sep) - gLogger.notice(msg) + gLogger.notice("%s\n %s \n%s" % (sep, msg, sep)) def addVOMSExtIfNeeded(self): + """ Add VOMS extension if needed + + :return: S_OK()/S_ERROR() + """ addVOMS = self.__piParams.addVOMSExt or Registry.getGroupOption(self.__piParams.diracGroup, "AutoAddVOMS", False) if not addVOMS: return S_OK() @@ -118,13 +181,12 @@ def addVOMSExtIfNeeded(self): gLogger.notice("Added VOMS attribute %s" % vomsAttr) chain = resultVomsAttributes['Value'] - result = chain.dumpAllToFile(self.__proxyGenerated) - if not result["OK"]: - return result - return S_OK() + return chain.dumpAllToFile(self.__proxyGenerated) def createProxy(self): """ Creates the proxy on disk + + :return: S_OK()/S_ERROR() """ gLogger.notice("Generating proxy...") resultProxyGenerated = ProxyGeneration.generateProxy(piParams) @@ -136,6 +198,8 @@ def createProxy(self): def uploadProxy(self): """ Upload the proxy to the proxyManager service + + :return: S_OK()/S_ERROR() """ issuerCert = self.getIssuerCert() resultUserDN = issuerCert.getSubjectDN() # pylint: disable=no-member @@ -177,19 +241,22 @@ def printInfo(self): if self.__uploadedInfo: gLogger.notice("\nProxies uploaded:") maxDNLen = 0 - maxGroupLen = 0 - for userDN in self.__uploadedInfo: + maxProviderLen = len('ProxyProvider') + for userDN, data in self.__uploadedInfo.items(): maxDNLen = max(maxDNLen, len(userDN)) - for group in self.__uploadedInfo[userDN]: - maxGroupLen = max(maxGroupLen, len(group)) - gLogger.notice(" %s | %s | Until (GMT)" % ("DN".ljust(maxDNLen), "Group".ljust(maxGroupLen))) - for userDN in self.__uploadedInfo: - for group in self.__uploadedInfo[userDN]: - gLogger.notice(" %s | %s | %s" % (userDN.ljust(maxDNLen), - group.ljust(maxGroupLen), - self.__uploadedInfo[userDN][group].strftime("%Y/%m/%d %H:%M"))) + maxProviderLen = max(maxProviderLen, len(data['provider'])) + gLogger.notice(" %s | %s | %s | SupportedGroups" % ("DN".ljust(maxDNLen), "ProxyProvider".ljust(maxProviderLen), + "Until (GMT)".ljust(16))) + for userDN, data in self.__uploadedInfo.items(): + gLogger.notice(" %s | %s | %s | " % (userDN.ljust(maxDNLen), data['provider'].ljust(maxProviderLen), + data['expirationtime'].strftime("%Y/%m/%d %H:%M").ljust(16)), + ",".join(data['groups'])) def checkCAs(self): + """ Check CAs + + :return: S_OK() + """ if "X509_CERT_DIR" not in os.environ: gLogger.warn("X509_CERT_DIR is unset. Abort check of CAs") return @@ -222,6 +289,10 @@ def checkCAs(self): return S_OK() def doTheMagic(self): + """ Magic method + + :return: S_OK()/S_ERROR() + """ proxy = self.createProxy() if not proxy['OK']: return proxy @@ -248,6 +319,137 @@ def doTheMagic(self): return S_OK() + def doOAuthMagic(self): + """ Magic method with tokens + + :return: S_OK()/S_ERROR() + """ + import urllib3 + import threading + import webbrowser + import requests + import json + from authlib.integrations.requests_client import OAuth2Session + + from DIRAC.Core.Utilities.JEncode import encode + from DIRAC.ConfigurationSystem.Client.Utilities import getProxyAPI, getDIRACClientID + from DIRAC.FrameworkSystem.Utilities.halo import Halo, qrterminal + from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import submitUserAuthorizationFlow + from DIRAC.FrameworkSystem.private.authorization.grants.DeviceFlow import waitFinalStatusOfUserAuthorizationFlow + from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + from DIRAC.ConfigurationSystem.Client.Utilities import getAuthAPI, getDIRACClientID + + spinner = Halo() + proxyAPI = getProxyAPI() + + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + # Get IdP + result = IdProviderFactory().getIdProvider(self.__piParams.provider) + if not result['OK']: + return result + + idpObj = result['Value'] + + # Submit Device authorisation flow + with Halo('Authentification from %s.' % self.__piParams.provider) as spin: + if Script.enableCS()['OK']: + result = idpObj.submitDeviceCodeAuthorizationFlow(self.__piParams.diracGroup) + if not result['OK']: + sys.exit(result['Message']) + response = result['Value'] + else: + try: + r = requests.post('{api}/device?{group}'.format( + api=getAuthAPI(), + group = ('group=%s' % self.__piParams.diracGroup) if self.__piParams.diracGroup else '' + ), verify=False) + r.raise_for_status() + response = r.json() + # Check if all main keys are present here + for k in ['user_code', 'device_code', 'verification_uri']: + if not response.get(k): + sys.exit('Mandatory %s key is absent in authentication response.' % k) + except requests.exceptions.Timeout: + sys.exit('Authentication server is not answer, timeout.') + except requests.exceptions.RequestException as ex: + sys.exit(r.content or repr(ex)) + except Exception as ex: + sys.exit('Cannot read authentication response: %s' % repr(ex)) + + deviceCode = response['device_code'] + userCode = response['user_code'] + verURL = response['verification_uri'] + verURLComplete = response.get('verification_uri_complete') + interval = response.get('interval', 5) + + # Notify user to go to authorization endpoint + showURL = 'Use next link to continue, your user code is "%s"\n%s' % (userCode, verURL) + if self.__piParams.addQRcode: + if not verURLComplete: + spinner.warn('Cannot get verification_uri_complete for authentication.') + spinner.info(showURL) + else: + result = qrterminal(verURLComplete) + if not result['OK']: + spinner.fail(result['Message']) + spinner.info(showURL) + else: + # Show QR code + spinner.info('Scan QR code to continue: %s' % result['Value']) + else: + spinner.info(showURL) + + # Try to open in default browser + if webbrowser.open_new_tab(verURL): + spinner.text = '%s opening in default browser..' % verURL + + with Halo('Waiting authorization status..') as spin: + result = idpObj.waitFinalStatusOfDeviceCodeAuthorizationFlow(deviceCode) + if not result['OK']: + sys.exit(result['Message']) + idpObj.token = result['Value'] + + spin.color = 'green' + spin.text = 'Saving token.. to env DIRAC_TOKEN..' + + os.environ["DIRAC_TOKEN"] = json.dumps(idpObj.token) + + spin.text = 'Download proxy..' + url = '%s?lifetime=%s' % (proxyAPI, self.__piParams.proxyLifeTime) + addVOMS = self.__piParams.addVOMSExt or Registry.getGroupOption(self.__piParams.diracGroup, "AutoAddVOMS", False) + if addVOMS: + url += '&voms=%s' % addVOMS + if not idpObj.token.get('refresh_token'): + sys.exit('Refresh token is absent in response.') + url += '&refresh_token=%s' % idpObj.token['refresh_token'] + r = idpObj.get(url) + r.raise_for_status() + proxy = r.text + if not proxy: + sys.exit("Something went wrong, the proxy is empty.") + + if not self.__piParams.proxyLoc: + self.__piParams.proxyLoc = '/tmp/x509up_u%s' % os.getuid() + + spin.color = 'green' + spin.text = 'Saving proxy.. to %s..' % self.__piParams.proxyLoc + try: + with open(self.__piParams.proxyLoc, 'w+') as fd: + fd.write(proxy.encode("UTF-8")) + os.chmod(self.__piParams.proxyLoc, stat.S_IRUSR | stat.S_IWUSR) + except Exception as e: + return S_ERROR("%s :%s" % (self.__piParams.proxyLoc, repr(e).replace(',)', ')'))) + self.__piParams.certLoc = self.__piParams.proxyLoc + spin.text = 'Proxy is saved to %s.' % self.__piParams.proxyLoc + + result = Script.enableCS() + if not result['OK']: + return S_ERROR("Cannot contact CS to get user list") + threading.Thread(target=self.checkCAs).start() + gConfig.forceRefresh(fromMaster=True) + return S_OK(self.__piParams.proxyLoc) + @DIRACScript() def main(): @@ -260,9 +462,13 @@ def main(): DIRAC.gConfig.setOptionValue("/DIRAC/Security/UseServerCertificate", "False") pI = ProxyInit(piParams) - resultDoTheMagic = pI.doTheMagic() - if not resultDoTheMagic['OK']: - gLogger.fatal(resultDoTheMagic['Message']) + gLogger.info(gConfig.getConfigurationTree()) + if piParams.addProvider: + resultDoMagic = pI.doOAuthMagic() + else: + resultDoMagic = pI.doTheMagic() + if not resultDoMagic['OK']: + gLogger.fatal(resultDoMagic['Message']) sys.exit(1) pI.printInfo() diff --git a/src/DIRAC/Interfaces/API/DiracAdmin.py b/src/DIRAC/Interfaces/API/DiracAdmin.py index 6dc18b9c601..7733f24ca26 100755 --- a/src/DIRAC/Interfaces/API/DiracAdmin.py +++ b/src/DIRAC/Interfaces/API/DiracAdmin.py @@ -15,11 +15,13 @@ import os from DIRAC import gLogger, gConfig, S_OK, S_ERROR -from DIRAC.Core.Utilities.PromptUser import promptUser from DIRAC.Core.Base.API import API +from DIRAC.Core.Security.ProxyInfo import getProxyInfo +from DIRAC.Core.Utilities.PromptUser import promptUser +from DIRAC.Core.Utilities.Grid import ldapCEState, ldapCEVOView +from DIRAC.Core.Utilities.Grid import ldapSite, ldapCluster, ldapCE, ldapService from DIRAC.ConfigurationSystem.Client.CSAPI import CSAPI from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup -from DIRAC.Core.Security.ProxyInfo import getProxyInfo from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.FrameworkSystem.Client.NotificationClient import NotificationClient from DIRAC.ResourceStatusSystem.Client.ResourceStatusClient import ResourceStatusClient @@ -60,70 +62,48 @@ def __init__(self): ############################################################################# def uploadProxy(self): - """Upload a proxy to the DIRAC WMS. This method - - Example usage: - - >>> print diracAdmin.uploadProxy('dteam_pilot') - {'OK': True, 'Value': 0L} + """ Upload a proxy to the DIRAC WMS. This method - :return: S_OK,S_ERROR + Example usage: - :param permanent: Indefinitely update proxy - :type permanent: boolean + >>> print diracAdmin.uploadProxy('dteam_pilot') + {'OK': True, 'Value': 0L} + :return: S_OK,S_ERROR """ return gProxyManager.uploadProxy() ############################################################################# - def setProxyPersistency(self, userDN, userGroup, persistent=True): - """Set the persistence of a proxy in the Proxy Manager + def checkProxyUploaded(self, user, userGroup, requiredTime): + """ Check if a user(DN-group) has a proxy in the proxy management + Updates internal cache if needed to minimize queries to the service - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.setProxyPersistency( 'some DN', 'dirac group', True )) - {'OK': True } + >>> gLogger.notice(diracAdmin.checkProxyUploaded('user name', 'dirac group', 0)) + {'OK': True, 'Value' : True/False } - :param userDN: User DN - :type userDN: string - :param userGroup: DIRAC Group - :type userGroup: string - :param persistent: Persistent flag - :type persistent: boolean - :return: S_OK,S_ERROR - """ - return gProxyManager.setPersistency(userDN, userGroup, persistent) - - ############################################################################# - def checkProxyUploaded(self, userDN, userGroup, requiredTime): - """Set the persistence of a proxy in the Proxy Manager + :param str user: user name or DN + :param str userGroup: DIRAC Group + :param bool requiredTime: Required life time of the uploaded proxy - Example usage: - - >>> gLogger.notice(diracAdmin.setProxyPersistency( 'some DN', 'dirac group', True )) - {'OK': True, 'Value' : True/False } - - :param userDN: User DN - :type userDN: string - :param userGroup: DIRAC Group - :type userGroup: string - :param requiredTime: Required life time of the uploaded proxy - :type requiredTime: boolean - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ - return gProxyManager.userHasProxy(userDN, userGroup, requiredTime) + return gProxyManager.userHasProxy(user, userGroup, requiredTime) ############################################################################# def getSiteMask(self, printOutput=False, status='Active'): - """Retrieve current site mask from WMS Administrator service. + """ Retrieve current site mask from WMS Administrator service. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getSiteMask()) - {'OK': True, 'Value': 0L} + >>> gLogger.notice(diracAdmin.getSiteMask()) + {'OK': True, 'Value': 0L} - :return: S_OK,S_ERROR + :param bool printOutput: print output + :param str status: site status + :return: S_OK,S_ERROR """ result = self.sitestatus.getSites(siteState=status) @@ -138,15 +118,16 @@ def getSiteMask(self, printOutput=False, status='Active'): ############################################################################# def getBannedSites(self, printOutput=False): - """Retrieve current list of banned and probing sites. + """ Retrieve current list of banned and probing sites. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getBannedSites()) - {'OK': True, 'Value': []} + >>> gLogger.notice(diracAdmin.getBannedSites()) + {'OK': True, 'Value': []} - :return: S_OK,S_ERROR + :param bool printOutput: print output + :return: S_OK,S_ERROR """ bannedSites = self.sitestatus.getSites(siteState='Banned') @@ -166,14 +147,17 @@ def getBannedSites(self, printOutput=False): ############################################################################# def getSiteSection(self, site, printOutput=False): - """Simple utility to get the list of CEs for DIRAC site name. + """ Simple utility to get the list of CEs for DIRAC site name. + + Example usage: - Example usage: + >>> gLogger.notice(diracAdmin.getSiteSection('LCG.CERN.ch')) + {'OK': True, 'Value':} - >>> gLogger.notice(diracAdmin.getSiteSection('LCG.CERN.ch')) - {'OK': True, 'Value':} + :param str site: site + :param bool printOutput: print output - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ gridType = site.split('.')[0] if not gConfig.getSections('/Resources/Sites/%s' % (gridType))['OK']: @@ -186,15 +170,18 @@ def getSiteSection(self, site, printOutput=False): ############################################################################# def allowSite(self, site, comment, printOutput=False): - """Adds the site to the site mask. + """ Adds the site to the site mask. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.allowSite()) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.allowSite()) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str site: site + :param str comment: comment + :param bool printOutput: print output + :return: S_OK,S_ERROR """ result = self._checkSiteIsValid(site) if not result['OK']: @@ -223,14 +210,17 @@ def allowSite(self, site, comment, printOutput=False): ############################################################################# def getSiteMaskLogging(self, site=None, printOutput=False): - """Retrieves site mask logging information. + """ Retrieves site mask logging information. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getSiteMaskLogging('LCG.AUVER.fr')) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.getSiteMaskLogging('LCG.AUVER.fr')) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str site: site + :param bool printOutput: print output + + :return: S_OK,S_ERROR """ result = self._checkSiteIsValid(site) if not result['OK']: @@ -252,7 +242,7 @@ def getSiteMaskLogging(self, site=None, printOutput=False): sitesLogging = result['Value'] if isinstance(sitesLogging, dict): - for siteName, tupleList in sitesLogging.items(): # can be an iterator + for siteName, tupleList in sitesLogging.items(): if not siteName: gLogger.notice('\n===> %s\n' % siteName) for tup in tupleList: @@ -269,15 +259,18 @@ def getSiteMaskLogging(self, site=None, printOutput=False): ############################################################################# def banSite(self, site, comment, printOutput=False): - """Removes the site from the site mask. + """ Removes the site from the site mask. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.banSite()) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.banSite()) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str site: site + :param str comment: comment + :param bool printOutput: print output + :return: S_OK,S_ERROR """ result = self._checkSiteIsValid(site) if not result['OK']: @@ -306,16 +299,18 @@ def banSite(self, site, comment, printOutput=False): ############################################################################# def getServicePorts(self, setup='', printOutput=False): - """Checks the service ports for the specified setup. If not given this is - taken from the current installation (/DIRAC/Setup) + """ Checks the service ports for the specified setup. If not given this is + taken from the current installation (/DIRAC/Setup) - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getServicePorts()) - {'OK': True, 'Value':''} + >>> gLogger.notice(diracAdmin.getServicePorts()) + {'OK': True, 'Value':''} - :return: S_OK,S_ERROR + :param str setup: setup + :param bool printOutput: output + :return: S_OK,S_ERROR """ if not setup: setup = gConfig.getValue('/DIRAC/Setup', '') @@ -364,68 +359,76 @@ def getServicePorts(self, setup='', printOutput=False): return S_OK(result) ############################################################################# - def getProxy(self, userDN, userGroup, validity=43200, limited=False): - """Retrieves a proxy with default 12hr validity and stores - this in a file in the local directory by default. + def getProxy(self, user, userGroup, validity=43200, limited=False): + """ Retrieves a proxy with default 12hr validity and stores + this in a file in the local directory by default. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getProxy()) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.getProxy()) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str user: user name or DN + :param str userGroup: group name + :param int validity: proxy lifetime + :param bool limited: limited proxy + :return: S_OK,S_ERROR """ - return gProxyManager.downloadProxy(userDN, userGroup, limited=limited, + return gProxyManager.downloadProxy(user, userGroup, limited=limited, requiredTimeLeft=validity) ############################################################################# - def getVOMSProxy(self, userDN, userGroup, vomsAttr=False, validity=43200, limited=False): - """Retrieves a proxy with default 12hr validity and VOMS extensions and stores - this in a file in the local directory by default. + def getVOMSProxy(self, user, userGroup, validity=43200, limited=False): + """ Retrieves a proxy with default 12hr validity and VOMS extensions and stores + this in a file in the local directory by default. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getVOMSProxy()) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.getVOMSProxy()) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str user: user name or DN + :param str userGroup: group name + :param int validity: proxy lifetime + :param bool limited: limited proxy + :return: S_OK,S_ERROR """ - return gProxyManager.downloadVOMSProxy(userDN, userGroup, limited=limited, - requiredVOMSAttribute=vomsAttr, - requiredTimeLeft=validity) + return gProxyManager.downloadVOMSProxy(user, userGroup, limited=limited, requiredTimeLeft=validity) ############################################################################# - def getPilotProxy(self, userDN, userGroup, validity=43200): - """Retrieves a pilot proxy with default 12hr validity and stores - this in a file in the local directory by default. + def getPilotProxy(self, user, userGroup, validity=43200): + """ Retrieves a pilot proxy with default 12hr validity and stores + this in a file in the local directory by default. - Example usage: + Example usage: - >>> gLogger.notice(diracAdmin.getVOMSProxy()) - {'OK': True, 'Value': } + >>> gLogger.notice(diracAdmin.getVOMSProxy()) + {'OK': True, 'Value': } - :return: S_OK,S_ERROR + :param str user: user name or DN + :param str userGroup: group name + :param int validity: proxy lifetime + :return: S_OK,S_ERROR """ - - return gProxyManager.getPilotProxyFromDIRACGroup(userDN, userGroup, requiredTimeLeft=validity) + return gProxyManager.downloadCorrectProxy(user, userGroup, requiredTimeLeft=validity) ############################################################################# def resetJob(self, jobID): - """Reset a job or list of jobs in the WMS. This operation resets the reschedule - counter for a job or list of jobs and allows them to run as new. + """ Reset a job or list of jobs in the WMS. This operation resets the reschedule + counter for a job or list of jobs and allows them to run as new. - Example:: + Example:: - >>> gLogger.notice(dirac.reset(12345)) - {'OK': True, 'Value': [12345]} + >>> gLogger.notice(dirac.reset(12345)) + {'OK': True, 'Value': [12345]} - :param job: JobID - :type job: integer or list of integers - :return: S_OK,S_ERROR + :param jobID: JobID + :type jobID: integer or list of integers + :return: S_OK,S_ERROR """ if isinstance(jobID, six.string_types): try: @@ -443,16 +446,18 @@ def resetJob(self, jobID): ############################################################################# def getJobPilotOutput(self, jobID, directory=''): - """Retrieve the pilot output for an existing job in the WMS. - The output will be retrieved in a local directory unless - otherwise specified. + """ Retrieve the pilot output for an existing job in the WMS. + The output will be retrieved in a local directory unless + otherwise specified. + + >>> gLogger.notice(dirac.getJobPilotOutput(12345)) + {'OK': True, StdOut:'',StdError:''} - >>> gLogger.notice(dirac.getJobPilotOutput(12345)) - {'OK': True, StdOut:'',StdError:''} + :param jobID: JobID + :type jobID: int or str + :param str directory: path for output - :param job: JobID - :type job: integer or string - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ if not directory: directory = self.currentDir @@ -495,14 +500,15 @@ def getJobPilotOutput(self, jobID, directory=''): ############################################################################# def getPilotOutput(self, gridReference, directory=''): - """Retrieve the pilot output (std.out and std.err) for an existing job in the WMS. + """ Retrieve the pilot output (std.out and std.err) for an existing job in the WMS. - >>> gLogger.notice(dirac.getJobPilotOutput(12345)) - {'OK': True, 'Value': {}} + >>> gLogger.notice(dirac.getJobPilotOutput(12345)) + {'OK': True, 'Value': {}} - :param job: JobID - :type job: integer or string - :return: S_OK,S_ERROR + :param str gridReference: Pilot Job Reference + :param str directory: path for output + + :return: S_OK,S_ERROR """ if not isinstance(gridReference, six.string_types): return self._errorReport('Expected string for pilot reference') @@ -552,14 +558,14 @@ def getPilotOutput(self, gridReference, directory=''): ############################################################################# def getPilotInfo(self, gridReference): - """Retrieve info relative to a pilot reference + """ Retrieve info relative to a pilot reference + + >>> gLogger.notice(dirac.getPilotInfo(12345)) + {'OK': True, 'Value': {}} - >>> gLogger.notice(dirac.getPilotInfo(12345)) - {'OK': True, 'Value': {}} + :param str gridReference: Pilot Job Reference - :param gridReference: Pilot Job Reference - :type gridReference: string - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ if not isinstance(gridReference, six.string_types): return self._errorReport('Expected string for pilot reference') @@ -569,13 +575,14 @@ def getPilotInfo(self, gridReference): ############################################################################# def killPilot(self, gridReference): - """Kill the pilot specified + """ Kill the pilot specified + + >>> gLogger.notice(dirac.getPilotInfo(12345)) + {'OK': True, 'Value': {}} - >>> gLogger.notice(dirac.getPilotInfo(12345)) - {'OK': True, 'Value': {}} + :param str gridReference: Pilot Job Reference - :param gridReference: Pilot Job Reference - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ if not isinstance(gridReference, six.string_types): return self._errorReport('Expected string for pilot reference') @@ -585,14 +592,14 @@ def killPilot(self, gridReference): ############################################################################# def getPilotLoggingInfo(self, gridReference): - """Retrieve the pilot logging info for an existing job in the WMS. + """ Retrieve the pilot logging info for an existing job in the WMS. - >>> gLogger.notice(dirac.getPilotLoggingInfo(12345)) - {'OK': True, 'Value': {"The output of the command"}} + >>> gLogger.notice(dirac.getPilotLoggingInfo(12345)) + {'OK': True, 'Value': {"The output of the command"}} - :param gridReference: Gridp pilot job reference Id - :type gridReference: string - :return: S_OK,S_ERROR + :param str gridReference: Gridp pilot job reference Id + + :return: S_OK,S_ERROR """ if not isinstance(gridReference, six.string_types): return self._errorReport('Expected string for pilot reference') @@ -601,16 +608,16 @@ def getPilotLoggingInfo(self, gridReference): ############################################################################# def getJobPilots(self, jobID): - """Extract the list of submitted pilots and their status for a given - jobID from the WMS. Useful information is printed to the screen. + """ Extract the list of submitted pilots and their status for a given + jobID from the WMS. Useful information is printed to the screen. - >>> gLogger.notice(dirac.getJobPilots()) - {'OK': True, 'Value': {PilotID:{StatusDict}}} + >>> gLogger.notice(dirac.getJobPilots()) + {'OK': True, 'Value': {PilotID:{StatusDict}}} - :param job: JobID - :type job: integer or string - :return: S_OK,S_ERROR + :param job: JobID + :type job: int or str + :return: S_OK,S_ERROR """ if isinstance(jobID, six.string_types): try: @@ -625,15 +632,16 @@ def getJobPilots(self, jobID): ############################################################################# def getPilotSummary(self, startDate='', endDate=''): - """Retrieve the pilot output for an existing job in the WMS. Summary is - printed at INFO level, full dictionary of results also returned. + """ Retrieve the pilot output for an existing job in the WMS. Summary is + printed at INFO level, full dictionary of results also returned. + + >>> gLogger.notice(dirac.getPilotSummary()) + {'OK': True, 'Value': {CE:{Status:Count}}} - >>> gLogger.notice(dirac.getPilotSummary()) - {'OK': True, 'Value': {CE:{Status:Count}}} + :param str startDate: start date + :param str endDate: end date - :param job: JobID - :type job: integer or string - :return: S_OK,S_ERROR + :return: S_OK,S_ERROR """ result = PilotManagerClient().getPilotSummary(startDate, endDate) if not result['OK']: @@ -663,8 +671,13 @@ def getPilotSummary(self, startDate='', endDate=''): ############################################################################# def setSiteProtocols(self, site, protocolsList, printOutput=False): - """ - Allows to set the defined protocols for each SE for a given site. + """ Allows to set the defined protocols for each SE for a given site. + + :param str site: site + :param list protocolsList: protocols + :param bool printOutput: output + + :return: S_OK/S_ERROR """ result = self._checkSiteIsValid(site) if not result['OK']: @@ -720,144 +733,237 @@ def setSiteProtocols(self, site, protocolsList, printOutput=False): ############################################################################# def csSetOption(self, optionPath, optionValue): - """ - Function to modify an existing value in the CS. + """ Function to modify an existing value in the CS. + + :param str optionPath: option path + :param optionValue: value """ return self.csAPI.setOption(optionPath, optionValue) ############################################################################# def csSetOptionComment(self, optionPath, comment): - """ - Function to modify an existing value in the CS. + """ Function to modify an existing value in the CS. + + :param str optionPath: option path + :param str comment: comment """ return self.csAPI.setOptionComment(optionPath, comment) ############################################################################# def csModifyValue(self, optionPath, newValue): - """ - Function to modify an existing value in the CS. + """ Function to modify an existing value in the CS. + + :param str optionPath: option path + :param newValue: value """ return self.csAPI.modifyValue(optionPath, newValue) ############################################################################# def csRegisterUser(self, username, properties): - """ - Registers a user in the CS. - - - username: Username of the user (easy;) - - properties: Dict containing: - - DN - - groups : list/tuple of groups the user belongs to - - : More properties of the user, like mail + """ Registers a user in the CS. + :param str username: user name + :param dict properties: containing DN, groups, etc. + - groups : list/tuple of groups the user belongs to + - : More properties of the user, like mail """ return self.csAPI.addUser(username, properties) ############################################################################# def csDeleteUser(self, user): - """ - Deletes a user from the CS. Can take a list of users + """ Deletes a user from the CS. Can take a list of users + + :param str user: user name """ return self.csAPI.deleteUsers(user) ############################################################################# def csModifyUser(self, username, properties, createIfNonExistant=False): - """ - Modify a user in the CS. Takes the same params as in addUser and - applies the changes + """ Modify a user in the CS. Takes the same params as in addUser and applies the changes + + :param str username: user name + :param dict properties: containing DN, groups, etc. + :param bool createIfNonExistant: create user if non exist """ return self.csAPI.modifyUser(username, properties, createIfNonExistant) ############################################################################# def csListUsers(self, group=False): - """ - Lists the users in the CS. If no group is specified return all users. + """ Lists the users in the CS. If no group is specified return all users. + + :param str group: group name + + :return: list """ return self.csAPI.listUsers(group) ############################################################################# def csDescribeUsers(self, mask=False): - """ - List users and their properties in the CS. - If a mask is given, only users in the mask will be returned + """ List users and their properties in the CS. + + :param str mask: If a mask is given, only users in the mask will be returned + + :return: list """ return self.csAPI.describeUsers(mask) ############################################################################# def csModifyGroup(self, groupname, properties, createIfNonExistant=False): - """ - Modify a user in the CS. Takes the same params as in addGroup and applies - the changes + """ Modify a user in the CS. Takes the same params as in addGroup and applies the changes + + :param str groupname: group name + :param dict properties: properties + :param bool createIfNonExistant: create group if non exist """ return self.csAPI.modifyGroup(groupname, properties, createIfNonExistant) ############################################################################# def csListHosts(self): - """ - Lists the hosts in the CS + """ Lists the hosts in the CS + + :return: list """ return self.csAPI.listHosts() ############################################################################# def csDescribeHosts(self, mask=False): - """ - Gets extended info for the hosts in the CS + """ Gets extended info for the hosts in the CS + + :param mask: mask + + :return: list """ return self.csAPI.describeHosts(mask) ############################################################################# def csModifyHost(self, hostname, properties, createIfNonExistant=False): - """ - Modify a host in the CS. Takes the same params as in addHost and applies - the changes + """ Modify a host in the CS. Takes the same params as in addHost and applies the changes + + :param str hostname: host name + :param dict properties: properties + :param bool createIfNonExistant: create group if non exist """ return self.csAPI.modifyHost(hostname, properties, createIfNonExistant) ############################################################################# def csListGroups(self): - """ - Lists groups in the CS + """ Lists groups in the CS + + :return: list """ return self.csAPI.listGroups() ############################################################################# def csDescribeGroups(self, mask=False): - """ - List groups and their properties in the CS. - If a mask is given, only groups in the mask will be returned + """ List groups and their properties in the CS. + + :param mask: If a mask is given, only groups in the mask will be returned + + :return: list """ return self.csAPI.describeGroups(mask) ############################################################################# def csSyncUsersWithCFG(self, usersCFG): - """ - Synchronize users in cfg with its contents + """ Synchronize users in cfg with its contents + + :param object usersCFG: CFG """ return self.csAPI.syncUsersWithCFG(usersCFG) ############################################################################# def csCommitChanges(self, sortUsers=True): - """ - Commit the changes in the CS + """ Commit the changes in the CS + + :param list sortUsers: sort users """ return self.csAPI.commitChanges(sortUsers=False) ############################################################################# def sendMail(self, address, subject, body, fromAddress=None, localAttempt=True, html=False): - """ - Send mail to specified address with body. + """ Send mail to specified address with body. + + :param str address: address + :param str subject: subject + :param str body: body text + :param str fromAddress: address from who + :param bool localAttempt: local attempt + :param str html: html + + :return: S_OK/S_ERROR """ notification = NotificationClient() return notification.sendMail(address, subject, body, fromAddress, localAttempt, html) ############################################################################# def sendSMS(self, userName, body, fromAddress=None): - """ - Send mail to specified address with body. + """ Send mail to specified address with body. + + :param str userName: user name + :param str body: body text + :param str fromAddress: address from who + + :return: S_OK/S_ERROR """ if len(body) > 160: return S_ERROR('Exceeded maximum SMS length of 160 characters') notification = NotificationClient() return notification.sendSMS(userName, body, fromAddress) + ############################################################################# + def getBDIISite(self, site, host=None): + """ Get information about site from BDII at host + + :param str site: site name + :param str host: host name + """ + return ldapSite(site, host=host) + + ############################################################################# + def getBDIICluster(self, ce, host=None): + """ Get information about ce from BDII at host + + :param str ce: ce name + :param str host: host name + """ + return ldapCluster(ce, host=host) + + ############################################################################# + def getBDIICE(self, ce, host=None): + """ Get information about ce from BDII at host + + :param str ce: ce name + :param str host: host name + """ + return ldapCE(ce, host=host) + + ############################################################################# + def getBDIIService(self, ce, host=None): + """ Get information about ce from BDII at host + + :param str ce: ce name + :param str host: host name + """ + return ldapService(ce, host=host) + + ############################################################################# + def getBDIICEState(self, ce, useVO=voName, host=None): + """ Get information about ce state from BDII at host + + :param str ce: ce name + :param str useVO: VO name + :param str host: host name + """ + return ldapCEState(ce, useVO, host=host) + + ############################################################################# + def getBDIICEVOView(self, ce, useVO=voName, host=None): + """ Get information about ce voview from BDII at host + + :param str ce: CE name + :param str useVO: VO name + :param str host: host name + """ + return ldapCEVOView(ce, useVO, host=host) + # EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF#EOF# diff --git a/src/DIRAC/RequestManagementSystem/Client/Request.py b/src/DIRAC/RequestManagementSystem/Client/Request.py index c5e7f713232..bd3af7d79a7 100644 --- a/src/DIRAC/RequestManagementSystem/Client/Request.py +++ b/src/DIRAC/RequestManagementSystem/Client/Request.py @@ -25,6 +25,7 @@ # # from DIRAC from DIRAC import S_OK, S_ERROR from DIRAC.Core.Security.ProxyInfo import getProxyInfo +from DIRAC.Core.DISET.AuthManager import authorizeByCertificate from DIRAC.RequestManagementSystem.Client.Operation import Operation from DIRAC.RequestManagementSystem.private.JSONUtils import RMSEncoder from DIRAC.DataManagementSystem.Utilities.DMSHelpers import DMSHelpers @@ -77,6 +78,7 @@ def __init__(self, fromDict=None): self.JobID = 0 self.Error = None self.DIRACSetup = None + self.Owner = None self.OwnerDN = None self.RequestName = None self.OwnerGroup = None @@ -84,12 +86,17 @@ def __init__(self, fromDict=None): self.dmsHelper = DMSHelpers() + credDict = {} proxyInfo = getProxyInfo() if proxyInfo["OK"]: - proxyInfo = proxyInfo["Value"] - if proxyInfo["validGroup"] and proxyInfo["validDN"]: - self.OwnerDN = proxyInfo["identity"] - self.OwnerGroup = proxyInfo["group"] + credDict['DN'] = proxyInfo["Value"]['issuer'] + credDict['group'] = proxyInfo["Value"].get('group') + credDict['username'] = proxyInfo["Value"].get('username') + + if authorizeByCertificate(credDict): + self.Owner = credDict["username"] + self.OwnerDN = credDict["DN"] + self.OwnerGroup = credDict["group"] self.__operations__ = [] @@ -397,7 +404,7 @@ def toJSON(self): def _getJSONData(self): """ Returns the data that have to be serialized by JSON """ - attrNames = ['RequestID', "RequestName", "OwnerDN", "OwnerGroup", + attrNames = ['RequestID', "RequestName", "Owner", "OwnerDN", "OwnerGroup", "Status", "Error", "DIRACSetup", "SourceComponent", "JobID", "CreationTime", "SubmitTime", "LastUpdate", "NotBefore"] jsonData = {} diff --git a/src/DIRAC/RequestManagementSystem/private/OperationHandlerBase.py b/src/DIRAC/RequestManagementSystem/private/OperationHandlerBase.py index 0dae0d9447c..c20fd037d38 100644 --- a/src/DIRAC/RequestManagementSystem/private/OperationHandlerBase.py +++ b/src/DIRAC/RequestManagementSystem/private/OperationHandlerBase.py @@ -44,7 +44,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + __RCSID__ = "$Id $" + # # # @file OperationHandlerBase.py # @author Krzysztof.Ciba@NOSPAMgmail.com @@ -54,12 +56,14 @@ # # imports import os import six + # # from DIRAC from DIRAC import gLogger, gConfig, S_ERROR, S_OK from DIRAC.Core.Utilities.Graph import DynamicProps from DIRAC.RequestManagementSystem.Client.Operation import Operation from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupsWithVOMSAttribute +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupsWithVOMSAttribute,\ + getUsernameForDN, getGroupsForUser from DIRAC.Core.Utilities.ReturnValues import returnSingleResult from DIRAC.DataManagementSystem.Client.DataManager import DataManager from DIRAC.Resources.Catalog.FileCatalog import FileCatalog @@ -171,19 +175,33 @@ def getProxyForLFN(self, lfn): dirMeta = dirMeta["Value"] ownerRole = "/%s" % dirMeta["OwnerRole"] if not dirMeta["OwnerRole"].startswith("/") else dirMeta["OwnerRole"] + ownerGroup = dirMeta.get("OwnerGroup") ownerDN = dirMeta["OwnerDN"] + owner = dirMeta.get("Owner") + + if not owner: + result = getUsernameForDN(ownerDN) + if not result['OK']: + return result + owner = result['Value'] + ownerGroups = [ownerGroup] + if not ownerGroup: + result = getGroupsForUser(owner) + if not result['OK']: + return result + ownerGroups = result['Value'] ownerProxy = None - for ownerGroup in getGroupsWithVOMSAttribute(ownerRole): - vomsProxy = gProxyManager.downloadVOMSProxy(ownerDN, ownerGroup, limited=True, - requiredVOMSAttribute=ownerRole) + for ownerGroup in getGroupsWithVOMSAttribute(ownerRole, groupsList=ownerGroups): + vomsProxy = gProxyManager.downloadVOMSProxy(ownerDN or owner, ownerGroup, limited=True) if not vomsProxy["OK"]: - self.log.debug("getProxyForLFN: failed to get VOMS proxy for %s role=%s: %s" % (ownerDN, - ownerRole, - vomsProxy["Message"])) + self.log.debug("getProxyForLFN: failed to get VOMS proxy for %s@%s role=%s: %s" % (owner, + ownerGroup, + ownerRole, + vomsProxy["Message"])) continue ownerProxy = vomsProxy["Value"] - self.log.debug("getProxyForLFN: got proxy for %s@%s [%s]" % (ownerDN, ownerGroup, ownerRole)) + self.log.debug("getProxyForLFN: got proxy for %s@%s [%s]" % (owner, ownerGroup, ownerRole)) break if not ownerProxy: diff --git a/src/DIRAC/RequestManagementSystem/private/RequestTask.py b/src/DIRAC/RequestManagementSystem/private/RequestTask.py index b7f230bfa7c..acf8efc6362 100644 --- a/src/DIRAC/RequestManagementSystem/private/RequestTask.py +++ b/src/DIRAC/RequestManagementSystem/private/RequestTask.py @@ -128,21 +128,21 @@ def __setupManagerProxies(self): userName = shifterDict["Value"].get("User", "") userGroup = shifterDict["Value"].get("Group", "") - userDN = Registry.getDNForUsername(userName) - if not userDN["OK"]: - self.log.error("Cannot get DN For Username", "%s: %s" % (userName, userDN["Message"])) - continue - userDN = userDN["Value"][0] + result = Registry.getDNForUsernameInGroup(userName, userGroup) + if not result['OK']: + return result + userDN = result['Value'] + vomsAttr = Registry.getVOMSAttributeForGroup(userGroup) if vomsAttr: self.log.debug("getting VOMS [%s] proxy for shifter %s@%s (%s)" % (vomsAttr, userName, userGroup, userDN)) - getProxy = gProxyManager.downloadVOMSProxyToFile(userDN, userGroup, + getProxy = gProxyManager.downloadVOMSProxyToFile(userName, userGroup, requiredTimeLeft=1200, cacheTime=4 * 43200) else: self.log.debug("getting proxy for shifter %s@%s (%s)" % (userName, userGroup, userDN)) - getProxy = gProxyManager.downloadProxyToFile(userDN, userGroup, + getProxy = gProxyManager.downloadProxyToFile(userName, userGroup, requiredTimeLeft=1200, cacheTime=4 * 43200) if not getProxy["OK"]: @@ -169,9 +169,13 @@ def setupProxy(self): ownerDN = self.request.OwnerDN ownerGroup = self.request.OwnerGroup + result = Registry.getUsernameForDN(ownerDN) + if not result['OK']: + return result + owner = result['Value'] isShifter = [] for shifter, creds in self.__managersDict.items(): - if creds["ShifterDN"] == ownerDN and creds["ShifterGroup"] == ownerGroup: + if creds["ShifterName"] == owner and creds["ShifterGroup"] == ownerGroup: isShifter.append(shifter) if isShifter: proxyFile = self.__managersDict[isShifter[0]]["ProxyFile"] @@ -179,10 +183,10 @@ def setupProxy(self): return S_OK({"Shifter": isShifter, "ProxyFile": proxyFile}) # # if we're here owner is not a shifter at all - ownerProxyFile = gProxyManager.downloadVOMSProxyToFile(ownerDN, ownerGroup) + ownerProxyFile = gProxyManager.downloadVOMSProxyToFile(owner, ownerGroup) if not ownerProxyFile["OK"] or not ownerProxyFile["Value"]: reason = ownerProxyFile.get("Message", "No valid proxy found in ProxyManager.") - return S_ERROR("Change proxy error for '%s'@'%s': %s" % (ownerDN, ownerGroup, reason)) + return S_ERROR("Change proxy error for '%s'@'%s': %s" % (owner, ownerGroup, reason)) ownerProxyFile = ownerProxyFile["Value"] os.environ["X509_USER_PROXY"] = ownerProxyFile diff --git a/src/DIRAC/RequestManagementSystem/private/RequestValidator.py b/src/DIRAC/RequestManagementSystem/private/RequestValidator.py index 1c087c2f6ce..1a8b3d5d526 100644 --- a/src/DIRAC/RequestManagementSystem/private/RequestValidator.py +++ b/src/DIRAC/RequestManagementSystem/private/RequestValidator.py @@ -290,17 +290,19 @@ def setAndCheckRequestOwner(request, remoteCredentials): the RequestExecutingAgent :param request: the request to test - :param remoteCredentials: credentials from the clients + :param dict remoteCredentials: credentials from the clients :returns: True if everything is fine, False otherwise """ credDN = remoteCredentials['DN'] credGroup = remoteCredentials['group'] + credUsername = remoteCredentials['username'] credProperties = remoteCredentials['properties'] # If the owner or the group was not set, we use the one of the credentials - if not request.OwnerDN or not request.OwnerGroup: + if not request.Owner or not request.OwnerGroup: + request.Owner = credUsername request.OwnerDN = credDN request.OwnerGroup = credGroup return True @@ -308,7 +310,7 @@ def setAndCheckRequestOwner(request, remoteCredentials): # From here onward, we expect the ownerDN/group to already have a value # If the credentials in the Request match those from the credentials, it's OK - if request.OwnerDN == credDN and request.OwnerGroup == credGroup: + if request.Owner == credUsername and request.OwnerGroup == credGroup: return True # From here, something/someone is putting a request on behalf of someone else diff --git a/src/DIRAC/Resources/Catalog/FileCatalogClient.py b/src/DIRAC/Resources/Catalog/FileCatalogClient.py index d19444dfb83..6c66fea3278 100644 --- a/src/DIRAC/Resources/Catalog/FileCatalogClient.py +++ b/src/DIRAC/Resources/Catalog/FileCatalogClient.py @@ -8,8 +8,7 @@ from DIRAC import S_OK, S_ERROR from DIRAC.Core.Tornado.Client.ClientSelector import TransferClientSelector as TransferClient - -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup, getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup, getDNForUsernameInGroup from DIRAC.Resources.Catalog.Utilities import checkCatalogArguments from DIRAC.Resources.Catalog.FileCatalogClientBase import FileCatalogClientBase @@ -223,11 +222,8 @@ def getDirectoryMetadata(self, lfns, timeout=120): for path in result['Value']['Successful']: owner = result['Value']['Successful'][path]['Owner'] group = result['Value']['Successful'][path]['OwnerGroup'] - res = getDNForUsername(owner) - if res['OK']: - result['Value']['Successful'][path]['OwnerDN'] = res['Value'][0] - else: - result['Value']['Successful'][path]['OwnerDN'] = '' + ownerDN = getDNForUsernameInGroup(owner, group).get('Value', '') + result['Value']['Successful'][path]['OwnerDN'] = ownerDN result['Value']['Successful'][path]['OwnerRole'] = getVOMSAttributeForGroup(group) return result diff --git a/src/DIRAC/Resources/Catalog/LcgFileCatalogClient.py b/src/DIRAC/Resources/Catalog/LcgFileCatalogClient.py index 66fe95784a8..25bfca66eae 100755 --- a/src/DIRAC/Resources/Catalog/LcgFileCatalogClient.py +++ b/src/DIRAC/Resources/Catalog/LcgFileCatalogClient.py @@ -20,7 +20,7 @@ from DIRAC.Core.Utilities.Time import fromEpoch from DIRAC.Core.Utilities.List import breakListIntoChunks from DIRAC.Core.Security.ProxyInfo import getProxyInfo, formatProxyInfoAsString -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername, getVOMSAttributeForGroup, \ +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNsForUsername, getVOMSAttributeForGroup, \ getVOForGroup, getVOOption from DIRAC.Resources.Catalog.FileCatalogClientBase import FileCatalogClientBase @@ -59,7 +59,7 @@ def getClientCertInfo(): proxyInfo['VOMS'] = getVOMSAttributeForGroup(proxyInfo['group']) errStr = "getClientCertInfo: Proxy information does not contain the VOMs information." gLogger.warn(errStr) - res = getDNForUsername(proxyInfo['username']) + res = getDNsForUsername(proxyInfo['username']) if not res['OK']: errStr = "getClientCertInfo: Error getting known proxies for user." gLogger.error(errStr, res['Message']) diff --git a/src/DIRAC/Resources/Computing/GlobusComputingElement.py b/src/DIRAC/Resources/Computing/GlobusComputingElement.py index e89bfafabd2..e72b1fe366b 100644 --- a/src/DIRAC/Resources/Computing/GlobusComputingElement.py +++ b/src/DIRAC/Resources/Computing/GlobusComputingElement.py @@ -23,7 +23,7 @@ from DIRAC import S_OK, S_ERROR from DIRAC.Core.Utilities.File import makeGuid from DIRAC.Core.Utilities.Grid import executeGridCommand -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupOption +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.Resources.Computing.ComputingElement import ComputingElement from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB @@ -215,9 +215,13 @@ def getJobOutput(self, jobID, _localDir=None): if not result['OK'] or not result['Value']: return S_ERROR('Failed to determine owner for pilot ' + pilotRef) pilotDict = result['Value'][pilotRef] - owner = pilotDict['OwnerDN'] - group = getGroupOption(pilotDict['OwnerGroup'], 'VOMSRole', pilotDict['OwnerGroup']) - ret = gProxyManager.getPilotProxyFromVOMSGroup(owner, group) + owner = pilotDict['Owner'] + ownerDN = pilotDict['OwnerDN'] + group = pilotDict['OwnerGroup'] + if not getVOMSAttributeForGroup(group): + self.log.error("No voms attribute assigned to group %s when requested pilot proxy." % group) + return S_ERROR("Failed to get the pilot's owner proxy") + ret = gProxyManager.downloadVOMSProxy(ownerDN or owner, group) if not ret['OK']: self.log.error(ret['Message']) self.log.error('Could not get proxy:', 'User "%s", Group "%s"' % (owner, group)) diff --git a/src/DIRAC/Resources/IdProvider/IdProvider.py b/src/DIRAC/Resources/IdProvider/IdProvider.py index a3ea4c52383..b7695ae9666 100644 --- a/src/DIRAC/Resources/IdProvider/IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/IdProvider.py @@ -4,16 +4,70 @@ from __future__ import division from __future__ import print_function -from DIRAC import gLogger +from DIRAC import gLogger, S_OK, S_ERROR __RCSID__ = "$Id$" class IdProvider(object): - def __init__(self, parameters=None): + def __init__(self, *args, **kwargs): # parameters=None, sessionManager=None): + """ C'or + + :param dict parameters: parameters of the identity Provider + :param object sessionManager: session manager + """ self.log = gLogger.getSubLogger(self.__class__.__name__) - self.parameters = parameters + self.parameters = kwargs.get('parameters', {}) + self.sessionManager = kwargs.get('sessionManager') + self._initialization() + + def loadMetadata(self): + """ Load metadata to cache if needed + + :return: S_OK()/S_ERROR() + """ + return S_OK() + + def _initialization(self): + """ Initialization """ + pass def setParameters(self, parameters): + """ Set parameters + + :param dict parameters: parameters of the identity Provider + """ self.parameters = parameters + + def setManager(self, sessionManager): + """ Set session manager + + :param object sessionManager: session manager + """ + self.sessionManager = sessionManager + + def setLogger(self, logger): + """ Set logger + + :param object logger: logger + """ + self.log = logger + + def isSessionManagerAble(self): + """ Check if session manager is available + + :return: S_OK()/S_ERROR() + """ + if not self.sessionManager: + try: + from DIRAC.FrameworkSystem.Client.AuthManagerClient import gSessionManager + self.sessionManager = gSessionManager + except Exception as e: + return S_ERROR('Session manager is not available: %s' % e) + return S_OK() + + def getTokenWithAuth(self, *args, **kwargs): + """ Method to provide autherization flow on client side + """ + return S_ERROR('getTokenWithAuth not implemented.') diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index 1134654f218..67524a51adb 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -11,12 +11,15 @@ from __future__ import print_function from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getInfoAboutProviders +from DIRAC.Core.Utilities import ObjectLoader, ThreadSafe +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo __RCSID__ = "$Id$" +gCacheMetadata = ThreadSafe.Synchronizer() + class IdProviderFactory(object): ############################################################################# @@ -24,23 +27,56 @@ def __init__(self): """ Standard constructor """ self.log = gLogger.getSubLogger('IdProviderFactory') + self.cacheMetadata = DictCache() + + @gCacheMetadata + def getMetadata(self, idP): + return self.cacheMetadata.get(idP) + + @gCacheMetadata + def addMetadata(self, idP, data, time=24 * 3600): + if data: + self.cacheMetadata.add(idP, time, data) + + def getIdProviderForToken(self, token): + """ This method returns a IdProvider instance corresponding to the supplied + issuer in a token. + + :param str token: token jwt + + :return: S_OK(IdProvider)/S_ERROR() + """ + # Read token without verification to get issuer + issuer = jwt.decode(accessToken, options=dict(verify_signature=False))['iss'].strip('/') + result = getIdProviderForIssuer(issuer) + if not result['OK'] + return result + return self.getIdProvider(result['Value']) + ############################################################################# - def getIdProvider(self, idProvider): - """ This method returns a IdProvider instance corresponding to the supplied name. + def getIdProvider(self, idProvider, sessionManager=None): + """ This method returns a IdProvider instance corresponding to the supplied + name. :param str idProvider: the name of the Identity Provider + :param object sessionManager: session manager :return: S_OK(IdProvider)/S_ERROR() """ - result = getInfoAboutProviders(of='Id', providerName=idProvider, option="all", section="all") - if not result['OK']: - return result - pDict = result['Value'] - pDict['ProviderName'] = idProvider + if isinstance(idProvider, dict): + pDict = idProvider + else: + result = getProviderInfo(idProvider) + if not result['OK']: + self.log.error('Failed to read configuration', '%s: %s' % (idProvider, result['Message'])) + return result + pDict = result['Value'] + pDict['ProviderName'] = idProvider + pDict['sessionManager'] = sessionManager pType = pDict['ProviderType'] - self.log.verbose('Creating IdProvider', 'of %s type with the name %s' % (pType, idProvider)) + self.log.verbose('Creating IdProvider of %s type with the name %s' % (pType, idProvider)) subClassName = "%sIdProvider" % (pType) result = ObjectLoader().loadObject('Resources.IdProvider.%s' % subClassName) @@ -50,12 +86,22 @@ def getIdProvider(self, idProvider): pClass = result['Value'] try: - provider = pClass() - provider.setParameters(pDict) + meta = self.getMetadata(idProvider) + if meta: + pDict.update(meta) + provider = pClass(**pDict) + if not meta and hasattr(provider, 'metadata'): + # result = provider.loadMetadata() + # if not result['OK']: + # return result + # self.addMetadata(idProvider, result['Value']) + self.addMetadata(idProvider, provider.metadata) + # provider.setParameters(pDict) + # provider.setManager(sessionManager) except Exception as x: msg = 'IdProviderFactory could not instantiate %s object: %s' % (subClassName, str(x)) self.log.exception() - self.log.error(msg) + self.log.warn(msg) return S_ERROR(msg) return S_OK(provider) diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py new file mode 100644 index 00000000000..7dccbcf224f --- /dev/null +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -0,0 +1,466 @@ +""" IdProvider based on OAuth2 protocol +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +# TODO: only for python 3 +# import jwt +# from jwt import PyJWKClient + +from authlib.jose import JsonWebKey, jwt + +import re +import time +import pprint +import requests +from requests import exceptions +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope +from authlib.common.urls import url_decode +from authlib.common.security import generate_token +from authlib.integrations.requests_client import OAuth2Session +from authlib.oidc.discovery.well_known import get_well_known_url +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.oauth2.rfc6749.parameters import prepare_token_request +from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope + +from DIRAC import S_OK, S_ERROR, gLogger +from DIRAC.Resources.IdProvider.IdProvider import IdProvider +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderByAlias +from DIRAC.FrameworkSystem.private.authorization.utils.Sessions import Session +from DIRAC.FrameworkSystem.private.authorization.utils.Requests import createOAuth2Request +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthClients +from DIRAC.FrameworkSystem.private.authorization.utils.ProfileParser import * +from authlib.oauth2.rfc8628 import DEVICE_CODE_GRANT_TYPE +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSRoleGroupMapping, getVOForGroup, getGroupOption + +__RCSID__ = "$Id$" + +DEFAULT_HEADERS = { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' +} + + +def checkResponse(func): + def function_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions.Timeout: + return S_ERROR('Time out') + except exceptions.RequestException as ex: + return S_ERROR(str(ex)) + return function_wrapper + + +class OAuth2IdProvider(IdProvider, OAuth2Session): + def __init__(self, name=None, token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, token=None, token_placement='header', + update_token=None, **parameters): + """ OIDCClient constructor + """ + if 'ProviderName' not in parameters: + parameters['ProviderName'] = name + IdProvider.__init__(self, **parameters) + OAuth2Session.__init__(self, token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, token=token, token_placement=token_placement, + update_token=update_token, **parameters) + # Convert scope to list + self.parser = ProfileParser(**parameters) + scope = scope or '' + self.scope = [s.strip() for s in scope.strip().replace('+', ' ').split(',' if ',' in scope else ' ')] + self.parameters = parameters + self.exceptions = exceptions + self.name = parameters['ProviderName'] + + # Add hooks to raise HTTP errors + self.hooks['response'] = lambda r, *args, **kwargs: r.raise_for_status() + self.update_token = update_token or self._updateToken + self.store_token = self._storeToken + self.metadata_class = AuthorizationServerMetadata + self.server_metadata_url = parameters.get('server_metadata_url', get_well_known_url(self.metadata['issuer'], True)) + try: + self.metadata_class(self.metadata).validate() + except ValueError: + metadata = self.metadata_class(self.fetch_metadata()) + self.metadata.update(dict((k, v) for k, v in metadata.items() if k not in self.metadata)) + # for k, v in metadata.items(): + # if k not in self.metadata: + # self.metadata[k] = v + self.metadata_class(self.metadata).validate() + + # Set JWKs + self.jwks = parameters.get('jwks', self.fetch_metadata(self.metadata['jwks_uri'])) + if not self.jwks: + raise Exception('Cannot load JWKs for %s' % self.name) + + self.log.debug('"%s" OAuth2 IdP initialization done:' % self.name, + '\nclient_id: %s\nclient_secret: %s\nmetadata:\n%s' % (self.client_id, + self.client_secret, + pprint.pformat(self.metadata))) + + def _storeToken(self, token): + if self.sessionManager: + return self.sessionManager.storeToken(dict(self.token)) + return S_OK(None) + + def _updateToken(self, token, refresh_token): + if self.sessionManager: + # Here "token" is `OAuth2Token` type + self.sessionManager.updateToken(dict(token), refresh_token) + + def request(self, *args, **kwargs): + self.token_endpoint_auth_methods_supported = self.metadata.get('token_endpoint_auth_methods_supported') + if self.token_endpoint_auth_methods_supported: + if self.token_endpoint_auth_method not in self.token_endpoint_auth_methods_supported: + self.token_endpoint_auth_method = self.token_endpoint_auth_methods_supported[0] + return OAuth2Session.request(self, verify=False, *args, **kwargs) + + def verifyToken(self, token): + """ Token verification + + :param token: token + """ + # Research public keys for issuer + # # TODO: only for python 3 + # if not self.jwks: + # self.jwks = PyJWKClient(self.metadata['jwks_uri']) + # signing_key = self.jwks.get_signing_key_from_jwt(token) + + try: + return self._verify_jwt(token) + except Exception: + self.jwks = self.fetch_metadata(self.metadata['jwks_uri']) + return self._verify_jwt(token) + + def _verify_jwt(self, token): + """ + """ + return jwt.decode(token, JsonWebKey.import_key_set(self.jwks)) + + def fetch_metadata(self, url=None): + """ + """ + return self.get(url or self.server_metadata_url, withhold_token=True).json() + + def researchGroup(self, payload, token): + """ Research group + """ + return {} + + def getIDsMetadata(self, ids=None): + """ Metadata for IDs + """ + metadata = {} + result = self.isSessionManagerAble() + if not result['OK']: + return result + result = self.sessionManager.getIdPTokens(self.name, ids) + if not result['OK']: + return result + for token in result['Value']: + if token['user_id'] in metadata: + continue + self.token = token + result = self.__getUserInfo() + if result['OK']: + result = self.parser(result['Value']) + # result = self._parseUserProfile(result['Value']) + if result['OK']: + _, _, profile = result['Value'] + metadata[token['user_id']] = profile[self.name][token['user_id']] + + return S_OK(metadata) + + def submitDeviceCodeAuthorizationFlow(self, group=None): + """ Submit authorization flow + + :return: S_OK(dict)/S_ERROR() -- dictionary with device code flow response + """ + if group: + idPRole = getGroupOption(group, 'IdPRole') + if not idPRole: + return S_ERROR('Cannot find role for %s' % group) + group_scopes = [self.PARAM_SCOPE + idPRole] + + try: + r = requests.post(self.metadata['device_authorization_endpoint'], data=dict( + client_id=self.client_id, scope=list_to_scope(self.scope + group_scopes) + )) + r.raise_for_status() + deviceResponse = r.json() + if 'error' in deviceResponse: + return S_ERROR('%s: %s' % (deviceResponse['error'], deviceResponse.get('description', ''))) + + # Check if all main keys are present here + for k in ['user_code', 'device_code', 'verification_uri']: + if not deviceResponse.get(k): + return S_ERROR('Mandatory %s key is absent in authentication response.' % k) + + return S_OK(deviceResponse) + except requests.exceptions.Timeout: + return S_ERROR('Authentication server is not answer, timeout.') + except requests.exceptions.RequestException as ex: + return S_ERROR(r.content or repr(ex)) + except Exception as ex: + return S_ERROR('Cannot read authentication response: %s' % repr(ex)) + + def waitFinalStatusOfDeviceCodeAuthorizationFlow(self, deviceCode, interval=5, timeout=300): + """ Submit waiting loop process, that will monitor current authorization session status + + :param str deviceCode: received device code + :param int interval: waiting interval + :param int timeout: max time of waiting + + :return: S_OK(dict)/S_ERROR() - dictionary contain access/refresh token and some metadata + """ + __start = time.time() + + while True: + time.sleep(int(interval)) + if time.time() - __start > timeout: + return S_ERROR('Time out.') + r = requests.post(self.metadata['token_endpoint'], data=dict(client_id=self.client_id, + grant_type=DEVICE_CODE_GRANT_TYPE, + device_code=deviceCode)) + token = r.json() + if not token: + return S_ERROR('Resived token is empty!') + if 'error' not in token: + # os.environ['DIRAC_TOKEN'] = r.text + return S_OK(token) + if token['error'] != 'authorization_pending': + return S_ERROR(token['error'] + ' : ' + token.get('description', '')) + + def submitNewSession(self, session=None): + """ Submit new authorization session + + :param str session: session number + + :return: S_OK(str)/S_ERROR() + """ + url, state = self.create_authorization_url(self.metadata['authorization_endpoint'], state=self.generateState(session)) + return S_OK((url, state, {})) + + @checkResponse + def parseAuthResponse(self, response, session=None): + """ Make user info dict: + + :param dict response: response on request to get user profile + :param object session: session + + :return: S_OK(dict)/S_ERROR() + """ + response = createOAuth2Request(response) + + self.log.debug('Try to parse authentication response:', pprint.pformat(response.data)) + + if not session: + session = {} # Session(response.args['state']) + + self.log.debug('Current session is:\n', pprint.pformat(dict(session))) + # self.log.debug('Current metadata is:\n', pprint.pformat(self.metadata)) + + self.fetch_access_token(authorization_response=response.uri, + code_verifier=session.get('code_verifier')) + + # Get user info + result = self.__getUserInfo() + if not result['OK']: + return result + credDict = parseBasic(result['Value']) + credDict.update(parseEduperson(result['Value'])) + cerdDict = userDiscover(credDict) + result = self.parser(result['Value']) + if not result['OK']: + return result + username, userID, userProfile = result['Value'] + userProfile['credDict'] = credDict + + self.log.debug('Got response dictionary:\n', pprint.pformat(userProfile)) + + # Store token + self.token['client_id'] = self.client_id + self.token['provider'] = self.name + self.token['user_id'] = userID + self.log.debug('Store token to the database:\n', pprint.pformat(dict(self.token))) + + result = self.store_token(self.token) + if not result['OK']: + return result + + return S_OK((username, userID, userProfile)) + + def _fillUserProfile(self, useToken=None): + result = self.__getUserInfo(useToken) + return self.parser(result['Value']) if result['OK'] else result + # return self._parseUserProfile(result['Value']) if result['OK'] else result + + def __getUserInfo(self, useToken=None): + self.log.debug('Sent request to userinfo endpoint..') + r = None + try: + r = self.request('GET', self.metadata['userinfo_endpoint'], + withhold_token=useToken) + r.raise_for_status() + return S_OK(r.json()) + except (self.exceptions.RequestException, ValueError) as e: + return S_ERROR("%s: %s" % (repr(e), r.text if r else '')) + + # def _parseUserProfile(self, userProfile): + # """ Parse user profile + + # :param dict userProfile: user profile in OAuht2 format + + # :return: S_OK()/S_ERROR() + # """ + # # Generate username + # gname = userProfile.get('given_name') + # fname = userProfile.get('family_name') + # pname = userProfile.get('preferred_username') + # name = userProfile.get('name') and userProfile['name'].split(' ') + # username = pname or gname and fname and gname[0] + fname + # username = username or name and len(name) > 1 and name[0][0] + name[1] or '' + # username = re.sub('[^A-Za-z0-9]+', '', username.lower())[:13] + # self.log.debug('Parse user name:', username) + + # profile = {} + + # # Set provider + # profile['Provider'] = self.name + + # # Collect user info + # profile['ID'] = userProfile.get('sub') + # if not profile['ID']: + # return S_ERROR('No ID of user found.') + # profile['Email'] = userProfile.get('email') + # profile['FullName'] = gname and fname and ' '.join([gname, fname]) or name and ' '.join(name) or '' + # self.log.debug('Parse user profile:\n', profile) + + # # Default DIRAC groups, configured for IdP + # profile['Groups'] = self.parameters.get('DiracGroups') + # if profile['Groups'] and not isinstance(profile['Groups'], list): + # profile['Groups'] = profile['Groups'].replace(' ', '').split(',') + # self.log.debug('Default groups:', ', '.join(profile['Groups'] or [])) + # self.log.debug('Response Information:', pprint.pformat(userProfile)) + + # self.log.debug('Read regex syntax to get DNs describetion dictionary..') + # userDNs = {} + # dictItemRegex, listItemRegex = {}, None + # try: + # dnClaim = self.parameters['Syntax']['DNs']['claim'] + # for k, v in self.parameters['Syntax']['DNs'].items(): + # if isinstance(v, dict) and v.get('item'): + # dictItemRegex[k] = v['item'] + # elif k == 'item': + # listItemRegex = v + # except Exception as e: + # if not profile['Groups']: + # self.log.warn('No DNs described in Syntax/DNs IdP configuration section were found in the response.', + # "And no DiracGroups were found fo IdP.") + # return S_OK((username, profile)) + + # self.log.debug('Dict type items pattern:\n', pprint.pformat(dictItemRegex)) + # self.log.debug('List type items pattern:\n', pprint.pformat(listItemRegex)) + + # if not userProfile.get(dnClaim) and not profile['Groups']: + # self.log.warn('No "DiracGroups", no claim "%s" that describe DNs found.' % dnClaim) + # else: + # self.log.debug('Found "%s" claim that describe user DNs' % dnClaim) + # if not isinstance(userProfile[dnClaim], list): + # self.log.debug('Convert "%s" claim to list..' % dnClaim) + # userProfile[dnClaim] = userProfile[dnClaim].split(',') + + # for item in userProfile[dnClaim]: + # dnInfo = {} + # self.log.debug('Read "%s" item:' % dnClaim, item) + # if isinstance(item, dict): + # for subClaim, reg in dictItemRegex.items(): + # result = re.compile(reg).match(item[subClaim]) + # if result: + # for k, v in result.groupdict().items(): + # dnInfo[k] = v + # elif listItemRegex: + # result = re.compile(listItemRegex).match(item) + # if result: + # for k, v in result.groupdict().items(): + # dnInfo[k] = v + + # self.log.debug('Read parsed DN information:\n', dnInfo) + # if dnInfo.get('DN'): + # if not dnInfo['DN'].startswith('/'): + # self.log.debug('Convert %s to view with slashes.' % dnInfo['DN']) + # items = dnInfo['DN'].split(',') + # items.reverse() + # dnInfo['DN'] = '/' + '/'.join(items) + # if dnInfo.get('ProxyProvider'): + # self.log.debug('Found %s provider in item,' % dnInfo['ProxyProvider']) + # result = getProviderByAlias(dnInfo['ProxyProvider'], instance='Proxy') + # dnInfo['ProxyProvider'] = result['Value'] if result['OK'] else 'Certificate' + # self.log.debug('In the DIRAC configuration it corresponds to the ', dnInfo['ProxyProvider']) + # userDNs[dnInfo['DN']] = dnInfo + # if userDNs: + # profile['DNs'] = userDNs + + # self.log.verbose('We were able to compile the following profile for %s:\n' % username, profile) + # return S_OK((username, profile)) + + def exchange_token(self, url, subject_token=None, subject_token_type=None, body='', + refresh_token=None, access_token=None, auth=None, headers=None, **kwargs): + """ Fetch a new access token using a refresh token. + + :param url: Refresh Token endpoint, must be HTTPS. + :param str subject_token: subject_token + :param str subject_token_type: token type https://tools.ietf.org/html/rfc8693#section-3 + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param str refresh_token: refresh token + :param str access_token: access token + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: A :class:`OAuth2Token` object (a dict too). + """ + session_kwargs = self._extract_session_request_params(kwargs) + refresh_token = refresh_token or self.token.get('refresh_token') + access_token = access_token or self.token.get('access_token') + subject_token = subject_token or refresh_token + subject_token_type = subject_token_type or 'urn:ietf:params:oauth:token-type:refresh_token' + if 'scope' not in kwargs and self.scope: + kwargs['scope'] = self.scope + body = prepare_token_request('urn:ietf:params:oauth:grant-type:token-exchange', body, + subject_token=subject_token, subject_token_type=subject_token_type, **kwargs) + + if headers is None: + headers = DEFAULT_HEADERS + + for hook in self.compliance_hook.get('exchange_token_request', []): + url, headers, body = hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.token_endpoint_auth_method) + + return self._exchange_token(url, refresh_token=refresh_token, body=body, headers=headers, + auth=auth, **session_kwargs) + + def _exchange_token(self, url, body='', refresh_token=None, headers=None, auth=None, **kwargs): + resp = self.session.post(url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook.get('exchange_token_response', []): + resp = hook(resp) + + token = self.parse_response_token(resp.json()) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if callable(self.update_token): + self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def generateState(self, session=None): + return session or generate_token(10) \ No newline at end of file diff --git a/src/DIRAC/Resources/ProxyProvider/DIRACCAProxyProvider.py b/src/DIRAC/Resources/ProxyProvider/DIRACCAProxyProvider.py index 1409ce7a2ac..e3fba4e7221 100644 --- a/src/DIRAC/Resources/ProxyProvider/DIRACCAProxyProvider.py +++ b/src/DIRAC/Resources/ProxyProvider/DIRACCAProxyProvider.py @@ -97,11 +97,11 @@ def setParameters(self, parameters): if 'Algoritm' in parameters: self.algoritm = parameters['Algoritm'] if 'Match' in parameters: - self.match = [self.fields2nid[f] for f in parameters['Match']] + self.match = [self.fields2nid[f] for f in parameters['Match']] if parameters['Match'][0] else [] if 'Supplied' in parameters: - self.supplied = [self.fields2nid[f] for f in parameters['Supplied']] + self.supplied = [self.fields2nid[f] for f in parameters['Supplied']] if parameters['Supplied'][0] else [] if 'Optional' in parameters: - self.optional = [self.fields2nid[f] for f in parameters['Optional']] + self.optional = [self.fields2nid[f] for f in parameters['Optional']] if parameters['Optional'][0] else [] allFields = self.optional + self.supplied + self.match if 'DNOrder' in parameters: self.dnList = [] @@ -200,7 +200,7 @@ def getProxy(self, userDN): :param str userDN: user DN - :return: S_OK(str)/S_ERROR() -- contain a proxy string + :return: S_OK(object)/S_ERROR() -- contain a X509Chain() object """ self.__X509Name = X509.X509_Name() result = self.checkStatus(userDN) @@ -215,8 +215,15 @@ def getProxy(self, userDN): result = chain.loadKeyFromString(keyStr) if result['OK']: result = chain.generateProxyToString(365 * 24 * 3600, rfc=True) + if result['OK']: + chain = X509Chain() + result = chain.loadProxyFromString(result['Value']) + if result['OK']: - return result + # Store proxy in proxy manager + result = self.proxyManager._storeProxy(userDN, chain) + + return S_OK(chain) if result['OK'] else result def generateDN(self, **kwargs): """ Get DN of the user certificate that will be created @@ -370,7 +377,7 @@ def __createCertM2Crypto(self): userCert.set_version(2) userCert.set_subject(self.__X509Name) userCert.set_serial_number(int(random.random() * 10 ** 10)) - # Add extentionals + # Add extentials userCert.add_ext(X509.new_extension('basicConstraints', 'CA:' + str(False).upper())) userCert.add_ext(X509.new_extension('extendedKeyUsage', 'clientAuth', critical=1)) # Set livetime diff --git a/src/DIRAC/Resources/ProxyProvider/OAuth2ProxyProvider.py b/src/DIRAC/Resources/ProxyProvider/OAuth2ProxyProvider.py new file mode 100644 index 00000000000..45dfcae5344 --- /dev/null +++ b/src/DIRAC/Resources/ProxyProvider/OAuth2ProxyProvider.py @@ -0,0 +1,170 @@ +""" ProxyProvider implementation for the proxy generation using OIDC authorization flow +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint +import datetime +import requests +from requests import exceptions + +from DIRAC import S_OK, S_ERROR +from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error +from DIRAC.ConfigurationSystem.Client.Helpers import Registry +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProvidersForInstance, getProviderInfo +from DIRAC.ConfigurationSystem.Client.Utilities import getAuthAPI +from DIRAC.Resources.ProxyProvider.ProxyProvider import ProxyProvider +from DIRAC.Resources.IdProvider.IdProviderFactory import IdProviderFactory + +# from DIRAC.FrameworkSystem.Utilities.OAuth2 import OAuth2 +from DIRAC.FrameworkSystem.Client.AuthManagerClient import gSessionManager +from DIRAC.FrameworkSystem.Client.AuthManagerData import gAuthManagerData + +__RCSID__ = "$Id$" + + +class OAuth2ProxyProvider(ProxyProvider): + + def __init__(self, parameters=None): + # TODO: need do self.idpObj -- idP in contex(access tokens) we do request + super(OAuth2ProxyProvider, self).__init__(parameters) + self.__idps = IdProviderFactory() + + def setParameters(self, parameters): + self.parameters = parameters + self.proxy_endpoint = self.parameters['GetProxyEndpoint'] + self.idProviders = self.parameters['IdProvider'] or [] # TODO: Supported ID Providers + if not isinstance(self.parameters['IdProvider'], list): + self.idProviders = [self.parameters['IdProvider']] + if not self.idProviders: + result = getProvidersForInstance('Id', providerType='OAuth2') # TODO: Its not need + if not result['OK']: + return result + self.idProviders = result['Value'] + + def checkStatus(self, userDN): + """ Read ready to work status of proxy provider + + :param str userDN: user DN + + :return: S_OK(dict)/S_ERROR() -- dictionary contain fields: + - 'Status' with ready to work status[ready, needToAuth] + - 'AccessTokens' with list of access token + """ + result = gAuthManagerData.getIDsForDN(userDN, provider=self.parameters['ProviderName']) + if not result['OK']: + self.log.error(result['Message']) + return result + uid = result['Value'][0] + # TODO: authManagerService must get token throgh iDP with requested lifetime + result = gSessionManager.getTokenByUserIDAndProvider(uid, self.idProviders[0]) + if not result['OK']: + self.log.error(result['Message']) + return result + token = result['Value'] + if not token: + idP = self.idProviders[0] + return S_OK({'Status': 'needToAuth', 'Comment': 'Need to auth with %s identity provider' % idP, + 'Action': ['auth', [idP, 'inThread', '%s/auth/%s' % (getAuthAPI().strip('/'), idP)]]}) + + # Proxy uploaded in DB? + result = self.proxyManager._isProxyExist(userDN, 12 * 3600) + if not result['OK']: + self.log.error(result['Message']) + return result + if not result['Value']: + # Proxy not uploaded in DB, lets generate and upload + result = self.getProxy(userDN, token=token) + if not result['OK']: + self.log.error(result['Message']) + return result + + return S_OK({'Status': 'ready'}) + + def getProxy(self, userDN, token=None): + """ Generate user proxy with OIDC flow authentication + + :param str userDN: user DN + :param list sessions: sessions + + :return: S_OK/S_ERROR, Value is a proxy string + """ + if not token: + result = gAuthManagerData.getIDsForDN(userDN, provider=self.parameters['ProviderName']) + if not result['OK']: + self.log.error(result['Message']) + return result + uid = result['Value'][0] + # TODO: authManagerService must get token throgh iDP with requested lifetime + result = gSessionManager.getTokenByUserIDAndProvider(uid, self.idProviders[0]) + if not result['OK']: + self.log.error(result['Message']) + return result + token = result['Value'] + if not token: + return S_ERROR('Token not found for proxy request.') + + self.log.verbose('For proxy request use token:\n', pprint.pformat(token)) + + # Get proxy request + result = self.__getProxyRequest(token) + # if not result['OK']: + # # expire token to refresh + # token['expires_at'] = 1 + # result = self.__getProxyRequest(token) + if not result['OK']: + return result + if not result['Value']: + return S_ERROR('Returned proxy is empty.') + + self.log.info('Proxy is taken') + proxyStr = result['Value'].encode("utf-8") + + # Get DN + chain = X509Chain() + result = chain.loadProxyFromString(proxyStr) + if not result['OK']: + return result + result = chain.getCredentials() + if not result['OK']: + return result + DN = result['Value']['identity'] + + # Check + if DN != userDN: + return S_ERROR('Received proxy DN "%s" not match with requested DN "%s"' % (DN, userDN)) + + # Store proxy in proxy manager + result = self.proxyManager._storeProxy(DN, chain) + + return S_OK(chain) if result['OK'] else result # {'proxy': proxyStr, 'DN': DN}) + + def __getProxyRequest(self, token): + """ Get user proxy from proxy provider + + :param str session: access token + + :return: S_OK(basestring)/S_ERROR() + """ + result = self.__idps.getIdProvider(self.idProviders[0]) + if not result['OK']: + return result + provObj = result['Value'] + + provObj.token = token + url = '%s?access_type=offline' % self.proxy_endpoint + url += '&proxylifetime=%s' % self.parameters.get('MaxProxyLifetime', 3600 * 24) + url += '&client_id=%s&client_secret=%s' % (provObj.client_id, provObj.client_secret) + + # Get proxy request + self.log.verbose('Send proxy request to %s, with token:\n' % self.proxy_endpoint, pprint.pformat(provObj.token)) + self.log.debug("GET ", url) + + r = None + try: + r = provObj.request('GET', url, withhold_token=True) + r.raise_for_status() + return S_OK(r.text) + except provObj.exceptions.RequestException as e: + return S_ERROR("%s: %s" % (repr(e), r.text if r else '')) diff --git a/src/DIRAC/Resources/ProxyProvider/PUSPProxyProvider.py b/src/DIRAC/Resources/ProxyProvider/PUSPProxyProvider.py index 0b6ea25b50d..26df50abae4 100644 --- a/src/DIRAC/Resources/ProxyProvider/PUSPProxyProvider.py +++ b/src/DIRAC/Resources/ProxyProvider/PUSPProxyProvider.py @@ -72,4 +72,7 @@ def getProxy(self, userDN): if credDict['identity'] != userDN: return S_ERROR('Requested DN does not match the obtained one in the PUSP proxy') - return chain.generateProxyToString(lifeTime=credDict['secondsLeft']) + # Store proxy in proxy manager + result = self.proxyManager._storeProxy(userDN, chain) + + return S_OK(chain) if result['OK'] else result diff --git a/src/DIRAC/Resources/ProxyProvider/ProxyProvider.py b/src/DIRAC/Resources/ProxyProvider/ProxyProvider.py index 00fec638241..b7e19359270 100644 --- a/src/DIRAC/Resources/ProxyProvider/ProxyProvider.py +++ b/src/DIRAC/Resources/ProxyProvider/ProxyProvider.py @@ -4,22 +4,48 @@ from __future__ import division from __future__ import print_function from DIRAC import S_OK, S_ERROR +from DIRAC import S_OK, S_ERROR, gLogger __RCSID__ = "$Id$" class ProxyProvider(object): - def __init__(self, parameters=None): + def __init__(self, parameters=None, proxyManager=None): + """ C'or + :param dict parameters: parameters of the Proxy Provider + :param object proxyManager: proxy manager + """ + self.log = gLogger.getSubLogger(self.__class__.__name__) self.parameters = parameters - self.name = None - if parameters: - self.name = parameters.get('ProviderName') def setParameters(self, parameters): + """ Set parameters + + :param dict parameters: parameters of the proxy Provider + """ self.parameters = parameters - self.name = parameters.get('ProviderName') + + def setManager(self, proxyManager): + """ Set proxy manager + + :param object proxyManager: proxy manager + """ + self.proxyManager = proxyManager + + def isProxyManagerAble(self): + """ Check if proxy manager able + + :return: S_OK()/S_ERROR() + """ + if not self.proxyManager: + try: + from DIRAC.FrameworkSystem.Client.ProxyManagerClient import ProxyManagerClient + self.proxyManager = ProxyManagerClient() + except Exception as e: + return S_ERROR('Proxy manager not able: %s' % e) + return S_OK() def checkStatus(self, userDN): """ Read ready to work status of proxy provider @@ -37,4 +63,4 @@ def generateDN(self, **kwargs): :return: S_OK(str)/S_ERROR() -- contain DN """ - return S_ERROR("Not implemented in %s", self.name) + return S_ERROR("Not implemented in %s", self.parameters.get('ProviderName')) diff --git a/src/DIRAC/Resources/ProxyProvider/ProxyProviderFactory.py b/src/DIRAC/Resources/ProxyProvider/ProxyProviderFactory.py index 7f831ea1c32..b17286036d1 100644 --- a/src/DIRAC/Resources/ProxyProvider/ProxyProviderFactory.py +++ b/src/DIRAC/Resources/ProxyProvider/ProxyProviderFactory.py @@ -9,9 +9,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getInfoAboutProviders +from DIRAC.Core.Utilities import ObjectLoader +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getProviderInfo __RCSID__ = "$Id$" @@ -25,17 +26,18 @@ def __init__(self): self.log = gLogger.getSubLogger(__name__) ############################################################################# - def getProxyProvider(self, proxyProvider): + def getProxyProvider(self, proxyProvider, proxyManager=None): """ This method returns a ProxyProvider instance corresponding to the supplied name. :param str proxyProvider: the name of the Proxy Provider + :param object proxyManager: proxy manager :return: S_OK(ProxyProvider)/S_ERROR() """ if not proxyProvider: return S_ERROR('Provider name not set.') - result = getInfoAboutProviders(of='Proxy', providerName=proxyProvider, option='all', section='all') + result = getProviderInfo(proxyProvider) if not result['OK']: return result ppDict = result['Value'] @@ -55,6 +57,7 @@ def getProxyProvider(self, proxyProvider): try: pProvider = ppClass() pProvider.setParameters(ppDict) + pProvider.setManager(proxyManager) except Exception as x: msg = 'ProxyProviderFactory could not instantiate %s object: %s' % (subClassName, str(x)) self.log.exception() diff --git a/src/DIRAC/Resources/ProxyProvider/test/Test_DIRACCAProxyProvider.py b/src/DIRAC/Resources/ProxyProvider/test/Test_DIRACCAProxyProvider.py index b5ff1c24ebf..be60e8301be 100644 --- a/src/DIRAC/Resources/ProxyProvider/test/Test_DIRACCAProxyProvider.py +++ b/src/DIRAC/Resources/ProxyProvider/test/Test_DIRACCAProxyProvider.py @@ -13,15 +13,14 @@ try: import commands except ImportError: - # Python 3's subprocess module contains a compatibility layer + # PY3 subprocess module contains a compatibility layer import subprocess as commands + import unittest import tempfile - import pytest -import DIRAC -from DIRAC import gLogger +from DIRAC import gLogger, S_OK, rootPath from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Resources.ProxyProvider.DIRACCAProxyProvider import DIRACCAProxyProvider @@ -46,6 +45,14 @@ 'ProviderName': 'DIRAC_CA_CFG'} +class proxyManager(object): + """ Fake proxyManager + """ + def _storeProxy(self, userDN, chain): + """ Fake store method + """ + return S_OK() + class DIRACCAProviderTestCase(unittest.TestCase): @classmethod @@ -90,23 +97,23 @@ def tearDownClass(cls): class testDIRACCAProvider(DIRACCAProviderTestCase): + """ Base class for the testDIRACCAProvider test cases + """ @pytest.mark.slow def test_getProxy(self): """ Test 'getProxy' - try to get proxies for different users and check it """ - def check(proxyStr, proxyProvider, name): + def check(chain, proxyProvider, name): """ Check proxy - :param str proxyStr: proxy as string + :param object chain: proxy as string :param str proxyProvider: proxy provider name :param str name: proxy name """ proxyFile = os.path.join(testCAPath, proxyProvider + name.replace(' ', '') + '.pem') gLogger.info('Check proxy..') - chain = X509Chain() - result = chain.loadProxyFromString(proxyStr) - self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) + for result in [chain.getRemainingSecs(), chain.getIssuerCert(), chain.getPKeyObj(), @@ -125,8 +132,8 @@ def check(proxyStr, proxyProvider, name): (diracCAConf, 'read configuration file')]: gLogger.info('\n* Try proxy provider that %s..' % log) ca = DIRACCAProxyProvider() - result = ca.setParameters(proxyProvider) - self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) + ca.setParameters(proxyProvider) + ca.setManager(proxyManager()) gLogger.info('* Get proxy using FullName and Email of user..') for name, email, res in [('MrUser', 'good@mail.com', True), diff --git a/src/DIRAC/Resources/ProxyProvider/test/Test_ProxyProviderFactory.py b/src/DIRAC/Resources/ProxyProvider/test/Test_ProxyProviderFactory.py index 8166df1c0e2..f6a9d9a3e42 100644 --- a/src/DIRAC/Resources/ProxyProvider/test/Test_ProxyProviderFactory.py +++ b/src/DIRAC/Resources/ProxyProvider/test/Test_ProxyProviderFactory.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + __RCSID__ = "$Id$" import os @@ -15,25 +16,33 @@ certsPath = os.path.join(os.path.dirname(DIRAC.__file__), 'Core/Security/test/certs') -def sf_getInfoAboutProviders(of, providerName, option, section): - if of == 'Proxy' and option == 'all' and section == 'all': - if providerName == 'MY_DIRACCA': - return S_OK({'ProviderType': 'DIRACCA', - 'CertFile': os.path.join(certsPath, 'ca/ca.cert.pem'), - 'KeyFile': os.path.join(certsPath, 'ca/ca.key.pem'), - 'Supplied': ['O', 'OU', 'CN'], - 'Optional': ['emailAddress'], - 'DNOrder': ['O', 'OU', 'CN', 'emailAddress'], - 'OU': 'CA', - 'C': 'DN', - 'O': 'DIRACCA'}) - elif providerName == 'MY_PUSP': - return S_OK({'ProviderType': 'PUSP', 'ServiceURL': 'https://somedomain'}) +class proxyManager(object): + """ Fake proxyManager + """ + def _storeProxy(self, userDN, chain): + """ Fake store method + """ + return S_OK() + + +def sf_getProviderInfo(providerName): + if providerName == 'MY_DIRACCA': + return S_OK({'ProviderType': 'DIRACCA', + 'CertFile': os.path.join(certsPath, 'ca/ca.cert.pem'), + 'KeyFile': os.path.join(certsPath, 'ca/ca.key.pem'), + 'Supplied': ['O', 'OU', 'CN'], + 'Optional': ['emailAddress'], + 'DNOrder': ['O', 'OU', 'CN', 'emailAddress'], + 'OU': 'CA', + 'C': 'DN', + 'O': 'DIRACCA'}) + elif providerName == 'MY_PUSP': + return S_OK({'ProviderType': 'PUSP', 'ServiceURL': 'https://somedomain'}) return S_ERROR('No proxy provider found') -@mock.patch('DIRAC.Resources.ProxyProvider.ProxyProviderFactory.getInfoAboutProviders', - new=sf_getInfoAboutProviders) +@mock.patch('DIRAC.Resources.ProxyProvider.ProxyProviderFactory.getProviderInfo', + new=sf_getProviderInfo) class ProxyProviderFactoryTest(unittest.TestCase): """ Base class for the ProxyProviderFactory test cases """ @@ -42,7 +51,7 @@ def test_standalone(self): """ Test loading a proxy provider element with everything defined in itself. """ for provider, resultOfGenerateDN in [('MY_DIRACCA', True), ('MY_PUSP', False)]: - result = ProxyProviderFactory().getProxyProvider(provider) + result = ProxyProviderFactory().getProxyProvider(provider, proxyManager=proxyManager()) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) proxyProviderObj = result['Value'] result = proxyProviderObj.generateDN(FullName='test', Email='email@test.org') diff --git a/src/DIRAC/TransformationSystem/Agent/TaskManagerAgentBase.py b/src/DIRAC/TransformationSystem/Agent/TaskManagerAgentBase.py index 22bfb8a6c86..aeea9579c90 100644 --- a/src/DIRAC/TransformationSystem/Agent/TaskManagerAgentBase.py +++ b/src/DIRAC/TransformationSystem/Agent/TaskManagerAgentBase.py @@ -25,7 +25,7 @@ from DIRAC.Core.Utilities.List import breakListIntoChunks from DIRAC.Core.Utilities.Dictionaries import breakDictionaryIntoChunks from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername, getUsernameForDN +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsernameInGroup, getUsernameForDN from DIRAC.FrameworkSystem.Client.MonitoringClient import gMonitor from DIRAC.TransformationSystem.Client.FileReport import FileReport from DIRAC.TransformationSystem.Client.TaskManager import WorkflowTasks @@ -664,7 +664,10 @@ def __getCredentials(self): owner = resCred['Value']['User'] ownerGroup = resCred['Value']['Group'] # returns a list - ownerDN = getDNForUsername(owner)['Value'][0] + result = getDNForUsernameInGroup(owner, ownerGroup) + if not result['OK']: + return result + ownerDN = result['Value'] self.credTuple = (owner, ownerGroup, ownerDN) self.log.info("Cred: Tasks will be submitted with the credentials %s:%s" % (owner, ownerGroup)) return S_OK() diff --git a/src/DIRAC/TransformationSystem/Client/TaskManager.py b/src/DIRAC/TransformationSystem/Client/TaskManager.py index 086a6b77b45..774a22264a1 100644 --- a/src/DIRAC/TransformationSystem/Client/TaskManager.py +++ b/src/DIRAC/TransformationSystem/Client/TaskManager.py @@ -31,7 +31,7 @@ from DIRAC.WorkloadManagementSystem.Client.JobMonitoringClient import JobMonitoringClient from DIRAC.TransformationSystem.Client.TransformationClient import TransformationClient from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsernameInGroup from DIRAC.TransformationSystem.Agent.TransformationAgentsUtilities import TransformationAgentsUtilities COMPONENT_NAME = 'TaskManager' @@ -161,10 +161,10 @@ def prepareTransformationTasks(self, transBody, taskDict, owner='', ownerGroup=' ownerGroup = proxyInfo['group'] if not ownerDN: - res = getDNForUsername(owner) - if not res['OK']: - return res - ownerDN = res['Value'][0] + result = getDNForUsernameInGroup(owner, ownerGroup) + if not result['OK']: + return result + ownerDN = result['Value'] try: transJson = json.loads(transBody) @@ -502,10 +502,10 @@ def prepareTransformationTasks(self, transBody, taskDict, owner='', ownerGroup=' ownerGroup = proxyInfo['group'] if not ownerDN: - res = getDNForUsername(owner) - if not res['OK']: - return res - ownerDN = res['Value'][0] + result = getDNForUsernameInGroup(owner, ownerGroup) + if not result['OK']: + return result + ownerDN = result['Value'] if bulkSubmissionFlag: return self.__prepareTasksBulk(transBody, taskDict, owner, ownerGroup, ownerDN) diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/JobAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/JobAgent.py index 6d5b119457b..8de404e3bc0 100755 --- a/src/DIRAC/WorkloadManagementSystem/Agent/JobAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/JobAgent.py @@ -284,6 +284,7 @@ def execute(self): jobJDL = matcherInfo['JDL'] jobGroup = matcherInfo['Group'] ownerDN = matcherInfo['DN'] + owner = matcherInfo['User'] optimizerParams = {} for key in matcherInfo: @@ -348,7 +349,7 @@ def execute(self): jobReport.setJobStatus(status='Matched', minor='Job Received by Agent', sendFlag=False) - result_setupProxy = self._setupProxy(ownerDN, jobGroup) + result_setupProxy = self._setupProxy(owner, jobGroup) if not result_setupProxy['OK']: return self._rescheduleFailedJob( jobID, result_setupProxy['Message'], self.stopOnApplicationFailure) @@ -462,12 +463,12 @@ def _getCPUWorkLeft(self, cpuConsumed): return cpuWorkleft ############################################################################# - def _setupProxy(self, ownerDN, ownerGroup): + def _setupProxy(self, owner, ownerGroup): """ Retrieve a proxy for the execution of the job """ if gConfig.getValue('/DIRAC/Security/UseServerCertificate', False): - proxyResult = self._requestProxyFromProxyManager(ownerDN, ownerGroup) + proxyResult = self._requestProxyFromProxyManager(owner, ownerGroup) if not proxyResult['OK']: self.log.error('Failed to setup proxy', proxyResult['Message']) return S_ERROR('Failed to setup proxy: %s' % proxyResult['Message']) @@ -487,7 +488,7 @@ def _setupProxy(self, ownerDN, ownerGroup): groupProps = ret['Value']['groupProperties'] if Properties.GENERIC_PILOT in groupProps or Properties.PILOT in groupProps: - proxyResult = self._requestProxyFromProxyManager(ownerDN, ownerGroup) + proxyResult = self._requestProxyFromProxyManager(owner, ownerGroup) if not proxyResult['OK']: self.log.error('Invalid Proxy', proxyResult['Message']) return S_ERROR('Failed to setup proxy: %s' % proxyResult['Message']) @@ -496,18 +497,18 @@ def _setupProxy(self, ownerDN, ownerGroup): return S_OK(proxyChain) ############################################################################# - def _requestProxyFromProxyManager(self, ownerDN, ownerGroup): + def _requestProxyFromProxyManager(self, owner, ownerGroup): """Retrieves user proxy with correct role for job and sets up environment to run job locally. """ - self.log.info("Requesting proxy', 'for %s@%s" % (ownerDN, ownerGroup)) + self.log.info("Requesting proxy', 'for %s@%s" % (owner, ownerGroup)) token = gConfig.getValue("/Security/ProxyToken", "") if not token: self.log.verbose("No token defined. Trying to download proxy without token") token = False - retVal = gProxyManager.getPayloadProxyFromDIRACGroup(ownerDN, ownerGroup, - self.defaultProxyLength, token) + retVal = gProxyManager.downloadCorrectProxy(owner, ownerGroup, self.defaultProxyLength, + token=token) if not retVal['OK']: self.log.error('Could not retrieve payload proxy', retVal['Message']) os.system('dirac-proxy-info') diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py index 8b3316a2ad8..9a813eb35ca 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py @@ -97,6 +97,7 @@ def __init__(self, *args, **kwargs): # self.voGroups contain all the eligible user groups for pilots submitted by this SiteDirector self.voGroups = [] self.pilotDN = '' + self.pilotUser = '' self.pilotGroup = '' self.platforms = [] self.sites = [] @@ -185,11 +186,15 @@ def beginExecution(self): # Which credentials to use? # are they specific to the SD? (if not, get the generic ones) self.pilotDN = self.am_getOption("PilotDN", self.pilotDN) + self.pilotUser = self.am_getOption("PilotUser", self.pilotUser) self.pilotGroup = self.am_getOption("PilotGroup", self.pilotGroup) - result = findGenericPilotCredentials(vo=self.vo, pilotDN=self.pilotDN, pilotGroup=self.pilotGroup) + result = findGenericPilotCredentials(vo=self.vo, + pilotUser=self.pilotUser, + pilotDN=self.pilotDN, + pilotGroup=self.pilotGroup) if not result['OK']: return result - self.pilotDN, self.pilotGroup = result['Value'] + self.pilotUser, self.pilotGroup, self.pilotDN = result['Value'] # Parameters self.workingDirectory = self.am_getOption('WorkDirectory') @@ -235,6 +240,7 @@ def beginExecution(self): self.log.always('CETypes:', ceTypes) self.log.always('CEs:', ces) self.log.always('PilotDN:', self.pilotDN) + self.log.always('PilotUser:', self.pilotUser) self.log.always('PilotGroup:', self.pilotGroup) result = self.resourcesModule.getQueues(community=self.vo, @@ -553,9 +559,9 @@ def submitPilots(self): # Get the working proxy cpuTime = queueCPUTime + 86400 - self.log.verbose("Getting pilot proxy", - "for %s/%s %d long" % (self.pilotDN, self.pilotGroup, cpuTime)) - result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, self.pilotGroup, cpuTime) + self.log.verbose("Getting pilot proxy for", + "%s@%s (%s) %d long" % (self.pilotUser, self.pilotGroup, self.pilotDN, cpuTime)) + result = gProxyManager.downloadCorrectProxy(self.pilotUser, self.pilotGroup, cpuTime) if not result['OK']: return result proxy = result['Value'] @@ -1015,9 +1021,9 @@ def getExecutable(self, queue, proxy=None, jobExecDir='', envVariables=None, """ Prepare the full executable for queue :param str queue: queue name - :param bundleProxy: flag that say if to bundle or not the proxy - :type bundleProxy: bool - :param str queue: pilot execution dir (normally an empty string) + :param bool proxy: flag that say if to bundle or not the proxy + :param str jobExecDir: pilot execution dir (normally an empty string) + :param envVariables: env variables :returns: a string the options for the pilot :rtype: str @@ -1154,7 +1160,7 @@ def updatePilotStatus(self): """ # Generate a proxy before feeding the threads to renew the ones of the CEs to perform actions - result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, self.pilotGroup, 23400) + result = gProxyManager.downloadCorrectProxy(self.pilotUser, self.pilotGroup, 23400) if not result['OK']: return result proxy = result['Value'] @@ -1171,7 +1177,7 @@ def updatePilotStatus(self): ce = self.queueDict[queue]['CE'] if not ce.isProxyValid(120)['OK']: - result = gProxyManager.getPilotProxyFromDIRACGroup(self.pilotDN, self.pilotGroup, 1000) + result = gProxyManager.downloadCorrectProxy(self.pilotUser, self.pilotGroup, 1000) if not result['OK']: return result proxy = result['Value'] diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_JobAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_JobAgent.py index 939dea8a4bf..ed90d2ac5c6 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_JobAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_JobAgent.py @@ -120,18 +120,19 @@ def test__setupProxy(mocker, mockGCReplyInput, mockPMReplyInput, expected): mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.AgentModule.__init__") mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.AgentModule", side_effect=mockAM) mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.gConfig.getValue", side_effect=mockGCReply) - module_str = "DIRAC.WorkloadManagementSystem.Agent.JobAgent.gProxyManager.getPayloadProxyFromDIRACGroup" + module_str = "DIRAC.WorkloadManagementSystem.Agent.JobAgent.gProxyManager.downloadCorrectProxy" mocker.patch(module_str, side_effect=mockPMReply) jobAgent = JobAgent('Test', 'Test1') + owner = 'DIRAC' ownerDN = 'DIRAC' ownerGroup = 'DIRAC' jobAgent.log = gLogger jobAgent.log.setLevel('DEBUG') - result = jobAgent._setupProxy(ownerDN, ownerGroup) + result = jobAgent._setupProxy(owner, ownerGroup) assert result['OK'] == expected['OK'] @@ -179,18 +180,19 @@ def test__requestProxyFromProxyManager(mocker, mockGCReplyInput, mockPMReplyInpu mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.AgentModule.__init__") mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.AgentModule", side_effect=mockAM) mocker.patch("DIRAC.WorkloadManagementSystem.Agent.JobAgent.gConfig.getValue", side_effect=mockGCReply) - module_str = "DIRAC.WorkloadManagementSystem.Agent.JobAgent.gProxyManager.getPayloadProxyFromDIRACGroup" + module_str = "DIRAC.WorkloadManagementSystem.Agent.JobAgent.gProxyManager.downloadCorrectProxy" mocker.patch(module_str, side_effect=mockPMReply) jobAgent = JobAgent('Test', 'Test1') + owner = 'DIRAC' ownerDN = 'DIRAC' ownerGroup = 'DIRAC' jobAgent.log = gLogger jobAgent.log.setLevel('DEBUG') - result = jobAgent._requestProxyFromProxyManager(ownerDN, ownerGroup) + result = jobAgent._requestProxyFromProxyManager(owner, ownerGroup) assert result['OK'] == expected['OK'] diff --git a/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py b/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py index 68f956fbdab..9b247ad7c28 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/Matcher.py @@ -133,7 +133,7 @@ def selectJob(self, resourceDescription, credDict): if resOpt['OK']: for key, value in resOpt['Value'].items(): resultDict[key] = value - resAtt = self.jobDB.getJobAttributes(jobID, ['OwnerDN', 'OwnerGroup']) + resAtt = self.jobDB.getJobAttributes(jobID, ['Owner', 'OwnerDN', 'OwnerGroup']) if not resAtt['OK']: raise RuntimeError('Could not retrieve job attributes') if not resAtt['Value']: @@ -148,6 +148,7 @@ def selectJob(self, resourceDescription, credDict): self._updatePilotJobMapping(resourceDict, jobID) resultDict['DN'] = resAtt['Value']['OwnerDN'] + resultDict['User'] = resAtt['Value']['Owner'] resultDict['Group'] = resAtt['Value']['OwnerGroup'] resultDict['PilotInfoReportedFlag'] = True @@ -340,14 +341,13 @@ def _checkCredentials(self, resourceDict, credDict): resourceDict['OwnerGroup'] = credDict['group'] self.log.notice("Setting the resource group to the credentials group") if 'OwnerDN' in resourceDict and resourceDict['OwnerDN'] != credDict['DN']: - ownerDN = resourceDict['OwnerDN'] - result = Registry.getGroupsForDN(resourceDict['OwnerDN']) + result = Registry.getGroupsForDN(resourceDict['OwnerDN'], groupsList=credDict['group']) if not result['OK']: raise RuntimeError(result['Message']) if credDict['group'] not in result['Value']: # DN is not in the same group! bad boy. self.log.warn("You cannot request jobs from this DN, as it does not belong to your group!", - "(%s)" % ownerDN) + "(%s)" % resourceDict['OwnerDN']) resourceDict['OwnerDN'] = credDict['DN'] # Nothing special, group and DN have to be the same else: diff --git a/src/DIRAC/WorkloadManagementSystem/ConfigTemplate.cfg b/src/DIRAC/WorkloadManagementSystem/ConfigTemplate.cfg index 8b7c82a4e0d..edcd305a765 100644 --- a/src/DIRAC/WorkloadManagementSystem/ConfigTemplate.cfg +++ b/src/DIRAC/WorkloadManagementSystem/ConfigTemplate.cfg @@ -195,6 +195,8 @@ Agents GridEnv = # the DN of the certificate proxy used to submit pilots. If not found here, what is in Operations/Pilot section of the CS will be used PilotDN = + # the user of the certificate proxy used to submit pilots. If not found here, what is in Operations/Pilot section of the CS will be used + PilotUser = # the group of the certificate proxy used to submit pilots. If not found here, what is in Operations/Pilot section of the CS will be used PilotGroup = diff --git a/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py b/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py index 8562f350393..be3adc865dd 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py @@ -314,6 +314,10 @@ def getPilotInfo(self, pilotRef=False, parentId=False, conn=False, paramNames=[] pilotDict[parameters[i]] = row[i] if parameters[i] == 'PilotID': pilotIDs.append(row[i]) + result = getUsernameForDN(pilotDict['OwnerDN']) + if not result['OK']: + return result + pilotDict['Owner'] = result['Value'] resDict[row[0]] = pilotDict result = self.getJobsForPilot(pilotIDs) @@ -1062,8 +1066,10 @@ def getPilotMonitorWeb(self, selectDict, sortList, startItem, maxItems): userList = [userList] dnList = [] for uName in userList: - uList = getDNForUsername(uName)['Value'] - dnList += uList + result = getDNsForUsername(uName) + if not result['OK']: + return result + dnList += result['Value'] selectDict['OwnerDN'] = dnList del selectDict['Owner'] startDate = selectDict.get('FromDate', None) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobManagerHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/JobManagerHandler.py index eced66fd5aa..01c0b1e1388 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobManagerHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobManagerHandler.py @@ -80,7 +80,7 @@ def initialize(self): self.peerUsesLimitedProxy = credDict['isLimitedProxy'] self.diracSetup = self.serviceInfoDict['clientSetup'] self.maxParametricJobs = self.srv_getCSOption('MaxParametricJobs', MAX_PARAMETRIC_JOBS) - self.jobPolicy = JobPolicy(self.ownerDN, self.ownerGroup, self.userProperties) + self.jobPolicy = JobPolicy(self.owner, self.ownerGroup, self.userProperties) self.jobPolicy.jobDB = self.jobDB return S_OK() @@ -189,11 +189,6 @@ def export_submitJob(self, jobDesc): jobIDList.append(jobID) - # Set persistency flag - retVal = gProxyManager.getUserPersistence(self.ownerDN, self.ownerGroup) - if 'Value' not in retVal or not retVal['Value']: - gProxyManager.setPersistency(self.ownerDN, self.ownerGroup, True) - if parametricJob: result = S_OK(jobIDList) else: @@ -269,7 +264,7 @@ def __checkIfProxyUploadIsRequired(self): :return: bool """ - result = gProxyManager.userHasProxy(self.ownerDN, self.ownerGroup, validSeconds=18000) + result = gProxyManager.userHasProxy(self.owner, self.ownerGroup, validSeconds=18000) if not result['OK']: self.log.error("Can't check if the user has proxy uploaded", result['Message']) return True diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py index 6580eb9cb72..ef2e54f6f31 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py @@ -54,12 +54,13 @@ def initialize(self): """ credDict = self.getRemoteCredentials() + owner = credDict['username'] ownerDN = credDict['DN'] ownerGroup = credDict['group'] operations = Operations(group=ownerGroup) self.globalJobsInfo = operations.getValue( '/Services/JobMonitoring/GlobalJobsInfo', True) - self.jobPolicy = JobPolicy(ownerDN, ownerGroup, self.globalJobsInfo) + self.jobPolicy = JobPolicy(owner, ownerGroup, self.globalJobsInfo) self.jobPolicy.jobDB = self.jobDB return S_OK() diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobPolicy.py b/src/DIRAC/WorkloadManagementSystem/Service/JobPolicy.py index 06b35984748..fbef697e267 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobPolicy.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobPolicy.py @@ -9,20 +9,19 @@ from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Security import Properties -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getUsernameForDN, getGroupsForUser, \ - getPropertiesForGroup, getUsersInGroup +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupsForUser, getPropertiesForGroup, getUsersInGroup +RIGHT_KILL = 'Kill' +RIGHT_RESET = 'Reset' +RIGHT_DELETE = 'Delete' +RIGHT_SUBMIT = 'Submit' RIGHT_GET_JOB = 'GetJob' RIGHT_GET_INFO = 'GetInfo' +RIGHT_GET_STATS = 'GetStats' +RIGHT_RESCHEDULE = 'Reschedule' RIGHT_GET_SANDBOX = 'GetSandbox' RIGHT_PUT_SANDBOX = 'PutSandbox' RIGHT_CHANGE_STATUS = 'ChangeStatus' -RIGHT_DELETE = 'Delete' -RIGHT_KILL = 'Kill' -RIGHT_SUBMIT = 'Submit' -RIGHT_RESCHEDULE = 'Reschedule' -RIGHT_GET_STATS = 'GetStats' -RIGHT_RESET = 'Reset' ALL_RIGHTS = [RIGHT_GET_JOB, RIGHT_GET_INFO, RIGHT_GET_SANDBOX, RIGHT_PUT_SANDBOX, RIGHT_CHANGE_STATUS, RIGHT_DELETE, RIGHT_KILL, RIGHT_SUBMIT, @@ -51,33 +50,38 @@ class JobPolicy(object): - def __init__(self, userDN, userGroup, allInfo=True): + def __init__(self, username, userGroup, allInfo=True): + """ C'tor - self.userDN = userDN - self.userName = '' - result = getUsernameForDN(userDN) - if result['OK']: - self.userName = result['Value'] - self.userGroup = userGroup - self.userProperties = getPropertiesForGroup(userGroup, []) + :param str username: user name + :param str userGroup: group name + :param bool allInfo: all information + """ self.jobDB = None self.allInfo = allInfo + self.userName = username + self.userGroup = userGroup + self.userProperties = getPropertiesForGroup(userGroup, []) self.__permissions = {} self.__getUserJobPolicy() def getUserRightsForJob(self, jobID, owner=None, group=None): - """ Get access rights to job with jobID for the user specified by - userDN/userGroup + """ Get access rights to job with jobID for the user specified by username/userGroup + + :param str jobID: job ID + :param str owner: user name + :param str group: group name + + :return: S_OK()/S_ERROR() """ if owner is None or group is None: result = self.jobDB.getJobAttributes(jobID, ['Owner', 'OwnerGroup']) if not result['OK']: return result - elif result['Value']: - owner = result['Value']['OwnerDN'] - group = result['Value']['OwnerGroup'] - else: + if not result['Value']: return S_ERROR('Job not found') + owner = result['Value']['OwnerDN'] + group = result['Value']['OwnerGroup'] return self.getJobPolicy(owner, group) @@ -113,10 +117,15 @@ def __getUserJobPolicy(self): for right in PROPERTY_RIGHTS[Properties.GENERIC_PILOT]: self.__permissions[right] = True - def getJobPolicy(self, jobOwner='', jobOwnerGroup=''): - """ Get the job operations rights for a job owned by jobOwnerDN/jobOwnerGroup - for a user with userDN/userGroup. + def getJobPolicy(self, jobOwner=None, jobOwnerGroup=None): + """ Get the job operations rights for a job owned by jobOwner/jobOwnerGroup + for a user with username/userGroup. Returns a dictionary of various operations rights + + :param str jobOwner: user name + :param str jobOwnerGroup: group name + + :return: S_OK(dict)/S_ERROR() """ permDict = dict(self.__permissions) # Job Owner can do everything with his jobs @@ -133,7 +142,12 @@ def getJobPolicy(self, jobOwner='', jobOwnerGroup=''): return S_OK(permDict) def evaluateJobRights(self, jobList, right): - """ Get access rights to jobID for the user ownerDN/ownerGroup + """ Get access rights to jobID for the user owner/ownerGroup + + :param list jobList: job list + :param str right: right + + :return: tuple -- contain valid, invalid, nonauth, owner jobs """ validJobList = [] invalidJobList = [] @@ -180,6 +194,10 @@ def evaluateJobRights(self, jobList, right): def getControlledUsers(self, right): """ Get users and groups which jobs are subject to the given access right + + :param str right: right + + :return: S_OK()/S_ERROR() """ userGroupList = 'ALL' diff --git a/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py index f5e3fc77d86..bb72bcadf2a 100644 --- a/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py @@ -137,15 +137,16 @@ def export_getPilotLoggingInfo(cls, pilotReference): return S_ERROR('Failed to determine owner for pilot ' + pilotReference) pilotDict = result['Value'][pilotReference] - owner = pilotDict['OwnerDN'] - group = pilotDict['OwnerGroup'] + owner = pilotDict['Owner'] + ownerDN = pilotDict['OwnerDN'] + ownerGroup = pilotDict['OwnerGroup'] gridType = pilotDict['GridType'] pilotStamp = pilotDict['PilotStamp'] # Add the pilotStamp to the pilot Reference, some CEs may need it to retrieve the logging info pilotReference = pilotReference + ':::' + pilotStamp return getPilotLoggingInfo(gridType, pilotReference, # pylint: disable=unexpected-keyword-arg - proxyUserDN=owner, proxyUserGroup=group) + proxyUserName=owner, proxyUserGroup=ownerGroup) ############################################################################## types_getPilotSummary = [] @@ -258,9 +259,9 @@ def export_killPilot(cls, pilotRefList): return S_ERROR('Failed to get info for pilot ' + pilotReference) pilotDict = result['Value'][pilotReference] - owner = pilotDict['OwnerDN'] - group = pilotDict['OwnerGroup'] - queue = '@@@'.join([owner, group, pilotDict['GridSite'], pilotDict['DestinationSite'], pilotDict['Queue']]) + owner = pilotDict['Owner'] + ownerGroup = pilotDict['OwnerGroup'] + queue = '@@@'.join([owner, ownerGroup, pilotDict['GridSite'], pilotDict['DestinationSite'], pilotDict['Queue']]) gridType = pilotDict['GridType'] pilotRefDict.setdefault(queue, {}) pilotRefDict[queue].setdefault('PilotList', []) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py index bff2f9d7cce..1dbc2974c7d 100644 --- a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py @@ -14,7 +14,7 @@ from DIRAC.Core.Utilities.Grid import executeGridCommand from DIRAC.Core.Utilities.Proxy import executeWithUserProxy from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getQueue -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupOption +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOMSAttributeForGroup from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.Resources.Computing.ComputingElementFactory import ComputingElementFactory from DIRAC.WorkloadManagementSystem.Client.ServerUtils import pilotAgentsDB @@ -88,7 +88,8 @@ def getGridJobOutput(pilotReference): return S_ERROR('Pilot info is empty') pilotDict = result['Value'][pilotReference] - owner = pilotDict['OwnerDN'] + owner = pilotDict['Owner'] + ownerDN = pilotDict['OwnerDN'] group = pilotDict['OwnerGroup'] # FIXME: What if the OutputSandBox is not StdOut and StdErr, what do we do with other files? @@ -100,7 +101,8 @@ def getGridJobOutput(pilotReference): resultDict = {} resultDict['StdOut'] = stdout resultDict['StdErr'] = error - resultDict['OwnerDN'] = owner + resultDict['Owner'] = owner + resultDict['OwnerDN'] = ownerDN resultDict['OwnerGroup'] = group resultDict['FileList'] = [] return S_OK(resultDict) @@ -122,11 +124,13 @@ def getGridJobOutput(pilotReference): shutil.rmtree(queueDict['WorkingDirectory']) return result ce = result['Value'] - groupVOMS = getGroupOption(group, 'VOMSRole', group) - result = gProxyManager.getPilotProxyFromVOMSGroup(owner, groupVOMS) + if not getVOMSAttributeForGroup(group): + gLogger.error("No voms attribute assigned to group %s when requested pilot proxy." % group) + return S_ERROR("Failed to get the pilot's owner proxy") + result = gProxyManager.downloadVOMSProxy(owner, group) if not result['OK']: gLogger.error('Could not get proxy:', - 'User "%s" Group "%s" : %s' % (owner, groupVOMS, result['Message'])) + 'User "%s" Group "%s" : %s' % (owner, group, result['Message'])) return S_ERROR("Failed to get the pilot's owner proxy") proxy = result['Value'] ce.setProxy(proxy) @@ -147,7 +151,8 @@ def getGridJobOutput(pilotReference): resultDict = {} resultDict['StdOut'] = stdout resultDict['StdErr'] = error - resultDict['OwnerDN'] = owner + resultDict['Owner'] = owner + resultDict['OwnerDN'] = ownerDN resultDict['OwnerGroup'] = group resultDict['FileList'] = [] shutil.rmtree(queueDict['WorkingDirectory']) @@ -177,8 +182,10 @@ def killPilotsInQueues(pilotRefDict): # FIXME: quite hacky. Should be either removed, or based on some flag if gridType in ["CREAM", "ARC", "Globus", "HTCondorCE"]: - group = getGroupOption(group, 'VOMSRole', group) - ret = gProxyManager.getPilotProxyFromVOMSGroup(owner, group) + if not getVOMSAttributeForGroup(group): + gLogger.error("No voms attribute assigned to group %s when requested pilot proxy." % group) + return S_ERROR("Failed to get the pilot's owner proxy") + ret = gProxyManager.downloadVOMSProxy(owner, group) if not ret['OK']: gLogger.error('Could not get proxy:', 'User "%s" Group "%s" : %s' % (owner, group, ret['Message'])) diff --git a/src/DIRAC/WorkloadManagementSystem/private/ConfigHelper.py b/src/DIRAC/WorkloadManagementSystem/private/ConfigHelper.py index 2c0af95eb07..7605f2d0734 100644 --- a/src/DIRAC/WorkloadManagementSystem/private/ConfigHelper.py +++ b/src/DIRAC/WorkloadManagementSystem/private/ConfigHelper.py @@ -11,7 +11,7 @@ from DIRAC.ConfigurationSystem.Client.Helpers import Registry, Operations -def findGenericPilotCredentials(vo=False, group=False, pilotDN='', pilotGroup=''): +def findGenericPilotCredentials(vo=False, group=False, pilotDN='', pilotGroup='', pilotUser=None): """ Looks into the Operations/<>/Pilot section of CS to find the pilot credentials. Then check if the user has a registered proxy in ProxyManager. @@ -21,33 +21,34 @@ def findGenericPilotCredentials(vo=False, group=False, pilotDN='', pilotGroup='' :param str group: group name :param str pilotDN: pilot DN :param str pilotGroup: pilot group + :param str pilotUser: pilot user :return: S_OK(tuple)/S_ERROR() """ if not group and not vo: return S_ERROR("Need a group or a VO to determine the Generic pilot credentials") + vo = vo or Registry.getVOForGroup(group) if not vo: - vo = Registry.getVOForGroup(group) - if not vo: - return S_ERROR("Group %s does not have a VO associated" % group) + return S_ERROR("Group %s does not have a VO associated" % group) opsHelper = Operations.Operations(vo=vo) + pilotDN = pilotDN or opsHelper.getValue("Pilot/GenericPilotDN", "") + pilotUser = pilotUser or opsHelper.getValue("Pilot/GenericPilotUser", "") + pilotGroup = pilotGroup or opsHelper.getValue("Pilot/GenericPilotGroup", "") if not pilotGroup: - pilotGroup = opsHelper.getValue("Pilot/GenericPilotGroup", "") - if not pilotDN: - pilotDN = opsHelper.getValue("Pilot/GenericPilotDN", "") - if not pilotDN: - pilotUser = opsHelper.getValue("Pilot/GenericPilotUser", "") - if pilotUser: - result = Registry.getDNForUsername(pilotUser) - if result['OK']: - pilotDN = result['Value'][0] - if pilotDN and pilotGroup: - gLogger.verbose("Pilot credentials: %s@%s" % (pilotDN, pilotGroup)) - result = gProxyManager.userHasProxy(pilotDN, pilotGroup, 86400) + return S_ERROR("%s does not have group" % pilotDN or pilotUser) + if pilotUser and not pilotDN: + result = Registry.getDNForUsernameInGroup(pilotUser, pilotGroup) if not result['OK']: - return S_ERROR("%s@%s has no proxy in ProxyManager") - return S_OK((pilotDN, pilotGroup)) - - if pilotDN: - return S_ERROR("DN %s does not have group %s" % (pilotDN, pilotGroup)) - return S_ERROR("No generic proxy in the Proxy Manager with groups %s" % pilotGroup) + return result + pilotDN = result['Value'] + if pilotDN and not pilotUser: + result = Registry.getUsernameForDN(pilotDN) + if not result['OK']: + return result + pilotUser = result['Value'] + + gLogger.verbose("Pilot credentials: %s@%s (%s)" % (pilotUser, pilotGroup, pilotDN)) + result = gProxyManager.userHasProxy(pilotUser, pilotGroup, 86400) + if not result['OK']: + return S_ERROR("%s@%s has no proxy in ProxyManager" % (pilotUser, pilotGroup)) + return S_OK((pilotUser, pilotGroup, pilotDN)) diff --git a/src/DIRAC/WorkloadManagementSystem/private/correctors/BaseHistoryCorrector.py b/src/DIRAC/WorkloadManagementSystem/private/correctors/BaseHistoryCorrector.py index 55ed786d6ef..470609cfe29 100644 --- a/src/DIRAC/WorkloadManagementSystem/private/correctors/BaseHistoryCorrector.py +++ b/src/DIRAC/WorkloadManagementSystem/private/correctors/BaseHistoryCorrector.py @@ -10,7 +10,7 @@ import time as nativetime from DIRAC import S_OK, S_ERROR, gLogger -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNsForUsername from DIRAC.WorkloadManagementSystem.private.correctors.BaseCorrector import BaseCorrector @@ -100,7 +100,7 @@ def _getUsageHistoryForTimeSpan(self, timeSpan, groupToUse=""): if groupToUse: mappedData = {} for userName in data: - result = getDNForUsername(userName) + result = getDNsForUsername(userName) if not result['OK']: self.log.error("User does not have any DN assigned", "%s :%s" % (userName, result['Message'])) continue diff --git a/tests/Integration/Framework/Test_AuthDB.py b/tests/Integration/Framework/Test_AuthDB.py new file mode 100644 index 00000000000..e4859908722 --- /dev/null +++ b/tests/Integration/Framework/Test_AuthDB.py @@ -0,0 +1,91 @@ +""" This is a test of the AuthDB + It supposes that the DB is present and installed in DIRAC +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=invalid-name,wrong-import-position,protected-access +import sys +import pytest +import pprint + +from DIRAC import gConfig +from DIRAC.FrameworkSystem.DB.AuthDB import AuthDB + + +pprint.pprint(gConfig.getOptionsDictRecursively('/')) +db = AuthDB() + + +def test_Tokens(): + """ Try to store/get/remove Tokens + """ + # Example of the new token metadata + tData1 = {'access_token': 'eyJraWQiOiJvaWRjIiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiI5N2ZhZGY2M2U1NWixTokH0OMjseMTQMk36sU5O', + 'client_id': '2C7823B4-4A85-A912-E5D06D955809', + 'expires_at': 1616538163, + 'expires_in': 3599, + 'id_token': 'eyJraWQiOiJvaWRjIiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiI5N2ZhZGY2M2U1NWVhMzlkQGVnaS5ldSIsImF1ZCI61', + 'provider': 'CheckIn', + 'refresh_token': 'eyJhbGciOiJub25lIn0.eyJleHAImp0aSI6IjQwNDI5M2YwLTk4NztNDI0Yi04NDZjLWU1NDQzMWRjMmEzZSJ9.', + 'scope': 'openid offline_access profile eduperson_scoped_affiliation eduperson_unique_id', + 'token_type': 'Bearer', + 'user_id': '97fadf63e5123358a4f084e4c136475e377357c6723269f23eb9aba437fd6d9d@egi.eu'} + + # Example of updated token + tData2 = {'access_token': 'eyJraWQiOiJvaWRjIiwi4e4c136475e377357c6723269f23eb9aba437fd6d9dk36sU5Od', + 'client_id': '2C7823B4-4A85-A912-E5D06D955809', + 'expires_at': 1616538163, + 'expires_in': 3599, + 'id_token': 'eyJraWQiOiJvaWRjIiwiYWxnIjoiUlMy4e4c136475e377357c6723269f23eb9aba4F1ZCI6d1', + 'provider': 'CheckIn', + 'refresh_token': 'eyJhbGciOiJub25lIn0.eyJleHAImp0aSI6IjQ475e377357c6723269f23eb9aba4Fd9.', + 'scope': 'openid offline_access profile eduperson_scoped_affiliation eduperson_unique_id', + 'token_type': 'Bearer', + 'user_id': '97fadf63e5123358a4f084e4c136475e377357c6723269f23eb9aba437fd6d9d@egi.eu'} + + # Remove token + db.removeToken(tData1['access_token']) + db.removeToken(tData2['access_token']) + + # Add token + result = db.storeToken(tData1) + assert result['OK'], result['Message'] + + # Get token + result = db.getTokenByUserIDAndProvider(tData1['user_id'], tData1['provider']) + assert result['OK'], result['Message'] + assert result['Value']['refresh_token'] == tData1['refresh_token'] + assert result['Value']['access_token'] == tData1['access_token'] + + # Update token + result = db.updateToken(tData2, tData1['refresh_token']) + assert result['OK'], result['Message'] + + # Get token + result = db.getTokenByUserIDAndProvider(tData1['user_id'], tData1['provider']) + assert result['OK'], result['Message'] + assert result['Value']['refresh_token'] == tData2['refresh_token'] + assert result['Value']['access_token'] == tData2['access_token'] + + # Get token + result = db.getIdPTokens(tData2['provider']) + assert result['OK'], result['Message'] + aTokens = [] + for token in result['Value']: + aTokens.append(token['access_token']) + assert tData2['access_token'] in aTokens + assert tData1['access_token'] not in aTokens + + # Remove token + result = db.removeToken(tData2['access_token']) + assert result['OK'], result['Message'] + + # Make sure that the Client is absent + result = db.getIdPTokens(tData2['provider']) + assert result['OK'], result['Message'] + aTokens = [] + for token in result['Value']: + aTokens.append(token['access_token']) + assert tData2['access_token'] not in aTokens diff --git a/tests/Integration/Framework/Test_ProxyDB.py b/tests/Integration/Framework/Test_ProxyDB.py index 8ed10903a4c..07cf64c5ce0 100644 --- a/tests/Integration/Framework/Test_ProxyDB.py +++ b/tests/Integration/Framework/Test_ProxyDB.py @@ -57,6 +57,16 @@ } """ % (os.path.join(certsPath, 'ca/ca.cert.pem'), os.path.join(certsPath, 'ca/ca.key.pem')) +usersDNs = {'user_ca': '/C=DN/O=DIRACCA/OU=None/CN=user_ca/emailAddress=user_ca@diracgrid.org', + 'user': '/C=CC/O=DN/O=DIRAC/CN=user', + 'no_user': '/C=CC/O=DN/O=DIRAC/CN=no_user', + 'user_1': '/C=CC/O=DN/O=DIRAC/CN=user_1', + 'user_2': '/C=CC/O=DN/O=DIRAC/CN=user_2', + 'user_3': '/C=CC/O=DN/O=DIRAC/CN=user_3', + 'user_4': '/C=CC/O=DN/O=DIRAC/CN=user_4'} + +userNames = list(usersDNs.keys()) + userCFG = """ Registry { @@ -122,6 +132,7 @@ { Users = user_ca, user, user_1, user_2, user_3 VO = vo_1 + VOMSRole = role_1 } group_2 { @@ -315,12 +326,12 @@ def setUp(self): gLogger.debug('\n') if self.failed: self.fail(self.failed) - db._update('DELETE FROM ProxyDB_Proxies WHERE UserName IN ("user_ca", "user", "user_1", "user_2", "user_3")') - db._update('DELETE FROM ProxyDB_CleanProxies WHERE UserName IN ("user_ca", "user", "user_1", "user_2", "user_3")') + db._update('DELETE FROM ProxyDB_Proxies WHERE UserName IN ("%s")' % '", "'.join(userNames)) + db._update('DELETE FROM ProxyDB_CleanProxies WHERE UserDN IN ("%s")' % '", "'.join(list(usersDNs.values()))) def tearDown(self): - db._update('DELETE FROM ProxyDB_Proxies WHERE UserName IN ("user_ca", "user", "user_1", "user_2", "user_3")') - db._update('DELETE FROM ProxyDB_CleanProxies WHERE UserName IN ("user_ca", "user", "user_1", "user_2", "user_3")') + db._update('DELETE FROM ProxyDB_Proxies WHERE UserName IN ("%s")' % '", "'.join(userNames)) + db._update('DELETE FROM ProxyDB_CleanProxies WHERE UserDN IN ("%s")' % '", "'.join(list(usersDNs.values()))) @classmethod def tearDownClass(cls): @@ -352,16 +363,16 @@ def test_connectDB(self): def test_getUsers(self): """ Test 'getUsers' - try to get users from DB """ - field = '("%%s", "/C=CC/O=DN/O=DIRAC/CN=%%s", %%s "PEM", TIMESTAMPADD(SECOND, %%s, UTC_TIMESTAMP()))%s' % '' + field = '(%%s"/C=CC/O=DN/O=DIRAC/CN=%%s", %%s "PEM", TIMESTAMPADD(SECOND, %%s, UTC_TIMESTAMP()))%s' % '' # Fill table for test gLogger.info('\n* Fill tables for test..') for table, values, fields in [('ProxyDB_Proxies', - [field % ('user', 'user', '"group_1",', '800'), - field % ('user_2', 'user_2', '"group_1",', '-1')], + [field % ('"user", ', 'user', '"group_1",', '800'), + field % ('"user_2", ', 'user_2', '"group_1",', '-1')], '(UserName, UserDN, UserGroup, Pem, ExpirationTime)'), ('ProxyDB_CleanProxies', - [field % ('user_3', 'user_3', '', '43200')], - '(UserName, UserDN, Pem, ExpirationTime)')]: + [field % ('', 'user_3', '', '43200')], + '(UserDN, Pem, ExpirationTime)')]: result = db._update('INSERT INTO %s%s VALUES %s ;' % (table, fields, ', '.join(values))) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) # Testing 'getUsers' @@ -375,8 +386,8 @@ def test_getUsers(self): self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) usersList = [] for line in result['Value']: - if line['Name'] in ['user', 'user_2', 'user_3']: - usersList.append(line['Name']) + if line['user'] in ['user', 'user_2', 'user_3']: + usersList.append(line['user']) self.assertEqual(set(expect), set(usersList), str(usersList) + ', when expected ' + str(expect)) def test_purgeExpiredProxies(self): @@ -400,7 +411,7 @@ def test_getRemoveProxy(self): """ Testing get, store proxy """ gLogger.info('\n* Check that DB is clean..') - result = db.getProxiesContent({'UserName': ['user_ca', 'user', 'user_1' 'user_2', 'user_3']}, {}) + result = db.getProxiesContent({'UserName': userNames}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 0), 'In DB present proxies.') @@ -412,20 +423,20 @@ def test_getRemoveProxy(self): result = db._update(cmd) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) # Try to no correct getProxy requests - for dn, group, reqtime, log in [('/C=CC/O=DN/O=DIRAC/CN=user', 'group_1', 9999, - 'No proxy provider, set request time, not valid proxy in ProxyDB_Proxies'), - ('/C=CC/O=DN/O=DIRAC/CN=user', 'group_1', 0, - 'Not valid proxy in ProxyDB_Proxies'), - ('/C=CC/O=DN/O=DIRAC/CN=no_user', 'no_valid_group', 0, - 'User not exist, proxy not in DB tables'), - ('/C=CC/O=DN/O=DIRAC/CN=user', 'no_valid_group', 0, - 'Group not valid, proxy not in DB tables'), - ('/C=CC/O=DN/O=DIRAC/CN=user', 'group_1', 0, - 'No proxy provider for user, proxy not in DB tables'), - ('/C=CC/O=DN/O=DIRAC/CN=user_4', 'group_2', 0, - 'Group has option enableToDownload = False in CS')]: + for user, group, reqtime, log in [('user', 'group_1', 9999, + 'Not valid proxy, without proxy provider, with less lifetime'), + ('user', 'group_1', 0, + 'Not valid proxy, without proxy provider, with good lifetime'), + ('no_user', 'no_valid_group', 0, + 'User not exist, proxy not in DB tables'), + ('user', 'no_valid_group', 0, + 'Group not valid, proxy not in DB tables'), + ('user', 'group_1', 0, + 'Proxy removed from DB and not have a proxy provider'), + ('user_4', 'group_2', 0, + 'Group has option enableToDownload = False in CS')]: gLogger.info('== > %s:' % log) - result = db.getProxy(dn, group, reqtime) + result = db.getProxy(usersDNs[user], group, reqtime) self.assertFalse(result['OK'], 'Must be fail.') gLogger.info('Msg: %s' % result['Message']) # In the last case method found proxy and must to delete it as not valid @@ -433,52 +444,49 @@ def test_getRemoveProxy(self): self.assertTrue(bool(db._query(cmd)['Value'][0][0] == 0), "GetProxy method didn't delete the last proxy.") gLogger.info('* Check that DB is clean..') - result = db.getProxiesContent({'UserName': ['user_ca', 'user', 'user_1', 'user_2', 'user_3']}, {}) + result = db.getProxiesContent({'UserName': userNames}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 0), 'In DB present proxies.') gLogger.info('* Generate proxy on the fly..') - result = db.getProxy('/C=DN/O=DIRACCA/OU=None/CN=user_ca/emailAddress=user_ca@diracgrid.org', - 'group_1', 1800) + result = db.getProxy(usersDNs['user_ca'], 'group_1', 1800) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) gLogger.info('* Check that ProxyDB_CleanProxy contain generated proxy..') - result = db.getProxiesContent({'UserName': 'user_ca'}, {}) + result = db.getProxiesContent({'UserName': 'user_ca'}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 1), 'Generated proxy must be one.') for table, count in [('ProxyDB_Proxies', 0), ('ProxyDB_CleanProxies', 1)]: - cmd = 'SELECT COUNT( * ) FROM %s WHERE UserName="user_ca"' % table + cmd = 'SELECT COUNT( * ) FROM %s WHERE UserDN="%s"' % (table, usersDNs['user_ca']) self.assertTrue(bool(db._query(cmd)['Value'][0][0] == count), table + ' must ' + (count and 'contain proxy' or 'be empty')) gLogger.info('* Check that DB is clean..') - result = db.deleteProxy('/C=DN/O=DIRACCA/OU=None/CN=user_ca/emailAddress=user_ca@diracgrid.org', - proxyProvider='DIRAC_CA') + result = db.deleteProxy([usersDNs['user_ca']]) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - result = db.getProxiesContent({'UserName': ['user_ca', 'user', 'user_1', 'user_2', 'user_3']}, {}) + result = db.getProxiesContent({'UserName': userNames}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 0), 'In DB present proxies.') gLogger.info('* Upload proxy..') - for user, dn, group, vo, time, res, log in [("user", '/C=CC/O=DN/O=DIRAC/CN=user', "group_1", False, 12, + for user, dn, group, vo, time, res, log in [("user", usersDNs['user'], "group_1", False, 12, False, 'With group extension'), - ("user", '/C=CC/O=DN/O=DIRAC/CN=user', False, "vo_1", 12, + ("user", usersDNs['user'], False, "vo_1", 12, False, 'With voms extension'), - ("user_1", '/C=CC/O=DN/O=DIRAC/CN=user_1', False, "vo_1", 12, + ("user_1", usersDNs['user_1'], False, "vo_1", 12, False, 'With voms extension'), - ("user", '/C=CC/O=DN/O=DIRAC/CN=user', False, False, 0, + ("user", usersDNs['user'], False, False, 0, False, 'Expired proxy'), - ("no_user", '/C=CC/O=DN/O=DIRAC/CN=no_user', False, False, 12, + ("no_user", usersDNs['no_user'], False, False, 12, False, 'Not exist user'), - ("user", '/C=CC/O=DN/O=DIRAC/CN=user', False, False, 12, + ("user", usersDNs['user'], False, False, 12, True, 'Valid proxy')]: # Clean tables with proxies for table in ['ProxyDB_Proxies', 'ProxyDB_CleanProxies']: - result = db._update('DELETE FROM %s WHERE UserName = "user"' % table) - self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - result = db._update('DELETE FROM %s WHERE UserName = "user_1"' % table) - self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) + for _dn in [usersDNs['user'], usersDNs['user_1']]: + result = db._update('DELETE FROM %s WHERE UserDN = "%s"' % (table, dn)) + self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) + gLogger.info('== > %s:' % log) result = self.createProxy(user, group, time, vo=vo) @@ -489,7 +497,7 @@ def test_getRemoveProxy(self): if vo: self.assertTrue(bool(chain.isVOMS().get('Value')), 'Cannot create proxy with VOMS extension') - result = db.generateDelegationRequest(chain, dn) + result = db.generateDelegationRequest({'x509Chain': chain, 'DN': dn}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) resDict = result['Value'] result = chain.generateChainFromRequestString(resDict['request'], time * 3500) @@ -504,27 +512,26 @@ def test_getRemoveProxy(self): if not res: gLogger.info('Msg: %s' % (result['Message'])) cmd = 'SELECT COUNT( * ) FROM ProxyDB_Proxies WHERE UserName="%s"' % user - self.assertTrue(bool(db._query(cmd)['Value'][0][0] == 0), - 'ProxyDB_Proxies must ' + ('contain proxy' if res else 'be empty')) - cmd = 'SELECT COUNT( * ) FROM ProxyDB_CleanProxies WHERE UserName="%s"' % user + self.assertTrue(bool(db._query(cmd)['Value'][0][0] == 0), 'ProxyDB_Proxies must be empty') + cmd = 'SELECT COUNT( * ) FROM ProxyDB_CleanProxies WHERE UserDN="%s"' % usersDNs[user] self.assertTrue(bool(db._query(cmd)['Value'][0][0] == (1 if res else 0)), 'ProxyDB_CleanProxies must ' + ('contain proxy' if res else 'be empty')) # Last test test must leave proxy in DB gLogger.info('* Check that ProxyDB_CleanProxy contain generated proxy..') - result = db.getProxiesContent({'UserName': 'user'}, {}) + result = db.getProxiesContent({'UserName': 'user'}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 1), 'Generated proxy must be one.') - cmd = 'SELECT COUNT( * ) FROM ProxyDB_CleanProxies WHERE UserName="user"' + cmd = 'SELECT COUNT( * ) FROM ProxyDB_CleanProxies WHERE UserDN="%s"' % usersDNs[user] self.assertTrue(bool(db._query(cmd)['Value'][0][0] == 1), 'ProxyDB_CleanProxies must contain proxy') gLogger.info('* Get proxy that store only in ProxyDB_CleanProxies..') # Try to get proxy that was stored to ProxyDB_CleanProxies in previous step for res, group, reqtime, log in [(False, 'group_1', 24 * 3600, 'Request time more that in stored proxy'), - (False, 'group_2', 0, 'Request group not contain user'), + (True, 'group_2', 0, 'Request group not contain user(this check on service side)'), (True, 'group_1', 0, 'Request time less that in stored proxy')]: gLogger.info('== > %s:' % log) - result = db.getProxy('/C=CC/O=DN/O=DIRAC/CN=user', group, reqtime) + result = db.getProxy(usersDNs['user'], group, reqtime) text = 'Must be ended %s%s' % (res and 'successful' or 'with error', ': %s' % result.get('Message', 'Error message is absent.')) self.assertEqual(result['OK'], res, text) @@ -533,14 +540,13 @@ def test_getRemoveProxy(self): self.assertTrue(chain.isValidProxy()['OK'], '\n' + result.get('Message', 'Error message is absent.')) result = chain.getDIRACGroup() self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - self.assertEqual('group_1', result['Value'], 'Group must be group_1, not ' + result['Value']) else: gLogger.info('Msg: %s' % (result['Message'])) gLogger.info('* Check that DB is clean..') - result = db.deleteProxy('/C=CC/O=DN/O=DIRAC/CN=user', proxyProvider='Certificate') + result = db.deleteProxy([usersDNs['user']]) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - result = db.getProxiesContent({'UserName': ['user_ca', 'user', 'user_2', 'user_3']}, {}) + result = db.getProxiesContent({'UserName': userNames}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 0), 'In DB present proxies.') @@ -555,7 +561,7 @@ def test_getRemoveProxy(self): result = db._update(cmd) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) # Try to get it - result = db.getProxy(dn, group, 1800) + result = db.getProxy(usersDNs[user], group, 1800) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) # Check that proxy contain group chain = result['Value'][0] @@ -565,9 +571,9 @@ def test_getRemoveProxy(self): self.assertEqual('group_1', result['Value'], 'Group must be group_1, not ' + result['Value']) gLogger.info('* Check that DB is clean..') - result = db.deleteProxy('/C=CC/O=DN/O=DIRAC/CN=user') + result = db.deleteProxy([usersDNs['user']]) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - result = db.getProxiesContent({'UserName': ['user_ca', 'user', 'user_1', 'user_2', 'user_3']}, {}) + result = db.getProxiesContent({'UserName': userNames}) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) self.assertTrue(bool(int(result['Value']['TotalRecords']) == 0), 'In DB present proxies.') @@ -588,33 +594,28 @@ def test_getRemoveProxy(self): self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) # Try to get proxy with VOMS extension - for dn, group, role, time, log in [('/C=CC/O=DN/O=DIRAC/CN=user_4', 'group_2', False, 9999, - 'Not exist VO for current group'), - ('/C=CC/O=DN/O=DIRAC/CN=user', 'group_1', 'role_1', 9999, - 'Stored proxy already have different VOMS extension'), - ('/C=CC/O=DN/O=DIRAC/CN=user_1', 'group_1', 'role_1', 9999, - 'Stored proxy already have different VOMS extension'), - ('/C=DN/O=DIRACCA/OU=None/CN=user_ca/emailAddress=user_ca@diracgrid.org', - 'group_1', 'role_1', 9999, 'Not correct VO configuration')]: - gLogger.info('== > %s(DN: %s):' % (log, dn)) - if not any([dn, group, role, time, log]): - gLogger.info('voms-proxy-fake command not working as expected, proxy have no VOMS extention, go to the next..') - continue - result = db.getVOMSProxy(dn, group, time, role) + for user, group, time, log in [('user_4', 'group_2', 9999, + 'Not exist VO for current group'), + ('user', 'group_1', 9999, + 'Stored proxy already have different VOMS extension'), + ('user_1', 'group_1', 9999, + 'Stored proxy already have different VOMS extension'), + ('user_ca', 'group_1', 9999, 'Not correct VO configuration')]: + gLogger.info('== > %s(DN: %s):' % (log, usersDNs[user])) + result = db.getProxy(usersDNs[user], group, time, voms=True) self.assertFalse(result['OK'], 'Must be fail.') gLogger.info('Msg: %s' % result['Message']) # Check stored proxies for table, user, count in [('ProxyDB_Proxies', 'user', 1), ('ProxyDB_CleanProxies', 'user_ca', 1)]: - cmd = 'SELECT COUNT( * ) FROM %s WHERE UserName="%s"' % (table, user) + cmd = 'SELECT COUNT( * ) FROM %s WHERE UserDN="%s"' % (table, usersDNs[user]) self.assertTrue(bool(db._query(cmd)['Value'][0][0] == count)) gLogger.info('* Delete proxies..') - for dn, table in [('/C=CC/O=DN/O=DIRAC/CN=user', 'ProxyDB_Proxies'), - ('/C=DN/O=DIRACCA/OU=None/CN=user_ca/emailAddress=user_ca@diracgrid.org', - 'ProxyDB_CleanProxies')]: - result = db.deleteProxy(dn) + for user, table in [('user', 'ProxyDB_Proxies'), + ('user_ca', 'ProxyDB_CleanProxies')]: + result = db.deleteProxy([usersDNs[user]]) self.assertTrue(result['OK'], '\n' + result.get('Message', 'Error message is absent.')) - cmd = 'SELECT COUNT( * ) FROM %s WHERE UserName="user_ca"' % table + cmd = 'SELECT COUNT( * ) FROM %s WHERE UserDN="%s"' % (table, usersDNs[user]) self.assertTrue(bool(db._query(cmd)['Value'][0][0] == 0)) diff --git a/tests/Integration/RequestManagementSystem/FIXME_IntegrationFCT.py b/tests/Integration/RequestManagementSystem/FIXME_IntegrationFCT.py index d66976e2ae1..1f1a69c0005 100644 --- a/tests/Integration/RequestManagementSystem/FIXME_IntegrationFCT.py +++ b/tests/Integration/RequestManagementSystem/FIXME_IntegrationFCT.py @@ -35,7 +35,7 @@ from DIRAC.Core.Utilities.Adler import fileAdler from DIRAC.Core.Utilities.File import makeGuid from DIRAC.Interfaces.API.DiracAdmin import DiracAdmin -from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupsForUser, getDNForUsername +from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getGroupsForUser, getDNForUsernameInGroup # # from RMS and DMS from DIRAC.RequestManagementSystem.Client.Request import Request from DIRAC.RequestManagementSystem.Client.Operation import Operation @@ -56,9 +56,9 @@ class FullChainTest( object ): """ - def buildRequest( self, owner, group, sourceSE, targetSE1, targetSE2 ): + def buildRequest(self, owner, group, sourceSE, targetSE1, targetSE2): - files = self.files( owner, group ) + files = self.files(owner, group) putAndRegister = Operation() putAndRegister.Type = "PutAndRegister" @@ -71,128 +71,128 @@ def buildRequest( self, owner, group, sourceSE, targetSE1, targetSE2 ): putFile.ChecksumType = "adler32" putFile.Size = size putFile.GUID = guid - putAndRegister.addFile( putFile ) + putAndRegister.addFile(putFile) replicateAndRegister = Operation() replicateAndRegister.Type = "ReplicateAndRegister" - replicateAndRegister.TargetSE = "%s,%s" % ( targetSE1, targetSE2 ) + replicateAndRegister.TargetSE = "%s,%s" % (targetSE1, targetSE2) for fname, lfn, size, checksum, guid in files: repFile = File() repFile.LFN = lfn repFile.Size = size repFile.Checksum = checksum repFile.ChecksumType = "adler32" - replicateAndRegister.addFile( repFile ) + replicateAndRegister.addFile(repFile) removeReplica = Operation() removeReplica.Type = "RemoveReplica" removeReplica.TargetSE = sourceSE for fname, lfn, size, checksum, guid in files: - removeReplica.addFile( File( {"LFN": lfn } ) ) + removeReplica.addFile(File({"LFN": lfn})) removeFile = Operation() removeFile.Type = "RemoveFile" for fname, lfn, size, checksum, guid in files: - removeFile.addFile( File( {"LFN": lfn } ) ) + removeFile.addFile(File({"LFN": lfn})) removeFileInit = Operation() removeFileInit.Type = "RemoveFile" for fname, lfn, size, checksum, guid in files: - removeFileInit.addFile( File( {"LFN": lfn } ) ) + removeFileInit.addFile(File({"LFN": lfn})) req = Request() - req.addOperation( removeFileInit ) - req.addOperation( putAndRegister ) - req.addOperation( replicateAndRegister ) - req.addOperation( removeReplica ) - req.addOperation( removeFile ) + req.addOperation(removeFileInit) + req.addOperation(putAndRegister) + req.addOperation(replicateAndRegister) + req.addOperation(removeReplica) + req.addOperation(removeFile) return req - def files( self, userName, userGroup ): + def files(self, userName, userGroup): """ get list of files in user domain """ files = [] for i in range(10): fname = "/tmp/testUserFile-%s" % i if userGroup == "dteam_user": - lfn = "/lhcb/user/%s/%s/%s" % ( userName[0], userName, fname.split( "/" )[-1] ) + lfn = "/lhcb/user/%s/%s/%s" % (userName[0], userName, fname.split("/")[-1]) else: - lfn = "/lhcb/certification/test/rmsdms/%s" % fname.split( "/" )[-1] - fh = open( fname, "w+" ) + lfn = "/lhcb/certification/test/rmsdms/%s" % fname.split("/")[-1] + fh = open(fname, "w+") for i in range(100): - fh.write( str( random.randint( 0, i ) ) ) + fh.write(str(random.randint(0, i))) fh.close() - size = os.stat( fname ).st_size - checksum = fileAdler( fname ) - guid = makeGuid( fname ) - files.append( ( fname, lfn, size, checksum, guid ) ) + size = os.stat(fname).st_size + checksum = fileAdler(fname) + guid = makeGuid(fname) + files.append((fname, lfn, size, checksum, guid)) return files - def putRequest( self, userName, userDN, userGroup, sourceSE, targetSE1, targetSE2 ): + def putRequest(self, userName, userDN, userGroup, sourceSE, targetSE1, targetSE2): """ test case for user """ - req = self.buildRequest( userName, userGroup, sourceSE, targetSE1, targetSE2 ) + req = self.buildRequest(userName, userGroup, sourceSE, targetSE1, targetSE2) - req.RequestName = "test%s-%s" % ( userName, userGroup ) + req.RequestName = "test%s-%s" % (userName, userGroup) req.OwnerDN = userDN req.OwnerGroup = userGroup - gLogger.always( "putRequest: request '%s'" % req.RequestName ) + gLogger.always("putRequest: request '%s'" % req.RequestName) for op in req: - gLogger.always( "putRequest: => %s %s %s" % ( op.Order, op.Type, op.TargetSE ) ) + gLogger.always("putRequest: => %s %s %s" % (op.Order, op.Type, op.TargetSE)) for f in op: - gLogger.always( "putRequest: ===> file %s" % f.LFN ) + gLogger.always("putRequest: ===> file %s" % f.LFN) reqClient = ReqClient() - delete = reqClient.deleteRequest( req.RequestName ) + delete = reqClient.deleteRequest(req.RequestName) if not delete["OK"]: - gLogger.error( "putRequest: %s" % delete["Message"] ) + gLogger.error("putRequest: %s" % delete["Message"]) return delete - put = reqClient.putRequest( req ) + put = reqClient.putRequest(req) if not put["OK"]: - gLogger.error( "putRequest: %s" % put["Message"] ) + gLogger.error("putRequest: %s" % put["Message"]) return put # # test execution if __name__ == "__main__": - if len( sys.argv ) != 5: - gLogger.error( "Usage:\n python %s userGroup SourceSE TargetSE1 TargetSE2\n" ) - sys.exit( -1 ) + if len(sys.argv) != 5: + gLogger.error("Usage:\n python %s userGroup SourceSE TargetSE1 TargetSE2\n") + sys.exit(-1) userGroup = sys.argv[1] sourceSE = sys.argv[2] targetSE1 = sys.argv[3] targetSE2 = sys.argv[4] - gLogger.always( "will use '%s' group" % userGroup ) + gLogger.always("will use '%s' group" % userGroup) admin = DiracAdmin() userName = admin._getCurrentUser() if not userName["OK"]: - gLogger.error( userName["Message"] ) - sys.exit( -1 ) + gLogger.error(userName["Message"]) + sys.exit(-1) userName = userName["Value"] - gLogger.always( "current user is '%s'" % userName ) + gLogger.always("current user is '%s'" % userName) - userGroups = getGroupsForUser( userName ) + userGroups = getGroupsForUser(userName) if not userGroups["OK"]: - gLogger.error( userGroups["Message"] ) - sys.exit( -1 ) + gLogger.error(userGroups["Message"]) + sys.exit(-1) userGroups = userGroups["Value"] if userGroup not in userGroups: - gLogger.error( "'%s' is not a member of the '%s' group" % ( userName, userGroup ) ) - sys.exit( -1 ) + gLogger.error("'%s' is not a member of the '%s' group" % (userName, userGroup)) + sys.exit(-1) - userDN = getDNForUsername( userName ) - if not userDN["OK"]: - gLogger.error( userDN["Message"] ) - sys.exit( -1 ) - userDN = userDN["Value"][0] - gLogger.always( "userDN is %s" % userDN ) + result = getDNForUsernameInGroup(userName, userGroup) + if not result['OK']: + gLogger.error(result['Message']) + sys.exit(-1) + userDN = result['Value'] + gLogger.always("userDN is %s" % userDN) fct = FullChainTest() - put = fct.putRequest( userName, userDN, userGroup, sourceSE, targetSE1, targetSE2 ) + put = fct.putRequest(userName, userDN, userGroup, sourceSE, targetSE1, targetSE2) diff --git a/tests/Integration/RequestManagementSystem/Test_Client_Req.py b/tests/Integration/RequestManagementSystem/Test_Client_Req.py index ec14303e44d..7d82014396d 100644 --- a/tests/Integration/RequestManagementSystem/Test_Client_Req.py +++ b/tests/Integration/RequestManagementSystem/Test_Client_Req.py @@ -188,7 +188,7 @@ def test02Authorization(self): request = Request({"RequestName": "unauthorized"}) request.OwnerDN = 'NotMe' - request.OwnerDN = 'AnotherGroup' + request.OwnerGroup = 'AnotherGroup' op = Operation({"Type": "RemoveReplica", "TargetSE": "CERN-USER"}) op += File({"LFN": "/lhcb/user/c/cibak/foo"}) request += op diff --git a/tests/Integration/Resources/ProxyProvider/Test_DIRACCAProxyProvider.py b/tests/Integration/Resources/ProxyProvider/Test_DIRACCAProxyProvider.py index 6fcfdb63956..abe9a761e0c 100644 --- a/tests/Integration/Resources/ProxyProvider/Test_DIRACCAProxyProvider.py +++ b/tests/Integration/Resources/ProxyProvider/Test_DIRACCAProxyProvider.py @@ -12,8 +12,7 @@ from diraccfg import CFG -import DIRAC -from DIRAC import gConfig +from DIRAC import gConfig, S_OK from DIRAC.Core.Security.X509Chain import X509Chain # pylint: disable=import-error from DIRAC.Resources.ProxyProvider.ProxyProviderFactory import ProxyProviderFactory @@ -66,6 +65,15 @@ """ +class proxyManager(object): + """ Fake proxyManager + """ + def _storeProxy(self, userDN, chain): + """ Fake store method + """ + return S_OK() + + class DIRACCAPPTest(unittest.TestCase): """ Base class for the Modules test cases """ @@ -86,7 +94,7 @@ def setUp(self): cfg.loadFromBuffer(userCFG) gConfig.loadCFG(cfg) - result = ProxyProviderFactory().getProxyProvider('DIRAC_TEST_CA') + result = ProxyProviderFactory().getProxyProvider('DIRAC_TEST_CA', proxyManager=proxyManager()) self.assertTrue(result['OK'], '\n%s' % result.get('Message') or 'Error message is absent.') self.pp = result['Value'] @@ -103,8 +111,7 @@ def test_getProxy(self): ': %s' % result.get('Message', 'Error message is absent.')) self.assertEqual(result['OK'], res, text) if res: - chain = X509Chain() - chain.loadChainFromString(result['Value']) + chain = result['Value'] result = chain.getCredentials() self.assertTrue(result['OK'], '\n%s' % result.get('Message') or 'Error message is absent.') credDict = result['Value'] @@ -122,8 +129,7 @@ def test_generateProxyDN(self): self.assertTrue(result['OK'], '\n%s' % result.get('Message') or 'Error message is absent.') result = self.pp.getProxy(result['Value']) self.assertTrue(result['OK'], '\n%s' % result.get('Message') or 'Error message is absent.') - chain = X509Chain() - chain.loadChainFromString(result['Value']) + chain = result['Value'] result = chain.getCredentials() self.assertTrue(result['OK'], '\n%s' % result.get('Message') or 'Error message is absent.') issuer = result['Value']['issuer'] diff --git a/tests/Integration/WorkloadManagementSystem/Test_PilotsClient.py b/tests/Integration/WorkloadManagementSystem/Test_PilotsClient.py index a1590370433..e26e49cc02c 100644 --- a/tests/Integration/WorkloadManagementSystem/Test_PilotsClient.py +++ b/tests/Integration/WorkloadManagementSystem/Test_PilotsClient.py @@ -25,10 +25,10 @@ def test_PilotsDB(): - + realDN = '/C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser/emailAddress=lhcb-dirac-ci@cern.ch' + realGroup = 'dirac_user' pilots = PilotManagerClient() - - res = pilots.addPilotTQReference(['aPilot'], 1, '/a/ownerDN', 'a/owner/Group') + res = pilots.addPilotTQReference(['aPilot'], 1, realDN, realGroup) assert res['OK'] is True, res['Message'] res = pilots.getCurrentPilotCounters({}) assert res['OK'] is True, res['Message'] @@ -39,17 +39,19 @@ def test_PilotsDB(): assert res['OK'] is True, res['Message'] assert res['Value'] == {}, res['Value'] - res = pilots.addPilotTQReference(['anotherPilot'], 1, '/a/ownerDN', 'a/owner/Group') + res = pilots.addPilotTQReference(['anotherPilot'], 1, realDN, realGroup) assert res['OK'] is True, res['Message'] res = pilots.storePilotOutput('anotherPilot', 'This is an output', 'this is an error') assert res['OK'] is True, res['Message'] res = pilots.getPilotOutput('anotherPilot') assert res['OK'] is True, res['Message'] - assert res['Value'] == {'OwnerDN': '/a/ownerDN', - 'OwnerGroup': 'a/owner/Group', - 'StdErr': 'this is an error', - 'FileList': [], - 'StdOut': 'This is an output'} + # There are new "Owner" key ... Therefore, if the main keys match then all is well + expectedDict = {'FileList': [], + 'OwnerDN': realDN, + 'OwnerGroup': realGroup, + 'StdErr': 'this is an error', + 'StdOut': 'This is an output'} + assert all([res['Value'][k] == v for k, v in expectedDict.items()]) res = pilots.getPilotInfo('anotherPilot') assert res['OK'] is True, res['Message'] assert res['Value']['anotherPilot']['AccountingSent'] == 'False', res['Value'] @@ -66,12 +68,12 @@ def test_PilotsDB(): res = pilots.getPilotMonitorSelectors() assert res['OK'] is True, res['Message'] assert res['Value'] == {'GridType': ['DIRAC'], - 'OwnerGroup': ['a/owner/Group'], + 'OwnerGroup': [realGroup], 'DestinationSite': ['NotAssigned'], 'Broker': ['Unknown'], 'Status': ['Submitted'], - 'OwnerDN': ['/a/ownerDN'], + 'OwnerDN': [realDN], 'GridSite': ['Unknown'], - 'Owner': []}, res['Value'] + 'Owner': ['adminusername']}, res['Value'] res = pilots.getPilotSummaryWeb({}, [], 0, 100) assert res['OK'] is True, res['Message'] assert res['Value']['TotalRecords'] == 1, res['Value'] diff --git a/tests/Integration/all_integration_server_tests.sh b/tests/Integration/all_integration_server_tests.sh index a5f3bb67fef..5529389d268 100644 --- a/tests/Integration/all_integration_server_tests.sh +++ b/tests/Integration/all_integration_server_tests.sh @@ -26,6 +26,7 @@ pytest "${THIS_DIR}/Core/Test_MySQLDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( echo -e "*** $(date -u) **** FRAMEWORK TESTS (partially skipped) ****\n" pytest "${THIS_DIR}/Framework/Test_InstalledComponentsDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) python "${THIS_DIR}/Framework/Test_ProxyDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) +python "${THIS_DIR}/Framework/Test_AuthDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #pytest ${THIS_DIR}/Framework/Test_LoggingDB.py" |& tee -a "${SERVER_TEST_OUTPUT}"; (( ERR |= "${?}" )) #-------------------------------------------------------------------------------# diff --git a/tests/Jenkins/dirac_ci.sh b/tests/Jenkins/dirac_ci.sh index 8beec165e0d..e6c728fef16 100644 --- a/tests/Jenkins/dirac_ci.sh +++ b/tests/Jenkins/dirac_ci.sh @@ -181,6 +181,9 @@ installSite() { exit 1 fi + #TODO: remove it, this hack for testing + pip install authlib==0.15.3 + if ! dirac-setup-site "${DEBUG}"; then echo "ERROR: dirac-setup-site failed" >&2 exit 1