Hello, I am using wsl ubuntu 22.04, and I am using the compressai lib. I am trying to fine-tune a model on my machine, I have prepared my dataset images of 256*256. However, once I load the checkpoints file, I get the error shown below. Ps: I used the same file to test and it is working, I can retrain the model from epoch 0 normal and I don’t need to make any changes.
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=(device == "cuda"),
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=(device == "cuda"),
)
net = MLICPlusPlus(config=config)
if args.cuda and torch.cuda.device_count() > 1:
net = CustomDataParallel(net)
net = net.to(device)
#optimizer, aux_optimizer = configure_optimizers(net, args)
parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles"))
aux_parameters = [p for n, p in net.named_parameters() if n.endswith(".quantiles")]
optimizer = optim.Adam(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1)
criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
if args.checkpoint != None:
checkpoint = torch.load(args.checkpoint)
#new_ckpt = modify_checkpoint(checkpoint['state_dict'])
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
aux_optimizer.load_state_dict(checkpoint['aux_optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) #4
# lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1)
# lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count']
# lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch']
# print(lr_scheduler.state_dict())
start_epoch = checkpoint['epoch']
best_loss = checkpoint['loss']
current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size)
checkpoint = None
else:
start_epoch = 0
best_loss = 1e10
current_step = 0
# start_epoch = 0
# best_loss = 1e10
# current_step = 0
logger_train.info(args)
logger_train.info(config)
logger_train.info(net)
logger_train.info(optimizer)
optimizer.param_groups[0]['lr'] = args.learning_rate
for epoch in range(start_epoch, args.epochs):
logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
current_step = train_one_epoch(
net,
criterion,
train_dataloader,
optimizer,
aux_optimizer,
epoch,
args.clip_max_norm,
logger_train,
tb_logger,
current_step
)
def train_one_epoch(
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, logger_train, tb_logger, current_step
):
model.train()
device = next(model.parameters()).device
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
aux_optimizer.zero_grad()
out_net = model(d)
out_criterion = criterion(out_net, d)
out_criterion["loss"].backward()
if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
aux_loss = model.aux_loss()
aux_loss.backward()
aux_optimizer.step()
File “/home/projetdoc/MLIC2ICIP/MLIC2/train.py”, line 167, in
main()
File “/home/projetdoc/MLIC2ICIP/MLIC2/train.py”, line 129, in main
current_step = train_one_epoch(
File “/home/projetdoc/MLIC2ICIP/MLIC2/utilss/training.py”, line 24, in train_one_epoch
optimizer.step()
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py”, line 75, in wrapper
return wrapped(*args, **kwargs)
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/optimizer.py”, line 385, in wrapper
out = func(*args, **kwargs)
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/optimizer.py”, line 76, in _use_grad
ret = func(self, *args, **kwargs)
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/adam.py”, line 166, in step
adam(
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/adam.py”, line 316, in adam
func(params,
File “/home/projetdoc/.local/lib/python3.10/site-packages/torch/optim/adam.py”, line 520, in _multi_tensor_adam
torch.foreach_lerp(device_exp_avgs, device_grads, 1 - beta1)
RuntimeError: The size of tensor a (3) must match the size of tensor b (192) at non-singleton dimension 1.
I could not make sense of the error. Thank you