Error when using EWC

I am trying to train my model using Elastic Weight Consolidation technique, However using the code below, Miss match error
def train_one_epoch(net, train_dataloader):
crit = nn.MSELoss()
device = “cuda” if torch.cuda.is_available() else “cpu”
net.to(device)
ewc = ElasticWeightConsolidation(net, crit=crit, lr=1e-4)
net.train()
for input in enumerate(train_dataloader):
reconstructed = net(input)
ewc.forward_backward_update(reconstructed,input)

train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
)

test_transforms = transforms.Compose(
    [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
)

train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) #args.dataset
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

device = "cuda" if torch.cuda.is_available() else "cpu"

train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
    pin_memory=(device == "cuda"),
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=args.test_batch_size,
    num_workers=args.num_workers,
    shuffle=False,
    pin_memory=(device == "cuda"),
)

net = bmshj2018_hyperprior(quality=2, pretrained= True)
#net = image_models["bmshj2018-hyperprior"](quality=3)
net = net.to(device)

if args.cuda and torch.cuda.device_count() > 1:
    net = CustomDataParallel(net)
optimizer, aux_optimizer = configure_optimizers(net, args)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
criterion = RateDistortionLoss(lmbda=0.0035)

last_epoch = 0
if args.checkpoint:  # load from previous checkpoint
    print("Loading", args.checkpoint)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    last_epoch = checkpoint["epoch"] + 1
    net.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
    lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

best_loss = float("inf")
for epoch in range(last_epoch, args.epochs):
    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    train_one_epoch(
        net,
        train_dataloader
    )
    loss = test_epoch(epoch, test_dataloader, net, criterion)
    lr_scheduler.step(loss)

    is_best = loss < best_loss
    best_loss = min(loss, best_loss)

    if args.save:
        save_checkpoint(
            {
                "epoch": epoch,
                "state_dict": net.state_dict(),
                "loss": loss,
                "optimizer": optimizer.state_dict(),
                "aux_optimizer": aux_optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
            },
            is_best,
        )

I get this error:
Cell In[129], line 8 6 net.train() 7 for input in enumerate(train_dataloader): ----> 8 reconstructed = net(input) 9 ewc.forward_backward_update(reconstructed,input) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/compressai/models/google.py:293, in ScaleHyperprior.forward(self, x) 292 def forward(self, x): → 293 y = self.g_a(x) 294 z = self.h_a(torch.abs(y)) 295 z_hat, z_likelihoods = self.entropy_bottleneck(z) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input) 215 def forward(self, input): 216 for module in self: → 217 input = module(input) 218 return input File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs) 1496 # If we don’t have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): → 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = , File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input) 462 def forward(self, input: Tensor) → Tensor: → 463 return self._conv_forward(input, self.weight, self.bias) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias) 455 if self.padding_mode != ‘zeros’: 456 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 457 weight, bias, self.stride, 458 _pair(0), self.dilation, self.groups) → 459 return F.conv2d(input, weight, bias, self.stride, 460 self.padding, self.dilation, self.groups) TypeError: conv2d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of: * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups) didn’t match because some of the arguments have invalid types: (!tuple of (int, Tensor)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int) * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups) didn’t match because some of the arguments have invalid types: (!tuple of (int, Tensor)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)

Your code and error messages are not properly formatted, but it seems a conv2d call fails with an invalid input:


[quote="Chihoub_Chiheb_Eddin, post:1, topic:214928"]
TypeError: conv2d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int),
[/quote]
and it seems you are trying to pass a `tuple` to this layer while a tensor is expected.
1 Like

Alright, I’ll reinspect the code, sorry for the mess