EDIT: don’t detach_()
parameters without also resetting param.requires_grad = True
.
I have a very weird case of ballooning memory with a custom-built rnn unit. During each batch the memory usage balloons if I don’t .detach_()
the rnn unit’s parameters between epochs. The same also occurs if I put a Linear layer before the rnn unit.
I run my models on the CPU using pytorch version 0.3.1.post2. I “measure” the memory use by eyeballing my memory monitor widget.
During a normal run, my full script uses ~500Mb memory, but under the above conditions the memory usage goes up ~4Gb more during the first batch, it goes down at the end of each batch only to go back up again during the next batch.
I tried to code a minimal version, but I can’t get the memory usage to balloon. However the training time doubles or triples under the same conditions that cause the memory to balloon with my full code.
Here it is…
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
class TRNN(nn.Module):
"""Strongly typed RNN from https://arxiv.org/abs/1602.02218 with the bias removed"""
def __init__(self, input_size, hidden_size, detach=True):
super(TRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.detach = detach
self.weight_ih = nn.Parameter(torch.Tensor(2 * hidden_size, input_size))
self.reset_parameters()
def __repr__(self):
s = '{name}({input_size}, {hidden_size}, detach={detach})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
self.weight_ih.data.uniform_(-stdv, stdv)
self.hidden = None
def reset_hidden(self):
self.hidden = None
if self.detach:
self.weight_ih.detach_()
def detach_hidden(self):
self.hidden.detach_()
def forward(self, input_data):
timesteps, batch_size, features = input_data.size()
outputs = Variable(torch.zeros(timesteps, batch_size, self.hidden_size))
if self.hidden is None:
self.hidden = Variable(torch.zeros(batch_size, self.hidden_size))
for i, input_t in enumerate(input_data.split(1)):
gi = F.linear(input_t.view(batch_size, features), self.weight_ih, None)
i_n, i_f = gi.chunk(2, 1)
forgetgate = F.sigmoid(i_f)
newgate = i_n
self.hidden = newgate + forgetgate * (self.hidden - newgate)
outputs[i] = self.hidden
return outputs
def reset_hidden(layer):
if hasattr(layer, "reset_hidden"):
layer.reset_hidden()
def detach_hidden(layer):
if hasattr(layer, "detach_hidden"):
layer.detach_hidden()
def train(model, optimizer, batches):
start = time.time()
for epoch in range(5):
model.apply(reset_hidden)
for inputs, targets in batches:
output = model(Variable(inputs))
optimizer.zero_grad()
loss = F.mse_loss(output, Variable(targets))
loss.backward()
optimizer.step()
model.apply(detach_hidden)
print("Training done in", time.time()-start, "seconds")
print()
batch_size = 32
seq_len = 500
features = 100
targets = 10
batches = []
for i in range(10):
batches.append((
torch.randn(batch_size, seq_len, features),
torch.randn(batch_size, seq_len, targets)
))
print("Don't detach parameters")
model = nn.Sequential(
TRNN(features, 200, detach=False),
nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)
print("Do detach parameters")
model = nn.Sequential(
TRNN(features, 200, detach=True),
nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)
print("Linear first and do detach")
model = nn.Sequential(
nn.Linear(features, 100),
TRNN(features, 200, detach=False),
nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)
Output on my lowly old laptop…
Don’t detach parameters
Training done in 33.29167413711548 secondsDo detach parameters
Training done in 13.432651281356812 secondsLinear first and do detach
Training done in 40.65597105026245 seconds
Any ideas?