It seems like the tensor device will be changed to cuda:0 if is originally on the CPU. Is that intended behavior?
# a gets moved to cuda:0
a = torch.randint(0,5,(2,2),device='cpu')
b = torch.randn(2,2,device='cuda:0')
a = a.type_as(b)
a.device
-> device(type='cuda', index=0)
b.device
-> device(type='cuda', index=0)
# a is not moved
a = torch.randint(0,5,(2,2),device='cuda:0')
b = torch.randn(2,2,device='cuda:1')
a = a.type_as(b)
a.device
-> device(type='cuda', index=0)
b.device
-> device(type='cuda', index=1)