How to run torchinfo on BertClassifier?

I want to run torchinfo on BertClassifier and can’t do it without errors:

class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert    = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear  = nn.Linear(768, 3)
        self.relu    = nn.ReLU()

    def forward(self, input_id, mask):
        
        
        vEmbeddingToken, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output                 = self.dropout(pooled_output)
        linear_output                  = self.linear(dropout_output)
        final_layer                    = self.relu(linear_output)

        return final_layer


torchinfo.summary(BertClassifier(), ((4, 512),(4, 1, 512)))

Getting error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:272, in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs)
    271 if isinstance(x, (list, tuple)):
--> 272     _ = model.to(device)(*x, **kwargs)
    273 elif isinstance(x, dict):

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used

Input In [24], in BertClassifier.forward(self, input_id, mask)
     12 def forward(self, input_id, mask):
     13     
     14     #
     15     # pooled_output - embedding vector of [CLS] token
     16     #
---> 17     vEmbeddingToken, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
     18     dropout_output                 = self.dropout(pooled_output)

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
   1126     input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
   1129 if _global_forward_hooks or self._forward_hooks:

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:1010, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1008 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-> 1010 embedding_output = self.embeddings(
   1011     input_ids=input_ids,
   1012     position_ids=position_ids,
   1013     token_type_ids=token_type_ids,
   1014     inputs_embeds=inputs_embeds,
   1015     past_key_values_length=past_key_values_length,
   1016 )
   1017 encoder_outputs = self.encoder(
   1018     embedding_output,
   1019     attention_mask=extended_attention_mask,
   (...)
   1027     return_dict=return_dict,
   1028 )

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
   1126     input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
   1129 if _global_forward_hooks or self._forward_hooks:

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:235, in BertEmbeddings.forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    234 if inputs_embeds is None:
--> 235     inputs_embeds = self.word_embeddings(input_ids)
    236 token_type_embeddings = self.token_type_embeddings(token_type_ids)

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
   1126     input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
   1129 if _global_forward_hooks or self._forward_hooks:

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/sparse.py:158, in Embedding.forward(self, input)
    157 def forward(self, input: Tensor) -> Tensor:
--> 158     return F.embedding(
    159         input, self.weight, self.padding_idx, self.max_nm,
    160         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/functional.py:2183, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2182     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2183 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Input In [25], in <cell line: 1>()
----> 1 torchinfo.summary(BertClassifier(), ((4, 512),(4, 1, 512)))

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:201, in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, row_settings, verbose, **kwargs)
    196 validate_user_params(input_data, input_size, columns, col_width, verbose)
    198 x, correct_input_size = process_input(
    199     input_data, input_size, batch_dim, device, dtypes
    200 )
--> 201 summary_list = forward_pass(
    202     model, x, batch_dim, cache_forward_pass, device, **kwargs
    203 )
    204 formatting = FormattingOptions(depth, verbose, columns, col_width, rows)
    205 results = ModelStatistics(
    206     summary_list, correct_input_size, get_total_memory_used(x), formatting
    207 )

File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:281, in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs)
    279 except Exception as e:
    280     executed_layers = [layer for layer in summary_list if layer.executed]
--> 281     raise RuntimeError(
    282         "Failed to run torchinfo. See above stack traces for more details. "
    283         f"Executed layers up to: {executed_layers}"
    284     ) from e
    285 finally:
    286     if hooks is not None:

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

How can I use torchinfo on BertClassifier ?

It seems this model expects a specific dtype in its inputs and torchinfo fails with:

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

I took a quick look at the repository and I would guess you could either pass the input_data with the right dtypes to summary or use the dtypes argument:

    dtypes (List[torch.dtype]):
            If you use input_size, torchinfo assumes your input uses FloatTensors.
            If your model use a different data type, specify that dtype.
            For multiple inputs, specify the size of both inputs, and
            also specify the types of each parameter here.
            Default: None