Reset LSTM hidden states efficiently

Is there an efficient way for PyTorch to reset during unrolls? Something like haiku.ResetCore(). It is useful in implementing a recurrent policy in RL.