Torch.save Can't pickle local object

Here is the model I trained. Everything worked perfectly until I tried to save it.

class Extractor(nn.Module):

    def __init__(self):
        super(Extractor, self).__init__()
        self.conv0 = nn.Conv2d(1, 16, kernel_size=(3,2),stride=1)   ## change with input shape
        self.bn0 = nn.BatchNorm2d(16)
        
        # Res block 1
        self.conv1 = nn.Conv2d(16, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv11 = nn.Conv2d(32, 32, kernel_size=(3,3),stride=(1,1),padding=(2,2))
        self.bn11 = nn.BatchNorm2d(32)
        
        # Res block 2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1),padding=(1,1))
        self.bn2 = nn.BatchNorm2d(64)
        self.conv21 = nn.Conv2d(64, 64, kernel_size=(3,3), stride=(2,2),padding=(1,1))
        self.bn21 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 64, kernel_size=(5,3), stride=(2,1),padding=(1,1)) ### change with input shape
        self.bn3 = nn.BatchNorm2d(64)
        
        self.drop = nn.Dropout2d(0.5)
        
        
    def forward(self, x):
        x = F.relu(self.bn0(self.conv0(x)))
        
        #Res block 1
        x1 = self.drop(F.relu(self.bn1(self.conv1(x))))
        x1 = F.relu(F.max_pool2d(self.drop(self.bn11(self.conv11(x1))), 2))
        x = torch.cat((x,torch.zeros_like(x)), axis=1)
        x = F.max_pool2d(x,2)
        x = x+x1
        
        #Res block 2
        x1 = self.drop(F.relu(self.bn2(self.conv2(x))))
        x1 = F.relu(self.drop(self.bn21(self.conv21(x1))))
        x = torch.cat((x,torch.zeros_like(x)), axis=1)
        x = F.max_pool2d(x,2)        
        x = x+x1
        
        #last conv
        x = self.drop(F.relu(self.bn3(self.conv3(x))))
        x = F.max_pool2d(x,(2,1))  ### change withinput
        x = x.view(-1, 64*3*15)
        
        return x

class Class_classifier(nn.Module):

    def __init__(self, num_class,in_feature=64*3*15,intermediate_nodes=100):
        super(Class_classifier, self).__init__()
        # self.fc1 = nn.Linear(50 * 4 * 4, 100)
        # self.bn1 = nn.BatchNorm1d(100)
        # self.fc2 = nn.Linear(100, 100)
        # self.bn2 = nn.BatchNorm1d(100)
        # self.fc3 = nn.Linear(100, 10)
        self.fc1 = nn.Linear(in_feature, intermediate_nodes)
        self.fc2 = nn.Linear(intermediate_nodes, num_class)
        
        self.relu = nn.ReLU()
        self.soft = nn.Softmax(dim=1)
        
    def forward(self, x):
        # logits = F.relu(self.bn1(self.fc1(input)))
        # logits = self.fc2(F.dropout(logits))
        # logits = F.relu(self.bn2(logits))
        # logits = self.fc3(logits)
        logits = self.relu(self.fc1(x))
        logits = self.fc2(F.dropout(logits))
        logits = self.soft(logits)

        return logits

class Domain_classifier(nn.Module):

    def __init__(self,domain_class,in_feature=64*3*15,intermediate_nodes=100,scheme='basic'):
        super(Domain_classifier, self).__init__()
        self.scheme = scheme
        # self.fc1 = nn.Linear(50 * 4 * 4, 100)
        # self.bn1 = nn.BatchNorm1d(100)
        # self.fc2 = nn.Linear(100, 2)
        
        self.fc1 = nn.Linear(in_feature, intermediate_nodes)
        self.fc2 = nn.Linear(intermediate_nodes, domain_class)
        
        self.relu = nn.ReLU()
        self.soft = nn.Softmax(dim=1)
    def forward(self, x, constant):
        if(self.scheme=='dann'):
            x = GradReverse.grad_reverse(x, constant)
        logits = self.relu(self.fc1(x))
        logits = self.soft(self.fc2(logits))

        return logits
class Network(nn.Module):

	def __init__(self,num_class, domain_class, scheme):
	    super(Network, self).__init__()
	    self.extractor = Extractor()
	    self.classifier = Class_classifier(num_class=num_class)
	    self.domain = Domain_classifier(domain_class=domain_class)
	            
	def forward(self, x, hp_lambda=0):
	    x = self.extractor(x)
	    clss = self.classifier(x)
	    dom = self.domain(x,hp_lambda)
	    return clss,dom

net = Network(2,4,"basic")
torch.save(net, "network.pth")

This save gives the following error message

AttributeError Traceback (most recent call last)
in
----> 1 torch.save(net.modules, “network.pth”)

~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
222 >>> torch.save(x, buffer)
223 “”"
–> 224 return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))
225
226

~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
147 f = open(f, mode)
148 try:
–> 149 return body(f)
150 finally:
151 if new_fd:

~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in (f)
222 >>> torch.save(x, buffer)
223 “”"
–> 224 return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))
225
226

~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
294 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
295 pickler.persistent_id = persistent_id
–> 296 pickler.dump(obj)
297
298 serialized_storage_keys = sorted(serialized_storages.keys())

AttributeError: Can’t pickle local object ‘summary..register_hook..hook’

What is going wrong here?
Saving the state dict works, but I can’t save the entire model.