RuntimeError: The size of tensor a (3) must match the size of tensor b (0) at non-singleton dimension 0

Hi there, I want to implement semi-supervised learning. First, I want to do the slef-supervised part using SwAV from pl_bolts. I have my unlabelled dataset in a file in google drive and I am using google colab pro plus to run the code. Below you can find my code
!pip install --upgrade pip
!pip install pytorch-lightning==1.8.0
!pip install git+https://github.com/PytorchLightning/lightning-bolts.git@master --upgrade
from google.colab import drive
drive.mount(‘/content/drive’)
import os
import torch
import pytorch_lightning as pl
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from pl_bolts.models.self_supervised import SwAV
from sklearn.metrics import accuracy_score
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, Resize, ToTensor, Normalize
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
class UnlabelledDataset(Dataset):
def init(self, root, transform=None, num_views=2):
self.root = root
self.transform = transform
self.num_views = num_views
self.image_files = [f for f in os.listdir(root) if f.endswith((‘.png’, ‘.jpg’, ‘.jpeg’))]

def __len__(self):
    return len(self.image_files)

def __getitem__(self, idx):
    image_path = os.path.join(self.root, self.image_files[idx])
    image = Image.open(image_path).convert('RGB')

    images = []
    for _ in range(self.num_views):
        transformed_image = image
        if self.transform:
            transformed_image = self.transform(image)
        images.append(transformed_image)

    return images, idx

def multi_view_collate_fn(batch):
images, indices = zip(*batch)
images = [torch.stack(view) for view in zip(*images)]
images = torch.stack(images).transpose(0, 1).contiguous().view(-1, 3, 224, 224)
indices = torch.tensor(indices)

return images, indices

def train_swav(swav_checkpoint_path, unlabelled_root, num_epochs, batch_size=8):

unlabelled_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.8, 0.8, 0.8, 0.2),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

'''unlabelled_dataset = UnlabelledDataset(root=unlabelled_root, transform=unlabelled_transforms)
unlabelled_loader = DataLoader(unlabelled_dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=multi_view_collate_fn)


# Load the checkpoint and remove the process group from the state dict
checkpoint = torch.hub.load_state_dict_from_url(swav_checkpoint_path, map_location="cuda:0")
for key in list(checkpoint['state_dict'].keys()):
    if 'queue' in key:
        del checkpoint['state_dict'][key]

swav_model = SwAV(
    gpus=1,
    num_samples=len(unlabelled_dataset),
    batch_size=batch_size * 2,  # Account for the two views per image
    dataset='imagenet',
    nmb_crops=[2]  # Add this line to specify the number of views per image
)
swav_model.load_state_dict(checkpoint['state_dict'])
torch.autograd.set_detect_anomaly(True)
# Override the lr_scheduler_step method
def custom_lr_scheduler_step(self, epoch, val_loss, unused_arg):
    _, schedulers = self.configure_optimizers()
    for scheduler_dict in schedulers:
        scheduler_dict['scheduler'].step()

swav_model.lr_scheduler_step = custom_lr_scheduler_step.__get__(swav_model, SwAV)

trainer = pl.Trainer(gpus=1, max_epochs=num_epochs)
trainer.fit(swav_model, unlabelled_loader)

return swav_model.model
'''
unlabelled_dataset = UnlabelledDataset(root=unlabelled_root, transform=unlabelled_transforms)
unlabelled_dataset = torch.utils.data.Subset(unlabelled_dataset, range(0, 16))  # Use only the first 100 samples
unlabelled_loader = DataLoader(unlabelled_dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=multi_view_collate_fn)

checkpoint = torch.hub.load_state_dict_from_url(swav_checkpoint_path, map_location="cuda:0")

swav_model = SwAV(
    gpus=1,
    dataset='imagenet',
    num_samples=len(unlabelled_dataset),
    batch_size=batch_size * 2,  # Account for the two views per image
    nmb_crops=[2, 224]  # two crops of size 224
)
swav_model.load_state_dict(checkpoint['state_dict'])
swav_model.size()
torch.autograd.set_detect_anomaly(True)

