Skip to content

Missing methods to easily access reset_state and states within keras.Model #61

Open
@mergian

Description

@mergian

System information.

TensorFlow version (you are using): 2.11.0
Are you willing to contribute it (Yes/No) : yes

Describe the feature and the current behavior/state.

Stateful RNN layers have the method layer.reset_state() and the states themselves can be fetched through layer.states. However, when I have a model consisting of many RNN layers, mixed with other layers, it becomes tedious to loop through all layers and reset them manually. So state-of-the-art is something like:

for l in model.layers:
	if hasattr(l, 'reset_state'):
		l.reset_state()

This becomes really combersome, when you use bidirectional RNNs, because then you need to also check if the layer has l.forward_layer and l.backward_layer and also reset the states in them.

Therefore my proposal is to add reset_state, get_states and set_states to keras.Model. The last two work similar to get_weights() and set_weights(). Possible implementation could be:

def reset_state(self):
	def reset_state(l):
		if hasattr(l, 'reset_state'):
			l.reset_state()
		if hasattr(l, 'forward_layer'):
			reset_state(l.forward_layer)
		if hasattr(l, 'backward_layer'):
			reset_state(l.backward_layer)
			
	for l in self.layers:
		reset_state(l)
			
def get_states(self):
	states = []
	def get_states(l):
		if hasattr(l, 'states'):
			lst += l.states
		if hasattr(l, 'forward_layer'):
			get_states(l.forward_layer)
		if hasattr(l, 'backward_layer'):
			get_states(l.backward_layer)
	
	for l in self.layers:
		get_states(l)
		
	return states
	
def set_states(self, states):
	it = iter(states)
	def set_states(l):
		if hasattr(l, 'states'):
			for s in l.states:
				s.assign(next(it))
		if hasattr(l, 'forward_layer'):
			set_states(l.forward_layer)
		if hasattr(l, 'backward_layer'):
			set_states(l.backward_layer)
	
	for l in self.layers:
		set_states(l)

Will this change the current api? How?
Yes, it adds the methods reset_state, get_states and set_states to class keras.Model, so people don't need to loop through the Keras data structures.

Who will benefit from this feature?
All people that use stateful RNN layers ;)

Contributing

  • Do you want to contribute a PR? (yes/no): yes, if my employer is willing to sign the CLA
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing): see above

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions