I was following AssemblyAI’s tutorial,
code had some bugs but I fixed them,except this
full code in colab
def training_step(self, batch, batch_index, optimizer_idx):
real_imgs, _ = batch
real_imgs = real_imgs.cuda()
# sample noise
z = torch.randn(real_imgs.shape[0], self.hparams.latent_dim, device=real_imgs.device)
z = z.type_as(real_imgs)
# train generator: maximize log(D(G(z))), D:discriminator, G:generator, z: random noise
if optimizer_idx == 0:
fake_imgs = self(z) # will execute forward()
y_hat = self.discriminator(fake_imgs).cuda()
y = torch.ones(real_imgs.size(0), 1).cuda()
y = y.type_as(real_imgs)
g_loss = self.adversarial_loss(y_hat, y).cuda()
log_dict = {"g_loss": g_loss}
return {"loss":g_loss, "progress_bar":log_dict, "log":log_dict}
# train discriminator: maximize log(D(x)) + log(1 - D(G(z)))
if optimizer_idx == 1:
# how well can it label as real
y_hat_real = self.discriminator(real_imgs).cuda()
y_real = torch.ones(real_imgs.size(0), 1).cuda()
y_real = y_real.type_as(real_imgs)
real_loss = self.adversarial_loss(y_hat_real, y_hat).cuda()
# how well can it label as fake
y_hat_fake = self.discriminator(self(z).detach())
y_fake = torch.zeros(real_imgs.size(0), 1) .cuda()
y_fake = y_fake.type_as(real_imgs)
fake_loss = self.adversarial_loss(y_hat_fake, y_fake).cuda()
d_loss = (real_loss + fake_loss / 2)
log_dict = {"d_loss": d_loss}
return {"loss":d_loss, "progress_bar":log_dict, "log":log_dict}
trainer = pl.Trainer(max_epochs=20, accelerator="gpu")
trainer.fit(model, dm)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[11], line 2
1 trainer = pl.Trainer(max_epochs=20, accelerator="gpu")
----> 2 trainer.fit(model, dm)
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\trainer.py:529, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
527 model = _maybe_unwrap_optimized(model)
528 self.strategy._lightning_module = model
--> 529 call._call_and_handle_interrupt(
530 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
531 )
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 if trainer.strategy.launcher is not None:
41 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42 return trainer_fn(*args, **kwargs)
44 except _TunerExitException:
45 _call_teardown_hook(trainer)
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\trainer.py:568, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
558 self._data_connector.attach_data(
559 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
560 )
562 ckpt_path = self._checkpoint_connector._select_ckpt_path(
563 self.state.fn,
564 ckpt_path,
565 model_provided=True,
566 model_connected=self.lightning_module is not None,
567 )
--> 568 self._run(model, ckpt_path=ckpt_path)
570 assert self.state.stopped
571 self.training = False
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\trainer.py:973, in Trainer._run(self, model, ckpt_path)
968 self._signal_connector.register_signal_handlers()
970 # ----------------------------
971 # RUN THE TRAINER
972 # ----------------------------
--> 973 results = self._run_stage()
975 # ----------------------------
976 # POST-Training CLEAN UP
977 # ----------------------------
978 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1016, in Trainer._run_stage(self)
1014 self._run_sanity_check()
1015 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1016 self.fit_loop.run()
1017 return None
1018 raise RuntimeError(f"Unexpected state {self.state}")
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:201, in _FitLoop.run(self)
199 try:
200 self.on_advance_start()
--> 201 self.advance()
202 self.on_advance_end()
203 self._restarting = False
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:354, in _FitLoop.advance(self)
352 self._data_fetcher.setup(combined_loader)
353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354 self.epoch_loop.run(self._data_fetcher)
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher)
131 while not self.done:
132 try:
--> 133 self.advance(data_fetcher)
134 self.on_advance_end()
135 self._restarting = False
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\training_epoch_loop.py:220, in _TrainingEpochLoop.advance(self, data_fetcher)
218 batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
219 else:
--> 220 batch_output = self.manual_optimization.run(kwargs)
222 self.batch_progress.increment_processed()
224 # update non-plateau LR schedulers
225 # update epoch-interval ones only when we are at the end of training epoch
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\optimization\manual.py:90, in _ManualOptimization.run(self, kwargs)
88 self.on_run_start()
89 with suppress(StopIteration): # no loop to break at this level
---> 90 self.advance(kwargs)
91 self._restarting = False
92 return self.on_run_end()
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\loops\optimization\manual.py:109, in _ManualOptimization.advance(self, kwargs)
106 trainer = self.trainer
108 # manually capture logged metrics
--> 109 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
110 del kwargs # release the batch from memory
111 self.trainer.strategy.post_training_step()
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\trainer\call.py:291, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
288 return None
290 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 291 output = fn(*args, **kwargs)
293 # restore current_fx when nested context
294 pl_module._current_fx_name = prev_fx_name
File ~\OneDrive\Masaüstü\pytorch\MLvenv\Lib\site-packages\pytorch_lightning\strategies\strategy.py:367, in Strategy.training_step(self, *args, **kwargs)
365 with self.precision_plugin.train_step_context():
366 assert isinstance(self.model, TrainingStep)
--> 367 return self.model.training_step(*args, **kwargs)
TypeError: GAN.training_step() missing 1 required positional argument: 'optimizer_idx'