I’m new to PyTorch, and I was wondering if there’s any implementation of (or a reason why I wouldn’t want) an extension of nn.Module which contains a method that trains the model? I was thinking something like:
class TrainModule(nn.Module):
def __init__(self):
super().__init__()
def train_model(self, traindata, valdata, epochs, batch_size,
criterion, optimizer, scheduler):
from tqdm import tqdm
trainloader = DataLoader(traindata, shuffle = True,
batch_size = batch_size)
valloader = DataLoader(valdata, shuffle = True,
batch_size = batch_size)
num_batches = len(traindata) // batch_size
if len(traindata) % batch_size:
num_batches += 1
for epoch in range(epochs):
with tqdm(total = len(traindata)) as epoch_pbar:
epoch_pbar.set_description(f'Epoch {epoch}')
acc_loss = 0
for idx, (inputs, target) in enumerate(trainloader):
# Reset the gradients
optimizer.zero_grad()
# Do a forward pass, calculate loss and backpropagate
outputs = self.forward(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
# Add loss to accumulated loss
acc_loss += loss
# Update progress bar description
avg_loss = acc_loss / (idx + 1)
desc = f'Epoch {epoch} - loss {avg_loss:.4f}'
epoch_pbar.set_description(desc)
epoch_pbar.update(inputs.shape[0])
# Learning rate decay
avg_loss = acc_loss / num_batches
scheduler.step(avg_loss)
# Deal with validation data
(...)
And allowing string inputs such as ‘adam’ for optimizer, or ‘mse’ for criterion, and so on.
Not having to write out these every time I build a model would be great, and it’d be consistent with the current use, as you can simply not use the train_model method, or overwrite it.