I would like to write some type of 2d-convolution-layer with the possibility to store intermediate tensors.

Precisely, I would like to store slices of the input-tensor to the layer in a tensor with shape (batch, slices, ch_in, k, k) such that one could multiply it pointwise with a filters-tensor.

The snippet down below yields what I need but due to the forloops is intractable.

How would it be possible to retrieve the described large tensor in an efficient way?

ch_in = 3

bs = 4

k = 3

h, w = 512, 512x = torch.randn(bs, ch_in, h, w, requires_grad=True) # input tensor

###### Define the start and end-indices of each 3x3-slice along height- and width-axis.

s_x = torch.arange(x.shape[2] - k + 1)

e_x = s_x + k

s_y = torch.arange(x.shape[3] - k + 1)

e_y = s_y + k

###### Iterate over these indices and append each slice retrieved this way to an empty list.

z = []

for s_x_, e_x_ in zip(s_x, e_x):

for s_y_, e_y_ in zip(s_y, e_y):

z.append(x[:, :, s_x_ : e_x_, s_y_ : e_y_])z = torch.stack(z, dim=1)

Any help would be appreciated! Thanks.