Hi,
I’m using nn.Dataparallel and suffering for network weight replicating issue.
Pytorch version is 1.3.1.
Linear_fw
is a child class of nn.Linear
, which becomes a part of the network architecture.
The key feature of my interest is to register a member named fast
for weight
and bias
in nn.Linear
.
i.e.
class Linear_fw(nn.Linear): # child class of nn.Linear.
def __init__(self, in_features, out_features):
super(Linear_fw, self).__init__(in_features, out_features)
self.weight.fast = None # <-- register a member, supposed to be assigned an nn.Parameter outside.
self.bias.fast = None #<-- register a member, supposed to be assigned an nn.Parameter outside.
self.fast_flag = True
The error occurs during forwarding with nn.DataParallel
class Linear_fw(nn.Linear): # child class of nn.Linear.
def __init__(self, in_features, out_features):
super(Linear_fw, self).__init__(in_features, out_features)
self.weight.fast = None # <-- register a member, supposed to be assigned an nn.Parameter outside.
self.bias.fast = None #<-- register a member, supposed to be assigned an nn.Parameter outside.
self.fast_flag = True
def forward(self, x):
if self.fast_flag:
out = F.linear(x, self.weight.fast, self.bias.fast) # <-- Attribute error occurs
else:
out = super(Linear_fw, self).forward(x)
return out
An attribute error raises saying that:
out = F.linear(x, self.weight.fast, self.bias.fast)
AttributeError: 'Tensor' object has no attribute 'fast'
Without using nn.DataParallel, self.weight
is nn.Parameters
and the code runs. But with using nn.DataParallel, self.weight
is torch.Tensor
, thus, accessing to its member self.weight.fast
raise the error. This error raise only with nn.DataParallel.
Here, the easiest way is to code (somehow magically) to replicate the network weight and bias of nn.Parameters
instead torch.Tensor
during nn.DataParallel.forward
.
Is it possible somehow? I sincerely appreciate your advices.