I know that I need Pytorch >= 1.12.0
when using tensor.scatter_reduce
function in 'amax'
mode. However, for some reason, I can’t use Pytorch > 1.10.1 in my development environment. So is it possible for me to implement same function as tensor.scatter_reduce(mode='amax')
under pytorch 1.10.1?
You might be able to use torch_scatter
instead of the built-in utils. for your use case.
Thanks for your reply, I tried your suggestion hours ago, however, some errors happened. I used this function in every ViT encoders block, it’s strange that, when I debug block by block and tested its output individually, it outputs the right result without error, but when I debug all blocks together it showed the following error:
Traceback (most recent call last):
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/base_trainer.py", line 85, in train
self.train_epoch()
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/ltr_trainer.py", line 133, in train_epoch
self.cycle_dataset(loader)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/trainers/ltr_trainer.py", line 86, in cycle_dataset
loss, stats = self.actor(data)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/actors/ostrack.py", line 31, in __call__
out_dict = self.forward_pass(data)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/train/actors/ostrack.py", line 73, in forward_pass
return_last_attn=False)
File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/ostrack.py", line 50, in forward
return_last_attn=return_last_attn, )
File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/base_backbone.py", line 154, in forward
x, aux_dict = self.forward_features(z, x,)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/vit_tome.py", line 85, in forward_features
x, source, size = blk(x, global_index_t, source_=source, size_=size)
File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/layers/attn_blocks.py", line 175, in forward
merge, x_s, source_
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/tome/merge.py", line 248, in merge_source
source = merge(source, mode="amax")
File "/data/new/SOT/OSTrack_new/lib/train/../../lib/models/ostrack/tome/merge.py", line 82, in merge
dst, _ = scatter(src, dst_idx.expand(n, r, c), dim=-2, out=dst, reduce='max')
File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch_scatter/scatter.py", line 160, in scatter
return scatter_max(src, index, dim, out, dim_size)[0]
File "/opt/conda/envs/lwh/lib/python3.7/site-packages/torch_scatter/scatter.py", line 72, in scatter_max
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
RuntimeError: Some elements marked as dirty during the forward method were not returned as output. The inputs that are modified inplace must all be outputs of the Function.
Specific code as following:
from torch_scatter import scatter
...
elif mode == 'amax':
dst = scatter(src, dst_idx.expand(n, r, c), dim=-2, out=dst, reduce='max')