Inference getting slower with TensorDateset training?

Hello, I am using GoogleColab to train a generative model and find that moving the data to GPU is the bottleneck of the training. Since Colab provides sufficient RAM, I did the following:

  1. put all the training images in one NumPy array
  2. create a tensor from it
  3. move the tensor to GPU
  4. create a TensorDataset from the tensor
  5. feed TensorDataset to dataloader

As you can guess the cost of moving data to GPU is significantly decreased, while the inference step starts to cost almost the same amount of time, and what is even weird is that at the beginning of each epoch, the inference is almost the same speed, and within 3 batches it totally slows down…

Okay so I managed to locate one operation during the loss construction, and the time cost of it does not change after I switched to TensorDataset, while if I have it simplified the inference time strangely decrease to the previous level, which is even confusing now.

To better explain this, this is a piece of the sample code:

# shape of input and target: (batch_size,3,256,256)
for input,target in train_loader:
    
    Generative_Model.zero_grad()
    predict = Generative_Model(input)

    # shape of m_predict and m_target: (batch_size,1,256,256)
    m_predict = predict[:,0,:,:]+predict[:,1,:,:]+predict[:,2,:,:]
    m_predict = torch.reshape(input,(input.shape[0],1,input.shape[2],input.shape[3]))
    m_target = target[:,0,:,:]+target[:,1,:,:]+target[:,2,:,:]
    m_target = torch.reshape(input,(input.shape[0],1,input.shape[2],input.shape[3]))

    # kernels are prefixed tensor
    # shape of kernels: (30,1,200,200)
    m_predict = torch.nn.functional.conv2d(m_predict,kernels,padding="same")
    m_target = torch.nn.functional.conv2d(m_target,kernels,padding="same")
    
    loss = torch.sum(m_predict)+torch.sum(m_target)
    
    loss.backward()
    Generative_Model_optimizer.step()

So with the dataset of train_loader changed to TensorDatasest, the cost of predict = Generative_Model(input) increases from 0.01s to 0.88s,

while if I change
m_predict = torch.nn.functional.conv2d(m_predict,kernels,padding="same")
m_target = torch.nn.functional.conv2d(m_target,kernels,padding="same")
to
m_predict = torch.nn.functional.conv2d(m_predict,kernels)
m_target = torch.nn.functional.conv2d(m_target,kernels)
Although the computation result is no longer my intention, the cost of predict = Generative_Model(input) falls back to 0.01s, and in both cases these two conv2d operation cost less than 0.01s…