Skip to content

Commit

Permalink
🤖 feat: Private Assistants (danny-avila#2881)
Browse files Browse the repository at this point in the history
* feat: add configuration for user private assistants

* filter private assistant message requests

* add test for privateAssistants

* add privateAssistants configuration to tests

* fix: destructuring error when assistants config is not added

* chore: revert chat controller changes

* chore: add payload type, add metadata types

* feat: validateAssistant

* refactor(fetchAssistants): allow for flexibility

* feat: validateAuthor

* refactor: return all assistants to ADMIN role

* feat: add assistant doc on assistant creation

* refactor(listAssistants): use `listAllAssistants` to exhaustively fetch all assistants

* chore: add suggestion to tts error

* refactor(validateAuthor): attempt database check first

* refactor: author validation when patching/deleting assistant

---------

Co-authored-by: Leon Juenemann <[email protected]>
  • Loading branch information
danny-avila and Leon Juenemann authored May 28, 2024
1 parent 9f2538f commit 5dc5d87
Show file tree
Hide file tree
Showing 20 changed files with 308 additions and 109 deletions.
4 changes: 2 additions & 2 deletions api/models/Assistant.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
*/
const updateAssistant = async (searchParams, updateData, session = null) => {
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
const options = { new: true, upsert: true, session };
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
};
Expand Down Expand Up @@ -52,7 +52,7 @@ const deleteAssistant = async (searchParams) => {
};

module.exports = {
updateAssistant,
updateAssistantDoc,
deleteAssistant,
getAssistants,
getAssistant,
Expand Down
31 changes: 4 additions & 27 deletions api/server/controllers/assistants/chatV1.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const {
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
Expand All @@ -31,15 +32,14 @@ const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');

const { handleAbortError } = require('~/server/middleware');

const ten_minutes = 1000 * 60 * 10;

/**
* @route POST /
* @desc Chat with an assistant
* @access Public
* @param {Express.Request} req - The request object, containing the request data.
* @param {object} req - The request object, containing the request data.
* @param {object} req.body - The request payload.
* @param {Express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
Expand All @@ -60,30 +60,6 @@ const chatV1 = async (req, res) => {
parentMessageId: _parentId = Constants.NO_PARENT,
} = req.body;

/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];

if (assistantsConfig) {
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
error,
});
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
});
}
}

/** @type {OpenAIClient} */
let openai;
/** @type {string|undefined} - the current thread id */
Expand Down Expand Up @@ -311,6 +287,7 @@ const chatV1 = async (req, res) => {
});

openai = _openai;
await validateAuthor({ req, openai });

if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
Expand Down
28 changes: 2 additions & 26 deletions api/server/controllers/assistants/chatV2.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const {
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { getTransactions } = require('~/models/Transaction');
Expand All @@ -30,8 +31,6 @@ const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');

const { handleAbortError } = require('~/server/middleware');

const ten_minutes = 1000 * 60 * 10;

/**
Expand Down Expand Up @@ -60,30 +59,6 @@ const chatV2 = async (req, res) => {
parentMessageId: _parentId = Constants.NO_PARENT,
} = req.body;

/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];

if (assistantsConfig) {
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
error,
});
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
});
}
}

/** @type {OpenAIClient} */
let openai;
/** @type {string|undefined} - the current thread id */
Expand Down Expand Up @@ -309,6 +284,7 @@ const chatV2 = async (req, res) => {
});

openai = _openai;
await validateAuthor({ req, openai });

if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
Expand Down
123 changes: 117 additions & 6 deletions api/server/controllers/assistants/helpers.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider');
const {
EModelEndpoint,
CacheKeys,
defaultAssistantsVersion,
defaultOrderQuery,
} = require('librechat-data-provider');
const {
initializeClient: initAzureClient,
} = require('~/server/services/Endpoints/azureAssistants');
Expand Down Expand Up @@ -35,6 +40,7 @@ const getCurrentVersion = async (req, endpoint) => {
* Initializes the client with the current request and response objects and lists assistants
* according to the query parameters. This function abstracts the logic for non-Azure paths.
*
* @deprecated
* @async
* @param {object} params - The parameters object.
* @param {object} params.req - The request object, used for initializing the client.
Expand All @@ -43,11 +49,65 @@ const getCurrentVersion = async (req, endpoint) => {
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
*/
const listAssistants = async ({ req, res, version, query }) => {
const _listAssistants = async ({ req, res, version, query }) => {
const { openai } = await getOpenAIClient({ req, res, version });
return openai.beta.assistants.list(query);
};

/**
* Fetches all assistants based on provided query params, until `has_more` is `false`.
*
* @async
* @param {object} params - The parameters object.
* @param {object} params.req - The request object, used for initializing the client.
* @param {object} params.res - The response object, used for initializing the client.
* @param {string} params.version - The API version to use.
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
*/
const listAllAssistants = async ({ req, res, version, query }) => {
/** @type {{ openai: OpenAIClient }} */
const { openai } = await getOpenAIClient({ req, res, version });
const allAssistants = [];

let first_id;
let last_id;
let afterToken = query.after;
let hasMore = true;

while (hasMore) {
const response = await openai.beta.assistants.list({
...query,
after: afterToken,
});

const { body } = response;

allAssistants.push(...body.data);
hasMore = body.has_more;

if (!first_id) {
first_id = body.first_id;
}

if (hasMore) {
afterToken = body.last_id;
} else {
last_id = body.last_id;
}
}

return {
data: allAssistants,
body: {
data: allAssistants,
has_more: false,
first_id,
last_id,
},
};
};

/**
* Asynchronously lists assistants for Azure configured groups.
*
Expand Down Expand Up @@ -82,7 +142,7 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
/* The specified model is only necessary to
fetch assistants for the shared instance */
req.body.model = currentModelTuples[0][0];
promises.push(listAssistants({ req, res, version, query }));
promises.push(listAllAssistants({ req, res, version, query }));
}

const resolvedQueries = await Promise.all(promises);
Expand Down Expand Up @@ -133,24 +193,75 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
return result;
}

const fetchAssistants = async (req, res) => {
const { limit = 100, order = 'desc', after, before, endpoint } = req.query;
/**
* Returns a list of assistants.
* @param {object} params
* @param {object} params.req - Express Request
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
* @param {object} params.res - Express Response
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
*/
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
const {
limit = 100,
order = 'desc',
after,
before,
endpoint,
} = req.query ?? {
endpoint: overrideEndpoint,
...defaultOrderQuery,
};

const version = await getCurrentVersion(req, endpoint);
const query = { limit, order, after, before };

/** @type {AssistantListResponse} */
let body;

if (endpoint === EModelEndpoint.assistants) {
({ body } = await listAssistants({ req, res, version, query }));
({ body } = await listAllAssistants({ req, res, version, query }));
} else if (endpoint === EModelEndpoint.azureAssistants) {
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
}

if (req.user.role === 'ADMIN') {
return body;
} else if (!req.app.locals[endpoint]) {
return body;
}

body.data = filterAssistants({
userId: req.user.id,
assistants: body.data,
assistantsConfig: req.app.locals[endpoint],
});
return body;
};

/**
* Filter assistants based on configuration.
*
* @param {object} params - The parameters object.
* @param {string} params.userId - The user ID to filter private assistants.
* @param {Assistant[]} params.assistants - The list of assistants to filter.
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
* @returns {Assistant[]} - The filtered list of assistants.
*/
function filterAssistants({ assistants, userId, assistantsConfig }) {
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
if (privateAssistants) {
return assistants.filter((assistant) => userId === assistant.metadata?.author);
} else if (supportedIds?.length) {
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
} else if (excludedIds?.length) {
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
}
return assistants;
}

module.exports = {
getOpenAIClient,
fetchAssistants,
Expand Down
25 changes: 9 additions & 16 deletions api/server/controllers/assistants/v1.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
const { FileContext } = require('librechat-data-provider');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { deleteAssistantActions } = require('~/server/services/ActionService');
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
const { uploadImageBuffer } = require('~/server/services/Files/process');
const { updateAssistant, getAssistants } = require('~/models/Assistant');
const { getOpenAIClient, fetchAssistants } = require('./helpers');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
Expand Down Expand Up @@ -40,9 +41,11 @@ const createAssistant = async (req, res) => {
};

const assistant = await openai.beta.assistants.create(assistantData);
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
if (azureModelIdentifier) {
assistant.model = azureModelIdentifier;
}
await promise;
logger.debug('/assistants/', assistant);
res.status(201).json(assistant);
} catch (error) {
Expand All @@ -61,7 +64,6 @@ const retrieveAssistant = async (req, res) => {
try {
/* NOTE: not actually being used right now */
const { openai } = await getOpenAIClient({ req, res });

const assistant_id = req.params.id;
const assistant = await openai.beta.assistants.retrieve(assistant_id);
res.json(assistant);
Expand All @@ -83,6 +85,7 @@ const retrieveAssistant = async (req, res) => {
const patchAssistant = async (req, res) => {
try {
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });

const assistant_id = req.params.id;
const { endpoint: _e, ...updateData } = req.body;
Expand Down Expand Up @@ -119,6 +122,7 @@ const patchAssistant = async (req, res) => {
const deleteAssistant = async (req, res) => {
try {
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });

const assistant_id = req.params.id;
const deletionStatus = await openai.beta.assistants.del(assistant_id);
Expand All @@ -141,19 +145,7 @@ const deleteAssistant = async (req, res) => {
*/
const listAssistants = async (req, res) => {
try {
const body = await fetchAssistants(req, res);

if (req.app.locals?.[req.query.endpoint]) {
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals[req.query.endpoint];
const { supportedIds, excludedIds } = assistantsConfig;
if (supportedIds?.length) {
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
} else if (excludedIds?.length) {
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
}
}

const body = await fetchAssistants({ req, res });
res.json(body);
} catch (error) {
logger.error('[/assistants] Error listing assistants', error);
Expand Down Expand Up @@ -195,6 +187,7 @@ const uploadAssistantAvatar = async (req, res) => {

let { metadata: _metadata = '{}' } = req.body;
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });

const image = await uploadImageBuffer({
req,
Expand Down Expand Up @@ -229,7 +222,7 @@ const uploadAssistantAvatar = async (req, res) => {

const promises = [];
promises.push(
updateAssistant(
updateAssistantDoc(
{ assistant_id },
{
avatar: {
Expand Down
Loading

0 comments on commit 5dc5d87

Please sign in to comment.