Here is the model I trained. Everything worked perfectly until I tried to save it.
class Extractor(nn.Module):
def __init__(self):
super(Extractor, self).__init__()
self.conv0 = nn.Conv2d(1, 16, kernel_size=(3,2),stride=1) ## change with input shape
self.bn0 = nn.BatchNorm2d(16)
# Res block 1
self.conv1 = nn.Conv2d(16, 32, kernel_size=3)
self.bn1 = nn.BatchNorm2d(32)
self.conv11 = nn.Conv2d(32, 32, kernel_size=(3,3),stride=(1,1),padding=(2,2))
self.bn11 = nn.BatchNorm2d(32)
# Res block 2
self.conv2 = nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1),padding=(1,1))
self.bn2 = nn.BatchNorm2d(64)
self.conv21 = nn.Conv2d(64, 64, kernel_size=(3,3), stride=(2,2),padding=(1,1))
self.bn21 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=(5,3), stride=(2,1),padding=(1,1)) ### change with input shape
self.bn3 = nn.BatchNorm2d(64)
self.drop = nn.Dropout2d(0.5)
def forward(self, x):
x = F.relu(self.bn0(self.conv0(x)))
#Res block 1
x1 = self.drop(F.relu(self.bn1(self.conv1(x))))
x1 = F.relu(F.max_pool2d(self.drop(self.bn11(self.conv11(x1))), 2))
x = torch.cat((x,torch.zeros_like(x)), axis=1)
x = F.max_pool2d(x,2)
x = x+x1
#Res block 2
x1 = self.drop(F.relu(self.bn2(self.conv2(x))))
x1 = F.relu(self.drop(self.bn21(self.conv21(x1))))
x = torch.cat((x,torch.zeros_like(x)), axis=1)
x = F.max_pool2d(x,2)
x = x+x1
#last conv
x = self.drop(F.relu(self.bn3(self.conv3(x))))
x = F.max_pool2d(x,(2,1)) ### change withinput
x = x.view(-1, 64*3*15)
return x
class Class_classifier(nn.Module):
def __init__(self, num_class,in_feature=64*3*15,intermediate_nodes=100):
super(Class_classifier, self).__init__()
# self.fc1 = nn.Linear(50 * 4 * 4, 100)
# self.bn1 = nn.BatchNorm1d(100)
# self.fc2 = nn.Linear(100, 100)
# self.bn2 = nn.BatchNorm1d(100)
# self.fc3 = nn.Linear(100, 10)
self.fc1 = nn.Linear(in_feature, intermediate_nodes)
self.fc2 = nn.Linear(intermediate_nodes, num_class)
self.relu = nn.ReLU()
self.soft = nn.Softmax(dim=1)
def forward(self, x):
# logits = F.relu(self.bn1(self.fc1(input)))
# logits = self.fc2(F.dropout(logits))
# logits = F.relu(self.bn2(logits))
# logits = self.fc3(logits)
logits = self.relu(self.fc1(x))
logits = self.fc2(F.dropout(logits))
logits = self.soft(logits)
return logits
class Domain_classifier(nn.Module):
def __init__(self,domain_class,in_feature=64*3*15,intermediate_nodes=100,scheme='basic'):
super(Domain_classifier, self).__init__()
self.scheme = scheme
# self.fc1 = nn.Linear(50 * 4 * 4, 100)
# self.bn1 = nn.BatchNorm1d(100)
# self.fc2 = nn.Linear(100, 2)
self.fc1 = nn.Linear(in_feature, intermediate_nodes)
self.fc2 = nn.Linear(intermediate_nodes, domain_class)
self.relu = nn.ReLU()
self.soft = nn.Softmax(dim=1)
def forward(self, x, constant):
if(self.scheme=='dann'):
x = GradReverse.grad_reverse(x, constant)
logits = self.relu(self.fc1(x))
logits = self.soft(self.fc2(logits))
return logits
class Network(nn.Module):
def __init__(self,num_class, domain_class, scheme):
super(Network, self).__init__()
self.extractor = Extractor()
self.classifier = Class_classifier(num_class=num_class)
self.domain = Domain_classifier(domain_class=domain_class)
def forward(self, x, hp_lambda=0):
x = self.extractor(x)
clss = self.classifier(x)
dom = self.domain(x,hp_lambda)
return clss,dom
net = Network(2,4,"basic")
torch.save(net, "network.pth")
This save gives the following error message
AttributeError Traceback (most recent call last)
in
----> 1 torch.save(net.modules, “network.pth”)
~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
222 >>> torch.save(x, buffer)
223 “”"
–> 224 return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))
225
226
~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
147 f = open(f, mode)
148 try:
–> 149 return body(f)
150 finally:
151 if new_fd:
~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in (f)
222 >>> torch.save(x, buffer)
223 “”"
–> 224 return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))
225
226
~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
294 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
295 pickler.persistent_id = persistent_id
–> 296 pickler.dump(obj)
297
298 serialized_storage_keys = sorted(serialized_storages.keys())
AttributeError: Can’t pickle local object ‘summary..register_hook..hook’
What is going wrong here?
Saving the state dict works, but I can’t save the entire model.