Skip to content

Commit 08647ab

Browse files
committed
an rpc endpoint to mark a task "done".
1 parent a57cb4b commit 08647ab

File tree

6 files changed

+247
-0
lines changed

6 files changed

+247
-0
lines changed

go/pkg/sysdb/coordinator/task.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,39 @@ func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteT
239239
}, nil
240240
}
241241

242+
// Mark a task run as complete and set the nonce for the next task run.
243+
func (s *Coordinator) DoneTask(ctx context.Context, req *coordinatorpb.DoneTaskRequest) (*coordinatorpb.DoneTaskResponse, error) {
244+
if req.TaskId == nil {
245+
log.Error("DoneTask: task_id is required")
246+
return nil, status.Errorf(codes.InvalidArgument, "task_id is required")
247+
}
248+
249+
if req.TaskRunNonce == nil {
250+
log.Error("DoneTask: task_run_nonce is required")
251+
return nil, status.Errorf(codes.InvalidArgument, "task_run_nonce is required")
252+
}
253+
254+
taskID, err := uuid.Parse(*req.TaskId)
255+
if err != nil {
256+
log.Error("DoneTask: invalid task_id", zap.Error(err))
257+
return nil, status.Errorf(codes.InvalidArgument, "invalid task_id: %v", err)
258+
}
259+
260+
taskRunNonce, err := uuid.Parse(*req.TaskRunNonce)
261+
if err != nil {
262+
log.Error("DoneTask: invalid task_run_nonce", zap.Error(err))
263+
return nil, status.Errorf(codes.InvalidArgument, "invalid task_run_nonce: %v", err)
264+
}
265+
266+
err = s.catalog.metaDomain.TaskDb(ctx).DoneTask(taskID, taskRunNonce)
267+
if err != nil {
268+
log.Error("DoneTask failed", zap.Error(err), zap.String("task_id", taskID.String()))
269+
return nil, err
270+
}
271+
272+
return &coordinatorpb.DoneTaskResponse{}, nil
273+
}
274+
242275
// GetOperators retrieves all operators from the database
243276
func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
244277
operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll()
@@ -263,6 +296,8 @@ func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOp
263296
}, nil
264297
}
265298

