Hello all, this is the definition of total variation loss
So, I implemented it as follows:
def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h + tv_w)
Is it correct? What does it means img[:,:,1:,:] - img[:,:,:-1,:]
?