Let’s begin with a simple 2-D tensor:
t1=torch.tensor([1,2,3,5],dtype=torch.int16).
t1=t1.reshape((2,2))
Next I create two 3-D tensors by re-shaping t1, and then try to perform an element-wise operation on them.
t11=t1.unsqueeze(1) # (shape: (2,1,2)
t12=t1.unsqueeze(2) # (shape: (2,2,1)
final= t12 + t11
I expected the final output to be of shape (2,2,2), but I don’t understand how did pytorch broadcast t12 and t11 to perform the element-wise addition operation. Could someone please explain how does broadcasting happen with 3D tensors, and how did it result in the final tensor shown above?
P.S.: I am aware how broadcasting works with 2-D tensors, I just cant seem to visualise it happening with 3-D tensors.
Thanks!