From de8c5aa3349d99f8f715ba95d4804e5f88f670ef Mon Sep 17 00:00:00 2001 From: Sweets Sweetman Date: Thu, 15 Jan 2026 20:52:37 -0500 Subject: [PATCH] feat(api): allow frontend to pass organizationId in chat requests - Add organizationId field to chatRequestSchema - Create validateOrganizationAccess function to check if an account can operate on behalf of an organization (either IS the org or is a member) - Validate organizationId in validateChatRequest and use it to override orgId when user is authorized - Return 403 error when user is not a member of the specified org - Add comprehensive tests for all scenarios Co-Authored-By: Claude Opus 4.5 --- lib/chat/__tests__/handleChatGenerate.test.ts | 4 + lib/chat/__tests__/handleChatStream.test.ts | 4 + .../integration/chatEndToEnd.test.ts | 4 + .../__tests__/validateChatRequest.test.ts | 129 ++++++++++++++++++ lib/chat/validateChatRequest.ts | 26 ++++ .../validateOrganizationAccess.test.ts | 101 ++++++++++++++ .../validateOrganizationAccess.ts | 41 ++++++ 7 files changed, 309 insertions(+) create mode 100644 lib/organizations/__tests__/validateOrganizationAccess.test.ts create mode 100644 lib/organizations/validateOrganizationAccess.ts diff --git a/lib/chat/__tests__/handleChatGenerate.test.ts b/lib/chat/__tests__/handleChatGenerate.test.ts index 4274e96b..04b0a05d 100644 --- a/lib/chat/__tests__/handleChatGenerate.test.ts +++ b/lib/chat/__tests__/handleChatGenerate.test.ts @@ -18,6 +18,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({ getApiKeyDetails: vi.fn(), })); +vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({ + validateOrganizationAccess: vi.fn(), +})); + vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); diff --git a/lib/chat/__tests__/handleChatStream.test.ts b/lib/chat/__tests__/handleChatStream.test.ts index 734964ed..b78918e3 100644 --- a/lib/chat/__tests__/handleChatStream.test.ts +++ b/lib/chat/__tests__/handleChatStream.test.ts @@ -18,6 +18,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({ getApiKeyDetails: vi.fn(), })); +vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({ + validateOrganizationAccess: vi.fn(), +})); + vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); diff --git a/lib/chat/__tests__/integration/chatEndToEnd.test.ts b/lib/chat/__tests__/integration/chatEndToEnd.test.ts index 86ef311e..27206c0a 100644 --- a/lib/chat/__tests__/integration/chatEndToEnd.test.ts +++ b/lib/chat/__tests__/integration/chatEndToEnd.test.ts @@ -31,6 +31,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({ getApiKeyDetails: vi.fn(), })); +vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({ + validateOrganizationAccess: vi.fn(), +})); + // Mock Supabase dependencies vi.mock("@/lib/supabase/account_emails/selectAccountEmails", () => ({ default: vi.fn(), diff --git a/lib/chat/__tests__/validateChatRequest.test.ts b/lib/chat/__tests__/validateChatRequest.test.ts index 6057b55e..b15ebf4e 100644 --- a/lib/chat/__tests__/validateChatRequest.test.ts +++ b/lib/chat/__tests__/validateChatRequest.test.ts @@ -19,15 +19,21 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({ getApiKeyDetails: vi.fn(), })); +vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({ + validateOrganizationAccess: vi.fn(), +})); + import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; import { getAuthenticatedAccountId } from "@/lib/auth/getAuthenticatedAccountId"; import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; import { getApiKeyDetails } from "@/lib/keys/getApiKeyDetails"; +import { validateOrganizationAccess } from "@/lib/organizations/validateOrganizationAccess"; const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId); const mockGetAuthenticatedAccountId = vi.mocked(getAuthenticatedAccountId); const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId); const mockGetApiKeyDetails = vi.mocked(getApiKeyDetails); +const mockValidateOrganizationAccess = vi.mocked(validateOrganizationAccess); // Helper to create mock NextRequest function createMockRequest(body: unknown, headers: Record = {}): Request { @@ -417,4 +423,127 @@ describe("validateChatRequest", () => { expect(result.success).toBe(false); }); }); + + describe("organizationId override", () => { + it("accepts organizationId in schema", () => { + const result = chatRequestSchema.safeParse({ + prompt: "test", + organizationId: "org-123", + }); + expect(result.success).toBe(true); + }); + + it("uses provided organizationId when user is member of org (bearer token)", async () => { + mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123"); + mockValidateOrganizationAccess.mockResolvedValue(true); + + const request = createMockRequest( + { prompt: "Hello", organizationId: "org-456" }, + { authorization: "Bearer valid-jwt-token" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).not.toBeInstanceOf(NextResponse); + expect((result as any).orgId).toBe("org-456"); + expect(mockValidateOrganizationAccess).toHaveBeenCalledWith({ + accountId: "user-account-123", + organizationId: "org-456", + }); + }); + + it("uses provided organizationId when user is member of org (API key)", async () => { + mockGetApiKeyAccountId.mockResolvedValue("api-key-account-123"); + mockGetApiKeyDetails.mockResolvedValue({ + accountId: "api-key-account-123", + orgId: null, + }); + mockValidateOrganizationAccess.mockResolvedValue(true); + + const request = createMockRequest( + { prompt: "Hello", organizationId: "org-789" }, + { "x-api-key": "personal-api-key" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).not.toBeInstanceOf(NextResponse); + expect((result as any).orgId).toBe("org-789"); + expect(mockValidateOrganizationAccess).toHaveBeenCalledWith({ + accountId: "api-key-account-123", + organizationId: "org-789", + }); + }); + + it("overwrites API key orgId with provided organizationId when user is member", async () => { + mockGetApiKeyAccountId.mockResolvedValue("org-account-123"); + mockGetApiKeyDetails.mockResolvedValue({ + accountId: "org-account-123", + orgId: "original-org-123", + }); + mockValidateOrganizationAccess.mockResolvedValue(true); + + const request = createMockRequest( + { prompt: "Hello", organizationId: "different-org-456" }, + { "x-api-key": "org-api-key" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).not.toBeInstanceOf(NextResponse); + expect((result as any).orgId).toBe("different-org-456"); + }); + + it("rejects organizationId when user is NOT a member of org", async () => { + mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123"); + mockValidateOrganizationAccess.mockResolvedValue(false); + + const request = createMockRequest( + { prompt: "Hello", organizationId: "org-not-member" }, + { authorization: "Bearer valid-jwt-token" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).toBeInstanceOf(NextResponse); + const json = await (result as NextResponse).json(); + expect(json.status).toBe("error"); + expect(json.message).toBe("Access denied to specified organizationId"); + }); + + it("uses API key orgId when no organizationId is provided", async () => { + mockGetApiKeyAccountId.mockResolvedValue("org-account-123"); + mockGetApiKeyDetails.mockResolvedValue({ + accountId: "org-account-123", + orgId: "api-key-org-123", + }); + + const request = createMockRequest( + { prompt: "Hello" }, + { "x-api-key": "org-api-key" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).not.toBeInstanceOf(NextResponse); + expect((result as any).orgId).toBe("api-key-org-123"); + // Should not validate org access when no organizationId is provided + expect(mockValidateOrganizationAccess).not.toHaveBeenCalled(); + }); + + it("returns null orgId when no organizationId provided and bearer token auth", async () => { + mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123"); + + const request = createMockRequest( + { prompt: "Hello" }, + { authorization: "Bearer valid-jwt-token" }, + ); + + const result = await validateChatRequest(request as any); + + expect(result).not.toBeInstanceOf(NextResponse); + expect((result as any).orgId).toBeNull(); + expect(mockValidateOrganizationAccess).not.toHaveBeenCalled(); + }); + }); }); diff --git a/lib/chat/validateChatRequest.ts b/lib/chat/validateChatRequest.ts index 1808f771..5805dba2 100644 --- a/lib/chat/validateChatRequest.ts +++ b/lib/chat/validateChatRequest.ts @@ -7,6 +7,7 @@ import { getAuthenticatedAccountId } from "@/lib/auth/getAuthenticatedAccountId" import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; import { getMessages } from "@/lib/messages/getMessages"; import { getApiKeyDetails } from "@/lib/keys/getApiKeyDetails"; +import { validateOrganizationAccess } from "@/lib/organizations/validateOrganizationAccess"; export const chatRequestSchema = z .object({ @@ -17,6 +18,7 @@ export const chatRequestSchema = z roomId: z.string().optional(), accountId: z.string().optional(), artistId: z.string().optional(), + organizationId: z.string().optional(), model: z.string().optional(), excludeTools: z.array(z.string()).optional(), }) @@ -138,6 +140,30 @@ export async function validateChatRequest( accountId = accountIdOrError; } + // Handle organizationId override from request body + if (validatedBody.organizationId) { + const hasOrgAccess = await validateOrganizationAccess({ + accountId, + organizationId: validatedBody.organizationId, + }); + + if (!hasOrgAccess) { + return NextResponse.json( + { + status: "error", + message: "Access denied to specified organizationId", + }, + { + status: 403, + headers: getCorsHeaders(), + }, + ); + } + + // Use the provided organizationId as orgId + orgId = validatedBody.organizationId; + } + // Normalize chat content: // - If messages are provided, keep them as-is // - If only prompt is provided, convert it into a single user UIMessage diff --git a/lib/organizations/__tests__/validateOrganizationAccess.test.ts b/lib/organizations/__tests__/validateOrganizationAccess.test.ts new file mode 100644 index 00000000..27f26fdb --- /dev/null +++ b/lib/organizations/__tests__/validateOrganizationAccess.test.ts @@ -0,0 +1,101 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { validateOrganizationAccess } from "../validateOrganizationAccess"; + +// Mock getAccountOrganizations supabase lib +vi.mock("@/lib/supabase/account_organization_ids/getAccountOrganizations", () => ({ + getAccountOrganizations: vi.fn(), +})); + +import { getAccountOrganizations } from "@/lib/supabase/account_organization_ids/getAccountOrganizations"; + +const mockGetAccountOrganizations = vi.mocked(getAccountOrganizations); + +describe("validateOrganizationAccess", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe("account is the organization", () => { + it("returns true when accountId equals organizationId", async () => { + const result = await validateOrganizationAccess({ + accountId: "org-123", + organizationId: "org-123", + }); + + expect(result).toBe(true); + // Should not query database when account IS the org + expect(mockGetAccountOrganizations).not.toHaveBeenCalled(); + }); + }); + + describe("account is a member of the organization", () => { + it("returns true when account is a member of the organization", async () => { + mockGetAccountOrganizations.mockResolvedValue([ + { + account_id: "member-account-456", + organization_id: "org-123", + created_at: new Date().toISOString(), + organization: null, + }, + ]); + + const result = await validateOrganizationAccess({ + accountId: "member-account-456", + organizationId: "org-123", + }); + + expect(result).toBe(true); + expect(mockGetAccountOrganizations).toHaveBeenCalledWith({ + accountId: "member-account-456", + organizationId: "org-123", + }); + }); + + it("returns false when account is NOT a member of the organization", async () => { + mockGetAccountOrganizations.mockResolvedValue([]); + + const result = await validateOrganizationAccess({ + accountId: "non-member-account-789", + organizationId: "org-123", + }); + + expect(result).toBe(false); + expect(mockGetAccountOrganizations).toHaveBeenCalledWith({ + accountId: "non-member-account-789", + organizationId: "org-123", + }); + }); + }); + + describe("invalid inputs", () => { + it("returns false when accountId is empty", async () => { + const result = await validateOrganizationAccess({ + accountId: "", + organizationId: "org-123", + }); + + expect(result).toBe(false); + expect(mockGetAccountOrganizations).not.toHaveBeenCalled(); + }); + + it("returns false when organizationId is empty", async () => { + const result = await validateOrganizationAccess({ + accountId: "account-123", + organizationId: "", + }); + + expect(result).toBe(false); + expect(mockGetAccountOrganizations).not.toHaveBeenCalled(); + }); + + it("returns false when both are empty", async () => { + const result = await validateOrganizationAccess({ + accountId: "", + organizationId: "", + }); + + expect(result).toBe(false); + expect(mockGetAccountOrganizations).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/lib/organizations/validateOrganizationAccess.ts b/lib/organizations/validateOrganizationAccess.ts new file mode 100644 index 00000000..d4317377 --- /dev/null +++ b/lib/organizations/validateOrganizationAccess.ts @@ -0,0 +1,41 @@ +import { getAccountOrganizations } from "@/lib/supabase/account_organization_ids/getAccountOrganizations"; + +export interface ValidateOrganizationAccessParams { + accountId: string; + organizationId: string; +} + +/** + * Validates if an account can operate on behalf of an organization. + * + * Access rules: + * - If accountId equals organizationId (account IS the org), access is granted + * - Otherwise, checks if accountId is a member of the organization + * + * @param params - The validation parameters + * @param params.accountId - The account ID to validate + * @param params.organizationId - The organization ID to check access for + * @returns true if access is allowed, false otherwise + */ +export async function validateOrganizationAccess( + params: ValidateOrganizationAccessParams, +): Promise { + const { accountId, organizationId } = params; + + if (!accountId || !organizationId) { + return false; + } + + // Account IS the organization + if (accountId === organizationId) { + return true; + } + + // Check if account is a member of the organization + const memberships = await getAccountOrganizations({ + accountId, + organizationId, + }); + + return memberships.length > 0; +}