Pytorch lighting module cannot be initialized with pre-trained weights

I would like to train 3D Unet, but the convergence is very slow, so I wanted to initialize the model with Imagenet weights, but I don’t know how. I found an initialization tutorial on the net and tried to adapt it to my code, but I get this error when training:

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 1, 96, 96, 96]

Here are the line of code that I added to my original class in order to be able to initialize my model with resnet34 weights : (and below I add the entire class)

def __init__(self,transfer=False):
# And
self.feature_extractor = models.resnet34(pretrained=transfer)
if transfer:
            # layers are frozen by using eval()
            self.feature_extractor.eval()
            # freeze params
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
 def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x

# and I added the first line of code in forward method

def forward(self, data):
        data = self._forward_features(data)
        pred = self.model(data)
        return pred

And here is my model :

class SegmentModel(pl.LightningModule):
    def __init__(self,transfer=False):
        super().__init__()
        
        self.model = UNet()
        self.loss_fn = DiceLoss(mode ='multiclass',classes=4, from_logits=True)

        self.metric = torchmetrics.JaccardIndex(num_classes=4)
        self.lr = 5e-4
        
        self.feature_extractor = models.resnet34(pretrained=transfer)
        
        if transfer:
            # layers are frozen by using eval()
            self.feature_extractor.eval()
            # freeze params
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
    
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x

    def forward(self, data):
        data = self._forward_features(data)
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        img = batch["MRI"]["data"] 
        mask = batch["Label"]["data"][:,0] 
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        iou = self.metric(pred,mask)
        self.log("Train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('Train_iou',iou, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
        
    def validation_step(self, batch, batch_idx):
        img = batch["MRI"]["data"]
        mask = batch["Label"]["data"][:,0]  
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        iou = self.metric(pred,mask)
        self.log("Val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('Val_iou',iou, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        opt=torch.optim.AdamW(self.parameters(),lr=self.lr)
        #scheduler=CosineAnnealingWarmRestarts(opt,T_0=25,eta_min=1e-6)
        scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=3,verbose=True)
        return {'optimizer':opt,'scheduler':scheduler}

knowing that Unet() is the 3D implementation of Unet model (using 3D convolutions and trilinear upsampling…)

And here I provide the whole text of the error :
when I run

trainer.fit(model, train_loader, val_loader)
RuntimeError                              Traceback (most recent call last)
Input In [37], in <cell line: 3>()
      1 # Train the model.
      2 # This might take some hours depending on your GPU
----> 3 trainer.fit(model, train_loader)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:770, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    751 r"""
    752 Runs the full optimization routine.
    753 
   (...)
    767     datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
    768 """
    769 self.strategy.model = model
--> 770 self._call_and_handle_interrupt(
    771     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    772 )

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:723, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    721         return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    722     else:
--> 723         return trainer_fn(*args, **kwargs)
    724 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    725 except KeyboardInterrupt as exception:

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:811, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    807 ckpt_path = ckpt_path or self.resume_from_checkpoint
    808 self._ckpt_path = self.__set_ckpt_path(
    809     ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    810 )
--> 811 results = self._run(model, ckpt_path=self.ckpt_path)
    813 assert self.state.stopped
    814 self.training = False

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1236, in Trainer._run(self, model, ckpt_path)
   1232 self._checkpoint_connector.restore_training_state()
   1234 self._checkpoint_connector.resume_end()
-> 1236 results = self._run_stage()
   1238 log.detail(f"{self.__class__.__name__}: trainer tearing down")
   1239 self._teardown()

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1323, in Trainer._run_stage(self)
   1321 if self.predicting:
   1322     return self._run_predict()
-> 1323 return self._run_train()

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1353, in Trainer._run_train(self)
   1351 self.fit_loop.trainer = self
   1352 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1353     self.fit_loop.run()

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
    202 try:
    203     self.on_advance_start(*args, **kwargs)
--> 204     self.advance(*args, **kwargs)
    205     self.on_advance_end()
    206     self._restarting = False

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:269, in FitLoop.advance(self)
    265 self._data_fetcher.setup(
    266     dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
    267 )
    268 with self.trainer.profiler.profile("run_training_epoch"):
--> 269     self._outputs = self.epoch_loop.run(self._data_fetcher)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
    202 try:
    203     self.on_advance_start(*args, **kwargs)
--> 204     self.advance(*args, **kwargs)
    205     self.on_advance_end()
    206     self._restarting = False

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:208, in TrainingEpochLoop.advance(self, data_fetcher)
    205     self.batch_progress.increment_started()
    207     with self.trainer.profiler.profile("run_training_batch"):
--> 208         batch_output = self.batch_loop.run(batch, batch_idx)
    210 self.batch_progress.increment_processed()
    212 # update non-plateau LR schedulers
    213 # update epoch-interval ones only when we are at the end of training epoch

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
    202 try:
    203     self.on_advance_start(*args, **kwargs)
--> 204     self.advance(*args, **kwargs)
    205     self.on_advance_end()
    206     self._restarting = False

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:88, in TrainingBatchLoop.advance(self, batch, batch_idx)
     86 if self.trainer.lightning_module.automatic_optimization:
     87     optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
---> 88     outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
     89 else:
     90     outputs = self.manual_loop.run(split_batch, batch_idx)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
    202 try:
    203     self.on_advance_start(*args, **kwargs)
--> 204     self.advance(*args, **kwargs)
    205     self.on_advance_end()
    206     self._restarting = False

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:203, in OptimizerLoop.advance(self, batch, *args, **kwargs)
    202 def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None:  # type: ignore[override]
--> 203     result = self._run_optimization(
    204         batch,
    205         self._batch_idx,
    206         self._optimizers[self.optim_progress.optimizer_position],
    207         self.optimizer_idx,
    208     )
    209     if result.loss is not None:
    210         # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
    211         # would be skipped otherwise
    212         self._outputs[self.optimizer_idx] = result.asdict()

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:256, in OptimizerLoop._run_optimization(self, split_batch, batch_idx, optimizer, opt_idx)
    249         closure()
    251 # ------------------------------
    252 # BACKWARD PASS
    253 # ------------------------------
    254 # gradient update with accumulated gradients
    255 else:
--> 256     self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
    258 result = closure.consume_result()
    260 if result.loss is not None:
    261     # if no result, user decided to skip optimization
    262     # otherwise update running loss + reset accumulated loss
    263     # TODO: find proper way to handle updating running loss

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:369, in OptimizerLoop._optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    366     self.optim_progress.optimizer.step.increment_ready()
    368 # model hook
--> 369 self.trainer._call_lightning_module_hook(
    370     "optimizer_step",
    371     self.trainer.current_epoch,
    372     batch_idx,
    373     optimizer,
    374     opt_idx,
    375     train_step_and_backward_closure,
    376     on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
    377     using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
    378     using_lbfgs=is_lbfgs,
    379 )
    381 if not should_accumulate:
    382     self.optim_progress.optimizer.step.increment_completed()

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1595, in Trainer._call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1592 pl_module._current_fx_name = hook_name
   1594 with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1595     output = fn(*args, **kwargs)
   1597 # restore current_fx when nested context
   1598 pl_module._current_fx_name = prev_fx_name

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py:1646, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1564 def optimizer_step(
   1565     self,
   1566     epoch: int,
   (...)
   1573     using_lbfgs: bool = False,
   1574 ) -> None:
   1575     r"""
   1576     Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
   1577     each optimizer.
   (...)
   1644 
   1645     """
-> 1646     optimizer.step(closure=optimizer_closure)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:168, in LightningOptimizer.step(self, closure, **kwargs)
    165     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    167 assert self._strategy is not None
--> 168 step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    170 self._on_after_step()
    172 return step_output

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:193, in Strategy.optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    183 """Performs the actual optimizer step.
    184 
    185 Args:
   (...)
    190     **kwargs: Any extra arguments to ``optimizer.step``
    191 """
    192 model = model or self.lightning_module
--> 193 return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:155, in PrecisionPlugin.optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
    153 if isinstance(model, pl.LightningModule):
    154     closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 155 return optimizer.step(closure=closure, **kwargs)

File /notebooks/envabir/lib/python3.9/site-packages/torch/optim/optimizer.py:88, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs)
     86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87 with torch.autograd.profiler.record_function(profile_name):
