Output latitude question of backward

When customizing the activation function with one parameter, the latitude of weight is (C = 32, H = 1, w = 1). When I implement the backward method, The latitude of grad_ weight is (B = 30, C = 32, H = 1, W = 1). I tried to average the batch latitude data of grad_ weight to obtain the latitude (C = 32, H = 1, w = 1). In both cases, there is no problem with my network. It puzzles me, No matter whether the grad_weight output is (b = 30, C = 32, H = 1, w = 1) or (C = 32, H = 1, w = 1), my network seems to have no problem running. I can still get the updated value of the weight of the custom layer, and the latitude of weight is still (C = 32, H = 1, w = 1).In the face of this situation, how should I deal with it?
Here is my codeļ¼

``````class PDELU(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight):     # input.size()->(30,32,116,144)      weight.size()  ->  (32,1,1)
t = 0.9
output = torch.clone(input)
weight_ = weight.expand(input.size())
output[input <= 0] = weight_[input <= 0]*((1+(1-t)*output[input <= 0])**(1/(1-t))-1)
ctx.save_for_backward(input, weight, output)
# print(output.size())
return output    #  output.size()  -> (30,32,116,144)

@staticmethod
def backward(ctx, grad_outputs):  # grad_outputs.size()-> (30,32,116,144)

t = 0.9
input, weight, output = ctx.saved_tensors
weight_ = weight.expand(input.size())
grad_input_ = torch.clone(output)   # First calculate the gradient of the input with respect to the output
grad_input_[input > 0] = 1
grad_input_[input <= 0] = ((output[input <= 0] + weight_[input <= 0])**t)*(weight_[input <= 0]**(1-t))
grad_input = grad_input_*grad_outputs   # Calculate the gradient of the input with respect to the loss function
grad_weight_ = torch.clone(input)
grad_weight_[input > 0] = 0
grad_weight_[input <= 0] = (1+(1-t)*input[input <= 0])**(1/(1-t))-1
grad_weight = torch.sum((grad_weight_ * grad_outputs), dim=(-2, -1), keepdim=True)

return grad_input, grad_weight    # grad_input.size()-> (30,32,116,144)   grad_weight.size()->(30,32,1,1)
``````