Saving full model with pytorch ignite

Hello, I was able to run successfully this tutorial with pytorch ignite

but in order to do a pruning method I need to save the whole model (state dict is not useful), and I try to save with,‘model1.pth’) for example and I get the following message:

File “”, line 1, in,‘modelo_ejemplo1.pth’)

File “C:\ProgramData\Anaconda3\lib\site-packages\torch\”, line 260, in save
return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))

File “C:\ProgramData\Anaconda3\lib\site-packages\torch\”, line 185, in _with_file_like
return body(f)

File “C:\ProgramData\Anaconda3\lib\site-packages\torch\”, line 260, in
return _with_file_like(f, “wb”, lambda f: _save(obj, f, pickle_module, pickle_protocol))

File “C:\ProgramData\Anaconda3\lib\site-packages\torch\”, line 332, in _save

AttributeError: Can’t pickle local object ‘_initialize…patch_forward…new_fwd’

I’ve been looking a lot about this error message, but anyone seems to use ignite a lot…

I can save the model like this

to_save = {‘model’: model}
handler = Checkpoint(to_save, DiskSaver(‘D:\Neural_Nets\EfficientNet\Pytorch\CIFAR10b’, create_dir=True,require_empty=False), n_saved=2)
trainer.add_event_handler(Events.EPOCH_STARTED, handler)

but I can only save the weights…I want to save the whole model like I do with other codes that don’t use this pytorch ignite…can you help me with this issue? I think ignite is more efficient than my other way of training, that is why I want to change to it but if I can’t save the model there is no case :slight_smile:

Thanks in advance!

Hi @Tanya_Boone,‘model1.pth’)
AttributeError: Can’t pickle local object ‘_initialize…patch_forward…new_fwd’

seems like your model can not be saved with

Maybe you need to replace some lambda function in there, if there are some…
Can you share the model definition to take a look ?

I don’t think is the model because I can save it with other type of training but when I switch to ignite is when I get this kind of error…

Hi Tanya! It’s a bit far but could you share a model that you can not save with ignite ? Thank you very much for your help :slight_smile:

Yes, for example, efficientnet

class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)

class Flatten(nn.Module):
def forward(self, x):
return x.reshape(x.shape[0], -1)

class SqueezeExcitation(nn.Module):

