I am wondering whether there is a way to extract a Swin-VIT backbone similar to resnet ?
I am attempting to train a few self-supervised learning algorithms, where I need to get just the backbone (feature extractor) and pass it to my self-supervised learning algorithm (i.e fix simCLR/SimSiam,Dino to the backbone).
With resnet this can be done quite easily.
resnet = resnet50()
# load pretrained weights : Function below just load the weights using torch.load
resnet = load_model_weights(..., pretrained_weight_file, resnet, num_classes = 51)
# Extract the backbone without the MLP Head
resnet_bb = torch.nn.Sequential(*list(resnet_bb_model.children())[:-1])
# Incorporate backbone to Self-Supervised model
There doesn’t seem to be a simple solution such as this with swin-vit. I have been using the Swin Transformer from this repository, as it contains pretrained weights trained on a large remote sensing dataset (enter link description here)
To my knowledge you can use the swin_vit.forward_features()
to get output of the swin-vit backbone
sys.path.append("RSP/Scene Recognition/models")
from swin_transformer import SwinTransformer
swin_vit = SwinTransformer()
# Load pretrained weights : Just using torch load to load model weights
swin_vit = load_model_weights(..., pretrained_weight_file, swin_vit, num_classes = 51)
# At this point you can get the features using
out = swin_vit.forward_features(img_tensor)
However, I am wondering if there is a way to get just the forward_features as seperate class that would mimic the resnet_backbone.
The reason is when I pass it to my self-supervised learning algorithm like this …
class SimSiam(pl.lightningModule):
def __init__(..., backbone_model):
self.backbone_model = backbone_model
...
def forward(self, x):
f = self.backbone_model.forward_features(X) #(b,3,256,256) -> ... -> #(b, 768)
z = self.projection_head(f) # (b,768) -> ... -> (b,2048)
p = self.prediction_head(z) # (b,2048) -> (b,512) -> (b,2048)
z = z.detach() #SimSiams stop the gradient to prevent collapse
return z,p
...
… during training the self-supervised algorithm (SimSiam), I you would get the following error. The reason being the entire backbone model being passed to the SimSiam
class, but only the model.forward_features()
part being used by this class (i.e the rest of the Swin-vit such as the MLP head not being used).
[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value strategy='ddp_find_unused_parameters_true' or by setting the flag in the strategy with strategy=DDPStrategy(find_unused_parameters=True).
This can be avoided in pytorch lightning by passing the strategy as mentioned in the error. But with resnet you would never get such an issue since you can pass just the backbone (minus MLP head) instead of the entire backbone model (with MLP head).
I might be incorrect, but I feel there could be issues when loading weights once training is completed (self-supervised learning model with backbone training). Since only part of the model weights are being updated (if we use strategy=‘ddp_find_unused_parameters_true’). It might be also quite fine but to avoid all of this I am wondering if there is a way to just pass the forward features (i.e swin-vit backbone) similar to how its done with resnet models ?