Hi, I am a beginner at DL, I am tring to design a patch-based net that using pretrained efficientnet as feature extractor.
Currently, I divide one image into 35 patches and concatenate them to feed the model
input shape is [35, 3, 300, 300](I tried batch_size > 1 in Dataloader as someone told, but didn’t work for me)
main part of my Dataset class
b = torch.zeros((len(patches)), 3, 300, 300) for i in range(len(patches)): b[i] = transforms.ToTensor()(patches[i]) return b, label
main part of my net as follows
class Net(): self.model = EfficientNet.from_pretrained('efficientnet-b3') ... def forward(self, x): x = self.model.extract_features(x) x = x * torch.sigmoid(x) x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) x = view(-1) x = nn.Linear(35*1536, 1000, bias=True) x = nn.ReLU() x = nn.Dropout(p=0.5) x = nn.Linear(1000, num_classes, bias=True) return x
extract_features output shape is [35, 1536, 7, 7], so I flatten the tensor and feed it into a fc layer then get the result.
But when I call CrossEntropy(net(input), label), it gave error as title, full track as below
Traceback (most recent call last): File "models.py", line 114, in <module> train_loss_cur, train_acc_cur = train_fn(model, loader=train_loader) File "models.py", line 66, in train_fn loss = loss_fn(predictions, labels) File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 932, in forward ignore_index=self.ignore_index, reduction=self.reduction) File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2317, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1535, in log_softmax ret = input.log_softmax(dim) IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
I searched a lot, but I just can’t figure out how to fix it, could someone be kind enough to tell me where did I go wrong, how to fix it. if detailed information is needed, I would post it ASAP.
Thanks a lot advance, QAQ