Training function good practices

Hi, after reading through multiple Pytorch implementations, I noted that there really isn’t a standard on the function call for the train() and test() function.

On one extreme, most of the time, I’ve seen people simply put the train() and test() function inside main.py. Then initialize training using some form of:
python main.py --mode=train ...
In which, the train() only take a single parameter (usually epoch) or none as all.
Basically, model, dataset, DataLoader, loss and optimizer are global variables in the scope of main.py
If I follow this convention, for my case; In which I compared the performance of two different custom loss → which requires two different custom Dataloaders (and Samplers)
Which meant I have to write two different train() functions.

On the other extreme, if I pass all parameters into train(), the arguments would be: epoch, model, train_set(or dataloader), loss, optimizer.
Which meant the train() function would be extremely bloated.

As such, I wanna ask you guys - What is your preference for the train functions ? train(epoch) or train(epoch, model, dataloader,…) ? Or is it some middle ground between these two extreme.

On the other extreme, if I pass all parameters into train(), the arguments would be: epoch, model, train_set(or dataloader), loss, optimizer.
Which meant the train() function would be extremely bloated.

I would like to think the other way actually since train is a function it is expected to perform one function/task i.e to train the model with the given parameters therefore, I would suggest passing everything as a parameter instead of having global references (which is also the base of functional programming).

Also, generally, I tend to use PyTorch lightning to manage this boilerplate for me. To do this I define the LightningModule and LightningDataModules separately and then depending on the choice of dataset and hyperparameters I pass them to the Lighting Trainer accordingly.

Hope this helps :smiley: