Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.MutablePrincipalCollection;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.StringUtils;
Expand Down Expand Up @@ -178,7 +179,7 @@ public class LdapRealm extends JndiLdapRealm {

private String groupIdAttribute = "cn";

private String memberAttributeValuePrefix = "uid={0}";
private String memberAttributeValuePrefix = "uid=";
private String memberAttributeValueSuffix = "";

private final Map<String, String> rolesByGroup = new LinkedHashMap<String, String>();
Expand Down Expand Up @@ -246,7 +247,7 @@ protected AuthenticationInfo queryForAuthenticationInfo(AuthenticationToken toke
* if any LDAP errors occur during the search.
*/
@Override
protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals,
public AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals,
final LdapContextFactory ldapContextFactory) throws NamingException {
if (!isAuthorizationEnabled()) {
return null;
Expand Down Expand Up @@ -286,7 +287,8 @@ private Set<String> getRoles(PrincipalCollection principals,
LdapContext systemLdapCtx = null;
try {
systemLdapCtx = ldapContextFactory.getSystemLdapContext();
return rolesFor(principals, username, systemLdapCtx, ldapContextFactory);
return rolesFor(principals, username, systemLdapCtx,
ldapContextFactory, SecurityUtils.getSubject().getSession());
} catch (AuthenticationException ae) {
ae.printStackTrace();
return Collections.emptySet();
Expand All @@ -295,9 +297,9 @@ private Set<String> getRoles(PrincipalCollection principals,
}
}

private Set<String> rolesFor(PrincipalCollection principals,
protected Set<String> rolesFor(PrincipalCollection principals,
String userNameIn, final LdapContext ldapCtx,
final LdapContextFactory ldapContextFactory) throws NamingException {
final LdapContextFactory ldapContextFactory, Session session) throws NamingException {
final Set<String> roleNames = new HashSet<>();
final Set<String> groupNames = new HashSet<>();
final String userName;
Expand All @@ -308,14 +310,7 @@ private Set<String> rolesFor(PrincipalCollection principals,
userName = userNameIn;
}

String userDn;
if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) {
// memberAttributeValuePrefix and memberAttributeValueSuffix
// were computed from memberAttributeValueTemplate
userDn = memberAttributeValuePrefix + userName + memberAttributeValueSuffix;
} else {
userDn = getUserDn(userName);
}
String userDn = getUserDnForSearch(userName);

// Activate paged results
int pageSize = getPagingSize();
Expand Down Expand Up @@ -364,8 +359,7 @@ private Set<String> rolesFor(PrincipalCollection principals,

// If group search filter is defined in Shiro config, then use it
if (groupSearchFilter != null) {
Matcher matchedPrincipal = matchPrincipal(userDn);
searchFilter = expandTemplate(groupSearchFilter, matchedPrincipal);
searchFilter = expandTemplate(groupSearchFilter, userName);
//searchFilter = String.format("%1$s", groupSearchFilter);
}
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -402,8 +396,8 @@ private Set<String> rolesFor(PrincipalCollection principals,
}
// save role names and group names in session so that they can be
// easily looked up outside of this object
SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_ROLES, roleNames);
SecurityUtils.getSubject().getSession().setAttribute(SUBJECT_USER_GROUPS, groupNames);
session.setAttribute(SUBJECT_USER_ROLES, roleNames);
session.setAttribute(SUBJECT_USER_GROUPS, groupNames);
if (!groupNames.isEmpty() && (principals instanceof MutablePrincipalCollection)) {
((MutablePrincipalCollection) principals).addAll(groupNames, getName());
}
Expand All @@ -413,7 +407,17 @@ private Set<String> rolesFor(PrincipalCollection principals,
return roleNames;
}

private void addRoleIfMember(final String userDn, final SearchResult group,
protected String getUserDnForSearch(String userName) {
if (userSearchAttributeName == null || userSearchAttributeName.isEmpty()) {
// memberAttributeValuePrefix and memberAttributeValueSuffix
// were computed from memberAttributeValueTemplate
return memberDn(userName);
} else {
return getUserDn(userName);
}
}

private void addRoleIfMember(final String userDn, final SearchResult group,
final Set<String> roleNames, final Set<String> groupNames,
final LdapContextFactory ldapContextFactory) throws NamingException {

Expand Down Expand Up @@ -446,8 +450,9 @@ private void addRoleIfMember(final String userDn, final SearchResult group,
}
}
} else {
// posix groups' members don' include the entire dn
if (groupObjectClass.equalsIgnoreCase(POSIX_GROUP)) {
attrValue = memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix;
attrValue = memberDn(attrValue);
}
if (userLdapDn.equals(new LdapName(attrValue))) {
groupNames.add(groupName);
Expand All @@ -474,7 +479,11 @@ private void addRoleIfMember(final String userDn, final SearchResult group,
}
}
}


private String memberDn(String attrValue) {
return memberAttributeValuePrefix + attrValue + memberAttributeValueSuffix;
}

public Map<String, String> getListRoles() {
Map<String, String> groupToRoles = getRolesByGroup();
Map<String, String> roles = new HashMap<>();
Expand Down Expand Up @@ -804,7 +813,7 @@ private SearchControls getUserSearchControls() {
return searchControls;
}

private SearchControls getGroupSearchControls() {
protected SearchControls getGroupSearchControls() {
SearchControls searchControls = SUBTREE_SCOPE;
if ("onelevel".equalsIgnoreCase(groupSearchScope)) {
searchControls = ONELEVEL_SCOPE;
Expand All @@ -819,13 +828,13 @@ public void setUserDnTemplate(final String template) throws IllegalArgumentExcep
userDnTemplate = template;
}

private Matcher matchPrincipal(final String principal) {
private String matchPrincipal(final String principal) {
Matcher matchedPrincipal = principalPattern.matcher(principal);
if (!matchedPrincipal.matches()) {
throw new IllegalArgumentException("Principal "
+ principal + " does not match " + principalRegex);
}
return matchedPrincipal;
return matchedPrincipal.group();
}

/**
Expand Down Expand Up @@ -856,7 +865,7 @@ private Matcher matchPrincipal(final String principal) {
protected String getUserDn(final String principal) throws IllegalArgumentException,
IllegalStateException {
String userDn;
Matcher matchedPrincipal = matchPrincipal(principal);
String matchedPrincipal = matchPrincipal(principal);
String userSearchBase = getUserSearchBase();
String userSearchAttributeName = getUserSearchAttributeName();

Expand Down Expand Up @@ -938,16 +947,7 @@ protected AuthenticationInfo createAuthenticationInfo(AuthenticationToken token,
getName());
}

private static final String expandTemplate(final String template, final Matcher input) {
String output = template;
Matcher matcher = TEMPLATE_PATTERN.matcher(output);
while (matcher.find()) {
String lookupStr = matcher.group(1);
int lookupIndex = Integer.parseInt(lookupStr);
String lookupValue = input.group(lookupIndex);
output = matcher.replaceFirst(lookupValue == null ? "" : lookupValue);
matcher = TEMPLATE_PATTERN.matcher(output);
}
return output;
protected static final String expandTemplate(final String template, final String input) {
return template.replace(MEMBER_SUBSTITUTION_TOKEN, input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
import java.util.Iterator;
import java.util.Map;

import javax.naming.NamingException;

import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.config.IniSecurityManagerFactory;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.realm.Realm;
import org.apache.shiro.realm.text.IniRealm;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
Expand Down Expand Up @@ -129,7 +133,15 @@ public static HashSet<String> getRoles() {
allRoles = ((IniRealm) realm).getIni().get("roles");
break;
} else if (name.equals("org.apache.zeppelin.realm.LdapRealm")) {
allRoles = ((LdapRealm) realm).getListRoles();
try {
AuthorizationInfo auth = ((LdapRealm) realm).queryForAuthorizationInfo(
new SimplePrincipalCollection(subject.getPrincipal(), realm.getName()),
((LdapRealm) realm).getContextFactory()
);
roles = new HashSet<>(auth.getRoles());
} catch (NamingException e) {
log.error("Can't fetch roles", e);
}
break;
} else if (name.equals("org.apache.zeppelin.realm.ActiveDirectoryGroupRealm")) {
allRoles = ((ActiveDirectoryGroupRealm) realm).getListRoles();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.zeppelin.realm;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.BasicAttributes;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;

import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;


public class LdapRealmTest {

@Test
public void testGetUserDn() {
LdapRealm realm = new LdapRealm();

// without a user search filter
realm.setUserSearchFilter(null);
assertEquals(
"foo ",
realm.getUserDn("foo ")
);

// with a user search filter
realm.setUserSearchFilter("memberUid={0}");
assertEquals(
"foo",
realm.getUserDn("foo")
);
}

@Test
public void testExpandTemplate() {
assertEquals(
"uid=foo,cn=users,dc=ods,dc=foo",
LdapRealm.expandTemplate("uid={0},cn=users,dc=ods,dc=foo", "foo")
);
}

@Test
public void getUserDnForSearch() {
LdapRealm realm = new LdapRealm();

realm.setUserSearchAttributeName("uid");
assertEquals(
"foo",
realm.getUserDnForSearch("foo")
);

// using a template
realm.setUserSearchAttributeName(null);
realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=hadoop,dc=apache");
assertEquals(
"cn=foo,ou=people,dc=hadoop,dc=apache",
realm.getUserDnForSearch("foo")
);
}

@Test
public void testRolesFor() throws NamingException {
LdapRealm realm = new LdapRealm();
realm.setGroupSearchBase("cn=groups,dc=apache");
realm.setGroupObjectClass("posixGroup");
realm.setMemberAttributeValueTemplate("cn={0},ou=people,dc=apache");
HashMap<String, String> rolesByGroups = new HashMap<>();
rolesByGroups.put("group-three", "zeppelin-role");
realm.setRolesByGroup(rolesByGroups);

LdapContextFactory ldapContextFactory = mock(LdapContextFactory.class);
LdapContext ldapCtx = mock(LdapContext.class);
Session session = mock(Session.class);


// expected search results
BasicAttributes group1 = new BasicAttributes();
group1.put(realm.getGroupIdAttribute(), "group-one");
group1.put(realm.getMemberAttribute(), "principal");

// user doesn't belong to this group
BasicAttributes group2 = new BasicAttributes();
group2.put(realm.getGroupIdAttribute(), "group-two");
group2.put(realm.getMemberAttribute(), "someoneelse");

// mapped to a different Zeppelin role
BasicAttributes group3 = new BasicAttributes();
group3.put(realm.getGroupIdAttribute(), "group-three");
group3.put(realm.getMemberAttribute(), "principal");

NamingEnumeration<SearchResult> results = enumerationOf(group1, group2, group3);
when(ldapCtx.search(any(String.class), any(String.class), any(SearchControls.class))).thenReturn(results);


Set<String> roles = realm.rolesFor(
new SimplePrincipalCollection("principal", "ldapRealm"),
"principal",
ldapCtx,
ldapContextFactory,
session
);

verify(ldapCtx).search(
"cn=groups,dc=apache",
"(objectclass=posixGroup)",
realm.getGroupSearchControls()
);

assertEquals(
new HashSet(Arrays.asList("group-one", "zeppelin-role")),
roles
);
}

private NamingEnumeration<SearchResult> enumerationOf(BasicAttributes... attrs) {
final Iterator<BasicAttributes> iterator = Arrays.asList(attrs).iterator();
return new NamingEnumeration<SearchResult>() {
@Override
public SearchResult next() throws NamingException {
return nextElement();
}

@Override
public boolean hasMore() throws NamingException {
return iterator.hasNext();
}

@Override
public void close() throws NamingException {
}

@Override
public boolean hasMoreElements() {
return iterator.hasNext();
}

@Override
public SearchResult nextElement() {
final BasicAttributes attrs = iterator.next();
return new SearchResult(null, null, attrs);
}
};
}
}