I’m using torch text and there is a strange behavior which I’m unable to understand here. I’m training a CNN on SST-2 dataset. Training time for one epoch is ~5 seconds. However, if I iterate over just batches, time taken is ~100ms, and just predicting takes a total of ~300ms. I’m unable to understand where time delay is
import time
total = 0
base = time.time()
for b in loader.train_loader:
now = time.time()
res = model(b.text)
tm = time.time() - now
total += tm
print("Prediction total time", total)
print("Elapsed time", time.time() - base)
Output is
Prediction total time 0.36881256103515625
Elapsed time 5.046782493591309
For just iterating through batches
import time
total = 0
base = time.time()
for b in loader.train_loader:
len(b.text)
tm = time.time() - now
total += tm
print("Elapsed time", time.time() - base)
output is
Elapsed time 0.0856020450592041
Can anyone explain what I’m missing here? Thanks