Is there any way to use Tensor.scatter_reduce_(mode='amax') under PyTorch 1.10.1?

I know that I need Pytorch >= 1.12.0 when using tensor.scatter_reduce function in 'amax' mode. However, for some reason, I can’t use Pytorch > 1.10.1 in my development environment. So is it possible for me to implement same function as tensor.scatter_reduce(mode='amax') under pytorch 1.10.1?

You might be able to use torch_scatter instead of the built-in utils. for your use case.

Thanks for your reply, I tried your suggestion hours ago, however, some errors happened. I used this function in every ViT encoders block, it’s strange that, when I debug block by block and tested its output individually, it outputs the right result without error, but when I debug all blocks together it showed the following error:

Traceback (most recent call last):
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/base_trainer.py", line 85, in train
    self.train_epoch()
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/ltr_trainer.py", line 133, in train_epoch
    self.cycle_dataset(loader)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/ltr_trainer.py", line 86, in cycle_dataset
    loss, stats = self.actor(data)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/actors/ostrack.py", line 31, in __call__
    out_dict = self.forward_pass(data)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/actors/ostrack.py", line 73, in forward_pass
    return_last_attn=False)
  File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/ostrack.py", line 50, in forward
    return_last_attn=return_last_attn, )
  File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/base_backbone.py", line 154, in forward
    x, aux_dict = self.forward_features(z, x,)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/vit_tome.py", line 85, in forward_features
    x, source, size = blk(x, global_index_t, source_=source, size_=size)
  File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/layers/attn_blocks.py", line 175, in forward
    merge, x_s, source_
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/tome/merge.py", line 248, in merge_source
    source = merge(source, mode="amax")
  File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/tome/merge.py", line 82, in merge
    dst, _ = scatter(src, dst_idx.expand(n, r, c), dim=-2, out=dst, reduce='max')
  File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch_scatter/scatter.py", line 160, in scatter
    return scatter_max(src, index, dim, out, dim_size)[0]
  File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch_scatter/scatter.py", line 72, in scatter_max
    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function.

Specific code as following:

from torch_scatter import scatter

...

elif mode == 'amax':
     dst = scatter(src, dst_idx.expand(n, r, c), dim=-2, out=dst, reduce='max')