I have a 4-d (batch, channel, W, H) tensor that I’d like to split into equal sized tensor while preserving the batch and channel dimensioinality. I was wondering if there’s a better way of doing this instead of nesting two torch.split calls. My ultimate goal is to apply the same type of transformation to each of these chunks (this transformation is not a convolution). Maybe there’s a way to avoid the “explicit” tensor splitting and apply the transformation directly in a chunk-based fashion? This is what I’ve done:
a = torch.randn(64, 3, 16, 16)
# I'd like to split it into [64,3,4,4] chunks
chunk_dim = 4
a_x_split = torch.chunk(a, chunk_dim, dim=2)
chunks = 
for cnk in a_x_split:
cnks = torch.chunk(cnk, chunk_dim, dim=3)
for c_ in cnks:
print(len(chunks)) # prints '16'
print(chunks.shape) # prints 'torch.Size([64, 3, 4, 4])'
# then apply same transform to each chunk individually
Hi, I think in that way the internal order is preserved, no?
In my case, imagine its a high-dimension image batch [8,64,64,64], I wanna crop this “image” tensor into 4 pieces, just like cropping images into 4 parts. So each partial “images” tensor will have [8,64,32,32]. I think torch.unfold is doing right, but am also not 100% sure. @ptrblck could you please clarify that? Thanks