Hello all.
I have a main model class like this:
class KDNN(nn.Module):
def __init__(self):
super(KDNN, self).__init__()
self.EnE = torch.nn.Sequential(
MaskedLinear(IE_dim, h_dim1, torch.LongTensor(1- Mask1.values)),
nn.BatchNorm1d(h_dim1),
nn.ReLU(),
MaskedLinear(h_dim1, h_dim2, torch.LongTensor(1-Mask2.values)),
nn.BatchNorm1d(h_dim2),
nn.ReLU(),
MaskedLinear(h_dim2, Out_dim, torch.LongTensor(1-Mask3.values)))
def forward(self, x):
output = self.EnE(x)
return output
And with some helps from posts here a MaskedLinear class like this (to design a layer with masked weights):
class MaskedLinear(nn.Module):
def __init__(self, in_dim, out_dim, mask):
super(MaskedLinear, self).__init__()
def backward_hook(grad):
# Clone due to not being allowed to modify in-place gradients
out = grad.clone()
out[torch.t(self.mask)] = 0
return out
self.linear = nn.Linear(in_dim, out_dim)
self.mask = mask.byte()
self.linear.weight.data[torch.t(self.mask)] = 0 # zero out bad weights
self.linear.weight.register_hook(backward_hook) # hook to zero out bad gradients
def forward(self, input):
return self.linear(input)
When I want to save an object of the KDNN class, I am getting an error: "Can’t pickle local object ‘MaskedLinear.init..backward_hook’ "
Any suggestions?
Thank you!