1
+ import sys
2
+ from pyGeno .tools .parsers .CSVTools import CSVFile
3
+
4
+ class Recorder_ABC (object ) :
5
+
6
+ def commit (self , store , model ) :
7
+ """Does something with the currenty state of the trainer's store and the model"""
8
+ raise NotImplemented ("Should be implemented in child" )
9
+
10
+ def __len__ (self ) :
11
+ """returns the number of commits performed"""
12
+ raise NotImplemented ("Should be implemented in child" )
13
+
14
+ class GGPlot2 (Recorder_ABC ):
15
+ """This training recorder will create a nice CSV file fit for using with ggplot2. It will also print regular
16
+ reports if you tell it to be verbose and save the best models"""
17
+ def __init__ (self , filename , verbose = True ):
18
+
19
+ self .filename = filename .replace (".csv" , "" ) + ".ggplot2.csv"
20
+ self .verbose = verbose
21
+
22
+ self .bestScores = {}
23
+ self .currentScores = {}
24
+
25
+ self .csvLegend = None
26
+ self .csvFile = None
27
+
28
+ self .length = 0
29
+
30
+ def commit (self , store , model ) :
31
+ """Appends the current state of the store to the CSV """
32
+ def _fillLine (csvFile , score , bestScore , setName , setLen , outputName , ** csvValues ) :
33
+ line = csvFile .newLine ()
34
+ for k , v in csvValues .iteritems () :
35
+ line [k ] = v
36
+ line ["score" ] = score
37
+ line ["best_score" ] = bestScore [0 ]
38
+ line ["best_score_commit" ] = bestScore [1 ]
39
+ line ["set" ] = "%s(%s)" % (setName , setLen )
40
+ line ["output" ] = outputName
41
+ line .commit ()
42
+
43
+ self .length += 1
44
+ if self .csvLegend is None :
45
+ self .csvLegend = store ["hyperParameters" ].keys ()
46
+ self .csvLegend .extend ( ["score" , "best_score" , "set" , "output" ] )
47
+
48
+ self .csvFile = CSVFile (legend = self .csvLegend )
49
+ self .csvFile .streamToFile ( self .filename , writeRate = 1 )
50
+
51
+ for theSet , scores in store ["scores" ].iteritems () :
52
+ self .currentScores [theSet ] = {}
53
+ if theSet not in self .bestScores :
54
+ self .bestScores [theSet ] = {}
55
+ for outputName , score in scores .iteritems () :
56
+ self .currentScores [theSet ][outputName ] = score
57
+ if outputName not in self .bestScores [theSet ] or score < self .bestScores [theSet ][outputName ][0 ] :
58
+ self .bestScores [theSet ][outputName ] = (score , self .length )
59
+ model .save ("best-%s-%s" % (theSet , self .filename ))
60
+
61
+ _fillLine (
62
+ self .csvFile ,
63
+ self .currentScores [theSet ][outputName ],
64
+ self .bestScores [theSet ][outputName ],
65
+ theSet ,
66
+ store ["setSizes" ][theSet ],
67
+ outputName ,
68
+ ** store ["hyperParameters" ]
69
+ )
70
+
71
+
72
+ if self .verbose :
73
+ self .printCurrentState ()
74
+
75
+ def printCurrentState (self ) :
76
+ """prints the current state stored in the recorder"""
77
+ if self .length > 0 :
78
+ print "\n ==>rec: ggplot2, commit %s:" % self .length
79
+ for setName , scores in self .bestScores .iteritems () :
80
+ print " |-%s set" % setName
81
+ for outputName in scores :
82
+ if self .currentScores [setName ][outputName ] == self .bestScores [setName ][outputName ][0 ] :
83
+ highlight = "+best+"
84
+ else :
85
+ score , epoch = self .bestScores [setName ][outputName ]
86
+ highlight = "(best: %s @ commit: %s)" % (score , epoch )
87
+
88
+ print " |->%s: %s %s" % (outputName , self .currentScores [setName ][outputName ], highlight )
89
+ else :
90
+ print "==>rec: ggplot2, nothing to show yet"
91
+
92
+ sys .stdout .flush ()
93
+
94
+ def __len__ (self ) :
95
+ """returns the number of commits performed"""
96
+ return self .length
0 commit comments