I am currently trying to improve the efficiency of my code, but I have found an issue I can’t figure out - the time taken to send tensors to my gpu is far greater for the test data that the training data, despite there being far more training data.
I am using a custom dataset and random_split
to create the train_ and test_datasets:
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])
before passing the datasets into dataloaders
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True)
test_dataloader = DataLoader(test_dataset,
batch_size=batch_size,
num_workers=num_workers)
During training and testing, the data is then sent to the device:
for data1, data2, data3 in dataloader:
data1, data2, data3 = data1.to(device), data2.to(device), data3.to(device)
Which is exactly the same for both training and testing loops, the only difference is that the for loop is within with torch.inference_mode()
during testing.
I am using cProfile to identify where the inefficiencies in my program are coming from and noticed that .to() was taking up ~1/4 of my total runtime. To further investigate I made the groups of .to(device)s functions, i.e:
def send_to_device_train(data1, data2, data3, device):
return data1.to(device), data2.to(device), data3.to(device)
and a similar function for test data.
When I run a cProfile I get this unusual result:
ncalls tottime percall cumtime percall filename:lineno(function)
6602 5.352 0.001 5.352 0.001 {method 'to' of 'torch._C._TensorBase' objects}
152 0.001 0.000 2.373 0.016 /home/user/engine.py:13(send_to_device_test)
581 0.002 0.000 0.179 0.000 /home/user/engine.py:8(send_to_device_train)
Despite performing .to() one quarter as many times, send_to_device_test
takes 13 times longer (at least for cumulative time).
I tried removing the .to()s from the torch.inference_mode() (calling the context manager after the tensors have been sent) and also set shuffle=True, in the hope that might be the culprit, but the problem persists.
I wonder if anyone would be able to help me solve this issue?