from torchsummary import summary
Summarize the given PyTorch model. Summarized information includes:
1) Layer names,
2) input/output shapes,
3) kernel shape,
4) # of parameters,
5) # of operations (Mult-Adds)
Args:
model (nn.Module):
PyTorch model to summarize. The model should be fully in either train()
or eval() mode. If layers are not all in the same mode, running summary
may have side effects on batchnorm or dropout statistics. If you
encounter an issue with this, please open a GitHub issue.
input_data (Sequence of Sizes or Tensors):
Example input tensor of the model (dtypes inferred from model input).
- OR -
Shape of input data as a List/Tuple/torch.Size
(dtypes must match model input, default is FloatTensors).
You should NOT include batch size in the tuple.
- OR -
If input_data is not provided, no forward pass through the network is
performed, and the provided model information is limited to layer names.
Default: None
batch_dim (int):
Batch_dimension of input data. If batch_dim is None, the input data
is assumed to contain the batch dimension.
WARNING: in a future version, the default will change to None.
Default: 0
branching (bool):
Whether to use the branching layout for the printed output.
Default: True
col_names (Iterable[str]):
Specify which columns to show in the output. Currently supported:
("input_size", "output_size", "num_params", "kernel_size", "mult_adds")
If input_data is not provided, only "num_params" is used.
Default: ("output_size", "num_params")
col_width (int):
Width of each column.
Default: 25
depth (int):
Number of nested layers to traverse (e.g. Sequentials).
Default: 3
device (torch.Device):
Uses this torch device for model and input_data.
If not specified, uses result of torch.cuda.is_available().
Default: None
dtypes (List[torch.dtype]):
For multiple inputs, specify the size of both inputs, and
also specify the types of each parameter here.
Default: None
verbose (int):
0 (quiet): No output
1 (default): Print model summary
2 (verbose): Show weight and bias layers in full detail
Default: 1
*args, **kwargs:
Other arguments used in `model.forward` function.
Return:
ModelStatistics object
See torchsummary/model_statistics.py for more information.