GPU inference slows if requested consecutively

Pytorch gpu first prediction is really fast but then slows down if prediction is requested consecutively.

Please find notebook to run the whole thing here
https://colab.research.google.com/drive/1gqSzQqFm8HL0OwmJzSRlcRFQ3FOpnvFh?usp=sharing

%%time
# Predict hidden states features for each layer
for each in range(10):
  with torch.no_grad():
      # See the models docstrings for the detail of the inputs
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
      # Transformers models always output tuples.
      # See the models docstrings for the detail of all the outputs
      # In our case, the first element is the hidden state of the last layer of the Bert model
      encoded_layers = outputs[0]
  # We have encoded our input sequence in a FloatTensor of shape (batch size, sequence length, model hidden dimension)
  # assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)
CPU times: user 23.7 s, sys: 21.4 s, total: 45.1 s
Wall time: 45.2 s
%%time
# Predict hidden states features for each layer
for each in range(1):
  with torch.no_grad():
      # See the models docstrings for the detail of the inputs
      outputs = model(tokens_tensor, token_type_ids=segments_tensors)
      # Transformers models always output tuples.
      # See the models docstrings for the detail of all the outputs
      # In our case, the first element is the hidden state of the last layer of the Bert model
      encoded_layers = outputs[0]
  # We have encoded our input sequence in a FloatTensor of shape (batch size, sequence length, model hidden dimension)
  # assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)
CPU times: user 20.4 ms, sys: 329 µs, total: 20.8 ms
Wall time: 22.9 ms

I have seen the same behaviour in tensorflow as well. Is this related to CUDA perhaps ?

CUDA operations are asynchronous so you would need to synchronize the code via torch.cuda.synchronize() before starting and stopping the timer.
Otherwise the timing will be accumulated in the next blocking operation and will yield a wrong result.