Hi,
I am creating a model for image classification but getting this error: Given groups=1, weight of size [40, 3, 3, 3], expected input[16, 32, 33, 5] to have 3 channels, but got 32 channels instead
I am using an efficientnet_b3 model.
Here is my dataset class:
The error is raised by an nn.Conv2d layer, which expects an input with 3 channels, while it seems you are trying to pass an input tensor with 5 channels (in a wrong memory layout) to this layer.
Assuming my assumption is correct and your input contains 5 channels, you would have to first permute it to the channels-first format via input = input.permute(0, 3, 1, 2) and then either make sure to use 3 channels in the input tensor or change the first conv layer to accept 5 channels.
I don’t know, which loss function you are using, but assume you’ve solved the shape mismatch error.
Could you post the complete stack trace from the new error, as I also don’t know which part of the code raises it.
If this error is raised from the DataLoader, make sure that all images are 3 channel images in the channels-first format. Currently it seems entry0 is still in channels-last and entry9 is a grayscale image without a channel dimension.
I am using nn.CrossEntropyLoss()
Here is the full stack trace:
RuntimeError Traceback (most recent call last)
<ipython-input-16-2f18cc448813> in <module>
----> 1 run_training()
<ipython-input-15-8e23b7e5c3d9> in run_training()
12 print(f'training on: {device}')
13 for epoch in tqdm(range(CFG["epochs"])):
---> 14 trn_loss = train_epoch(train_loader, model, loss_fn, optimizer, device)
15 with torch.no_grad():
16 val_loss, valid_preds = valid_epoch(val_loader, model, loss_fn, device)
<ipython-input-13-c839bbc89ccd> in train_epoch(dataloader, model, loss_fn, optimizer, device, scheduler)
2 model.train()
3 final_loss = 0
----> 4 for data in dataloader:
5 img = data["img"].to(device)
6 img = img.permute(0, 3, 1, 2)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1083 else:
1084 del self._task_info[idx]
-> 1085 return self._process_data(data)
1086
1087 def _try_put_index(self):
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1109 self._try_put_index()
1110 if isinstance(data, ExceptionWrapper):
-> 1111 data.reraise()
1112 return data
1113
/opt/conda/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
426 # have message field
427 raise self.exc_type(message=msg)
--> 428 raise self.exc_type(msg)
429
430
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [64, 64, 3] at entry 0 and [64, 64] at entry 9
I think my assumption is correct and you would have to make sure all images have the same shape so that the DataLoader can stack them to create a batch.
nn.CrossEntropyLoss expects an output in the shape [batch_size, nb_classes] and a target in [batch_size] containing class indices in the range [0, nb_classes-1] for a multi-classification use case.
Thanks, I fixed it. Turns out the way I was opening the image was wrong, I was using PIL.Image.open. Now I am using CV2 to open then and convert BGR2RGB.