Torch.save cause TypeError cannot picke 'WeakMethod' object

Hello. I got a error like below.
I have no idea what is wrong . I tried to save my model with torch.save but the error come out.
My learning enviornment is based on nvcr 21.11-py3
What make that error?

Traceback (most recent call last):
File “main.py”, line 231, in
main()
File “main.py”, line 152, in main
torch.save({
File “/opt/conda/lib/python3.8/site-packages/torch/serialization.py”, line 423, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File “/opt/conda/lib/python3.8/site-packages/torch/serialization.py”, line 635, in _save
pickler.dump(obj)
TypeError: cannot pickle ‘WeakMethod’ object
Traceback (most recent call last):
File “main.py”, line 231, in
main()
File “main.py”, line 152, in main
torch.save({
File “/opt/conda/lib/python3.8/site-packages/torch/serialization.py”, line 423, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File “/opt/conda/lib/python3.8/site-packages/torch/serialization.py”, line 635, in _save
pickler.dump(obj)
TypeError: cannot pickle ‘WeakMethod’ object

Could you post a minimal, executable code snippet reproducing this error, please?

@ptrblck I found what make the error.
Not including scheduler in save, torch.save works well.
But including schduler, ‘WeakMethod’ object TypeError raised.
Why saving schduler cause the error?
please let me know in detail .

torch.save({ ‘model’: model.state_dict(),
‘optimizer’: optimizer.state_dict(),
‘epoch’: e,
‘lr_scheduler’: scheduler.state_dict(),
‘args’: args,
}, checkpoint_path)

My snippet is like below.

 checkpoint_path ='./ckpt/'+start_time+'_rank_'+str(int(os.environ['RANK']))+'.tar'
 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

 model = resnet50(weights="IMAGENET1K_V2")
 criterion = nn.BCELoss(reduction='sum') #  nn.BCEWithLogitsLoss()

if args.mode=='DDP':
    args.distributed = True
model = with_distributed(model, device, mode=args.mode)
if args.resume:
    checkpoint_path ='./ckpt/'+args.resume+'.tar'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
optimizer = optim.Adam(model.parameters(), lr = args.lr) #LAMB(model.parameters(), lr = args.lr)
scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.lr, max_lr=args.lr*10, \
                                                    step_size_up=10, cycle_momentum=False)
acc_metric = AccuracyMetric(device).to(device)

data_set_start = time.time()
train_bsz = args.train_batch_size
test_bsz = args.test_batch_size
train_data = LymphDataset(transformer=transformer_select(train=True), td_data=True)
valid_data = LymphDataset(transformer=transformer_select(train=True), td_data=True, valid=True)

if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
else:
    train_sampler = None
data_set_end = time.time()
print ("Data SET TIME", data_set_end-data_set_start)
data_load_start = time.time()
train_loader = torch.utils.data.DataLoader(
    dataset = train_data,
    batch_size = train_bsz,
    num_workers = args.workers,
    collate_fn = train_data.collate_fn,
    pin_memory=True,
    drop_last=True,
    shuffle=False,
    sampler=train_sampler
)
valid_loader = torch.utils.data.DataLoader(
    dataset = valid_data,
    batch_size = test_bsz,
    num_workers = args.workers,
    collate_fn = valid_data.collate_fn,
    pin_memory=True,
    drop_last=True,
)
data_load_end = time.time()
print ("Data Loading TIME", data_load_end-data_load_start)
epoch_loss , grad_norm, param_norm = [], [], []
f1_init = 0
acc_init = 0

for e in range(args.epochs):
    start_time = time.time()
    #train(model, optimizer, criterion, device, train_loader, train_bsz, epoch_loss, grad_norm , param_norm)
    #scheduler.step()
    end_time = time.time()
    log.info(f'Time for 1 Epoch {end_time - start_time} sec')
    if args.mode == 'DDP':
        model_without_ddp = model.module
    else:
        model_without_ddp = model

if e % 1 == 0:
        print ("DOING")
        #f1, acc = test(model, criterion, device, valid_loader, test_bsz, acc_metric )
        #log.info(f"F1 SCORE : {f1} , ACCURACY : {acc} ")
        #print ("f1")
        #if f1_init < f1:
        #    f1_init = f1
        #    if args.mode == 'DDP':
        #        model_without_ddp = model.module
        #    else:
        #        model_without_ddp = model
        #    log.info(f'Save the weight of model with f1 score : {f1}')
        torch.save({
                        'model': model_without_ddp .state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': e,
                        'lr_scheduler': scheduler.state_dict(),
                        #'args': args,
                    }, checkpoint_path)
    scheduler.step()

I can reproduce the issue using:

import torch
import torchvision.models as models

model = models.resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=1.)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1., max_lr=10, step_size_up=10, cycle_momentum=False)

torch.save({'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': 0,
            'lr_scheduler': scheduler.state_dict(),
}, "tmp.pt")

The issue is also tracked here with a proposed workaround.

@ptrblck Really appreciate your help. It works. Thank you.

1 Like