hello everyone.
I have a convlstm model that I have trained. So after that, I am trying to use that pretrained model to extract some features in order to do a feature concatenation.
the size of the input is (64, 16, 3. 224. 224).
what I did is before passing the input data to the pretrained convlstm model, I resize the input data like this:
b, d, c, h, w = rgb_data.shape
rgb_data = rgb_data.view(b*d, c, h, w)
but I have encountered dimension error
Traceback (most recent call last):
File "plot.py", line 67, in <module>
listspafeatures, listlabels = get_Features(spatialmodel, train_loader)
File "plot.py", line 56, in get_Features
spatiallsttfeat, listlabels = get_output(spatialModel, dataloader)
File "plot.py", line 42, in get_output
rgb = spatialModel(rgb_data.to(device))
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 579, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 530, in check_forward_args
self.check_input(input, batch_sizes)
File "/home/coco/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 174, in check_input
raise RuntimeError(
RuntimeError: input must have 3 dimensions, got 2
below the entire code:
the Convlstm model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
# this encoder only supports resnet
class Encoder(nn.Module):
def __init__(self, backbone_name:str):
super(Encoder, self).__init__()
# select a model
if backbone_name == "resnet18":
resnet = resnet18(pretrained=True)
elif backbone_name == "resnet34":
resnet = resnet34(pretrained=True)
elif backbone_name == "resnet50":
resnet = resnet50(pretrained=True)
elif backbone_name == "resnet101":
resnet = resnet101(pretrained=True)
elif backbone_name == "resnet152":
resnet = resnet152(pretrained=True)
else:
assert False, f"'{backbone_name}' backbone is not supported"
self.out_features = resnet.fc.in_features
# remove a fully connected layer
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
# freeze all updatable weights of the encoder
self._freeze_all(self.encoder)
def _freeze_all(self, model:nn.Module):
for param in model.parameters():
param.requires_grad = False
def forward(self, x):
x = self.encoder(x).squeeze()
return x
# this convlstm only supports lstm
class ConvLSTM(nn.Module):
def __init__(self, backbone_name:str, num_classes:int, hidden_size:int = 1024, num_layers:int = 1, bidirectional:bool = True, drop_val = 0.5):
super(ConvLSTM, self).__init__()
# freeze
self.encoder = Encoder(backbone_name)
# updateable
self.lstm = nn.LSTM(self.encoder.out_features, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
self.classifier = nn.Sequential(
nn.Linear(2 * hidden_size if bidirectional else hidden_size, hidden_size),
#nn.BatchNorm1d(hidden_size),
nn.ReLU(),
#nn.Dropout(drop_val),
nn.Linear(hidden_size, num_classes),
)
self.dropout = nn.Dropout(drop_val)
def forward(self, x):
# get shape
b, d, c, h, w = x.shape
# get (spatial)feature of frames
x = x.view(b * d, c, h, w)
x = self.encoder(x)
# get (temporal)feature of frames
x = x.view(b, d, -1)
x = self.dropout(x)
x = self.lstm(x)[0][:, -1]
# get classifier scores
x = self.classifier(x)
return x
the module for the feature extraction
import torch
import os
import argparse
import torch.nn as nn
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from VideoDataset import VideoDataset
import os
import csv
import json
from torch.backends import cudnn
parser = argparse.ArgumentParser(description="Toyota Smart Home spatial stream on ConvLSTM")
parser.add_argument("--frames-path", default="./Data/mp4_frames/", type=str)
parser.add_argument("--csv-path", default="./Data/Labels/cross_subject/", type=str)
parser.add_argument("--cross-view", action="store_true")
parser.add_argument("--frame-size", default=224, type=int)
parser.add_argument("--sequence-length", default=16, type=int)
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--gpu-number", default=0, type=int)
args = parser.parse_args()
# get a device
if torch.cuda.is_available():
cudnn.benchmark = True
device = torch.device(f"cuda:{args.gpu_number}")
else:
device = torch.device(f"cpu")
def get_output(spatialModel:nn.Module, dataloader):
listrgbfeature = []; listlabels = []
for i, (rgb_data, labels) in enumerate(dataloader):
# penses a faire le control sur le feature dimension
b, d, c, h, w = rgb_data.shape
rgb_data = rgb_data.view(b*d, c, h, w)
rgb = spatialModel(rgb_data.to(device))
listrgbfeature.append(rgb)
listlabels.append(labels)
return listrgbfeature, listlabels
def get_Features(spatialModel:nn.Module, dataloader):
spatialModel = nn.Sequential(*list(spatialModel.children())[:-1])
for param in spatialModel.parameters():
param.requires_grad = False
spatiallsttfeat, listlabels = get_output(spatialModel, dataloader)
return spatiallsttfeat, listlabels
train = VideoDataset(frames_path=args.frames_path, csv_path=args.csv_path + "train.csv", frame_size=args.frame_size, sequence_length=args.sequence_length)
val = VideoDataset(frames_path=args.frames_path, csv_path=os.path.join(args.csv_path, "val.csv" if args.cross_view else "test.csv"), frame_size=args.frame_size, sequence_length=args.sequence_length)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val, batch_size=args.batch_size, shuffle=False, num_workers=0)
spatialmodel = torch.load("model_save/spatial_model.pth")
listspafeatures, listlabels = get_Features(spatialmodel, train_loader)