Hey guys,
I have a general question about running nn.modules in for loops. Are memory leaks, slow gradients, or prohibitive memory usage things I should be concerned about? Is there a limit to the size of a nn.Module I can iterate over in a for loop? I have heard varying things from users of PyTorch, and I feel like the question could be well addressed here.
Lets say, for example:
- I have an encoded question size: (batch, seq_length, question_dim).
- I then want to apply a time-dependent linear transformation to each time step of the question
- I want to concatenate my result from (2.) to another vector called “context_previous” size (batch, 1, question_dim) and linearly transform that to yield “context_present” size (batch, 1, question_dim)
- “context_present” then becomes “context_previous” in the next time step.
The only way that I see for this kind of computation to be possible is via a for loop. However, I feel like iterating repeatedly over the same nn.Module causes issues.
Here is some example code that I have for a project that I have been working on:
def forward(self, q, q_lens, v = None, n_objs = None):
'''controller process all time steps at once'''
### get question embedding ###
bs = q.size(0)
embeddings = self.embedding(q)
if self.lstm:
q, (o, h), q_lens = self.encoder(embeddings, q_lens)
else:
q, o, q_lens = self.encoder(embeddings, q_lens)
q_mask = generate_mask(q, q_lens,reverse = True, device = self.device).squeeze(1)
q = q.masked_fill(q_mask, 0)
### get context and module probs for each timesep
probs = []
c_list = []
obj_probs = []
c_prev = torch.zeros(bs, self.h_dim).to(self.device)
v_prev = v if self.v_dim else None
### get contexts, module probs, and obj_probs for each timestep ###
for timestep in range(self.t_controller):
w1_t = self.w1[timestep](o).to(self.device)
u = self.w2(torch.cat((w1_t, c_prev), dim = 1))
### get module scores ###
if self.n_modules is not None:
module_weights = self.mlp(u)
module_scores = self.softmax(module_weights)
probs.append(module_scores)
### question attention for context ###
elem_prod = u.unsqueeze(1).repeat(1,q.size(1),1) * q
q_weights = self.w3(elem_prod.masked_fill(q_mask, 0))
q_weights[q_weights == 0] = float("-inf")
q_att = self.softmax_context(q_weights)
### context ###
c_prev = q * q_att
c_prev = torch.sum(c_prev, dim = 1)
### get obj probs ###
if self.v_dim is not None:
c_logits = self.contx_v_att(c_prev, v, n_objs)
v_prev = c_logits * v_prev
obj_probs.append(c_logits)
### append context and module_probs to list ###
c_list.append(c_prev)
contexts = torch.stack(c_list, dim = 1)
module_probs = torch.stack(probs, dim = 1) if probs else None
obj_probs = torch.stack(obj_probs, dim = 1) if obj_probs else None
return contexts, q, o, module_probs, obj_probs
Any insights would be appreciated! Thank you!!!