diff --git a/README.md b/README.md index 427a3d3..f1a8684 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,36 @@ export const loader = (args: LoaderFunctionArgs) => }); ``` +### Authenticated actions + +AuthKit now provides helpers for React Router **actions** so you can enforce authentication in React Router's actions just as easily as in loaders. + +#### `authkitAction` + +Use `authkitAction` when you need the full flexibility of optional or required auth, custom storage, or refresh callbacks. Its API mirrors `authkitLoader`, and you receive the same `{ auth, getAccessToken }` arguments inside your action. + +```tsx +import type { ActionFunctionArgs } from 'react-router'; +import { authkitAction } from '@workos-inc/authkit-react-router'; + +export const action = (args: ActionFunctionArgs) => + authkitAction(args, async ({ auth, getAccessToken }) => { + if (!auth.user) { + return { status: 'anonymous' }; + } + + const token = getAccessToken(); + await fetch('https://api.example.com/secure', { + method: 'POST', + headers: { Authorization: `Bearer ${token}` }, + }); + + return { status: 'ok', sessionId: auth.sessionId }; + }); +``` + +When the action must always have an authenticated user, pass `{ ensureSignedIn: true }` as the final argument. That guarantees `getAccessToken()` will never return `null` and unauthenticated users will be redirected automatically. + #### Security Considerations By default, access tokens are not included in the data sent to React components. This helps prevent unintentional token exposure in: diff --git a/src/index.ts b/src/index.ts index 7d038c1..e136cb7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,11 +1,12 @@ import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization, withAuth } from './auth.js'; import { authLoader } from './authkit-callback-route.js'; import { configure, getConfig } from './config.js'; -import { authkitLoader, refreshSession } from './session.js'; +import { authkitAction, authkitLoader, refreshSession } from './session.js'; import { getWorkOS } from './workos.js'; export { authLoader, + authkitAction, authkitLoader, configure, withAuth, diff --git a/src/interfaces.ts b/src/interfaces.ts index 3d5a3d5..c8965d9 100644 --- a/src/interfaces.ts +++ b/src/interfaces.ts @@ -108,6 +108,8 @@ export type AuthKitLoaderOptions = { } ); +export type AuthKitActionOptions = AuthKitLoaderOptions; + export interface AuthorizedData { user: User; sessionId: string; diff --git a/src/session.spec.ts b/src/session.spec.ts index 70bcbeb..04c824a 100644 --- a/src/session.spec.ts +++ b/src/session.spec.ts @@ -1,4 +1,4 @@ -import { LoaderFunctionArgs, Session as ReactRouterSession, redirect } from 'react-router'; +import { ActionFunctionArgs, LoaderFunctionArgs, Session as ReactRouterSession, redirect } from 'react-router'; import { AuthenticationResponse } from '@workos-inc/node'; import * as ironSession from 'iron-session'; import * as jose from 'jose'; @@ -7,7 +7,7 @@ import { getSessionStorage as getSessionStorageMock, } from './sessionStorage.js'; import { Session } from './interfaces.js'; -import { authkitLoader, encryptSession, terminateSession, refreshSession } from './session.js'; +import { authkitAction, authkitLoader, encryptSession, terminateSession, refreshSession } from './session.js'; import { assertIsResponse } from './test-utils/test-helpers.js'; import { getWorkOS } from './workos.js'; import { getConfig } from './config.js'; @@ -526,6 +526,318 @@ describe('session', () => { }); }); + describe('authkitAction', () => { + const createActionArgs = (request: Request): ActionFunctionArgs => ({ + request, + params: {}, + context: {}, + }); + + describe('unauthenticated flows', () => { + beforeEach(() => { + const mockSession = createMockSession({ + has: jest.fn().mockReturnValue(false), + get: jest.fn(), + }); + getSession.mockResolvedValue(mockSession); + unsealData.mockResolvedValue(null); + }); + + it('should return unauthorized data when no session exists', async () => { + const { data } = await authkitAction(createActionArgs(createMockRequest())); + + expect(data).toEqual({ + user: null, + impersonator: null, + organizationId: null, + permissions: null, + entitlements: null, + featureFlags: null, + role: null, + roles: null, + sessionId: null, + }); + }); + + it('should redirect when ensureSignedIn is true', async () => { + try { + await authkitAction(createActionArgs(createMockRequest()), { ensureSignedIn: true }); + fail('Expected redirect response to be thrown'); + } catch (response: unknown) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toMatch(/^https:\/\/auth\.workos\.com\/oauth/); + expect(response.headers.get('Set-Cookie')).toBe('destroyed-session-cookie'); + } + }); + }); + + describe('authenticated flows', () => { + const mockSessionData = { + accessToken: 'action.jwt.token', + refreshToken: 'refresh.token', + user: { + id: 'user-1', + email: 'test@example.com', + }, + impersonator: null, + }; + + beforeEach(() => { + const mockSession = createMockSession({ + has: jest.fn().mockReturnValue(true), + get: jest.fn().mockReturnValue('encrypted-jwt'), + set: jest.fn(), + }); + getSession.mockResolvedValue(mockSession); + unsealData.mockResolvedValue({ + ...mockSessionData, + headers: { + 'Set-Cookie': 'action-session-cookie', + }, + }); + jwtVerify.mockResolvedValue({ + payload: {}, + protectedHeader: {}, + key: new TextEncoder().encode('test-key'), + } as jose.JWTVerifyResult & jose.ResolvedKey); + (jose.decodeJwt as jest.Mock).mockReturnValue({ + sid: 'test-session-id', + org_id: 'org-123', + role: 'member', + roles: ['member'], + permissions: ['read'], + entitlements: ['basic'], + feature_flags: ['flag-1'], + }); + }); + + it('should merge custom action data with auth data', async () => { + const customAction = jest.fn().mockResolvedValue({ + actionData: 'test', + }); + + const { data } = await authkitAction(createActionArgs(createMockRequest()), customAction); + + expect(customAction).toHaveBeenCalled(); + expect(data).toEqual( + expect.objectContaining({ + actionData: 'test', + user: mockSessionData.user, + sessionId: 'test-session-id', + }), + ); + }); + + it('should provide getAccessToken to custom actions', async () => { + const customAction = jest.fn().mockImplementation(({ getAccessToken }) => { + return { retrievedToken: getAccessToken() }; + }); + + const { data } = await authkitAction(createActionArgs(createMockRequest()), customAction); + + expect(customAction).toHaveBeenCalledWith( + expect.objectContaining({ + auth: expect.objectContaining({ user: mockSessionData.user }), + getAccessToken: expect.any(Function), + }), + ); + expect(data).toEqual( + expect.objectContaining({ + retrievedToken: mockSessionData.accessToken, + user: mockSessionData.user, + }), + ); + }); + + it('should pass through custom Responses and append cookies', async () => { + const customAction = jest.fn().mockReturnValue( + new Response(JSON.stringify({ ok: true }), { + headers: { 'Content-Type': 'application/json' }, + }), + ); + + const { data, init } = await authkitAction(createActionArgs(createMockRequest()), customAction); + + expect(getHeaderValue(init?.headers, 'Set-Cookie')).toBe('action-session-cookie'); + expect(data).toEqual( + expect.objectContaining({ + ok: true, + user: mockSessionData.user, + }), + ); + }); + + describe('session refresh during action', () => { + beforeEach(() => { + // Setup session with expired token + const mockSession = createMockSession({ + has: jest.fn().mockReturnValue(true), + get: jest.fn().mockReturnValue('encrypted-jwt'), + set: jest.fn(), + }); + getSession.mockResolvedValue(mockSession); + + const expiredSessionData = { + accessToken: 'expired.token', + refreshToken: 'refresh.token', + user: { id: 'user-1' }, + impersonator: null, + }; + unsealData.mockResolvedValue(expiredSessionData); + sealData.mockResolvedValue('new-encrypted-jwt'); + commitSession.mockResolvedValue('new-action-session-cookie'); + + // Token verification fails + jwtVerify.mockRejectedValue(new Error('Token expired')); + + // But refresh succeeds + authenticateWithRefreshToken.mockResolvedValue({ + accessToken: 'new.valid.action.token', + refreshToken: 'new.refresh.token', + user: { + object: 'user', + id: 'user-1', + email: 'test@example.com', + emailVerified: true, + profilePictureUrl: null, + firstName: 'Test', + lastName: 'User', + lastSignInAt: '2021-01-01T00:00:00Z', + createdAt: '2021-01-01T00:00:00Z', + updatedAt: '2021-01-01T00:00:00Z', + externalId: null, + }, + } as AuthenticationResponse); + + // Mock different JWT decoding results for expired vs new token + (jose.decodeJwt as jest.Mock).mockImplementation((token: string) => { + if (token === 'expired.token') { + return { + sid: 'test-session-id', + org_id: 'org-123', + role: null, + roles: [], + permissions: [], + entitlements: [], + feature_flags: [], + }; + } + if (token === 'new.valid.action.token') { + return { + sid: 'new-session-id', + org_id: 'org-123', + role: 'user', + roles: ['user'], + permissions: ['read'], + entitlements: ['basic'], + feature_flags: ['flag-1'], + }; + } + return {}; // fallback + }); + }); + + it('should refresh session when access token is invalid', async () => { + const { data, init } = await authkitAction(createActionArgs(createMockRequest())); + + // Verify the refresh token flow was triggered + expect(authenticateWithRefreshToken).toHaveBeenCalledWith({ + clientId: expect.any(String), + refreshToken: 'refresh.token', + organizationId: 'org-123', + }); + + // Verify the response contains the new token data + expect(data).toEqual( + expect.objectContaining({ + sessionId: 'new-session-id', + organizationId: 'org-123', + role: 'user', + roles: ['user'], + permissions: ['read'], + entitlements: ['basic'], + featureFlags: ['flag-1'], + }), + ); + + // Verify cookie was set + expect(getHeaderValue(init?.headers, 'Set-Cookie')).toBe('new-action-session-cookie'); + }); + + it('should redirect to authorization URL preserving returnPathname when action refresh fails', async () => { + authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid')); + + // Setup the mock to return a URL with state parameter + getAuthorizationUrlMock.mockResolvedValue('https://auth.workos.com/oauth/authorize?state=abc123'); + + try { + const mockRequest = createMockRequest('test-cookie', 'https://app.example.com/dashboard/data'); + await authkitAction(createActionArgs(mockRequest)); + fail('Expected redirect response to be thrown'); + } catch (response: unknown) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toBe('https://auth.workos.com/oauth/authorize?state=abc123'); + expect(response.headers.get('Set-Cookie')).toBe('destroyed-session-cookie'); + + // Verify getAuthorizationUrl was called with the correct returnPathname + expect(getAuthorizationUrlMock).toHaveBeenCalledWith({ + returnPathname: '/dashboard/data', + }); + } + }); + + it('calls onSessionRefreshSuccess when provided for action', async () => { + const onSessionRefreshSuccess = jest.fn(); + await authkitAction(createActionArgs(createMockRequest()), { + onSessionRefreshSuccess, + }); + + expect(onSessionRefreshSuccess).toHaveBeenCalledWith({ + accessToken: 'new.valid.action.token', + user: expect.objectContaining({ id: 'user-1' }), + impersonator: null, + organizationId: 'org-123', + }); + }); + + it('calls onSessionRefreshError when provided and action refresh fails', async () => { + authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid')); + const onSessionRefreshError = jest.fn().mockReturnValue(redirect('/error')); + + await authkitAction(createActionArgs(createMockRequest()), { + onSessionRefreshError, + }); + + expect(onSessionRefreshError).toHaveBeenCalledWith( + expect.objectContaining({ + error: expect.any(Error), + request: expect.any(Request), + }), + ); + }); + + it('allows redirect from onSessionRefreshError callback in action', async () => { + authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid')); + + try { + await authkitAction(createActionArgs(createMockRequest()), { + onSessionRefreshError: () => { + throw redirect('/custom-error'); + }, + }); + fail('Expected redirect response to be thrown'); + } catch (response: unknown) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toBe('/custom-error'); + } + }); + }); + }); + }); + describe('session refresh', () => { beforeEach(() => { // Setup session with expired token diff --git a/src/session.ts b/src/session.ts index 25f0be9..8bd72b8 100644 --- a/src/session.ts +++ b/src/session.ts @@ -1,7 +1,8 @@ -import { data, redirect, type LoaderFunctionArgs, type SessionData } from 'react-router'; +import { data, redirect, type ActionFunctionArgs, type LoaderFunctionArgs, type SessionData } from 'react-router'; import { getAuthorizationUrl } from './get-authorization-url.js'; import type { AccessToken, + AuthKitActionOptions, AuthKitLoaderOptions, AuthorizedData, DataWithResponseInit, @@ -175,6 +176,23 @@ type AuthorizedAuthLoader = ( args: LoaderFunctionArgs & { auth: AuthorizedData; getAccessToken: () => string }, ) => LoaderReturnValue; +type ActionValue = Response | DataWithResponseInit | NonNullable | null; +type ActionReturnValue = Promise> | ActionValue; + +type AuthAction = ( + args: ActionFunctionArgs & { + auth: AuthorizedData | UnauthorizedData; + getAccessToken: () => string | null; + }, +) => ActionReturnValue; + +type AuthorizedAuthAction = ( + args: ActionFunctionArgs & { + auth: AuthorizedData; + getAccessToken: () => string; + }, +) => ActionReturnValue; + /** * This loader handles authentication state, session management, and access token refreshing * automatically, making it easier to build authenticated routes. @@ -413,6 +431,149 @@ export async function authkitLoader( } } +export async function authkitAction( + actionArgs: ActionFunctionArgs, + options: AuthKitActionOptions & { ensureSignedIn: true }, +): Promise>; + +export async function authkitAction( + actionArgs: ActionFunctionArgs, + options?: AuthKitActionOptions, +): Promise>; + +export async function authkitAction( + actionArgs: ActionFunctionArgs, + action: AuthorizedAuthAction, + options: AuthKitActionOptions & { ensureSignedIn: true }, +): Promise & AuthorizedData>>; + +export async function authkitAction( + actionArgs: ActionFunctionArgs, + action: AuthAction, + options?: AuthKitActionOptions, +): Promise & (AuthorizedData | UnauthorizedData)>>; + +export async function authkitAction( + actionArgs: ActionFunctionArgs, + actionOrOptions?: AuthAction | AuthorizedAuthAction | AuthKitActionOptions, + options: AuthKitActionOptions = {}, +) { + const action = typeof actionOrOptions === 'function' ? actionOrOptions : undefined; + const { + ensureSignedIn = false, + debug = false, + onSessionRefreshSuccess, + onSessionRefreshError, + storage, + cookie, + } = typeof actionOrOptions === 'object' ? actionOrOptions : options; + + const cookieName = cookie?.name ?? getConfig('cookieName'); + const { getSession, destroySession } = await configureSessionStorage({ + storage, + cookieName, + }); + + const { request } = actionArgs; + + try { + const session = await updateSession(request, debug); + + if (!session) { + if (ensureSignedIn) { + const returnPathname = getReturnPathname(request.url); + const cookieSession = await getSession(request.headers.get('Cookie')); + + throw redirect(await getAuthorizationUrl({ returnPathname }), { + headers: { + 'Set-Cookie': await destroySession(cookieSession), + }, + }); + } + + const auth: UnauthorizedData = { + user: null, + impersonator: null, + organizationId: null, + permissions: null, + entitlements: null, + featureFlags: null, + role: null, + roles: null, + sessionId: null, + }; + + return await handleAuthAction(action, actionArgs, auth); + } + + const { + sessionId, + organizationId = null, + role = null, + roles = null, + permissions = [], + entitlements = [], + featureFlags = [], + } = getClaimsFromAccessToken(session.accessToken); + + const { impersonator = null } = session; + + if (onSessionRefreshSuccess && 'headers' in session) { + await onSessionRefreshSuccess({ + accessToken: session.accessToken, + user: session.user, + impersonator, + organizationId, + }); + } + + const auth: AuthorizedData = { + user: session.user, + sessionId, + organizationId, + role, + roles, + permissions, + entitlements, + featureFlags, + impersonator, + }; + + return await handleAuthAction(action, actionArgs, auth, session); + } catch (error) { + if (error instanceof SessionRefreshError) { + const cookieSession = await getSession(request.headers.get('Cookie')); + + if (onSessionRefreshError) { + try { + const result = await onSessionRefreshError({ + error: error.cause, + request, + sessionData: cookieSession, + }); + + if (result instanceof Response) { + return result; + } + } catch (callbackError) { + if (callbackError instanceof Response) { + throw callbackError; + } + } + } + + const returnPathname = getReturnPathname(request.url); + throw redirect(await getAuthorizationUrl({ returnPathname }), { + headers: { + 'Set-Cookie': await destroySession(cookieSession), + }, + }); + } + + throw error; + } +} + async function handleAuthLoader( loader: AuthLoader | AuthorizedAuthLoader | undefined, args: LoaderFunctionArgs, @@ -472,6 +633,76 @@ async function handleAuthLoader( return data({ ...loaderResult, ...auth }, session ? { headers: { ...session.headers } } : undefined); } +async function handleAuthAction( + action: AuthAction | AuthorizedAuthAction | undefined, + args: ActionFunctionArgs, + auth: AuthorizedData | UnauthorizedData, + session?: Session, +) { + if (!action) { + return data(auth, session ? { headers: { ...session.headers } } : undefined); + } + + let actionResult; + + if (auth.user) { + const getAccessToken = () => { + if (!session?.accessToken) { + throw new Error('No access token available'); + } + return session.accessToken; + }; + + actionResult = await (action as AuthorizedAuthAction)({ + ...args, + auth: auth as AuthorizedData, + getAccessToken, + }); + } else { + const getAccessToken = () => null; + + actionResult = await (action as AuthAction)({ + ...args, + auth, + getAccessToken, + }); + } + + if (isResponse(actionResult)) { + if (isRedirect(actionResult)) { + throw actionResult; + } + + const newResponse = new Response(actionResult.body, actionResult); + + if (session) { + newResponse.headers.append('Set-Cookie', session.headers['Set-Cookie']); + } + + if (!isJsonResponse(newResponse)) { + return newResponse; + } + + const responseData = await newResponse.json(); + + return data({ ...responseData, ...auth }, newResponse); + } + + const actualData = isDataWithResponseInit(actionResult) ? actionResult.data : actionResult; + + const mergedHeaders = isDataWithResponseInit(actionResult) ? new Headers(actionResult.init?.headers) : new Headers(); + + if (session?.headers) { + Object.entries(session.headers).forEach(([key, value]) => { + mergedHeaders.set(key, value); + }); + } + + const mergedData = actualData && typeof actualData === 'object' ? { ...actualData, ...auth } : { ...auth }; + + return data(mergedData, { headers: mergedHeaders }); +} + export async function terminateSession(request: Request, { returnTo }: { returnTo?: string } = {}) { const { getSession, destroySession } = await getSessionStorage(); const encryptedSession = await getSession(request.headers.get('Cookie'));