The result of following code seems not right:
device = 'mps'
list_ = torch.zeros(3, 2)
print('init:\n{}'.format(list_))
for i in range(3):
print('---------------%i---------------' %i)
ten = torch.tensor([i,i], device=device)
print('tensor:\n{}'.format(ten))
list_[i] = ten
print("ten:\n{}\nlist:\n{}".format(ten, list_), end='\n\n')
result:
init:
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
---------------0---------------
tensor:
tensor([0, 0], device='mps:0')
ten:
tensor([0, 0], device='mps:0')
list:
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
---------------1---------------
tensor:
tensor([1, 1], device='mps:0')
ten:
tensor([4575657222473777152, 1], device='mps:0')
list:
tensor([[1.0000e+00, 1.0000e+00],
[1.4013e-45, 0.0000e+00],
[0.0000e+00, 0.0000e+00]])
---------------2---------------
tensor:
tensor([2, 2], device='mps:0')
ten:
tensor([4611686019501129728, 2], device='mps:0')
list:
tensor([[2.0000e+00, 2.0000e+00],
[2.8026e-45, 0.0000e+00],
[0.0000e+00, 0.0000e+00]])
move list_ to mps:
device = 'mps'
list_ = torch.zeros(3, 2).to(device)
print('init:\n{}'.format(list_))
for i in range(3):
print('---------------%i---------------' %i)
ten = torch.tensor([i,i], device=device)
print('tensor:\n{}'.format(ten))
list_[i] = ten
print("ten:\n{}\nlist:\n{}".format(ten, list_), end='\n\n')
result:
init:
tensor([[0., 0.],
[0., 0.],
[0., 0.]], device='mps:0')
---------------0---------------
tensor:
tensor([0, 0], device='mps:0')
ten:
tensor([0, 0], device='mps:0')
list:
tensor([[0., 0.],
[0., 0.],
[0., 0.]], device='mps:0')
---------------1---------------
tensor:
tensor([1, 1], device='mps:0')
ten:
tensor([1, 1], device='mps:0')
list:
tensor([[1., 1.],
[0., 0.],
[0., 0.]], device='mps:0')
---------------2---------------
tensor:
tensor([2, 2], device='mps:0')
ten:
tensor([2, 2], device='mps:0')
list:
tensor([[2., 2.],
[0., 0.],
[0., 0.]], device='mps:0')