from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(64 * 3 * 8 * 8, 1000)
self.fc2 = nn.Linear(1000, 200)
self.fc3 = nn.Linear(200, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
data = data.view(-1, args.batch_size*3*8*8)
target = target.view(-1, args.batch_size)
output = model(data)
loss = F.nll_loss(output, target[0])
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.view(-1, args.test_batch_size*3*8*8)
target = target.view(-1, args.test_batch_size)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target[0], reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='FNN')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size', type=int, default=64, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# Transforms
simple_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
# Dataset
train_dataset = datasets.ImageFolder('data/train/', simple_transform)
valid_dataset = datasets.ImageFolder('data/valid/', simple_transform)
# Data loader
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset=valid_dataset, batch_size=args.test_batch_size, shuffle=False,
num_workers=2)
model = Net().to(device)
optimizer = optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "fnn.pt")
if __name__ == '__main__':
main()
The view
operations of data
and target
look strange.
Currently you are using
data = data.view(-1, args.test_batch_size*3*8*8)
target = target.view(-1, args.test_batch_size)
Generally and also based on your model code, you should provide the data as [batch_size, in_features]
and the target as [batch_size]
containing class indices.
Could you change that and try to run your code again?
PS: I’ve formatted your code for better readability. You can add code snippets using three backticks ```
Thanks for your quick response. Here I mention what exactly I want to do to help as much as possible. I have in data
the folderstrain
and valid
and within them 9 folders that you would be the classes. They contain RGB images of 8x8 pixels. The idea is to use an MLP to classify future 8x8 images in one of these 9 classes.
I am using batch_size = 8. Here I made some changes according to your comments. You will tell me the necessary changes to make to meet the objective.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3 * 8 * 8, 1000)
self.fc2 = nn.Linear(1000, 200)
self.fc3 = nn.Linear(200, 9)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
data = data.view(args.batch_size, 3 * 8 * 8)
target = target.view(args.batch_size, -1)
output = model(data)
loss = F.nll_loss(output, target[0])
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.view(args.test_batch_size, 3 * 8 * 8)
target = target.view(args.test_batch_size, -1)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target[0], reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='FNN')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size', type=int, default=8, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# Transforms
simple_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
# Dataset
train_dataset = datasets.ImageFolder('data/train/', simple_transform)
valid_dataset = datasets.ImageFolder('data/valid/', simple_transform)
# Data loader
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset=valid_dataset, batch_size=args.test_batch_size, shuffle=False,
num_workers=2)
model = Net().to(device)
optimizer = optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "fnn.pt")
if __name__ == '__main__':
main()
I think target
should be viewed as target = target.view(args.batch_size)
.
Did you try to run the code and checked for other errors?
Perfect.
Making target = target.view (args.batch_size)
resolves the problem.
Finally the code was as follows:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3 * 8 * 8, 1000)
self.fc2 = nn.Linear(1000, 200)
self.fc3 = nn.Linear(200, 9)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
data = data.view(args.batch_size, 3 * 8 * 8)
target = target.view(args.batch_size)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data = data.view(args.test_batch_size, 3 * 8 * 8)
target = target.view(args.test_batch_size)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='FNN')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size', type=int, default=8, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# Transforms
simple_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
)
# Dataset
train_dataset = datasets.ImageFolder('data/train/', simple_transform)
valid_dataset = datasets.ImageFolder('data/valid/', simple_transform)
# Data loader
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset=valid_dataset, batch_size=args.test_batch_size, shuffle=False,
num_workers=2)
model = Net().to(device)
optimizer = optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "fnn.pt")
if __name__ == '__main__':
main()
Again thanks.
Hi @ptrblck ,
Please could you explain in brief why these error occurs ?
Also what are the different case scenario these cases occurs , in simple codes?
Thankyou for answer in advance
I had got same error but got no idea where could it occur
Summary
model = Classifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
epochs = 30
steps = 0
train_losses, test_losses = [], []
for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
optimizer.zero_grad()
log_ps = model(images)
loss = criterion(log_ps, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
else: #after each loop complete
test_loss = 0
train_loss = 0
## TODO: Implement the validation pass and print out the validation accuracy
with torch.no_grad():
for images, labes in testloader:
log_ps = model(images)
test_loss += criterion(log_ps, labels) #update test loss
ps = torch.exp(log_ps) #actual probs
top_p, top_class = ps.topk(1, dim=1) #gives actual predicted classes
equals = top_class == labels.view(*top_class.shape) #whether predicted class match with actual class
accuracy += torch.mean(equals.type(torch.FloatTensor))
train_losses.append(running_loss/len(trainloader))
test_losses.append(test_loss/len(testloader))
This error is raised, if the batch sizes of the model output and target do not match.
E.g. this small code example shows the issue, where the data uses a batch size of 5, while the target uses one of 4:
nb_classes = 10
output = torch.randn(5, nb_classes, requires_grad=True) # [batch_size, nb_classes]
target = torch.randint(0, nb_classes, (4,)) # [batch_size]
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
> ValueError: Expected input batch_size (5) to match target batch_size (4).
Check the shapes of log_ps
and labels
in your code and make sure they have the same batch size.
oh my stupid fault , I was having typo of labes instead labels and was working previously written labels variable in criterion. Thank you I appretiate. .
Hello, am having a similar error ,
Am doing Phase retrieval in digital holography but the error below is messing me up. any help will be greatly appreciated.
Here is CN network, and am using a batch size of 2.
class Network(nn.Module):
def init(self, in_channels=2 ):
super(Network,self).__init__()
self.Conv1 =nn.Sequential(
nn.Conv2d(in_channels,16, 3,1,1),nn.ReLU(),
nn.MaxPool2d(3,1,1),
nn.Conv2d(16,2, 3),nn.ReLU())
self.Conv2 = nn.Sequential(
nn.Conv2d(in_channels, 16,3,1,1),nn.ReLU(),
nn.MaxPool2d(3,1,1),
nn.Conv2d(16, 4, 3),nn.ReLU())
self.Conv3 = nn.Sequential(
nn.Conv2d(in_channels, 16,3,1,1),nn.ReLU(),
nn.MaxPool2d(3,1,1),
nn.Conv2d(16, 8, 3),nn.ReLU())
self.Conv4 =nn.Sequential(
nn.Conv2d(in_channels, 16, 3,1,1),nn.ReLU(),
nn.MaxPool2d(3,1,1),
nn.Conv2d(16,16, 3),nn.ReLU())
self.Upsample1 = nn.Sequential(
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode="bilinear"),
nn.Conv2d(2, 4, 3),nn.ReLU(),
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode='bilinear'),
nn.Conv2d(4, 8, 3),nn.ReLU(),
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode='bilinear'),
nn.Conv2d(8, 16, 3),nn.ReLU())
self.Upsample2 = nn.Sequential(
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode='bilinear'),
nn.Conv2d(4, 8, 3),nn.ReLU(),
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode='bilinear'),
nn.Conv2d(8, 16, 3),nn.ReLU())
self.Upsample3 = nn.Sequential(
nn.Upsample(size=(256, 256), scale_factor=None,align_corners=True, mode='bilinear'),
nn.Conv2d(8, 16, 3),nn.ReLU())
self.Conv1a = nn.Conv2d(16, 16, 3)
self.Conv2a = nn.Conv2d(16, 16, 3)
self.Conv3a = nn.Conv2d(16, 16, 3)
self.Conv4a = nn.Conv2d(16, 16, 3)
self.Conv5 = nn.Conv2d(16, 64, 3)
def forward(self,x):
# print(x.size)
w = self.Conv1(x)
w = self.Upsample1(w)
w= F.relu(self.Conv1a(w))
k = self.Conv2(x)
k = self.Upsample2(k)
k= F.relu(self.Conv2a(k))
y = self.Conv3(x)
y = self.Upsample3(y)
y= F.relu(self.Conv3a(y))
z = self.Conv4(x)
z= F.relu(self.Conv4a(z))
print('*****',w.shape,k.shape,y.shape,z.shape)
p = torch.cat([w, k, y, z], 0)
print(p.shape)
q = self.Conv5(p)
print(q.shape)
return p
net= Network()
Below is the error am getting…
***** torch.Size([1, 16, 252, 252]) torch.Size([1, 16, 252, 252]) torch.Size([1, 16, 252, 252]) torch.Size([1, 16, 252, 252])
torch.Size([4, 16, 252, 252])
torch.Size([4, 64, 250, 250])
output_shape torch.Size([4, 16, 252, 252])
Traceback (most recent call last):
File “E:\DMU\pytnewCNN.ipy”, line 201, in
loss =criterion(outputs, labels)
File “C:\Users\Hp\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “C:\Users\Hp\anaconda3\lib\site-packages\torch\nn\modules\loss.py”, line 932, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File “C:\Users\Hp\anaconda3\lib\site-packages\torch\nn\functional.py”, line 2317, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File “C:\Users\Hp\anaconda3\lib\site-packages\torch\nn\functional.py”, line 2113, in nll_loss
.format(input.size(0), target.size(0)))
ValueError: Expected input batch_size (4) to match target batch_size (2).
I’m not familiar with your model, but this line of code:
p = torch.cat([w, k, y, z], 0)
would concatenate the different activations in the batch dimension, which is most likely wrong.
If you would like to concatenate the activations instead in the channel dimension, you would have to use dim=1
.
Currently this line is changing the batch size, which is probably creating the error.
I’m encountering a similar error.
It trains for almost an entire epoch and fails towards in last few steps.
ValueError: Expected input batch_size (1) to match target batch_size (0).
I think there is something wrong with the way I’m setting up the dimensions of the linear layers (specifically linear1) but shouldn’t the model throw an error right at the beginning of training?
The inputs to the MultiheadAttention
are enocder and decoder outputs both of dims [2, 1024, 768].
Output of MultiheadAttention
is also [2, 1024, 768]
I use self.linear1
as self.linear1(multi_attn_output.reshape(-1, 1024*768))
The output of self.linear1
also looks alright : [2, 512]
alpha = 0.5
class CrossAttentionSummarizer(pl.LightningModule):
def __init__(self):
super(CrossAttentionSummarizer, self).__init__()
self.summarizer_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
self.qa_encoder = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
self.multihead_attn = nn.MultiheadAttention(embed_dim=768, num_heads=4, batch_first=True)
self.linear1 = nn.Linear(1024*768, 512)
self.linear2 = nn.Linear(512, 2, bias=False)
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, question_passage_input_ids, question_passage_attention_mask, question_labels, input_ids, attention_mask, decoder_attention_mask, labels=None):
summarizer_output = self.summarizer_model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask
)
qa_output = self.qa_encoder(
question_passage_input_ids,
question_passage_attention_mask,
question_labels
)
decoder_output = summarizer_output[3]
encoder_output = qa_output[2]
multi_attn_output, multi_attn_output_weights = self.multihead_attn(decoder_output, encoder_output, encoder_output)
lin_output = self.linear1(multi_attn_output.reshape(-1, 1024*768))
cls_outputs = self.linear2(lin_output)
cls_outputs = nn.functional.softmax(cls_outputs, dim=1)
cls_preds = torch.argmax(cls_outputs, dim=1)
cls_pred_loss = self.ce_loss(cls_outputs, question_labels.type(torch.int64).squeeze())
return summarizer_output.loss, summarizer_output.logits, cls_pred_loss, cls_preds
Check if you are using squeeze()
somewhere without specifying the dim
argument, as it seems you are accidentally dropping the batch dimension with a size of 1.
I use squeeze()
in
cls_pred_loss = self.ce_loss(cls_outputs, question_labels.type(torch.int64).squeeze())
question_labels
are of type int, 0 or 1. And batched by the DataLoader.
Removing that causes the dimension of the classification labels to botch up…
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
This is a binary classification. I’m using 2 output heads and applying torch.argmax
for preds and CrossEntropyLoss()
for the loss
You shouldn’t remove the squeeze
operation, but should be explicit in which dimension should be removed.
E.g. if you want to drop dim1
use:
# batch_size 2
x = torch.randn(2, 1)
y = x.squeeze(dim=1)
print(y.shape)
# torch.Size([2])
# batch_size 1
x = torch.randn(1, 1)
y = x.squeeze(dim=1)
print(y.shape)
# torch.Size([1])
# batch_size 1
x = torch.randn(1, 1)
y = x.squeeze()
print(y.shape)
# torch.Size([]) !!!!
Trying it
I have one other doubt …
In :
cls_pred_loss = self.ce_loss(cls_outputs, question_labels.type(torch.int64).squeeze(dim=1))
the dimension of cls_outputs
is [2,2] (batch_first=True) and that of question_labels
is [2,1]
So, in CrossEntropyLoss()
I’m using the outputs of the 2 logits cls_output
and a class label 0/1.
I’m trying to conform to the definition of CELoss() as per the documentation:
input has to be a 2D Tensor of size (minibatch, C).
This criterion expects a class index (0 to C-1) as the target for each value of a 1D tensor of size minibatch
I hope this implementation is correct (?)
And finally I use dim
in the prediction and loss layer too…
cls_outputs = nn.functional.softmax(cls_outputs, dim=1)
cls_preds = torch.argmax(cls_outputs, dim=1)