IMA_SIZE,BATCH_SIZE,EPOCH=299,1,5
class MeninDataset(Dataset):
def init(self, root_dir, train):
self.root_dir = root_dir
self.images = os.listdir(self.root_dir)
self.train = train
self.transforms_train=transforms.Compose([
transforms.Resize(IMA_SIZE),#inception 为299 其他为 224
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomApply([
transforms.RandomRotation(degrees=(0,360))],p=0.5),
transforms.ToTensor()
])
self.transforms_test=transforms.Compose([
transforms.Resize(IMA_SIZE),#inception 为299 其他为 224
transforms.ToTensor()
])
def __getitem__(self, index): # 根据索引返回图像与标签
if self.train==True:
list = os.listdir(self.root_dir)
img = Image.open(self.root_dir + list[index])
img = self.transforms_train(img)
else:
list = os.listdir(self.root_dir)
img = Image.open(self.root_dir + list[index])
img = self.transforms_test(img)
img = img.expand(3,IMA_SIZE,IMA_SIZE)
if list[index].split('.')[0][0] == 'l':
lab = 0
else:
lab = 1
return img, lab
def __len__(self):
if self.train==True:
list = os.listdir(self.root_dir)
else:
list = os.listdir(self.root_dir)
len_data = len(list)
return len_data
class Net(nn.Module):
def init(self, model):
super(Net, self).init()
self.resnet = nn.Sequential(*list(model.children())[:-1])
self.flatten = nn.Flatten()
self.linear = nn.Linear(2048,2, bias=True)
def forward(self, x):
x = self.resnet(x)
x = self.flatten(x)
x = self.linear(x)
return x
train_data = MeninDataset(root_dir=’./train_fig/’, train=True)
test_data = MeninDataset(root_dir=’./test_fig/’, train=False)
ld_train = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
ld_test = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
cuda = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
model = models.resnet34(pretrained=True).to(cuda)
#model = Net(model).to(cuda)
scaler = torch.cuda.amp.GradScaler()
optimizer = optim.Adam(model.parameters(), weight_decay=1e-4, lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
##训练
tape_trainAcc, tape_testAcc = [], []
for epoch in range(EPOCH):
train_correct, test_correct, train_total, test_total = 0, 0, 0, 0
model.train()
for _, (train_x, train_y) in enumerate(ld_train):
train_x, train_y = train_x.to(cuda), train_y.to(cuda)
if True: # 半精度
optimizer.zero_grad()
with amp.autocast():
train_pred_y = model(train_x)
loss = loss_fn(train_pred_y, train_y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if False: # 全精度
optimizer.zero_grad()
train_pred_y = model(train_x)
loss = loss_function(train_pred_y, train_y)
loss.backward()
optimizer.step()
_, train_pred = torch.max(train_pred_y.data, 1)
train_total += train_y.size(0)
train_correct += (train_pred == train_y).sum().item()
model.eval()
with torch.no_grad():
for _, (test_x, test_y) in enumerate(ld_test):
test_x, test_y = test_x.to(cuda), test_y.to(cuda)
test_pred_y = model(test_x)
_, test_pred = torch.max(test_pred_y.data, dim=1) # 预测结果
test_correct += (test_pred == test_y).sum().item()
test_total += test_y.size(0)
tape_trainAcc.append(100 * train_correct / train_total)
tape_testAcc.append(100 * test_correct / test_total)
print('epoch:{}'.format(epoch), ';',
'train Acc:%.2f' % (100 * train_correct / train_total), ';',
'test Acc:%.2f' % (100 * test_correct / test_total))
#记录
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
"tape_trainAcc":tape_trainAcc,
"tape_testAcc": tape_testAcc,
}
if epoch % 10 == 0 or epoch==(EPOCH-1):
torch.save(checkpoint, './checkpoint/checkpoint.pth')
if tape_trainAcc[epoch] > 90:
torch.save(model.state_dict(), './checkpoint/model_{}.pth'.format(epoch))
#torch.cuda.empty_cache()
##断点
checkpoint = torch.load(’./checkpoint/checkpoint.pth’)#断点模型
model.load_state_dict(checkpoint[‘net’])
model.to(cuda)
optimizer.load_state_dict(checkpoint[‘optimizer’])
tape_trainAcc, tape_testAcc= checkpoint[‘tape_trainAcc’],checkpoint[‘tape_testAcc’]
start_epoch=checkpoint[‘epoch’]+1
end_epoch =start_epoch+8 #训练次数
for epoch in range(start_epoch, end_epoch):
train_correct, test_correct, train_total, test_total = 0, 0, 0, 0
model.train()
for _, (train_x, train_y) in enumerate(ld_train):
train_x, train_y = train_x.to(cuda), train_y.to(cuda)
if True:#半精度
optimizer.zero_grad()
with amp.autocast():
train_pred_y = model(train_x)
loss = loss_fn(train_pred_y, train_y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if False:#全精度
optimizer.zero_grad()
train_pred_y = model(train_x)
loss = loss_function(train_pred_y, train_y)
loss.backward()
optimizer.step()
_, train_pred = torch.max(train_pred_y.data, 1)
train_total += train_y.size(0)
train_correct += (train_pred == train_y).sum().item()
model.eval()
with torch.no_grad():
for _, (test_x, test_y) in enumerate(ld_test):
test_x, test_y = test_x.to(cuda), test_y.to(cuda)
test_pred_y = model(test_x)
_, test_pred = torch.max(test_pred_y.data, dim=1) #预测结果
test_correct += (test_pred == test_y).sum().item()
test_total += test_y.size(0)
tape_trainAcc.append(100 * train_correct / train_total)
tape_testAcc.append(100 * test_correct / test_total)
print('epoch:{}'.format(epoch),';',
'train Acc:%.2f'%(100 * train_correct / train_total),';',
'test Acc:%.2f'%(100 * test_correct / test_total))
#记录
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
"tape_trainAcc": tape_trainAcc,
"tape_testAcc": tape_testAcc,
}
if epoch % 10 == 0:
torch.save(checkpoint, './checkpoint/checkpoint.pth')
if tape_trainAcc[epoch] > 90:
torch.save(model.state_dict(), './checkpoint/Spark_{}.pth'.format(epoch))