Save/Load pretrained model when pre-hook operation are added

Hello everyone,

I am wondering if when we save the parameters of a trained model which contains layers with custom pre-hook operations (such as spectral normalization) the state dictionary actually also contains parameters related to those pre-hook operations and can we also recover those parameters with the load_state_dict function.

I made a very simple example using spectral normalization (Pytorch implementation) as pre-hook operation available here:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
import numpy as np 
from torch.utils.data import Dataset, DataLoader
import os
import gzip
import struct 
from tqdm import tqdm
import cv2
from collections import OrderedDict
from spectral_norm import spectral_norm


class MNISTDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]  
        return sample


class TestModel1 (nn.Module):

    def __init__(self):
        super(TestModel1, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Sequential(
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU()
        )
        self.conv3 = spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1))
        self.conv4 = nn.Sequential(
            spectral_norm(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)),
            nn.ReLU()
        )
        self.conv9 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.conv10 = nn.Sequential(
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
            nn.ReLU()
        )
        self.conv11 = nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0)


    def forward(self, x):
        y_ = x/255
        y_ = self.conv1(y_)
        y = self.conv2(y_)
        y_ = y_+y
        y_ = self.conv3(y_)
        y = self.conv4(y_)
        y_ = y_+y
        y_ = self.conv9(y_)
        y = self.conv10(y_)
        y_ = y_+y
        y_ = self.conv11(y_)
        y_ = F.sigmoid(y_)
        y_ = 255*y_
        return y_

    def load(self, path):

        dic = torch.load(os.path.join(path, "intermediate_state.pth"))["state_dict"]
        new_dic = OrderedDict()
        for k,v in dic.items():
            name = k[7:]
            new_dic[name] = v
        self.load_state_dict(new_dic)
        

class TestModel0 (nn.Module):

    def __init__(self):
        super(TestModel0, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.conv9 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.conv10 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.conv11 = nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0)


    def forward(self, x):
        y_ = x/255
        y_ = self.conv1(y_)
        y = self.conv2(y_)
        y_ = y_+y
        y_ = self.conv3(y_)
        y = self.conv4(y_)
        y_ = y_+y
        y_ = self.conv9(y_)
        y = self.conv10(y_)
        y_ = y_+y
        y_ = self.conv11(y_)
        y_ = F.sigmoid(y_)
        y_ = 255*y_
        return y_

    def load(self, path):

        dic = torch.load(os.path.join(path, "intermediate_state.pth"))["state_dict"]
        new_dic = OrderedDict()
        for k,v in dic.items():
            name = k[7:]
            new_dic[name] = v
        self.load_state_dict(new_dic)

def train(model, loader, path, epochs=10, parallelize=True):

    if parallelize:
        model = nn.DataParallel(model)
    else:
        model = model
    model = model.cuda()

    path_epoch = os.path.join(path, "Training")
    if not os.path.exists(path_epoch):
        os.makedirs(path_epoch)

    loss_reco = 0

    optimizer = optim.Adam(model.parameters(), lr=0.00001)
    updater = lambda epoch: 0.95**(epoch//10) 
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = updater)

    criterion_reco = nn.MSELoss()    

    for epoch in range(epochs):
        print("Begining epoch {}/{}.".format(epoch+1, epochs))
        
        count = 0

        for num, batch in tqdm(enumerate(loader)):

            batch = batch.unsqueeze(1).cuda()

            count += 1
            rec = model(batch)

            loss2 = criterion_reco(rec, batch.detach())
            loss_reco += loss2.item()
            
            optimizer.zero_grad()
            loss2.backward()
            optimizer.step()

            if num==0:
                for i in range(min(5, rec.size(0))):
                    im = np.transpose(batch[i].detach().cpu().numpy(), (1, 2, 0))
                    reco = np.transpose(rec[i].detach().cpu().numpy(), (1, 2, 0))
                    cv2.imwrite(os.path.join(path_epoch, "epoch_"+str(epoch)+"_"+str(i)+"_image.png"), im)
                    cv2.imwrite(os.path.join(path_epoch, "epoch_"+str(epoch)+"_"+str(i)+"_reco.png"), reco)

        loss_reco = loss_reco/(count)

        print("Reconstruction loss: {}".format(loss_reco))

        loss_reco = 0

        torch.save({"state_dict":model.state_dict()}, os.path.join(path,"intermediate_state.pth"))

        scheduler.step()
    

def eval_model(model, loader, parallelize=True):

    if parallelize:
        model = nn.DataParallel(model)
    else:
        model = model
    model = model.cuda()
    model.eval()

    loss_reco = 0

    path_epoch = os.path.join(path, "Val")
    if not os.path.exists(path_epoch):
        os.makedirs(path_epoch)

    criterion_reco = nn.L1Loss()
        
    count = 0

    for num, batch in tqdm(enumerate(loader)):

        count += 1
        batch = batch.unsqueeze(1).cuda()
        rec = model(batch)

        loss2 = criterion_reco(rec, batch.detach())

        loss_reco += loss2.item()

        if num==0:
                for i in range(min(5, rec.size(0))):
                    im = np.transpose(batch[i].detach().cpu().numpy(), (1, 2, 0))
                    reco = np.transpose(rec[i].detach().cpu().numpy(), (1, 2, 0))
                    cv2.imwrite(os.path.join(path_epoch, str(i)+"_image.png"), im)
                    cv2.imwrite(os.path.join(path_epoch, str(i)+"_reco.png"), reco)

    loss_reco = loss_reco/(count)

    print("Reconstruction loss: {}".format(loss_reco))

    return model


