Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/chat/__tests__/handleChatGenerate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({
getApiKeyDetails: vi.fn(),
}));

vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({
validateOrganizationAccess: vi.fn(),
}));

vi.mock("@/lib/chat/setupChatRequest", () => ({
setupChatRequest: vi.fn(),
}));
Expand Down
4 changes: 4 additions & 0 deletions lib/chat/__tests__/handleChatStream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({
getApiKeyDetails: vi.fn(),
}));

vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({
validateOrganizationAccess: vi.fn(),
}));

vi.mock("@/lib/chat/setupChatRequest", () => ({
setupChatRequest: vi.fn(),
}));
Expand Down
4 changes: 4 additions & 0 deletions lib/chat/__tests__/integration/chatEndToEnd.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({
getApiKeyDetails: vi.fn(),
}));

vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({
validateOrganizationAccess: vi.fn(),
}));

// Mock Supabase dependencies
vi.mock("@/lib/supabase/account_emails/selectAccountEmails", () => ({
default: vi.fn(),
Expand Down
129 changes: 129 additions & 0 deletions lib/chat/__tests__/validateChatRequest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@ vi.mock("@/lib/keys/getApiKeyDetails", () => ({
getApiKeyDetails: vi.fn(),
}));

vi.mock("@/lib/organizations/validateOrganizationAccess", () => ({
validateOrganizationAccess: vi.fn(),
}));

import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId";
import { getAuthenticatedAccountId } from "@/lib/auth/getAuthenticatedAccountId";
import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId";
import { getApiKeyDetails } from "@/lib/keys/getApiKeyDetails";
import { validateOrganizationAccess } from "@/lib/organizations/validateOrganizationAccess";

const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId);
const mockGetAuthenticatedAccountId = vi.mocked(getAuthenticatedAccountId);
const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId);
const mockGetApiKeyDetails = vi.mocked(getApiKeyDetails);
const mockValidateOrganizationAccess = vi.mocked(validateOrganizationAccess);

// Helper to create mock NextRequest
function createMockRequest(body: unknown, headers: Record<string, string> = {}): Request {
Expand Down Expand Up @@ -417,4 +423,127 @@ describe("validateChatRequest", () => {
expect(result.success).toBe(false);
});
});

describe("organizationId override", () => {
it("accepts organizationId in schema", () => {
const result = chatRequestSchema.safeParse({
prompt: "test",
organizationId: "org-123",
});
expect(result.success).toBe(true);
});

it("uses provided organizationId when user is member of org (bearer token)", async () => {
mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123");
mockValidateOrganizationAccess.mockResolvedValue(true);

const request = createMockRequest(
{ prompt: "Hello", organizationId: "org-456" },
{ authorization: "Bearer valid-jwt-token" },
);

const result = await validateChatRequest(request as any);

expect(result).not.toBeInstanceOf(NextResponse);
expect((result as any).orgId).toBe("org-456");
expect(mockValidateOrganizationAccess).toHaveBeenCalledWith({
accountId: "user-account-123",
organizationId: "org-456",
});
});

it("uses provided organizationId when user is member of org (API key)", async () => {
mockGetApiKeyAccountId.mockResolvedValue("api-key-account-123");
mockGetApiKeyDetails.mockResolvedValue({
accountId: "api-key-account-123",
orgId: null,
});
mockValidateOrganizationAccess.mockResolvedValue(true);

const request = createMockRequest(
{ prompt: "Hello", organizationId: "org-789" },
{ "x-api-key": "personal-api-key" },
);

const result = await validateChatRequest(request as any);

expect(result).not.toBeInstanceOf(NextResponse);
expect((result as any).orgId).toBe("org-789");
expect(mockValidateOrganizationAccess).toHaveBeenCalledWith({
accountId: "api-key-account-123",
organizationId: "org-789",
});
});

it("overwrites API key orgId with provided organizationId when user is member", async () => {
mockGetApiKeyAccountId.mockResolvedValue("org-account-123");
mockGetApiKeyDetails.mockResolvedValue({
accountId: "org-account-123",
orgId: "original-org-123",
});
mockValidateOrganizationAccess.mockResolvedValue(true);

const request = createMockRequest(
{ prompt: "Hello", organizationId: "different-org-456" },
{ "x-api-key": "org-api-key" },
);

const result = await validateChatRequest(request as any);

expect(result).not.toBeInstanceOf(NextResponse);
expect((result as any).orgId).toBe("different-org-456");
});

it("rejects organizationId when user is NOT a member of org", async () => {
mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123");
mockValidateOrganizationAccess.mockResolvedValue(false);

const request = createMockRequest(
{ prompt: "Hello", organizationId: "org-not-member" },
{ authorization: "Bearer valid-jwt-token" },
);

const result = await validateChatRequest(request as any);

expect(result).toBeInstanceOf(NextResponse);
const json = await (result as NextResponse).json();
expect(json.status).toBe("error");
expect(json.message).toBe("Access denied to specified organizationId");
});

it("uses API key orgId when no organizationId is provided", async () => {
mockGetApiKeyAccountId.mockResolvedValue("org-account-123");
mockGetApiKeyDetails.mockResolvedValue({
accountId: "org-account-123",
orgId: "api-key-org-123",
});

const request = createMockRequest(
{ prompt: "Hello" },
{ "x-api-key": "org-api-key" },
);

const result = await validateChatRequest(request as any);

expect(result).not.toBeInstanceOf(NextResponse);
expect((result as any).orgId).toBe("api-key-org-123");
// Should not validate org access when no organizationId is provided
expect(mockValidateOrganizationAccess).not.toHaveBeenCalled();
});

it("returns null orgId when no organizationId provided and bearer token auth", async () => {
mockGetAuthenticatedAccountId.mockResolvedValue("user-account-123");

const request = createMockRequest(
{ prompt: "Hello" },
{ authorization: "Bearer valid-jwt-token" },
);

const result = await validateChatRequest(request as any);

expect(result).not.toBeInstanceOf(NextResponse);
expect((result as any).orgId).toBeNull();
expect(mockValidateOrganizationAccess).not.toHaveBeenCalled();
});
});
});
26 changes: 26 additions & 0 deletions lib/chat/validateChatRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { getAuthenticatedAccountId } from "@/lib/auth/getAuthenticatedAccountId"
import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId";
import { getMessages } from "@/lib/messages/getMessages";
import { getApiKeyDetails } from "@/lib/keys/getApiKeyDetails";
import { validateOrganizationAccess } from "@/lib/organizations/validateOrganizationAccess";

