def total_variation_loss(img, weight):
bs_img, c_img, h_img, w_img = img.size()
tv_h = torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2).sum()
tv_w = torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2).sum()
return weight*(tv_h+tv_w)/(bs_img*c_img*h_img*w_img)
which is not equal to img[:,:,1:,:] - img[:,:,:-1,:]
.