Tensor split accross multiple dimensions

I have a 4-d (batch, channel, W, H) tensor that I’d like to split into equal sized tensor while preserving the batch and channel dimensioinality. I was wondering if there’s a better way of doing this instead of nesting two torch.split calls. My ultimate goal is to apply the same type of transformation to each of these chunks (this transformation is not a convolution). Maybe there’s a way to avoid the “explicit” tensor splitting and apply the transformation directly in a chunk-based fashion? This is what I’ve done:

a = torch.randn(64, 3, 16, 16)

# I'd like to split it into [64,3,4,4] chunks

chunk_dim = 4
a_x_split = torch.chunk(a, chunk_dim, dim=2)

chunks = []
for cnk in a_x_split:
    cnks = torch.chunk(cnk, chunk_dim, dim=3)
    for c_ in cnks:
        chunks.append(c_)

print(len(chunks)) # prints '16'
print(chunks[0].shape) # prints 'torch.Size([64, 3, 4, 4])'

# then apply same transform to each chunk individually
...

tensor.unfold should yield the same result:

a = torch.randn(64, 3, 16, 16)
kernel_size = 4
kernel_stride = 4
a = a.unfold(2, kernel_size, kernel_stride).unfold(3, kernel_size, kernel_stride)
a = a.contiguous().view(a.size(0), a.size(1), -1, a.size(4), a.size(5))
print(a.shape)
> torch.Size([64, 3, 16, 4, 4])

Not that the patches are now in dim2. You can of course permute the tensor as you wish.
chunks[0] now corresponds to a[:, :, 0].

2 Likes

Thank you for the above. What would be the best way to revert this operation while perserving the internal orderingn of the data? sounds silly but I’m struggling with this o.O)

Hi, I think in that way the internal order is preserved, no?
In my case, imagine its a high-dimension image batch [8,64,64,64], I wanna crop this “image” tensor into 4 pieces, just like cropping images into 4 parts. So each partial “images” tensor will have [8,64,32,32]. I think torch.unfold is doing right, but am also not 100% sure. @ptrblck could you please clarify that? Thanks :wink:

Have a look at this post to see how to recreate the original input.
As @windson said, the internal order should be preserved.

1 Like