299+
// PeekScheduleByCOllectionId gives, for a vector of collection IDs, a vector of schedule entries,
300+
// including when to run and the nonce to use for said run.
266301
func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) {
267302
tasks, err := s.catalog.metaDomain.TaskDb(ctx).PeekScheduleByCollectionId(req.CollectionId)
268303
if err != nil {

go/pkg/sysdb/grpc/task_service.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRe
4444
return res, nil
4545
}
4646

47+
func (s *Server) DoneTask(ctx context.Context, req *coordinatorpb.DoneTaskRequest) (*coordinatorpb.DoneTaskResponse, error) {
48+
log.Info("DoneTask", zap.String("collection_id", req.GetCollectionId()), zap.String("task_id", req.GetTaskId()))
49+
50+
res, err := s.coordinator.DoneTask(ctx, req)
51+
if err != nil {
52+
log.Error("DoneTask failed", zap.Error(err))
53+
return nil, err
54+
}
55+
56+
return res, nil
57+
}
58+
4759
func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
4860
log.Info("GetOperators")
4961

go/pkg/sysdb/metastore/db/dao/task.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package dao
22

33
import (
44
"errors"
5+
"time"
56

67
"github.com/chroma-core/chroma/go/pkg/common"
78
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
9+
"github.com/google/uuid"
810
"github.com/jackc/pgx/v5/pgconn"
911
"github.com/pingcap/log"
1012
"go.uber.org/zap"
@@ -58,6 +60,55 @@ func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.
5860
return &task, nil
5961
}
6062

63+
func (s *taskDb) GetByID(taskID uuid.UUID) (*dbmodel.Task, error) {
64+
var task dbmodel.Task
65+
err := s.db.
66+
Where("task_id = ?", taskID).
67+
Where("is_deleted = ?", false).
68+
First(&task).Error
69+
70+
if err != nil {
71+
if errors.Is(err, gorm.ErrRecordNotFound) {
72+
return nil, nil
73+
}
74+
log.Error("GetByID failed", zap.Error(err), zap.String("task_id", taskID.String()))
75+
return nil, err
76+
}
77+
return &task, nil
78+
}
79+
80+
func (s *taskDb) DoneTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error {
81+
nextNonce, err := uuid.NewV7()
82+
if err != nil {
83+
log.Error("DoneTask: failed to generate next nonce", zap.Error(err))
84+
return err
85+
}
86+
87+
now := time.Now()
88+
result := s.db.Exec(`
89+
UPDATE tasks
90+
SET last_run = ?,
91+
next_nonce = ?,
92+
current_attempts = 0,
93+
updated_at = ?
94+
WHERE task_id = ?
95+
AND next_nonce = ?
96+
AND is_deleted = false
97+
`, now, nextNonce, now, taskID, taskRunNonce)
98+
99+
if result.Error != nil {
100+
log.Error("DoneTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String()))
101+
return result.Error
102+
}
103+
104+
if result.RowsAffected == 0 {
105+
log.Warn("DoneTask: no rows affected", zap.String("task_id", taskID.String()), zap.String("task_run_nonce", taskRunNonce.String()))
106+
return common.ErrTaskNotFound
107+
}
108+
109+
return nil
110+
}
111+
61112
func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error {
62113
// Update task name and is_deleted in a single query
63114
// Format: _deleted_<original_name>_<input_collection_id>_<task_id>

go/pkg/sysdb/metastore/db/dao/task_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,144 @@ func (suite *TaskDbTestSuite) TestTaskDb_DeleteAll() {
312312
}
313313
}
314314

315+
func (suite *TaskDbTestSuite) TestTaskDb_GetByID() {
316+
taskID := uuid.New()
317+
operatorID := dbmodel.OperatorRecordCounter
318+
nextNonce, _ := uuid.NewV7()
319+
320+
task := &dbmodel.Task{
321+
ID: taskID,
322+
Name: "test-get-by-id-task",
323+
OperatorID: operatorID,
324+
InputCollectionID: "input_col_id",
325+
OutputCollectionName: "output_col_name",
326+
OperatorParams: "{}",
327+
TenantID: "tenant1",
328+
DatabaseID: "db1",
329+
MinRecordsForTask: 100,
330+
NextNonce: nextNonce,
331+
}
332+
333+
err := suite.Db.Insert(task)
334+
suite.Require().NoError(err)
335+
336+
retrieved, err := suite.Db.GetByID(taskID)
337+
suite.Require().NoError(err)
338+
suite.Require().NotNil(retrieved)
339+
suite.Require().Equal(task.ID, retrieved.ID)
340+
suite.Require().Equal(task.Name, retrieved.Name)
341+
suite.Require().Equal(task.OperatorID, retrieved.OperatorID)
342+
343+
suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
344+
}
345+
346+
func (suite *TaskDbTestSuite) TestTaskDb_GetByID_NotFound() {
347+
retrieved, err := suite.Db.GetByID(uuid.New())
348+
suite.Require().NoError(err)
349+
suite.Require().Nil(retrieved)
350+
}
351+
352+
func (suite *TaskDbTestSuite) TestTaskDb_GetByID_IgnoresDeleted() {
353+
taskID := uuid.New()
354+
operatorID := dbmodel.OperatorRecordCounter
355+
nextNonce, _ := uuid.NewV7()
356+
357+
task := &dbmodel.Task{
358+
ID: taskID,
359+
Name: "test-get-by-id-deleted",
360+
OperatorID: operatorID,
361+
InputCollectionID: "input1",
362+
OutputCollectionName: "output1",
363+
OperatorParams: "{}",
364+
TenantID: "tenant1",
365+
DatabaseID: "db1",
366+
MinRecordsForTask: 100,
367+
NextNonce: nextNonce,
368+
}
369+
370+
err := suite.Db.Insert(task)
371+
suite.Require().NoError(err)
372+
373+
err = suite.Db.SoftDelete("input1", "test-get-by-id-deleted")
374+
suite.Require().NoError(err)
375+
376+
retrieved, err := suite.Db.GetByID(taskID)
377+
suite.Require().NoError(err)
378+
suite.Require().Nil(retrieved)
379+
380+
suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
381+
}
382+
383+
func (suite *TaskDbTestSuite) TestTaskDb_DoneTask() {
384+
taskID := uuid.New()
385+
operatorID := dbmodel.OperatorRecordCounter
386+
originalNonce, _ := uuid.NewV7()
387+
388+
task := &dbmodel.Task{
389+
ID: taskID,
390+
Name: "test-done-task",
391+
OperatorID: operatorID,
392+
InputCollectionID: "input_col_id",
393+
OutputCollectionName: "output_col_name",
394+
OperatorParams: "{}",
395+
TenantID: "tenant1",
396+
DatabaseID: "db1",
397+
MinRecordsForTask: 100,
398+
NextNonce: originalNonce,
399+
CurrentAttempts: 3,
400+
}
401+
402+
err := suite.Db.Insert(task)
403+
suite.Require().NoError(err)
404+
405+
err = suite.Db.DoneTask(taskID, originalNonce)
406+
suite.Require().NoError(err)
407+
408+
retrieved, err := suite.Db.GetByID(taskID)
409+
suite.Require().NoError(err)
410+
suite.Require().NotNil(retrieved)
411+
suite.Require().NotEqual(originalNonce, retrieved.NextNonce)
412+
suite.Require().NotNil(retrieved.LastRun)
413+
suite.Require().Equal(int32(0), retrieved.CurrentAttempts)
414+
415+
suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
416+
}
417+
418+
func (suite *TaskDbTestSuite) TestTaskDb_DoneTask_InvalidNonce() {
419+
taskID := uuid.New()
420+
operatorID := dbmodel.OperatorRecordCounter
421+
correctNonce, _ := uuid.NewV7()
422+
wrongNonce, _ := uuid.NewV7()
423+
424+
task := &dbmodel.Task{
425+
ID: taskID,
426+
Name: "test-done-task-wrong-nonce",
427+
OperatorID: operatorID,
428+
InputCollectionID: "input_col_id",
429+
OutputCollectionName: "output_col_name",
430+
OperatorParams: "{}",
431+
TenantID: "tenant1",
432+
DatabaseID: "db1",
433+
MinRecordsForTask: 100,
434+
NextNonce: correctNonce,
435+
}
436+
437+
err := suite.Db.Insert(task)
438+
suite.Require().NoError(err)
439+
440+
err = suite.Db.DoneTask(taskID, wrongNonce)
441+
suite.Require().Error(err)
442+
suite.Require().Equal(common.ErrTaskNotFound, err)
443+
444+
suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID)
445+
}
446+
447+
func (suite *TaskDbTestSuite) TestTaskDb_DoneTask_NotFound() {
448+
err := suite.Db.DoneTask(uuid.New(), uuid.Must(uuid.NewV7()))
449+
suite.Require().Error(err)
450+
suite.Require().Equal(common.ErrTaskNotFound, err)
451+
}
452+
315453
// TestOperatorConstantsMatchSeededDatabase verifies that operator constants in
316454
// dbmodel/constants.go match what we seed in the test database (which should match migrations).
317455
// This catches drift between constants and migrations at test time.

go/pkg/sysdb/metastore/db/dbmodel/task.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func (v Task) TableName() string {
3838
type ITaskDb interface {
3939
Insert(task *Task) error
4040
GetByName(inputCollectionID string, taskName string) (*Task, error)
41+
GetByID(taskID uuid.UUID) (*Task, error)
42+
DoneTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error
4143
SoftDelete(inputCollectionID string, taskName string) error
4244
DeleteAll() error
4345
PeekScheduleByCollectionId(collectionIDs []string) ([]*Task, error)

idl/chromadb/proto/coordinator.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,14 @@ message DeleteTaskResponse {
566566
bool success = 1;
567567
}
568568

569+
message DoneTaskRequest {
570+
optional string collection_id = 1;
571+
optional string task_id = 2;
572+
optional string task_run_nonce = 3;
573+
}
574+
575+
message DoneTaskResponse {}
576+
569577
message Operator {
570578
string id = 1;
571579
string name = 2;
@@ -634,6 +642,7 @@ service SysDB {
634642
rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse) {}
635643
rpc GetTaskByName(GetTaskByNameRequest) returns (GetTaskByNameResponse) {}
636644
rpc DeleteTask(DeleteTaskRequest) returns (DeleteTaskResponse) {}
645+
rpc DoneTask(DoneTaskRequest) returns (DoneTaskResponse) {}
637646
rpc GetOperators(GetOperatorsRequest) returns (GetOperatorsResponse) {}
638647
rpc PeekScheduleByCollectionId(PeekScheduleByCollectionIdRequest) returns (PeekScheduleByCollectionIdResponse) {}
639648
}

0 commit comments

Comments
 (0)