Gradient for nn.Parameter

I am trying to implement this algorithm:

So I created a model like this:

class CNNCifar(nn.Module):
    def __init__(self):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.alpha = nn.Parameter(torch.randn(100, 100), requires_grad=True)
        self.w = torch.randn((100, 100), requires_grad=True)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

with a training loop that goes like this:

k = len(neighbour_sets)
    device = torch.device("cuda" if not torch.cuda.is_available() else "cpu")
    model = CNNCifar().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    l2c_optimizer = optim.Adam([model.alpha], lr=beta, weight_decay=0.01)

    test_accuracies = [[] for _ in range(k)]

    theta = [model.state_dict().copy() for _ in range(k)]
    theta_half = [model.state_dict().copy() for _ in range(k)]

    # w = torch.randn(k, k, requires_grad=True)
    delta_theta = [model.state_dict().copy() for _ in range(k)]

    with tqdm_output(tqdm(range(T))) as trange:
        for t in trange:
            for i in range(k):
                # Local SGD step
                log.info(f'Started training a Local SGD at node {i + 1}')

                model.load_state_dict(theta[i])
                for m in range(S):
                    for _, data in enumerate(train_loaders[i]):
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        optimizer.zero_grad()
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        loss.backward()
                        optimizer.step()

                log.info(f'Finished training a Local SGD at node {i + 1}')



                # Change capturing
                log.info(f'Computing change capturing at node {i + 1}')
                for name, param in model.named_parameters():
                    delta_theta[i][name] = theta[i][name] - theta_half[i][name]

                log.info(f'Computing mixing weights at node {i + 1}')
                # Mixing weights calculation
                model.w = model.w.clone()
                model.w[i] = compute_mixing_weights(model.alpha[i], neighbour_sets[i])

                # Aggregation
                log.info(f'Aggergating at node {i + 1}')
                theta_next = {}
                for name, param in model.named_parameters():
                    theta_next[name] = theta[i][name].clone()

                for j in neighbour_sets[i]:
                    for name, param in model.named_parameters():
                        theta_next[name] -= model.w[i][j].item() * delta_theta[i][name][j].clone()


                # Update L2C
                log.info(f'Updating L2C at node {i + 1}')
                model.load_state_dict(theta_next)
                model.train()
                # a training loop to find alpha that minimizes the validation loss
                for _, data in enumerate(val_loaders[i]):
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    l2c_optimizer.zero_grad()
                    model.alpha.requires_grad_(True)
                    
                    log.info(f'Forward pass check')
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    model.alpha.retain_grad()
                    loss.backward()
                    print(f'gradient of alpha is {model.alpha.grad}')
                    import pdb; pdb.set_trace()
                    l2c_optimizer.step()

                    # Update α[i]
                    # import pdb; pdb.set_trace()
                    # alpha_grad = model.alpha.grad  # Access the computed gradients
                    # model.alpha.data[i] -= beta * alpha_grad[i]
                

                # Remove edges for sparse topology
                if t == T_0:
                    for _ in range(K_0):
                        j = min(neighbour_sets[i], key=lambda x: w[i][x])
                        neighbour_sets[i].delete(j)

                theta[i] = model.state_dict().copy()
                theta_half[i] = model.state_dict().copy()

                # Compute test accuracy for each local model
                test_accuracies = compute_test_acc(model, test_loaders[i], device, test_accuracies, i)
            
        log.info(f'Test accuracies atiteration at Comm_round {t} =  {sum(test_accuracies) / k}')
    
    return theta, test_accuracies

The problem is: in step 18 of the algorithm, the gradient of the loss is computed with respect to alpha. but when I access the model.alpha.grad it’s None, an no gradient is available.

What am I doing wrong here?

Hi Ahmed!

First, although model.w does depend (differentiably) on model.alpha
(depending, of course, on what compute_mixing_weights() actually
does), calling .item() “breaks the computation graph” because .item()
returns a python scalar that is no longer tracked by autograd. As a result,
theta_next[name] does not, in the sense of autograd, depend on alpha.

Second, the more serious problem is that load_state_dict() also
breaks the computation graph. The short explanation is that pytorch is
very fastidious about not letting you modify Parameters (in a way that
is tracked by autograd).

Consider:

>>> import torch
>>> torch.__version__
'2.0.0'
>>> lin1 = torch.nn.Linear (2, 3)
>>> lin2 = torch.nn.Linear (2, 3)
>>> loss = lin2 (torch.randn (2)).sum()
>>> lin1.load_state_dict (lin2.state_dict())
<All keys matched successfully>
>>> loss.backward()
>>> lin2.weight.grad
tensor([[0.4859, 0.7030],
        [0.4859, 0.7030],
        [0.4859, 0.7030]])
>>> lin1.weight.grad
>>>

If I understand what you are trying to do, the best approach might be
something like:

class CNNCifar(nn.Module):
    def __init__(self):
        ...
        self.fc1_weight = torch.randn (120, 16 * 5 * 5)
        self.fc1_bias = torch.randn (120)
    ...
    def forward(self, x):
        ...
        x = F.relu (torch.nn.functional.linear (x, self.fc1_weight, self.fc1_bias)
        ...

Here, you replace Linear (and the other layers in model) whose
Parameters are difficult to modify in a way that autograd can track
with ordinary tensors (that can be tracked by autograd) that you then
pass to the functional version of linear().

Your compute_mixing_weights() would then mix together the fc1_weight
and fc1_bias tensors from the various neighbour_sets (and similarly for
the other modifiable tensors that replace self.conv1, etc.).

You would then do something like:

                for j in neighbour_sets[i]:
                    theta_next.fc1_weight -= model.w[i][j] * delta_theta[i, j].fc1_weight.clone()
                    theta_next.fc1_bias -= model.w[i][j] * delta_theta[i, j].fc1_bias.clone()
                    ...

where theta_next and delta_theta are no longer state dictionaries, but
just something that holds copies of the model tensors (i.e., fc1_weight
and fc1_bias).

Note that fc1_weight is not a Parameter and doesn’t have requires_grad
explicitly set to True. It will “inherit” requires_grad = True when it is
updated using w, which does have requires_grad = True.

(If I understand your use case correctly, you want to separately optimize
both the “ordinary” parameters() of model and model.alpha. So you
may, indeed, need to endow fc1_weight with requires_grad = True.
Note that it’s fine to set up an Optimizer to optimizer non-Parameter
tensors (such as fc1_weight).

Best.

K. Frank

1 Like

Thank you K. Frank very much for this detailed explanation.

So I adjusted everything, as much as I understood from your comment, as follows:

k = len(neighbour_sets)
    device = torch.device("cuda" if not torch.cuda.is_available() else "cpu")
    model = CNNCifar().to(device)
    criterion = nn.CrossEntropyLoss()
    # optimizer for all params except alpha and w
    params = [model.conv1_weight, model.conv1_bias, model.conv2_weight, model.conv2_bias, model.fc1_weight, model.fc1_bias, model.fc2_weight, model.fc2_bias, model.fc3_weight, model.fc3_bias]
    optimizer = optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=0.01)
    l2c_optimizer = optim.Adam([model.alpha], lr=beta, weight_decay=0.01)

    test_accuracies = [[] for _ in range(k)]

    theta = [model for _ in range(k)]
    theta_half = [model for _ in range(k)]
    delta_theta = [model for _ in range(k)]
    theta_next = [model for _ in range(k)]

    w = torch.zeros(k, k, dtype=theta[0].alpha.dtype, device=theta[0].alpha.device)

    with tqdm(range(T)) as trange:
        for t in trange:
            for i in range(k):
                # Local SGD step
                log.info(f'Started training a Local SGD at node {i + 1}')

                # theta_half[i] = theta[i]
                # for m in range(S):
                #     for _, data in enumerate(train_loaders[i]):
                #         inputs, labels = data
                #         inputs, labels = inputs.to(device), labels.to(device)
                #         optimizer.zero_grad()
                #         outputs = theta_half[i](inputs)
                #         loss = criterion(outputs, labels)
                #         loss.backward()
                #         optimizer.step()

                log.info(f'Finished training a Local SGD at node {i + 1}')

                # Change capturing
                log.info(f'Computing change capturing at node {i + 1}')
                import pdb; pdb.set_trace()
                delta_theta[i].fc1_weight = theta[i].fc1_weight - theta_half[i].fc1_weight.clone()
                delta_theta[i].fc1_bias = theta[i].fc1_bias - theta_half[i].fc1_bias.clone()
                delta_theta[i].fc2_weight = theta[i].fc2_weight - theta_half[i].fc2_weight.clone()
                delta_theta[i].fc2_bias = theta[i].fc2_bias - theta_half[i].fc2_bias.clone()
                delta_theta[i].fc3_weight = theta[i].fc3_weight - theta_half[i].fc3_weight.clone()
                delta_theta[i].fc3_bias = theta[i].fc3_bias - theta_half[i].fc3_bias.clone()
                delta_theta[i].conv1_weight = theta[i].conv1_weight - theta_half[i].conv1_weight.clone()
                delta_theta[i].conv1_bias = theta[i].conv1_bias - theta_half[i].conv1_bias.clone()
                delta_theta[i].conv2_weight = theta[i].conv2_weight - theta_half[i].conv2_weight.clone()
                delta_theta[i].conv2_bias = theta[i].conv2_bias - theta_half[i].conv2_bias.clone()
                               
                log.info(f'Computing mixing weights at node {i + 1}')
                # Mixing weights calculation like this :\highlightcyan{$w_{i,j}=\frac{\exp(\alpha_{i,j})}{\sum_{\ell\in i\cup N(i)}\exp(\alpha_{i,\ell})}$}
                for j in neighbour_sets[i]:
                    w[i][j] = torch.exp(theta[i].alpha[i][j])
                w[i] /= w[i].sum()
                
                theta[i].w.data = w

                # Aggregation
                log.info(f'Aggergating at node {i + 1}')

                for j in neighbour_sets[i]:
                    theta_next[i].fc1_weight = theta[i].fc1_weight -  theta[i].w[i][j] * delta_theta[i, j].fc1_weight.clone()
                    theta_next[i].fc1_bias = theta[i].fc1_bias - theta[i].w[i][j] * delta_theta[i, j].fc1_bias.clone()
                    theta_next[i].fc2_weight = theta[i].fc2_weight - theta[i].w[i][j] * delta_theta[i, j].fc2_weight.clone()
                    theta_next[i].fc2_bias = theta[i].fc2_bias - theta[i].w[i][j] * delta_theta[i, j].fc2_bias.clone()
                    theta_next[i].fc3_weight = theta[i].fc3_weight - theta[i].w[i][j] * delta_theta[i, j].fc3_weight.clone()
                    theta_next[i].fc3_bias = theta[i].fc3_bias - theta[i].w[i][j] * delta_theta[i, j].fc3_bias.clone()
                    theta_next[i].conv1_weight = theta[i].conv1_weight - theta[i].w[i][j] * delta_theta[i, j].conv1_weight.clone()
                    theta_next[i].conv1_bias = theta[i].conv1_bias - theta[i].w[i][j] * delta_theta[i, j].conv1_bias.clone()
                    theta_next[i].conv2_weight = theta[i].conv2_weight - theta[i].w[i][j] * delta_theta[i, j].conv2_weight.clone()
                    theta_next[i].conv2_bias = theta[i].conv2_bias - theta[i].w[i][j] * delta_theta[i, j].conv2_bias.clone()

                # Update L2C
                log.info(f'Updating L2C at node {i + 1}')
                # a training loop to find alpha that minimizes the validation loss
                for _, data in enumerate(val_loaders[i]):
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    l2c_optimizer.zero_grad()
                    model.alpha.requires_grad_(True)
                    outputs = theta_next[i](inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    print(f'gradient of alpha is {theta_next[i].alpha.grad}')
                    l2c_optimizer.step()

                # Remove edges for sparse topology
                if t == T_0:
                    for _ in range(K_0):
                        j = min(neighbour_sets[i], key=lambda x: w[i][x])
                        neighbour_sets[i].delete(j)

                # theta[i] = model.state_dict().copy()
                # theta_half[i] = model.state_dict().copy()

                # Compute test accuracy for each local model
                test_accuracies = compute_test_acc(model, test_loaders[i], device, test_accuracies, i)
            
        log.info(f'Test accuracies atiteration at Comm_round {t} =  {sum(test_accuracies) / k}')
    
    return theta, test_accuracies

But still got None from this line

print(f'gradient of alpha is {theta_next[i].alpha.grad}')

My model (based on your suggestions) is as follows:


from torch import nn
import torch.nn.functional as F
import torch
from utils import compute_mixing_weights

class CNNCifar(nn.Module):
    def __init__(self):
        super(CNNCifar, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1_weight = torch.randn(6, 3, 5, 5, requires_grad=True)
        self.conv1_bias = torch.randn(6, requires_grad=True)

        self.conv2_weight = torch.randn(16, 6, 5, 5, requires_grad=True)
        self.conv2_bias = torch.randn(16, requires_grad=True)

        self.fc1_weight = torch.randn(120, 16 * 5 * 5, requires_grad=True)
        self.fc1_bias = torch.randn(120, requires_grad=True)

        self.fc2_weight = torch.randn(84, 120, requires_grad=True)
        self.fc2_bias = torch.randn(84, requires_grad=True)

        self.fc3_weight = torch.randn(10, 84, requires_grad=True)
        self.fc3_bias = torch.randn(10, requires_grad=True)

        self.alpha = nn.Parameter(torch.randn(100, 100), requires_grad=True)
        self.w = nn.Parameter(torch.randn(100, 100), requires_grad=True)


    def forward(self, x):
        x = self.pool(F.relu(F.conv2d(x, self.conv1_weight, self.conv1_bias)))
        x = self.pool(F.relu(F.conv2d(x, self.conv2_weight, self.conv2_bias)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(torch.nn.functional.linear(x, self.fc1_weight, self.fc1_bias))
        x = F.relu(torch.nn.functional.linear(x, self.fc2_weight, self.fc2_bias))
        x = torch.nn.functional.linear(x, self.fc3_weight, self.fc3_bias)
        return F.log_softmax(x, dim=1)


Hi Ahmed!

You have quite a few things going on here. I would suggest that you work
with a simplified, “toy” model where you figure out how to optimize both
alpha and some simple model parameter in the presence of “weight mixing.”

Get that working, and then try to extend it to your more-complicated use case.

One immediate problem is that at this point theta, theta_half, etc.,
are all lists that contain references to the same instance of model. So,
if your modify, say, theta[2], you have also modified theta[0] and
theta_half[3], because these are all actually references to the same
object, namely model. This surely isn’t what your want and could likely
be the cause of your problem.

Best.

K. Frank