def __init__(self, inplanes, se_planes):
    super(SqueezeExcitation, self).__init__()
    self.reduce_expand = nn.Sequential(
        nn.Conv2d(inplanes, se_planes, 
                  kernel_size=1, stride=1, padding=0, bias=True),
        nn.Conv2d(se_planes, inplanes, 
                  kernel_size=1, stride=1, padding=0, bias=True),

def forward(self, x):
    x_se = torch.mean(x, dim=(-2, -1), keepdim=True)
    x_se = self.reduce_expand(x_se)
    return x_se * x

class MBConv(nn.Module):
def init(self, inplanes, planes, kernel_size, stride,
expand_rate=1.0, se_rate=0.25,
super(MBConv, self).init()

    expand_planes = int(inplanes * expand_rate)
    se_planes = max(1, int(inplanes * se_rate))

    self.expansion_conv = None        
    if expand_rate > 1.0:
        self.expansion_conv = nn.Sequential(
            nn.Conv2d(inplanes, expand_planes, 
                      kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3),
        inplanes = expand_planes

    self.depthwise_conv = nn.Sequential(
        nn.Conv2d(inplanes, expand_planes,
                  kernel_size=kernel_size, stride=stride, 
                  padding=kernel_size // 2, groups=expand_planes,
        nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3),

    self.squeeze_excitation = SqueezeExcitation(expand_planes, se_planes)
    self.project_conv = nn.Sequential(
        nn.Conv2d(expand_planes, planes, 
                  kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(planes, momentum=0.01, eps=1e-3),

    self.with_skip = stride == 1
    self.drop_connect_rate = torch.tensor(drop_connect_rate, requires_grad=False)

def _drop_connect(self, x):        
    keep_prob = 1.0 - self.drop_connect_rate
    drop_mask = torch.rand(x.shape[0], 1, 1, 1) + keep_prob
    drop_mask = drop_mask.type_as(x)
    return drop_mask * x / keep_prob
def forward(self, x):
    z = x
    if self.expansion_conv is not None:
        x = self.expansion_conv(x)

    x = self.depthwise_conv(x)
    x = self.squeeze_excitation(x)
    x = self.project_conv(x)
    # Add identity skip
    if x.shape == z.shape and self.with_skip:            
        if and self.drop_connect_rate is not None:
        x += z
    return x

def init_weights(module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, a=0, mode=‘fan_out’)
elif isinstance(module, nn.Linear):
init_range = 1.0 / math.sqrt(module.weight.shape[1])
nn.init.uniform_(module.weight, a=-init_range, b=init_range)

class EfficientNet(nn.Module):

def _setup_repeats(self, num_repeats):
    return int(math.ceil(self.depth_coefficient * num_repeats))

def _setup_channels(self, num_channels):
    num_channels *= self.width_coefficient
    new_num_channels = math.floor(num_channels / self.divisor + 0.5) * self.divisor
    new_num_channels = max(self.divisor, new_num_channels)
    if new_num_channels < 0.9 * num_channels:
        new_num_channels += self.divisor
    return new_num_channels

def __init__(self, num_classes, 
    super(EfficientNet, self).__init__()
    self.width_coefficient = width_coefficient
    self.depth_coefficient = depth_coefficient
    self.divisor = 8
    list_channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]
    list_channels = [self._setup_channels(c) for c in list_channels]
    list_num_repeats = [1, 2, 2, 3, 3, 4, 1]
    list_num_repeats = [self._setup_repeats(r) for r in list_num_repeats]        
    expand_rates = [1, 6, 6, 6, 6, 6, 6]
    strides = [1, 2, 2, 2, 1, 2, 1]
    kernel_sizes = [3, 3, 5, 3, 5, 5, 3]

    # Define stem:
    self.stem = nn.Sequential(
        nn.Conv2d(3, list_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(list_channels[0], momentum=0.01, eps=1e-3),
    # Define MBConv blocks
    blocks = []
    counter = 0
    num_blocks = sum(list_num_repeats)
    for idx in range(7):
        num_channels = list_channels[idx]
        next_num_channels = list_channels[idx + 1]
        num_repeats = list_num_repeats[idx]
        expand_rate = expand_rates[idx]
        kernel_size = kernel_sizes[idx]
        stride = strides[idx]
        drop_rate = drop_connect_rate * counter / num_blocks
        name = "MBConv{}_{}".format(expand_rate, counter)
            MBConv(num_channels, next_num_channels, 
                   kernel_size=kernel_size, stride=stride, expand_rate=expand_rate, 
                   se_rate=se_rate, drop_connect_rate=drop_rate)
        counter += 1
        for i in range(1, num_repeats):                
            name = "MBConv{}_{}".format(expand_rate, counter)
            drop_rate = drop_connect_rate * counter / num_blocks                
                MBConv(next_num_channels, next_num_channels, 
                       kernel_size=kernel_size, stride=1, expand_rate=expand_rate, 
                       se_rate=se_rate, drop_connect_rate=drop_rate)                                    
            counter += 1
    self.blocks = nn.Sequential(OrderedDict(blocks))
    # Define head
    self.head = nn.Sequential(
        nn.Conv2d(list_channels[-2], list_channels[-1], 
                  kernel_size=1, bias=False),
        nn.BatchNorm2d(list_channels[-1], momentum=0.01, eps=1e-3),
        nn.Linear(list_channels[-1], num_classes)

def forward(self, x):
    f = self.stem(x)
    f = self.blocks(f)
    y = self.head(f)
    return y

model = EfficientNet(num_classes=10,
width_coefficient=1.4, depth_coefficient=1.8,
resolution = 380
img_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]

but I managed to save it with this library called dill

For example,

import dill

Dill routine


and I dont get any errors…

when I want to load it I do it like this

model1 = torch.load(model_name)

Sorry for this late response. I will look more precisely your code, thank very much !

