-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
this is the code:
import torch
from pytorch_block_sparse import BlockSparseLinear
x = torch.randn(32, 128).to('cuda')
y = torch.randn(32, 64).to('cuda')
model = torch.nn.Sequential(
BlockSparseLinear(128, 64)
).to('cuda')
y_pred = model(x)
loss = torch.nn.MSELoss()(y_pred, y)
loss.backward()
i get this error on last line:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-1-656a0631d816> in <module>
11 y_pred = model(x)
12 loss = torch.nn.MSELoss()(y_pred, y)
---> 13 loss.backward()
~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
--> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
222
223 def register_hook(self, hook):
~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
130 Variable._execution_engine.run_backward(
131 tensors, grad_tensors_, retain_graph, create_graph,
--> 132 allow_unreachable=True) # allow_unreachable flag
133
134
~/anaconda3/lib/python3.7/site-packages/torch/autograd/function.py in apply(self, *args)
87 def apply(self, *args):
88 # _forward_cls is defined by derived class
---> 89 return self._forward_cls.backward(self, *args) # type: ignore
90
91
~/anaconda3/lib/python3.7/site-packages/pytorch_block_sparse/block_sparse_linear.py in backward(ctx, grad_output)
112
113 assert(not (grad_weight1 == 0).all())
--> 114 assert(grad_input1.shape == input.shape)
115 return grad_input1, grad_weight1, None
116
AttributeError: 'NoneType' object has no attribute 'shape'
also this warning at model = ...:
.../pytorch_block_sparse/block_sparse.py:88: UserWarning: This overload of nonzero is deprecated:
nonzero()
Consider using one of the following signatures instead:
nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
nnz = block_mask.nonzero()
pip freeze | grep torch:
pytorch-block-sparse==0.1.2
torch==1.7.0+cu101
torchaudio==0.7.0
torchvision==0.8.1+cu101
Metadata
Metadata
Assignees
Labels
No labels