-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnetFactory.hpp
More file actions
152 lines (125 loc) · 4.2 KB
/
netFactory.hpp
File metadata and controls
152 lines (125 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/**
* @file netFactory.hpp
* @brief I'm not a fan of factories, but here's one - this makes
* a network of the appropriate type which conforms to an example
* set, and is a namespace.
*
*/
#ifndef __NETFACTORY_HPP
#define __NETFACTORY_HPP
#include "bpnet.hpp"
#include "obnet.hpp"
#include "hinet.hpp"
#include "uesnet.hpp"
/**
* \brief
* This class - really a namespace - contains functions which create,
* load or save networks of all types.
*/
class NetFactory { // not a namespace because Doxygen gets confused.
public:
/**
* \brief
* Construct a single hidden layer network of a given type
* which conforms to the example set.
*/
static Net *makeNet(NetType t,ExampleSet &e,int hnodes){
Net *net;
int layers[3];
layers[0] = e.getInputCount();
layers[1] = hnodes;
layers[2] = e.getOutputCount();
return makeNet(t,3,layers);
}
static Net *makeNet(NetType t,int layercount, int *layers){
switch(t){
case NetType::PLAIN:
return new BPNet(layercount,layers);
case NetType::OUTPUTBLENDING:
return new OutputBlendingNet(layercount,layers);
case NetType::HINPUT:
return new HInputNet(layercount,layers);
case NetType::UESMANN:
return new UESNet(layercount,layers);
default:break;
}
}
/**
* \brief Load a network of any type from a file - note, endianness not checked!
*/
inline static Net *load(const char *fn){
FILE *a = fopen(fn,"rb");
if(!a)
throw new std::runtime_error("cannot open file");
// get type
uint32_t magic;
if(!fread(&magic,sizeof(uint32_t),1,a)){
fclose(a);
throw new std::runtime_error("bad net save file");
}
NetType t = static_cast<NetType>(magic);
// build layer specification reading the layer count and then
// the layer sizes
uint32_t layercount,tmp;
if(!fread(&layercount,sizeof(uint32_t),1,a)){
fclose(a);
throw new std::runtime_error("bad net save file");
}
int *layers = new int[layercount];
for(int i=0;i<layercount;i++){
if(!fread(&tmp,sizeof(uint32_t),1,a)){
delete [] layers;
fclose(a);
throw new std::runtime_error("bad net save file");
}
layers[i]=tmp;
}
// build the net
Net *n = makeNet(t,layercount,layers);
// get the parameter data
int size = n->getDataSize();
double *buf = new double[size];
// and read it
// printf("loading %d doubles\n",size);
int readData = fread(buf,sizeof(double),size,a);
if(readData!=size){
delete [] buf;
delete [] layers;
fclose(a);
throw new std::runtime_error("bad net save file");
}
n->load(buf);
delete [] buf;
delete [] layers;
fclose(a);
return n;
}
/**
* \brief Save a net of any type to a file - note, endianness not checked!
*/
inline static void save(const char *fn,Net *n) {
FILE *a = fopen(fn,"wb");
if(!a)
throw new std::runtime_error("cannot open file");
// get and write the magic number
uint32_t magic=static_cast<uint32_t>(n->type); // magic number
fwrite(&magic,sizeof(uint32_t),1,a);
// write the layer count and layer sizes, all as 32-bit.
uint32_t layercount = n->getLayerCount();
fwrite(&layercount,sizeof(uint32_t),1,a);
for(int i=0;i<layercount;i++){
uint32_t layersize = n->getLayerSize(i);
fwrite(&layersize,sizeof(uint32_t),1,a);
}
// get the parameter data
int size = n->getDataSize();
// printf("saving %d doubles\n",size);
double *buf = new double[size];
n->save(buf);
// and write it
fwrite(buf,sizeof(double),size,a);
delete [] buf;
fclose(a);
}
};
#endif /* __NETFACTORY_HPP */