Can also try adding something like:
_oldgetitem = torch.FloatTensor.__getitem__
def _getitem(self, slice_):
if type(slice_) is tuple and torch.LongTensor in [type(x) for x in slice_]:
i = [j for j, ix in enumerate(slice_)
if type(ix) == torch.LongTensor]
return self.transpose(0, i)[slice_[i]].transpose(i, 0)
return _oldgetitem(self, slice_ )
torch.FloatTensor.__getitem__ = _getitem
x[b, :][:, b] works as expected. Maybe a bit of a hack though.
x[b, :][:, b] vs.
x[b, b] – this is something which has never been solved for