My main training and validation loop looks like this:
import torch
def train(net,dataloader,loss_func,optimizer,device):
net.train()
num_true_pred = 0
total_loss = 0
for images,labels in dataloader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = net(images)
loss = loss_func(outputs,labels)
loss.backward()
optimizer.step()
class_preds = outputs > 0 # for binary cross entropy
num_true_pred += torch.sum(class_preds == labels)
total_loss += loss
train_loss = total_loss.item() / len(dataloader)
train_acc = num_true_pred.item() / len(dataloader)
return net,train_loss,train_acc
def validate(net,dataloader,loss_func,device):
net.eval()
num_true_pred = 0
total_loss = 0
for images,labels in dataloader:
images = images.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = net(images)
loss = loss_func(outputs,labels)
class_preds = outputs > 0 # for binary cross entropy
num_true_pred += torch.sum(class_preds == labels)
total_loss += loss
val_loss = total_loss.item() / len(dataloader)
val_acc = num_true_pred.item() / len(dataloader)
return val_loss,val_acc
# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize datasets and dataloaders
train_dataset = ...
train_dataloader = ...
val_dataset = ...
val_dataloader = ...
# initialize net and move to GPU
net = ...
net = net.to(device)
# initialize loss function (e.g. binary cross entropy)
loss_func = ...
# initialize optimizer (e.g. SGD)
optimizer = ...
# number of epochs to train and validate for
num_epochs = ...
for epoch in range(num_epochs):
net,train_loss,train_acc = train(net,train_dataloader,loss_func,
optimizer,device)
val_loss,val_acc = validate(net,val_dataloader,loss_func,device)
My main question is about the train
function. Do I need to return the network net
as well as the train_loss
and train_acc
? What I mean is, is the network net
mutable such that any changes that are done to it inside the train
function reflect outside of it? I should then be able to change the for
loop at the end to:
for epoch in range(num_epochs):
train_loss,train_acc = train(net,train_dataloader,loss_func,
optimizer,device)
val_loss,val_acc = validate(net,val_dataloader,loss_func,device)
Also, please let me know if there are other ways to improve this code, since this is the template that I use for all my training and validation loops.