Save file size doubled after loading scheduler checkpoint

After loading the model state_dict, optimizer state_dict, and scheduler state_dict and then saving all three, the file size is double that of when saving all three without previously loading the three state_dict’s.

Running:

import torch
import torch.nn as nn
import os

device_id = 'cpu'
device = torch.device(device_id)

model = nn.Linear(100, 1000).to(device)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.CyclicLR(
    optimizer, 1e-7, 1e-4, 500, cycle_momentum=False
)

f = open('filesize.log', 'w')
f.write(torch.__version__ + '\n')

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_start.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

for step in range(5):
    optimizer.zero_grad()
    
    x = torch.randn(64, 100).to(device)
    out = model(x)
    out.backward(torch.randn(64, 1000).to(device))
    
    optimizer.step()

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_5steps.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

for step in range(5):
    optimizer.zero_grad()
    
    x = torch.randn(64, 100).to(device)
    out = model(x)
    out.backward(torch.randn(64, 1000).to(device))
    
    optimizer.step()

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_10steps.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

checkpoint = torch.load('linear_5steps.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_10steps_load.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

for step in range(5):
    optimizer.zero_grad()
    
    x = torch.randn(64, 100).to(device)
    out = model(x)
    out.backward(torch.randn(64, 1000).to(device))
    
    optimizer.step()

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_10steps_loadtrain.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

checkpoint = torch.load('linear_10steps_loadtrain.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_15steps_load.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

for step in range(5):
    optimizer.zero_grad()
    
    x = torch.randn(64, 100).to(device)
    out = model(x)
    out.backward(torch.randn(64, 1000).to(device))
    
    optimizer.step()

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict()
}
fname = 'linear_15steps_loadtrain.pth'
torch.save(checkpoint, fname)
f.write(fname + ' ' + str(os.path.getsize(fname)) + '\n')

f.close()

filesize.log’s output is:

1.10.2+cu113
linear_start.pth 406119
linear_5steps.pth 1214999
linear_10steps.pth 1214999
linear_10steps_load.pth 2428097
linear_10steps_loadtrain.pth 2428097
linear_15steps_load.pth 2428097
linear_15steps_loadtrain.pth 2428097

Similar results are found on both Ubuntu and Windows. Anyone know why this happens? Is this a bug?

I don’t think it’s a bug and you would be seeing:

  • linear_start contains an empty optimizer, since no training was done yet, so it’s size is “small”
  • linear_5steps filled all buffers so the size increases
  • linear_10/15 was saved after training the model for more steps. Since the parameters were changed my guess is that the data contains more entropy now and the data compression cannot lower the size that much anymore (that’s my best guess and you could check it manually)

linear_start being small makes sense.
But linear_10steps_load was saved immediately from loading linear_5steps (a bit of a misnomer) so it doesn’t appear to be because of more training. And the file size of linear_10steps_load.pth is suspiciously close to double of linear_5steps.pth. Thanks.

This is potentially a bug of PyTorch scheduler due to the mechanism of pickle. Here’s what I found:

To make the difference of file size more significant, I use nn.Linear(4096, 4096) and run your code snippet. Here’s the sizes of checkpoints on disk

402755585 linear_10steps_load.pth
402755585 linear_10steps_loadtrain.pth
201378711 linear_10steps.pth
402755585 linear_15steps_load.pth
402755585 linear_15steps_loadtrain.pth
201378711 linear_5steps.pth
 67127399 linear_start.pth

The size of model is (4096 * 4096 (weight) + 4096 (bias)) * 4 bytes/float= 67,125,248 bytes, quite close to size of linear_start.pth, which was saved at the moment the optimizer momentums are not initialized, as expected.

With the Adam optimizer.step() done, momentums are initialized, Adam momentums have approximately 2 times of params to the model. So the total size of is around 3 * 67,125,248 = 201,375,744, close to size of linear_5steps.pth and linear_10steps.pth, still as expected.

Things become weird when we first call load_state_dict() then save the checkpoint. The size of all loaded checkpoints goes to 2 times of our expected size. And I found that in the checkpoint[‘scheduler’], there is a special object

scale_fn <class 'method'> <bound method CyclicLR._triangular_scale_fn of <torch.optim.lr_scheduler.CyclicLR object at 0x7f0e1d2b4760>>

So that’s why the checkpoint size increases by 1 times after load_state_dict:

  1. when we saved the scheduler to checkpoint, there is a bound method object, named ‘scale_fn’. And note that pickle would save the whole instance (let’s denote it scheduler_old) when saving such bound method; note that the scheduler_old.optimizer points to the same location of current optimizer, I guess pickle just filters the duplicated one. So we see the sizes of linear_5steps.pth and linear_10steps.pth are normal.
  2. scheduler.load_state_dict() assigns the ‘scale_fn’ from state dict to its self.scale_fn, attched with the instance scheduler_old
  3. once again, save the scheduler to checkpoint. Now, pickle have to save model_new, optimizer_new, schduler_new and scheduler_old, where schduler_new.optimizer is optimizer_new but scheduler_old.optimizer is NOT! So finally, the size of file increase to 2 times.

This can be explained via a simple example.

import pickle
import torch

class A:
    def __init__(self):
        self.mm = torch.randn(4096, 4096)
        self.myfunc = self.boundmethod

    def boundmethod(self):
        pass

    def state_dict(self):
        return {
            'method': self.myfunc,
            'mm': self.mm
        }


a = A()

# save the method to file
with open('method', 'wb') as fo:
    pickle.dump(a.state_dict(), fo)
# size on disk: 67109389

b = A()
with open('method', 'rb') as fi:
    b.myfunc = pickle.load(fi)['method']

# save b method to disk
with open('method_b', 'wb') as fo:
    pickle.dump(b.state_dict(), fo)
# size on disk 134218556

A quick fix is change the method to staticmethod

import pickle
import torch

class B:
    def __init__(self):
        self.mm = torch.randn(4096, 4096)
        self.myfunc = self.boundmethod

    @staticmethod
    def boundmethod(b: "B"):
        pass

    def state_dict(self):
        return {
            'method': self.myfunc,
            'mm': self.mm
        }


a = B()

# save the method to file
with open('method', 'wb') as fo:
    pickle.dump(a.state_dict(), fo)
# size on disk: 67109327

b = B()
with open('method', 'rb') as fi:
    b.myfunc = pickle.load(fi)['method']

# save b method to disk
with open('method_b', 'wb') as fo:
    pickle.dump(b.state_dict(), fo)
# size on disk 67109327

2 Likes

I believe you’re right. Good explanation, thanks.

Found a similar/the same issue for OneCycleLR at torch.save() serializes a bound method in OneCycleLR's state_dict, causing CUDA problems · Issue #42376 · pytorch/pytorch · GitHub , where scale_fn is replaced by anneal_func.

Switching to staticmethod should fix it for _annealing_cos and _annealing_linear in OneCycleLR, and for _triangular_scale_fn and _triangular2_scale_fn in CyclicLR, but does not fix it for _exp_range_scale_fn in CyclicLR.

I would fix it by defining

class CyclicLR(_LRScheduler):
    # ... other stuff

    @staticmethod
    def _triangular_scale_fn(x, gamma=1.0):
        return 1.

    @staticmethod
    def _triangular2_scale_fn(x, gamma=1.0):
        return 1 / (2. ** (x - 1))

    @staticmethod
    def _exp_range_scale_fn(x, gamma=1.0):
        return gamma**(x)

    def get_lr(self):
        # ... other stuff

        lrs = []
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
            base_height = (max_lr - base_lr) * scale_factor
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle, self.gamma) # changed
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_epoch, self.gamma) # changed
            lrs.append(lr)

        if self.cycle_momentum:
            momentums = []
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
                base_height = (max_momentum - base_momentum) * scale_factor
                if self.scale_mode == 'cycle':
                    momentum = max_momentum - base_height * self.scale_fn(cycle, self.gamma) # changed
                else:
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch, self.gamma) # changed
                momentums.append(momentum)
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group['momentum'] = momentum

