Pytorch setting elements to zero with a "tensor index"

Hello Pytorch coders,

I’ve used Pytorch for a few months. But I recently want to create a customized pooling layer which is similar to the “Max-Pooling Dropout” layer and I think Pytorch provided us a bunch of tools to build such a layer I need. Here is my approach:

  1. use MaxPool2d with indices returned
  2. set tensor[indices] to zero
  3. I want it behaves like torch.take (without flatten) if possible.

here is how to get the “index tensor”. (I think it is called “index tensor”. correct me if I was wrong)

input1 = torch.randn(1, 1, 6, 6)
m = nn.MaxPool2d(2,2, return_indices=True)
val, indx = m(input1)

indx is the “index tensor” which can be used easily as

torch.take(input1, indx)

No flatten needed, no argument needed to set dimension. I think it make sense since indx is generated from input1.

Question: how do I set the values input1 pointed by indx to 0 in the “torch.take” style? I saw some answers like input1.flatten().scatter_(dim=-1, index=indx.flatten(), value=0.).reshape_as(input1). But I don’t think PyTorch returning such “index tensor” thing which cannot be applied directly. (Maybe I was wrong.)

Is there something like

torch.set_value(input1, indx, 0) ?

Update: if I use input1.flatten().scatter_(dim=-1, index=indx.flatten(), value=-5.).reshape_as(input1) with GPU (torch.version=1.12.1+cu113), then I got the inplace error:

Traceback (most recent call last):
  File "v1.1c.py", line 120, in <module>
    loss.backward()
  File "/opt/software/anaconda3/lib/python3.7/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/software/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 175, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 512, 6, 6]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Does it work if you simply replace the inplace scatter_ with the out-of-place version?

from torch import nn

input1 = torch.randn(1, 1, 6, 6, requires_grad=True)
m = nn.MaxPool2d(2,2, return_indices=True)
val, indx = m(input1)
out = input1.flatten().scatter(dim=-1, index=indx.flatten(), value=-5.).reshape_as(input1)
print(out)
loss = out.sum()
loss.backward()

Thanks eqy for your solution! It does solve the GPU error. But I am still wondering if there is a nice way to use the “tensor index”.