I’m trying to convert the
y labels in
mnist data into one-hot format.
Since I’m not quite familiar with PyTorch yet, for each iteration, I just convert the
y to numpy format and reshape it into one-hot and then convert it back to PyTorch. Like that
for batch_idx, (x, y) in enumerate(train_loader): y_onehot = y.numpy() y_onehot = (np.arange(num_labels) == y_onehot[:,None]).astype(np.float32) y_onehot = torch.from_numpy(y_onehot)
However, I notice that the it gets slower each iteration, and I doubt it’s these code which might request new memory each iteration that makes the code slower.
So my question is, is there a more PyTorch way, which may help me avoid such conversion?