and then simply

class OneCycleLR(_LRScheduler):
    # ... other stuff

    @staticmethod
    def _annealing_cos(start, end, pct):
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    @staticmethod
    def _annealing_linear(start, end, pct):
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        return (end - start) * pct + start

Not sure if there’s a better way to fix it.

One option I think is to save the mode in state_dict, and custom the state_dict() & load_state_dict() method, i.e. filter out the scale_fn attribute in state_dict() and rescore the scale_fn via mode in load_state_dict. See pytorch/lr_scheduler.py at dd1121435b14b48d862b5f16e4ce8b0d71f6dd5d · pytorch/pytorch · GitHub

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key not in ['optimizer', 'scale_fn']}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.
        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)
        # resume the scale_fn via mode
        if self.mode == 'triangular':
            self.scale_fn = self._triangular_scale_fn
        elif self.mode == 'triangular2':
            self.scale_fn = self._triangular2_scale_fn
        elif self.mode == 'exp_range':
            self.scale_fn = self._exp_range_scale_fn
        else:
            raise ValueError

I think your solution is better. Perhaps:

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key not in ['optimizer'] and not hasattr(value, '__self__')}

so that bound methods are never copied to the state_dict, is a more general solution?

That would certainly filter out the bound method, but also remember that we need to resume the method at load_state_dict. I think we have to deal with the state dict case by case.

Right. Checking for bound methods in the state_dict is probably better as a unit test.