def train_loop(
epochs,
lr,
loss_fn,
train_dataset,
val_dataset,
pre_trained_model=False,
**model_args
):
if pre_trained_model == False:
model = MyLSTM(
model_args["input_size"],
model_args["embedding_size"],
model_args["hidden_size"],
model_args["num_layers"],
model_args["dense_output_layers"],
model_args["input_sequence_length"],
model_args["output_sequence_length"]
)
else:
model = model_args["model"]
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
#if torch.cuda.device_count() > 1:
# model = torch.nn.DataParallel(model)
model = model.float()
model = model.to(device)
if next(model.parameters()).is_cuda:
print("RUNNING cuda")
else:
print("RUNNING cpu")
optimizer = torch.optim.Adam(
model.parameters(), lr, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0
)
for epoch in range(epochs):
model.train()
torch.autograd.set_detect_anomaly(True)
running_loss = 0.0
train_loss = 0.0
epoch_steps = 0
state_h, state_c = model.zero_state(train_dataset.batch_size)
state_h, state_c = state_h.to(device), state_c.to(device)
for i, data in enumerate(train_dataset, 0):
inputs, labels = data
inputs, labels = inputs.float().to(device), labels.float().to(device)
optimizer.zero_grad()
outputs, (state_h, state_c) = model(inputs, (state_h, state_c))
padded_label, _ = pad_packed_sequence(labels, batch_first=True)
loss = loss_fn(reduction='mean')(outputs, padded_label.transpose(1,2))
loss.backward(retain_graph=True)
optimizer.step()
running_loss += loss.item()
train_loss += loss.item()
epoch_steps += 1
if i % 10 == 9:
print("{epoch: %d, batch: %5d} running_loss: %.4f" % (epoch + 1,i + 1,running_loss /\
(10)))
running_loss = 0.0
try:
train_loss /= (epoch_steps)
except:
continue
print("{epoch: %d} train_loss: %.4f" % (epoch + 1,train_loss))
model.eval()
val_loss = 0.0
val_steps = 0
for i, data in enumerate(val_dataset, 0):
with torch.no_grad():
X, y = data
y, _ = pad_packed_sequence(y, batch_first=True)
X, y = X.float().to(device), y.float().to(device)
pred = model(X.to(device))
loss = loss_fn(reduction='sum')(pred, y).item()
val_loss += loss
val_steps += 1
try:
val_loss /= (val_steps*y.size(0)*y.size(1))
except:
continue
print("{epoch: %d} val_loss: %.4f" % (epoch + 1,val_loss))
return model
You are most likely trying to backpropagate through your states and might want to .detach()
them in the training loop.
Also, using retrain_graph=True
is usually wrong as users try to fix other (valid) errors with it and can also yield to these issues.