diff --git a/genai/embed.go b/genai/embed.go index cea90ec..a7bd2f5 100644 --- a/genai/embed.go +++ b/genai/embed.go @@ -37,7 +37,8 @@ type EmbeddingModel struct { name string fullName string // TaskType describes how the embedding will be used. - TaskType TaskType + TaskType TaskType + outputDimension *int32 } // Name returns the name of the EmbeddingModel. @@ -45,6 +46,10 @@ func (m *EmbeddingModel) Name() string { return m.name } +func (m *EmbeddingModel) SetOutputDimension(outputDim int32) { + m.outputDimension = &outputDim +} + // EmbedContent returns an embedding for the list of parts. func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*EmbedContentResponse, error) { return m.EmbedContentWithTitle(ctx, "", parts...) @@ -54,7 +59,7 @@ func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*Embe // If the given title is non-empty, it is passed to the model and // the task type is set to TaskTypeRetrievalDocument. func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string, parts ...Part) (*EmbedContentResponse, error) { - req := newEmbedContentRequest(m.fullName, m.TaskType, title, parts) + req := newEmbedContentRequest(m.fullName, m.TaskType, title, m.outputDimension, parts) res, err := m.c.gc.EmbedContent(ctx, req) if err != nil { return nil, err @@ -62,10 +67,11 @@ func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string return (EmbedContentResponse{}).fromProto(res), nil } -func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest { +func newEmbedContentRequest(model string, tt TaskType, title string, outputDim *int32, parts []Part) *pb.EmbedContentRequest { req := &pb.EmbedContentRequest{ - Model: model, - Content: NewUserContent(parts...).toProto(), + Model: model, + Content: NewUserContent(parts...).toProto(), + OutputDimensionality: outputDim, } // A non-empty title overrides the task type. if title != "" { @@ -82,8 +88,9 @@ func newEmbedContentRequest(model string, tt TaskType, title string, parts []Par // An EmbeddingBatch holds a collection of embedding requests. type EmbeddingBatch struct { - tt TaskType - req *pb.BatchEmbedContentsRequest + tt TaskType + req *pb.BatchEmbedContentsRequest + outputDim *int32 } // NewBatch returns a new, empty EmbeddingBatch with the same TaskType as the model. @@ -96,6 +103,7 @@ func (m *EmbeddingModel) NewBatch() *EmbeddingBatch { req: &pb.BatchEmbedContentsRequest{ Model: m.fullName, }, + outputDim: m.outputDimension, } } @@ -107,7 +115,7 @@ func (b *EmbeddingBatch) AddContent(parts ...Part) *EmbeddingBatch { // AddContent adds a content to the batch with a title. func (b *EmbeddingBatch) AddContentWithTitle(title string, parts ...Part) *EmbeddingBatch { - b.req.Requests = append(b.req.Requests, newEmbedContentRequest(b.req.Model, b.tt, title, parts)) + b.req.Requests = append(b.req.Requests, newEmbedContentRequest(b.req.Model, b.tt, title, b.outputDim, parts)) return b }