How do I get gradients of a CNN one time only (without making it sticky)?

I am asking this to confirm that my intuition is correct.

If I have a CNN like this:

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 5, padding=2)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.fc1 = torch.nn.Linear(8 * 7 * 7, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = ReLU(self.conv1(x))
        x = self.pool1(x)
        x = ReLU(self.conv2(x))
        x = self.pool2(x)
        x = ReLU(self.fc1(x.view(-1, 8 * 7 * 7)))
        x = self.fc2(x)
        return x

I just want to compute gradients once and manually update it if needed (like weight -= lr * weight_grad)

Now, loss backward() can retain the gradients. Let us say I don’t want to retain any gradients, play with autograd configurations etc.

So, to get the gradients one time only, is this correct?

criterion = torch.nn.CrossEntropyLoss()
output = CNN_model(data); loss = criterion(output, target)
grad = torch.autograd.grad(loss, CNN_model.parameters())

for parameter, grad_value in zip(CNN_model.parameters(), grad): parameter.data -= lr * grad_value
1 Like

I have spent almost 3 days by now. All I am trying to do is a simple binary search for learning rate in case loss does not decrease properly. I may try TensorFlow for now (and cite that for this research instead of PyTorch if that works. I like PyTorch, but, I don’t think I have the time)

For future reference to those who are interested: This is what I have tried:


def test_CNN(CNN_model, device, test_DataLoader):
    with torch.no_grad():
        CNN_model.eval()

        loss = 0; accuracy = 0
        for (data, target) in test_DataLoader:
            data = data.to(device); target = target.to(device)
            output = CNN_model(data)
            _, predicted = torch.max(output.data, 1)
            batch_loss = criterion(output, target)
            loss += batch_loss.item()
            accuracy += (predicted == target).sum().item()

    return loss, accuracy / len(test_DataLoader.dataset)
import pandas as pd
from time import time
import torch, torchvision
from common import CNN, criterion, test_CNN, to_csv

data = "C:\\Raghavendra\\research_peer_review\\LearningRate\\data\\"

def init_binary_search_tree():
    """
    I hope this gets integrated into PyTorch torch optim. This dict version is bad. I am making it simple as of now.
    """
    d = dict()

    d[-2] = {"right": -1, "left": -3}
    d[-6] = {"right": -5, "left": -7}
    d[-10] = {"right": -9, "left": -11}
    d[-14] = {"right": -13, "left": -15}

    d[-4] = {"right": -2, "left": -6}
    d[-12] = {"right": -10, "left": -14}

    d[-8] = {"right": -4, "left": -12}
    return d

def test_lr(lr, CNN_model, data, target, current_loss):
    output = CNN_model(data); loss = criterion(output, target)
    grad = torch.autograd.grad(loss, CNN_model.parameters())

    original_parameters = [parameter.clone() for parameter in CNN_model.parameters()]

    for parameter, grad_value in zip(CNN_model.parameters(), grad): parameter.data -= lr * grad_value
    output = CNN_model(data); loss = criterion(output, target); future_loss = loss.item()

    if future_loss > current_loss:
        test_result = False
        for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
    else:
        test_result = True

    return test_result

def reset(d, data, target, CNN_model, current_loss):
    node = -8; lr = 2 ** -16
    original_parameters = [parameter.clone() for parameter in CNN_model.parameters()]

    for _ in range(3):
        # Check right node first; Check left node next; If none of those work, "stay" at the current node
        lr_right = 2 ** d[node]["right"]; lr_left = 2 ** d[node]["left"]
        test_right_result = test_lr(lr_right, CNN_model, data, target, current_loss)

        if test_right_result == False:
            for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
            test_left_result = test_lr(lr_left, CNN_model, data, target, current_loss)
            if test_left_result == False:
                for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
                lr = 2 ** node; break
            else:
                lr = lr_left; node = d[node]["left"]
        else:
            lr = lr_right; node = d[node]["right"]

    print(lr)
    return lr, CNN_model

def train_momentum_check_CNN(CSV, train_set, test_set, CNN_model, device, n_epochs=1, batch_size=1, batch_plot=1):
    train_DataLoader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_DataLoader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    train_loss_list = list(); test_loss_list = list(); test_accuracy_list = list()
    future_loss = 0; current_loss = 0

    CNN_model.train(); lr = 1
    for i in range(n_epochs):
        for j, (data, target) in enumerate(train_DataLoader):
            data = data.to(device); target = target.to(device)
            output = CNN_model(data)
            _, predicted = torch.max(output.data, 1)

            current_loss = future_loss
            loss = criterion(output, target)
            future_loss = loss.item()

            loss.backward()
            for parameter in CNN_model.parameters(): parameter.data -= lr * parameter.grad

            if future_loss > current_loss:
                lr, CNN_model = reset(d, data, target, CNN_model, current_loss)

            if j % batch_plot == 0:
                test_loss, test_accuracy = test_CNN(CNN_model, device, test_DataLoader)

                train_loss_list.append(loss.item())
                test_loss_list.append(test_loss)
                test_accuracy_list.append(test_accuracy)

    df = pd.DataFrame()
    df["train_loss"] = train_loss_list; df["test_loss"] = test_loss_list
    df["test_accuracy"] = test_accuracy_list
    df.to_csv(to_csv + CSV, index=False)

    return

t0 = time()

d = init_binary_search_tree()
transforms = torchvision.transforms.ToTensor(); device = torch.device("cuda")

train_set = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=transforms)
test_set = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=transforms)

