Reshape selected axis (with torch.reshape)?

You are trying to flatten dim0 into two different dimensions, which won’t work without a permutation.
Since your for loop approach wasn’t updated, I assume this code snippet creates the desired output and also shows how to permute the original tensor before reshaping:

pred = torch.randn(4, 4, 272, 352)
pred= pred.argmax(1).contiguous()
pred_ref = pred.view(2, 2, 272, 352).permute(0, 2, 1, 3).reshape(544, 704)
pred_ = pred.view(-1, 272, 352)

img = torch.zeros(544, 704)
index = 0
for i in range(2):
    for j in range(2):
        img[i*272:(i+1)*272,j*352:(j+1)*352]= pred_[index]
        index = index+1

print((img == pred_ref).all())
> tensor(True)
1 Like