diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..6b92925 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -33,6 +33,14 @@ def hook(module, input, output): summary[m_key]["output_shape"] = [ [-1] + list(o.size())[1:] for o in output ] + elif isinstance(output,OrderedDict): + keys = list(output.keys()) + if 'out' in keys: + key='out' + else: + key = keys[-1] + summary[m_key]["output_shape"] = list(output[key].size()) + summary[m_key]["output_shape"][0] = batch_size else: summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size