hello everyone here is my train code and I am trying to get train accuracy but outputs are too high for each epoch .Could you help me ?
import torch
from torch.utils.data import DataLoader
import time
import argparse
from progress.bar import IncrementalBar
import numpy as np
import logging
from dataset import Tissue
from dataset import transforms as T
from gan.generator import UnetGenerator
from gan.discriminator import ConditionalDiscriminator
from gan.criterion import GeneratorLoss, DiscriminatorLoss
from gan.utils import Logger
parser = argparse.ArgumentParser(prog=âtopâ, description=âTrain Pix2Pixâ)
parser.add_argument(ââepochsâ, type=int, default=100, help=âNumber of epochsâ)
parser.add_argument(ââdatasetâ, type=str, default=âtissueâ, help=âName of the dataset: [âtissueâ]â)
parser.add_argument(ââbatch_sizeâ, type=int, default=32, help=âSize of the batchesâ)
parser.add_argument(ââlrâ, type=float, default=0.02, help=âAdams learning rateâ)
args = parser.parse_args()
device = âcudaâ if torch.cuda.is_available() else âcpuâ
transforms = T.Compose([T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])])
models
print(âDefining models!â)
generator = UnetGenerator().to(device)
discriminator = ConditionalDiscriminator().to(device)
optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr)
loss functions
g_criterion = GeneratorLoss(alpha=100)
d_criterion = DiscriminatorLoss()
dataset
print(fâDownloading â{args.dataset.upper()}â dataset!â)
dataset = Tissue(root=â.', transform=transforms, download=False, mode=âtrainâ)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
Start of training process
print(âStart of training process!â)
logger = Logger(filename=args.dataset)
total_correct = 0
total_images = 0
for epoch in range(args.epochs):
ge_loss = 0.
de_loss = 0.
start = time.time()
bar = IncrementalBar(fâ[Epoch {epoch + 1}/{args.epochs}]', max=len(dataloader))
epoch_correct = 0
epoch_images = 0
for x, real in dataloader:
x = x.to(device)
real = real.to(device)
# Generator`s loss
fake = generator(x)
fake_pred = discriminator(fake, x)
g_loss = g_criterion(fake, real, fake_pred)
# Discriminator`s loss
fake = generator(x).detach()
fake_pred = discriminator(fake, x)
real_pred = discriminator(real, x)
d_loss = d_criterion(fake_pred, real_pred)
# Generator`s params update
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# Discriminator`s params update
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# add batch losses
ge_loss += g_loss.item()
de_loss += d_loss.item()
bar.next()
# Calculate accuracy per batch
with torch.no_grad():
generator.eval()
outputs = generator(x)
_, predicted = torch.max(outputs.data, 1)
total = real.size(0)
_, labels_indices = torch.max(real, 1)
correct = (predicted == labels_indices).sum().item()
epoch_correct += correct
epoch_images += total
bar.finish()
# obtain per epoch losses
g_loss_avg = ge_loss / len(dataloader)
d_loss_avg = de_loss / len(dataloader)
# Log losses
logger.add_scalar('generator_loss', g_loss_avg, epoch + 1)
logger.add_scalar('discriminator_loss', d_loss_avg, epoch + 1)
logger.save_weights(generator.state_dict(), 'generator')
logger.save_weights(discriminator.state_dict(), 'discriminator')
# Log accuracy per epoch
epoch_accuracy = (epoch_correct / epoch_images) * 100
print("Epoch {}: Accuracy of the network on the {} Train images: {:.2f} %".format(epoch + 1, epoch_images,
epoch_accuracy))
# accumulate correct predictions and total images processed across all epochs
total_correct += epoch_correct
total_images += epoch_images
# count timeframe
end = time.time()
tm = (end - start)
Calculate overall accuracy
overall_accuracy = (total_correct / total_images) * 100
print(âOverall Accuracy of the network on the {} Train images: {:.2f} %â.format(total_images, overall_accuracy))
logger.close()
print(âEnd of training process!â)
Defining models!
Downloading âTISSUEâ dataset!
Start of training process!
Epoch 1: Accuracy of the network on the 400 Train images: 3140211.25 %
Epoch 2: Accuracy of the network on the 400 Train images: 3021189.50 %
Epoch 3: Accuracy of the network on the 400 Train images: 2870963.25 %
Epoch 4: Accuracy of the network on the 400 Train images: 2798101.00 %
Epoch 5: Accuracy of the network on the 400 Train images: 2781629.50 %
Epoch 6: Accuracy of the network on the 400 Train images: 2778091.50 %
Epoch 7: Accuracy of the network on the 400 Train images: 2777310.75 %
Epoch 8: Accuracy of the network on the 400 Train images: 2777132.50 %
Epoch 9: Accuracy of the network on the 400 Train images: 2777086.50 %
Epoch 10: Accuracy of the network on the 400 Train images: 2777074.00 %
Epoch 11: Accuracy of the network on the 400 Train images: 2777070.00 %
Epoch 12: Accuracy of the network on the 400 Train images: 2777069.00 %
Epoch 13: Accuracy of the network on the 400 Train images: 2777067.50 %
Epoch 14: Accuracy of the network on the 400 Train images: 2777068.75 %