.sum(2, keepdim=True) means to perform the sum operation over the 2nd dimension. keepdim means that after reduction, that dimension remains but has size 1.
In [1]: import torch
In [2]: a = torch.rand(2, 3, 4)
In [3]: a.sum(2, keepdim=False).size()
Out[3]: torch.Size([2, 3])
In [4]: a.sum(2, keepdim=True).size()
Out[4]: torch.Size([2, 3, 1])
/pytorch/torch/csrc/autograd/python_function.cpp:622: UserWarning: Legacy autograd function with non-static forward method is deprecated and will be removed in 1.3. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
with
class BinActive(torch.autograd.Function):
'''
Binarize the input activations and calculate the mean across channel dimension.
'''
def forward(self, input):
self.save_for_backward(input)
size = input.size()
mean = torch.mean(input.abs(), 1, keepdim=True)
input = input.sign()
return input, mean
def backward(self, grad_output, grad_output_mean):
input, = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input.ge(1)] = 0
grad_input[input.le(-1)] = 0
return grad_input