-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathextract-weights.py
executable file
·65 lines (56 loc) · 1.71 KB
/
extract-weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
# [email protected] (Jason Riesa)
#
# Extract a weight vector from a particular training iteration
# from a weights file.
# Usage:
# ./extract-weights.py <weights_file> <iteration_number> <output_file>
# For example, to extract the weights from the 7th epoch of training:
# ./extract-weights.py training.weights 7 training.weights-7
import svector
import sys
import cPickle
if __name__ == "__main__":
# Print usage.
if len(sys.argv) != 4:
sys.stderr.write("Usage: %s %s %s %s\n" %(sys.argv[0], "<weights>", "<iter>", "<output-name>"))
sys.exit(1)
# Open weights file for reading.
try:
wf = open(sys.argv[1],'r')
except:
sys.stderr.write("Could not open weights file %s for reading.\n" %(sys.argv[1]))
sys.exit(1)
# Get iteration number we are interested in.
try:
iter = int(sys.argv[2]) - 1
except:
sys.stderr.write("Argument <iter> must be an integer. Received: %s\n" %(sys.argv[2]))
sys.exit(1)
# Open output file for writing.
try:
filename = sys.argv[3]+'.weights-%d' %(iter+1)
out = open(filename, 'w')
except:
sys.stderr.write("Could not open output file %s for writing.\n" %(filename))
sys.exit(1)
count = 0
while (count < iter):
try:
cPickle.load(wf)
except:
sys.stderr.write("Could not read file %s at iteration %d.\n" %(sys.argv[1], iter+1))
sys.exit(1)
count += 1
continue
# count == iter
try:
w = cPickle.load(wf)
except:
sys.stderr.write("Could not read file %s at iteration %d.\n" %(sys.argv[1], iter+1))
sys.exit(1)
# Write weight vector to output file
sys.stderr.write("%d components\n" %(len(w)))
cPickle.dump(w,out,protocol=cPickle.HIGHEST_PROTOCOL)
out.close()
wf.close()