Reshape selected axis (with torch.reshape)?

Hi! So kind of a stupid question but say if I have a tensor dimension of T x B x C x H x W and I flattened it to TB x C x H x W, would torch reshape using the original dimension values gave back the same order of data? Is there a way that I can reshape only the selected axis to expand the dimension? and would the order also be maintained?

Yes, as long as you don’t permute the dimensions, view will yield the original tensor:

T, B, C, H, W = 2, 3, 4, 5, 6
x = torch.randn(T, B, C, H, W)
y = x.view(-1, C, H, W)
z = y.view(T, B, C, H, W)
print((z == x).all())
> tensor(True)

Gotcha, thank you! I was able to verify that is the case here.

@ptrblck
My tensor shape is 4*H,W
if arrange this using loop to bigger H,W that is factor of above i get correct Image.
But when i simply do view(BigH,BigW) i get incorrect image ,there is repetition of patches

Could you post a code snippet showing these unexpected results, please?

Pred= model (x)
#pred shape is 4 * 4* 272,352
pred= pred.argmax(1).contiguous()
#pred shape is 4 272352
pred= pred.view(2,2,272,356).view(544,704) #gives incorrect image
#expected ordering of index is[ [0,1][2,3]]

When use for loop i get right image
for i in rows
for j in columns
img[i:i+1 *272,j:j+1*352]= pred[index]
Index=index+1

In your code you are interleaving dimensions, which I tried to mention in:

as long as you don’t permute the dimensions, view will yield the original tensor

I’m also unsure how the for loop approach would work and if you have a typo or of you want to replace the values in img. Should the indexing be [i:(i+1)*272, j:(j+1)*352]?
The reshape operation is also failing, as it seems you want to flatten it to (2, 2, 272, 352).

Yes it was typo… in for loop
What do you mean by interleaving?
So there is no straight way ?
I even flattened whole tensor , then tried to reshape but it gave same results .
@ptrblck
I thought pytorch does sequential reshaping

You are trying to flatten dim0 into two different dimensions, which won’t work without a permutation.
Since your for loop approach wasn’t updated, I assume this code snippet creates the desired output and also shows how to permute the original tensor before reshaping:

pred = torch.randn(4, 4, 272, 352)
pred= pred.argmax(1).contiguous()
pred_ref = pred.view(2, 2, 272, 352).permute(0, 2, 1, 3).reshape(544, 704)
pred_ = pred.view(-1, 272, 352)

img = torch.zeros(544, 704)
index = 0
for i in range(2):
    for j in range(2):
        img[i*272:(i+1)*272,j*352:(j+1)*352]= pred_[index]
        index = index+1

print((img == pred_ref).all())
> tensor(True)
1 Like