About device attribute of Tensor

Hi, all
I got error: RuntimeError: get_device is not implemented for tensors with CPU backend

I have customized a loss function like kl loss, I have already declared self.kl_loss_ch in the init function, why it does not have the device=‘cuda:0’ attribute like other tensors in the forward process, how can I put the tensor into the corresponding GPU?

I creat self.kl_loss_bs is used to store intermediate results, how do I add it to the calculation graph
thanks


class SegmentationMultiLosses(nn.CrossEntropyLoss):
    """2D Cross Entropy Loss with Multi-L1oss"""
    def __init__(self, nclass=-1, weight=None,size_average=True, ignore_index=-1):
        super(SegmentationMultiLosses, self).__init__(weight, size_average, ignore_index)
        self.nclass = nclass
        self.epsilon = 1e-6
        self.dt_min = 0
        self.dt_max = 16
        self.kl_loss_bs = torch.Tensor([0])
        self.kl_loss_ch = torch.Tensor([0])


    def forward(self, *inputs):
        # out: [b, 5, h, w],  [b, 4], target: [b, h, w], [b, 4]
        out, target_img, target_exist, target_dt= tuple(inputs)
        #out_img, out_exist, out_dt_img= out
        out_img, out_dt_img, out_exist = out
        #out_img, target_img = tuple(inputs) #train deeplab
        #pdb.set_trace()
        loss_seg = super(SegmentationMultiLosses, self).forward(out_img, target_img)
        loss_exist = nn.BCELoss().forward(out_exist, target_exist)
        loss_dt = nn.MSELoss().forward(out_dt_img, target_dt)

        #kl loss  input and target
        out_dt_copy =  out_dt_img.clone().detach()
        out_dt_norm = out_dt_copy.clamp(self.dt_min, self.dt_max) / self.dt_max

        out_seg_sm  = F.softmax(out_img, dim=1)
        seg_prob = out_seg_sm[:,1:,:,:]

        #kl  loss 
        batch_num = seg_prob.shape[0]
        channel_num = seg_prob.shape[1]
        for b in range(batch_num):
            for c in range(channel_num):
                lane_prob_dt = torch.clamp(out_dt_norm[b,c,:,:], min=self.epsilon) # p
                bg_prob_dt = 1 - out_dt_norm[b,c,:,:] # 1-p
                bg_prob_dt = torch.clamp(bg_prob_dt, min=self.epsilon)
                lane_prob_seg = torch.clamp(seg_prob[b,c,:,:], min=self.epsilon)
                bg_prob_seg = 1 - seg_prob[b,c,:,:]
                bg_prob_seg = torch.clamp(bg_prob_seg, min=self.epsilon)
                kl_loss_map = torch.mean((bg_prob_dt * torch.log(bg_prob_dt / bg_prob_seg).float()) + (lane_prob_dt * torch.log(lane_prob_dt / lane_prob_seg).float()))
                self.kl_loss_ch = torch.add(self.kl_loss_ch, kl_loss_map)
            self.kl_loss_bs = torch.add(self.kl_loss_bs, torch.div(self.kl_loss_ch, channel_num))
        pdb.set_trace()
        loss_fuse = torch.div(self.kl_loss_bs.float(), batch_num)

        print('loss_seg: {}  loss_exit: {}  loss_dt: {} loss_fuse: {} '.format(loss_seg.item(), loss_exist.item()*0.1, loss_dt.item(), loss_fuse.item() * 10))
        loss = loss_seg + loss_dt + 0.1 * loss_exist + 10 * loss_fuse

        return loss


self.criterion = SegmentationMultiLosses(nclass=self.nclass, weight=torch.Tensor([0.4, 1, 1, 1, 1])).cuda()


debug get:

self.kl_loss_bs.requires_grad
True
(Pdb) self.kl_loss_bs
(Pdb) tensor([8.1037], grad_fn=<AddBackward0>)
kl_loss_map
(Pdb) tensor(0.0588, device='cuda:0', grad_fn=<MeanBackward1>)

If kl_loss_bs and kl_loss_ch should be trainable parameters, you should wrap them into nn.Parameter:

class SegmentationMultiLosses(nn.CrossEntropyLoss):
    """2D Cross Entropy Loss with Multi-L1oss"""
    def __init__(self, nclass=-1, weight=None,size_average=True, ignore_index=-1):
        super(SegmentationMultiLosses, self).__init__(weight, size_average, ignore_index)
        ...
        self.kl_loss_bs = nn.Parameter(torch.tensor([0]))
        self.kl_loss_ch = nn.Parameter(torch.tensor([0]))

Alternatively, if you don’t want to train them, you could register them as buffers:

class SegmentationMultiLosses(nn.CrossEntropyLoss):
    """2D Cross Entropy Loss with Multi-L1oss"""
    def __init__(self, nclass=-1, weight=None,size_average=True, ignore_index=-1):
        super(SegmentationMultiLosses, self).__init__(weight, size_average, ignore_index)
        ...
        self.register_buffer('kl_loss_bs', 'torch.tensor([0]))
        self.register_buffer('kl_loss_ch', torch.tensor([0]))

Both approach will make sure that these tensors are also pushed to the device, if you call model.to(device) or model.cuda().

Also, the recommended way to initialize tensors given a list or value is torch.tensor (lowercase t). :wink: