Convert training function to automatic mixed precision pytorch

I have this function, help me to convert to amp, tks so much!

def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None, 
            target_lpsoft=None, target_image_full=None, target_weak_params=None):

        source_out = self.BaseNet_DP(source_x, ssl=True)
        source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

        loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)
        loss_GTA.backward()        

        if self.opt.proto_rectify:
            threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True)
        else:
            threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()

        if self.opt.ema:
            ema_input = target_image_full
            with torch.no_grad():
                ema_out = self.BaseNet_ema_DP(ema_input)
            ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
            ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)

        target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x)
        target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

        loss = torch.Tensor([0]).to(self.default_gpu)
        batch, _, w, h = threshold_arg.shape
        if self.opt.proto_rectify:
            weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params)
            rectified = weights * threshold_arg
            threshold_arg = rectified.max(1, keepdim=True)[1]
            rectified = rectified / rectified.sum(1, keepdim=True)
            argmax = rectified.max(1, keepdim=True)[0]
            threshold_arg[argmax < self.opt.train_thred] = 250
        if self.opt.S_pseudo > 0:
            threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
            threshold_arg = threshold_argS

        loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]))

        if self.opt.rce:
            rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
            loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce

        if self.opt.regular_w > 0:
            regular_loss = self.regular_loss(target_out['out'])
            loss_CTS = loss_CTS + regular_loss * self.opt.regular_w

        cluster_argS = None
        loss_consist = torch.Tensor([0]).to(self.default_gpu)
        if self.opt.proto_consistW > 0:
            ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params)         #N*256*H*W
            ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat)  #N*19*H*W
            ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4)
            mask = (ema2strong_feat_proto_distance != 250).float()
            teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS)
            targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
            targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

            prototype_tmp = self.objective_vectors.expand(4, -1, -1)  #gpu memory limitation
            strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers)
            student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

            loss_consist = F.kl_div(student, teacher, reduction='none')
            loss_consist = (loss_consist * mask).sum() / mask.sum()
            loss = loss + self.opt.proto_consistW * loss_consist

        loss = loss + loss_CTS
        loss.backward()
        self.BaseOpti.step()
        self.BaseOpti.zero_grad()

        if self.opt.moving_prototype: #update prototype
            ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach())
            for t in range(len(ema_ids)):
                self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False)
        
        if self.opt.ema: #update ema model
            for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()):
                param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999)
            for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()):
                buffer_k.data = buffer_q.data.clone()

        return loss.item(), loss_CTS.item(), loss_consist.item()

The automatic mixed-precision tutorials give you some examples how to add the amp util to the “standard” training routine. :slight_smile:

Thanks for answering. But i have problem when this training function have 2 losses function and have other transformations

Working with multiple models, losses, and optimizers is also explained in the tutorial a bit further down.
What kind of transformations are you concerned about?

