Code:
import torch
import torchvision.models
from torchinfo import summary
import torch.nn as nn
class Identity(nn.Module):
def init(self):
super(Identity, self).init()
def forward(self, x):
print(x.shape)
return x
model1=torchvision.models.vit_h_14(weights=‘IMAGENET1K_SWAG_E2E_V1’)
for index in range(12,32):
model1.encoder.layers[index]=nn.Identity()
model1.encoder.ln=Identity()
model1.head=Identity()
model1.heads=Identity()
model1.fc=Identity()
print(model1)