Do nn.Modules in for loop cause issues?

Hey guys,

I have a general question about running nn.modules in for loops. Are memory leaks, slow gradients, or prohibitive memory usage things I should be concerned about? Is there a limit to the size of a nn.Module I can iterate over in a for loop? I have heard varying things from users of PyTorch, and I feel like the question could be well addressed here.

Lets say, for example:

  1. I have an encoded question size: (batch, seq_length, question_dim).
  2. I then want to apply a time-dependent linear transformation to each time step of the question
  3. I want to concatenate my result from (2.) to another vector called “context_previous” size (batch, 1, question_dim) and linearly transform that to yield “context_present” size (batch, 1, question_dim)
  4. “context_present” then becomes “context_previous” in the next time step.

The only way that I see for this kind of computation to be possible is via a for loop. However, I feel like iterating repeatedly over the same nn.Module causes issues.

Here is some example code that I have for a project that I have been working on:

    def forward(self, q, q_lens, v = None, n_objs = None):
        '''controller process all time steps at once'''
        ### get question embedding ###
        bs = q.size(0)
        embeddings = self.embedding(q)
        if self.lstm:
            q, (o, h), q_lens = self.encoder(embeddings, q_lens)
        else:
            q, o, q_lens = self.encoder(embeddings, q_lens)
        q_mask = generate_mask(q, q_lens,reverse = True, device = self.device).squeeze(1)
        q = q.masked_fill(q_mask, 0)
    
        ### get context and module probs for each timesep 
        probs = []
        c_list = []
        obj_probs = []
        c_prev = torch.zeros(bs, self.h_dim).to(self.device)
        v_prev = v if self.v_dim else None
        ### get contexts, module probs, and obj_probs for each timestep ###
        for timestep in range(self.t_controller):
            w1_t = self.w1[timestep](o).to(self.device)
            u = self.w2(torch.cat((w1_t, c_prev), dim = 1)) 
            ### get module scores ###
            if self.n_modules is not None:
                module_weights = self.mlp(u)
                module_scores = self.softmax(module_weights)
                probs.append(module_scores)
            ### question attention for context ###
            elem_prod = u.unsqueeze(1).repeat(1,q.size(1),1)  * q 
            q_weights = self.w3(elem_prod.masked_fill(q_mask, 0)) 
            q_weights[q_weights == 0] = float("-inf")
            q_att = self.softmax_context(q_weights)
            ### context ###
            c_prev = q * q_att
            c_prev = torch.sum(c_prev, dim = 1)
            ### get obj probs ###          
            if self.v_dim is not None:
               c_logits = self.contx_v_att(c_prev, v, n_objs)
               v_prev = c_logits * v_prev
               obj_probs.append(c_logits) 
            ### append context and module_probs to list ###
            c_list.append(c_prev)
        contexts = torch.stack(c_list, dim = 1)
        module_probs = torch.stack(probs, dim = 1) if probs else None
        obj_probs = torch.stack(obj_probs, dim = 1) if obj_probs else None
        return contexts, q, o, module_probs, obj_probs

Any insights would be appreciated! Thank you!!!

for loops hurt the performace (not specific to Pytorch, but Numpy in general). If you can vectorize some of the operations (may be a little difficult in this case), you can expect drastic performance gain.

Hmm, interesting. Do you have any idea or experience as to why that is the case? I would have tried to vectorize the entire process and eliminate any loops, but I think that with the case that I explained above, it would be impossible. So, I am unsure how to get around this issue