Tooling for checkpointing and restart multi model training?

Can we discuss tooling btw? (If there is a better channel for this, let me know, or maybe even we should create one?)

The general question:
I wrote a pure pytorch prototype last night using wandb logging, and saved JUST the model checkpoint as artifacts. It crashed in the middle of the night and I am manually restarting from that checkpoint.

Should I adopt pytorch-lightning? I’ve used it in the past but I used to run into complications with it when using stranger models like GANs etc. But it also was nice for DDP on multi-GPUs on a single node. Should I use Are there other tools people prefer for this sort of thing?

The more specific setting I’m working in:
I adopted facebook vicreg code to PRETRAIN a noncontrastive vicreg model that torchsynth synth1b1 audio synthesis PARAMETERS embeddings to match orchsynth synth1b1 audio synthesis AUDIO embeddings. It has some weird facebook pure LARS optimizer I couldn’t get working so I used good old SGD instead. I am currently training it on a single GPU. It crashed after 60K batches for reasons I’m still investigating. The training loss is still going down.

Next, I am going to write a DOWNSTREAM model that uses held out dev and test batches that take the vicreg model and learns a parameter embedding to parameter model on dev, either freeze or tuning the vicreg model, and evaluate on train. I will do QUALITATIVE evaluation by listening to the torchsynth synth of torchsynth audio from predicted parameters versus torchsynth audio from true parameters. I will do QUANTITATIVE evaluation I think using l1 melspectrogram distance of torchsynth audio from predicted parameters versus torchsynth audio from true parameters. (I want to use MSSTFT distance using auraloss but lambdalabs has a funny numba version and auraloss dependency on librosa and thus numba is a huge pain. So instead I code my own audio distance using torchaudio I think.) I first will pretrain vicreg then learn downstream later. Later I might interleave the training (even tho they both have separate optimizers etc.)

I’ll share code and models alter when it’s slightly more mature.

You may want to narrow down your question to make it more specific and easier to answer.

I work for W&B and using it for saving logs and models, and resuming crashed runs is very common.
Here is some documentation: Resume Runs - Documentation
Let me know if you have any other issues.

As for using pytorch-lightning, it is a pretty mature library and it really depends on your preference whether you like using their abstractions to make your code cleaner.
Both wandb and pytorch-lightning (including when using the WandbLogger) will work fine for training GANs and resuming crashed runs.