Hi,
I am trying to convert a SwitchTransformer to an ONNX graph using this code
from transformers import SwitchTransformersEncoderModel, SwitchTransformersConfig, AutoTokenizer
import torch
if __name__ == "__main__":
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vocab_size = 1024
batch_size = 10
switch_transformer_config = SwitchTransformersConfig(
vocab_size=vocab_size,
d_model=256,
d_kv=64,
d_ff=512,
num_layers=2,
num_sparse_encoder_layers=1,
num_experts=3,
)
# - Create a small switch transformer
model = SwitchTransformersEncoderModel(switch_transformer_config).to(device).eval()
random_inputs = (vocab_size * torch.rand(batch_size,vocab_size)).int().to(device)
# - Check if the model runs on the input
model(random_inputs)
torch.onnx.export(
model,
random_inputs,
"onnx-models/switch_transformer.onnx",
verbose=True
)
I get the following error
input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1 INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":554, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.
However, this error disappears when I comment out the dropout call in line 379 of transformers/models/switch_transformers/modeling_switch_transformers.py
which says
output = hidden_states + self.dropout(forwarded_states)
I don’t understand why this is happening as the model successfully runs over the input.
Any help would be appreciated.
Ok, actually it makes sense that it works when I comment out the dropout call as this basically makes the forwarded_state
redundant.
def forward(self, hidden_states, output_router_logits):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.mlp(forwarded_states)
if isinstance(forwarded_states, tuple):
forwarded_states, router_tuple = forwarded_states
else:
router_tuple = None
output = hidden_states + self.dropout(forwarded_states)
if output_router_logits and router_tuple is not None:
output = (output, router_tuple)
return output
So the problem arises in the MLP layer, which is in this case the MoE layer.
def forward(self, hidden_states):
r"""
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
expert the corresponding hidden states.
"""
# Step 1: Get the router_mask from the router as wel as the probabilities
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices])
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)