Increasing batch size didn't help reduce training time

Hi, I am new to Pytorch and this forum. Just start to use Pytorch three month ago. I recently have used few sequece models to map a earth system timeseries. There is one structure I build present to be very strange. My other models would get a training time decrease when I use large batch size, but this model didn’t. The model is a CNN and GRU connected model. I tested batch size of 220, 520, 1020 and 3520. The memory cost keep increasing, but the training time didn’t reduce. For example, 220 batch will finish one epoch within 65 seconds, and 1020 will take 73 seconds and 35*20 will take 77s. The input data format is [time, batch, 9 feature, 10 layer]. Here are the model and training process :

class N2OSPConv1d_GRU(nn.Module):

    def __init__(self,n_f,n_l):
        super(N2OSPConv1d_GRU, self).__init__()
        self.conv1_1 = nn.Conv1d(n_f, 16, kernel_size=1)
        self.conv1_2 = nn.Conv1d(n_f, 16, kernel_size=3, padding=1)
        self.conv1_3 = nn.Conv1d(n_f, 16, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(48, 96, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(96, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Conv1d(128, 64, kernel_size=1)
        self.fc2 = nn.Conv1d(64, 16, kernel_size=1)
        self.gru = nn.GRU(16*n_l, 64, 2,dropout=0.2)
        self.out = nn.Linear(64, 1)
        self.drop=nn.Dropout(0.2)
        self.nl=n_l
        self.nf=n_f
        self.nhid=64
    def forward(self, x,hid):
        stime=time.time()
        tseq = x.size(0)
        bsz0 = x.size(1)        
        x = x.contiguous().view(tseq*bsz0,self.nf,self.nl)
        x1 = F.relu(self.conv1_1(x))
        x2 = F.relu(self.conv1_2(x))
        x3 = F.relu(self.conv1_3(x))
        x = torch.cat((x1, x2, x3), 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x,hid = self.gru(x.view(tseq,bsz0,16*self.nl),hid)
        x = self.out(self.drop(x))
        return x,hid
#bsz should be batch size
    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return weight.new_zeros(2, bsz, self.nhid)

And here is the training:

starttime=time.time()
lr_adam=0.0001 
optimizer = optim.Adam(model1.parameters(), lr=lr_adam) #add weight decay normally 1-9e-4
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
bsz1=train_n*fln
bsz_val1=val_n*fln
bsz=2*fln
bsz_val=1*fln
totsq=Tx*tyear
slw=365*2
slw05=365*2
maxit=int((totsq-slw)/slw05+1)
train_losses = []
val_losses = []
maxepoch=2
model1.train()
for epoch in range(maxepoch):
    train_loss=0.0
    val_loss=0.0       
    Y_pred_all=torch.zeros(Y_train.size(),device=device)
    model1.zero_grad()
    for bb in range(int(bsz1/bsz)):
        with torch.no_grad():
            hidden = model1.init_hidden(bsz)
        for it in range(maxit):
            Y_pred,hidden = model1(X_train[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,:,:],hidden)
            loss = my_loss(Y_pred, Y_train[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,:])
            hidden.detach_() #if GRU)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                train_loss=train_loss+loss.item()
                Y_pred_all[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,0]=Y_pred[:,:,0]
    scheduler.step()
    #validation
    model1.eval()
    with torch.no_grad():
        train_loss=train_loss/(bsz1/bsz)/maxit
        train_losses.append(train_loss)
        train_R2=pearsonr2(Y_pred_all.contiguous().view(-1),Y_train.contiguous().view(-1))
        Y_val_pred=torch.zeros(Y_val.size(),device=device)
        for bb in range(int(bsz_val1/bsz_val)):
            hidden=model1.init_hidden(bsz_val)
            for it in range(maxit):
                Y_val_pred_t,hidden = model1(X_val[slw05*it:slw05*it+slw,bb*bsz_val:(bb+1)*bsz_val,:,:],hidden)
                Y_val_pred[slw05*it:slw05*it+slw,bb*bsz_val:(bb+1)*bsz_val,0] = Y_val_pred_t[:,:,0]       
        loss = my_loss(Y_val_pred, Y_val)
        val_loss=loss.item()
        val_losses.append(val_loss)
        val_R2=pearsonr2(Y_val_pred.contiguous().view(-1),Y_val.contiguous().view(-1))
        if val_loss < loss_val_best and val_R2 > R2_best:
            loss_val_best=val_loss
            R2_best = val_R2
            f0=open(path_save,'w')
            f0.close()
            #os.remove(path_save)
            torch.save({'epoch': epoch,
                    'model_state_dict': model1.state_dict(),
                    'R2': train_R2,
                    'loss': train_loss,
                    'los_val': val_loss,
                    'R2_val': val_R2,
                    }, path_save)    
        print("finished training epoch", epoch+1)
        mtime=time.time()
        print("train_loss: ", train_loss, "train_R2", train_R2,"val_loss:",val_loss,"val_R2", val_R2,\
              "loss val best:",loss_val_best,"R2 val best:",R2_best, f"Spending time: {mtime - starttime}s")
        if train_R2 > 0.99:
            break
    model1.train()
endtime=time.time()
print(f"total Training time: {endtime - starttime}s")

The training part is similar to other models, and it performs normal, when I increasing batch, the training time reduced. Do you guys have any ideas of why this happening and how to reduce the training time?Preformatted text

You could profile your script and check which part of the overall pipeline is the current bottleneck.
E.g. if the data loading and processing is too slow, a model speedup won’t be visible as your training routing would have to wait for the next batch to be available.

Thanks @ptrblck. Data loading, processing, and each line in model have been tested. The bottle neck is the backward() part. It will increase a much larger than other mthod when I increase batch size for this model structure. Other models like pure LSTM, GRU and CNN will be fine, although the time increased for backward() in each time step but overall time will decrease.

I have some thinkings about the model strusture. I have used two View() in model to reshape the data. And all my structure with the view() will have this problem. So I think it may be due to this but I dont know why. I have also seen some of you previous answers on view() in forward function but I don’t know whether it is the problem.

seems that formatting issues conceal shapes, but in general rnn kernels can have O(timesteps) complexity, with batch size having a smaller effect

if it is 2*20 to 35*20, it is pretty small tensors, so size increase is almost “free” as you have some parallelism reserve.