From 13032e564d10ab85c23061124dc1e9117b9a602a Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 25 Apr 2025 15:02:34 +0300 Subject: [PATCH 1/7] refactor: return all values from body in the providers --- internal/handlers/handlers.go | 19 +++++++++---- internal/providers/generic.go | 21 ++++++-------- internal/providers/github.go | 16 +++++++---- internal/providers/google.go | 23 ++++++---------- internal/providers/providers.go | 49 +++++++++++++++++---------------- 5 files changed, 67 insertions(+), 61 deletions(-) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 6e80851f..c873d9fb 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -613,14 +613,23 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { return } - // Get email - email, err := h.Providers.GetUser(providerName.Provider) - - log.Debug().Str("email", email).Msg("Got email") + // Get user + user, err := h.Providers.GetUser(providerName.Provider) // Handle error if err != nil { - log.Error().Msg("Failed to get email") + log.Error().Msg("Failed to get user") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Msg("Got user") + + // Get email + email, ok := user["email"].(string) + + if !ok { + log.Error().Msg("Failed to get email from user") c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) return } diff --git a/internal/providers/generic.go b/internal/providers/generic.go index 798b039d..084d5fbc 100644 --- a/internal/providers/generic.go +++ b/internal/providers/generic.go @@ -8,18 +8,16 @@ import ( "github.com/rs/zerolog/log" ) -// We are assuming that the generic provider will return a JSON object with an email field -type GenericUserInfoResponse struct { - Email string `json:"email"` -} +func GetGenericUser(client *http.Client, url string) (map[string]interface{}, error) { + // Create user struct + user := make(map[string]interface{}) -func GetGenericEmail(client *http.Client, url string) (string, error) { // Using the oauth client get the user info url res, err := client.Get(url) // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Got response from generic provider") @@ -29,24 +27,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) { // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Read body from generic provider") - // Parse the body into a user struct - var user GenericUserInfoResponse - // Unmarshal the body into the user struct err = json.Unmarshal(body, &user) // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Parsed user from generic provider") - // Return the email - return user.Email, nil + // Return the user + return user, nil } diff --git a/internal/providers/github.go b/internal/providers/github.go index 010e799c..3c1a59c8 100644 --- a/internal/providers/github.go +++ b/internal/providers/github.go @@ -20,13 +20,16 @@ func GithubScopes() []string { return []string{"user:email"} } -func GetGithubEmail(client *http.Client) (string, error) { +func GetGithubUser(client *http.Client) (map[string]interface{}, error) { + // Create user struct + user := make(map[string]interface{}) + // Get the user emails from github using the oauth http client res, err := client.Get("https://api.github.com/user/emails") // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Got response from github") @@ -36,7 +39,7 @@ func GetGithubEmail(client *http.Client) (string, error) { // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Read body from github") @@ -49,7 +52,7 @@ func GetGithubEmail(client *http.Client) (string, error) { // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Parsed emails from github") @@ -57,10 +60,11 @@ func GetGithubEmail(client *http.Client) (string, error) { // Find and return the primary email for _, email := range emails { if email.Primary { - return email.Email, nil + user["email"] = email.Email + return user, nil } } // User does not have a primary email? - return "", errors.New("no primary email found") + return user, errors.New("no primary email found") } diff --git a/internal/providers/google.go b/internal/providers/google.go index ba5c8b44..9785586a 100644 --- a/internal/providers/google.go +++ b/internal/providers/google.go @@ -8,23 +8,21 @@ import ( "github.com/rs/zerolog/log" ) -// Google works the same as the generic provider -type GoogleUserInfoResponse struct { - Email string `json:"email"` -} - // The scopes required for the google provider func GoogleScopes() []string { return []string{"https://www.googleapis.com/auth/userinfo.email"} } -func GetGoogleEmail(client *http.Client) (string, error) { +func GetGoogleUser(client *http.Client) (map[string]interface{}, error) { + // Create user struct + user := make(map[string]interface{}) + // Get the user info from google using the oauth http client res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Got response from google") @@ -34,24 +32,21 @@ func GetGoogleEmail(client *http.Client) (string, error) { // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Read body from google") - // Parse the body into a user struct - var user GoogleUserInfoResponse - // Unmarshal the body into the user struct err = json.Unmarshal(body, &user) // Check if there was an error if err != nil { - return "", err + return user, err } log.Debug().Msg("Parsed user from google") - // Return the email - return user.Email, nil + // Return the user + return user, nil } diff --git a/internal/providers/providers.go b/internal/providers/providers.go index c1bad5ec..3c4f4fc9 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -93,14 +93,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { } } -func (providers *Providers) GetUser(provider string) (string, error) { - // Get the email from the provider +func (providers *Providers) GetUser(provider string) (map[string]interface{}, error) { + // Create user struct + user := make(map[string]interface{}) + + // Get the user from the provider switch provider { case "github": // If the github provider is not configured, return an error if providers.Github == nil { log.Debug().Msg("Github provider not configured") - return "", nil + return user, nil } // Get the client from the github provider @@ -108,23 +111,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { log.Debug().Msg("Got client from github") - // Get the email from the github provider - email, err := GetGithubEmail(client) + // Get the user from the github provider + user, err := GetGithubUser(client) // Check if there was an error if err != nil { - return "", err + return user, err } - log.Debug().Msg("Got email from github") + log.Debug().Msg("Got user from github") - // Return the email - return email, nil + // Return the user + return user, nil case "google": // If the google provider is not configured, return an error if providers.Google == nil { log.Debug().Msg("Google provider not configured") - return "", nil + return user, nil } // Get the client from the google provider @@ -132,23 +135,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { log.Debug().Msg("Got client from google") - // Get the email from the google provider - email, err := GetGoogleEmail(client) + // Get the user from the google provider + user, err := GetGoogleUser(client) // Check if there was an error if err != nil { - return "", err + return user, err } - log.Debug().Msg("Got email from google") + log.Debug().Msg("Got user from google") - // Return the email - return email, nil + // Return the user + return user, nil case "generic": // If the generic provider is not configured, return an error if providers.Generic == nil { log.Debug().Msg("Generic provider not configured") - return "", nil + return user, nil } // Get the client from the generic provider @@ -156,20 +159,20 @@ func (providers *Providers) GetUser(provider string) (string, error) { log.Debug().Msg("Got client from generic") - // Get the email from the generic provider - email, err := GetGenericEmail(client, providers.Config.GenericUserURL) + // Get the user from the generic provider + user, err := GetGenericUser(client, providers.Config.GenericUserURL) // Check if there was an error if err != nil { - return "", err + return user, err } - log.Debug().Msg("Got email from generic") + log.Debug().Msg("Got user from generic") // Return the email - return email, nil + return user, nil default: - return "", nil + return user, nil } } From 5e4e2ddbd97e31d7ed21f10b6638d748f6b0c370 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 25 Apr 2025 15:28:24 +0300 Subject: [PATCH 2/7] refactor: only accept claims following the OIDC spec --- cmd/root.go | 2 +- internal/constants/constants.go | 11 +++++++++++ internal/handlers/handlers.go | 16 +++++++--------- internal/providers/generic.go | 5 +++-- internal/providers/github.go | 7 ++++--- internal/providers/google.go | 5 +++-- internal/providers/providers.go | 5 +++-- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index beaa5d26..fe10166f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -189,7 +189,7 @@ func init() { rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") - rootCmd.Flags().String("generic-name", "Other", "Generic OAuth provider name.") + rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 37aa55d0..b6b5598b 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -7,3 +7,14 @@ var TinyauthLabels = []string{ "tinyauth.allowed", "tinyauth.headers", } + +// Claims are the OIDC supported claims +type Claims struct { + Name string `json:"name"` + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + MiddleName string `json:"middle_name"` + Nickname string `json:"nickname"` + Picture string `json:"picture"` + Email string `json:"email"` +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index c873d9fb..33908d90 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -625,22 +625,20 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { log.Debug().Msg("Got user") - // Get email - email, ok := user["email"].(string) - - if !ok { - log.Error().Msg("Failed to get email from user") + // Check that email is not empty + if user.Email == "" { + log.Warn().Msg("Email is empty") c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) return } // Email is not whitelisted - if !h.Auth.EmailWhitelisted(email) { - log.Warn().Str("email", email).Msg("Email not whitelisted") + if !h.Auth.EmailWhitelisted(user.Email) { + log.Warn().Str("email", user.Email).Msg("Email not whitelisted") // Build query queries, err := query.Values(types.UnauthorizedQuery{ - Username: email, + Username: user.Email, }) // Handle error @@ -658,7 +656,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { // Create session cookie (also cleans up redirect cookie) h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: email, + Username: user.Email, Provider: providerName.Provider, }) diff --git a/internal/providers/generic.go b/internal/providers/generic.go index 084d5fbc..8307600f 100644 --- a/internal/providers/generic.go +++ b/internal/providers/generic.go @@ -4,13 +4,14 @@ import ( "encoding/json" "io" "net/http" + "tinyauth/internal/constants" "github.com/rs/zerolog/log" ) -func GetGenericUser(client *http.Client, url string) (map[string]interface{}, error) { +func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { // Create user struct - user := make(map[string]interface{}) + var user constants.Claims // Using the oauth client get the user info url res, err := client.Get(url) diff --git a/internal/providers/github.go b/internal/providers/github.go index 3c1a59c8..2f8862e7 100644 --- a/internal/providers/github.go +++ b/internal/providers/github.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net/http" + "tinyauth/internal/constants" "github.com/rs/zerolog/log" ) @@ -20,9 +21,9 @@ func GithubScopes() []string { return []string{"user:email"} } -func GetGithubUser(client *http.Client) (map[string]interface{}, error) { +func GetGithubUser(client *http.Client) (constants.Claims, error) { // Create user struct - user := make(map[string]interface{}) + var user constants.Claims // Get the user emails from github using the oauth http client res, err := client.Get("https://api.github.com/user/emails") @@ -60,7 +61,7 @@ func GetGithubUser(client *http.Client) (map[string]interface{}, error) { // Find and return the primary email for _, email := range emails { if email.Primary { - user["email"] = email.Email + user.Email = email.Email return user, nil } } diff --git a/internal/providers/google.go b/internal/providers/google.go index 9785586a..f8ba30d1 100644 --- a/internal/providers/google.go +++ b/internal/providers/google.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io" "net/http" + "tinyauth/internal/constants" "github.com/rs/zerolog/log" ) @@ -13,9 +14,9 @@ func GoogleScopes() []string { return []string{"https://www.googleapis.com/auth/userinfo.email"} } -func GetGoogleUser(client *http.Client) (map[string]interface{}, error) { +func GetGoogleUser(client *http.Client) (constants.Claims, error) { // Create user struct - user := make(map[string]interface{}) + var user constants.Claims // Get the user info from google using the oauth http client res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 3c4f4fc9..a22e83e9 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -2,6 +2,7 @@ package providers import ( "fmt" + "tinyauth/internal/constants" "tinyauth/internal/oauth" "tinyauth/internal/types" @@ -93,9 +94,9 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { } } -func (providers *Providers) GetUser(provider string) (map[string]interface{}, error) { +func (providers *Providers) GetUser(provider string) (constants.Claims, error) { // Create user struct - user := make(map[string]interface{}) + var user constants.Claims // Get the user from the provider switch provider { From dca09a3d9dce9697af0e97f76484b67829d3aada Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 25 Apr 2025 16:41:45 +0300 Subject: [PATCH 3/7] feat: map info from OIDC claims to headers --- cmd/root.go | 7 ++- frontend/src/pages/logout-page.tsx | 6 +- frontend/src/schemas/user-context-schema.ts | 2 + internal/api/api_test.go | 7 ++- internal/auth/auth.go | 19 +++++- internal/constants/constants.go | 12 ++-- internal/docker/docker.go | 11 ++-- internal/handlers/handlers.go | 34 ++++++++++- internal/hooks/hooks.go | 66 +++++++++------------ internal/types/api.go | 2 + internal/types/config.go | 5 ++ internal/types/types.go | 4 ++ internal/utils/utils.go | 5 ++ 13 files changed, 117 insertions(+), 63 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index fe10166f..d4167022 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -111,6 +111,11 @@ var rootCmd = &cobra.Command{ LoginMaxRetries: config.LoginMaxRetries, } + // Create hooks config + hooksConfig := types.HooksConfig{ + Domain: domain, + } + // Create docker service docker := docker.NewDocker() @@ -128,7 +133,7 @@ var rootCmd = &cobra.Command{ providers.Init() // Create hooks service - hooks := hooks.NewHooks(auth, providers) + hooks := hooks.NewHooks(hooksConfig, auth, providers) // Create handlers handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 5ef60417..e05c362a 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context"; import { Trans, useTranslation } from "react-i18next"; export const LogoutPage = () => { - const { isLoggedIn, username, oauth, provider } = useUserContext(); + const { isLoggedIn, oauth, provider, email } = useUserContext(); const { genericName } = useAppContext(); const { t } = useTranslation(); @@ -56,7 +56,7 @@ export const LogoutPage = () => { values={{ provider: provider === "generic" ? genericName : capitalize(provider), - username: username, + username: email, }} /> ) : ( @@ -65,7 +65,7 @@ export const LogoutPage = () => { t={t} components={{ Code: }} values={{ - username: username, + username: email, }} /> )} diff --git a/frontend/src/schemas/user-context-schema.ts b/frontend/src/schemas/user-context-schema.ts index 4d43d7b6..6ebe567b 100644 --- a/frontend/src/schemas/user-context-schema.ts +++ b/frontend/src/schemas/user-context-schema.ts @@ -3,6 +3,8 @@ import { z } from "zod"; export const userContextSchema = z.object({ isLoggedIn: z.boolean(), username: z.string(), + name: z.string(), + email: z.string(), oauth: z.boolean(), provider: z.string(), totpPending: z.boolean(), diff --git a/internal/api/api_test.go b/internal/api/api_test.go index e0cb6e53..23c1baf6 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{ LoginMaxRetries: 0, } +// Simple hooks config for tests +var hooksConfig = types.HooksConfig{ + Domain: "localhost", +} + // Cookie var cookie string @@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API { providers.Init() // Create hooks service - hooks := hooks.NewHooks(auth, providers) + hooks := hooks.NewHooks(hooksConfig, auth, providers) // Create handlers service handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index d6ed5f38..cf1d0f03 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -160,6 +160,8 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) // Set data session.Values["username"] = data.Username + session.Values["name"] = data.Name + session.Values["email"] = data.Email session.Values["provider"] = data.Provider session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() session.Values["totpPending"] = data.TotpPending @@ -211,14 +213,23 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) return types.SessionCookie{}, err } + log.Debug().Interface("session", session).Msg("Got session") + // Get data from session username, usernameOk := session.Values["username"].(string) + email, emailOk := session.Values["email"].(string) + name, nameOk := session.Values["name"].(string) provider, providerOK := session.Values["provider"].(string) expiry, expiryOk := session.Values["expiry"].(int64) totpPending, totpPendingOk := session.Values["totpPending"].(bool) - if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { - log.Warn().Msg("Session cookie is missing data") + if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk { + log.Warn().Msg("Session cookie is invalid") + + // If any data is missing, delete the session cookie + auth.DeleteSessionCookie(c) + + // Return empty cookie return types.SessionCookie{}, nil } @@ -233,11 +244,13 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) return types.SessionCookie{}, nil } - log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") + log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Msg("Parsed cookie") // Return the cookie return types.SessionCookie{ Username: username, + Name: name, + Email: email, Provider: provider, TotpPending: totpPending, }, nil diff --git a/internal/constants/constants.go b/internal/constants/constants.go index b6b5598b..f3b02d31 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -8,13 +8,9 @@ var TinyauthLabels = []string{ "tinyauth.headers", } -// Claims are the OIDC supported claims +// Claims are the OIDC supported claims (including preferd username for some reason) type Claims struct { - Name string `json:"name"` - FamilyName string `json:"family_name"` - GivenName string `json:"given_name"` - MiddleName string `json:"middle_name"` - Nickname string `json:"nickname"` - Picture string `json:"picture"` - Email string `json:"email"` + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` } diff --git a/internal/docker/docker.go b/internal/docker/docker.go index 07962e07..43807cec 100644 --- a/internal/docker/docker.go +++ b/internal/docker/docker.go @@ -6,8 +6,7 @@ import ( "tinyauth/internal/types" "tinyauth/internal/utils" - apiTypes "github.com/docker/docker/api/types" - containerTypes "github.com/docker/docker/api/types/container" + container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/rs/zerolog/log" ) @@ -38,9 +37,9 @@ func (docker *Docker) Init() error { return nil } -func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { +func (docker *Docker) GetContainers() ([]container.Summary, error) { // Get the list of containers - containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{}) + containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) // Check if there was an error if err != nil { @@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { return containers, nil } -func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) { +func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { // Inspect the container inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) // Check if there was an error if err != nil { - return apiTypes.ContainerJSON{}, err + return container.InspectResponse{}, err } // Return the inspect diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 33908d90..accc2de6 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -10,6 +10,7 @@ import ( "tinyauth/internal/hooks" "tinyauth/internal/providers" "tinyauth/internal/types" + "tinyauth/internal/utils" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -183,8 +184,9 @@ func (h *Handlers) AuthHandler(c *gin.Context) { return } - // Set the user header c.Header("Remote-User", userContext.Username) + c.Header("Remote-Name", userContext.Name) + c.Header("Remote-Email", userContext.Email) // Set the rest of the headers for key, value := range labels.Headers { @@ -310,6 +312,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { // Set totp pending cookie h.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: login.Username, + Name: utils.Capitalize(login.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), Provider: "username", TotpPending: true, }) @@ -328,6 +332,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { // Create session cookie with username as provider h.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: login.Username, + Name: utils.Capitalize(login.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), Provider: "username", }) @@ -402,6 +408,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) { // Create session cookie with username as provider h.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), Provider: "username", }) @@ -465,6 +473,8 @@ func (h *Handlers) UserHandler(c *gin.Context) { Status: 200, IsLoggedIn: userContext.IsLoggedIn, Username: userContext.Username, + Name: userContext.Name, + Email: userContext.Email, Provider: userContext.Provider, Oauth: userContext.OAuth, TotpPending: userContext.TotpPending, @@ -654,9 +664,29 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { log.Debug().Msg("Email whitelisted") + // Get username + var username string + + if user.PreferredUsername != "" { + username = user.PreferredUsername + } else { + username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) + } + + // Get name + var name string + + if user.Name != "" { + name = user.Name + } else { + name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + } + // Create session cookie (also cleans up redirect cookie) h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Email, + Username: username, + Name: name, + Email: user.Email, Provider: providerName.Provider, }) diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 5e9a6890..29c67867 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -1,22 +1,27 @@ package hooks import ( + "fmt" + "strings" "tinyauth/internal/auth" "tinyauth/internal/providers" "tinyauth/internal/types" + "tinyauth/internal/utils" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" ) -func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { +func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { return &Hooks{ + Config: config, Auth: auth, Providers: providers, } } type Hooks struct { + Config types.HooksConfig Auth *auth.Auth Providers *providers.Providers } @@ -36,11 +41,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { // Return user context since we are logged in with basic auth return types.UserContext{ - Username: basic.Username, - IsLoggedIn: true, - OAuth: false, - Provider: "basic", - TotpPending: false, + Username: basic.Username, + Name: utils.Capitalize(basic.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), + IsLoggedIn: true, + Provider: "basic", } } @@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { if err != nil { log.Error().Err(err).Msg("Failed to get session cookie") // Return empty context - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - TotpPending: false, - } + return types.UserContext{} } // Check if session cookie has totp pending @@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // Return empty context since we are pending totp return types.UserContext{ Username: cookie.Username, - IsLoggedIn: false, - OAuth: false, + Name: cookie.Name, + Email: cookie.Email, Provider: cookie.Provider, TotpPending: true, } @@ -82,11 +81,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // It exists so we are logged in return types.UserContext{ - Username: cookie.Username, - IsLoggedIn: true, - OAuth: false, - Provider: "username", - TotpPending: false, + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + IsLoggedIn: true, + Provider: "username", } } } @@ -108,33 +107,22 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { hooks.Auth.DeleteSessionCookie(c) // Return empty context - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - TotpPending: false, - } + return types.UserContext{} } log.Debug().Msg("Email is whitelisted") // Return user context since we are logged in with oauth return types.UserContext{ - Username: cookie.Username, - IsLoggedIn: true, - OAuth: true, - Provider: cookie.Provider, - TotpPending: false, + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + IsLoggedIn: true, + OAuth: true, + Provider: cookie.Provider, } } // Neither basic auth or oauth is set so we return an empty context - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - TotpPending: false, - } + return types.UserContext{} } diff --git a/internal/types/api.go b/internal/types/api.go index 144bb56a..0e12634f 100644 --- a/internal/types/api.go +++ b/internal/types/api.go @@ -33,6 +33,8 @@ type UserContextResponse struct { Message string `json:"message"` IsLoggedIn bool `json:"isLoggedIn"` Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` Provider string `json:"provider"` Oauth bool `json:"oauth"` TotpPending bool `json:"totpPending"` diff --git a/internal/types/config.go b/internal/types/config.go index 88e9169f..cc35de1e 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -78,3 +78,8 @@ type AuthConfig struct { LoginTimeout int LoginMaxRetries int } + +// HooksConfig is the configuration for the hooks service +type HooksConfig struct { + Domain string +} diff --git a/internal/types/types.go b/internal/types/types.go index dc6f9c1d..f9344c05 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -25,6 +25,8 @@ type OAuthProviders struct { // SessionCookie is the cookie for the session (exculding the expiry) type SessionCookie struct { Username string + Name string + Email string Provider string TotpPending bool } @@ -40,6 +42,8 @@ type TinyauthLabels struct { // UserContext is the context for the user type UserContext struct { Username string + Name string + Email string IsLoggedIn bool OAuth bool Provider string diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 25830153..528ad233 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -323,3 +323,8 @@ func CheckWhitelist(whitelist string, str string) bool { // Return false if no match was found return false } + +// Capitalize just the first letter of a string +func Capitalize(str string) string { + return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) +} From 065b9eaf3de6f82fefd1290c5195a6a3b2385f3a Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 28 Apr 2025 22:49:56 +0300 Subject: [PATCH 4/7] feat: add support for required oauth groups --- frontend/src/lib/i18n/locales/en-US.json | 1 + frontend/src/lib/i18n/locales/en.json | 1 + frontend/src/pages/unauthorized-page.tsx | 63 ++++++++++++++------ internal/auth/auth.go | 65 ++++++++++----------- internal/constants/constants.go | 8 ++- internal/handlers/handlers.go | 73 ++++++++++++++++-------- internal/hooks/hooks.go | 13 +++-- internal/types/api.go | 1 + internal/types/types.go | 3 + internal/utils/utils.go | 2 + 10 files changed, 147 insertions(+), 83 deletions(-) diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json index 12135fe3..acf2a5ff 100644 --- a/frontend/src/lib/i18n/locales/en-US.json +++ b/frontend/src/lib/i18n/locales/en-US.json @@ -42,6 +42,7 @@ "unauthorizedTitle": "Unauthorized", "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", "unaothorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", + "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", "unauthorizedButton": "Try again", "untrustedRedirectTitle": "Untrusted redirect", "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{domain}}). Are you sure you want to continue?", diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json index 12135fe3..acf2a5ff 100644 --- a/frontend/src/lib/i18n/locales/en.json +++ b/frontend/src/lib/i18n/locales/en.json @@ -42,6 +42,7 @@ "unauthorizedTitle": "Unauthorized", "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", "unaothorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", + "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", "unauthorizedButton": "Try again", "untrustedRedirectTitle": "Untrusted redirect", "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{domain}}). Are you sure you want to continue?", diff --git a/frontend/src/pages/unauthorized-page.tsx b/frontend/src/pages/unauthorized-page.tsx index 6825bd21..bcb35902 100644 --- a/frontend/src/pages/unauthorized-page.tsx +++ b/frontend/src/pages/unauthorized-page.tsx @@ -3,11 +3,13 @@ import { Layout } from "../components/layouts/layout"; import { Navigate } from "react-router"; import { isQueryValid } from "../utils/utils"; import { Trans, useTranslation } from "react-i18next"; +import React from "react"; export const UnauthorizedPage = () => { const queryString = window.location.search; const params = new URLSearchParams(queryString); const username = params.get("username") ?? ""; + const groupErr = params.get("groupErr") ?? ""; const resource = params.get("resource") ?? ""; const { t } = useTranslation(); @@ -16,6 +18,47 @@ export const UnauthorizedPage = () => { return ; } + if (isQueryValid(resource) && !isQueryValid(groupErr)) { + return ( + + }} + values={{ resource, username }} + /> + + ); + } + + if (isQueryValid(groupErr) && isQueryValid(resource)) { + return ( + + }} + values={{ username, resource }} + /> + + ) + } + + return ( + + }} + values={{ username }} + /> + + ); +}; + +const UnauthorizedLayout = ({ children }: { children: React.ReactNode }) => { + const { t } = useTranslation(); + return ( @@ -23,25 +66,7 @@ export const UnauthorizedPage = () => { {t("Unauthorized")} - {isQueryValid(resource) ? ( - - }} - values={{ resource, username }} - /> - - ) : ( - - }} - values={{ username }} - /> - - )} + {children}