Hi, at some part of my project, copy_() operation struggled me a bit,
thus I tried to make a toy example as below but could not find a solution for it.
import torch
class A():
def __init__(self):
self.tensor = torch.zeros((5, 6), dtype=torch.float)
def check(self, b):
self.tensor[[2, 4]].copy_(b.tensor)
print(self.tensor)
class B():
def __init__(self, tensor):
self.tensor = tensor
m = torch.ones((2, 6), dtype=torch.float)
a = A()
b = B(m)
a.check(b)
As written, this simple code is just to copy 1.0s to 0.0 at specific locations,
and what I get is only full of zeros, which means copy_() operation did not work.
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
below code which uses assigning operation = instead of copy_() works
def check(self, b):
self.tensor[[2, 4]] = b.tensor
print(self.tensor)
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.]])
and moreover, below code which is not involved in the use of classes works fine
c = torch.zeros((5, 6), dtype=torch.float)
d = torch.ones((2, 6), dtype=torch.float)
c[[2, 4]] = d
print(c)
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.]])
I am hoping to fully understand the reason of not copying the values at the first toy example
(I tried on PyTorch 1.6.0 and 1.7.0)