Training in with batch size 1 is very slow

I am training a simple 2 layers MLP in an online learning setting where batch size and number of epoch are 1.
The input size is (5000000, 28) and the network is (28-100-2).
However, the training speed is very slow (13 seconds for 5000 instances) compare to my previous implementation in keras (5 seconds for 10000 instances).

Is there a way to improve the speed in online training ?
I put the code for custom dataset and training here, I did try training with GPU but the speed is significantly slower than CPU because batch size is 1 in this case.

class Stream(Dataset):
    def __init__(self, x,y):
        self.x = x
        self.y = y
        self.length = x.size(0)

    def __iter__(self):
        return self
    def __len__(self):
        return self.length
    def __getitem__(self, idx):
        x_ = self.x[idx]
        y_ = self.y[idx]
        return x_, y_

stream = DataLoader(Stream(x_train, y_train), batch_size = 1, shuffle = False, num_workers = 8)
for j, (x,y) in enumerate(stream):
        loss_, acc_  = model.observe(x,y)
        loss += loss_
        acc += acc_
1 Like

What is model.observe returning? Are loss_ and acc_ detached?
Could you post the model code, since maybe you storing the whole computation in loss.

Here is the rest of the code, thank you.

class MLP(nn.Module):
    def __init__(self, sz):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(sz, 100)
        self.layer2 = nn.Linear(100,2)

    def forward(self, x):
        x = relu(self.layer1(x))
        x = self.layer2(x)
        return x

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__() = MLP(28)
        self.opt = torch.optim.SGD(self.parameters(), lr=0.1)
        self.bce = torch.nn.CrossEntropyLoss()

    def observe(self, x, y):
        y_ =
        loss = self.bce(y_, y)

        _, idx = torch.max(, 1 , keepdim=False)
        acc = (
        return loss, acc

Could you change the last line to:

return loss.item(), acc

and run it again? Right now you are returning the computation graph, which might slow down your application and increase the memory usage.

I modified accordingly but there was not much improvement in the running time.
Is there anything else I can try?

I have a similar problem as well. any suggestion?