Subtraction along axis

Hi everyone, I am a bit of a newbie in PyTorch and I have a very basic issue that I am having trouble with.

I have two matrices A and B, with different number of rows, but same number of columns.
Basically, A and B are different collections of same-sized vectors.
What I am trying to do is to subtract each vector in B from each vector in A.

This minimal example does exactly what I’m trying to accomplish:

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[1, 1, 1], [2, 2, 2]])

c = torch.stack([a-x for x in b])

but of course I would like to get rid of that ugly “for” cycle, which I guess is not good for performances.
I think I need some sort of torch.diff(a, b, axis=1).
Does something like that exist?

Broadcasting should work, but would trade memory for a potential speedup:

c = a.unsqueeze(0) - b.unsqueeze(1)
1 Like