AvgPool2d is not differentiable twice

When I want to implement the ResNet generator of improved wgan. When I want to implement the MeanPoolConv,
def MeanPoolConv(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
output = inputs
output = tf.add_n([output[:,:,::2,::2], output[:,:,1::2,::2], output[:,:,::2,1::2], output[:,:,1::2,1::2]]) / 4.
output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
return output
I transfer it as follows
class MeanPoolConv(nn.Module):
def init(self,in_channels,out_channels,filter_size,biases=True):
super(MeanPoolConv,self).init()
self.avepool=nn.AvgPool2d(2,2)
self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=filter_size,stride=1,
bias=biases)
def forward(self,input):
output=self.avepool(input)
output=self.conv(output)
return output

When running the code, I encountered the following problems:
Traceback (most recent call last):
File “/home/hpc-126/remote-host/wgan-gp-res.py”, line 347, in
gradient_penalty.backward()
File “/home/hpc-126/Downloads/pytorch-master/torch/autograd/variable.py”, line 152, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File “/home/hpc-126/Downloads/pytorch-master/torch/autograd/init.py”, line 98, in backward
variables, grad_variables, retain_graph)
RuntimeError: AvgPool2d is not differentiable twice

Hi,

This error is expected as higher order gradients have not yet been implemented for all functions unfortunately.

Thank you for your answer,
Can I use other method to implement the function of avepool layer? Similar to tf.add_n () /4 in the tesorflow code

You can use the advanced indexing that has been used in master recently to do this:

import torch
from torch import nn
from torch.autograd import Variable

avgpool=nn.AvgPool2d(2,2)

a = Variable(torch.rand(10, 20, 50, 50))

# nn implementation
out_nn = avgpool(a)

# Manual implementation
out_man = (a[:,:,::2,::2] + a[:,:,1::2,::2] + a[:,:,::2,1::2] + a[:,:,1::2,1::2]) / 4

print((out_nn - out_man).abs().mean())
1 Like

This method is available. Thank you very much !:grin: