.to(device) takes far longer for test data than train data

I am currently trying to improve the efficiency of my code, but I have found an issue I can’t figure out - the time taken to send tensors to my gpu is far greater for the test data that the training data, despite there being far more training data.

I am using a custom dataset and random_split to create the train_ and test_datasets:

train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])

before passing the datasets into dataloaders

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  num_workers=num_workers,
                                  shuffle=True)
    
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 num_workers=num_workers)

During training and testing, the data is then sent to the device:

for data1, data2, data3 in dataloader:
    data1, data2, data3 = data1.to(device), data2.to(device), data3.to(device)

Which is exactly the same for both training and testing loops, the only difference is that the for loop is within with torch.inference_mode() during testing.

I am using cProfile to identify where the inefficiencies in my program are coming from and noticed that .to() was taking up ~1/4 of my total runtime. To further investigate I made the groups of .to(device)s functions, i.e:

def send_to_device_train(data1, data2, data3, device):
    return data1.to(device), data2.to(device), data3.to(device)

and a similar function for test data.

When I run a cProfile I get this unusual result:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     6602    5.352    0.001    5.352    0.001 {method 'to' of 'torch._C._TensorBase' objects}
      152    0.001    0.000    2.373    0.016 /home/user/engine.py:13(send_to_device_test)
      581    0.002    0.000    0.179    0.000 /home/user/engine.py:8(send_to_device_train)

Despite performing .to() one quarter as many times, send_to_device_test takes 13 times longer (at least for cumulative time).

I tried removing the .to()s from the torch.inference_mode() (calling the context manager after the tensors have been sent) and also set shuffle=True, in the hope that might be the culprit, but the problem persists.

I wonder if anyone would be able to help me solve this issue?