Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type { OauthTokens, User } from '@workos-inc/node';

export type DataWithResponseInit<T> = ReturnType<typeof data<T>>;

export type UnwrapData<T> = T extends DataWithResponseInit<infer U> ? U : T;

export type HandleAuthOptions = {
returnPathname?: string;
onSuccess?: (data: AuthLoaderSuccessData) => void | Promise<void>;
Expand Down
48 changes: 35 additions & 13 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import type {
DataWithResponseInit,
Session,
UnauthorizedData,
UnwrapData,
} from './interfaces.js';
import { getWorkOS } from './workos.js';

import { sealData, unsealData } from 'iron-session';
import { createRemoteJWKSet, decodeJwt, jwtVerify } from 'jose';
import { getConfig } from './config.js';
import { configureSessionStorage, getSessionStorage } from './sessionStorage.js';
import { isJsonResponse, isRedirect, isResponse } from './utils.js';
import { isDataWithResponseInit, isJsonResponse, isRedirect, isResponse } from './utils.js';

// must be a type since this is a subtype of response
// interfaces must conform to the types they extend
Expand Down Expand Up @@ -168,11 +169,17 @@ type LoaderValue<Data> = Response | TypedResponse<Data> | NonNullable<Data> | nu
type LoaderReturnValue<Data> = Promise<LoaderValue<Data>> | LoaderValue<Data>;

type AuthLoader<Data> = (
args: LoaderFunctionArgs & { auth: AuthorizedData | UnauthorizedData; getAccessToken: () => string | null },
args: LoaderFunctionArgs & {
auth: AuthorizedData | UnauthorizedData;
getAccessToken: () => string | null;
},
) => LoaderReturnValue<Data>;

type AuthorizedAuthLoader<Data> = (
args: LoaderFunctionArgs & { auth: AuthorizedData; getAccessToken: () => string },
args: LoaderFunctionArgs & {
auth: AuthorizedData;
getAccessToken: () => string;
},
) => LoaderReturnValue<Data>;

/**
Expand All @@ -181,9 +188,6 @@ type AuthorizedAuthLoader<Data> = (
*
* Creates an authentication-aware loader function for React Router.
*
* This loader handles authentication state, session management, and access token refreshing
* automatically, making it easier to build authenticated routes.
*
* @overload
* Basic usage with enforced authentication that redirects unauthenticated users to sign in.
*
Expand Down Expand Up @@ -252,7 +256,7 @@ export async function authkitLoader<Data = unknown>(
loaderArgs: LoaderFunctionArgs,
loader: AuthorizedAuthLoader<Data>,
options: AuthKitLoaderOptions & { ensureSignedIn: true },
): Promise<DataWithResponseInit<Data & AuthorizedData>>;
): Promise<DataWithResponseInit<UnwrapData<Data> & AuthorizedData>>;

/**
* This loader handles authentication state, session management, and access token refreshing
Expand Down Expand Up @@ -287,7 +291,7 @@ export async function authkitLoader<Data = unknown>(
loaderArgs: LoaderFunctionArgs,
loader: AuthLoader<Data>,
options?: AuthKitLoaderOptions,
): Promise<DataWithResponseInit<Data & (AuthorizedData | UnauthorizedData)>>;
): Promise<DataWithResponseInit<UnwrapData<Data> & (AuthorizedData | UnauthorizedData)>>;

export async function authkitLoader<Data = unknown>(
loaderArgs: LoaderFunctionArgs,
Expand All @@ -305,7 +309,10 @@ export async function authkitLoader<Data = unknown>(
} = typeof loaderOrOptions === 'object' ? loaderOrOptions : options;

const cookieName = cookie?.name ?? getConfig('cookieName');
const { getSession, destroySession } = await configureSessionStorage({ storage, cookieName });
const { getSession, destroySession } = await configureSessionStorage({
storage,
cookieName,
});

const { request } = loaderArgs;

Expand Down Expand Up @@ -443,7 +450,11 @@ async function handleAuthLoader(
} else {
// Unauthorized case
const getAccessToken = () => null;
loaderResult = await (loader as AuthLoader<unknown>)({ ...args, auth, getAccessToken });
loaderResult = await (loader as AuthLoader<unknown>)({
...args,
auth,
getAccessToken,
});
}

if (isResponse(loaderResult)) {
Expand All @@ -467,9 +478,20 @@ async function handleAuthLoader(
return data({ ...responseData, ...auth }, newResponse);
}

// If the loader returns a non-Response, assume it's a data object
// istanbul ignore next
return data({ ...loaderResult, ...auth }, session ? { headers: { ...session.headers } } : undefined);
const actualData = isDataWithResponseInit(loaderResult) ? loaderResult.data : loaderResult;

const mergedHeaders = isDataWithResponseInit(loaderResult) ? new Headers(loaderResult.init?.headers) : new Headers();

if (session?.headers) {
Object.entries(session.headers).forEach(([key, value]) => {
mergedHeaders.set(key, value);
});
}

const mergedData = actualData && typeof actualData === 'object' ? { ...actualData, ...auth } : { ...auth };

// Always pass headers (empty headers object is valid)
return data(mergedData, { headers: mergedHeaders });
}

export async function terminateSession(request: Request, { returnTo }: { returnTo?: string } = {}) {
Expand Down