Hello,
I want to create a function to overwrite the forward and backward pass of a nn.Module. E.g. I load a ResNet or any other network, and I automatically change the forward and backward pass of all the layers by a custom Autograd function.
The code I have now is working in CPU and one GPU but not working when I extend to DataParallel.
It is distributing data and the model in different GPUs (Exactly like here: Issue for DataParallel · Issue #8637 · pytorch/pytorch · GitHub) and (DataParallel on modules with dynamically overwritten forwards)
Also I don’t know if that is the best way of overwriting the forward/backward pass.
I define a function to be applied to every model layer:
def override_backward(layer):
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
def forward_conv(x):
if layer.bias is None:
return Conv2dFA.apply(x,
layer.weight,
layer.weight_fa,
None,
None,
layer.stride,
layer.padding,
layer.dilation,
layer.groups)
else:
return Conv2dFA.apply(x,
layer.weight,
layer.weight_fa,
layer.bias,
layer.bias_fa,
layer.stride,
layer.padding,
layer.dilation,
layer.groups)
layer.forward = forward_conv
The function Conv2dFA is:
class Conv2dFA(autograd.Function):
@staticmethod
def forward(context, input, kernels, kernels_fa, bias, bias_fa, stride, padding, dilation, groups):
context.stride, context.padding, context.dilation, context.groups = stride, padding, dilation, groups
context.save_for_backward(input, kernels, kernels_fa, bias, bias_fa)
output = torch.nn.functional.conv2d(input,
kernels,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
return output
@staticmethod
def backward(context, grad_output):
input, kernels, kernels_fa, bias, bias_fa = context.saved_tensors
grad_input = grad_kernels = grad_kernels_fa = grad_bias = grad_bias_fa = None
if context.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input(input_size=input.shape,
weight=kernels_fa,
grad_output=grad_output,
stride=context.stride,
padding=context.padding,
dilation=context.dilation,
groups=context.groups)
if context.needs_input_grad[1]:
grad_kernels = torch.nn.grad.conv2d_weight(input=input,
weight_size=kernels_fa.shape,
grad_output=grad_output,
stride=context.stride,
padding=context.padding,
dilation=context.dilation,
groups=context.groups)
if bias is not None and context.needs_input_grad[3]:
grad_bias = grad_output.sum(0).sum(2).sum(1)
# add the input in the stride gradient which is useless
# return grad_input, grad_kernels, grad_kernels_fa, grad_bias, grad_bias_fa, stride, padding, dilation, groups
return grad_input, grad_kernels, grad_kernels_fa, grad_bias, grad_bias_fa, None, None, None, None
And I apply it like this:
model_fa = resnet.resnet18()
model_fa.apply(override_backward)
Is there a way of doing the dynamic overwrite forward that will work with Data Parallel? I don’t want to create custom classes as I want this method to be applicable to every neural net.
Thanks in advance.