Wrong dimension of per sample gradients

I am currently implementing a gradient-based optimization over a receding horizon. So for each time step, we sample a number of candidates (action sequences). Each candidate should be optimized to minimize the costs by updating each action in the sequence. So far, I have a sequential version which is quite slow and I wanted to use per sample gradients to batch process the gradient calculation for each candidate. However. the dimension of the calculated gradients is wrong. Does anyone have an idea what I need to change?

import numpy as np
import torch
from torch.optim import Adam
import torch.nn as nn
from torch import jit
from torch.func import functional_call, vmap, grad


class CostModel(nn.Module):
    def __init__(self):
        super().__init__()

    def step(self, u):
        x_squared, x, y = self.state

        u = torch.clamp(u, -2, 2)[0]

        newy = y - 3 * x_squared + 0.1 * u.detach()
        newy = torch.clamp(newy, -2, 2)
        newx = x + 0.02 * newy

        self.state = torch.stack((newx**2, newx, newy), dim=0)

        costs = x**2 + 0.1 * y**2 + 0.001 * (u**2)

        return self._get_obs(), -costs

    def reset(self):
        high = torch.tensor([2.0, 1.0])
        low = -high  # We enforce symmetric limits.
        x, y = torch.distributions.uniform.Uniform(low, high).sample((1,))[0]

        self.state = torch.stack((x**2, x, y), dim=0).unsqueeze(dim=0)

        return self._get_obs(), {}

    def _get_obs(self):
        return self.state

    def set_state(self, state):
        self.state = state
        return self._get_obs()


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cost_model = CostModel()

    def forward(self, state: torch.Tensor, actions: torch.Tensor):
        H = actions.size(0) + 1
        B = actions.size(1)
        states = [torch.empty(0)] * H
        costs = [torch.empty(0)] * (H - 1)
        states[0] = state[0]
        self.cost_model.reset()
        self.cost_model.set_state(state[0].narrow(0, 0, B).permute(1, 0))
        # Loop over time sequence
        for t in range(H - 1):
            next_state, cost = self.cost_model.step(actions[t].permute(1, 0))

            states[t + 1] = next_state.permute(1, 0)
            costs[t] = cost

        # return next state and associated costs
        return torch.stack(states[1:], dim=0), torch.stack(costs, dim=0)


class Optimizer(jit.ScriptModule):
    __constants__ = [
        "action_size",
        "horizon",
        "optimisation_iters",
        "candidates",
        "learning_rate",
        "device",
        "min_action",
        "max_action",
    ]

    def __init__(
        self,
        action_size,
        horizon,
        optimisation_iters,
        candidates,
        learning_rate,
        device,
        min_action,
        max_action,
    ):
        super().__init__()
        self.action_size = action_size
        self.horizon = horizon
        self.optimisation_iters = optimisation_iters
        self.candidates = candidates
        self.learning_rate = learning_rate
        self.device = device

        self.min_action = min_action
        self.max_action = max_action

    def forward_sequential(self, model, state):
        """Works, but is quite slow."""
        B, Z = state.size(0), state.size(1)
        state = state.unsqueeze(dim=1).expand(B * self.horizon, self.candidates, Z)

        candidate_trajectories = torch.clamp(
            torch.randn(
                self.horizon,
                B,
                self.candidates,
                self.action_size,
                device=self.device,
            )
            * (self.max_action - self.min_action)
            / 2
            + (self.min_action + self.max_action) / 2,
            min=self.min_action,
            max=self.max_action,
        ).view(self.horizon, B * self.candidates, self.action_size)

        candidate_costs = torch.zeros(self.candidates)
        for j in range(self.candidates):
            actions = nn.Parameter(candidate_trajectories[:, j, :].unsqueeze(dim=1))
            optimizer = Adam([actions], lr=self.learning_rate)
            for _ in range(self.optimisation_iters):
                states, costs = model(state, actions)
                cumulative_costs = costs.sum(dim=0)
                # we want to minimize the cumulative costs over the horizon
                loss = -cumulative_costs
                optimizer.zero_grad()
                loss.backward()  # actions.grad is of size [horizon x 1 x action_size]
                optimizer.step()

            candidate_costs[j] = cumulative_costs.detach()
            candidate_trajectories[:, j, :] = actions.detach().squeeze(-1)

        best_actions = candidate_trajectories[:, candidate_costs.argmax(dim=0), :]

        return torch.clamp(best_actions.detach()[0], min=self.min_action, max=self.max_action)

    def forward(self, model, state):
        """Works, but the gradient dimensions are not correct."""
        B, Z = state.size(0), state.size(1)
        state = state.unsqueeze(dim=1).expand(B * self.horizon, self.candidates, Z)

        candidate_trajectories = torch.clamp(
            torch.randn(
                self.horizon,
                B,
                self.candidates,
                self.action_size,
                device=self.device,
            )
            * (self.max_action - self.min_action)
            / 2
            + (self.min_action + self.max_action) / 2,
            min=self.min_action,
            max=self.max_action,
        ).view(self.horizon, B * self.candidates, self.action_size)

        def compute_loss(state, actions):
            state = state.unsqueeze(dim=1)
            actions = actions.unsqueeze(dim=1)
            _, costs = model(state, actions)
            loss = -costs.sum(dim=0)[0]

            return loss

        actions = nn.Parameter(candidate_trajectories)
        optimizer = Adam([actions], lr=self.learning_rate)
        for _ in range(self.optimisation_iters):
            states, costs = model(state, actions)
            cumulative_costs = costs.sum(dim=0)
            # current dimension of gradients is [horizon x candidates x state_size] instead of [horizon x candidates x action_size]?
            ft_per_candidate_grads = vmap(grad(compute_loss), in_dims=(1, 1), randomness="same")(
                state, actions
            )
            optimizer.zero_grad()
            # run time error since gradient dimensions are not correct
            optimizer.param_groups[0]["params"][0].grad = ft_per_candidate_grads.permute(1, 0, 2)
            optimizer.step()

        best_actions = candidate_trajectories[:, cumulative_costs.detach().argmax(dim=0), :]

        return torch.clamp(best_actions.detach()[0], min=self.min_action, max=self.max_action)


if __name__ == "__main__":
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    horizon = 12
    candidates = 1000
    optimisation_iters = 10

    state = np.random.randn(
        3,
    ).astype(np.float32)

    min_action = -2.0
    max_action = 2.0

    optimizer = Optimizer(
        action_size=1,
        horizon=horizon,
        optimisation_iters=optimisation_iters,
        candidates=candidates,
        learning_rate=0.5,
        device=device,
        min_action=min_action,
        max_action=max_action,
    )
    model = Model()

    for _ in range(200):
        state = torch.from_numpy(state).to(device).unsqueeze(dim=0)
        # action = optimizer.forward_sequential(model, state)
        action = optimizer(model, state)
        state = np.random.randn(
            3,
        )

Okay, so basically swapping the order of state and actions both in the compute_loss definition and when calling vmap fixes the issue. Does anyone know if there is some implicit order assumed? I couldn’t find anything in the documentation of vmap.