Alright, here’s the issue I’m facing. I created a model ensemble using functorch’s combine_state_for_ensemble. The purpose of this is so I can run the same model with multiple different values for its parameters in parallel without having to loop over a forward pass for each version of this model one at a time. The model takes three different inputs and produces one output - most of the model is just a bunch of linear layers, but part of it is an LSTM; one of the three inputs is a sequence, the other two are just simple vectors. This is the error I’m getting:
File "/home/USERNAME/miniconda3/envs/ENVNAME/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 774, in forward result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.
The following is how I create/call the ensemble:
class MultiModel: def __init__( self, mFile="", iterSpan=-1 ): GPU = pt.device( 'cuda:0' ) self.IterSpan = iterSpan; self.device = GPU ValNets =  if iterSpan < 1: ValNet_0 = ValueNet( modelIter=iterSpan ) ValNet_0.to( GPU ); ValNets.append( ValNet_0 ) for t in range( 1,iterSpan+1 ): ValNet_t = ValueNet( modelIter=t, load_from_file=mFile ) ValNet_t.to( GPU ); ValNets.append( ValNet_t ) multimodel, multiparameters, multibuffers =\ combine_state_for_ensemble( ValNets ) self.MM = multimodel self.MP = multiparameters self.MB = multibuffers def __call__( self, hInputs, oInputs, aInputs ): #in_dims is what it is because the intention is to feed the #same batch of data into all models in the ensemble return vmap( self.MM, in_dims=(0,0,None,None,None) )( self.MP, self.MB, hInputs,oInputs,aInputs )
If it’s relevant, I have no need to use this for training - the individual ValNets are trained in isolation and only need to be ensembled for inference. Can post the code defining exactly what a ValNet is too but I’ll wait to be asked for that one.