I would like to train 3D Unet, but the convergence is very slow, so I wanted to initialize the model with Imagenet weights, but I don’t know how. I found an initialization tutorial on the net and tried to adapt it to my code, but I get this error when training:
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 1, 96, 96, 96]
Here are the line of code that I added to my original class in order to be able to initialize my model with resnet34 weights : (and below I add the entire class)
def __init__(self,transfer=False):
# And
self.feature_extractor = models.resnet34(pretrained=transfer)
if transfer:
# layers are frozen by using eval()
self.feature_extractor.eval()
# freeze params
for param in self.feature_extractor.parameters():
param.requires_grad = False
def _forward_features(self, x):
x = self.feature_extractor(x)
return x
# and I added the first line of code in forward method
def forward(self, data):
data = self._forward_features(data)
pred = self.model(data)
return pred
And here is my model :
class SegmentModel(pl.LightningModule):
def __init__(self,transfer=False):
super().__init__()
self.model = UNet()
self.loss_fn = DiceLoss(mode ='multiclass',classes=4, from_logits=True)
self.metric = torchmetrics.JaccardIndex(num_classes=4)
self.lr = 5e-4
self.feature_extractor = models.resnet34(pretrained=transfer)
if transfer:
# layers are frozen by using eval()
self.feature_extractor.eval()
# freeze params
for param in self.feature_extractor.parameters():
param.requires_grad = False
def _forward_features(self, x):
x = self.feature_extractor(x)
return x
def forward(self, data):
data = self._forward_features(data)
pred = self.model(data)
return pred
def training_step(self, batch, batch_idx):
img = batch["MRI"]["data"]
mask = batch["Label"]["data"][:,0]
mask = mask.long()
pred = self(img)
loss = self.loss_fn(pred, mask)
iou = self.metric(pred,mask)
self.log("Train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('Train_iou',iou, on_step=False, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
img = batch["MRI"]["data"]
mask = batch["Label"]["data"][:,0]
mask = mask.long()
pred = self(img)
loss = self.loss_fn(pred, mask)
iou = self.metric(pred,mask)
self.log("Val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('Val_iou',iou, on_step=False, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
opt=torch.optim.AdamW(self.parameters(),lr=self.lr)
#scheduler=CosineAnnealingWarmRestarts(opt,T_0=25,eta_min=1e-6)
scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=3,verbose=True)
return {'optimizer':opt,'scheduler':scheduler}
knowing that Unet() is the 3D implementation of Unet model (using 3D convolutions and trilinear upsampling…)
And here I provide the whole text of the error :
when I run
trainer.fit(model, train_loader, val_loader)
RuntimeError Traceback (most recent call last)
Input In [37], in <cell line: 3>()
1 # Train the model.
2 # This might take some hours depending on your GPU
----> 3 trainer.fit(model, train_loader)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:770, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
751 r"""
752 Runs the full optimization routine.
753
(...)
767 datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
768 """
769 self.strategy.model = model
--> 770 self._call_and_handle_interrupt(
771 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
772 )
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:723, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
721 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
722 else:
--> 723 return trainer_fn(*args, **kwargs)
724 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
725 except KeyboardInterrupt as exception:
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:811, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
807 ckpt_path = ckpt_path or self.resume_from_checkpoint
808 self._ckpt_path = self.__set_ckpt_path(
809 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
810 )
--> 811 results = self._run(model, ckpt_path=self.ckpt_path)
813 assert self.state.stopped
814 self.training = False
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1236, in Trainer._run(self, model, ckpt_path)
1232 self._checkpoint_connector.restore_training_state()
1234 self._checkpoint_connector.resume_end()
-> 1236 results = self._run_stage()
1238 log.detail(f"{self.__class__.__name__}: trainer tearing down")
1239 self._teardown()
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1323, in Trainer._run_stage(self)
1321 if self.predicting:
1322 return self._run_predict()
-> 1323 return self._run_train()
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1353, in Trainer._run_train(self)
1351 self.fit_loop.trainer = self
1352 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1353 self.fit_loop.run()
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:269, in FitLoop.advance(self)
265 self._data_fetcher.setup(
266 dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
267 )
268 with self.trainer.profiler.profile("run_training_epoch"):
--> 269 self._outputs = self.epoch_loop.run(self._data_fetcher)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:208, in TrainingEpochLoop.advance(self, data_fetcher)
205 self.batch_progress.increment_started()
207 with self.trainer.profiler.profile("run_training_batch"):
--> 208 batch_output = self.batch_loop.run(batch, batch_idx)
210 self.batch_progress.increment_processed()
212 # update non-plateau LR schedulers
213 # update epoch-interval ones only when we are at the end of training epoch
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:88, in TrainingBatchLoop.advance(self, batch, batch_idx)
86 if self.trainer.lightning_module.automatic_optimization:
87 optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
---> 88 outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
89 else:
90 outputs = self.manual_loop.run(split_batch, batch_idx)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:204, in Loop.run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:203, in OptimizerLoop.advance(self, batch, *args, **kwargs)
202 def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
--> 203 result = self._run_optimization(
204 batch,
205 self._batch_idx,
206 self._optimizers[self.optim_progress.optimizer_position],
207 self.optimizer_idx,
208 )
209 if result.loss is not None:
210 # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
211 # would be skipped otherwise
212 self._outputs[self.optimizer_idx] = result.asdict()
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:256, in OptimizerLoop._run_optimization(self, split_batch, batch_idx, optimizer, opt_idx)
249 closure()
251 # ------------------------------
252 # BACKWARD PASS
253 # ------------------------------
254 # gradient update with accumulated gradients
255 else:
--> 256 self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
258 result = closure.consume_result()
260 if result.loss is not None:
261 # if no result, user decided to skip optimization
262 # otherwise update running loss + reset accumulated loss
263 # TODO: find proper way to handle updating running loss
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:369, in OptimizerLoop._optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
366 self.optim_progress.optimizer.step.increment_ready()
368 # model hook
--> 369 self.trainer._call_lightning_module_hook(
370 "optimizer_step",
371 self.trainer.current_epoch,
372 batch_idx,
373 optimizer,
374 opt_idx,
375 train_step_and_backward_closure,
376 on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
377 using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
378 using_lbfgs=is_lbfgs,
379 )
381 if not should_accumulate:
382 self.optim_progress.optimizer.step.increment_completed()
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1595, in Trainer._call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
1592 pl_module._current_fx_name = hook_name
1594 with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1595 output = fn(*args, **kwargs)
1597 # restore current_fx when nested context
1598 pl_module._current_fx_name = prev_fx_name
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py:1646, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1564 def optimizer_step(
1565 self,
1566 epoch: int,
(...)
1573 using_lbfgs: bool = False,
1574 ) -> None:
1575 r"""
1576 Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
1577 each optimizer.
(...)
1644
1645 """
-> 1646 optimizer.step(closure=optimizer_closure)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:168, in LightningOptimizer.step(self, closure, **kwargs)
165 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
167 assert self._strategy is not None
--> 168 step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
170 self._on_after_step()
172 return step_output
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:193, in Strategy.optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
183 """Performs the actual optimizer step.
184
185 Args:
(...)
190 **kwargs: Any extra arguments to ``optimizer.step``
191 """
192 model = model or self.lightning_module
--> 193 return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:155, in PrecisionPlugin.optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
153 if isinstance(model, pl.LightningModule):
154 closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 155 return optimizer.step(closure=closure, **kwargs)
File /notebooks/envabir/lib/python3.9/site-packages/torch/optim/optimizer.py:88, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs)
86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
87 with torch.autograd.profiler.record_function(profile_name):
---> 88 return func(*args, **kwargs)
File /notebooks/envabir/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /notebooks/envabir/lib/python3.9/site-packages/torch/optim/adamw.py:100, in AdamW.step(self, closure)
98 if closure is not None:
99 with torch.enable_grad():
--> 100 loss = closure()
102 for group in self.param_groups:
103 params_with_grad = []
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:140, in PrecisionPlugin._wrap_closure(self, model, optimizer, optimizer_idx, closure)
127 def _wrap_closure(
128 self,
129 model: "pl.LightningModule",
(...)
132 closure: Callable[[], Any],
133 ) -> Any:
134 """This double-closure allows makes sure the ``closure`` is executed before the
135 ``on_before_optimizer_step`` hook is called.
136
137 The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
138 consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
139 """
--> 140 closure_result = closure()
141 self._after_closure(model, optimizer, optimizer_idx)
142 return closure_result
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:148, in Closure.__call__(self, *args, **kwargs)
147 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 148 self._result = self.closure(*args, **kwargs)
149 return self._result.loss
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:134, in Closure.closure(self, *args, **kwargs)
133 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 134 step_output = self._step_fn()
136 if step_output.closure_loss is None:
137 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:427, in OptimizerLoop._training_step(self, split_batch, batch_idx, opt_idx)
422 step_kwargs = _build_training_step_kwargs(
423 lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
424 )
426 # manually capture logged metrics
--> 427 training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
428 self.trainer.strategy.post_training_step()
430 model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1765, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
1762 return
1764 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1765 output = fn(*args, **kwargs)
1767 # restore current_fx when nested context
1768 pl_module._current_fx_name = prev_fx_name
File /notebooks/envabir/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:333, in Strategy.training_step(self, *args, **kwargs)
328 """The actual training step.
329
330 See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
331 """
332 with self.precision_plugin.train_step_context():
--> 333 return self.model.training_step(*args, **kwargs)
Input In [32], in SegmentMadisonOrgans.training_step(self, batch, batch_idx)
37 mask = batch["Label"]["data"][:,0] # Remove single channel as CrossEntropyLoss expects NxHxW
38 mask = mask.long()
---> 40 pred = self(img)
41 loss = self.loss_fn(pred, mask)
42 iou = self.metric(pred,mask)
File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
Input In [32], in SegmentMadisonOrgans.forward(self, data)
29 def forward(self, data):
---> 30 data = self._forward_features(data)
31 pred = self.model(data)
32 return pred
Input In [32], in SegmentMadisonOrgans._forward_features(self, x)
25 def _forward_features(self, x):
---> 26 x = self.feature_extractor(x)
27 return x
File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /notebooks/envabir/lib/python3.9/site-packages/torchvision/models/resnet.py:283, in ResNet.forward(self, x)
282 def forward(self, x: Tensor) -> Tensor:
--> 283 return self._forward_impl(x)
File /notebooks/envabir/lib/python3.9/site-packages/torchvision/models/resnet.py:266, in ResNet._forward_impl(self, x)
264 def _forward_impl(self, x: Tensor) -> Tensor:
265 # See note [TorchScript super()]
--> 266 x = self.conv1(x)
267 x = self.bn1(x)
268 x = self.relu(x)
File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/conv.py:447, in Conv2d.forward(self, input)
446 def forward(self, input: Tensor) -> Tensor:
--> 447 return self._conv_forward(input, self.weight, self.bias)
File /notebooks/envabir/lib/python3.9/site-packages/torch/nn/modules/conv.py:443, in Conv2d._conv_forward(self, input, weight, bias)
439 if self.padding_mode != 'zeros':
440 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
441 weight, bias, self.stride,
442 _pair(0), self.dilation, self.groups)
--> 443 return F.conv2d(input, weight, bias, self.stride,
444 self.padding, self.dilation, self.groups)
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 1, 96, 96, 96]