Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const (
// It provides a contract for initializing, managing context,
// and performing chat completions with different language models.
type Llm interface {
SetName(string)
// Init initializes the LLM provider with the given adapter configuration.
// It is called once when the provider is added to the adapter.
Init(llm internal.Adapter) error
Expand All @@ -35,6 +36,10 @@ type Llm interface {
// request options struct. This is used for type checking and reflection
// when processing custom request options.
RequestOptionsType() reflect.Type

SubmitBatch(context.Context, internal.Adapter, ...Requester) (*UntypedBatchPromise, error)
Check(context.Context, *UntypedBatchPromise) (BatchStatus, error)
Wait(ctx context.Context, pr *UntypedBatchPromise) <-chan BatchWaitResponse
}

// LlmAdapter is the main entrypoint for interacting with different LLM providers.
Expand Down Expand Up @@ -109,6 +114,28 @@ func (llm *LlmAdapter) GetProvider(requestProvider *string) (Llm, error) {
return provider, nil
}

func (llm *LlmAdapter) SubmitBatch(ctx context.Context, providerName string, reqs ...Requester) (*UntypedBatchPromise, error) {
p, ok := llm.providers[providerName]
if !ok {
return nil, errors.Newf("unknown provider '%s'", providerName)
}

return p.SubmitBatch(ctx, llm, reqs...)
}

func (llm *LlmAdapter) BatchPromise(providerName string, id string) (*UntypedBatchPromise, error) {
provider, ok := llm.providers[providerName]
if !ok {
return nil, errors.New("cannot find the provider that created this promise")
}

return &UntypedBatchPromise{
ProviderName: providerName,
Provider: provider,
Id: id,
}, nil
}

// LlmAdapter implementation of Adapter

func (llm LlmAdapter) DefaultModel() string {
Expand Down
93 changes: 93 additions & 0 deletions batches.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package llmadapter

import (
"context"

"github.com/checkmarble/marble-llm-adapter/internal"
"github.com/cockroachdb/errors"
)

type (
BatchStatus int
)

const (
BatchPending BatchStatus = iota
BatchRunning
BatchFinished
BatchError
)

type BatchUnsupported struct{}

func (BatchUnsupported) SubmitBatch(ctx context.Context, llm internal.Adapter, reqs ...Requester) (*UntypedBatchPromise, error) {
return nil, errors.New("provider does not support batch mode")
}

func (BatchUnsupported) Check(context.Context, *UntypedBatchPromise) (BatchStatus, error) {
return BatchError, errors.New("provider does not support batch mode")
}

func (BatchUnsupported) Wait(ctx context.Context, pr *UntypedBatchPromise) <-chan BatchWaitResponse {
return nil
}

type Batch[T any] struct {
Requests []Request[T]
}

func (b Batch[T]) Batch(ctx context.Context, llm *LlmAdapter, providerName string) (*BatchPromise[T], error) {
requesters := make([]Requester, len(b.Requests))

for idx, r := range b.Requests {
requesters[idx] = Requester(r)
}

promise, err := llm.SubmitBatch(ctx, providerName, requesters...)

if err != nil {
return nil, err
}

return &BatchPromise[T]{promise}, nil
}

type UntypedBatchPromise struct {
Provider Llm
ProviderName string
Id string
}

type BatchPromise[T any] struct {
*UntypedBatchPromise
}

func (p BatchPromise[T]) Check(ctx context.Context) (BatchStatus, error) {
return p.Provider.Check(ctx, p.UntypedBatchPromise)
}

func (p BatchPromise[T]) Wait(ctx context.Context) (map[string]Response[T], error) {
inners := <-p.Provider.Wait(ctx, p.UntypedBatchPromise)

if inners.Error != nil {
return nil, inners.Error
}

responses := make(map[string]Response[T], len(inners.Responses))

for id, resp := range inners.Responses {
responses[id] = Response[T]{
InnerResponse: resp,
}
}

return responses, nil
}

type BatchWaitResponse struct {
Status BatchStatus
Filename string
Error error

Responses map[string]InnerResponse
}
50 changes: 50 additions & 0 deletions examples/batch/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package main

import (
"context"
"fmt"
"log"
"os"

llmadapter "github.com/checkmarble/marble-llm-adapter"
"github.com/checkmarble/marble-llm-adapter/llms/aistudio"
"google.golang.org/genai"
)

func main() {
ctx := context.Background()

provider, _ := aistudio.New(
aistudio.WithBackend(genai.BackendVertexAI),
aistudio.WithProject(os.Getenv("GOOGLE_CLOUD_PROJECT")),
aistudio.WithLocation("europe-west1"),
aistudio.WithApiKey(os.Getenv("LLM_API_KEY")),
aistudio.WithBucket(os.Getenv("LLM_BATCH_BUCKET")),
)

llm, _ := llmadapter.New(
llmadapter.WithProvider("vertex", provider),
llmadapter.WithDefaultModel("gemini-2.5-flash"),
)

reqs := llmadapter.Batch[string]{
Requests: []llmadapter.Request[string]{
llmadapter.NewUntypedRequest().WithProvider("vertex").WithId("how").WithText(llmadapter.RoleUser, "How are you?"),
llmadapter.NewUntypedRequest().WithProvider("vertex").WithId("addition").WithText(llmadapter.RoleUser, "What is 1 + 1?"),
},
}

// promise, err := llm.SubmitBatch(ctx, "vertex", []llmadapter.Requester(reqs)...)
promise, err := reqs.Batch(ctx, llm, "vertex")
if err != nil {
log.Fatal(err)
}

result, err := promise.Wait(ctx)

if err != nil {
log.Fatal(err)
}

fmt.Printf("%#v\n", result)
}
67 changes: 51 additions & 16 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,87 @@ module github.com/checkmarble/marble-llm-adapter
go 1.24.4

require (
cloud.google.com/go/storage v1.55.0
github.com/cockroachdb/errors v1.12.0
github.com/fatih/structs v1.1.0
github.com/h2non/gock v1.2.0
github.com/invopop/jsonschema v0.13.0
github.com/openai/openai-go v1.9.0
github.com/samber/lo v1.51.0
github.com/simonfrey/jsonl v0.0.0-20240904112901-935399b9a740
github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0
google.golang.org/genai v1.15.0
)

require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/compute/metadata v0.5.0 // indirect
cel.dev/expr v0.20.0 // indirect
cloud.google.com/go v0.121.1 // indirect
cloud.google.com/go/auth v0.16.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.7.0 // indirect
cloud.google.com/go/iam v1.5.2 // indirect
cloud.google.com/go/monitoring v1.24.2 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 // indirect
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect
github.com/cockroachdb/redact v1.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fatih/structs v1.1.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect
github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/getsentry/sentry-go v0.27.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.2 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
github.com/zeebo/errs v1.4.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/detectors/gcp v1.36.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
go.opentelemetry.io/otel v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/sdk v1.36.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/net v0.40.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.66.2 // indirect
google.golang.org/protobuf v1.34.2 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/api v0.235.0 // indirect
google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250512202823-5a2f75b736a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 // indirect
google.golang.org/grpc v1.72.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading
Loading