one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 2]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
import math
class Model(nn.Module):
def init(self):
super().init()
self.alpha = 0.7
self.base = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in list(self.base.parameters())[:-15]:
param.requires_grad = False
self.base.classifier = nn.Sequential()
self.base.fc = nn.Sequential()
self.block1 = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
)
self.block2 = nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 9)
)
self.block3 = nn.Sequential(
nn.Linear(128, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, 2)
)
self.optimizer1 = torch.optim.Adam([
{'params' : self.base.parameters(), 'lr': 1e-5},
{'params' : self.block1.parameters(), 'lr': 3e-4}
])
self.optimizer2 = torch.optim.Adam(self.block2.parameters(), lr = 3e-4)
self.optimizer3 = torch.optim.Adam(self.block3.parameters(), lr = 3e-4)
self.loss_fxn = nn.CrossEntropyLoss()
self.fruit_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = 9)
self.fresh_accuracy = torchmetrics.Accuracy(task = 'multiclass', num_classes = 2)
self.TRAIN_BATCHES = math.ceil(len(train_dataset)/BATCH_SIZE)
self.VAL_BATCHES = math.ceil(len(val_dataset)/BATCH_SIZE)
self.history = {'train_loss': [], 'val_loss': [],
'train_acc_fruit': [], 'train_acc_fresh': [],
'val_acc_fruit': [], 'val_acc_fresh': []}
def forward(self, x):
x = self.base(x)
x = self.block1(x)
y1, y2 = self.block2(x), self.block3(x)
return y1, y2
def train_step(self, x, y1, y2):
pred1, pred2 = self.forward(x)
l2 = self.loss_fxn(pred2, y2)
self.optimizer3.zero_grad()
l2.backward(retain_graph = True)
self.optimizer3.step()
l1= self.loss_fxn(pred1, y1)
self.optimizer2.zero_grad()
l1.backward(retain_graph = True)
self.optimizer2.step()
print(l1, l2,self.alpha)
loss = self.alpha * l1 + (1 - self.alpha) * l2
print(loss)
self.optimizer1.zero_grad()
loss.backward()
self.optimizer1.step()
fruit_acc = self.fruit_accuracy(torch.argmax(pred1, axis = 1), y1)
fresh_acc = self.fresh_accuracy(torch.argmax(pred2, axis = 1), y2)
return loss, fruit_acc, fresh_acc
def val_step(self, x, y1, y2):
with torch.no_grad():
pred1, pred2 = self.forward(x)
loss = self.alpha * self.loss_fxn(pred1, y1) + (1 - self.alpha) * self.loss_fxn(pred2, y2)
fruit_acc = self.fruit_accuracy(torch.argmax(pred1, axis = 1), y1)
fresh_acc = self.fresh_accuracy(torch.argmax(pred2, axis = 1), y2)
return loss, fruit_acc, fresh_acc
def update_history(self, train_loss, train_fruit, train_fresh, val_loss, val_fruit, val_fresh):
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['train_acc_fresh'].append(train_fresh)
self.history['train_acc_fruit'].append(train_fruit)
self.history['val_acc_fresh'].append(val_fresh)
self.history['val_acc_fruit'].append(val_fruit)
def train(self, epochs = 5):
torch.autograd.set_detect_anomaly(True)
for epoch in tqdm(range(epochs)):
train_loss, train_fruit, train_fresh = 0, 0, 0
val_loss, val_fruit, val_fresh = 0, 0, 0
for X, y1, y2 in tqdm(train_loader):
X, y1, y2 = [v.to(device) for v in (X, y1, y2)]
loss, fruit_acc, fresh_acc = self.train_step(X, y1, y2)
val_loss = val_loss+loss.item()
val_fruit = val_fruit+fruit_acc.item()
val_fresh = val_fresh+fresh_acc.item()
for X, y1, y2 in tqdm(val_loader):
X, y1, y2 = [v.to(device) for v in (X, y1, y2)]
loss, fruit_acc, fresh_acc = self.val_step(X, y1, y2)
val_loss = val_loss+loss.item()
val_fruit = val_fruit+fruit_acc.item()
val_fresh = val_fresh+fresh_acc.item()
train_loss, train_fruit, train_fresh = [x/self.TRAIN_BATCHES for x in (train_loss, train_fruit, train_fresh)]
val_loss, val_fruit, val_fresh = [x/self.VAL_BATCHES for x in (val_loss, val_fruit, val_fresh)]
self.update_history( train_loss, train_fruit, train_fresh, val_loss, val_fruit, val_fresh)
print("[Epoch: {}] Train: [loss: {:.3f}, fruit: {:.3f} fresh: {:.3f}] Val: [loss: {:.3f}, fruit: {:.3f} fresh: {:.3f}]".format(epoch, train_loss, train_fruit, train_fresh,
val_loss, val_fruit, val_fresh))
I can’t seem to find the error can someone help me with this.