Efficient slicing and stacking of tensors

Hello,
I would like to write some type of 2d-convolution-layer with the possibility to store intermediate tensors.

Precisely, I would like to store slices of the input-tensor to the layer in a tensor with shape (batch, slices, ch_in, k, k) such that one could multiply it pointwise with a filters-tensor.

The snippet down below yields what I need but due to the forloops is intractable.
How would it be possible to retrieve the described large tensor in an efficient way?

ch_in = 3
bs = 4
k = 3
h, w = 512, 512

x = torch.randn(bs, ch_in, h, w, requires_grad=True) # input tensor

Define the start and end-indices of each 3x3-slice along height- and width-axis.

s_x = torch.arange(x.shape[2] - k + 1)
e_x = s_x + k
s_y = torch.arange(x.shape[3] - k + 1)
e_y = s_y + k

Iterate over these indices and append each slice retrieved this way to an empty list.

z =
for s_x_, e_x_ in zip(s_x, e_x):
for s_y_, e_y_ in zip(s_y, e_y):
z.append(x[:, :, s_x_ : e_x_, s_y_ : e_y_])

z = torch.stack(z, dim=1)

Any help would be appreciated! Thanks.

That’s not the case as for loops don’t break the computation graph.
For your use case you might want to consider using unfold to avoid the loops, but as already mentioned, the for loop should not break your code and the error might come from another place.

Sorry, I did a wording mistake here.
It’s not about breaking the computation graph but about the speed.
Using loops over large tensors are not feasible for my case.

Anyways, unfold looks like what I need.

Thank you!

Ah, thanks for clarifying as I misunderstood the “intractable” wording.