Some help please. I am facing the problem that the model cannot be trained. I have tried many methods but still cannot solve the problem.
Traceback (most recent call last):
File "tools/train.py", line 196, in <module>
main()
File "tools/train.py", line 185, in main
train_model(
File "/home/boot/STU/workspaces/zjf/ViTPose/mmpose/apis/train.py", line 200, in train_model
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/epoch_based_runner.py", line 127, in run
epoch_runner(data_loaders[i], **kwargs)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/epoch_based_runner.py", line 51, in train
self.call_hook('after_train_iter')
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/base_runner.py", line 307, in call_hook
getattr(hook, fn_name)(self)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/hooks/optimizer.py", line 36, in after_train_iter
runner.outputs['loss'].backward()
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/autograd/__init__.py", line 145, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 144, 28, 28]], which is output 0 of PermuteBackward, is at version 3; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Using “torch.autograd.set_detect_anomaly(True)” to checkout.
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnConvolutionBackward. Traceback of forward call that caused the error:
File "tools/train.py", line 196, in <module>
main()
File "tools/train.py", line 185, in main
train_model(
File "/home/boot/STU/workspaces/zjf/ViTPose/mmpose/apis/train.py", line 200, in train_model
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/epoch_based_runner.py", line 127, in run
epoch_runner(data_loaders[i], **kwargs)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/epoch_based_runner.py", line 50, in train
self.run_iter(data_batch, train_mode=True, **kwargs)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/epoch_based_runner.py", line 29, in run_iter
outputs = self.model.train_step(data_batch, self.optimizer,
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/parallel/distributed.py", line 53, in train_step
output = self.module.train_step(*inputs[0], **kwargs[0])
File "/home/boot/STU/workspaces/zjf/ViTPose/mmpose/models/detectors/base.py", line 104, in train_step
losses = self.forward(**data_batch)
File "/home/boot/STU/workspaces/zjf/mmcv/mmcv/runner/fp16_utils.py", line 98, in new_func
return old_func(*args, **kwargs)
File "/home/boot/STU/workspaces/zjf/ViTPose/mmpose/models/detectors/top_down.py", line 138, in forward
return self.forward_train(img, target, target_weight, img_metas,
File "/home/boot/STU/workspaces/zjf/ViTPose/mmpose/models/detectors/top_down.py", line 145, in forward_train
output = self.backbone(img)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 452, in forward
x = self.forward_features(x)
File "/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 441, in forward_features
x = stage(x)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 331, in forward
x = blk(x)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 138, in forward
x = x + self.drop_path(self.attn(self.norm1(x)))
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 90, in forward
qkv = self.qkv(x)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/boot/anaconda3/envs/vitpose/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
(function _print_stack)
It seems that the error occurred in File"/home/boot/STU/workspaces/zjf/ViTPose/models/dilateformer.py", line 90, in forward
which is “qkv = self.qkv(x)” in function forward
However, I can’t fix the error in this code
class MultiDilatelocalAttention(nn.Module):
"Implementation of Dilate-attention"
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0.,proj_drop=0., kernel_size=3, dilation=[1, 2, 3]):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.dilation = dilation
self.kernel_size = kernel_size
self.scale = qk_scale or head_dim ** -0.5
self.num_dilation = len(dilation)
assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"
self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
self.dilate_attention = nn.ModuleList(
[DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])
for i in range(self.num_dilation)])
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
x = x.permute(0, 3, 1, 2)# B, C, H, W
"""
Only here grad disapear
"""
print('x before reshape:', x._version)
qkv = self.qkv(x).reshape(B, 3, self.num_dilation, C//self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)
#num_dilation,3,B,C//num_dilation,H,W
print('qkv:', qkv._version)
x = x.reshape(B, self.num_dilation, C//self.num_dilation, H, W).permute(1, 0, 3, 4, 2)
print('x after reshape:', x._version)
# num_dilation, B, H, W, C//num_dilation
for i in range(self.num_dilation):
x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])# B, H, W,C//num_dilation
print('x after qkv:', x._version)
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)
x = self.proj(x)
x = self.proj_drop(x)
return x