Hi. I will try to clarify as I understand it myself.
- It is not
expand
, there the problem is rising, if you not changinga
, the code is working:
a = nn.Parameter(torch.FloatTensor(torch.ones(3,1)), requires_grad=True)
b = a.expand(3,4)
c = a.repeat(1,4)
print(a.data_ptr(),b.data_ptr(),c.data_ptr())
print(b)
print(c)
d = torch.sum(3*b)
d.backward()
print(a.grad)
- The error is giving a small clue (but clue) on what is the matter, it is about “unreachable”. And the problem is in this code
a[0,0] = 3
. You are basically making some part ofa
to notrequires_grad
and so make this part unreachable for backward computation. But if you change the code like this:
a.data[0,0] = 3
. It will change only data and notrequires_grad
and messing up with computational graph. And all the code will work as expected.
I believe, I may not be 100% correct in details of my explanation but the main idea should be correct.
As of the difference between expand
and repeat
, this post can be of some help to you.
Cheers.