MAML Implementation failing to adapt to new tasks

Hello!

I’m very new to Machine learning, I’ve created my own PPO implementation and am trying to implement MAML in the RL domain.

I am trying to replicate something like the following results shown in the original MAML paper:

However, I’ve been using MetaWorld as my testing environments, specifically ML10.

My task distribution was 50 environments with randomly assigned tasks from ML10.

Hyperparameters:

  • tasks per meta update = 5
  • metaLr = 1e-2
  • meta iterations = 100
  • inner trajectory length = 500 (1 episode)
  • innerLr = 5e-2
  • inner trajectories per gradient step = 20

After meta training, I’ve been validating by comparing a PPO agent using my meta-trained weights, to a freshly initialised PPO agent, on the ML10 test tasks and observing their average return after 0-4 Episodes->NetworkUpdates. I did 30 episodes of evaluation after each Network update.

Here are some of the results I’ve been getting (note the legend is wrong from - the red is the meta agent:

I’ve tried my hardest to work this out myself but I’m pretty stumped, can someone more knowledgeable than me spot the errors in implementation? I’ve been stuck for weeks trying to work this out.

class MAMLMetaLearner():
    def __init__(self, envs, tasks=5, metaLr=3e-3, metaSteps=30, innerTimeSteps=500, innerLr=1e-2, innerRollouts=20):

        self.envs = envs
        self.env = None
        self.metaLr = metaLr
        self.innerTimeSteps = innerTimeSteps
        self.innerLr = innerLr # inner learning rate is generally higher than the meta learning rate
        self.metaSteps = metaSteps
        self.tasks = tasks # number of tasks to train on per meta iteration
        self.innerRollouts = innerRollouts # number of rollouts for inner loop training

        self.metaAgent = PPO("MLP", self.envs[0], lrAnealling=False, progressBar=False) # meta agent
        self.metaOptimizer = torch.optim.Adam(self.metaAgent.actorCritic.parameters(), lr=metaLr)

        self.metaLoss = []

    # train the agent on a single task
    def ComputeTaskLoss(self, env):

        self.metaAgent.env = env # set the environment for inner loop learning
        
        # adapt the task agent parameters using the task model
        state = None

        # in order to not break the computational graph I will create a dictionary of the inner loop parameters
        adaptedParametersDictionary = {name: parameter for (name, parameter) in self.metaAgent.actorCritic.named_parameters()}

        for step in range(self.innerRollouts):

            state, states, actions, oldLogProbabilities, returns, advantages, _, _ = self.metaAgent.CollectTrajectories(totalTimesteps=self.innerTimeSteps, state=state, parameters = adaptedParametersDictionary, batchSize=self.innerTimeSteps) # collect trajectories using the task agent

            if len(advantages) > 1:
                advantages = (advantages - advantages.mean()) / torch.max(advantages.std(), torch.tensor(1e-9, device=self.metaAgent.device))

            if self.metaAgent.discreteActionSpace:
                actionLogits, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
                newProbabilityDistribution = Categorical(logits=actionLogits) # using categorical is good because it provides functions to sample and calculate log probabilities
                newLogProbabilities = actionDistribution.log_prob(actions)
            else:
                mean, standardDeviation, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
                newProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
                newLogProbabilities = (newProbabilityDistribution.log_prob(actions)).sum(dim=-1)

            probabilityRatio = torch.exp(newLogProbabilities - oldLogProbabilities) # this is a measure of how much the policy has changed from the old policy

            # compute loss
            loss, _, _, _ = self.metaAgent.ComputeLoss(values, returns, newProbabilityDistribution, probabilityRatio, advantages)

            self.metaAgent.optimizer.zero_grad() # zero the gradients

            # compute gradients with respect to adapted parameters
            gradients = torch.autograd.grad(loss, adaptedParametersDictionary.values(), create_graph=True)           

            # update the adapted parameters using the gradients
            adaptedParametersDictionary = {name: parameter - self.innerLr * gradient for ((name, parameter), gradient) in zip(adaptedParametersDictionary.items(), gradients)}


        # compute loss after adaptation
        _, states, actions, newLogProbabilities, returns, advantages, _, _ = self.metaAgent.CollectTrajectories(totalTimesteps=self.innerTimeSteps, parameters=adaptedParametersDictionary, batchSize=self.innerTimeSteps) # collect trajectories using the updated task agent

        if len(advantages) > 1:
            advantages = (advantages - advantages.mean()) / torch.max(advantages.std(), torch.tensor(1e-9, device=self.metaAgent.device))

        if self.metaAgent.actorCritic.discreteActionSpace:
            actionLogits, _ = self.metaAgent.actorCritic(states)
            actionDistribution = Categorical(logits=actionLogits) # using categorical is good because it provides functions to sample and calculate log probabilities
            oldLogProbabilities = actionDistribution.log_prob(actions)
        else:
            mean, standardDeviation, _ = self.metaAgent.actorCritic(states)
            oldProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
            oldLogProbabilities = (oldProbabilityDistribution.log_prob(actions)).sum(dim=-1)

        # calculate newProbabilityDistribution
        if self.metaAgent.actorCritic.discreteActionSpace:
            actionLogits, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
            newProbabilityDistribution = Categorical(logits=actionLogits)
        else:
            mean, standardDeviation, values = self.metaAgent.ForwardPass(states, parameters=adaptedParametersDictionary)
            newProbabilityDistribution = torch.distributions.Normal(mean, standardDeviation)
           
        probabilityRatio = torch.exp(newLogProbabilities - oldLogProbabilities) # this is a measure of how much the policy has changed from the original policy

        adaptedLoss, _, _, _ = self.metaAgent.ComputeLoss(values, returns, newProbabilityDistribution, probabilityRatio, advantages)

        return adaptedLoss

    def MetaTrain(self):
        print("Meta Training Started...")
        count = 0
        smallestLoss = -float("inf")
        for metaIteration in tqdm(range(self.metaSteps), desc="Meta Training Progress", unit = "iterations"):
            metaLoss = 0.0
            tasks = [random.choice(self.envs) for task in range(self.tasks)] # randomly select tasks from the environment list

            for task in tasks:
                taskLoss = self.ComputeTaskLoss(task) # inner loop learning
                metaLoss = metaLoss + taskLoss

            metaLoss = metaLoss / self.tasks
            
            # meta update
            self.metaOptimizer.zero_grad() # zero the gradients

            torch.nn.utils.clip_grad_norm_(self.metaAgent.actorCritic.parameters(), 0.5) # clip the gradients

            # apply the meta gradients
            metaGradients = torch.autograd.grad(metaLoss, self.metaAgent.actorCritic.parameters()) # compute gradients

            for parameter, gradient in zip(self.metaAgent.actorCritic.parameters(), metaGradients):
                parameter.grad = gradient

            self.metaOptimizer.step()

            count += 1
            print(f"Iteration {len(self.metaLoss)} , Loss: {metaLoss.item()}. {count} training iterations completed this session")

        print("Meta Training Complete")

The MAML class borrows functions from my PPO implementation, eg: CollectTrajectories, from my PPO implementation. The MAML class uses it for collecting rollouts

CollectTrajectories
def CollectTrajectories(self, totalTimesteps, progressBar=None, timesteps=0, state=None, parameters=None, batchSize=None):

        self.actorCritic.eval()  # Set the network to eval mode to disable dropout during rollout

        # sanity check
        if state is None:
            state, _info = self.env.reset()

        # if the batch size is not specified, use the default batch size
        if batchSize is None: 
            batchSize = self.batchSize // self.env.num_envs

        # if the environment is vectorised, we need to handle the batch tensors differently
        if self.env.is_vector_env:
            observationSpace = self.env.unwrapped.single_observation_space.shape
            actionSpace = self.env.unwrapped.single_action_space.shape
        else:
            observationSpace = self.env.observation_space.shape
            actionSpace = self.env.action_space.shape

        # tensors for each environment's trajectory (dimensions: [batch_size, num_envs, ...])
        states = torch.zeros((batchSize, self.env.num_envs, *observationSpace), dtype=torch.float, device=self.device)
        actions = torch.zeros((batchSize, self.env.num_envs, *actionSpace), dtype=torch.float, device=self.device)
        rewards = torch.zeros(batchSize, self.env.num_envs, dtype=torch.float, device=self.device)
        logProbabilities = torch.zeros(batchSize, self.env.num_envs, dtype=torch.float, device=self.device)
        values = torch.zeros(batchSize, self.env.num_envs, dtype=torch.float, device=self.device)
        dones = torch.zeros(batchSize, self.env.num_envs, dtype=torch.bool, device=self.device)

        if timesteps == 0:
            state, _info = self.env.reset() # [num_envs, [single_observation_space]]
            done = torch.zeros(self.env.num_envs).to(self.device)
            truncated = torch.zeros(self.env.num_envs).to(self.device)

        with torch.no_grad():
            # collect a batch of trajectories to train the network
            for step in range(0, (batchSize)):

                timesteps += 1 * self.env.num_envs

                state = torch.tensor(state, dtype=torch.float, device=self.device)
                states[step] = state
                
                if self.discreteActionSpace:
                    actionLogits, value = self.ForwardPass(state, parameters)
                    actionDistribution = Categorical(logits=actionLogits)
                    action = actionDistribution.sample()
                    if not self.env.is_vector_env:
                        action = action.squeeze()  # a single environment would expect a value without the extra dimension
                    logProbability = actionDistribution.log_prob(action)
                else:
                    mean, standardDeviation, value = self.ForwardPass(state, parameters)
                    actionDistribution = torch.distributions.Normal(mean, standardDeviation)
                    action = actionDistribution.sample()          
                    logProbability = (actionDistribution.log_prob(action)).sum(dim=-1)

                actions[step] = action

                nextState, reward, done, truncated, info = self.env.step(action.cpu().numpy())
        
                done = torch.tensor(done, dtype=torch.bool, device=self.device)
                truncated = torch.tensor(truncated, dtype=torch.bool, device=self.device)

                dones[step] = torch.logical_or(done, truncated)
                rewards[step] = torch.tensor(reward, dtype=torch.float, device=self.device)
                values[step] = value.squeeze() # remove the extra dimension
                logProbabilities[step] = logProbability

                state = nextState               

                # episode termination handling
                for i in range(self.env.num_envs):
                    if info:
                        if self.env.is_vector_env:
                            if "_episode" in info: # check if the key "_episode" exists
                                if info["_episode"][i]: # if THE episode has ended
                                    episodeMetrics = info.get('episode', {}) # get the metrics of the episode that ended
                                    self.episodeLengths.append(episodeMetrics.get('l', [0])[i])
                                    self.totalEpisodeRewards.append(episodeMetrics.get('r', [0])[i]) 

                        elif (done or truncated):
                            episodeMetrics = info.get('episode')
                            self.episodeLengths.append(episodeMetrics.get('l', 0).item())
                            self.totalEpisodeRewards.append(episodeMetrics.get('r', 0).item())
                            state, _info = self.env.reset() # single environments are not automatically reset

                # update the progress bar
                if self.progressBar:
                    progressBar.update(self.env.num_envs)

                if timesteps >= totalTimesteps: # if the last batch exceeds the total timesteps end early
                    break

            # we need to know the value of the next state for the last state in the batch and if it was terminal
            with torch.no_grad():
                if self.discreteActionSpace:
                    _actionLogits, lastNextValues = self.actorCritic(torch.tensor(state, dtype=torch.float, device=self.device)) # get the value of the next state
                else:
                    _mean, _standardDeviation, lastNextValues = self.actorCritic(torch.tensor(state, dtype=torch.float, device=self.device))
                lastNextValues = lastNextValues.squeeze()

            if self.profiling:
                self.logMemoryUsage("after_batch/before_advantages&returns")
            # calculate the advantages for the batch
            returns = self.GetReturns(rewards, dones, lastNextValues)
            advantages = self.GetAdvantages(rewards, values, dones, lastNextValues)

            # flatten the trajectories
            states = states.view(-1, *states.shape[2:])
            logProbabilities = logProbabilities.view(-1)
            returns = returns.view(-1)
            advantages = advantages.view(-1)
            if self.discreteActionSpace:
                actions = actions.view(-1)
            else:
                actions = actions.view(-1, actions.size(-1) if actions.ndim > 2 else 1)

        return state, states, actions, logProbabilities, returns, advantages, timesteps, progressBar

Here is my ForwardPass function, used to collect inner loop rollouts without breaking the computational graph. When I want to use the actual networks weights, i just don’t pass a parameter input.

ForwardPass
def ForwardPass(self, state, parameters):

        if parameters is None:
            return self.actorCritic(state)
        # if the parameters are not None, we are using the network for MAML and need to pass the adapted parameters to the forward function
        else:
            return func.functional_call(self.actorCritic, parameters, state)

The I calculate my loss in the MAML class with this ComputeLoss function.

ComputeLoss
def ComputeLoss(self, values, returns, probabilityDistribution, probabilityRatio, advantages):

        # calculate the value loss under the current critic network
        values = values.squeeze() # remove the extra dimension
        valueLoss = functional.smooth_l1_loss(returns, values)  # smooth l1 loss is less sensitive to outliers

        # calculate the entropy loss   
        entropyLoss = -probabilityDistribution.entropy().mean() 
       
        # calculate policy loss
        surrogate1 = probabilityRatio * advantages
        surrogate2 = torch.clamp(probabilityRatio, 1 - self.epsilon, 1 + self.epsilon) * advantages

        # PPO attempts to maximise the clipped objective function but ADAM minimises the loss so we add a negative sign
        policyLoss = -torch.min(surrogate1, surrogate2).mean()

        # combine all the losses
        loss = policyLoss + (valueLoss * self.valueLossCoef)  - (entropyLoss * self.entropyCoef) # the entropy loss is subtracted because we want to maximise entropy

        return loss, valueLoss, policyLoss , entropyLoss