Performing cross validation on a model

I am trying to perform k-fold cross validation on my LSTM model. I have observed few implementations using Skorch bust none of them show how to use optimizer, scheduler, gradient clipping etc. So, I was just wondering on how to perform cross validation in Pytorch on a model that looks like this:

class LSTMTagger(nn.Module):

        def __init__(self):
            super(LSTMTagger, self).__init__()
    #      self.lstm1 = nn.LSTM(input_size = 1, hidden_size = 100)
    #      self.lstm2 = nn.LSTM(100, 50)
            self.embedding = 
    nn.Embedding(wv.vectors.shape[0],512)#embedding_matrix.shape[1])
            self.lstm1 = nn.LSTM(input_size = 512, hidden_size = 64, dropout = 
    0.1,batch_first=True,bidirectional = True)
            self.dropout = nn.Dropout(p = 0.25)
            self.linear1 = nn.Linear(in_features = 128, out_features = 64)
            self.dropout = nn.Dropout(p = 0.25)
            self.linear2 = nn.Linear(in_features = 64, out_features = 1)
            self.sigmoid = nn.Sigmoid()

        def forward(self, X):
            X_embed = self.embedding(X)
            outr1, (h, c) = self.lstm1(X_embed)
            h = h[-1]
            xr = self.dropout(h) 
            xr= self.linear1(xr)
            xr = self.dropout(xr)
            xr= self.linear2(xr)
            outr4 = self.sigmoid(xr)

            return outr4

model = LSTMTagger()
torch.multiprocessing.set_sharing_strategy('file_system')
if torch.cuda.device_count() > 1:
  print("Using ", torch.cuda.device_count(), " GPUs")

model = nn.DataParallel(model, device_ids=[0,1,2,3,4]) #py r
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

def train_epoch(
      model,
      data_loader,
      loss_fn,
      optimizer,
      device,
      scheduler,
      n_examples
    ):
      model = model.train()
      losses = []
      correct_predictions = 0
      for d in data_loader:
        print(f"Input ids: {np.shape(d['input_ids'])}\n len: {len(d['input_ids'][0])}")
        input_ids = d["input_ids"].to(device)
        targets = d["targets"].to(device)
        outputs = model(input_ids)
        shape = int(np.shape(outputs)[0])
        outputs= outputs.view(shape)
        _, preds = torch.max(outputs, dim=1)
        shape_preds = int(np.shape(preds)[0])
        preds= preds.view(shape_preds)
        loss = criterion(outputs.squeeze(), targets)
        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
      return correct_predictions.double() / n_examples, np.mean(losses)
EPOCHS = 6
optimizer = optim.Adam(model.parameters(), lr=2e-5)
total_steps = len(data_train) * EPOCHS
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
loss_fn = nn.CrossEntropyLoss().to(device)
history = defaultdict(list)
best_accuracy = 0
criterion = nn.BCELoss()

print('starting training')
# exit()

for epoch in range(EPOCHS):
#   y_ip= input('Please enter y_ip y_ip value')
  # print()
#   if y_ip=='b':
#     break
      print(f'Epoch {epoch + 1}/{EPOCHS}')
      print('-' * 10)
      train_acc, train_loss = train_epoch(
        model,
        data_train,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(df_train)
      )
      print(f'Train loss {train_loss} accuracy {train_acc}')
      val_acc, val_loss = eval_model(
        model,
        data_val,
        loss_fn,
        device,
        len(df_val)
      )
      print(f'Val   loss {val_loss} accuracy {val_acc}')
      checkpoint = {
        'epoch': epoch + 1,
        'valid_loss_min': val_loss,
        'valid_accuracy': val_acc,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(), #jc
      }
      # save checkpoint
      # save_ckp(checkpoint, False, checkpoint_path, best_model_path)
      if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state_g.bin')
        # save checkpoint
        save_ckp(checkpoint, True, 'state.pt', 'better.pt')
        best_accuracy = val_acc
      else:
        torch.save(model.state_dict(), 'model_state_g.bin')
        save_ckp(checkpoint, False, 'state.pt', 'better.pt')
test_acc, _ = eval_model(
  model,
  data_test,
  loss_fn,
  device,
  len(df_t
)
test_accuracy = test_acc.item()
print("Test accuracy of model is : {}\n".format(test_accuracy))

Here, I am using data loader to generate data tensors and I am defining functions beforehand to save checkpoints and I am not sure on how to do this in Skorch either. Thanks for your help in advance.