Is there any methods(or tools) to track(or debug) tensor.size?

Hi guys!

I always ran into dimension mismatch problems after a lot of view/transpose/unsqueeze operations. And It’s time-consuming to debug where did I make wrong tensor size.

I’d like to know if there is any methods(or tools) to track(or debug) tensor.size?


einops is the solution to all your problems :slight_smile:

I’m not quite sure what your goal is, but you could look into hooks. Below is a complete example that prints the shapes of the input and output tensors for each module.

That being said, it’s not only the shape that matters. Particularly when using view() it’s easy to get the right shape but to mess up the tensors; cf. this post.

import torch
import torch.nn as nn

class Hook():
    def __init__(self, module, backward=False):
        self.module = module
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
            self.hook = module.register_backward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    def close(self):
class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, output_size):
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=False) = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.Linear(64, output_size),

    def forward(self, X):
        out = self.embed(X)
        out = out.transpose(0, 1)
        out, (h, c) = self.lstm(out)
        log_probs =[-1])
        return log_probs        
# Instantiate the model (simple LSTM classifier)
model = SimpleLSTM(100, 300, 512, 3)

# Attache a forward hook to each layer
forward_hooks = [Hook(layer[1]) for layer in list(model.named_modules())]

# Create random input patch and pump it through the model
x = torch.randint(vocab_size, (batch_size, seq_len))
out = model(x)

# Print input and output shapes of each layer
for hook in forward_hooks:
    if isinstance(hook.input, torch.Tensor):
        input_shape = hook.input.shape
        input_shape = hook.input[0].shape
    if isinstance(hook.output, torch.Tensor):
        output_shape = hook.output.shape
        output_shape = hook.output[0].shape
    print("Module:", hook.module)
    print("Input shape:\t", input_shape)
    print("Output shape:\t", output_shape)

This is a very interesting and ambitious package! I’d have a deep look. Thanks!

The script is helpful, thank you!

And just a little fix for nn.ModuleList like Transformer module

from torch.nn import ModuleList

class Hook(object):
    def __init__(self, module, backward=False):
        if isinstance(module, ModuleList):
            module = module
            module = [module]
        self.module = module
        for sub_module in self.module:
            if not backward:
                self.hook = sub_module.register_forward_hook(self.hook_fn)
                self.hook = sub_module.register_backward_hook(self.hook_fn)
1 Like

Cool! Yeah, I’m sure that my minimal example did not cover all syntactic sugar of PyTorch. Thanks for posting your code!

Answer by myself so that others who come after me can check it out.

After some searching, the repo I found that best fit my needs is tsalib
Examples are as follows(gist)[tsalib demo · GitHub]

from tsalib import dim_var
from einops import rearrange

# Uppercase(abbr.):default_size
Batch = dim_var("Batch(b):64")
Dimension = dim_var("Dimension(d):128")
Heads = dim_var("Heads(h):8")
MaxLength = dim_var("MaxLength(l):80")
SrcVocabSize = dim_var("SrcVocabSize(sv)")
TargetVocabSize = dim_var("TargetVocabSize(tv)")

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, dimension, dropout=0.1):

        self.dimension = dimension
        self.d_k = dimension // heads
        self.heads = heads

        self.q_linear = nn.Linear(dimension, dimension)
        self.k_linear = nn.Linear(dimension, dimension)
        self.v_linear = nn.Linear(dimension, dimension)
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(dimension, dimension)

    def self_attention(self, q, k, v, mask=None):
        scores: (Batch, Heads, MaxLength, MaxLength) = matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            mask: (Batch, 1, 1, MaxLength) = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = functional.F.softmax(scores, dim=-1)

        if self.dropout is not None:
            scores = self.dropout(scores)

        output: (Batch, Heads, MaxLength, Dimension // Heads) = matmul(scores, v)
        return output

    def forward(self,
                q: (Batch, MaxLength, Dimension),
                k: (Batch, MaxLength, Dimension),
                v: (Batch, MaxLength, Dimension),
                mask: (Batch, 1, MaxLength) = None

        q: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.q_linear(q),
                                                                     'b l (h d) -> b h l d', h=self.heads)
        k: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.k_linear(k),
                                                                     'b l (h d) -> b h l d', h=self.heads)
        v: (Batch, MaxLength, Heads, Dimension // Heads) = rearrange(self.v_linear(v),
                                                                     'b l (h d) -> b h l d', h=self.heads)

        scores: (Batch, Heads, MaxLength, Dimension // Heads) = self.self_attention(q, k, v, mask)

        concat: (Batch, MaxLength, Dimension) = rearrange(scores, 'b h l d -> b l (h d)')

        output: (Batch, MaxLength, Dimension) = self.output(concat)
        return output

It also combines the repo einops mentioned by 2L. In fact, tsalib itself comes with a similar warp operation, but since tsalib itself has not been maintained for a long time (last commit 4years ago), it is still a choice to operate tensor. The currently maintaining stable repo einops has been replaced, and only the type annotation of tsalib has been retained.

Since I have just started using it, I am not sure whether coding like this is the best practice. I am still figuring it out.

Hmmm there is jaxtyping, that allows for tensor runtime type checking. But ofc it can only assert the input and output size of the tensors.