---> 88     return func(*args, **kwargs)

File /notebooks/envabir/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /notebooks/envabir/lib/python3.9/site-packages/torch/optim/adamw.py:100, in AdamW.step(self, closure)
     98 if closure is not None:
     99     with torch.enable_grad():
--> 100         loss = closure()
    102 for group in self.param_groups:
    103     params_with_grad = []

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:140, in PrecisionPlugin._wrap_closure(self, model, optimizer, optimizer_idx, closure)
    127 def _wrap_closure(
    128     self,
    129     model: "pl.LightningModule",
   (...)
    132     closure: Callable[[], Any],
    133 ) -> Any:
    134     """This double-closure allows makes sure the ``closure`` is executed before the
    135     ``on_before_optimizer_step`` hook is called.
    136 
    137     The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
    138     consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    139     """
--> 140     closure_result = closure()
    141     self._after_closure(model, optimizer, optimizer_idx)
    142     return closure_result

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:148, in Closure.__call__(self, *args, **kwargs)
    147 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 148     self._result = self.closure(*args, **kwargs)
    149     return self._result.loss

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:134, in Closure.closure(self, *args, **kwargs)
    133 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 134     step_output = self._step_fn()
    136     if step_output.closure_loss is None:
    137         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:427, in OptimizerLoop._training_step(self, split_batch, batch_idx, opt_idx)
    422 step_kwargs = _build_training_step_kwargs(
    423     lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
    424 )
    426 # manually capture logged metrics
