Not sure if this is the right way to do it. I’m wrapping my model in nn.DataParallel for multi gpu training. There is an LSTM module as part of this model. This LSTM module has a custom method that resets the hidden states, called after each time a forward pass is done during training. This is the only custom method that’s used.
To access this reset method in the parallel model, I do
model.module.lstm.reset_hidden_state()
Whereas if my model is not wrapped in DataParallel, it would just be
model.lstm.reset_hidden_state()
Is this right, or do I have to write a custom DataParallel wrapper that has scatter, gather, etc methods? If so, how would I do it?
This is the lstm module:
class LSTM(nn.Module):
def __init__(self, latent_dim, num_layers, hidden_dim):
super().__init__()
self.lstm = nn.LSTM(input_size=latent_dim, num_layers=num_layers, hidden_size=hidden_dim, batch_first=True, dropout=0.0)
self.hidden_state = None
def reset_hidden_state(self):
self.hidden_state = None
def forward(self,X):
self.lstm.flatten_parameters()
X, self.hidden_state = self.lstm(X, self.hidden_state)
return X