diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/realm/LdapRealm.java b/zeppelin-server/src/main/java/org/apache/zeppelin/realm/LdapRealm.java index dc10749c5b1..b41ac68402d 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/realm/LdapRealm.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/realm/LdapRealm.java @@ -33,6 +33,7 @@ import org.apache.shiro.realm.ldap.JndiLdapRealm; import org.apache.shiro.realm.ldap.LdapContextFactory; import org.apache.shiro.realm.ldap.LdapUtils; +import org.apache.shiro.session.Session; import org.apache.shiro.subject.MutablePrincipalCollection; import org.apache.shiro.subject.PrincipalCollection; import org.apache.shiro.util.StringUtils; @@ -178,7 +179,7 @@ public class LdapRealm extends JndiLdapRealm { private String groupIdAttribute = "cn"; - private String memberAttributeValuePrefix = "uid={0}"; + private String memberAttributeValuePrefix = "uid="; private String memberAttributeValueSuffix = ""; private final Map rolesByGroup = new LinkedHashMap(); @@ -246,7 +247,7 @@ protected AuthenticationInfo queryForAuthenticationInfo(AuthenticationToken toke * if any LDAP errors occur during the search. */ @Override - protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, + public AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals, final LdapContextFactory ldapContextFactory) throws NamingException { if (!isAuthorizationEnabled()) { return null; @@ -286,7 +287,8 @@ private Set getRoles(PrincipalCollection principals, LdapContext systemLdapCtx = null; try { systemLdapCtx = ldapContextFactory.getSystemLdapContext(); - return rolesFor(principals, username, systemLdapCtx, ldapContextFactory); + return rolesFor(principals, username, systemLdapCtx, + ldapContextFactory, SecurityUtils.getSubject().getSession()); } catch (AuthenticationException ae) { ae.printStackTrace(); return Collections.emptySet(); @@ -295,9 +297,9 @@ private Set getRoles(PrincipalCollection principals, } } - private Set rolesFor(PrincipalCollection principals, + protected Set rolesFor(PrincipalCollection principals, String userNameIn, final LdapContext ldapCtx, - final LdapContextFactory ldapContextFactory) throws NamingException { + final LdapContextFactory ldapContextFactory, Session session) throws NamingException { final Set roleNames = new HashSet<>(); final Set groupNames = new HashSet<>(); final String userName; @@ -308,14 +310,7 @@ private Set rolesFor(PrincipalCollection principals, userName = userNameIn; } - String userDn; - if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) { - // memberAttributeValuePrefix and memberAttributeValueSuffix - // were computed from memberAttributeValueTemplate - userDn = memberAttributeValuePrefix + userName + memberAttributeValueSuffix; - } else { - userDn = getUserDn(userName); - } + String userDn = getUserDnForSearch(userName); // Activate paged results int pageSize = getPagingSize(); @@ -364,8 +359,7 @@ private Set rolesFor(PrincipalCollection principals, // If group search filter is defined in Shiro config, then use it if (groupSearchFilter != null) { - Matcher matchedPrincipal = matchPrincipal(userDn); - searchFilter = expandTemplate(groupSearchFilter, matchedPrincipal); + searchFilter = expandTemplate(groupSearchFilter, userName); //searchFilter = String.format("%1$s", groupSearchFilter); } if (log.isDebugEnabled()) { @@ -402,8 +396,8 @@ private Set rolesFor(PrincipalCollection principals, } // save role names and group names in session so that they can be // easily looked up outside of this object - SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_ROLES, roleNames); - SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_GROUPS, groupNames); + session.setAttribute(SUBJECT_USER_ROLES, roleNames); + session.setAttribute(SUBJECT_USER_GROUPS, groupNames); if (!groupNames.isEmpty() && (principals instanceof MutablePrincipalCollection)) { ((MutablePrincipalCollection) principals).addAll(groupNames, getName()); } @@ -413,7 +407,17 @@ private Set rolesFor(PrincipalCollection principals, return roleNames; } - private void addRoleIfMember(final String userDn, final SearchResult group, + protected String getUserDnForSearch(String userName) { + if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) { + // memberAttributeValuePrefix and memberAttributeValueSuffix + // were computed from memberAttributeValueTemplate + return memberDn(userName); + } else { + return getUserDn(userName); + } + } + + private void addRoleIfMember(final String userDn, final SearchResult group, final Set roleNames, final Set groupNames, final LdapContextFactory ldapContextFactory) throws NamingException { @@ -446,8 +450,9 @@ private void addRoleIfMember(final String userDn, final SearchResult group, } } } else { + // posix groups' members don' include the entire dn if (groupObjectClass.equalsIgnoreCase(POSIX_GROUP)) { - attrValue = memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix; + attrValue = memberDn(attrValue); } if (userLdapDn.equals(new LdapName(attrValue))) { groupNames.add(groupName); @@ -474,7 +479,11 @@ private void addRoleIfMember(final String userDn, final SearchResult group, } } } - + + private String memberDn(String attrValue) { + return memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix; + } + public Map getListRoles() { Map groupToRoles = getRolesByGroup(); Map roles = new HashMap<>(); @@ -804,7 +813,7 @@ private SearchControls getUserSearchControls() { return searchControls; } - private SearchControls getGroupSearchControls() { + protected SearchControls getGroupSearchControls() { SearchControls searchControls = SUBTREE_SCOPE; if ("onelevel".equalsIgnoreCase(groupSearchScope)) { searchControls = ONELEVEL_SCOPE; @@ -819,13 +828,13 @@ public void setUserDnTemplate(final String template) throws IllegalArgumentExcep userDnTemplate = template; } - private Matcher matchPrincipal(final String principal) { + private String matchPrincipal(final String principal) { Matcher matchedPrincipal = principalPattern.matcher(principal); if (!matchedPrincipal.matches()) { throw new IllegalArgumentException("Principal " + principal + " does not match " + principalRegex); } - return matchedPrincipal; + return matchedPrincipal.group(); } /** @@ -856,7 +865,7 @@ private Matcher matchPrincipal(final String principal) { protected String getUserDn(final String principal) throws IllegalArgumentException, IllegalStateException { String userDn; - Matcher matchedPrincipal = matchPrincipal(principal); + String matchedPrincipal = matchPrincipal(principal); String userSearchBase = getUserSearchBase(); String userSearchAttributeName = getUserSearchAttributeName(); @@ -938,16 +947,7 @@ protected AuthenticationInfo createAuthenticationInfo(AuthenticationToken token, getName()); } - private static final String expandTemplate(final String template, final Matcher input) { - String output = template; - Matcher matcher = TEMPLATE_PATTERN.matcher(output); - while (matcher.find()) { - String lookupStr = matcher.group(1); - int lookupIndex = Integer.parseInt(lookupStr); - String lookupValue = input.group(lookupIndex); - output = matcher.replaceFirst(lookupValue == null ? "" : lookupValue); - matcher = TEMPLATE_PATTERN.matcher(output); - } - return output; + protected static final String expandTemplate(final String template, final String input) { + return template.replace(MEMBER_SUBSTITUTION_TOKEN, input); } } diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/SecurityUtils.java b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/SecurityUtils.java index b2029ecf6fd..439101f5324 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/SecurityUtils.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/SecurityUtils.java @@ -26,10 +26,14 @@ import java.util.Iterator; import java.util.Map; +import javax.naming.NamingException; + +import org.apache.shiro.authz.AuthorizationInfo; import org.apache.shiro.config.IniSecurityManagerFactory; import org.apache.shiro.mgt.SecurityManager; import org.apache.shiro.realm.Realm; import org.apache.shiro.realm.text.IniRealm; +import org.apache.shiro.subject.SimplePrincipalCollection; import org.apache.shiro.subject.Subject; import org.apache.shiro.util.ThreadContext; import org.apache.shiro.web.mgt.DefaultWebSecurityManager; @@ -129,7 +133,15 @@ public static HashSet getRoles() { allRoles = ((IniRealm) realm).getIni().get("roles"); break; } else if (name.equals("org.apache.zeppelin.realm.LdapRealm")) { - allRoles = ((LdapRealm) realm).getListRoles(); + try { + AuthorizationInfo auth = ((LdapRealm) realm).queryForAuthorizationInfo( + new SimplePrincipalCollection(subject.getPrincipal(), realm.getName()), + ((LdapRealm) realm).getContextFactory() + ); + roles = new HashSet<>(auth.getRoles()); + } catch (NamingException e) { + log.error("Can't fetch roles", e); + } break; } else if (name.equals("org.apache.zeppelin.realm.ActiveDirectoryGroupRealm")) { allRoles = ((ActiveDirectoryGroupRealm) realm).getListRoles(); diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/realm/LdapRealmTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/realm/LdapRealmTest.java new file mode 100644 index 00000000000..9070c5f4e33 --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/realm/LdapRealmTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.zeppelin.realm; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import javax.naming.NamingEnumeration; +import javax.naming.NamingException; +import javax.naming.directory.BasicAttributes; +import javax.naming.directory.SearchControls; +import javax.naming.directory.SearchResult; +import javax.naming.ldap.LdapContext; + +import org.apache.shiro.realm.ldap.LdapContextFactory; +import org.apache.shiro.session.Session; +import org.apache.shiro.subject.SimplePrincipalCollection; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; + + +public class LdapRealmTest { + + @Test + public void testGetUserDn() { + LdapRealm realm = new LdapRealm(); + + // without a user search filter + realm.setUserSearchFilter(null); + assertEquals( + "foo ", + realm.getUserDn("foo ") + ); + + // with a user search filter + realm.setUserSearchFilter("memberUid={0}"); + assertEquals( + "foo", + realm.getUserDn("foo") + ); + } + + @Test + public void testExpandTemplate() { + assertEquals( + "uid=foo,cn=users,dc=ods,dc=foo", + LdapRealm.expandTemplate("uid={0},cn=users,dc=ods,dc=foo", "foo") + ); + } + + @Test + public void getUserDnForSearch() { + LdapRealm realm = new LdapRealm(); + + realm.setUserSearchAttributeName("uid"); + assertEquals( + "foo", + realm.getUserDnForSearch("foo") + ); + + // using a template + realm.setUserSearchAttributeName(null); + realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=hadoop,dc=apache"); + assertEquals( + "cn=foo,ou=people,dc=hadoop,dc=apache", + realm.getUserDnForSearch("foo") + ); + } + + @Test + public void testRolesFor() throws NamingException { + LdapRealm realm = new LdapRealm(); + realm.setGroupSearchBase("cn=groups,dc=apache"); + realm.setGroupObjectClass("posixGroup"); + realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=apache"); + HashMap rolesByGroups = new HashMap<>(); + rolesByGroups.put("group-three", "zeppelin-role"); + realm.setRolesByGroup(rolesByGroups); + + LdapContextFactory ldapContextFactory = mock(LdapContextFactory.class); + LdapContext ldapCtx = mock(LdapContext.class); + Session session = mock(Session.class); + + + // expected search results + BasicAttributes group1 = new BasicAttributes(); + group1.put(realm.getGroupIdAttribute(), "group-one"); + group1.put(realm.getMemberAttribute(), "principal"); + + // user doesn't belong to this group + BasicAttributes group2 = new BasicAttributes(); + group2.put(realm.getGroupIdAttribute(), "group-two"); + group2.put(realm.getMemberAttribute(), "someoneelse"); + + // mapped to a different Zeppelin role + BasicAttributes group3 = new BasicAttributes(); + group3.put(realm.getGroupIdAttribute(), "group-three"); + group3.put(realm.getMemberAttribute(), "principal"); + + NamingEnumeration results = enumerationOf(group1, group2, group3); + when(ldapCtx.search(any(String.class), any(String.class), any(SearchControls.class))).thenReturn(results); + + + Set roles = realm.rolesFor( + new SimplePrincipalCollection("principal", "ldapRealm"), + "principal", + ldapCtx, + ldapContextFactory, + session + ); + + verify(ldapCtx).search( + "cn=groups,dc=apache", + "(objectclass=posixGroup)", + realm.getGroupSearchControls() + ); + + assertEquals( + new HashSet(Arrays.asList("group-one", "zeppelin-role")), + roles + ); + } + + private NamingEnumeration enumerationOf(BasicAttributes... attrs) { + final Iterator iterator = Arrays.asList(attrs).iterator(); + return new NamingEnumeration() { + @Override + public SearchResult next() throws NamingException { + return nextElement(); + } + + @Override + public boolean hasMore() throws NamingException { + return iterator.hasNext(); + } + + @Override + public void close() throws NamingException { + } + + @Override + public boolean hasMoreElements() { + return iterator.hasNext(); + } + + @Override + public SearchResult nextElement() { + final BasicAttributes attrs = iterator.next(); + return new SearchResult(null, null, attrs); + } + }; + } +}