Alright, here’s the issue I’m facing. I created a model ensemble using functorch’s combine_state_for_ensemble. The purpose of this is so I can run the same model with multiple different values for its parameters in parallel without having to loop over a forward pass for each version of this model one at a time. The model takes three different inputs and produces one output - most of the model is just a bunch of linear layers, but part of it is an LSTM; one of the three inputs is a sequence, the other two are just simple vectors. This is the error I’m getting:
File "/home/USERNAME/miniconda3/envs/ENVNAME/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 774, in forward
result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.
The following is how I create/call the ensemble:
class MultiModel:
def __init__( self, mFile="", iterSpan=-1 ):
GPU = pt.device( 'cuda:0' )
self.IterSpan = iterSpan; self.device = GPU
ValNets = []
if iterSpan < 1:
ValNet_0 = ValueNet( modelIter=iterSpan )
ValNet_0.to( GPU ); ValNets.append( ValNet_0 )
for t in range( 1,iterSpan+1 ):
ValNet_t = ValueNet( modelIter=t, load_from_file=mFile )
ValNet_t.to( GPU ); ValNets.append( ValNet_t )
multimodel, multiparameters, multibuffers =\
combine_state_for_ensemble( ValNets )
self.MM = multimodel
self.MP = multiparameters
self.MB = multibuffers
def __call__( self, hInputs, oInputs, aInputs ):
#in_dims is what it is because the intention is to feed the
#same batch of data into all models in the ensemble
return vmap( self.MM, in_dims=(0,0,None,None,None) )(
self.MP, self.MB, hInputs,oInputs,aInputs )
If it’s relevant, I have no need to use this for training - the individual ValNets are trained in isolation and only need to be ensembled for inference. Can post the code defining exactly what a ValNet is too but I’ll wait to be asked for that one.