Output latitude question of backward

When customizing the activation function with one parameter, the latitude of weight is (C = 32, H = 1, w = 1). When I implement the backward method, The latitude of grad_ weight is (B = 30, C = 32, H = 1, W = 1). I tried to average the batch latitude data of grad_ weight to obtain the latitude (C = 32, H = 1, w = 1). In both cases, there is no problem with my network. It puzzles me, No matter whether the grad_weight output is (b = 30, C = 32, H = 1, w = 1) or (C = 32, H = 1, w = 1), my network seems to have no problem running. I can still get the updated value of the weight of the custom layer, and the latitude of weight is still (C = 32, H = 1, w = 1).In the face of this situation, how should I deal with it?
Here is my codeļ¼š

class PDELU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weight):     # input.size()->(30,32,116,144)      weight.size()  ->  (32,1,1)
        t = 0.9
        output = torch.clone(input)
        weight_ = weight.expand(input.size())  
        output[input <= 0] = weight_[input <= 0]*((1+(1-t)*output[input <= 0])**(1/(1-t))-1)
        ctx.save_for_backward(input, weight, output)
        # print(output.size())
        return output    #  output.size()  -> (30,32,116,144)

    @staticmethod
    def backward(ctx, grad_outputs):  # grad_outputs.size()-> (30,32,116,144)

        t = 0.9
        input, weight, output = ctx.saved_tensors
        weight_ = weight.expand(input.size())
        grad_input_ = torch.clone(output)   # First calculate the gradient of the input with respect to the output
        grad_input_[input > 0] = 1
        grad_input_[input <= 0] = ((output[input <= 0] + weight_[input <= 0])**t)*(weight_[input <= 0]**(1-t))
        grad_input = grad_input_*grad_outputs   # Calculate the gradient of the input with respect to the loss function
        grad_weight_ = torch.clone(input)
        grad_weight_[input > 0] = 0
        grad_weight_[input <= 0] = (1+(1-t)*input[input <= 0])**(1/(1-t))-1     
        grad_weight = torch.sum((grad_weight_ * grad_outputs), dim=(-2, -1), keepdim=True)

        return grad_input, grad_weight    # grad_input.size()-> (30,32,116,144)   grad_weight.size()->(30,32,1,1)