In Caffe protobufs, I have seen resampling layers defined as:
layer{
name: "resample_layer"
type: "Resample"
bottom: "input_layer"
top: "resample_layer"
propagate_down: false
resample_param{
height: 300
width: 400
}
}
That propagatate_down
parameter is supposed to prevent backpropagation down certain paths. From this question, it seems the PyTorch equivalent operation is detach()
.
However, the PyTorch equivalent to Caffe’s Resample
layer type is torch.nn.functional.interpolate
. When I backpropagate through interpolate, it seems the gradients are summed for each expanded index. For example
import torch
import torch.nn as nn
import torch.nn.functional as F
# With no interpolation
z = torch.rand(1,1,3,3)
z.requires_grad=True
zz = 3 * z
zz.register_hook(print)
ll = (1.5 * zz).sum()
ll.backward()
>> tensor([[[[1.5000, 1.5000, 1.5000],
[1.5000, 1.5000, 1.5000],
[1.5000, 1.5000, 1.5000]]]])
# With interpolate
z = torch.rand(1,1,3,3)
z.requires_grad=True
zz = 3 * z
zz.register_hook(print)
ww = F.interpolate(zz, scale_factor=2)
ll = (1.5 * ww).sum()
ll.backward()
>> tensor([[[[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]]]])
The factor of 4 seems to be due to the fact that the number of elements has quadrupled due to scale_factor=2
. I suspect this affects my gradients somewhat differently than the Resample
layer in Caffe with propagate_down: False
.
How can I get the equivalent behavior in PyTorch? Using detach()
or with torch.no_grad()
only screws up the backward()
pass.