How to understand torch.expand_as operation?

print out a,b will be easy for understanding.
Example 1)
output:

a: tensor([[0.9619, 0.0384, 0.7012],
        [0.5561, 0.3637, 0.9272]])
b: tensor([[0.5986, 0.2582, 0.6261],
        [0.6928, 0.9175, 0.6737],
        [0.9951, 0.8568, 0.6015],
        [0.7922, 0.5019, 0.8162]])

the reason is 1) the size of b is bigger than a’s, you can not expand b by a. 2) the dimension is not match, to output different c, you can size of b to (2, 2, 3) or others. Shown as below,

a = torch.rand(2, 3)
b = torch.rand(2,2, 3)
print('a:',a)
print('b:',b)
c = a.expand_as(b)
print('c:',c)

outputs:

a: tensor([[0.4748, 0.5521, 0.7741],
        [0.0785, 0.2785, 0.5222]])
b: tensor([[[0.7777, 0.3046, 0.8019],
         [0.7398, 0.1424, 0.6398]],

        [[0.9034, 0.8937, 0.8674],
         [0.1737, 0.3192, 0.4451]]])
c: tensor([[[0.4748, 0.5521, 0.7741],
         [0.0785, 0.2785, 0.5222]],

        [[0.4748, 0.5521, 0.7741],
         [0.0785, 0.2785, 0.5222]]])

Example 2) is same problem with example 1.

2 Likes