trainer = pl.Trainer(gpus=1, max_epochs=num_epochs)
trainer.fit(swav_model, unlabelled_loader)

return swav_model.model

pretrained_backbone = train_swav(‘https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar’, ‘/content/drive/MyDrive/ThesisData/unlabeled/train05/’, num_epochs=1)

ANY HELP WILL BE MUCH APPRECIATED!

Here is the full Traceback of the error

RuntimeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 pretrained_backbone = train_swav(‘https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar’, ‘/content/drive/MyDrive/ThesisData/unlabeled/train05/’, num_epochs=1)

35 frames
in train_swav(swav_checkpoint_path, unlabelled_root, num_epochs, batch_size)
59
60 trainer = pl.Trainer(gpus=1, max_epochs=num_epochs)
—> 61 trainer.fit(swav_model, unlabelled_loader)
62
63 return swav_model.model

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
606 model = self._maybe_unwrap_optimized(model)
607 self.strategy._lightning_module = model
→ 608 call._call_and_handle_interrupt(
609 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
610 )

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
37 else:
—> 38 return trainer_fn(*args, **kwargs)
39
40 except _TunerExitException:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
648 model_connected=self.lightning_module is not None,
649 )
→ 650 self._run(model, ckpt_path=self.ckpt_path)
651
652 assert self.state.stopped

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1110 self._checkpoint_connector.resume_end()
1111
→ 1112 results = self._run_stage()
1113
1114 log.detail(f"{self.class.name}: trainer tearing down")

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
1189 if self.predicting:
1190 return self._run_predict()
→ 1191 self._run_train()
1192
1193 def _pre_training_routine(self) → None:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1212
1213 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
→ 1214 self.fit_loop.run()
1215
1216 def _run_evaluate(self) → _EVALUATE_OUTPUT:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
→ 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
265 self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device)
266 with self.trainer.profiler.profile(“run_training_epoch”):
→ 267 self._outputs = self.epoch_loop.run(self._data_fetcher)
268
269 def on_advance_end(self) → None:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
→ 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in advance(self, data_fetcher)
211
212 with self.trainer.profiler.profile(“run_training_batch”):
→ 213 batch_output = self.batch_loop.run(kwargs)
214
215 self.batch_progress.increment_processed()

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
→ 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py in advance(self, kwargs)
86 self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get(“batch_idx”, 0)
87 )
—> 88 outputs = self.optimizer_loop.run(optimizers, kwargs)
89 else:
90 outputs = self.manual_loop.run(kwargs)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
→ 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in advance(self, optimizers, kwargs)
200 kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
201
→ 202 result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
203 if result.loss is not None:
204 # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _run_optimization(self, kwargs, optimizer)
247 else:
248 # the batch_idx is optional with inter-batch parallelism
→ 249 self._optimizer_step(optimizer, opt_idx, kwargs.get(“batch_idx”, 0), closure)
250
251 result = closure.consume_result()

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
368 )
369 kwargs[“using_native_amp”] = isinstance(self.trainer.precision_plugin, MixedPrecisionPlugin)
→ 370 self.trainer._call_lightning_module_hook(
371 “optimizer_step”,
372 self.trainer.current_epoch,

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
1354
1355 with self.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"):
→ 1356 output = fn(*args, **kwargs)
1357
1358 # restore current_fx when nested context

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_lbfgs)
1752
1753 “”"
→ 1754 optimizer.step(closure=optimizer_closure)
1755
1756 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) → None:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py in step(self, closure, **kwargs)
167
168 assert self._strategy is not None
→ 169 step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
170
171 self._on_after_step()

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
232 # TODO(fabric): remove assertion once strategy’s optimizer_step typing is fixed
233 assert isinstance(model, pl.LightningModule)
→ 234 return self.precision_plugin.optimizer_step(
235 optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
236 )

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py in optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs)
117 “”“Hook to run the optimizer step.”“”
118 closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
→ 119 return optimizer.step(closure=closure, **kwargs)
120
121 def _track_grad_norm(self, trainer: “pl.Trainer”) → None:

