I have a problem getting the output from nn. sequential as shown in the code below. I want to get the output of the three stages in the forward method as given below. These stages are called by the backbone in forward method. Each stage has one transformer layer. Precisely, I want to return x_stage1,x_stage2,x_stage3. Can you please help? @ptrblck
class VIT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
dim,
depth,
heads,
mlp_mult,
stages = 3,
dim_key = 32,
dim_value = 64,
dropout = 0.,
num_distill_classes = None
):
super().__init__()
dims = cast_tuple(dim, stages)
depths = cast_tuple(depth, stages)
layer_heads = cast_tuple(heads, stages)
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
)
fmap_size = image_size // (2 ** 4)
layers = []
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
is_last = ind == (stages - 1)
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
if not is_last:
next_dim = dims[ind + 1]
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
fmap_size = ceil(fmap_size / 2)
self.backbone = nn.Sequential(*layers)
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...')
)
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.pool(x)
out = self.mlp_head(x)
distill = self.distill_head(x)
if exists(distill):
return out, distill
return out, [x,x_stage1,x_stage2,x_stage3]