I have to implement a loss in backward of convolution layer as illustrated in below code. But I don’t know how can I get grad_weight = grad_weight + cont_loss_weight
such that the shape of cont_loss_weight should be same as that of grad_weight and grad_bias = grad_bias + cont_loss_bias
such that the shape of cont_loss_bias should be same as that of grad_bias.
class Custom_Convolution(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, padding): #input's shape = ([batch_size=100, 96, 8, 8])
output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)
ctx.save_for_backward(input, weight, bias, output)
return output #output's shape = ([[batch_size= 100,128, 4, 4])
@staticmethod
def backward(ctx, grad_output): #grad_output size = ([batch_size= 100, 128, 4, 4])
input, weight, bias, output = ctx.saved_tensors #input's size = ([batch_size=100 , 96,8,8])
features = output
features = features.view(features.shape[0], features.shape[1], -1)
features = torch.nn.functional.normalize(features,dim=1)
#Total_features= features.shape[0]* features.shape[1]
cont_loss = torch.tensor([0.]).requires_grad_(requires_grad=True).to(dev) # shape: ([1])
for ..... :
# My code for loss... includes some operations like torch.div,exp,sum...
# Calculation of loss for each feature 'i' (Li) out of Total_features
# Total loss = sum of all Li (Number of Li values = features.shape[0]* features.shape[1])
cont_loss = cont_loss.add(cont_loss_img) # Total contrastive loss of the batch
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input( input.shape, weight, grad_output) #shape = ([batch_size,96,8,8])
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output) #shape = ([128,96,5,5])
# cont_loss_weight = torch.nn.grad.conv2d_weight(cont_loss, weight.shape, grad_output)
# grad_weight = grad_weight + cont_loss_weight
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0,2,3)) #shape = ([128])
#grad_bias = grad_bias + cont_loss_bias
if bias is not None:
return grad_input, grad_weight, grad_bias, None, None
else:
return grad_input, grad_weight, None, None