kengz
(Wah Loon Keng)
2
- use index to pick samples as opposed to directly sampling on the list:
- save on tensor creation, do it by batch directly, and if your data is numpy use torch.from_numpy:
def sample(self, batch_size):
batch_idxs = np.random.randint(len(self), size=batch_size)
batches = list(zip(*self.mem[batch_idxs]))
return [torch.from_numpy(batch) for batch in batches]
- If you want to save on the Tensor creation cost, store your data directly as tensors, then sample directly without
torch.from_numpy(batch)
.