diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..bb33fe1 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cuda"): +def summary(model, input_size, batch_size=-1, device="cuda", force_dtype=None): def register_hook(module): @@ -51,6 +51,8 @@ def hook(module, input, output): dtype = torch.cuda.FloatTensor else: dtype = torch.FloatTensor + if force_dtype: + dtype = force_dtype # multiple inputs to the network if isinstance(input_size, tuple):