I am working on adding a loss item rot_loss
on existed loss in my programe. However, the error of modified by inplace operation is occured. I checked the whole programme, but still could not fix it. Please help with this bug. Thank you very very much.
The code for loss is
def train(self, args, logger=None):
"""
Train function of FixMatch.
From data_loader, it inference training data, computes losses, and update the networks.
"""
ngpus_per_node = torch.cuda.device_count()
#lb: labeled, ulb: unlabeled
self.train_model.train()
# for gpu profiling
start_batch = torch.cuda.Event(enable_timing=True)
end_batch = torch.cuda.Event(enable_timing=True)
start_run = torch.cuda.Event(enable_timing=True)
end_run = torch.cuda.Event(enable_timing=True)
start_batch.record()
best_eval_acc, best_it = 0.0, 0
scaler = GradScaler()
amp_cm = autocast if args.amp else contextlib.nullcontext
p_target=[]
p_target_idx=0
for _,x_lb, y_lb in self.loader_dict['train_lb']:
p_target_idx += 1
p_target.append(one_hot(y_lb.cuda(args.gpu),args.num_classes,args.gpu).mean(dim=0))
if p_target_idx * args.batch_size > args.batch_size * args.num_labels : break #batch size could be bigger than num labels, args.batch_size for stable estimation
p_target = torch.stack(p_target).mean(dim=0)
print('p_target:',p_target)
p_model_list=[]
for (_,x_lb, y_lb), (_, x_ulb_w, x_ulb_s1, x_ulb_s2, x_ulb_s1_rot,rot_v, _) in zip(self.loader_dict['train_lb'], self.loader_dict['train_ulb']):
# prevent the training iterations exceed args.num_train_iter
if self.it > args.num_train_iter:
break
end_batch.record()
torch.cuda.synchronize()
start_run.record()
num_lb = x_lb.shape[0]
num_ulb = x_ulb_w.shape[0]
assert num_ulb == x_ulb_s1.shape[0]
x_lb, x_ulb_w, x_ulb_s1, x_ulb_s2, x_ulb_s1_rot = x_lb.cuda(args.gpu), x_ulb_w.cuda(args.gpu), x_ulb_s1.cuda(args.gpu),x_ulb_s2.cuda(args.gpu),x_ulb_s1_rot.cuda(args.gpu)
rot_v = rot_v.cuda(args.gpu)
y_lb = y_lb.cuda(args.gpu)
inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s1, x_ulb_s2))
# inference and calculate sup/unsup losses
with amp_cm():
logits,embeds = self.train_model(inputs)
logits_x_lb = logits[:num_lb]
logits_x_ulb_w, logits_x_ulb_s1,logits_x_ulb_s2 = logits[num_lb:].chunk(3)
del logits
logits_rot = self.rot_classifier(x_ulb_s1_rot)
rot_loss = ce_loss(logits_rot, rot_v, reduction='mean')
rot_loss = (args.ulb_loss_ratio / 2) * rot_loss
# hyper-params for update
T = self.t_fn(self.it)
prob_x_ulb = torch.softmax(logits_x_ulb_w,dim=1)
if len(p_model_list) < 128:
p_model_list.append(prob_x_ulb.mean(dim=0).detach())
else:
p_model_list.pop(0)
p_model_list.append(prob_x_ulb.mean(dim=0).detach())
p_model = torch.stack(p_model_list).mean(dim=0)
prob_x_ulb = prob_x_ulb * p_target / p_model
prob_x_ulb = (prob_x_ulb / prob_x_ulb.sum(dim=-1,keepdim=True))
sharpen_prob_x_ulb = prob_x_ulb ** (1/T)
sharpen_prob_x_ulb = (sharpen_prob_x_ulb / sharpen_prob_x_ulb.sum(dim=-1,keepdim=True)).detach()
mixed_inputs = torch.cat((x_lb, x_ulb_s1, x_ulb_s2))
input_labels = torch.cat([one_hot(y_lb,args.num_classes,args.gpu), sharpen_prob_x_ulb, sharpen_prob_x_ulb], dim=0)
mixed_x, mixed_y,_ = mixup_one_target(mixed_inputs, input_labels,
args.gpu,
args.alpha,
is_bias=True)
mixed_logits,_ = self.train_model(mixed_x)
#
sup_loss = -torch.mean(torch.sum(mixed_y[:num_lb]* F.log_softmax(mixed_logits[:num_lb],dim=1), dim=1))
unsup_loss = (args.ulb_loss_ratio/2)*consistency_loss(mixed_logits[num_lb:], mixed_y[num_lb:])\
+ (args.ulb_loss_ratio/2)*rot_loss
total_loss = sup_loss + self.lambda_u * unsup_loss
# parameter updates
if args.amp:
scaler.scale(total_loss).backward()
scaler.step(self.optimizer)
scaler.update()
else:
total_loss.backward()
self.optimizer.step()
self.scheduler.step()
self.train_model.zero_grad()
with torch.no_grad():
self._eval_model_update()
end_run.record()
torch.cuda.synchronize()
#tensorboard_dict update
tb_dict = {}
tb_dict['train/sup_loss'] = sup_loss.detach()
tb_dict['train/unsup_loss'] = unsup_loss.detach()
tb_dict['train/total_loss'] = total_loss.detach()
tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
tb_dict['train/prefecth_time'] = start_batch.elapsed_time(end_batch)/1000.
tb_dict['train/run_time'] = start_run.elapsed_time(end_run)/1000.
if self.it % self.num_eval_iter == 0:
eval_dict = self.evaluate(args=args)
tb_dict.update(eval_dict)
save_path = os.path.join(args.save_dir, args.save_name)
if tb_dict['eval/top-1-acc'] > best_eval_acc:
best_eval_acc = tb_dict['eval/top-1-acc']
best_it = self.it
self.print_fn(f"{self.it} iteration, USE_EMA: {hasattr(self, 'eval_model')}, {tb_dict}, BEST_EVAL_ACC: {best_eval_acc}, at {best_it} iters")
if not args.multiprocessing_distributed or \
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
if self.it == best_it:
self.save_model('model_best.pth', save_path)
if not self.tb_log is None:
self.tb_log.update(tb_dict, self.it)
self.it +=1
del tb_dict
start_batch.record()
if self.it > 2**19:
self.num_eval_iter = 1000
eval_dict = self.evaluate(args=args)
eval_dict.update({'eval/best_acc': best_eval_acc, 'eval/best_it': best_it})
return eval_dict
The model defination of self.train_model
and self.rot_classifier
is :
class WideResNet(nn.Module):
def __init__(self, depth, num_classes, widen_factor=1, bn_momentum=0.1, leaky_slope=0.0, dropRate=0.0, use_embed=False, is_remix=False):
super(WideResNet, self).__init__()
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
assert ((depth - 4) % 6 == 0)
n = (depth - 4) // 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, bn_momentum, leaky_slope, dropRate)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, bn_momentum, leaky_slope, dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, bn_momentum, leaky_slope, dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=bn_momentum)
self.relu = nn.LeakyReLU(negative_slope=leaky_slope, inplace=False)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]
self.use_embed = use_embed
# rot_classifier for Remix Match
if is_remix:
self.rot_classifier = nn.Linear(self.nChannels, 4)
# init bias
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x, ood_test=False):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
output = self.fc(out)
if ood_test:
return output, out
else:
if self.use_embed:
return output, out
else:
return output
def rot_classify(self, rot_embeds):
out = self.conv1(rot_embeds)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
output = self.rot_classifier(out)
return output
Here is the TraceBack of the error:
Traceback (most recent call last):
File "/home/lr/wuhao/anaconda3/envs/ssl/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/home/lr/wuhao/ssl-consistency-pytorch/remixmatch.py", line 220, in main_worker
trainer(args, logger=logger)
File "/home/lr/wuhao/ssl-consistency-pytorch/models/remixmatch/remixmatch.py", line 189, in train
scaler.scale(total_loss).backward()
File "/home/lr/wuhao/anaconda3/envs/ssl/lib/python3.6/site-packages/torch/tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/lr/wuhao/anaconda3/envs/ssl/lib/python3.6/site-packages/torch/autograd/__init__.py", line 147, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128]] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).