Ran into a potential torch.btrifact bug when tensor size is (nb, 1, 1)

I ran into a weird bug around which I have gotten. Still, I’d like to report it here for your reference. torch.btrifact leads to a dim-transposed LU when the input tensor is at size (num_batch, 1, 1) on GPU.

Here is a code snippet

import torch

print('PyTorch version: {}'.format(torch.__version__))
print('---below is an incorrect case---')
aa = torch.rand((2,1,1))
aa_LU,_ = aa.btrifact()
bb = aa.cuda()
bb_LU,_ = bb.btrifact()
print(aa_LU.size())
print(bb_LU.size())

print('---below is a correct case---')
aa = torch.rand((2,3,3))
aa_LU,_ = aa.btrifact()
bb = aa.cuda()
bb_LU,_ = bb.btrifact()
print(aa_LU.size())
print(bb_LU.size())

I got

PyTorch version: 0.4.1
---below is an incorrect case---
torch.Size([2, 1, 1])
torch.Size([1, 2, 1])
---below is a correct case---
torch.Size([2, 3, 3])
torch.Size([2, 3, 3])

It works for me on a master build. So I think that this has been fixed!

>>> import torch
>>>
>>> print('PyTorch version: {}'.format(torch.__version__))
PyTorch version: 0.5.0a0+7b905e4
>>> print('---below is an incorrect case---')
---below is an incorrect case---
>>> aa = torch.rand((2,1,1))
>>> aa_LU,_ = aa.btrifact()
>>> bb = aa.cuda()

>>> bb_LU,_ = bb.btrifact()
>>> print(aa_LU.size())
torch.Size([2, 1, 1])
>>> print(bb_LU.size())
torch.Size([2, 1, 1])

Nice to know that! Thanks.