train_momentum_check_CNN(CSV="MNIST_momentum_check_CNN.csv", train_set=train_set, test_set=test_set, CNN_model=CNN().to(device), device=device,
                         n_epochs=3, batch_size=16, batch_plot=64)

print(time() - t0, "seconds")

Edit: The question stands. But, there is a business logic mistake in d.

To disconnect the gradients or parameters from the graph, you can use tensor.detach(). Is that what you’re looking for?

No. I need to try out different learning rates. So, I need the following:

For each batch

  1. Get the gradient
  2. For a given learning rate, get the loss of the model
  3. For another learning rate, get the loss of the model again
  4. Pick the learning rate that minimises loss better for that batch
    Repeat

For this, for each learning rate, I was thinking about just getting the gradients one time, do weight -= lr1 * weight_grad and test out the learning rate. If I don’t like the loss, I may want to revert it back like this weight += lr1 * weight_grad and test out some other learning rate lr2 (just like lr1)

Right now, there is just so much coupled up mechanisms like optimiser.zero_grad(), loss.backward() etc. I need a very simple manual update that does not expect me to think about all previous updates of the weights. Just the current one.

Help me out here: suppose you have data x going through your model at batch t0. You calculate the loss. Different gradient values get stored, based on the optimizer used.

The learning rate has nothing to do with the gradients or the calculated loss, yet. Take SGD, for instance, the update rule is w = w - learning_rate * gradient. In other words, you will have identical gradients and loss with both learning rates. The learning rate just determines how much the parameters get moved after loss is calculated.

That said, you could store the model and go through seeing how the loss gets impacted for batch t1, then reload the model and try the different learning rate, seeing how the loss of batch t1 compares for both.

Let me know if I am misunderstanding what you’re trying to accomplish overall, here.

I think now you have a better idea of what I am trying to do. I am trying to build my own optimiser that uses binary search. Let me make it really simple. Let us stick with batch t0 only as of now. Let us say I have no clue about machine learning and I just guess some couple of learning rates for batch t0. Now, I need to test and check which learning rate is better for t0.

Yes. I pass data x (from batch t0; So, x = t0) through the model once. I get some loss. I compute the gradients. Now, what I do not know is the (locally) optimal learning rate. [I am keeping it simple, I am not going to use Hessians] So, now I need to see which learning rate gives me lower loss than the loss I got the first time. I have the same gradients for both learning rates. But, practically, I won’t have the same loss for both learning rates. So, then, I try to pick the one that minimises the loss.

The problem is that the first time I try with first learning rate, things seem to work ok. But, when I try out the second learning rate, the computation graph seems to be having some really bad problems.

In fact, the binary code that I am trying is pretty simple:

def reset(data, target, CNN_model, current_loss):
    lr = 2 ** -16; left = -15; right = -1

    for _ in range(32):
        mid = (left + right) / 2; lr = 2 ** mid
        test_lr_mid = test_lr(lr, CNN_model, data, target, current_loss)
    
        if test_lr_mid:
            left = mid
        else:
            right = mid

    print(lr)
    return lr, CNN_model

where test_lr function is supposed to say if the given learning rate minimises the loss (True or False).

For reference, I am also adding the test_lr function. I think test_lr is the problem. But, I have no idea how I can resolve it.

def test_lr(lr, CNN_model, data, target, current_loss):
    output = CNN_model(data); loss = criterion(output, target)
    grad = torch.autograd.grad(loss, CNN_model.parameters())

    original_parameters = [parameter.clone() for parameter in CNN_model.parameters()]

    for parameter, grad_value in zip(CNN_model.parameters(), grad): parameter.data -= lr * grad_value
    output = CNN_model(data); loss = criterion(output, target); future_loss = loss.item()

    if future_loss > current_loss:
        test_result = False
        for parameter, original_parameter in zip(CNN_model.parameters(), original_parameters): parameter.data = original_parameter.data
    else:
        test_result = True

    return test_result

I am also thinking about creating new models for both learning rates and continuing to train using new models. But, I am not sure if that works because of the computation graphs etc. I am not using any optimiser at all. I am not sure if I should set the parameters directly into the model or if I should just revert the weight update (weight -= lr * weight_grad and weight += lr * weight_grad). I am not sure if there are other options.

If you run that same batch of data through the model after backprop, the higher learning rate will always produce lower loss(assuming lr<1.0 here). But that may result in over-fitting, so is not necessarily ideal.

Doing the comparison per epoch may be more beneficial and could accomplish what you’re aiming to do. But in that case, it would be better to store the model, say m0, and then compute m1_a(lr = a) and m1_b(lr = b) after an epoch of data, then compare both for calculating loss in inference mode on all of the data(much faster) and proceed with the better of the two.

1 Like

Very nice perspective! Or instead of a complete epoch, introduce meta-batch (yes “meta” batch sounds nice :slightly_smiling_face: ). Meta-batch is a batch of batches.

And thanks for confirming that I can train multiple models and use them to train them even better.

If that’s along the lines of what you’re aiming for, you’ll need the following:

import torch
import copy #we need deepcopy from this to separate the model copy parameters in memory

model_0 = copy.deepcopy(model)

Not as straight forward as copying a variable or tensor, but is the easiest way to make a copy of a model, in this case.

1 Like

I replaced binary search with a very simple (even better) logic. Now, I train two models only for one meta-batch (or chunk). One with LR and the other one with 2 times LR. I am not sure why loss2_bool is not working, loss_bool is working properly.

I am using deepcopy also.

import pandas as pd
from time import time
import torch, torchvision
from copy import deepcopy
from itertools import islice
from common import CNN, criterion, test_CNN, to_csv

data = "C:\\Raghavendra\\research_peer_review\\LearningRate\\data\\"

def train_LR(lr, i_slice, CNN_model):

    loss_item = 2 ** 32; loss_init = 2 ** 32 # inf
    CNN_model = deepcopy(CNN_model); CNN_model.train()

    for j, (data, target) in enumerate(i_slice):
        data = data.to(device); target = target.to(device)
        output = CNN_model(data)
        _, predicted = torch.max(output.data, 1)

        loss = criterion(output, target)
        loss_item = loss.item()

        if j == 0: loss_init = loss_item

        CNN_model.zero_grad(); loss.backward()
        with torch.no_grad():
            for parameter in CNN_model.parameters():
                parameter -= lr * parameter.grad

    loss_bool = True if loss_item < loss_init else False
    return CNN_model, loss_item, loss_bool

def train_momentum_check_CNN(CSV, train_set, test_set, CNN_model, device, n_epochs=1, batch_size=1, meta_batch=1):
    train_DataLoader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_DataLoader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    train_loss_list = list(); test_loss_list = list(); test_accuracy_list = list(); lr = 1

    for i in range(n_epochs):
        for _ in range(0, len(train_set), meta_batch * batch_size):
            i_slice = islice(iter(train_DataLoader), meta_batch)
            CNN_model, loss_item, loss_bool = train_LR(lr, i_slice, CNN_model)
            CNN2_model, loss2_item, loss2_bool = train_LR(2 * lr, i_slice, CNN_model)

            """
            Lesser loss 
            LR      double_LR   output      copy
            --      ---------   ------      ----
            True    True        lr * 2      CNN2_model
            True    False       lr          -
            False   True        lr / 2      -
            False   False       lr / 2      -
            """

            print(lr, loss_bool, loss2_bool)

            if loss_bool and loss2_bool: lr *= 2; CNN_model = deepcopy(CNN2_model)
            if loss_bool == False: lr /= 2

            test_loss, test_accuracy = test_CNN(CNN_model, device, test_DataLoader)

            train_loss_list.append(loss_item)
            test_loss_list.append(test_loss)
            test_accuracy_list.append(test_accuracy)

    df = pd.DataFrame()
    df["train_loss"] = train_loss_list; df["test_loss"] = test_loss_list
    df["test_accuracy"] = test_accuracy_list
    df.to_csv(to_csv + CSV, index=False)

    return

