I am writing BatchNorm layer, I want to know is there a forward and back propagation interface available in pytorch. I find interface of conv. This is conv code here:
from torch.autograd import Function
from torch.nn.grad import conv2d_input, conv2d_weight
class my_conv2d(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
ctx.bias = bias
ctx.stride = stride
ctx.padding= padding
ctx.dilation = dilation
ctx.groups = groups
ctx.save_for_backward(input, weight)
# print("{}-{}".format(current_batch, input.size()))
...
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_variables
...
input_grad = weight_grad = None
if ctx.needs_input_grad[0]:
input_grad = conv2d_input(input.shape, weight, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
if ctx.needs_input_grad[1]:
weight_grad = conv2d_weight(input, weight.shape, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
return input_grad, weight_grad, None, None, None, None, None