ONNX of switch transformer fails in a weird place

Hi,

I am trying to convert a SwitchTransformer to an ONNX graph using this code

from transformers import SwitchTransformersEncoderModel, SwitchTransformersConfig, AutoTokenizer
import torch

if __name__ == "__main__":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    vocab_size = 1024
    batch_size = 10
    switch_transformer_config = SwitchTransformersConfig(
        vocab_size=vocab_size,
        d_model=256,
        d_kv=64,
        d_ff=512,
        num_layers=2,
        num_sparse_encoder_layers=1,
        num_experts=3,
    )
    # - Create a small switch transformer
    model = SwitchTransformersEncoderModel(switch_transformer_config).to(device).eval()
    random_inputs = (vocab_size * torch.rand(batch_size,vocab_size)).int().to(device)

    # - Check if the model runs on the input
    model(random_inputs)

    torch.onnx.export(
        model,
        random_inputs,
        "onnx-models/switch_transformer.onnx",
        verbose=True
    )

I get the following error

input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1 INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":554, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

However, this error disappears when I comment out the dropout call in line 379 of transformers/models/switch_transformers/modeling_switch_transformers.py which says

output = hidden_states + self.dropout(forwarded_states)

I don’t understand why this is happening as the model successfully runs over the input.

Any help would be appreciated.

Ok, actually it makes sense that it works when I comment out the dropout call as this basically makes the forwarded_state redundant.

def forward(self, hidden_states, output_router_logits):
    forwarded_states = self.layer_norm(hidden_states)
    forwarded_states = self.mlp(forwarded_states)

    if isinstance(forwarded_states, tuple):
        forwarded_states, router_tuple = forwarded_states
    else:
        router_tuple = None

    output = hidden_states + self.dropout(forwarded_states)

    if output_router_logits and router_tuple is not None:
        output = (output, router_tuple)

    return output

So the problem arises in the MLP layer, which is in this case the MoE layer.

def forward(self, hidden_states):
    r"""
    Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

    1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
    and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
    hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

    2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
    expert the corresponding hidden states.

    """
    # Step 1: Get the router_mask from the router as wel as the probabilities
    router_mask, router_probs, router_logits = self.router(hidden_states)
    expert_index = torch.argmax(router_mask, dim=-1)

    # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
    # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

    next_states = hidden_states.clone()
    for idx, expert in enumerate(self.experts.values()):

        token_indices = router_mask[:, :, idx].bool()
        next_states[token_indices] = expert(hidden_states[token_indices])

    hidden_states = router_probs * next_states
    return hidden_states, (router_logits, expert_index)