You are trying to flatten dim0
into two different dimensions, which won’t work without a permutation.
Since your for loop approach wasn’t updated, I assume this code snippet creates the desired output and also shows how to permute the original tensor before reshaping:
pred = torch.randn(4, 4, 272, 352)
pred= pred.argmax(1).contiguous()
pred_ref = pred.view(2, 2, 272, 352).permute(0, 2, 1, 3).reshape(544, 704)
pred_ = pred.view(-1, 272, 352)
img = torch.zeros(544, 704)
index = 0
for i in range(2):
for j in range(2):
img[i*272:(i+1)*272,j*352:(j+1)*352]= pred_[index]
index = index+1
print((img == pred_ref).all())
> tensor(True)