def step(self, source_x, source_label, target_x, target_imageS=None, target_params=None, target_lp=None, 
            target_lpsoft=None, target_image_full=None, target_weak_params=None):
        self.BaseOpti.zero_grad()
        with torch.cuda.amp.autocast():
          source_out = self.BaseNet_DP(source_x, ssl=True)
          source_outputUp = F.interpolate(source_out['out'], size=source_x.size()[2:], mode='bilinear', align_corners=True)

          loss_GTA = cross_entropy2d(input=source_outputUp, target=source_label)      

          if self.opt.proto_rectify:
              threshold_arg = F.interpolate(target_lpsoft, scale_factor=0.25, mode='bilinear', align_corners=True)
          else:
              threshold_arg = F.interpolate(target_lp.unsqueeze(1).float(), scale_factor=0.25).long()

          if self.opt.ema:
              ema_input = target_image_full
              with torch.no_grad():
                  ema_out = self.BaseNet_ema_DP(ema_input)
              ema_out['feat'] = F.interpolate(ema_out['feat'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)
              ema_out['out'] = F.interpolate(ema_out['out'], size=(int(ema_input.shape[2]/4), int(ema_input.shape[3]/4)), mode='bilinear', align_corners=True)

          target_out = self.BaseNet_DP(target_imageS) if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_x)
          target_out['out'] = F.interpolate(target_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
          target_out['feat'] = F.interpolate(target_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

          loss = torch.Tensor([0]).to(self.default_gpu)
          batch, _, w, h = threshold_arg.shape
          if self.opt.proto_rectify:
              weights = self.get_prototype_weight(ema_out['feat'], target_weak_params=target_weak_params)
              rectified = weights * threshold_arg
              threshold_arg = rectified.max(1, keepdim=True)[1]
              rectified = rectified / rectified.sum(1, keepdim=True)
              argmax = rectified.max(1, keepdim=True)[0]
              threshold_arg[argmax < self.opt.train_thred] = 250
          if self.opt.S_pseudo > 0:
              threshold_argS = self.label_strong_T(threshold_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
              cluster_argS = self.label_strong_T(cluster_arg.clone().float(), target_params, padding=250, scale=4).to(torch.int64)
              threshold_arg = threshold_argS

          loss_CTS = cross_entropy2d(input=target_out['out'], target=threshold_arg.reshape([batch, w, h]))

          if self.opt.rce:
              rce = self.rce(target_out['out'], threshold_arg.reshape([batch, w, h]).clone())
              loss_CTS = self.opt.rce_alpha * loss_CTS + self.opt.rce_beta * rce

          if self.opt.regular_w > 0:
              regular_loss = self.regular_loss(target_out['out'])
              loss_CTS = loss_CTS + regular_loss * self.opt.regular_w

          cluster_argS = None
          loss_consist = torch.Tensor([0]).to(self.default_gpu)
          if self.opt.proto_consistW > 0:
              ema2weak_feat = self.full2weak(ema_out['feat'], target_weak_params)         #N*256*H*W
              ema2weak_feat_proto_distance = self.feat_prototype_distance(ema2weak_feat)  #N*19*H*W
              ema2strong_feat_proto_distance = self.label_strong_T(ema2weak_feat_proto_distance, target_params, padding=250, scale=4)
              mask = (ema2strong_feat_proto_distance != 250).float()
              teacher = F.softmax(-ema2strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

              targetS_out = target_out if self.opt.S_pseudo > 0 else self.BaseNet_DP(target_imageS)
              targetS_out['out'] = F.interpolate(targetS_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
              targetS_out['feat'] = F.interpolate(targetS_out['feat'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)

              prototype_tmp = self.objective_vectors.expand(4, -1, -1)  #gpu memory limitation
              strong_feat_proto_distance = self.feat_prototype_distance_DP(targetS_out['feat'], prototype_tmp, self.class_numbers)
              student = F.log_softmax(-strong_feat_proto_distance * self.opt.proto_temperature, dim=1)

              loss_consist = F.kl_div(student, teacher, reduction='none')
              loss_consist = (loss_consist * mask).sum() / mask.sum()
              loss = loss + self.opt.proto_consistW * loss_consist

          loss = loss + loss_CTS

        scaler.scale(loss_GTA).backward(retain_graph=True)
        scaler.scale(loss).backward()
        scaler.step(self.BaseOpti)
        scaler.update()

        if self.opt.moving_prototype: #update prototype
            ema_vectors, ema_ids = self.calculate_mean_vector(ema_out['feat'].detach(), ema_out['out'].detach())
            for t in range(len(ema_ids)):
                self.update_objective_SingleVector(ema_ids[t], ema_vectors[t].detach(), start_mean=False)
        
        if self.opt.ema: #update ema model
            for param_q, param_k in zip(self.BaseNet.parameters(), self.BaseNet_ema.parameters()):
                param_k.data = param_k.data.clone() * 0.999 + param_q.data.clone() * (1. - 0.999)
            for buffer_q, buffer_k in zip(self.BaseNet.buffers(), self.BaseNet_ema.buffers()):
                buffer_k.data = buffer_q.data.clone()

        return loss.item(), loss_CTS.item(), loss_consist.item()

I converted to this. Am I right to do this?

The usage of the amp utilities look alright.
Based on the code it seems that you explicitly want to use the float data type in some operations, so you could also use nested autocast statements and disable the casting for these operations.

1 Like