-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgenerator.go
173 lines (138 loc) · 2.91 KB
/
generator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package randtxt
import (
"bytes"
"fmt"
"io"
"math/rand"
"strings"
"github.com/pboyd/markov"
)
// Generator generates random text from a model built by ModelBuilder.
type Generator struct {
chain markov.Chain
// TagSet is the language and tagset specific rules. This should match
// the TagSet used when the model was built.
TagSet TagSet
}
// NewGenerator returns a new generator. Returns an error if the chain has an
// unrecognized format.
func NewGenerator(chain markov.Chain) (*Generator, error) {
_, err := inspectChain(chain)
if err != nil {
return nil, err
}
return &Generator{
chain: chain,
TagSet: PennTreebankTagSet,
}, nil
}
func inspectChain(chain markov.Chain) (int, error) {
root, err := chain.Get(0)
if err != nil {
return 0, err
}
rootString, ok := root.(string)
if !ok {
return 0, fmt.Errorf("chain has type %T, want string", root)
}
split := strings.Split(rootString, " ")
for _, gram := range split {
tag := parseTag(gram)
if tag.POS == "" || tag.Text == "" {
return 0, fmt.Errorf("unrecognized tag format %q", gram)
}
}
return len(split), nil
}
// Paragraph returns a paragraph containing between "min" and "max" sentences.
func (g *Generator) Paragraph(min, max int) (string, error) {
text := &bytes.Buffer{}
err := g.WriteParagraph(text, min, max)
if err != nil {
return "", err
}
return text.String(), nil
}
// WriteParagraph writes a paragraph of random text to "out". The paragraph
// will contain between "min" and "max" sentences.
func (g *Generator) WriteParagraph(out io.Writer, min, max int) error {
total := rand.Intn(max-min) + min
generated := 0
done := make(chan struct{})
defer close(done)
gen := g.generate(done)
for te := range gen {
if te.Tag.POS == "." {
break
}
}
first := <-gen
if first.Err != nil {
return first.Err
}
io.WriteString(out, g.TagSet.Join(first.Tag, Tag{}))
last := first.Tag
for te := range gen {
if te.Err != nil {
return te.Err
}
tag := te.Tag
io.WriteString(out, g.TagSet.Join(tag, last))
if tag.POS == "." {
generated++
if generated == total {
break
}
}
last = tag
}
return nil
}
func (g *Generator) generate(done chan struct{}) <-chan tagOrError {
out := make(chan tagOrError)
send := func(tag Tag, err error) bool {
te := tagOrError{
Tag: tag,
Err: err,
}
select {
case out <- te:
return false
case <-done:
return true
}
}
go func() {
defer close(out)
past, err := randomSeed(g.chain)
if err != nil {
send(Tag{}, err)
return
}
for _, rawTag := range strings.Split(past, " ") {
if send(parseTag(rawTag), nil) {
return
}
}
model, err := NewModel(g.chain, past)
if err != nil {
send(Tag{}, err)
return
}
for {
err := model.Step()
if err != nil {
send(Tag{}, err)
return
}
if send(model.Current(), nil) {
return
}
}
}()
return out
}
type tagOrError struct {
Tag Tag
Err error
}