Skip to content

Commit 01f66ee

Browse files
authored
Fine-tune pipeline for classification (#320)
* Classification: db models and migration script (#305) * db models and migration script * Classification: Fine tuning Initiation and retrieve endpoint (#315) * Fine-tuning core, initiation, and retrieval * seperate session for bg task, and formating fixes * fixing alembic revision * Classification : Model evaluation of fine tuned models (#326) * Model evaluation of fine tuned models * fixing alembic revision * alembic revision fix * Classification : train and test data to s3 (#343) * alembic file for adding and removing columns * train and test s3 url column * updating alembic revision * formatting fix * Classification : retaining prediction and fetching data from s3 for model evaluation (#359) * adding new columns to model eval table * test data and prediction data s3 url changes * single migration file * status enum columns * document seeding * Classification : small fixes and storage related changes (#365) * first commit covering all * changing model name to fine tuned model in model eval * error handling in get cloud storage and document not found error handling * fixing alembic revision * uv lock * new uv lock file * updated uv lock file * coderabbit suggestions and removing unused imports * changes in uv lock file * making csv a supported file format, changing uv lock and pyproject toml
1 parent 98295d6 commit 01f66ee

28 files changed

+4257
-1425
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""add fine tuning and model evaluation table
2+
3+
Revision ID: 6ed6ed401847
4+
Revises: 40307ab77e9f
5+
Create Date: 2025-09-01 14:54:03.553608
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "6ed6ed401847"
15+
down_revision = "9f8a4af9d6fd"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
finetuning_status_enum = postgresql.ENUM(
21+
"pending",
22+
"running",
23+
"completed",
24+
"failed",
25+
name="finetuningstatus",
26+
create_type=False,
27+
)
28+
29+
modelevaluation_status_enum = postgresql.ENUM(
30+
"pending",
31+
"running",
32+
"completed",
33+
"failed",
34+
name="modelevaluationstatus",
35+
create_type=False,
36+
)
37+
38+
39+
def upgrade():
40+
finetuning_status_enum.create(op.get_bind(), checkfirst=True)
41+
modelevaluation_status_enum.create(op.get_bind(), checkfirst=True)
42+
op.create_table(
43+
"fine_tuning",
44+
sa.Column("base_model", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
45+
sa.Column("split_ratio", sa.Float(), nullable=False),
46+
sa.Column("document_id", sa.Uuid(), nullable=False),
47+
sa.Column(
48+
"training_file_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True
49+
),
50+
sa.Column("system_prompt", sa.Text(), nullable=False),
51+
sa.Column("id", sa.Integer(), nullable=False),
52+
sa.Column("provider_job_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
53+
sa.Column(
54+
"status",
55+
finetuning_status_enum,
56+
nullable=False,
57+
server_default="pending",
58+
),
59+
sa.Column(
60+
"fine_tuned_model", sqlmodel.sql.sqltypes.AutoString(), nullable=True
61+
),
62+
sa.Column(
63+
"train_data_s3_object", sqlmodel.sql.sqltypes.AutoString(), nullable=True
64+
),
65+
sa.Column(
66+
"test_data_s3_object", sqlmodel.sql.sqltypes.AutoString(), nullable=True
67+
),
68+
sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
69+
sa.Column("project_id", sa.Integer(), nullable=False),
70+
sa.Column("organization_id", sa.Integer(), nullable=False),
71+
sa.Column("is_deleted", sa.Boolean(), nullable=False),
72+
sa.Column("inserted_at", sa.DateTime(), nullable=False),
73+
sa.Column("updated_at", sa.DateTime(), nullable=False),
74+
sa.Column("deleted_at", sa.DateTime(), nullable=True),
75+
sa.ForeignKeyConstraint(
76+
["document_id"],
77+
["document.id"],
78+
),
79+
sa.ForeignKeyConstraint(
80+
["organization_id"], ["organization.id"], ondelete="CASCADE"
81+
),
82+
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
83+
sa.PrimaryKeyConstraint("id"),
84+
)
85+
op.create_table(
86+
"model_evaluation",
87+
sa.Column("fine_tuning_id", sa.Integer(), nullable=False),
88+
sa.Column("id", sa.Integer(), nullable=False),
89+
sa.Column("document_id", sa.Uuid(), nullable=False),
90+
sa.Column(
91+
"fine_tuned_model", sqlmodel.sql.sqltypes.AutoString(), nullable=False
92+
),
93+
sa.Column(
94+
"test_data_s3_object", sqlmodel.sql.sqltypes.AutoString(), nullable=False
95+
),
96+
sa.Column("base_model", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
97+
sa.Column("split_ratio", sa.Float(), nullable=False),
98+
sa.Column("system_prompt", sa.Text(), nullable=False),
99+
sa.Column("score", postgresql.JSON(astext_type=sa.Text()), nullable=True),
100+
sa.Column(
101+
"prediction_data_s3_object",
102+
sqlmodel.sql.sqltypes.AutoString(),
103+
nullable=True,
104+
),
105+
sa.Column(
106+
"status",
107+
modelevaluation_status_enum,
108+
nullable=False,
109+
server_default="pending",
110+
),
111+
sa.Column("error_message", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
112+
sa.Column("project_id", sa.Integer(), nullable=False),
113+
sa.Column("organization_id", sa.Integer(), nullable=False),
114+
sa.Column("is_deleted", sa.Boolean(), nullable=False),
115+
sa.Column("inserted_at", sa.DateTime(), nullable=False),
116+
sa.Column("updated_at", sa.DateTime(), nullable=False),
117+
sa.Column("deleted_at", sa.DateTime(), nullable=True),
118+
sa.ForeignKeyConstraint(
119+
["document_id"],
120+
["document.id"],
121+
),
122+
sa.ForeignKeyConstraint(
123+
["fine_tuning_id"], ["fine_tuning.id"], ondelete="CASCADE"
124+
),
125+
sa.ForeignKeyConstraint(
126+
["organization_id"], ["organization.id"], ondelete="CASCADE"
127+
),
128+
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
129+
sa.PrimaryKeyConstraint("id"),
130+
)
131+
132+
133+
def downgrade():
134+
op.drop_table("model_evaluation")
135+
op.drop_table("fine_tuning")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
This endpoint initiates the fine-tuning of an OpenAI model using your custom dataset that you would have uploaded using the upload document endpoint. The uploaded dataset must include:
2+
3+
- A column named `query`, `question`, or `message` containing user inputs or messages.
4+
- A column named `label` indicating whether a given message is a genuine query or not (e.g., casual conversation or small talk).
5+
6+
The split_ratio in the request body determines how your data is divided between training and testing. For example, a split ratio of 0.5 means 50% of your data will be used for training, and the remaining 50% for testing. You can also provide multiple split ratios—for instance, [0.7, 0.9]. This will trigger multiple fine-tuning jobs, one for each ratio, effectively training multiple models on different portions of your dataset. You would also need to specify a base model that you want to finetune.
7+
8+
The system_prompt field specified in the request body allows you to define an initial instruction or context-setting message that will be included in the training data. This message helps the model learn how it is expected to behave when responding to user inputs. It is prepended as the first message in each training example during fine-tuning.
9+
10+
The system handles the fine-tuning process by interacting with OpenAI's APIs under the hood. These include:
11+
12+
- [Openai File create to upload your training and testing files](https://platform.openai.com/docs/api-reference/files/create)
13+
14+
- [Openai Fine Tuning Job create to initiate each fine-tuning job](https://platform.openai.com/docs/api-reference/fine_tuning/create)
15+
16+
If successful, the response will include a message along with a list of fine-tuning jobs that were initiated. Each job object includes:
17+
18+
- id: the internal ID of the fine-tuning job
19+
- document_id: the ID of the document used for fine-tuning
20+
- split_ratio: the data split used for that job
21+
- status: the initial status of the job (usually "pending")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Refreshes the status of a fine-tuning job by retrieving the latest information from OpenAI.
2+
If there are any changes in status, fine-tuned model, or error message, the local job record is updated accordingly.
3+
Returns the latest state of the job.
4+
5+
OpenAI’s job status is retrieved using their [Fine-tuning Job Retrieve API](https://platform.openai.com/docs/api-reference/fine_tuning/retrieve).

backend/app/api/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
utils,
1919
onboarding,
2020
credentials,
21+
fine_tuning,
22+
model_evaluation,
2123
)
2224
from app.core.config import settings
2325

@@ -38,6 +40,9 @@
3840
api_router.include_router(threads.router)
3941
api_router.include_router(users.router)
4042
api_router.include_router(utils.router)
43+
api_router.include_router(fine_tuning.router)
44+
api_router.include_router(model_evaluation.router)
45+
4146

4247
if settings.ENVIRONMENT in ["development", "testing"]:
4348
api_router.include_router(private.router)

0 commit comments

Comments
 (0)