I want to run torchinfo
on BertClassifier
and can’t do it without errors:
class BertClassifier(nn.Module):
def __init__(self, dropout=0.5):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-cased')
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, 3)
self.relu = nn.ReLU()
def forward(self, input_id, mask):
vEmbeddingToken, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
dropout_output = self.dropout(pooled_output)
linear_output = self.linear(dropout_output)
final_layer = self.relu(linear_output)
return final_layer
torchinfo.summary(BertClassifier(), ((4, 512),(4, 1, 512)))
Getting error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:272, in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs)
271 if isinstance(x, (list, tuple)):
--> 272 _ = model.to(device)(*x, **kwargs)
273 elif isinstance(x, dict):
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
Input In [24], in BertClassifier.forward(self, input_id, mask)
12 def forward(self, input_id, mask):
13
14 #
15 # pooled_output - embedding vector of [CLS] token
16 #
---> 17 vEmbeddingToken, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
18 dropout_output = self.dropout(pooled_output)
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
1126 input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
1129 if _global_forward_hooks or self._forward_hooks:
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:1010, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
1008 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-> 1010 embedding_output = self.embeddings(
1011 input_ids=input_ids,
1012 position_ids=position_ids,
1013 token_type_ids=token_type_ids,
1014 inputs_embeds=inputs_embeds,
1015 past_key_values_length=past_key_values_length,
1016 )
1017 encoder_outputs = self.encoder(
1018 embedding_output,
1019 attention_mask=extended_attention_mask,
(...)
1027 return_dict=return_dict,
1028 )
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
1126 input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
1129 if _global_forward_hooks or self._forward_hooks:
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:235, in BertEmbeddings.forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
234 if inputs_embeds is None:
--> 235 inputs_embeds = self.word_embeddings(input_ids)
236 token_type_embeddings = self.token_type_embeddings(token_type_ids)
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/module.py:1128, in Module._call_impl(self, *input, **kwargs)
1126 input = bw_hook.setup_input_hook(input)
-> 1128 result = forward_call(*input, **kwargs)
1129 if _global_forward_hooks or self._forward_hooks:
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/modules/sparse.py:158, in Embedding.forward(self, input)
157 def forward(self, input: Tensor) -> Tensor:
--> 158 return F.embedding(
159 input, self.weight, self.padding_idx, self.max_nm,
160 self.norm_type, self.scale_grad_by_freq, self.sparse)
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torch/nn/functional.py:2183, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2182 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2183 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Input In [25], in <cell line: 1>()
----> 1 torchinfo.summary(BertClassifier(), ((4, 512),(4, 1, 512)))
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:201, in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, row_settings, verbose, **kwargs)
196 validate_user_params(input_data, input_size, columns, col_width, verbose)
198 x, correct_input_size = process_input(
199 input_data, input_size, batch_dim, device, dtypes
200 )
--> 201 summary_list = forward_pass(
202 model, x, batch_dim, cache_forward_pass, device, **kwargs
203 )
204 formatting = FormattingOptions(depth, verbose, columns, col_width, rows)
205 results = ModelStatistics(
206 summary_list, correct_input_size, get_total_memory_used(x), formatting
207 )
File ~/abWrk/abWrkVenv/lib/python3.8/site-packages/torchinfo/torchinfo.py:281, in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs)
279 except Exception as e:
280 executed_layers = [layer for layer in summary_list if layer.executed]
--> 281 raise RuntimeError(
282 "Failed to run torchinfo. See above stack traces for more details. "
283 f"Executed layers up to: {executed_layers}"
284 ) from e
285 finally:
286 if hooks is not None:
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
How can I use torchinfo
on BertClassifier
?