Hi, I am a beginner at DL, I am tring to design a patch-based net that using pretrained efficientnet as feature extractor.
Currently, I divide one image into 35 patches and concatenate them to feed the model
input shape is [35, 3, 300, 300](I tried batch_size > 1 in Dataloader as someone told, but didn’t work for me)
main part of my Dataset class
b = torch.zeros((len(patches)), 3, 300, 300)
for i in range(len(patches)):
b[i] = transforms.ToTensor()(patches[i])
return b, label
main part of my net as follows
class Net():
self.model = EfficientNet.from_pretrained('efficientnet-b3')
...
def forward(self, x):
x = self.model.extract_features(x)
x = x * torch.sigmoid(x)
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
x = view(-1)
x = nn.Linear(35*1536, 1000, bias=True)
x = nn.ReLU()
x = nn.Dropout(p=0.5)
x = nn.Linear(1000, num_classes, bias=True)
return x
extract_features output shape is [35, 1536, 7, 7], so I flatten the tensor and feed it into a fc layer then get the result.
But when I call CrossEntropy(net(input), label), it gave error as title, full track as below
Traceback (most recent call last):
File "models.py", line 114, in <module>
train_loss_cur, train_acc_cur = train_fn(model, loader=train_loader)
File "models.py", line 66, in train_fn
loss = loss_fn(predictions, labels)
File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 932, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2317, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/harold/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1535, in log_softmax
ret = input.log_softmax(dim)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
I searched a lot, but I just can’t figure out how to fix it, could someone be kind enough to tell me where did I go wrong, how to fix it. if detailed information is needed, I would post it ASAP.
Thanks a lot advance, QAQ