I want to implement a max pooling function and I come across a strange bug. I simplify my codes and reproduce the problem as follow:
MaxPooling.py:
# MaxPooling.py
class MaxPoolFunction(Function):
@staticmethod
def forward(ctx, features):
n_features, n_channels = features.size()
output = torch.zeros(n_channels, device=features.device, dtype=torch.float)
argmax = torch.empty(n_channels, device=features.device, dtype=torch.long).fill_(-1)
max_pooling.forward(features, output, argmax)
ctx.argmax = argmax
ctx.n_features = n_features
ctx.n_channels = n_channels
return output
@staticmethod
def backward(ctx, grad_output):
print('grad_output:', grad_output, '\n')
grad_input = torch.zeros(ctx.n_features, ctx.n_channels, device=grad_output.device, dtype=grad_output.dtype)
# if I use original grad_output as argument, the result will be wrong.
max_pooling.backward(grad_output, ctx.argmax, grad_input)
# But if I make a copy of grad_output as argument, the result will be right.
new_grad_output = grad_output.clone().detach()
assert torch.sum(grad_output != new_grad_output) == 0
grad_input_1 = torch.zeros(ctx.n_features, ctx.n_channels, device=grad_output.device, dtype=grad_output.dtype)
max_pooling.backward(new_grad_output, ctx.argmax, grad_input_1)
print("grad_input: ", grad_input)
print("grad_input_1: ", grad_input_1)
return grad_input
# test
if __name__ == '__main__':
features = torch.rand((2, 8)).cuda().requires_grad_()
result = MaxPoolFunction.apply(features)
print("features:", features)
print("max pooling results:", result)
torch.sum(result).backward()
max_pooling.cpp:
// max_pooling.cpp
void forwardLaucher(int n_channels, int n_features, float* features, float* output, long* argmax);
void forward(at::Tensor features, at::Tensor output, at::Tensor argmax)
{
int n_features = features.size(0);
int n_channels = features.size(1);
forwardLaucher(n_channels, n_features, features.data<float>(), output.data<float>(), argmax.data<long>());
}
void backwardLaucher(int n_channels, float* grad, long* argmax, float* output);
void backward(at::Tensor grad, at::Tensor argmax, at::Tensor output)
{
int n_channels = grad.size(0);
backwardLaucher(n_channels, grad.data<float>(), argmax.data<long>(), output.data<float>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &forward, "forward");
m.def("backward", &backward, "backward");
}
max_pooling_kernel.cu:
// max_pooling_kernel.cu
__global__ void forwardKernel(int n_channels, int n_features, float* features, float* output, long* argmax)
{
int c = threadIdx.x;
for(int i=0; i<n_features; i++)
{
if(argmax[c]==-1 || output[c]<features[i*n_channels+c])
{
output[c] = features[i*n_channels+c];
argmax[c] = i;
}
}
}
void forwardLaucher(int n_channels, int n_features, float* features, float* output, long* argmax)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
forwardKernel<<<1, n_channels, 0, stream>>>(n_channels, n_features, features, output, argmax);
}
__global__ void backwardKernel(int n_channels, float* grad, long* argmax, float* output)
{
int c = threadIdx.x;
long idx = argmax[c];
if(idx != -1)
{
atomicAdd(&output[idx*n_channels+c], grad[c]);
}
}
void backwardLaucher(int n_channels, float* grad, long* argmax, float* output)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
backwardKernel<<<1, n_channels, 0, stream>>>(n_channels, grad, argmax, output);
}
When I run MaxPooling.py, I get following results:
features: tensor([[0.1430, 0.9068, 0.1484, 0.9407, 0.7852, 0.3469, 0.0498, 0.9610],
[0.1713, 0.9893, 0.2032, 0.9011, 0.1194, 0.1983, 0.1744, 0.2393]],
device=‘cuda:0’, requires_grad=True)max pooling results: tensor([0.1713, 0.9893, 0.2032, 0.9407, 0.7852, 0.3469, 0.1744, 0.9610],
device=‘cuda:0’, grad_fn=)grad_output: tensor([1., 1., 1., 1., 1., 1., 1., 1.], device=‘cuda:0’)
grad_input: tensor([[0.0000, 0.0000, 0.0000, 0.9407, 0.7852, 0.3469, 0.0000, 0.9610],
[1.0000, 1.5463, 0.2032, 0.0000, 0.0000, 0.0000, 0.1744, 0.0000]],
device=‘cuda:0’)grad_input_1: tensor([[0., 0., 0., 1., 1., 1., 0., 1.],
[1., 1., 1., 0., 0., 0., 1., 0.]], device=‘cuda:0’)
We can see that the grad_output
in backward function is same as new_grad_output
. But using the original grad_output
get wrong backward result. Why?
The codes should have limited that the grad_input only depends on the grad_output and argmax. But from the results we can also see that some values of grad_input are the values of max pooling results. It’s very strange.