Ensembled LSTM fails on inference

Alright, here’s the issue I’m facing. I created a model ensemble using functorch’s combine_state_for_ensemble. The purpose of this is so I can run the same model with multiple different values for its parameters in parallel without having to loop over a forward pass for each version of this model one at a time. The model takes three different inputs and produces one output - most of the model is just a bunch of linear layers, but part of it is an LSTM; one of the three inputs is a sequence, the other two are just simple vectors. This is the error I’m getting:

  File "/home/USERNAME/miniconda3/envs/ENVNAME/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 774, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.

The following is how I create/call the ensemble:

class MultiModel:

	def __init__( self, mFile="", iterSpan=-1 ):

		GPU = pt.device( 'cuda:0' )
		self.IterSpan = iterSpan; self.device = GPU

		ValNets = []

		if iterSpan < 1:
			ValNet_0 = ValueNet( modelIter=iterSpan )
			ValNet_0.to( GPU ); ValNets.append( ValNet_0 )

		for t in range( 1,iterSpan+1 ):
			ValNet_t = ValueNet( modelIter=t, load_from_file=mFile )
			ValNet_t.to( GPU ); ValNets.append( ValNet_t )

		multimodel, multiparameters, multibuffers =\
			combine_state_for_ensemble( ValNets )

		self.MM = multimodel
		self.MP = multiparameters
		self.MB = multibuffers

	def __call__( self, hInputs, oInputs, aInputs ):

		#in_dims is what it is because the intention is to feed the 
		#same batch of data into all models in the ensemble
		return vmap( self.MM, in_dims=(0,0,None,None,None) )(
			self.MP, self.MB, hInputs,oInputs,aInputs )

If it’s relevant, I have no need to use this for training - the individual ValNets are trained in isolation and only need to be ensembled for inference. Can post the code defining exactly what a ValNet is too but I’ll wait to be asked for that one.