Why first forward pass takes very long time for nn.DataParallel model

pytorch Version: 1.3.1
System configuration: p3.8xlarge aws instance

import torch.nn as nn
import torch
import time

x = torch.randn(2048, 2048).to("cuda")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2048, 1024)

    def forward(self, x):
        x = self.fc1(x)
        return x
#training:
net = Net()
net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
net.cuda()
for _ in range(10):
    net.train()
    tis =time.time()
    x = x.to("cuda")
    net(x)
    print(time.time() - tis)

Output:

12.42794680595398
0.0023696422576904297
0.0019919872283935547
0.0018835067749023438
0.001916646957397461
0.001863718032836914
0.001901865005493164
0.001911163330078125
0.0018074512481689453
0.0018219947814941406