Implement total variation loss in pytorch

Hello all, this is the definition of total variation loss

So, I implemented it as follows:

    def compute_total_variation_loss(img, weight):      
        tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
        tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()    
        return weight * (tv_h + tv_w)

Is it correct? What does it means img[:,:,1:,:] - img[:,:,:-1,:]?

def total_variation_loss(img, weight):
     bs_img, c_img, h_img, w_img = img.size()
     tv_h = torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2).sum()
     tv_w = torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2).sum()
     return weight*(tv_h+tv_w)/(bs_img*c_img*h_img*w_img)

which is not equal to img[:,:,1:,:] - img[:,:,:-1,:].

2 Likes