Gradient Ascent and Gradient Modification/Modifying Optimizer instead of Grad_weight

I think that this is a bit too late, but the solution I came up with is to use a custom autograd function, which reverses gradient direction. As like as @Tamal_Chowdhury , I have a lagrangian optimization problem, for which this function works perfectly. A small working example would be:

import torch


class AscentFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_input):
        return -grad_input


def make_ascent(loss):
    return AscentFunction.apply(loss)


x = torch.normal(10, 3, size=(10,))
w = torch.ones_like(x, requires_grad=True)

loss = (x * w).sum()
print(f'descent loss: {loss.item():.2f}')

loss.backward()
print(w.grad)

w.grad = None

loss = (x * w).sum()
m_loss = make_ascent(loss)
print(f'ascent loss: {m_loss.item():.2f}')

m_loss.backward()
print(w.grad)

It’s output:

descent loss: 96.13
tensor([12.7093, 11.2243,  6.4265,  7.6572, 14.2737, 15.1144,  8.0099,  6.2517,
         7.6352,  6.8274])
ascent loss: 96.13
tensor([-12.7093, -11.2243,  -6.4265,  -7.6572, -14.2737, -15.1144,  -8.0099,
         -6.2517,  -7.6352,  -6.8274])