I am experiencing a strange behaviour where after assigning a nn.parameter
to a tensor, this tensor becomes of type nn.parameter
as well which leads to a TypeError
in future assignments:
Cannot assign 'torch.FloatTensor' as parameter 'weight_current' torch.nn.Parameter or None expected)
This problem appears to be generated by the following code-snippet in one of my modules:
class MyLayer(torch.nn.Linear):
def forward(self, input, sample):
if sample:
self.weight_current = self.weight + torch.randn_like(self.weight)
self.bias_current = self.bias + torch.randn_like(self.bias)
else:
self.weight_current = self.weight
self.bias_current = self.bias
return torch.nn.functional.linear(input, self.weight_current, self.bias_current)
That is, after calling forward
with sample=False
, I cannot call it with sample=True
anymore, as this raises the TypeError
.
It seems odd to me that this is happening at all. Can somebody explain why pytorch behaves the way it does and what I can do about it?