t0 = time()

transforms = torchvision.transforms.ToTensor(); device = torch.device("cuda")

train_set = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=transforms)
test_set = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=transforms)

train_momentum_check_CNN(CSV="MNIST_momentum_check_CNN.csv", train_set=train_set, test_set=test_set, CNN_model=CNN().to(device), device=device,
                         n_epochs=3, batch_size=8, meta_batch=64)

print(time() - t0, "seconds")

I am planning to see why the copy is not working properly. But, this is strange

I’m not seeing where you’re defining an optimizer.

For example:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss = loss_fn(model(input), target)
loss.backward()
optimizer.step()

Additionally, it seems you are naming the new model with the same name. That might not work:

CNN_model = deepcopy(CNN_model); CNN_model.train()

Maybe try CNN_model_a.

1 Like

Yes. I am not using an optimiser because I am creating one myself which is about taking one LR and double of that LR and trying out the one that works.

Even after changing the name (that is a great idea), I do not see any changes in the loss2_bool (it is not working at all). loss_bool works

import pandas as pd
from time import time
import torch, torchvision
from copy import deepcopy
from itertools import islice
from common import CNN, criterion, test_CNN, to_csv

data = "C:\\Raghavendra\\research_peer_review\\LearningRate\\data\\"

def train_LR(lr, i_slice, CNN_model):

    loss_item = 2 ** 32; loss_init = 2 ** 32 # inf
    CNN_model_a = deepcopy(CNN_model); CNN_model_a.train()

    for j, (data, target) in enumerate(i_slice):
        data = data.to(device); target = target.to(device)
        output = CNN_model_a(data)
        _, predicted = torch.max(output.data, 1)

        loss = criterion(output, target)
        loss_item = loss.item()

        if j == 0: loss_init = loss_item

        CNN_model_a.zero_grad(); loss.backward()
        with torch.no_grad():
            for parameter in CNN_model_a.parameters():
                parameter -= lr * parameter.grad

    loss_bool = True if loss_item < loss_init else False
    return CNN_model_a, loss_item, loss_bool

def train_momentum_check_CNN(CSV, train_set, test_set, CNN_model, device, n_epochs=1, batch_size=1, meta_batch=1):
    train_DataLoader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_DataLoader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    train_loss_list = list(); test_loss_list = list(); test_accuracy_list = list(); lr = 1

    for i in range(n_epochs):
        for _ in range(0, len(train_set), meta_batch * batch_size):
            i_slice = islice(iter(train_DataLoader), meta_batch)
            CNN_model, loss_item, loss_bool = train_LR(lr, i_slice, CNN_model)
            CNN2_model, loss2_item, loss2_bool = train_LR(2 * lr, i_slice, CNN_model)

            """
            Lesser loss 
            LR      double_LR   output      copy
            --      ---------   ------      ----
            True    True        lr * 2      CNN2_model
            True    False       lr          -
            False   True        lr / 2      -
            False   False       lr / 2      -
            """

            print(lr, loss_bool, loss2_bool)

            if loss2_bool: lr *= 2; CNN_model = deepcopy(CNN2_model)
            if loss_bool == False: lr /= 2

            test_loss, test_accuracy = test_CNN(CNN_model, device, test_DataLoader)

            train_loss_list.append(loss_item)
            test_loss_list.append(test_loss)
            test_accuracy_list.append(test_accuracy)

    df = pd.DataFrame()
    df["train_loss"] = train_loss_list; df["test_loss"] = test_loss_list
    df["test_accuracy"] = test_accuracy_list
    df.to_csv(to_csv + CSV, index=False)

    return

t0 = time()

transforms = torchvision.transforms.ToTensor(); device = torch.device("cuda")

train_set = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=transforms)
test_set = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=transforms)

train_momentum_check_CNN(CSV="MNIST_momentum_check_CNN.csv", train_set=train_set, test_set=test_set, CNN_model=CNN().to(device), device=device,
                         n_epochs=3, batch_size=8, meta_batch=64)

print(time() - t0, "seconds")

This is the final idea I am testing for this peer review research. But, I am not sure why the training just fails for no reason when the train_LR is called the second time. (it doesnt seem to train at all)

I think it has to do with torch.no_grad(). So, I refactored the code and it is like this now:

