I’ve already seen How to find individual class accuracy. I am trying to get the same output (class accuracy) as in this. In my dataset I have three classes: Background, Class 1, Class 2. Each is mapped to 0, 1, 2. Currently use the following code to calculate the accuracy for the whole output:
def multi_acc(pred, label):
tags = torch.argmax(pred, dim=1)
corrects = (tags == label).float()
acc = corrects.sum() / corrects.numel()
acc = acc * 100
return acc
How can I modify the above to get the following output:
Class 1 loss: x, Class 1 Accuracy: y
Class 2 loss: x, Class 2 Accuracy: y
Class 3 loss: x, Class 3 Accuracy: y
Full code for reference:
def train_net(
net,
n_channels,
n_classes,
class_weights,
epochs=1,
val_precent=0.1,
batch_size=1,
lr=0.0001,
weight_decay=1e-8,
momentum=0.99,
):
print("Creating dataset for training...")
dataset = Loader(data_folder)
n_val = int(len(dataset) * val_precent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(
train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
)
val_loader = DataLoader(
val, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
)
global_step = 0
# optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.99, weight_decay=0.0005)
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
# optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(
weight=torch.Tensor(class_weights).to(device=device)
)
if wandb_track:
wandb.watch(net)
for epoch in range(epochs):
net.train()
tepoch_loss = 0
tepoch_acc = 0
vepoch_loss = 0
vepoch_acc = 0
once = True
for batch in train_loader:
imgs = batch["image"]
masks = batch["mask"]
assert imgs.shape[1] == n_channels, (
f"Network has been defined with {n_channels} input channels, "
f"but loaded images have {imgs.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
imgs = imgs.to(device=device, dtype=torch.float32)
masks = masks.to(device=device, dtype=torch.long)
masks_pred = net(imgs)
loss = criterion(masks_pred, masks.squeeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
tepoch_loss += loss.item()
tepoch_acc += multi_acc(masks_pred, masks)
global_step += 1
net.eval()
for batch in val_loader:
with torch.no_grad():
imgs = batch["image"]
masks = batch["mask"]
imgs = imgs.to(device=device, dtype=torch.float32)
masks = masks.to(device=device, dtype=torch.long)
masks_pred = net(imgs)
loss = criterion(masks_pred, masks.squeeze(1))
vepoch_loss += loss.item()
vepoch_acc += multi_acc(masks_pred, masks)
tepoch_loss /= n_train
tepoch_acc /= n_train
vepoch_loss /= n_val
vepoch_acc /= n_val
print(
"Epoch {0:} finished, Training loss: {1:.4f} [{2:.2f}%] Validation loss: {3:.4f} [{4:.2f}%]".format(
epoch + 1, tepoch_loss, tepoch_acc, vepoch_loss, vepoch_acc
)
)
if wandb_track:
wandb.log({"Test Accuracy": tepoch_acc, "Test Loss": tepoch_loss})
wandb.log(
{"Validation Accuracy": vepoch_acc, "Validation Loss": vepoch_loss}
)
try:
os.mkdir(model_path)
except OSError:
pass
torch.save(net.state_dict(), model_path + model_name)
if wandb_track:
torch.save(net.state_dict(), os.path.join(wandb.run.dir, model_name))
def multi_acc(pred, label):
tags = torch.argmax(pred, dim=1)
corrects = (tags == label).float()
acc = corrects.sum() / corrects.numel()
acc = acc * 100
return acc