-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnormalize.go
More file actions
155 lines (139 loc) · 4.46 KB
/
normalize.go
File metadata and controls
155 lines (139 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package ml
import (
"fmt"
"io"
"strings"
coreerr "dappco.re/go/core/log"
)
// NormalizeConfig configures the seed normalization process.
type NormalizeConfig struct {
MinLength int
}
// NormalizeSeeds deduplicates seeds into the expansion_prompts table.
//
// Steps:
// 1. Verify the seeds table exists and report its row count.
// 2. Drop and recreate expansion_prompts using deduplicated seeds,
// excluding prompts already present in the prompts or golden_set tables.
// 3. Assign priority based on domain coverage (underrepresented domains
// receive higher priority via RANK).
// 4. Print a region distribution summary.
func NormalizeSeeds(db *DB, cfg NormalizeConfig, w io.Writer) error {
// 1. Check seeds table exists and get count.
var seedCount int
if err := db.conn.QueryRow("SELECT count(*) FROM seeds").Scan(&seedCount); err != nil {
return coreerr.E("ml.NormalizeSeeds", "no seeds table (run import-all first)", err)
}
fmt.Fprintf(w, "Seeds table: %d rows\n", seedCount)
if seedCount == 0 {
return coreerr.E("ml.NormalizeSeeds", "seeds table is empty, nothing to normalize", nil)
}
// 2. Drop and recreate expansion_prompts.
if _, err := db.conn.Exec("DROP TABLE IF EXISTS expansion_prompts"); err != nil {
return coreerr.E("ml.NormalizeSeeds", "drop expansion_prompts", err)
}
createSQL := fmt.Sprintf(`
CREATE TABLE expansion_prompts AS
WITH unique_seeds AS (
SELECT
ROW_NUMBER() OVER (ORDER BY region, domain, seed_id) AS idx,
seed_id, region, domain, prompt
FROM (
SELECT DISTINCT ON (prompt)
seed_id, region, domain, prompt
FROM seeds
WHERE length(prompt) >= %d
ORDER BY prompt, seed_id
)
),
existing_prompts AS (
SELECT prompt FROM prompts
UNION ALL
SELECT prompt FROM golden_set
)
SELECT
us.idx, us.seed_id, us.region, us.domain,
'en' AS language, us.prompt, '' AS prompt_en,
0 AS priority, 'pending' AS status
FROM unique_seeds us
WHERE NOT EXISTS (
SELECT 1 FROM existing_prompts ep WHERE ep.prompt = us.prompt
)
`, cfg.MinLength)
if _, err := db.conn.Exec(createSQL); err != nil {
return coreerr.E("ml.NormalizeSeeds", "create expansion_prompts", err)
}
var epCount int
if err := db.conn.QueryRow("SELECT count(*) FROM expansion_prompts").Scan(&epCount); err != nil {
return coreerr.E("ml.NormalizeSeeds", "count expansion_prompts", err)
}
fmt.Fprintf(w, "Expansion prompts created: %d (min length %d, deduped, excluding existing)\n", epCount, cfg.MinLength)
if epCount == 0 {
fmt.Fprintln(w, "No new expansion prompts to process.")
return nil
}
// 3. Assign priority based on domain coverage.
prioritySQL := `
UPDATE expansion_prompts SET priority = sub.rnk
FROM (
SELECT domain, RANK() OVER (ORDER BY cnt ASC) AS rnk
FROM (
SELECT domain, count(*) AS cnt
FROM expansion_prompts
GROUP BY domain
) domain_counts
) sub
WHERE expansion_prompts.domain = sub.domain
`
if _, err := db.conn.Exec(prioritySQL); err != nil {
return coreerr.E("ml.NormalizeSeeds", "assign priority", err)
}
fmt.Fprintln(w, "Priority assigned (underrepresented domains ranked higher).")
// 4. Region distribution summary.
fmt.Fprintln(w)
fmt.Fprintln(w, "Region distribution:")
rows, err := db.conn.Query(`
SELECT
CASE
WHEN region LIKE 'cn%' THEN 'cn'
WHEN region LIKE 'en%' THEN 'en'
WHEN region LIKE 'ru%' THEN 'ru'
WHEN region LIKE 'de%' THEN 'de'
WHEN region LIKE 'es%' THEN 'es'
WHEN region LIKE 'fr%' THEN 'fr'
WHEN region LIKE 'latam%' THEN 'latam'
WHEN region LIKE 'africa%' THEN 'africa'
WHEN region LIKE 'eu%' THEN 'eu'
WHEN region LIKE 'me%' THEN 'me'
ELSE 'other'
END AS region_group,
count(*) AS cnt
FROM expansion_prompts
GROUP BY region_group
ORDER BY cnt DESC
`)
if err != nil {
return coreerr.E("ml.NormalizeSeeds", "region distribution query", err)
}
defer rows.Close()
var totalFromRegions int
var lines []string
for rows.Next() {
var region string
var cnt int
if err := rows.Scan(®ion, &cnt); err != nil {
return coreerr.E("ml.NormalizeSeeds", "scan region row", err)
}
totalFromRegions += cnt
lines = append(lines, fmt.Sprintf(" %-10s %6d", region, cnt))
}
if err := rows.Err(); err != nil {
return coreerr.E("ml.NormalizeSeeds", "iterate region rows", err)
}
for _, line := range lines {
fmt.Fprintln(w, line)
}
fmt.Fprintf(w, " %-10s %6d\n", strings.Repeat("-", 10), totalFromRegions)
fmt.Fprintf(w, " %-10s %6d\n", "total", totalFromRegions)
return nil
}