import pandas as pd
from time import time
import torch, torchvision
from copy import deepcopy
from itertools import islice
from common import CNN, criterion, test_CNN, to_csv

data = "C:\\Raghavendra\\research_peer_review\\LearningRate\\data\\"

def train_LR(lr, i_slice, CNN_model):

    loss_x_item = 2 ** 32; loss_2x_item = 2 ** 32
    loss_x_init = 2 ** 32; loss_2x_init = 2 ** 32 # inf

    CNN_model_x = deepcopy(CNN_model); CNN_model_2x = deepcopy(CNN_model)
    CNN_model_x.train(); CNN_model_2x.train()

    for j, (data, target) in enumerate(i_slice):
        data = data.to(device); target = target.to(device)
        output_x = CNN_model_x(data); output_2x = CNN_model_2x(data)

        loss_x = criterion(output_x, target); loss_2x = criterion(output_2x, target)
        loss_x_item = loss_x.item(); loss_2x_item = loss_2x.item()

        if j == 0: loss_x_init = loss_x_item; loss_2x_init = loss_2x_item

        CNN_model_x.zero_grad(); loss_x.backward()
        CNN_model_2x.zero_grad(); loss_2x.backward()

        with torch.no_grad():

            for parameter_x in CNN_model_x.parameters():
                parameter_x -= lr * parameter_x.grad

            for parameter_2x in CNN_model_2x.parameters():
                parameter_2x -= (2 * lr) * parameter_2x.grad

    loss_x_bool = loss_x_item < loss_x_init; loss_2x_bool = loss_2x_item < loss_2x_init

    return CNN_model_x, loss_x_item, loss_x_bool, CNN_model_2x, loss_2x_item, loss_2x_bool

def train_momentum_check_CNN(CSV, train_set, test_set, CNN_model, device, n_epochs=1, batch_size=1, meta_batch=1):
    train_DataLoader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_DataLoader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    train_loss_list = list(); test_loss_list = list(); test_accuracy_list = list(); lr = 2 ** -8

    for i in range(n_epochs):
        for _ in range(0, len(train_set), meta_batch * batch_size):
            i_slice = islice(iter(train_DataLoader), meta_batch)
            CNN_model_x, loss_x_item, loss_x_bool, CNN_model_2x, loss_2x_item, loss_2x_bool = train_LR(lr, i_slice, CNN_model)

            print(lr, loss_x_bool, loss_2x_bool)

            """
            Lesser loss 
            LR      double_LR   output      copy
            --      ---------   ------      ----
            True    True        lr * 2      CNN2_model
            True    False       lr          -
            False   True        lr / 2      -
            False   False       lr / 2      -
            """

            if loss_2x_bool: lr *= 2; CNN_model = deepcopy(CNN_model_2x)
            if loss_x_bool == False: lr /= 2

            test_loss, test_accuracy = test_CNN(CNN_model, device, test_DataLoader)

            train_loss_list.append(loss_x_item)
            test_loss_list.append(test_loss)
            test_accuracy_list.append(test_accuracy)

    df = pd.DataFrame()
    df["train_loss"] = train_loss_list; df["test_loss"] = test_loss_list
    df["test_accuracy"] = test_accuracy_list
    df.to_csv(to_csv + CSV, index=False)

    return

t0 = time()

transforms = torchvision.transforms.ToTensor(); device = torch.device("cuda")

train_set = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=transforms)
test_set = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=transforms)

train_momentum_check_CNN(CSV="MNIST_momentum_check_CNN.csv", train_set=train_set, test_set=test_set, CNN_model=CNN().to(device), device=device,
                         n_epochs=3, batch_size=8, meta_batch=64)

print(time() - t0, "seconds")

This almost solves the original question (but, I still don’t know why the previous code before refactoring was not working for sure). Also, now, it seems to depend on initial lr. If it is 1, the lr can be as high as 4. But, if it is 2 ** -8, it can be either some number converging towards 0 or some number like 0.125.

It depends on the run. I would like to make it stable (regardless of initial lr almost all times) and my guess is it has to be like 0.125. So, summary,

  1. why wasn’t previous code (before refactoring) not working properly?
  2. why does it depend that much on initial lr?
  3. How can I make it stable?

EDIT: It also seems to overfit a lot (test accuracy is about 80% many times and it seems to fall back on accuracy after a few iterations). Is this because of the business logic or is it because of some problem with my code?

Instability is likely a statistical issue given a small “meta batch” size of 64. You should probably increase that to 512 or more, with a batch size of 32 or 64. See how that does

1 Like