I tested .cpu() speed with the following code.
import torch
import torch.nn as nn
from time import time
import numpy as np
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10000, 1000)
self.fc2 = nn.Linear(1000, 1000)
self.fc3 = nn.Linear(1000, 1000)
self.fc4 = nn.Linear(1000, 1000)
def forward(self,x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
return x
net = Net()
net.cuda()
net.eval()
try1_time = []
try2_time = []
try3_time = []
for _ in range(10):
x = torch.FloatTensor(10,10000).random_().cuda()
with torch.no_grad():
y1 = net(x)
#try1
st = time()
y1_numpy = y1.cpu().numpy()
try1_time.append(time()-st)
y1 = y1*10
#try2
st = time()
y1_numpy = y1.cpu().numpy()
try2_time.append(time()-st)
with torch.no_grad():
y2 = net(x) # y2, not y1
#try3
st = time()
y1_numpy = y1.cpu().numpy() # y1
try3_time.append(time()-st)
print ("try1:%fs" %(np.mean(try1_time)))
print ("try2:%fs" %(np.mean(try2_time)))
print ("try3:%fs" %(np.mean(try3_time)))
result is
try1:0.001630s
try2:0.000073s
try3:0.001641s
I expected try2 time = try3 time, because y1 is not affected by y2 = net(x).
How can I speed up try3?