Skip to content

Commit 4641aaa

Browse files
committed
ggplot2 recorder now stores the commit of the last best score into the csv
1 parent c0b5508 commit 4641aaa

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

Mariana/training/recoders.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import sys
2+
from pyGeno.tools.parsers.CSVTools import CSVFile
3+
4+
class Recorder_ABC(object) :
5+
6+
def commit(self, store, model) :
7+
"""Does something with the currenty state of the trainer's store and the model"""
8+
raise NotImplemented("Should be implemented in child")
9+
10+
def __len__(self) :
11+
"""returns the number of commits performed"""
12+
raise NotImplemented("Should be implemented in child")
13+
14+
class GGPlot2(Recorder_ABC):
15+
"""This training recorder will create a nice CSV file fit for using with ggplot2. It will also print regular
16+
reports if you tell it to be verbose and save the best models"""
17+
def __init__(self, filename, verbose = True):
18+
19+
self.filename = filename.replace(".csv", "") + ".ggplot2.csv"
20+
self.verbose = verbose
21+
22+
self.bestScores = {}
23+
self.currentScores = {}
24+
25+
self.csvLegend = None
26+
self.csvFile = None
27+
28+
self.length = 0
29+
30+
def commit(self, store, model) :
31+
"""Appends the current state of the store to the CSV """
32+
def _fillLine(csvFile, score, bestScore, setName, setLen, outputName, **csvValues) :
33+
line = csvFile.newLine()
34+
for k, v in csvValues.iteritems() :
35+
line[k] = v
36+
line["score"] = score
37+
line["best_score"] = bestScore[0]
38+
line["best_score_commit"] = bestScore[1]
39+
line["set"] = "%s(%s)" %(setName, setLen)
40+
line["output"] = outputName
41+
line.commit()
42+
43+
self.length += 1
44+
if self.csvLegend is None :
45+
self.csvLegend = store["hyperParameters"].keys()
46+
self.csvLegend.extend( ["score", "best_score", "set", "output"] )
47+
48+
self.csvFile = CSVFile(legend = self.csvLegend)
49+
self.csvFile.streamToFile( self.filename, writeRate = 1 )
50+
51+
for theSet, scores in store["scores"].iteritems() :
52+
self.currentScores[theSet] = {}
53+
if theSet not in self.bestScores :
54+
self.bestScores[theSet] = {}
55+
for outputName, score in scores.iteritems() :
56+
self.currentScores[theSet][outputName] = score
57+
if outputName not in self.bestScores[theSet] or score < self.bestScores[theSet][outputName][0] :
58+
self.bestScores[theSet][outputName] = (score, self.length)
59+
model.save("best-%s-%s" % (theSet, self.filename))
60+
61+
_fillLine(
62+
self.csvFile,
63+
self.currentScores[theSet][outputName],
64+
self.bestScores[theSet][outputName],
65+
theSet,
66+
store["setSizes"][theSet],
67+
outputName,
68+
**store["hyperParameters"]
69+
)
70+
71+
72+
if self.verbose :
73+
self.printCurrentState()
74+
75+
def printCurrentState(self) :
76+
"""prints the current state stored in the recorder"""
77+
if self.length > 0 :
78+
print "\n==>rec: ggplot2, commit %s:" % self.length
79+
for setName, scores in self.bestScores.iteritems() :
80+
print " |-%s set" % setName
81+
for outputName in scores :
82+
if self.currentScores[setName][outputName] == self.bestScores[setName][outputName][0] :
83+
highlight = "+best+"
84+
else :
85+
score, epoch = self.bestScores[setName][outputName]
86+
highlight = "(best: %s @ commit: %s)" % (score, epoch)
87+
88+
print " |->%s: %s %s" % (outputName, self.currentScores[setName][outputName], highlight)
89+
else :
90+
print "==>rec: ggplot2, nothing to show yet"
91+
92+
sys.stdout.flush()
93+
94+
def __len__(self) :
95+
"""returns the number of commits performed"""
96+
return self.length

0 commit comments

Comments
 (0)