Training a model via a train method?

I’m new to PyTorch, and I was wondering if there’s any implementation of (or a reason why I wouldn’t want) an extension of nn.Module which contains a method that trains the model? I was thinking something like:

class TrainModule(nn.Module):
    def __init__(self):

    def train_model(self, traindata, valdata, epochs, batch_size, 
        criterion, optimizer, scheduler):
        from tqdm import tqdm

        trainloader = DataLoader(traindata, shuffle = True, 
            batch_size = batch_size)
        valloader = DataLoader(valdata, shuffle = True, 
            batch_size = batch_size)

        num_batches = len(traindata) // batch_size
        if len(traindata) % batch_size:
            num_batches += 1

        for epoch in range(epochs):
            with tqdm(total = len(traindata)) as epoch_pbar:
                epoch_pbar.set_description(f'Epoch {epoch}')
                acc_loss = 0
                for idx, (inputs, target) in enumerate(trainloader):

                    # Reset the gradients

                    # Do a forward pass, calculate loss and backpropagate
                    outputs = self.forward(inputs)
                    loss = criterion(outputs, target)

                    # Add loss to accumulated loss
                    acc_loss += loss

                    # Update progress bar description
                    avg_loss = acc_loss / (idx + 1)
                    desc = f'Epoch {epoch} - loss {avg_loss:.4f}'
            # Learning rate decay
            avg_loss = acc_loss / num_batches

            # Deal with validation data

And allowing string inputs such as ‘adam’ for optimizer, or ‘mse’ for criterion, and so on.

Not having to write out these every time I build a model would be great, and it’d be consistent with the current use, as you can simply not use the train_model method, or overwrite it.


What you actually wanted to create is both train and validation inside a single method. Just to confirm, right?

From some other ML libraries, that method would be called something like fit.
And it has lot of sense calling it fit because “fit” is foundation for the later terms overfitting and underfitting and we all know these are related to both train and validation losses.

The problem with this I think you can only create this as an abstract method, since models are different.

The dynamic nature of Python allows abstract methods, and I don’t see why this may not be possible.

Sure, we could call it fit, that’d make sense yes :slight_smile:

And yes, I’d like it to both train validate, with validation happening after each epoch.

Regarding the different models: Yes, but isn’t there a “template” that a typical model would use? Again, if you need anything else then you can simply write out a custom training loop as usual.

1 Like

Nice, @saattrupdan,

This is more like a question for PyTorch stuff and admins, really.
I just drafted few thoughts to shape your original idea.

Oh, so this might be the wrong forum? I could start a thread on the PyTorch Github, if you think that’d be more appropriate!

The problem is there are many ways and approaches to train a model. A closed one wouldn’t fit all the cases. Anyway you can code it yourself. The most general pipeline I found is to pass gt, list of inputs and miscelanea. That way you can fit any forward-backward model.

My guess is that createn a framework which fits any case would require lot of options being easier to code it yourself in the end.