Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ type EmbeddingModel struct {
name string
fullName string
// TaskType describes how the embedding will be used.
TaskType TaskType
TaskType TaskType
outputDimension *int32
}

// Name returns the name of the EmbeddingModel.
func (m *EmbeddingModel) Name() string {
return m.name
}

func (m *EmbeddingModel) SetOutputDimension(outputDim int32) {
m.outputDimension = &outputDim
}

// EmbedContent returns an embedding for the list of parts.
func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*EmbedContentResponse, error) {
return m.EmbedContentWithTitle(ctx, "", parts...)
Expand All @@ -54,18 +59,19 @@ func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*Embe
// If the given title is non-empty, it is passed to the model and
// the task type is set to TaskTypeRetrievalDocument.
func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string, parts ...Part) (*EmbedContentResponse, error) {
req := newEmbedContentRequest(m.fullName, m.TaskType, title, parts)
req := newEmbedContentRequest(m.fullName, m.TaskType, title, m.outputDimension, parts)
res, err := m.c.gc.EmbedContent(ctx, req)
if err != nil {
return nil, err
}
return (EmbedContentResponse{}).fromProto(res), nil
}

func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest {
func newEmbedContentRequest(model string, tt TaskType, title string, outputDim *int32, parts []Part) *pb.EmbedContentRequest {
req := &pb.EmbedContentRequest{
Model: model,
Content: NewUserContent(parts...).toProto(),
Model: model,
Content: NewUserContent(parts...).toProto(),
OutputDimensionality: outputDim,
}
// A non-empty title overrides the task type.
if title != "" {
Expand All @@ -82,8 +88,9 @@ func newEmbedContentRequest(model string, tt TaskType, title string, parts []Par

// An EmbeddingBatch holds a collection of embedding requests.
type EmbeddingBatch struct {
tt TaskType
req *pb.BatchEmbedContentsRequest
tt TaskType
req *pb.BatchEmbedContentsRequest
outputDim *int32
}

// NewBatch returns a new, empty EmbeddingBatch with the same TaskType as the model.
Expand All @@ -96,6 +103,7 @@ func (m *EmbeddingModel) NewBatch() *EmbeddingBatch {
req: &pb.BatchEmbedContentsRequest{
Model: m.fullName,
},
outputDim: m.outputDimension,
}
}

Expand All @@ -107,7 +115,7 @@ func (b *EmbeddingBatch) AddContent(parts ...Part) *EmbeddingBatch {

// AddContent adds a content to the batch with a title.
func (b *EmbeddingBatch) AddContentWithTitle(title string, parts ...Part) *EmbeddingBatch {
b.req.Requests = append(b.req.Requests, newEmbedContentRequest(b.req.Model, b.tt, title, parts))
b.req.Requests = append(b.req.Requests, newEmbedContentRequest(b.req.Model, b.tt, title, b.outputDim, parts))
return b
}

Expand Down