Save file size doubled after loading scheduler checkpoint

This is potentially a bug of PyTorch scheduler due to the mechanism of pickle. Here’s what I found:

To make the difference of file size more significant, I use nn.Linear(4096, 4096) and run your code snippet. Here’s the sizes of checkpoints on disk

402755585 linear_10steps_load.pth
402755585 linear_10steps_loadtrain.pth
201378711 linear_10steps.pth
402755585 linear_15steps_load.pth
402755585 linear_15steps_loadtrain.pth
201378711 linear_5steps.pth
 67127399 linear_start.pth

The size of model is (4096 * 4096 (weight) + 4096 (bias)) * 4 bytes/float= 67,125,248 bytes, quite close to size of linear_start.pth, which was saved at the moment the optimizer momentums are not initialized, as expected.

With the Adam optimizer.step() done, momentums are initialized, Adam momentums have approximately 2 times of params to the model. So the total size of is around 3 * 67,125,248 = 201,375,744, close to size of linear_5steps.pth and linear_10steps.pth, still as expected.

Things become weird when we first call load_state_dict() then save the checkpoint. The size of all loaded checkpoints goes to 2 times of our expected size. And I found that in the checkpoint[‘scheduler’], there is a special object

scale_fn <class 'method'> <bound method CyclicLR._triangular_scale_fn of <torch.optim.lr_scheduler.CyclicLR object at 0x7f0e1d2b4760>>

So that’s why the checkpoint size increases by 1 times after load_state_dict:

  1. when we saved the scheduler to checkpoint, there is a bound method object, named ‘scale_fn’. And note that pickle would save the whole instance (let’s denote it scheduler_old) when saving such bound method; note that the scheduler_old.optimizer points to the same location of current optimizer, I guess pickle just filters the duplicated one. So we see the sizes of linear_5steps.pth and linear_10steps.pth are normal.
  2. scheduler.load_state_dict() assigns the ‘scale_fn’ from state dict to its self.scale_fn, attched with the instance scheduler_old
  3. once again, save the scheduler to checkpoint. Now, pickle have to save model_new, optimizer_new, schduler_new and scheduler_old, where schduler_new.optimizer is optimizer_new but scheduler_old.optimizer is NOT! So finally, the size of file increase to 2 times.

This can be explained via a simple example.

import pickle
import torch

class A:
    def __init__(self):
        self.mm = torch.randn(4096, 4096)
        self.myfunc = self.boundmethod

    def boundmethod(self):
        pass

    def state_dict(self):
        return {
            'method': self.myfunc,
            'mm': self.mm
        }


a = A()

# save the method to file
with open('method', 'wb') as fo:
    pickle.dump(a.state_dict(), fo)
# size on disk: 67109389

b = A()
with open('method', 'rb') as fi:
    b.myfunc = pickle.load(fi)['method']

# save b method to disk
with open('method_b', 'wb') as fo:
    pickle.dump(b.state_dict(), fo)
# size on disk 134218556

A quick fix is change the method to staticmethod

import pickle
import torch

class B:
    def __init__(self):
        self.mm = torch.randn(4096, 4096)
        self.myfunc = self.boundmethod

    @staticmethod
    def boundmethod(b: "B"):
        pass

    def state_dict(self):
        return {
            'method': self.myfunc,
            'mm': self.mm
        }


a = B()

# save the method to file
with open('method', 'wb') as fo:
    pickle.dump(a.state_dict(), fo)
# size on disk: 67109327

b = B()
with open('method', 'rb') as fi:
    b.myfunc = pickle.load(fi)['method']

# save b method to disk
with open('method_b', 'wb') as fo:
    pickle.dump(b.state_dict(), fo)
# size on disk 67109327

2 Likes