I am currently implementing a gradient-based optimization over a receding horizon. So for each time step, we sample a number of candidates (action sequences). Each candidate should be optimized to minimize the costs by updating each action in the sequence. So far, I have a sequential version which is quite slow and I wanted to use per sample gradients to batch process the gradient calculation for each candidate. However. the dimension of the calculated gradients is wrong. Does anyone have an idea what I need to change?
import numpy as np
import torch
from torch.optim import Adam
import torch.nn as nn
from torch import jit
from torch.func import functional_call, vmap, grad
class CostModel(nn.Module):
def __init__(self):
super().__init__()
def step(self, u):
x_squared, x, y = self.state
u = torch.clamp(u, -2, 2)[0]
newy = y - 3 * x_squared + 0.1 * u.detach()
newy = torch.clamp(newy, -2, 2)
newx = x + 0.02 * newy
self.state = torch.stack((newx**2, newx, newy), dim=0)
costs = x**2 + 0.1 * y**2 + 0.001 * (u**2)
return self._get_obs(), -costs
def reset(self):
high = torch.tensor([2.0, 1.0])
low = -high # We enforce symmetric limits.
x, y = torch.distributions.uniform.Uniform(low, high).sample((1,))[0]
self.state = torch.stack((x**2, x, y), dim=0).unsqueeze(dim=0)
return self._get_obs(), {}
def _get_obs(self):
return self.state
def set_state(self, state):
self.state = state
return self._get_obs()
class Model(nn.Module):
def __init__(self):
super().__init__()
self.cost_model = CostModel()
def forward(self, state: torch.Tensor, actions: torch.Tensor):
H = actions.size(0) + 1
B = actions.size(1)
states = [torch.empty(0)] * H
costs = [torch.empty(0)] * (H - 1)
states[0] = state[0]
self.cost_model.reset()
self.cost_model.set_state(state[0].narrow(0, 0, B).permute(1, 0))
# Loop over time sequence
for t in range(H - 1):
next_state, cost = self.cost_model.step(actions[t].permute(1, 0))
states[t + 1] = next_state.permute(1, 0)
costs[t] = cost
# return next state and associated costs
return torch.stack(states[1:], dim=0), torch.stack(costs, dim=0)
class Optimizer(jit.ScriptModule):
__constants__ = [
"action_size",
"horizon",
"optimisation_iters",
"candidates",
"learning_rate",
"device",
"min_action",
"max_action",
]
def __init__(
self,
action_size,
horizon,
optimisation_iters,
candidates,
learning_rate,
device,
min_action,
max_action,
):
super().__init__()
self.action_size = action_size
self.horizon = horizon
self.optimisation_iters = optimisation_iters
self.candidates = candidates
self.learning_rate = learning_rate
self.device = device
self.min_action = min_action
self.max_action = max_action
def forward_sequential(self, model, state):
"""Works, but is quite slow."""
B, Z = state.size(0), state.size(1)
state = state.unsqueeze(dim=1).expand(B * self.horizon, self.candidates, Z)
candidate_trajectories = torch.clamp(
torch.randn(
self.horizon,
B,
self.candidates,
self.action_size,
device=self.device,
)
* (self.max_action - self.min_action)
/ 2
+ (self.min_action + self.max_action) / 2,
min=self.min_action,
max=self.max_action,
).view(self.horizon, B * self.candidates, self.action_size)
candidate_costs = torch.zeros(self.candidates)
for j in range(self.candidates):
actions = nn.Parameter(candidate_trajectories[:, j, :].unsqueeze(dim=1))
optimizer = Adam([actions], lr=self.learning_rate)
for _ in range(self.optimisation_iters):
states, costs = model(state, actions)
cumulative_costs = costs.sum(dim=0)
# we want to minimize the cumulative costs over the horizon
loss = -cumulative_costs
optimizer.zero_grad()
loss.backward() # actions.grad is of size [horizon x 1 x action_size]
optimizer.step()
candidate_costs[j] = cumulative_costs.detach()
candidate_trajectories[:, j, :] = actions.detach().squeeze(-1)
best_actions = candidate_trajectories[:, candidate_costs.argmax(dim=0), :]
return torch.clamp(best_actions.detach()[0], min=self.min_action, max=self.max_action)
def forward(self, model, state):
"""Works, but the gradient dimensions are not correct."""
B, Z = state.size(0), state.size(1)
state = state.unsqueeze(dim=1).expand(B * self.horizon, self.candidates, Z)
candidate_trajectories = torch.clamp(
torch.randn(
self.horizon,
B,
self.candidates,
self.action_size,
device=self.device,
)
* (self.max_action - self.min_action)
/ 2
+ (self.min_action + self.max_action) / 2,
min=self.min_action,
max=self.max_action,
).view(self.horizon, B * self.candidates, self.action_size)
def compute_loss(state, actions):
state = state.unsqueeze(dim=1)
actions = actions.unsqueeze(dim=1)
_, costs = model(state, actions)
loss = -costs.sum(dim=0)[0]
return loss
actions = nn.Parameter(candidate_trajectories)
optimizer = Adam([actions], lr=self.learning_rate)
for _ in range(self.optimisation_iters):
states, costs = model(state, actions)
cumulative_costs = costs.sum(dim=0)
# current dimension of gradients is [horizon x candidates x state_size] instead of [horizon x candidates x action_size]?
ft_per_candidate_grads = vmap(grad(compute_loss), in_dims=(1, 1), randomness="same")(
state, actions
)
optimizer.zero_grad()
# run time error since gradient dimensions are not correct
optimizer.param_groups[0]["params"][0].grad = ft_per_candidate_grads.permute(1, 0, 2)
optimizer.step()
best_actions = candidate_trajectories[:, cumulative_costs.detach().argmax(dim=0), :]
return torch.clamp(best_actions.detach()[0], min=self.min_action, max=self.max_action)
if __name__ == "__main__":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
horizon = 12
candidates = 1000
optimisation_iters = 10
state = np.random.randn(
3,
).astype(np.float32)
min_action = -2.0
max_action = 2.0
optimizer = Optimizer(
action_size=1,
horizon=horizon,
optimisation_iters=optimisation_iters,
candidates=candidates,
learning_rate=0.5,
device=device,
min_action=min_action,
max_action=max_action,
)
model = Model()
for _ in range(200):
state = torch.from_numpy(state).to(device).unsqueeze(dim=0)
# action = optimizer.forward_sequential(model, state)
action = optimizer(model, state)
state = np.random.randn(
3,
)