--> 427 training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
    428 self.trainer.strategy.post_training_step()
    430 model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1765, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
   1762     return
   1764 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1765     output = fn(*args, **kwargs)
   1767 # restore current_fx when nested context
   1768 pl_module._current_fx_name = prev_fx_name

File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:333, in Strategy.training_step(self, *args, **kwargs)
    328 """The actual training step.
    329 
    330 See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
    331 """
    332 with self.precision_plugin.train_step_context():
--> 333     return self.model.training_step(*args, **kwargs)

Input In [32], in SegmentMadisonOrgans.training_step(self, batch, batch_idx)
     37 mask = batch["Label"]["data"][:,0]  # Remove single channel as CrossEntropyLoss expects NxHxW
     38 mask = mask.long()
---> 40 pred = self(img)
     41 loss = self.loss_fn(pred, mask)
     42 iou = self.metric(pred,mask)

File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

Input In [32], in SegmentMadisonOrgans.forward(self, data)
     29 def forward(self, data):
---> 30     data = self._forward_features(data)
     31     pred = self.model(data)
     32     return pred

Input In [32], in SegmentMadisonOrgans._forward_features(self, x)
     25 def _forward_features(self, x):
---> 26     x = self.feature_extractor(x)
     27     return x

File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File /notebooks/envabir/lib/python3.9/site-packages/torchvision/models/resnet.py:283, in ResNet.forward(self, x)
    282 def forward(self, x: Tensor) -> Tensor:
--> 283     return self._forward_impl(x)

File /notebooks/envabir/lib/python3.9/site-packages/torchvision/models/resnet.py:266, in ResNet._forward_impl(self, x)
    264 def _forward_impl(self, x: Tensor) -> Tensor:
    265     # See note [TorchScript super()]
--> 266     x = self.conv1(x)
    267     x = self.bn1(x)
    268     x = self.relu(x)

File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/conv.py:447, in Conv2d.forward(self, input)
    446 def forward(self, input: Tensor) -> Tensor:
--> 447     return self._conv_forward(input, self.weight, self.bias)

File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/conv.py:443, in Conv2d._conv_forward(self, input, weight, bias)
    439 if self.padding_mode != 'zeros':
    440     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    441                     weight, bias, self.stride,
    442                     _pair(0), self.dilation, self.groups)
--> 443 return F.conv2d(input, weight, bias, self.stride,
    444                 self.padding, self.dilation, self.groups)

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 1, 96, 96, 96]

Your self.feature_extractor is a resnet34 which works in image inputs and not 3D volumes.
You could either iterate this feature extractor for each slice of your volume or use another pretrained model accepting 5D inputs.