Training time for computing loss.backward gradually increases as the training goes on

I am tring Pytorch these days and have built a simple RNN acoustic model for automatic speech recognition. The traing phrase goes well and final loss of traing set seems quite good.

However, when I calculate time consumation for each batch, I find that the time duration for each batch gradually increases and the trend is very obvious. As I dig deeper, I find that the backward part is the only one that gradually grows, not for data_loading or forward or optimation. I use google search and find there is a simlar post in Pytorch github issue but it doesn’t help because my situation is that even when I use the same training data for each batch, the time for backward still grows. I can’t figure out the reason so I come here and ask for help.

The following is the code for network definition:

class Net(nn.Module):

def __init__(self, input_dim, hidden_dim, target_dim, num_layers, batch_size):
    super(Net, self).__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.target_dim = target_dim
    self.num_layers = num_layers
    self.batch_size = batch_size
    self.rnn = nn.RNN(input_dim, hidden_dim, num_layers, batch_first=True)
    self.fc1 = nn.Linear(hidden_dim, target_dim)
    self.hidden = self.init_hidden()

def init_hidden(self):
    return (Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_dim)).cuda())

def forward(self, feats, lengths):
    packed_data = nn.utils.rnn.pack_padded_sequence(feats, lengths, batch_first=True)
    output, self.hidden = self.rnn(packed_data, self.hidden)
    output, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
    output = output.contiguous().view(-1, output.size(2))
    output = self.fc1(output)
    return output

To debug, I use the same training data for the whole training phrase. And it still shows the same trend. The following is the code for training and I use time.time() to calculate time duration for each part.

if name == ‘main’:
feat_dim = 13
num_layers = 1
hidden_dim = 32
batch_size = 8
target_size = 8845
num_epochs = 10

model = Net(feat_dim, hidden_dim, target_size, num_layers, batch_size)
model.cuda()
model.apply(weights_init)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
train_data = get_loader('data/train', batch_size, num_workers=0)
data0, data_len0, tag0, tag_len0 = iter(train_data).next()

for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_idx, (data, data_len, tag, tag_len) in enumerate(train_data):
        
        load_start = time.time()
        data, tag = Variable(data0).cuda(), Variable(tag0).cuda()
        t_load = time.time() - load_start

        optimizer.zero_grad()

        fwd_start = time.time()
        output = model(data, data_len0)
        t_fwd = time.time() - fwd_start

        tag = tag.view(-1)
        loss = loss_function(output, tag)
        
        bwd_start = time.time()
        loss.backward(retain_graph=True)
        t_bwd = time.time() - bwd_start
       
        upd_start = time.time()
        optimizer.step()
        t_upd = time.time() - upd_start

        print("epoch:%d, batch:%d, t_load:%.4f, t_fwd:%.4f, t_bwd:%.4f, t_upd:%.4f\n" 
                % (epoch, batch_idx, t_load, t_fwd, t_bwd, t_upd))

The print log is like:

epoch:0, batch:0, t_load:0.0006, t_fwd:0.6188, t_bwd:0.0508, t_upd:0.0008

epoch:0, batch:1, t_load:0.0017, t_fwd:0.0308, t_bwd:0.0908, t_upd:0.0004

epoch:0, batch:2, t_load:0.0004, t_fwd:0.0303, t_bwd:0.1255, t_upd:0.0005

epoch:0, batch:3, t_load:0.0004, t_fwd:0.0509, t_bwd:0.1587, t_upd:0.0005

epoch:0, batch:4, t_load:0.0005, t_fwd:0.0307, t_bwd:0.1937, t_upd:0.0006

epoch:0, batch:122, t_load:0.0004, t_fwd:0.0312, t_bwd:4.2359, t_upd:0.0008

epoch:0, batch:123, t_load:0.0006, t_fwd:0.0306, t_bwd:4.2741, t_upd:0.0005

epoch:1, batch:0, t_load:0.0004, t_fwd:0.0307, t_bwd:4.3143, t_upd:0.0005

epoch:1, batch:1, t_load:0.0004, t_fwd:0.0306, t_bwd:4.3380, t_upd:0.0006

epoch:1, batch:2, t_load:0.0004, t_fwd:0.0301, t_bwd:4.3747, t_upd:0.0007

epoch:1, batch:122, t_load:0.0005, t_fwd:0.0291, t_bwd:8.4838, t_upd:0.0006

epoch:1, batch:123, t_load:0.0004, t_fwd:0.0303, t_bwd:8.5050, t_upd:0.0006

epoch:2, batch:0, t_load:0.0005, t_fwd:0.0302, t_bwd:8.5414, t_upd:0.0007

epoch:2, batch:1, t_load:0.0003, t_fwd:0.0299, t_bwd:8.5654, t_upd:0.0007

epoch:2, batch:2, t_load:0.0005, t_fwd:0.0302, t_bwd:8.6911, t_upd:0.0006

The configuration of my workspace is:

Ubuntu 14.04
python 3.5.4
numpy 1.14.0
pytorch 0.3.0
CUDA 8.0
Cudnn 6.0

Can anyone give me some hint? I will put more information here if need. Thanks a lot!

Each successive batch backpropagates through the timesteps of that batch AND the timesteps of all the previous batches.

You need to detach or repackage the hidden state before running each batch.

# detach
model.hidden.detach_()

# or repackage
model.hidden = Variable(model.hidden.data, requires_grad=True)
1 Like

Thank you very much @jpeg729 . I have added ‘model.hidden.detach_()’ in the loop code and it worked. I have never seen this feature in any tutorial before. I will try to look into it. Thanks again for your generous help.

1 Like