I am having problems backpropagating (loss.backward()
) the error when my model uses the aten::scatter_
function to compute the loss function.
First, I define my model where in the forward function, I use the aten::scatter_
function to create the product across individuals with the same id
.
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.beta = nn.Linear(1, 1, bias=False)
def forward(self, x, id):
b = self.beta(torch.ones(1, 1))
xb = x*b
N = torch.unique(id).shape[0]
scattering = torch.ones(N, 1, dtype=x.dtype)
# Here i am computing the product across rows of the same id
scattering_res = scattering.scatter_(0, id-1, xb, reduce='multiply' )
loss = torch.sum(scattering_res)
return loss
Here I create some tensors to apply the model and replicate the error message I am getting:
id = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5],dtype=torch.int64).reshape(10, 1)
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float).reshape(10, 1)
y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],dtype=torch.float).reshape(10, 1)
net = Net()
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = nn.MSELoss()
# Update params
optimizer.zero_grad()
loss = net(x, id)
loss.backward()
## RuntimeError: derivative for aten::scatter_ is not implemented
Do you know how to solve this issue?. Additionally, I came across a GitHub issue (Derivative issue when using scatter_max · Issue #63 · rusty1s/pytorch_scatter · GitHub) where a similar problem was posted, but I couldnât adjust it to solve my problem.
Below you can see the whole traceback error message.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Untitled-2 in <cell line: 32>()
<a href='untitled:Untitled-2?line=31'>32</a> optimizer.zero_grad()
<a href='untitled:Untitled-2?line=32'>33</a> loss = net(x, id)
---> <a href='untitled:Untitled-2?line=33'>34</a> loss.backward()
File c:\Users\u0133260\Anaconda3\envs\pyt\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
387 if has_torch_function_unary(self):
388 return handle_torch_function(
389 Tensor.backward,
390 (self,),
(...)
394 create_graph=create_graph,
395 inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File c:\Users\u0133260\Anaconda3\envs\pyt\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors_, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)
RuntimeError: derivative for aten::scatter_ is not implemented
crossposted at python - PyTorch - derivative for aten::scatter_ is not implemented - Stack Overflow