Intermediate output from sequential layers and to use the output for further processing while training

I have a problem getting the output from nn. sequential as shown in the code below. I want to get the output of the three stages in the forward method as given below. These stages are called by the backbone in forward method. Each stage has one transformer layer. Precisely, I want to return x_stage1,x_stage2,x_stage3. Can you please help? @ptrblck

class VIT(nn.Module):
def __init__(
    self,
    *,
    image_size,
    num_classes,
    dim,
    depth,
    heads,
    mlp_mult,
    stages = 3,
    dim_key = 32,
    dim_value = 64,
    dropout = 0.,
    num_distill_classes = None
):
    super().__init__()

    dims = cast_tuple(dim, stages)
    depths = cast_tuple(depth, stages)
    layer_heads = cast_tuple(heads, stages)

    assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

    self.conv_embedding = nn.Sequential(
        nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
        nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
        nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
        nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
    )

    fmap_size = image_size // (2 ** 4)
    layers = []

    for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
        is_last = ind == (stages - 1)
        layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

        if not is_last:
            next_dim = dims[ind + 1]
            layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
            fmap_size = ceil(fmap_size / 2)

    self.backbone = nn.Sequential(*layers)

    self.pool = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        Rearrange('... () () -> ...')
    )

    self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
    self.mlp_head = nn.Linear(dim, num_classes)

def forward(self, img):
    x = self.conv_embedding(img)
    x = self.backbone(x)        
    x = self.pool(x)
    out = self.mlp_head(x)
    distill = self.distill_head(x)
    if exists(distill):
        return out, distill
    return out, [x,x_stage1,x_stage2,x_stage3]

It seems that x_stageX is undefined in the forward, so the return statement should yield an error.
Based on the description of your use case you would need to define these intermediate tensors either in the forward of your main model or in the submodels. If the latter case is not easily done, you could use e.g. forward hooks as described here.

Thank you :), I redefined these intermediate tensors in the forward of the main model. Thanks a lot.