How to do constrained optimization in PyTorch

You can do projected gradient descent by enforcing your constraint after each optimizer step. An example training loop would be:

    opt = optim.SGD(model.parameters(), lr=0.1)
    for i in range(1000):
        out = model(inputs)
        loss = loss_fn(out, labels)
        print(i, loss.item())
        opt.zero_grad()
        loss.backward()
        opt.step()
        with torch.no_grad():
            for param in model.parameters():
                param.clamp_(-1, 1)

The last three lines enforce the constraint that the weights fall in the range -1–1.

7 Likes