hey, could you please explain more about the usage .detach()
a bit more in the case of accumulating of loss. Sorry if my question is too basic, I’m still a new hand to PyTorch.
epoch_loss = 0
n_train = len(train_loader)
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
net.train()
imgs = batch['image']
true_masks = batch['labels']
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
logits,probs,masks_pred = net(imgs)
logits = torch.squeeze(logits,1)
loss = criterion(logits, true_masks)
epoch_loss += loss.item()/n_train
optimizer.zero_grad()
loss.backward()
optimizer.step()
# to be continued
If I understood it correctly in this case, epoch_loss = 0,
wouldn’t let the loss being saved from one iteration to the another. And, this is called .detach()
.
Thank u.