import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.models import resnet18
model = resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
input = torch.rand(bs,3, 256,256)
target = torch.rand(bs, 1000)
# target = target.long()
# before we learn
p1 = model.fc.weight
od1 = model.state_dict()
output = model(input)
loss_fn = torch.nn.CrossEntropyLoss()
# loss_fn = torch.nn.BCEWithLogitsLoss()
loss = loss_fn(output, target)
if optimizer is not None:
# backward pass + optimize
p2 = model.fc.weight
od2 = model.state_dict()
torch.nn.CrossEntropyLoss() should be the main loss function (criterion) for ImageNet classification. I just tried to test the code and got this error I could not fix fast.
What I am expecting at the end is that print(torch.equal(p1,p2)) will return False. This means that the parameter before and after the single optimization step will be different ( the model learned something ).
If I set the torch.nn.BCEWithLogitsLoss() it will work but the params p1 and p2 will be just the equal.
I assume that by “got this error” you mean the error message
in your post title, “Expected object of scalar type Long but got
scalar type Float.” This was to be expected.
CrossEntropyLoss takes integer class labels for its target,
specifically of type LongTensor. In your code you are passing
it a FloatTensor.
For random “toy” data, you probably want something like target = torch.randint (nClass, (bs,)), where nClass is the number of classes in your classification problem.
(I see that you commented out the line of code # target = target.long(). Didn’t that fix this particular error for you.)
(As a further note, BCEWithLogitsLoss() does take a FloatTensor target, so it would not give you this error. But,
yes, for a multiclass – i.e., not binary – classification problem
you would want CrossEntropyLoss.)
p1 is a reference to model.fc.weight. (In python (essentially)
all variables are references to objects.) So when your optimizer
updates model.fc.weight, p1 also “changes” in that the object p1 refers to has changed.
Try setting p1 = model.fc.weight.clone(). Now p1 will refer
to a new copy of the data in model.fc.weight that won’t be
changed when model.fc.weight itself is changed.
Or you could print out (with enough precision to see a small
change) a couple of elements of model.fc.weight before and
after the optimization step and note that they change.