I am trying to train my model using Elastic Weight Consolidation technique, However using the code below, Miss match error
def train_one_epoch(net, train_dataloader):
crit = nn.MSELoss()
device = “cuda” if torch.cuda.is_available() else “cpu”
net.to(device)
ewc = ElasticWeightConsolidation(net, crit=crit, lr=1e-4)
net.train()
for input in enumerate(train_dataloader):
reconstructed = net(input)
ewc.forward_backward_update(reconstructed,input)
train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
)
test_transforms = transforms.Compose(
[transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
)
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) #args.dataset
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=(device == "cuda"),
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=(device == "cuda"),
)
net = bmshj2018_hyperprior(quality=2, pretrained= True)
#net = image_models["bmshj2018-hyperprior"](quality=3)
net = net.to(device)
if args.cuda and torch.cuda.device_count() > 1:
net = CustomDataParallel(net)
optimizer, aux_optimizer = configure_optimizers(net, args)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
criterion = RateDistortionLoss(lmbda=0.0035)
last_epoch = 0
if args.checkpoint: # load from previous checkpoint
print("Loading", args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location=device)
last_epoch = checkpoint["epoch"] + 1
net.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
best_loss = float("inf")
for epoch in range(last_epoch, args.epochs):
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
train_one_epoch(
net,
train_dataloader
)
loss = test_epoch(epoch, test_dataloader, net, criterion)
lr_scheduler.step(loss)
is_best = loss < best_loss
best_loss = min(loss, best_loss)
if args.save:
save_checkpoint(
{
"epoch": epoch,
"state_dict": net.state_dict(),
"loss": loss,
"optimizer": optimizer.state_dict(),
"aux_optimizer": aux_optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
is_best,
)
I get this error:
Cell In[129], line 8 6 net.train() 7 for input in enumerate(train_dataloader): ----> 8 reconstructed = net(input) 9 ewc.forward_backward_update(reconstructed,input) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/compressai/models/google.py:293, in ScaleHyperprior.forward(self, x) 292 def forward(self, x): → 293 y = self.g_a(x) 294 z = self.h_a(torch.abs(y)) 295 z_hat, z_likelihoods = self.entropy_bottleneck(z) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input) 215 def forward(self, input): 216 for module in self: → 217 input = module(input) 218 return input File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input) 462 def forward(self, input: Tensor) → Tensor: → 463 return self._conv_forward(input, self.weight, self.bias) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias) 455 if self.padding_mode != ‘zeros’: 456 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 457 weight, bias, self.stride, 458 _pair(0), self.dilation, self.groups) → 459 return F.conv2d(input, weight, bias, self.stride, 460 self.padding, self.dilation, self.groups) TypeError: conv2d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of: * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups) didn’t match because some of the arguments have invalid types: (!tuple of (int, Tensor)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int) * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups) didn’t match because some of the arguments have invalid types: (!tuple of (int, Tensor)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)