export const chatRequestSchema = z
.object({
Expand All @@ -17,6 +18,7 @@ export const chatRequestSchema = z
roomId: z.string().optional(),
accountId: z.string().optional(),
artistId: z.string().optional(),
organizationId: z.string().optional(),
model: z.string().optional(),
excludeTools: z.array(z.string()).optional(),
})
Expand Down Expand Up @@ -138,6 +140,30 @@ export async function validateChatRequest(
accountId = accountIdOrError;
}

// Handle organizationId override from request body
if (validatedBody.organizationId) {
const hasOrgAccess = await validateOrganizationAccess({
accountId,
organizationId: validatedBody.organizationId,
});

if (!hasOrgAccess) {
return NextResponse.json(
{
status: "error",
message: "Access denied to specified organizationId",
},
{
status: 403,
headers: getCorsHeaders(),
},
);
}

// Use the provided organizationId as orgId
orgId = validatedBody.organizationId;
}

// Normalize chat content:
// - If messages are provided, keep them as-is
// - If only prompt is provided, convert it into a single user UIMessage
Expand Down
101 changes: 101 additions & 0 deletions lib/organizations/__tests__/validateOrganizationAccess.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { validateOrganizationAccess } from "../validateOrganizationAccess";

// Mock getAccountOrganizations supabase lib
vi.mock("@/lib/supabase/account_organization_ids/getAccountOrganizations", () => ({
getAccountOrganizations: vi.fn(),
}));

import { getAccountOrganizations } from "@/lib/supabase/account_organization_ids/getAccountOrganizations";

const mockGetAccountOrganizations = vi.mocked(getAccountOrganizations);

describe("validateOrganizationAccess", () => {
beforeEach(() => {
vi.clearAllMocks();
});

describe("account is the organization", () => {
it("returns true when accountId equals organizationId", async () => {
const result = await validateOrganizationAccess({
accountId: "org-123",
organizationId: "org-123",
});

expect(result).toBe(true);
// Should not query database when account IS the org
expect(mockGetAccountOrganizations).not.toHaveBeenCalled();
});
});

describe("account is a member of the organization", () => {
it("returns true when account is a member of the organization", async () => {
mockGetAccountOrganizations.mockResolvedValue([
{
account_id: "member-account-456",
organization_id: "org-123",
created_at: new Date().toISOString(),
organization: null,
},
]);

const result = await validateOrganizationAccess({
accountId: "member-account-456",
organizationId: "org-123",
});

expect(result).toBe(true);
expect(mockGetAccountOrganizations).toHaveBeenCalledWith({
accountId: "member-account-456",
organizationId: "org-123",
});
});

it("returns false when account is NOT a member of the organization", async () => {
mockGetAccountOrganizations.mockResolvedValue([]);

const result = await validateOrganizationAccess({
accountId: "non-member-account-789",
organizationId: "org-123",
});

expect(result).toBe(false);
expect(mockGetAccountOrganizations).toHaveBeenCalledWith({
accountId: "non-member-account-789",
organizationId: "org-123",
});
});
});

describe("invalid inputs", () => {
it("returns false when accountId is empty", async () => {
const result = await validateOrganizationAccess({
accountId: "",
organizationId: "org-123",
});

expect(result).toBe(false);
expect(mockGetAccountOrganizations).not.toHaveBeenCalled();
});

it("returns false when organizationId is empty", async () => {
const result = await validateOrganizationAccess({
accountId: "account-123",
organizationId: "",
});

expect(result).toBe(false);
expect(mockGetAccountOrganizations).not.toHaveBeenCalled();
});

it("returns false when both are empty", async () => {
const result = await validateOrganizationAccess({
accountId: "",
organizationId: "",
});

expect(result).toBe(false);
expect(mockGetAccountOrganizations).not.toHaveBeenCalled();
});
});
});
41 changes: 41 additions & 0 deletions lib/organizations/validateOrganizationAccess.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { getAccountOrganizations } from "@/lib/supabase/account_organization_ids/getAccountOrganizations";

export interface ValidateOrganizationAccessParams {
accountId: string;
organizationId: string;
}

/**
* Validates if an account can operate on behalf of an organization.
*
* Access rules:
* - If accountId equals organizationId (account IS the org), access is granted
* - Otherwise, checks if accountId is a member of the organization
*
* @param params - The validation parameters
* @param params.accountId - The account ID to validate
* @param params.organizationId - The organization ID to check access for
* @returns true if access is allowed, false otherwise
*/
export async function validateOrganizationAccess(
params: ValidateOrganizationAccessParams,
): Promise<boolean> {
const { accountId, organizationId } = params;

if (!accountId || !organizationId) {
return false;
}

// Account IS the organization
if (accountId === organizationId) {
return true;
}

// Check if account is a member of the organization
const memberships = await getAccountOrganizations({
accountId,
organizationId,
});

return memberships.length > 0;
}