Custom BatchNorm layer

I am writing BatchNorm layer, I want to know is there a forward and back propagation interface available in pytorch. I find interface of conv. This is conv code here:

from torch.autograd import Function
from torch.nn.grad import conv2d_input, conv2d_weight

class my_conv2d(Function):
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        ctx.bias = bias
        ctx.stride = stride
        ctx.padding= padding
        ctx.dilation = dilation
        ctx.groups = groups
        ctx.save_for_backward(input, weight)
        # print("{}-{}".format(current_batch, input.size()))

    def backward(ctx, grad_output):
        input, weight = ctx.saved_variables


        input_grad = weight_grad = None
        if ctx.needs_input_grad[0]:
            input_grad = conv2d_input(input.shape, weight, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
        if ctx.needs_input_grad[1]:
            weight_grad = conv2d_weight(input, weight.shape, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)

        return input_grad, weight_grad, None, None, None, None, None

You can use autograd.Function for your custom functions as described here, so it’s not specific to a convolution use case. :wink:

In this case, conv2d_input and conv2d_weight is provided by pytorch source code.Does pytorch also provide BatchNorm or relu related interfaces?