Hi, I am new to Pytorch and this forum. Just start to use Pytorch three month ago. I recently have used few sequece models to map a earth system timeseries. There is one structure I build present to be very strange. My other models would get a training time decrease when I use large batch size, but this model didn’t. The model is a CNN and GRU connected model. I tested batch size of 220, 520, 1020 and 3520. The memory cost keep increasing, but the training time didn’t reduce. For example, 220 batch will finish one epoch within 65 seconds, and 1020 will take 73 seconds and 35*20 will take 77s. The input data format is [time, batch, 9 feature, 10 layer]. Here are the model and training process :
class N2OSPConv1d_GRU(nn.Module):
def __init__(self,n_f,n_l):
super(N2OSPConv1d_GRU, self).__init__()
self.conv1_1 = nn.Conv1d(n_f, 16, kernel_size=1)
self.conv1_2 = nn.Conv1d(n_f, 16, kernel_size=3, padding=1)
self.conv1_3 = nn.Conv1d(n_f, 16, kernel_size=5, padding=2)
self.conv2 = nn.Conv1d(48, 96, kernel_size=3, padding=1)
self.conv3 = nn.Conv1d(96, 128, kernel_size=3, padding=1)
self.fc1 = nn.Conv1d(128, 64, kernel_size=1)
self.fc2 = nn.Conv1d(64, 16, kernel_size=1)
self.gru = nn.GRU(16*n_l, 64, 2,dropout=0.2)
self.out = nn.Linear(64, 1)
self.drop=nn.Dropout(0.2)
self.nl=n_l
self.nf=n_f
self.nhid=64
def forward(self, x,hid):
stime=time.time()
tseq = x.size(0)
bsz0 = x.size(1)
x = x.contiguous().view(tseq*bsz0,self.nf,self.nl)
x1 = F.relu(self.conv1_1(x))
x2 = F.relu(self.conv1_2(x))
x3 = F.relu(self.conv1_3(x))
x = torch.cat((x1, x2, x3), 1)
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x,hid = self.gru(x.view(tseq,bsz0,16*self.nl),hid)
x = self.out(self.drop(x))
return x,hid
#bsz should be batch size
def init_hidden(self, bsz):
weight = next(self.parameters())
return weight.new_zeros(2, bsz, self.nhid)
And here is the training:
starttime=time.time()
lr_adam=0.0001
optimizer = optim.Adam(model1.parameters(), lr=lr_adam) #add weight decay normally 1-9e-4
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
bsz1=train_n*fln
bsz_val1=val_n*fln
bsz=2*fln
bsz_val=1*fln
totsq=Tx*tyear
slw=365*2
slw05=365*2
maxit=int((totsq-slw)/slw05+1)
train_losses = []
val_losses = []
maxepoch=2
model1.train()
for epoch in range(maxepoch):
train_loss=0.0
val_loss=0.0
Y_pred_all=torch.zeros(Y_train.size(),device=device)
model1.zero_grad()
for bb in range(int(bsz1/bsz)):
with torch.no_grad():
hidden = model1.init_hidden(bsz)
for it in range(maxit):
Y_pred,hidden = model1(X_train[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,:,:],hidden)
loss = my_loss(Y_pred, Y_train[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,:])
hidden.detach_() #if GRU)
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss=train_loss+loss.item()
Y_pred_all[slw05*it:slw05*it+slw,bb*bsz:(bb+1)*bsz,0]=Y_pred[:,:,0]
scheduler.step()
#validation
model1.eval()
with torch.no_grad():
train_loss=train_loss/(bsz1/bsz)/maxit
train_losses.append(train_loss)
train_R2=pearsonr2(Y_pred_all.contiguous().view(-1),Y_train.contiguous().view(-1))
Y_val_pred=torch.zeros(Y_val.size(),device=device)
for bb in range(int(bsz_val1/bsz_val)):
hidden=model1.init_hidden(bsz_val)
for it in range(maxit):
Y_val_pred_t,hidden = model1(X_val[slw05*it:slw05*it+slw,bb*bsz_val:(bb+1)*bsz_val,:,:],hidden)
Y_val_pred[slw05*it:slw05*it+slw,bb*bsz_val:(bb+1)*bsz_val,0] = Y_val_pred_t[:,:,0]
loss = my_loss(Y_val_pred, Y_val)
val_loss=loss.item()
val_losses.append(val_loss)
val_R2=pearsonr2(Y_val_pred.contiguous().view(-1),Y_val.contiguous().view(-1))
if val_loss < loss_val_best and val_R2 > R2_best:
loss_val_best=val_loss
R2_best = val_R2
f0=open(path_save,'w')
f0.close()
#os.remove(path_save)
torch.save({'epoch': epoch,
'model_state_dict': model1.state_dict(),
'R2': train_R2,
'loss': train_loss,
'los_val': val_loss,
'R2_val': val_R2,
}, path_save)
print("finished training epoch", epoch+1)
mtime=time.time()
print("train_loss: ", train_loss, "train_R2", train_R2,"val_loss:",val_loss,"val_R2", val_R2,\
"loss val best:",loss_val_best,"R2 val best:",R2_best, f"Spending time: {mtime - starttime}s")
if train_R2 > 0.99:
break
model1.train()
endtime=time.time()
print(f"total Training time: {endtime - starttime}s")
The training part is similar to other models, and it performs normal, when I increasing batch, the training time reduced. Do you guys have any ideas of why this happening and how to reduce the training time?Preformatted text