From 5bd5fa13894eeb90e78dc56e32830a992cbf9d7c Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Thu, 1 Dec 2022 17:51:31 +0100 Subject: [PATCH 1/2] fix direct_returnn_layer_call() for multiple bases --- pytorch_to_returnn/torch/nn/modules/module.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_to_returnn/torch/nn/modules/module.py b/pytorch_to_returnn/torch/nn/modules/module.py index 1f9e920..06e8fba 100644 --- a/pytorch_to_returnn/torch/nn/modules/module.py +++ b/pytorch_to_returnn/torch/nn/modules/module.py @@ -526,15 +526,16 @@ def direct_returnn_layer_call(cls) -> bool: """ if not cls.has_torch_forward(): return True - base = cls - while base is not object: + if cls is object: + return True + for base in cls.__bases__: + if not issubclass(base, Module): + continue if cls.create_returnn_layer_dict != base.create_returnn_layer_dict: return True elif cls.forward != base.forward: return False - assert len(base.__bases__) == 1, "Not implemented otherwise" - base = base.__bases__[0] - return True + return any(base.direct_returnn_layer_call() for base in cls.__bases__ if issubclass(base, Module)) def check_returnn_layer(self, layer: LayerBase): """ From 89ef0e88d99f871f640e51fcbb3f90a496eb30b1 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 7 Dec 2022 16:46:34 +0100 Subject: [PATCH 2/2] queue instead of recursion --- pytorch_to_returnn/torch/nn/modules/module.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_to_returnn/torch/nn/modules/module.py b/pytorch_to_returnn/torch/nn/modules/module.py index 06e8fba..312dc24 100644 --- a/pytorch_to_returnn/torch/nn/modules/module.py +++ b/pytorch_to_returnn/torch/nn/modules/module.py @@ -526,16 +526,25 @@ def direct_returnn_layer_call(cls) -> bool: """ if not cls.has_torch_forward(): return True - if cls is object: - return True - for base in cls.__bases__: + + queue = [cls] + visited = set() + while len(queue) > 0: + base = queue.pop(0) + if base in visited: + continue + visited.add(base) + + if cls is object: + return True if not issubclass(base, Module): continue if cls.create_returnn_layer_dict != base.create_returnn_layer_dict: return True elif cls.forward != base.forward: return False - return any(base.direct_returnn_layer_call() for base in cls.__bases__ if issubclass(base, Module)) + queue += [base for base in cls.__bases__ if issubclass(base, Module)] + return True def check_returnn_layer(self, layer: LayerBase): """