I have this function, help me to convert to amp, tks so much!
def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None,
target_lpsoft=None, target_image_full=None, target_weak_params=None):
source_out = self.BaseNet_DP(source_x, ssl=True)
source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)
loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
loss_GTA.backward()
if self.opt.proto_rectify:
threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True)
else:
threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()
if self.opt.ema:
ema_input = target_image_full
with torch.no_grad():
ema_out = self.BaseNet_ema_DP(ema_input)
ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x)
target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
loss = torch.Tensor([0]).to(self.default_gpu)
batch, _, w, h = threshold_arg.shape
if self.opt.proto_rectify:
weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params)
rectified = weights * threshold_arg
threshold_arg = rectified.max(1, keepdim=True)[1]
rectified = rectified / rectified.sum(1, keepdim=True)
argmax = rectified.max(1, keepdim=True)[0]
threshold_arg[argmax < self.opt.train_thred] = 250
if self.opt.S_pseudo > 0:
threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
threshold_arg = threshold_argS
loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]))
if self.opt.rce:
rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce
if self.opt.regular_w > 0:
regular_loss = self.regular_loss(target_out['out'])
loss_CTS = loss_CTS + regular_loss * self.opt.regular_w
cluster_argS = None
loss_consist = torch.Tensor([0]).to(self.default_gpu)
if self.opt.proto_consistW > 0:
ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params) #N*256*H*W
ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat) #N*19*H*W
ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4)
mask = (ema2strong_feat_proto_distance != 250).float()
teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1)
targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS)
targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
prototype_tmp = self.objective_vectors.expand(4, -1, -1) #gpu memory limitation
strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers)
student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1)
loss_consist = F.kl_div(student, teacher, reduction='none')
loss_consist = (loss_consist * mask).sum() / mask.sum()
loss = loss + self.opt.proto_consistW * loss_consist
loss = loss + loss_CTS
loss.backward()
self.BaseOpti.step()
self.BaseOpti.zero_grad()
if self.opt.moving_prototype: #update prototype
ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach())
for t in range(len(ema_ids)):
self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False)
if self.opt.ema: #update ema model
for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()):
param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999)
for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()):
buffer_k.data = buffer_q.data.clone()
return loss.item(), loss_CTS.item(), loss_consist.item()