44
55from .shapes import *
66
7+ import pdb
8+
79LAYER_DESCRIPTORS = {
810
911 # Caffe Types
@@ -103,6 +105,7 @@ class LayerAdapter(object):
103105 def __init__ (self , layer , kind ):
104106 self .layer = layer
105107 self .kind = kind
108+ self ._input_shape = None
106109
107110 @property
108111 def parameters (self ):
@@ -114,7 +117,7 @@ def parameters(self):
114117 raise NodeDispatchError ('Caffe parameters not found for layer kind: %s' % (self .kind ))
115118
116119 @staticmethod
117- def get_kernel_value (scalar , repeated , idx , default = None ):
120+ def get_kernel_value (scalar , repeated , idx , default = None , params = None ):
118121 if scalar :
119122 return scalar
120123 if repeated :
@@ -127,15 +130,26 @@ def get_kernel_value(scalar, repeated, idx, default=None):
127130 # Extract the value for the given spatial dimension
128131 return repeated [idx ]
129132 if default is None :
133+ #pdb.set_trace()
130134 raise ValueError ('Unable to determine kernel parameter!' )
131135 return default
132136
137+ def set_input_shape (self , input_shape ):
138+ self ._input_shape = input_shape
139+
133140 @property
134141 def kernel_parameters (self ):
135142 assert self .kind in (NodeKind .Convolution , NodeKind .Pooling )
136143 params = self .parameters
137- k_h = self .get_kernel_value (params .kernel_h , params .kernel_size , 0 )
138- k_w = self .get_kernel_value (params .kernel_w , params .kernel_size , 1 )
144+ global_pool = hasattr (params , 'global_pooling' )
145+ if params .kernel_size :
146+ k_h = self .get_kernel_value (params .kernel_h , params .kernel_size , 0 )
147+ k_w = self .get_kernel_value (params .kernel_w , params .kernel_size , 1 )
148+ elif self ._input_shape :
149+ k_h , k_w = [self ._input_shape .height , self ._input_shape .width ]
150+ else : #errors out in get_kernel_value function
151+ k_h = self .get_kernel_value (params .kernel_h , params .kernel_size , 0 )
152+ k_w = self .get_kernel_value (params .kernel_w , params .kernel_size , 1 )
139153 s_h = self .get_kernel_value (params .stride_h , params .stride , 0 , default = 1 )
140154 s_w = self .get_kernel_value (params .stride_w , params .stride , 1 , default = 1 )
141155 p_h = self .get_kernel_value (params .pad_h , params .pad , 0 , default = 0 )
0 commit comments