diff --git a/test/checkmodel_test.go b/test/checkmodel_test.go index 3880b20..4b7df61 100644 --- a/test/checkmodel_test.go +++ b/test/checkmodel_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "regexp" "strings" "testing" "time" @@ -207,3 +208,114 @@ func TestCheckModelCombinations(t *testing.T) { } } } + +func TestCrossTypeChecks(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + mk := usecase.NewModelKit(nil) + + apiKeyBZ := strings.TrimSpace(os.Getenv("baizhiapikey")) + if apiKeyBZ == "" { + t.Fatalf("missing baizhiapikey") + } + + listBaseURL := "https://model-square.app.baizhi.cloud/v1" + + // 1. 获取模型列表 + ml, err := mk.ModelList(ctx, &domain.ModelListReq{ + Provider: string(consts.ModelProviderBaiZhiCloud), + BaseURL: listBaseURL, + APIKey: apiKeyBZ, + Type: "", + }) + if err != nil { + t.Fatalf("list model error: %v", err) + } + if ml == nil || len(ml.Models) == 0 { + if ml != nil && ml.Error != "" { + t.Fatalf("list model error: %s", ml.Error) + } + t.Fatalf("no models returned") + } + + // 2. 测试所有模型 + // Helper functions to guess model type + getLowerBaseModelName := func(id string) string { + parts := strings.Split(id, "/") + return strings.ToLower(parts[len(parts)-1]) + } + + isRerank := func(modelID string) bool { + mid := getLowerBaseModelName(modelID) + re := regexp.MustCompile(`(?i)(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)`) + return re.MatchString(mid) + } + + isEmbedding := func(modelID string) bool { + if isRerank(modelID) { + return false + } + mid := getLowerBaseModelName(modelID) + re := regexp.MustCompile(`(?i)(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings|voyage-)`) + return re.MatchString(mid) + } + + checkTypes := []string{"embedding", "rerank"} + for _, m := range ml.Models { + for _, ct := range checkTypes { + testName := fmt.Sprintf( + "provider=%s base=%s model=%s type=%s apiKey=present", + string(consts.ModelProviderBaiZhiCloud), + listBaseURL, + m.Model, + ct, + ) + + // Determine expected outcome + expectSuccess := false + if ct == "embedding" && isEmbedding(m.Model) { + expectSuccess = true + } else if ct == "rerank" && isRerank(m.Model) { + expectSuccess = true + } + + t.Run(testName, func(t *testing.T) { + // t.Parallel() // Optional: User didn't ask for parallel, and it might rate limit. Safer to run sequential. + resp, err := mk.CheckModel(ctx, &domain.CheckModelReq{ + Provider: string(consts.ModelProviderBaiZhiCloud), + Model: m.Model, + BaseURL: listBaseURL, + APIKey: apiKeyBZ, + Type: ct, + }) + + respError := "" + respContent := "" + if resp != nil { + respError = resp.Error + respContent = resp.Content + } + + logMsg := fmt.Sprintf("RespError: %q, RespContent: %q", respError, respContent) + + if expectSuccess { + // Expect Success + if err != nil { + t.Errorf("FAIL (Expected Success): %s; error: %v; %s", testName, err, logMsg) + } else if respError != "" { + t.Errorf("FAIL (Expected Success): %s; %s", testName, logMsg) + } else { + t.Logf("PASS: %s; %s", testName, logMsg) + } + } else { + // Expect Failure + if err == nil && respError == "" { + t.Errorf("FAIL (Expected Failure for mismatched type): %s; got success but expected error; %s", testName, logMsg) + } else { + fmt.Printf("PASS: %s; correctly failed as expected for mismatched type. Error: %v; %s\n", testName, err, logMsg) + } + } + }) + } + } +}