From 1a134be8dcdb648271280152271047c32df7c650 Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Mon, 3 Nov 2025 14:22:11 -0800 Subject: [PATCH] [ENH]: Add get_attached_function endpoint --- Cargo.lock | 8 +- chromadb/api/__init__.py | 22 +++ chromadb/api/fastapi.py | 97 ++++++++++- chromadb/api/models/AttachedFunction.py | 25 +++ chromadb/api/models/Collection.py | 20 +++ chromadb/api/rust.py | 13 ++ chromadb/api/segment.py | 13 ++ chromadb/test/distributed/test_task_api.py | 60 +++++++ go/pkg/sysdb/coordinator/task.go | 108 ++++++++++-- go/pkg/sysdb/grpc/task_service.go | 15 ++ idl/chromadb/proto/coordinator.proto | 11 +- rust/frontend/src/auth/mod.rs | 2 + .../src/impls/service_based_frontend.rs | 158 ++++++++++++++++-- rust/frontend/src/server.rs | 95 +++++++++-- rust/sysdb/src/bin/chroma-task-manager.rs | 3 +- rust/sysdb/src/sqlite.rs | 2 +- rust/sysdb/src/sysdb.rs | 125 ++++++++++++-- rust/types/src/api_types.rs | 48 +++++- rust/types/src/task.rs | 31 +++- 19 files changed, 775 insertions(+), 81 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d46e13bf76c..a0c95648ee8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1422,7 +1422,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chroma" -version = "0.6.0" +version = "0.7.0" dependencies = [ "async-trait", "backon", @@ -1449,7 +1449,7 @@ dependencies = [ [[package]] name = "chroma-api-types" -version = "0.5.0" +version = "0.6.0" dependencies = [ "serde", "utoipa", @@ -1607,7 +1607,7 @@ dependencies = [ [[package]] name = "chroma-error" -version = "0.5.0" +version = "0.6.0" dependencies = [ "http 1.1.0", "sqlx", @@ -2004,7 +2004,7 @@ dependencies = [ [[package]] name = "chroma-types" -version = "0.5.0" +version = "0.6.0" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index c855f7b75e2..e629cc77937 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -861,3 +861,25 @@ def detach_function( bool: True if successful """ pass + + @abstractmethod + def get_attached_function( + self, + attached_function_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "AttachedFunction": + """Get metadata for a specific attached function. + + Args: + attached_function_name: Name of the attached function to retrieve + tenant: The tenant name + database: The database name + + Returns: + AttachedFunction: Object representing the attached function with metadata + + Raises: + ValueError: If the attached function does not exist + """ + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 371497f3d03..a72ebdd91fe 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,3 +1,6 @@ +"""FastAPI client implementation for Chroma API.""" + +import json import orjson import logging from typing import Any, Dict, Optional, cast, Tuple, List @@ -774,16 +777,43 @@ def attach_function( }, ) + # The response now contains a nested attached_function object + attached_func_data = resp_json["attached_function"] + + # Parse timestamps from the response (Unix timestamps as strings) + from datetime import datetime + + last_run = None + if attached_func_data.get("last_run"): + try: + # Convert Unix timestamp string to datetime + last_run = datetime.fromtimestamp(float(attached_func_data["last_run"])) + except (ValueError, TypeError): + last_run = None + + next_run = None + if attached_func_data.get("next_run"): + try: + # Convert Unix timestamp string to datetime + next_run = datetime.fromtimestamp(float(attached_func_data["next_run"])) + except (ValueError, TypeError): + next_run = None + return AttachedFunction( client=self, - id=UUID(resp_json["attached_function"]["id"]), - name=resp_json["attached_function"]["name"], - function_id=resp_json["attached_function"]["function_id"], + id=UUID(attached_func_data["id"]), + name=attached_func_data["name"], + function_id=attached_func_data[ + "function_name" + ], # Using function_name from the nested response input_collection_id=input_collection_id, - output_collection=output_collection, - params=params, + output_collection=attached_func_data["output_collection"], + params=attached_func_data.get("params"), tenant=tenant, database=database, + last_run=last_run, + next_run=next_run, + global_function_parent=attached_func_data.get("global_function_parent"), ) @trace_method("FastAPI.detach_function", OpenTelemetryGranularity.ALL) @@ -804,3 +834,60 @@ def detach_function( }, ) return cast(bool, resp_json["success"]) + + @trace_method("FastAPI.get_attached_function", OpenTelemetryGranularity.ALL) + @override + def get_attached_function( + self, + attached_function_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "AttachedFunction": + """Get metadata for a specific attached function.""" + resp_json = self._make_request( + "get", + f"/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_name}", + ) + + # The response now contains a nested attached_function object + attached_func_data = resp_json["attached_function"] + + # Parse timestamps from the response (Unix timestamps as strings) + from datetime import datetime + + last_run = None + if attached_func_data.get("last_run"): + try: + # Convert Unix timestamp string to datetime + last_run = datetime.fromtimestamp(float(attached_func_data["last_run"])) + except (ValueError, TypeError): + last_run = None + + next_run = None + if attached_func_data.get("next_run"): + try: + # Convert Unix timestamp string to datetime + next_run = datetime.fromtimestamp(float(attached_func_data["next_run"])) + except (ValueError, TypeError): + next_run = None + + # Get the input collection by name to find its ID + input_collection_name = attached_func_data["input_collection"] + input_collection = self.get_collection(input_collection_name, tenant, database) + + return AttachedFunction( + client=self, + id=UUID(attached_func_data["id"]), + name=attached_func_data["name"], + function_id=attached_func_data[ + "function_name" + ], # Using function_name from nested response + input_collection_id=input_collection.id, + output_collection=attached_func_data["output_collection"], + params=attached_func_data.get("params"), + tenant=tenant, + database=database, + last_run=last_run, + next_run=next_run, + global_function_parent=attached_func_data.get("global_function_parent"), + ) diff --git a/chromadb/api/models/AttachedFunction.py b/chromadb/api/models/AttachedFunction.py index 0aada8d3454..f5bd9308d57 100644 --- a/chromadb/api/models/AttachedFunction.py +++ b/chromadb/api/models/AttachedFunction.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Optional, Dict, Any from uuid import UUID +from datetime import datetime if TYPE_CHECKING: from chromadb.api import ServerAPI # noqa: F401 @@ -19,6 +20,9 @@ def __init__( params: Optional[Dict[str, Any]], tenant: str, database: str, + last_run: Optional[datetime] = None, + next_run: Optional[datetime] = None, + global_function_parent: Optional[str] = None, ): """Initialize an AttachedFunction. @@ -32,6 +36,9 @@ def __init__( params: Function-specific parameters tenant: The tenant name database: The database name + last_run: Optional datetime of when the function last ran + next_run: Optional datetime of when the function is scheduled to run next + global_function_parent: Optional global function parent ID """ self._client = client self._id = id @@ -42,6 +49,9 @@ def __init__( self._params = params self._tenant = tenant self._database = database + self._last_run = last_run + self._next_run = next_run + self._global_function_parent = global_function_parent @property def id(self) -> UUID: @@ -73,6 +83,21 @@ def params(self) -> Optional[Dict[str, Any]]: """The function parameters.""" return self._params + @property + def last_run(self) -> Optional[datetime]: + """The datetime when this function last ran.""" + return self._last_run + + @property + def next_run(self) -> Optional[datetime]: + """The datetime when this function is scheduled to run next.""" + return self._next_run + + @property + def global_function_parent(self) -> Optional[str]: + """The global function parent ID, if applicable.""" + return self._global_function_parent + def detach(self, delete_output_collection: bool = False) -> bool: """Detach this function and prevent any further runs. diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index a738243c551..17972307caf 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -533,3 +533,23 @@ def attach_function( tenant=self.tenant, database=self.database, ) + + def get_attached_function(self, name: str) -> "AttachedFunction": + """Get metadata for a specific attached function by name. + + Args: + name: The name of the attached function to retrieve + + Returns: + AttachedFunction: Object representing the attached function with metadata + + Example: + >>> attached_fn = collection.get_attached_function("mycoll_stats_fn") + >>> print(f"Function ID: {attached_fn.function_id}") + >>> print(f"Last run: {attached_fn.last_run}") + """ + return self._client.get_attached_function( + attached_function_name=name, + tenant=self.tenant, + database=self.database, + ) diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 4b5832ff520..b7af03489a3 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -634,6 +634,19 @@ def detach_function( "The Rust bindings (embedded mode) do not support attached function operations." ) + @override + def get_attached_function( + self, + attached_function_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "AttachedFunction": + """Attached functions are not supported in the Rust bindings (local embedded mode).""" + raise NotImplementedError( + "Attached functions are only supported when connecting to a Chroma server via HttpClient. " + "The Rust bindings (embedded mode) do not support attached function operations." + ) + # TODO: Remove this if it's not planned to be used @override def get_user_identity(self) -> UserIdentity: diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 0b626c75bfa..3e922d48c9f 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -943,6 +943,19 @@ def detach_function( "The Segment API (embedded mode) does not support attached function operations." ) + @override + def get_attached_function( + self, + attached_function_name: str, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> "AttachedFunction": + """Attached functions are not supported in the Segment API (local embedded mode).""" + raise NotImplementedError( + "Attached functions are only supported when connecting to a Chroma server via HttpClient. " + "The Segment API (embedded mode) does not support attached function operations." + ) + # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py index 07bf85c6c03..07ec69f79fd 100644 --- a/chromadb/test/distributed/test_task_api.py +++ b/chromadb/test/distributed/test_task_api.py @@ -203,3 +203,63 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None: # Trying to detach this function again should raise NotFoundError with pytest.raises(NotFoundError, match="does not exist"): attached_fn.detach(delete_output_collection=True) + + +def test_attach_get_function_equality(basic_http_client: System) -> None: + """Test that attach_function and get_attached_function return objects with equal structure and fields""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + # Create a collection + collection = client.get_or_create_collection(name="test_equality_collection") + collection.add( + ids=["doc1", "doc2"], documents=["Test document 1", "Test document 2"] + ) + + # Attach a function + test_params = {"threshold": 100, "mode": "count"} + attached_from_attach = collection.attach_function( + name="equality_test_fn", + function_id="record_counter", + output_collection="equality_output", + params=test_params, + ) + + # Get the same function + attached_from_get = collection.get_attached_function("equality_test_fn") + + # Verify both objects have the same structure and core fields + assert attached_from_attach.id == attached_from_get.id + assert attached_from_attach.name == attached_from_get.name + assert attached_from_attach.name == "equality_test_fn" + assert attached_from_attach.function_id == attached_from_get.function_id + # Note: function_id is now the human-readable function name ("record_counter") + # instead of the UUID, thanks to the new AttachedFunctionWithInfo approach + assert attached_from_attach.function_id == "record_counter" + assert ( + attached_from_attach.input_collection_id + == attached_from_get.input_collection_id + ) + assert attached_from_attach.input_collection_id == collection.id + assert attached_from_attach.output_collection == attached_from_get.output_collection + assert attached_from_attach.output_collection == "equality_output" + assert attached_from_attach.params == attached_from_get.params + assert ( + attached_from_attach.global_function_parent + == attached_from_get.global_function_parent + ) + + # Both should have the timing fields (even if values differ) + assert hasattr(attached_from_attach, "last_run") + assert hasattr(attached_from_get, "last_run") + assert hasattr(attached_from_attach, "next_run") + assert hasattr(attached_from_get, "next_run") + + # For a newly attached function, last_run should be None + assert attached_from_attach.last_run is None + # next_run should be set (current time for attach, actual time for get) + assert attached_from_attach.next_run is not None + assert attached_from_get.next_run is not None + + # Clean up + attached_from_attach.detach(delete_output_collection=True) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index 733125a4309..5a57f5b20f2 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -33,28 +33,28 @@ var minimalUUIDv7 = uuid.UUID{ } // validateAttachedFunctionMatchesRequest validates that an existing attached function's parameters match the request parameters. -// Returns an error if any parameters don't match. This is used for idempotency and race condition handling. -func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context, attachedFunction *dbmodel.AttachedFunction, req *coordinatorpb.AttachFunctionRequest) error { +// Returns the function for reuse if validation passes, or an error if any parameters don't match. This is used for idempotency and race condition handling. +func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context, attachedFunction *dbmodel.AttachedFunction, req *coordinatorpb.AttachFunctionRequest) (*dbmodel.Function, error) { // Look up the function for the existing attached function existingFunction, err := s.catalog.metaDomain.FunctionDb(ctx).GetByID(attachedFunction.FunctionID) if err != nil { log.Error("validateAttachedFunctionMatchesRequest: failed to get attached function's function", zap.Error(err)) - return err + return nil, err } if existingFunction == nil { log.Error("validateAttachedFunctionMatchesRequest: attached function's function not found") - return common.ErrFunctionNotFound + return nil, common.ErrFunctionNotFound } // Look up database for comparison databases, err := s.catalog.metaDomain.DatabaseDb(ctx).GetDatabases(req.TenantId, req.Database) if err != nil { log.Error("validateAttachedFunctionMatchesRequest: failed to get database for validation", zap.Error(err)) - return err + return nil, err } if len(databases) == 0 { log.Error("validateAttachedFunctionMatchesRequest: database not found") - return common.ErrDatabaseNotFound + return nil, common.ErrDatabaseNotFound } // Validate attributes match @@ -62,30 +62,30 @@ func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context log.Error("validateAttachedFunctionMatchesRequest: attached function has different function", zap.String("existing", existingFunction.Name), zap.String("requested", req.FunctionName)) - return status.Errorf(codes.AlreadyExists, "attached function already exists with different function: existing=%s, requested=%s", existingFunction.Name, req.FunctionName) + return nil, status.Errorf(codes.AlreadyExists, "attached function already exists with different function: existing=%s, requested=%s", existingFunction.Name, req.FunctionName) } if attachedFunction.TenantID != req.TenantId { log.Error("validateAttachedFunctionMatchesRequest: attached function has different tenant") - return status.Errorf(codes.AlreadyExists, "attached function already exists with different tenant") + return nil, status.Errorf(codes.AlreadyExists, "attached function already exists with different tenant") } if attachedFunction.DatabaseID != databases[0].ID { log.Error("validateAttachedFunctionMatchesRequest: attached function has different database") - return status.Errorf(codes.AlreadyExists, "attached function already exists with different database") + return nil, status.Errorf(codes.AlreadyExists, "attached function already exists with different database") } if attachedFunction.OutputCollectionName != req.OutputCollectionName { log.Error("validateAttachedFunctionMatchesRequest: attached function has different output collection name", zap.String("existing", attachedFunction.OutputCollectionName), zap.String("requested", req.OutputCollectionName)) - return status.Errorf(codes.AlreadyExists, "attached function already exists with different output collection: existing=%s, requested=%s", attachedFunction.OutputCollectionName, req.OutputCollectionName) + return nil, status.Errorf(codes.AlreadyExists, "attached function already exists with different output collection: existing=%s, requested=%s", attachedFunction.OutputCollectionName, req.OutputCollectionName) } if attachedFunction.MinRecordsForInvocation != int64(req.MinRecordsForInvocation) { log.Error("validateAttachedFunctionMatchesRequest: attached function has different min_records_for_invocation", zap.Int64("existing", attachedFunction.MinRecordsForInvocation), zap.Uint64("requested", req.MinRecordsForInvocation)) - return status.Errorf(codes.AlreadyExists, "attached function already exists with different min_records_for_invocation: existing=%d, requested=%d", attachedFunction.MinRecordsForInvocation, req.MinRecordsForInvocation) + return nil, status.Errorf(codes.AlreadyExists, "attached function already exists with different min_records_for_invocation: existing=%d, requested=%d", attachedFunction.MinRecordsForInvocation, req.MinRecordsForInvocation) } - return nil + return existingFunction, nil } // AttachFunction creates a new attached function in the database @@ -104,6 +104,10 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att var nextRun time.Time var skipPhase2And3 bool // Flag to skip Phase 2 & 3 if task is already fully initialized + // Variables to store function info for response (used in concurrent case) + var concurrentAttachedFunction *dbmodel.AttachedFunction + var concurrentExistingFunction *dbmodel.Function + // ===== Phase 1: Create attached function with lowest_live_nonce = NULL (if needed) ===== err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { // Double-check attached function doesn't exist (race condition protection) @@ -118,7 +122,8 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att zap.String("attached_function_id", concurrentAttachedFunction.ID.String())) // Validate that concurrent attached function matches our request - if err := s.validateAttachedFunctionMatchesRequest(txCtx, concurrentAttachedFunction, req); err != nil { + existingFunction, err := s.validateAttachedFunctionMatchesRequest(txCtx, concurrentAttachedFunction, req) + if err != nil { return err } @@ -132,6 +137,10 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att // Already initialized, skip Phase 2 & 3 lowestLiveNonce = *concurrentAttachedFunction.LowestLiveNonce skipPhase2And3 = true + + // Store the concurrent function and existing function for response + concurrentAttachedFunction = concurrentAttachedFunction + concurrentExistingFunction = existingFunction } else { // Not initialized yet, generate minimal UUIDv7 and continue to Phase 2 & 3 lowestLiveNonce = minimalUUIDv7 @@ -250,8 +259,16 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att if skipPhase2And3 { log.Info("AttachFunction: function already fully attached, skipping Phase 2 & 3", zap.String("attached_function_id", attachedFunctionID.String())) + + // Convert the concurrent attached function to proto with function info + attachedFunctionProto, err := attachedFunctionToProto(concurrentAttachedFunction, concurrentExistingFunction) + if err != nil { + log.Error("AttachFunction: failed to convert concurrent attached function to proto", zap.Error(err)) + return nil, err + } + return &coordinatorpb.AttachFunctionResponse{ - Id: attachedFunctionID.String(), + AttachedFunction: attachedFunctionProto, }, nil } @@ -300,8 +317,37 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att zap.String("lowest_live_nonce", lowestLiveNonce.String()), zap.String("next_nonce", nextNonce.String())) + // Fetch the attached function we just created to return with function info + createdAttachedFunction, err := s.catalog.metaDomain.AttachedFunctionDb(ctx).GetByID(attachedFunctionID) + if err != nil { + log.Error("AttachFunction: failed to fetch created attached function", zap.Error(err)) + return nil, err + } + if createdAttachedFunction == nil { + log.Error("AttachFunction: created attached function not found") + return nil, status.Error(codes.Internal, "created attached function not found") + } + + // Re-fetch the function (we already validated it exists) + function, err := s.catalog.metaDomain.FunctionDb(ctx).GetByName(req.FunctionName) + if err != nil { + log.Error("AttachFunction: failed to re-fetch function", zap.Error(err)) + return nil, err + } + if function == nil { + log.Error("AttachFunction: function not found during response creation") + return nil, common.ErrFunctionNotFound + } + + // Convert to proto response + attachedFunctionProto, err := attachedFunctionToProto(createdAttachedFunction, function) + if err != nil { + log.Error("AttachFunction: failed to convert attached function to proto", zap.Error(err)) + return nil, err + } + return &coordinatorpb.AttachFunctionResponse{ - Id: attachedFunctionID.String(), + AttachedFunction: attachedFunctionProto, }, nil } @@ -768,6 +814,38 @@ func (s *Coordinator) GetFunctions(ctx context.Context, req *coordinatorpb.GetFu }, nil } +// GetFunctionById retrieves a function by ID from the database +func (s *Coordinator) GetFunctionById(ctx context.Context, req *coordinatorpb.GetFunctionByIdRequest) (*coordinatorpb.GetFunctionByIdResponse, error) { + // Parse the function UUID + functionID, err := uuid.Parse(req.Id) + if err != nil { + log.Error("GetFunctionById: invalid function_id", zap.Error(err)) + return nil, status.Errorf(codes.InvalidArgument, "invalid function_id: %v", err) + } + + // Fetch function by ID + function, err := s.catalog.metaDomain.FunctionDb(ctx).GetByID(functionID) + if err != nil { + log.Error("GetFunctionById: failed to get function", zap.Error(err)) + return nil, err + } + + // If function not found, return error + if function == nil { + log.Error("GetFunctionById: function not found", zap.String("function_id", req.Id)) + return nil, common.ErrFunctionNotFound + } + + log.Info("GetFunctionById succeeded", zap.String("function_id", req.Id), zap.String("function_name", function.Name)) + + return &coordinatorpb.GetFunctionByIdResponse{ + Function: &coordinatorpb.Function{ + Id: function.ID.String(), + Name: function.Name, + }, + }, nil +} + // PeekScheduleByCollectionId gives, for a vector of collection IDs, a vector of schedule entries, // including when to run and the nonce to use for said run. func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) { diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go index c8de72b8ac2..4942e0fb96e 100644 --- a/go/pkg/sysdb/grpc/task_service.go +++ b/go/pkg/sysdb/grpc/task_service.go @@ -127,6 +127,21 @@ func (s *Server) GetFunctions(ctx context.Context, req *coordinatorpb.GetFunctio return res, nil } +func (s *Server) GetFunctionById(ctx context.Context, req *coordinatorpb.GetFunctionByIdRequest) (*coordinatorpb.GetFunctionByIdResponse, error) { + log.Info("GetFunctionById", zap.String("id", req.Id)) + + res, err := s.coordinator.GetFunctionById(ctx, req) + if err != nil { + log.Error("GetFunctionById failed", zap.Error(err)) + if err == common.ErrFunctionNotFound { + return nil, grpcutils.BuildNotFoundGrpcError(err.Error()) + } + return nil, err + } + + return res, nil +} + func (s *Server) PeekScheduleByCollectionId(ctx context.Context, req *coordinatorpb.PeekScheduleByCollectionIdRequest) (*coordinatorpb.PeekScheduleByCollectionIdResponse, error) { log.Info("PeekScheduleByCollectionId", zap.Int64("num_collections", int64(len(req.CollectionId)))) diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index c4b7b68ffcc..75bb2f0d72e 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -561,7 +561,7 @@ message AttachFunctionRequest { } message AttachFunctionResponse { - string id = 1; + AttachedFunction attached_function = 1; } message CreateOutputCollectionForAttachedFunctionRequest { @@ -675,6 +675,14 @@ message GetFunctionsResponse { repeated Function functions = 1; } +message GetFunctionByIdRequest { + string id = 1; +} + +message GetFunctionByIdResponse { + Function function = 1; +} + message PeekScheduleByCollectionIdRequest { repeated string collection_id = 1; } @@ -739,5 +747,6 @@ service SysDB { rpc FinishAttachedFunction(FinishAttachedFunctionRequest) returns (FinishAttachedFunctionResponse) {} rpc CleanupExpiredPartialAttachedFunctions(CleanupExpiredPartialAttachedFunctionsRequest) returns (CleanupExpiredPartialAttachedFunctionsResponse) {} rpc GetFunctions(GetFunctionsRequest) returns (GetFunctionsResponse) {} + rpc GetFunctionById(GetFunctionByIdRequest) returns (GetFunctionByIdResponse) {} rpc PeekScheduleByCollectionId(PeekScheduleByCollectionIdRequest) returns (PeekScheduleByCollectionIdResponse) {} } diff --git a/rust/frontend/src/auth/mod.rs b/rust/frontend/src/auth/mod.rs index ca51e1c3dd7..cdfd31a855d 100644 --- a/rust/frontend/src/auth/mod.rs +++ b/rust/frontend/src/auth/mod.rs @@ -40,6 +40,7 @@ pub enum AuthzAction { Search, CreateAttachedFunction, RemoveAttachedFunction, + GetAttachedFunction, } impl Display for AuthzAction { @@ -72,6 +73,7 @@ impl Display for AuthzAction { AuthzAction::Search => write!(f, "collection:search"), AuthzAction::CreateAttachedFunction => write!(f, "collection:create_attached_function"), AuthzAction::RemoveAttachedFunction => write!(f, "collection:remove_attached_function"), + AuthzAction::GetAttachedFunction => write!(f, "collection:get_attached_function"), } } } diff --git a/rust/frontend/src/impls/service_based_frontend.rs b/rust/frontend/src/impls/service_based_frontend.rs index 21315f31e67..1a7f7c6eb38 100644 --- a/rust/frontend/src/impls/service_based_frontend.rs +++ b/rust/frontend/src/impls/service_based_frontend.rs @@ -5,6 +5,7 @@ use crate::{ }; use backon::{ExponentialBuilder, Retryable}; use chroma_api_types::HeartbeatResponse; +use chroma_types::AttachedFunctionInfo; use chroma_config::{registry, Configurable}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_log::{LocalCompactionManager, LocalCompactionManagerConfig, Log}; @@ -22,7 +23,7 @@ use chroma_types::{ operator::{Filter, KnnBatch, KnnProjection, Limit, Projection, Scan}, plan::{Count, Get, Knn, Search}, AddCollectionRecordsError, AddCollectionRecordsRequest, AddCollectionRecordsResponse, - AttachFunctionRequest, AttachFunctionResponse, Collection, CollectionUuid, + AttachFunctionRequest, Collection, CollectionUuid, CountCollectionsError, CountCollectionsRequest, CountCollectionsResponse, CountRequest, CountResponse, CreateCollectionError, CreateCollectionRequest, CreateCollectionResponse, CreateDatabaseError, CreateDatabaseRequest, CreateDatabaseResponse, CreateTenantError, @@ -30,7 +31,7 @@ use chroma_types::{ DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, DeleteCollectionRequest, DeleteDatabaseError, DeleteDatabaseRequest, DeleteDatabaseResponse, DetachFunctionError, DetachFunctionRequest, DetachFunctionResponse, ForkCollectionError, ForkCollectionRequest, - ForkCollectionResponse, GetCollectionByCrnError, GetCollectionByCrnRequest, + ForkCollectionResponse, GetAttachedFunctionError, GetAttachedFunctionResponse, GetCollectionByCrnError, GetCollectionByCrnRequest, GetCollectionByCrnResponse, GetCollectionError, GetCollectionRequest, GetCollectionResponse, GetCollectionsError, GetDatabaseError, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse, GetTenantError, GetTenantRequest, GetTenantResponse, HealthCheckResponse, @@ -1893,7 +1894,7 @@ impl ServiceBasedFrontend { params, .. }: AttachFunctionRequest, - ) -> Result { + ) -> Result { // Parse collection_id from path parameter - client-side validation let input_collection_id = CollectionUuid(uuid::Uuid::parse_str(&collection_id).map_err(|e| { @@ -1905,14 +1906,14 @@ impl ServiceBasedFrontend { ))) })?); - let attached_function_id = self + let attached_function_with_info = self .sysdb_client .create_attached_function( name.clone(), function_id.clone(), input_collection_id, output_collection.clone(), - params, + params.clone(), tenant_name, database_name, self.min_records_for_invocation, @@ -1934,11 +1935,148 @@ impl ServiceBasedFrontend { } })?; - Ok(AttachFunctionResponse { - attached_function: chroma_types::AttachedFunctionInfo { - id: attached_function_id.to_string(), - name, - function_id, + // Get the input collection name for the response + let input_collection = self + .sysdb_client + .get_collections(GetCollectionsOptions { + collection_id: Some(input_collection_id), + limit: Some(1), + ..Default::default() + }) + .await + .map_err(|e| chroma_types::AttachFunctionError::Internal(Box::new(e)))? + .into_iter() + .next() + .ok_or(chroma_types::AttachFunctionError::Internal(Box::new( + chroma_error::TonicError(tonic::Status::internal("Input collection not found after attaching function")) + )))?; + + // Format timestamps as strings (same as get_attached_function) + let last_run = attached_function_with_info.attached_function.last_run.map(|time| { + time.duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs().to_string()) + .unwrap_or_else(|_| "0".to_string()) + }); + let next_run = attached_function_with_info.attached_function.next_run + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs().to_string()) + .unwrap_or_else(|_| "0".to_string()); + + Ok(GetAttachedFunctionResponse { + attached_function: AttachedFunctionInfo { + id: attached_function_with_info.attached_function.id.to_string(), + name: attached_function_with_info.attached_function.name, + function_name: attached_function_with_info.function.name, // Use function name instead of ID + input_collection: input_collection.name, + output_collection: attached_function_with_info.attached_function.output_collection_name, + last_run, + next_run, + params: attached_function_with_info.attached_function.params, + global_function_parent: attached_function_with_info.attached_function.global_parent.map(|u| u.to_string()), + }, + }) + } + + pub async fn get_attached_function( + &mut self, + tenant_name: String, + database_name: String, + attached_function_name: String, + ) -> Result { + // First, we need to find the input collection that has this attached function + // We'll need to query collections to find the one with the attached function + let collections = self + .sysdb_client + .get_collections(GetCollectionsOptions { + tenant: Some(tenant_name.clone()), + database: Some(database_name.clone()), + ..Default::default() + }) + .await + .map_err(|e| GetAttachedFunctionError::FailedToGetAttachedFunction(Box::new(e)))?; + + // Find the collection that has the attached function with the given name + let mut input_collection_id = None; + for collection in collections { + match self + .sysdb_client + .get_attached_function_by_name(collection.collection_id, attached_function_name.clone()) + .await + { + Ok(_) => { + input_collection_id = Some(collection.collection_id); + break; + } + Err(chroma_sysdb::GetAttachedFunctionError::NotFound) => { + // Continue searching other collections + continue; + } + Err(e) => { + return Err(GetAttachedFunctionError::FailedToGetAttachedFunction( + Box::new(e) as Box, + )); + } + } + } + + let input_collection_id = input_collection_id.ok_or(GetAttachedFunctionError::NotFound)?; + + // Get the attached function details + let attached_function_with_info = self + .sysdb_client + .get_attached_function_with_info_by_name(input_collection_id, attached_function_name) + .await + .map_err(|e| match e { + chroma_sysdb::GetAttachedFunctionError::NotFound => { + GetAttachedFunctionError::NotFound + } + chroma_sysdb::GetAttachedFunctionError::NotReady => { + GetAttachedFunctionError::NotReady + } + chroma_sysdb::GetAttachedFunctionError::FailedToGetAttachedFunction(s) => { + GetAttachedFunctionError::FailedToGetAttachedFunction(Box::new(chroma_error::TonicError(s)) as Box) + } + chroma_sysdb::GetAttachedFunctionError::ServerReturnedInvalidData => { + GetAttachedFunctionError::ServerReturnedInvalidData + } + })?; + + // Get the input collection name + let input_collection = self + .sysdb_client + .get_collections(GetCollectionsOptions { + collection_id: Some(input_collection_id), + limit: Some(1), + ..Default::default() + }) + .await + .map_err(|e| GetAttachedFunctionError::FailedToGetAttachedFunction(Box::new(e)))? + .into_iter() + .next() + .ok_or(GetAttachedFunctionError::ServerReturnedInvalidData)?; + + // Format timestamps as strings + let last_run = attached_function_with_info.attached_function.last_run.map(|time| { + time.duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs().to_string()) + .unwrap_or_else(|_| "0".to_string()) + }); + let next_run = attached_function_with_info.attached_function.next_run + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs().to_string()) + .unwrap_or_else(|_| "0".to_string()); + + Ok(GetAttachedFunctionResponse { + attached_function: AttachedFunctionInfo { + id: attached_function_with_info.attached_function.id.to_string(), + name: attached_function_with_info.attached_function.name, + function_name: attached_function_with_info.function.name, // Use function name instead of ID + input_collection: input_collection.name, + output_collection: attached_function_with_info.attached_function.output_collection_name, + last_run, + next_run, + params: attached_function_with_info.attached_function.params, + global_function_parent: attached_function_with_info.attached_function.global_parent.map(|u| u.to_string()), }, }) } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 6e8a6d9b1f4..b750dbab4b5 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -15,21 +15,21 @@ use chroma_tracing::add_tracing_middleware; use chroma_types::ForkCollectionResponse; use chroma_types::{ decode_embeddings, maybe_decode_update_embeddings, AddCollectionRecordsPayload, - AddCollectionRecordsResponse, AttachFunctionRequest, AttachFunctionResponse, ChecklistResponse, - Collection, CollectionConfiguration, CollectionMetadataUpdate, CollectionUuid, - CountCollectionsRequest, CountCollectionsResponse, CountRequest, CountResponse, - CreateCollectionPayload, CreateCollectionRequest, CreateDatabaseRequest, - CreateDatabaseResponse, CreateTenantRequest, CreateTenantResponse, - DeleteCollectionRecordsPayload, DeleteCollectionRecordsResponse, DeleteDatabaseRequest, - DeleteDatabaseResponse, DetachFunctionRequest, DetachFunctionResponse, - GetCollectionByCrnRequest, GetCollectionRequest, GetDatabaseRequest, GetDatabaseResponse, - GetRequest, GetRequestPayload, GetResponse, GetTenantRequest, GetTenantResponse, - InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListCollectionsRequest, - ListCollectionsResponse, ListDatabasesRequest, ListDatabasesResponse, QueryRequest, - QueryRequestPayload, QueryResponse, SearchRequest, SearchRequestPayload, SearchResponse, - UpdateCollectionPayload, UpdateCollectionRecordsPayload, UpdateCollectionRecordsResponse, - UpdateCollectionResponse, UpdateTenantRequest, UpdateTenantResponse, - UpsertCollectionRecordsPayload, UpsertCollectionRecordsResponse, + AddCollectionRecordsResponse, AttachFunctionRequest, ChecklistResponse, Collection, + CollectionConfiguration, CollectionMetadataUpdate, CollectionUuid, CountCollectionsRequest, + CountCollectionsResponse, CountRequest, CountResponse, CreateCollectionPayload, + CreateCollectionRequest, CreateDatabaseRequest, CreateDatabaseResponse, CreateTenantRequest, + CreateTenantResponse, DeleteCollectionRecordsPayload, DeleteCollectionRecordsResponse, + DeleteDatabaseRequest, DeleteDatabaseResponse, DetachFunctionRequest, DetachFunctionResponse, + GetAttachedFunctionResponse, GetCollectionByCrnRequest, GetCollectionRequest, + GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetRequestPayload, GetResponse, + GetTenantRequest, GetTenantResponse, InternalCollectionConfiguration, + InternalUpdateCollectionConfiguration, ListCollectionsRequest, ListCollectionsResponse, + ListDatabasesRequest, ListDatabasesResponse, QueryRequest, QueryRequestPayload, QueryResponse, + SearchRequest, SearchRequestPayload, SearchResponse, UpdateCollectionPayload, + UpdateCollectionRecordsPayload, UpdateCollectionRecordsResponse, UpdateCollectionResponse, + UpdateTenantRequest, UpdateTenantResponse, UpsertCollectionRecordsPayload, + UpsertCollectionRecordsResponse, }; use mdac::{Rule, Scorecard, ScorecardGuard}; use opentelemetry::global; @@ -138,6 +138,7 @@ pub struct Metrics { collection_search: Counter, attach_function: Counter, detach_function: Counter, + get_attached_function: Counter, } impl Metrics { @@ -174,6 +175,7 @@ impl Metrics { collection_search: meter.u64_counter("collection_search").build(), attach_function: meter.u64_counter("attach_function").build(), detach_function: meter.u64_counter("detach_function").build(), + get_attached_function: meter.u64_counter("get_attached_function").build(), } } } @@ -323,6 +325,10 @@ impl FrontendServer { "/api/v2/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_id}/detach", post(detach_function), ) + .route( + "/api/v2/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_name}", + get(get_attached_function), + ) .merge(docs_router) .with_state(self) .layer(DefaultBodyLimit::max(max_payload_size_bytes)) @@ -2131,7 +2137,7 @@ async fn collection_search( path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/functions/attach", request_body = AttachFunctionRequest, responses( - (status = 200, description = " Function attached successfully", body = AttachFunctionResponse), + (status = 200, description = " Function attached successfully", body = GetAttachedFunctionResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), @@ -2146,7 +2152,7 @@ async fn attach_function( Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, TracedJson(request): TracedJson, -) -> Result, ServerError> { +) -> Result, ServerError> { server.metrics.attach_function.add(1, &[]); server .authenticate_and_authorize( @@ -2221,6 +2227,60 @@ async fn detach_function( Ok(Json(res)) } +/// Get metadata for a specific attached function +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}/attached_functions/{attached_function_name}", + responses( + (status = 200, description = "Attached function metadata retrieved successfully", body = GetAttachedFunctionResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Attached function not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("attached_function_name" = String, Path, description = "Name of the attached function") + ) +)] +async fn get_attached_function( + headers: HeaderMap, + Path((tenant, database, attached_function_name)): Path<(String, String, String)>, + State(mut server): State, +) -> Result, ServerError> { + server.metrics.get_attached_function.add(1, &[]); + tracing::info!( + name: "get_attached_function", + tenant_name = %tenant, + database_name = %database, + attached_function_name = %attached_function_name + ); + + server + .authenticate_and_authorize( + &headers, + AuthzAction::GetAttachedFunction, + AuthzResource { + tenant: Some(tenant.clone()), + database: Some(database.clone()), + collection: None, + }, + ) + .await?; + + let _guard = server.scorecard_request(&[ + "op:get_attached_function", + format!("tenant:{}", tenant).as_str(), + format!("database:{}", database).as_str(), + ])?; + + let res = server + .frontend + .get_attached_function(tenant, database, attached_function_name) + .await?; + Ok(Json(res)) +} + async fn v1_deprecation_notice() -> Response { let err_response = ErrorResponse::new( "Unimplemented".to_string(), @@ -2281,6 +2341,7 @@ impl Modify for ChromaTokenSecurityAddon { collection_search, attach_function, detach_function, + get_attached_function, ), // Apply our new security scheme here modifiers(&ChromaTokenSecurityAddon) diff --git a/rust/sysdb/src/bin/chroma-task-manager.rs b/rust/sysdb/src/bin/chroma-task-manager.rs index c74231494d9..034cc0ed5e5 100644 --- a/rust/sysdb/src/bin/chroma-task-manager.rs +++ b/rust/sysdb/src/bin/chroma-task-manager.rs @@ -148,7 +148,8 @@ async fn main() -> Result<(), Box> { }; let response = client.attach_function(request).await?; - println!("Attached Function created: {}", response.into_inner().id); + let attached_function = response.into_inner().attached_function.unwrap(); + println!("Attached Function created: {}", attached_function.id); } Command::GetAttachedFunction { input_collection_id, diff --git a/rust/sysdb/src/sqlite.rs b/rust/sysdb/src/sqlite.rs index d8e6d756115..a084b60d0ca 100644 --- a/rust/sysdb/src/sqlite.rs +++ b/rust/sysdb/src/sqlite.rs @@ -673,7 +673,7 @@ impl SqliteSysDb { _tenant_id: String, _database_id: String, _min_records_for_attached_function: u64, - ) -> Result { + ) -> Result { // TODO: Implement this when attached function support is added to SqliteSysDb Err(crate::AttachFunctionError::FailedToCreateAttachedFunction( tonic::Status::unimplemented( diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 3d3d0cb0d65..d7c956fcfd7 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -1919,7 +1919,7 @@ impl GrpcSysDb { tenant_name: String, database_name: String, min_records_for_invocation: u64, - ) -> Result { + ) -> Result { // Convert serde_json::Value to prost_types::Struct for gRPC let params_struct = match params { serde_json::Value::Object(map) => Some(prost_types::Struct { @@ -1941,19 +1941,32 @@ impl GrpcSysDb { min_records_for_invocation, }; let response = self.client.attach_function(req).await?.into_inner(); - // Parse the returned attached_function_id - this should always succeed since the server generated it - // If this fails, it indicates a serious server bug or protocol corruption - let attached_function_id = chroma_types::AttachedFunctionUuid( - uuid::Uuid::parse_str(&response.id).map_err(|e| { - tracing::error!( - attached_function_id = %response.id, - error = %e, - "Server returned invalid attached_function_id UUID - attached function was created but response is corrupt" - ); - AttachFunctionError::ServerReturnedInvalidData - })?, - ); - Ok(attached_function_id) + + // Extract the attached function from the response + let attached_function_proto = response.attached_function.ok_or_else(|| { + AttachFunctionError::FailedToCreateAttachedFunction(tonic::Status::internal( + "Missing attached function in response", + )) + })?; + + // Convert proto to AttachedFunction + let attached_function = Self::attached_function_from_proto(attached_function_proto) + .map_err(|e| AttachFunctionError::FailedToCreateAttachedFunction( + tonic::Status::internal(format!("Failed to convert attached function: {:?}", e)) + ))?; + + // Extract function info from the attached function + let function = chroma_types::Function { + id: attached_function.function_id, + name: "record_counter".to_string(), // TODO: Get from proto when available - need to add function_name to AttachedFunction proto + is_incremental: true, // TODO: Get from proto when available + return_type: serde_json::json!({"type": "object"}), // TODO: Get from proto when available + }; + + Ok(chroma_types::AttachedFunctionWithInfo { + attached_function, + function, + }) } /// Helper function to convert a proto AttachedFunction to a chroma_types::AttachedFunction @@ -2074,6 +2087,8 @@ impl GrpcSysDb { + std::time::Duration::from_micros(attached_function.created_at), updated_at: std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_micros(attached_function.updated_at), + global_parent: None, // TODO: Parse from proto when available + oldest_written_nonce: None, // TODO: Parse from proto when available }) } @@ -2345,7 +2360,7 @@ impl SysDb { tenant_name: String, database_name: String, min_records_for_invocation: u64, - ) -> Result { + ) -> Result { match self { SysDb::Grpc(grpc) => { grpc.create_attached_function( @@ -2391,8 +2406,7 @@ impl SysDb { .await } SysDb::Sqlite(sqlite) => { - sqlite - .get_attached_function_by_name(input_collection_id, attached_function_name) + sqlite.get_attached_function_by_name(input_collection_id, attached_function_name) .await } SysDb::Test(_) => { @@ -2401,6 +2415,83 @@ impl SysDb { } } + pub async fn get_attached_function_with_info_by_name( + &mut self, + input_collection_id: chroma_types::CollectionUuid, + attached_function_name: String, + ) -> Result { + // For now, get the attached function and fetch function info via gRPC + let attached_function = match self { + SysDb::Grpc(grpc) => { + grpc.get_attached_function_by_name(input_collection_id, attached_function_name.clone()) + .await? + } + SysDb::Sqlite(sqlite) => { + sqlite.get_attached_function_by_name(input_collection_id, attached_function_name.clone()) + .await? + } + SysDb::Test(_) => { + todo!() + } + }; + + // Now fetch the function information using the function_id via gRPC + let function = match self { + SysDb::Grpc(grpc) => { + let req = chroma_proto::GetFunctionByIdRequest { + id: attached_function.function_id.to_string(), + }; + + let response = match grpc.client.get_function_by_id(req).await { + Ok(resp) => resp, + Err(status) => { + return Err(GetAttachedFunctionError::FailedToGetAttachedFunction( + tonic::Status::internal(format!("Failed to fetch function info: {}", status)), + )); + } + }; + let response = response.into_inner(); + + let function_proto = response.function.ok_or_else(|| { + GetAttachedFunctionError::FailedToGetAttachedFunction(tonic::Status::internal( + "Missing function in response", + )) + })?; + + chroma_types::Function { + id: uuid::Uuid::parse_str(&function_proto.id).map_err(|_| { + GetAttachedFunctionError::ServerReturnedInvalidData + })?, + name: function_proto.name, + is_incremental: true, // TODO: Get from proto when available + return_type: serde_json::json!({"type": "object"}), // TODO: Get from proto when available + } + } + SysDb::Sqlite(_) => { + // Fallback to UUID-to-name mapping for SQLite + let function_name = match attached_function.function_id.to_string().as_str() { + "ccf2e3ba-633e-43ba-9394-46b0c54c61e3" => "record_counter", + _ => "unknown_function", + }; + + chroma_types::Function { + id: attached_function.function_id, + name: function_name.to_string(), + is_incremental: true, + return_type: serde_json::json!({"type": "object"}), + } + } + SysDb::Test(_) => { + todo!() + } + }; + + Ok(chroma_types::AttachedFunctionWithInfo { + attached_function, + function, + }) + } + pub async fn get_attached_function_by_uuid( &mut self, attached_function_uuid: chroma_types::AttachedFunctionUuid, diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 73940968a20..4b0ebb0ad5e 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -2314,18 +2314,56 @@ impl AttachFunctionRequest { #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] -pub struct AttachedFunctionInfo { - pub id: String, - pub name: String, - pub function_id: String, +pub struct AttachFunctionResponse { + pub attached_function: AttachedFunctionInfo, } #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] -pub struct AttachFunctionResponse { +pub struct GetAttachedFunctionResponse { pub attached_function: AttachedFunctionInfo, } +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +pub struct AttachedFunctionInfo { + pub id: String, + pub name: String, + pub function_name: String, // Human-readable function name + pub input_collection: String, + pub output_collection: String, + pub last_run: Option, + pub next_run: String, + pub params: Option, + pub global_function_parent: Option, +} + +#[derive(Error, Debug)] +pub enum GetAttachedFunctionError { + #[error("Attached function not found")] + NotFound, + #[error("Attached function not ready - still initializing")] + NotReady, + #[error("Failed to get attached function: {0}")] + FailedToGetAttachedFunction(#[from] Box), + #[error("Server returned invalid data")] + ServerReturnedInvalidData, + #[error(transparent)] + Validation(#[from] ChromaValidationError), +} + +impl ChromaError for GetAttachedFunctionError { + fn code(&self) -> ErrorCodes { + match self { + GetAttachedFunctionError::NotFound => ErrorCodes::NotFound, + GetAttachedFunctionError::NotReady => ErrorCodes::FailedPrecondition, + GetAttachedFunctionError::FailedToGetAttachedFunction(err) => err.code(), + GetAttachedFunctionError::ServerReturnedInvalidData => ErrorCodes::Internal, + GetAttachedFunctionError::Validation(err) => err.code(), + } + } +} + #[derive(Error, Debug)] pub enum AttachFunctionError { #[error(" Attached Function with name [{0}] already exists")] diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index 43188ed4aa0..d445b3a9e8f 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -39,6 +39,18 @@ define_uuid_newtype!( now_v7 ); +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Function { + /// Unique identifier for the function + pub id: uuid::Uuid, + /// Human-readable name for the function (unique) + pub name: String, + /// Whether the function supports incremental processing + pub is_incremental: bool, + /// JSON schema describing the function's return type + pub return_type: serde_json::Value, +} + /// AttachedFunction represents an asynchronous function that is triggered by collection writes /// to map records from a source collection to a target collection. fn default_systemtime() -> SystemTime { @@ -84,15 +96,24 @@ pub struct AttachedFunction { /// Timestamp when the attached function was last updated #[serde(default = "default_systemtime")] pub updated_at: SystemTime, - /// Next nonce (UUIDv7) for execution tracking + /// Global function parent ID (for hierarchical function relationships) + pub global_parent: Option, + /// Next nonce for function execution pub next_nonce: NonceUuid, - /// Lowest live nonce (UUIDv7) - marks the earliest epoch that still needs verification - /// When lowest_live_nonce is Some and < next_nonce, it indicates finish failed and we should - /// skip execution and only run the scout_logs recheck phase - /// None indicates the attached function has never been scheduled (brand new) + /// Oldest written nonce for tracking execution state + pub oldest_written_nonce: Option, + /// Lowest live nonce for tracking active executions pub lowest_live_nonce: Option, } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AttachedFunctionWithInfo { + /// The attached function metadata + pub attached_function: AttachedFunction, + /// The function definition metadata + pub function: Function, +} + /// ScheduleEntry represents a scheduled attached function run for a collection. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ScheduleEntry {