PyTorch equivalent to Caffe Resample with `propagate_down=False`?

In Caffe protobufs, I have seen resampling layers defined as:

layer{
	name: "resample_layer"
	type: "Resample"
	bottom: "input_layer"
	top: "resample_layer"
	propagate_down: false
	resample_param{
		height: 300
		width: 400
	}
}

That propagatate_down parameter is supposed to prevent backpropagation down certain paths. From this question, it seems the PyTorch equivalent operation is detach().

However, the PyTorch equivalent to Caffe’s Resample layer type is torch.nn.functional.interpolate. When I backpropagate through interpolate, it seems the gradients are summed for each expanded index. For example

import torch
import torch.nn as nn
import torch.nn.functional as F

# With no interpolation
z = torch.rand(1,1,3,3)
z.requires_grad=True
zz = 3 * z
zz.register_hook(print)
ll = (1.5 * zz).sum()
ll.backward()


>> tensor([[[[1.5000, 1.5000, 1.5000],
          [1.5000, 1.5000, 1.5000],
          [1.5000, 1.5000, 1.5000]]]])




# With interpolate
z = torch.rand(1,1,3,3)
z.requires_grad=True
zz = 3 * z
zz.register_hook(print)
ww = F.interpolate(zz, scale_factor=2)
ll = (1.5 * ww).sum()
ll.backward()

>> tensor([[[[6., 6., 6.],
          [6., 6., 6.],
          [6., 6., 6.]]]])

The factor of 4 seems to be due to the fact that the number of elements has quadrupled due to scale_factor=2. I suspect this affects my gradients somewhat differently than the Resample layer in Caffe with propagate_down: False.

How can I get the equivalent behavior in PyTorch? Using detach() or with torch.no_grad() only screws up the backward() pass.

1 Like

@marcman411: Have you find the solution? I also interested in the question