Hi All,

Hi,

The simplest way to do gradient ascent on a loss `L` is to do gradient descent on `-L` . 5 Likes

That is an interesting solution. I think I need to further clarify my original question. I would like to include a negative sign on the updates to the weights, and this corresponds to changing grad_weight to -grad_weight, while grad_input and grad_bias are left untouched. However, I am wary of unintended consequences of doing something like this to the gradients, and was wondering if there was an easy way to change the optimizer such that it performed gradient ascent(W + dW) for the non last layer weights specifically, but left the other parameters alone?

In that case I guess you will have to create your custom optimizer to handle that. With one group for the descent part and one group for the ascent part for example.

1 Like

I’m working on a similar problem where I need to optimize the following loss function:

Here `w` (omega) is model parameter and `Lamdas` are Lagrange Multipliers. I need to perform gradient descent wrt. `omega` and simultaneously gradient ascent wrt. `lambda`. `lambda` is not a model parameter and only included in the loss term.
Will your solution of updating lambda using gradient descent on -L work in this case? If it does then taking negative learning rate for `lambdas` in gradient descent should also be equivalent. And if it doesn’t then what should be the pytorch solution for this(without changing the optimizer source code)? Or should I need to creat a custom optimizer?

1 Like

I think that this is a bit too late, but the solution I came up with is to use a custom autograd function, which reverses gradient direction. As like as @Tamal_Chowdhury , I have a lagrangian optimization problem, for which this function works perfectly. A small working example would be:

``````import torch

@staticmethod
def forward(ctx, input):
return input

@staticmethod

def make_ascent(loss):
return AscentFunction.apply(loss)

x = torch.normal(10, 3, size=(10,))

loss = (x * w).sum()
print(f'descent loss: {loss.item():.2f}')

loss.backward()

loss = (x * w).sum()
m_loss = make_ascent(loss)
print(f'ascent loss: {m_loss.item():.2f}')

m_loss.backward()
``````descent loss: 96.13