def read_idx(filename):
    with gzip.open(filename) as f:
        zero, data_type, dims = struct.unpack(">HBB", f.read(4))
        shape = tuple(struct.unpack(">I", f.read(4))[0] for d in range(dims))
        return np.fromstring(f.read(), dtype=np.uint8).reshape(shape)

path_mnist = "brain_anomaly_detection/MNIST/images_train.gz"
path_labels = "brain_anomaly_detection/MNIST/label_train.gz"

if __name__=="__main__":

    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    path = "brain_anomaly_detection/test_save_model"
    if not os.path.exists(path):
        os.makedirs(path)

    mnist = read_idx(path_mnist).astype(np.float32)
    data = MNISTDataset(mnist)
    loader = DataLoader(dataset=data, batch_size=128, shuffle=True)

    test_model = TestModel1()
    eval_model(test_model, loader)
    test_model.train(True)
    train(test_model, loader, path, epochs=20)
    print("Test when using Spectral normalization")
    print("Evaluation right after training")
    eval_model(test_model, loader)

    test_model = TestModel1()
    test_model.load(path)
    print("Evaluation after save/load parameter")
    eval_model(test_model, loader)

    test_model = TestModel0()
    eval_model(test_model, loader)
    test_model.train(True)
    train(test_model, loader, path, epochs=20)
    print("Test when NOT using Spectral normalization")
    print("Evaluation right after training")
    eval_model(test_model, loader)

    test_model = TestModel0()
    test_model.load(path)
    print("Evaluation after save/load parameter")
    eval_model(test_model, loader)

First a few explanations about this code:
-TestModel1 and TestModel0 are two very simple convolutional auto encoder except that TestModel1 includes spectale normalization at several layers.
-They are both trained on MNIST dataset (I downloaded the orignal version in the folder “brain_anomaly_detection/MNIST” so you should adapt that part) with the exact same hyperparameters for 20 epochs (the idea here is just to train a little bit to obtained a better score than the random one corresponding to parameters initialization).
-One last important thing is that I slightly modified the Pytorch code of spectral normalization:
first line 29:

weight_mat = weight_mat.reshape(height, -1)
becomes
weight_mat = weight_mat.contiguous().view(height, -1)
because otherwise spectral_norm does not support data parallelization on multiple GPU in my experience.

then line 51->57:

else:
    r_g = getattr(module, self.name + '_orig').requires_grad
    getattr(module, self.name).detach_().requires_grad_(r_g)

becomes

else:
    weight, u = self.compute_weight(module)
    setattr(module, self.name, weight)

because otherwise the accuracy drops when switching to evaluation mode and it makes more sens to me this way (I am not an expert tho).

To summurize, if you want to execute this script you should:

  1. modify the path to MNIST training set
  2. Copy-past the Pytorch script for spectral norm and name it “spectral_norm.py” OR change line 14 of this script to import your version of spectral normalization
  3. Indicate your personal amount of GPU (line 265)
    Then everything should work fine.

After running this script, I obtain the following message which summarize the scores:

Test when using Spectral normalization
Evaluation right after training
Reconstruction loss: 1.65
Evaluation after save/load parameter
Reconstruction loss: 110.02
Test when NOT using Spectral normalization
Evaluation right after training
Reconstruction loss: 1.88
Evaluation after save/load parameter
Reconstruction loss: 1.88

110 (the score obtained by TestModel1 after save/load parameters) actually corresponds to the kind of score obtained when the network is just randomly initialized without any training which shows that there is a problem with the pre-hook operation when loading model’s parameter.

Sorry for the long post.

I had the same problem. I was using DataParallel(), and loading the model in eval() mode does not work.

@el_samou_samou, could you elaborate more on weight_mat = weight_mat.contiguous().view(height, -1)? I don’t think I had problem with that.

Hey Tae and @el_samou_samou … This is my bad. SN with DP currently is broken. They don’t work on training or eval mode. A fix is in the works. But a workaround is:

  1. use this DP:
class DataParallel(nn.parallel.DataParallel):
    def replicate(self, module, device_ids):
        replicas = super(DataParallel, self).replicate(module, device_ids)
        replicas[0] = module
        return replicas
  1. use this SN: https://gist.github.com/SsnL/8e638bcfd49e71d6b1930db0df87d970
    Note that only line 56 changed.
1 Like
weight_mat = weight_mat.contiguous().view(height, -1)

Was just a way to stop error messages but it probably doesn t solve the problem. If you did not encounter problems with that line it might just be because of your Pytorch version (I am on 0.4.0 so maybe they made some changes in 0.4.1 which solve this).

Anyway I solved this using another SN implementation found on Github but @SimonW solution looks nice as well! Thank you for that Simon.

Thank you for your answer. I still have to modify .reshape to .contiguous().view line 29. According to you, is it because I am on Pytorch 0.4.0 or something else?

It’s not necessary. Both would work!

A fix is at https://github.com/pytorch/pytorch/pull/12671

1 Like