Using custom method in distributed model

Not sure if this is the right way to do it. I’m wrapping my model in nn.DataParallel for multi gpu training. There is an LSTM module as part of this model. This LSTM module has a custom method that resets the hidden states, called after each time a forward pass is done during training. This is the only custom method that’s used.

To access this reset method in the parallel model, I do
model.module.lstm.reset_hidden_state()
Whereas if my model is not wrapped in DataParallel, it would just be
model.lstm.reset_hidden_state()

Is this right, or do I have to write a custom DataParallel wrapper that has scatter, gather, etc methods? If so, how would I do it?

This is the lstm module:

class LSTM(nn.Module):
    def __init__(self, latent_dim, num_layers, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_size=latent_dim, num_layers=num_layers, hidden_size=hidden_dim, batch_first=True, dropout=0.0)
        self.hidden_state = None
    def reset_hidden_state(self):
        self.hidden_state = None
    def forward(self,X):
        self.lstm.flatten_parameters()
        X, self.hidden_state = self.lstm(X, self.hidden_state)
        return X

It depends on what you expected reset_hidden_state to achieve. Below is what happens in EVERY forward pass when you use DataParallel.

  1. split input data
  2. replicate model to all devices
  3. feed input data splits to all model replicas
  4. gather outputs from all replicas
  5. done with forward

After the forward pass, the autograd graph actually contains multiple model replicas. It looks sth like

original model ← scatter ← model replicas ← replica output ← gather ← final output.

So in your above use case, if reset_hidden_state has any side effect that you would like to apply to the backward pass, it will only apply to the original model, not to model replicas. But if you are only trying to clear some states for the next forward pass, it should work.

1 Like