Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions test/checkmodel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"regexp"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -207,3 +208,114 @@ func TestCheckModelCombinations(t *testing.T) {
}
}
}

func TestCrossTypeChecks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
mk := usecase.NewModelKit(nil)

apiKeyBZ := strings.TrimSpace(os.Getenv("baizhiapikey"))
if apiKeyBZ == "" {
t.Fatalf("missing baizhiapikey")
}

listBaseURL := "https://model-square.app.baizhi.cloud/v1"

// 1. 获取模型列表
ml, err := mk.ModelList(ctx, &domain.ModelListReq{
Provider: string(consts.ModelProviderBaiZhiCloud),
BaseURL: listBaseURL,
APIKey: apiKeyBZ,
Type: "",
})
if err != nil {
t.Fatalf("list model error: %v", err)
}
if ml == nil || len(ml.Models) == 0 {
if ml != nil && ml.Error != "" {
t.Fatalf("list model error: %s", ml.Error)
}
t.Fatalf("no models returned")
}

// 2. 测试所有模型
// Helper functions to guess model type
getLowerBaseModelName := func(id string) string {
parts := strings.Split(id, "/")
return strings.ToLower(parts[len(parts)-1])
}

isRerank := func(modelID string) bool {
mid := getLowerBaseModelName(modelID)
re := regexp.MustCompile(`(?i)(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)`)
return re.MatchString(mid)
}

isEmbedding := func(modelID string) bool {
if isRerank(modelID) {
return false
}
mid := getLowerBaseModelName(modelID)
re := regexp.MustCompile(`(?i)(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings|voyage-)`)
return re.MatchString(mid)
}

checkTypes := []string{"embedding", "rerank"}
for _, m := range ml.Models {
for _, ct := range checkTypes {
testName := fmt.Sprintf(
"provider=%s base=%s model=%s type=%s apiKey=present",
string(consts.ModelProviderBaiZhiCloud),
listBaseURL,
m.Model,
ct,
)

// Determine expected outcome
expectSuccess := false
if ct == "embedding" && isEmbedding(m.Model) {
expectSuccess = true
} else if ct == "rerank" && isRerank(m.Model) {
expectSuccess = true
}

t.Run(testName, func(t *testing.T) {
// t.Parallel() // Optional: User didn't ask for parallel, and it might rate limit. Safer to run sequential.
resp, err := mk.CheckModel(ctx, &domain.CheckModelReq{
Provider: string(consts.ModelProviderBaiZhiCloud),
Model: m.Model,
BaseURL: listBaseURL,
APIKey: apiKeyBZ,
Type: ct,
})

respError := ""
respContent := ""
if resp != nil {
respError = resp.Error
respContent = resp.Content
}

logMsg := fmt.Sprintf("RespError: %q, RespContent: %q", respError, respContent)

if expectSuccess {
// Expect Success
if err != nil {
t.Errorf("FAIL (Expected Success): %s; error: %v; %s", testName, err, logMsg)
} else if respError != "" {
t.Errorf("FAIL (Expected Success): %s; %s", testName, logMsg)
} else {
t.Logf("PASS: %s; %s", testName, logMsg)
}
} else {
// Expect Failure
if err == nil && respError == "" {
t.Errorf("FAIL (Expected Failure for mismatched type): %s; got success but expected error; %s", testName, logMsg)
} else {
fmt.Printf("PASS: %s; correctly failed as expected for mismatched type. Error: %v; %s\n", testName, err, logMsg)
}
}
})
}
}
}
Loading