/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py in wrapper(*args, **kwargs)
67 instance._step_count += 1
68 wrapped = func.get(instance, cls)
—> 69 return wrapped(*args, **kwargs)
70
71 # Note that the returned function here is no longer a bound method,

/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
278 f"but got {result}.")
279
→ 280 out = func(*args, **kwargs)
281 self._optimizer_step_code()
282

/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py in _use_grad(self, *args, **kwargs)
31 try:
32 torch.set_grad_enabled(self.defaults[‘differentiable’])
—> 33 ret = func(self, *args, **kwargs)
34 finally:
35 torch.set_grad_enabled(prev_grad)

/usr/local/lib/python3.10/dist-packages/torch/optim/adam.py in step(self, closure)
119 if closure is not None:
120 with torch.enable_grad():
→ 121 loss = closure()
122
123 for group in self.param_groups:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
103 consistent with the PrecisionPlugin subclasses that cannot pass optimizer.step(closure) directly.
104 “”"
→ 105 closure_result = closure()
106 self._after_closure(model, optimizer, optimizer_idx)
107 return closure_result

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in call(self, *args, **kwargs)
147
148 def call(self, *args: Any, **kwargs: Any) → Optional[Tensor]:
→ 149 self._result = self.closure(*args, **kwargs)
150 return self._result.loss
151

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in closure(self, *args, **kwargs)
133
134 def closure(self, *args: Any, **kwargs: Any) → ClosureResult:
→ 135 step_output = self._step_fn()
136
137 if step_output.closure_loss is None:

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _training_step(self, kwargs)
417 “”"
418 # manually capture logged metrics
→ 419 training_step_output = self.trainer._call_strategy_hook(“training_step”, *kwargs.values())
420 self.trainer.strategy.post_training_step()
421

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
1492
1493 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"):
→ 1494 output = fn(*args, **kwargs)
1495
1496 # restore current_fx when nested context

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py in training_step(self, *args, **kwargs)
376 with self.precision_plugin.train_step_context():
377 assert isinstance(self.model, TrainingStep)
→ 378 return self.model.training_step(*args, **kwargs)
379
380 def post_training_step(self) → None:

/usr/local/lib/python3.10/dist-packages/pl_bolts/models/self_supervised/swav/swav_module.py in training_step(self, batch, batch_idx)
230
231 def training_step(self, batch, batch_idx):
→ 232 loss = self.shared_step(batch)
233
234 self.log(“train_loss”, loss, on_step=True, on_epoch=False)

/usr/local/lib/python3.10/dist-packages/pl_bolts/models/self_supervised/swav/swav_module.py in shared_step(self, batch)
217
218 # SWAV loss computation
→ 219 loss, queue, use_queue = self.criterion(
220 output=output,
221 embedding=embedding,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/pl_bolts/models/self_supervised/swav/loss.py in forward(self, output, embedding, prototype_weights, batch_size, queue, use_queue)
74 for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
75 p = self.softmax(output[batch_size * v : batch_size * (v + 1)] / self.temperature)
—> 76 subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
77 loss += subloss / (np.sum(self.num_crops) - 1)
78 loss /= len(self.crops_for_assign) # type: ignore

RuntimeError: The size of tensor a (3) must match the size of tensor b (0) at non-singleton dimension 0

Based on the error message and stacktrace the loss calculation fails in:

subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))

with

RuntimeError: The size of tensor a (3) must match the size of tensor b (0) at non-singleton dimension 0

Check the shapes of both tensors (q and p) and make sure you can properly multiply them as the shape mismatch causes the issue currently.