I have spent almost 3 days by now. All I am trying to do is a simple binary search for learning rate in case loss does not decrease properly. I may try TensorFlow for now (and cite that for this research instead of PyTorch if that works. I like PyTorch, but, I don’t think I have the time)
For future reference to those who are interested: This is what I have tried:
def test_CNN(CNN_model, device, test_DataLoader):
with torch.no_grad():
CNN_model.eval()
loss = 0; accuracy = 0
for (data, target) in test_DataLoader:
data = data.to(device); target = target.to(device)
output = CNN_model(data)
_, predicted = torch.max(output.data, 1)
batch_loss = criterion(output, target)
loss += batch_loss.item()
accuracy += (predicted == target).sum().item()
return loss, accuracy / len(test_DataLoader.dataset)
import pandas as pd
from time import time
import torch, torchvision
from common import CNN, criterion, test_CNN, to_csv
data = "C:\\Raghavendra\\research_peer_review\\LearningRate\\data\\"
def init_binary_search_tree():
"""
I hope this gets integrated into PyTorch torch optim. This dict version is bad. I am making it simple as of now.
"""
d = dict()
d[-2] = {"right": -1, "left": -3}
d[-6] = {"right": -5, "left": -7}
d[-10] = {"right": -9, "left": -11}
d[-14] = {"right": -13, "left": -15}
d[-4] = {"right": -2, "left": -6}
d[-12] = {"right": -10, "left": -14}
d[-8] = {"right": -4, "left": -12}
return d
def test_lr(lr, CNN_model, data, target, current_loss):
output = CNN_model(data); loss = criterion(output, target)
grad = torch.autograd.grad(loss, CNN_model.parameters())
original_parameters = [parameter.clone() for parameter in CNN_model.parameters()]
for parameter, grad_value in zip(CNN_model.parameters(), grad): parameter.data -= lr * grad_value
output = CNN_model(data); loss = criterion(output, target); future_loss = loss.item()
if future_loss > current_loss:
test_result = False
for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
else:
test_result = True
return test_result
def reset(d, data, target, CNN_model, current_loss):
node = -8; lr = 2 ** -16
original_parameters = [parameter.clone() for parameter in CNN_model.parameters()]
for _ in range(3):
# Check right node first; Check left node next; If none of those work, "stay" at the current node
lr_right = 2 ** d[node]["right"]; lr_left = 2 ** d[node]["left"]
test_right_result = test_lr(lr_right, CNN_model, data, target, current_loss)
if test_right_result == False:
for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
test_left_result = test_lr(lr_left, CNN_model, data, target, current_loss)
if test_left_result == False:
for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
lr = 2 ** node; break
else:
lr = lr_left; node = d[node]["left"]
else:
lr = lr_right; node = d[node]["right"]
print(lr)
return lr, CNN_model
def train_momentum_check_CNN(CSV, train_set, test_set, CNN_model, device, n_epochs=1, batch_size=1, batch_plot=1):
train_DataLoader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_DataLoader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
train_loss_list = list(); test_loss_list = list(); test_accuracy_list = list()
future_loss = 0; current_loss = 0
CNN_model.train(); lr = 1
for i in range(n_epochs):
for j, (data, target) in enumerate(train_DataLoader):
data = data.to(device); target = target.to(device)
output = CNN_model(data)
_, predicted = torch.max(output.data, 1)
current_loss = future_loss
loss = criterion(output, target)
future_loss = loss.item()
loss.backward()
for parameter in CNN_model.parameters(): parameter.data -= lr * parameter.grad
if future_loss > current_loss:
lr, CNN_model = reset(d, data, target, CNN_model, current_loss)
if j % batch_plot == 0:
test_loss, test_accuracy = test_CNN(CNN_model, device, test_DataLoader)
train_loss_list.append(loss.item())
test_loss_list.append(test_loss)
test_accuracy_list.append(test_accuracy)
df = pd.DataFrame()
df["train_loss"] = train_loss_list; df["test_loss"] = test_loss_list
df["test_accuracy"] = test_accuracy_list
df.to_csv(to_csv + CSV, index=False)
return
t0 = time()
d = init_binary_search_tree()
transforms = torchvision.transforms.ToTensor(); device = torch.device("cuda")
train_set = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=transforms)
test_set = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=transforms)
train_momentum_check_CNN(CSV="MNIST_momentum_check_CNN.csv", train_set=train_set, test_set=test_set, CNN_model=CNN().to(device), device=device,
n_epochs=3, batch_size=16, batch_plot=64)
print(time() - t0, "seconds")
Edit: The question stands. But, there is a business logic mistake in d
.