Questions about inplace operation?

I wanted to implement Binary activation function to replace ReLU. To save memory, I want to use “inplace” operation. However, it seems that “inplace” causes some problems and I don’t know why. I know the original input is changed, but I don’t know where else the original input is used. Could someone explain it for me?

“code-1” is the “inplace” operation, which gets much lower accuracy than “code-2”.

>     class binarize(torch.autograd.Function):
>         def forward(self, input):
>             noback_indicator = (input <= -1.0) | (input >= 1.0)
>             self.noback_indicator = noback_indicator
> 
>             ############ code-1 ###########
>             # input[input>=0] = 1
>             # input[input<0] = -1
>             # return input
>             ############ code-1 ###########
> 
>             ############ code-2 ###########
>             output = input.clone() 
>             output[output>=0] = 1
>             output[output<0] = -1
>             return output
>             ############ code-2 ###########
> 
>         def backward(self, grad_output):
>             grad_input = grad_output.clone()
>             grad_input[self.noback_indicator] = 0
>             return grad_input
> 
>     class Binary(torch.nn.Module):
>         def forward(self, input):
>             output = binarize()(input)
>             return output

My net definition is:

>         self.features = nn.Sequential(
>             nn.Conv2d(3, 128, 3, 1, 1, bias=False),
>             nn.BatchNorm2d(128, affine=False),
>             mm.Binary(),
>             nn.Conv2d(128, 128, 3, 1, 1, bias=False),
>             nn.MaxPool2d(kernel_size=2, stride=2),
>             nn.BatchNorm2d(128, affine=False),
>             mm.Binary(),
>             
>             nn.Conv2d(128, 256, 3, 1, 1, bias=False),
>             nn.BatchNorm2d(256, affine=False),
>             mm.Binary(),
>             nn.Conv2d(256, 256, 3, 1, 1, bias=False),
>             nn.MaxPool2d(kernel_size=2, stride=2),
>             nn.BatchNorm2d(256, affine=False),
>             mm.Binary(),
>             
>             nn.Conv2d(256, 512, 3, 1, 1, bias=False),
>             nn.BatchNorm2d(512, affine=False),
>             mm.Binary(),
>             nn.Conv2d(512, 512, 3, 1, 1, bias=False),
>             nn.MaxPool2d(kernel_size=2, stride=2),
>             nn.BatchNorm2d(512, affine=False),
>             mm.Binary(),
>         )

Why not use

binarize = lambda x: F.relu(x) * 2 - 1

The problem might be that you modify the input inplace without telling pytorch that the input is dirty. In forward you need to do this…

self.mark_dirty(input)

For more info, see here…

@jpeg729 I get it now. Thank you very much!