DDPG agent with convolutional layers for feature extraction

I’m trying to come up with a definition of the critic for a DDPG agent in PyTorch using a CNN as a feature extractor. It is pretty straight forward for the actor model. However, for the critic model I am not sure. Below, I’ve given my code for the actor model.

class actor(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit):
        super(actor, self).__init__()

        channels, height, width = obs_dim
        self.act_limit = act_limit

         self.actor = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=8, stride=4)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)),
            ('relu2', nn.ReLU()),
            ('conv3', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)),
            ('relu3', nn.ReLU()),
            ('flatten', nn.Flatten()),
            ('linear1', nn.Linear(3136, 512)),
            ('relu4', nn.ReLU()),
            ('linear2', nn.Linear(512, act_dim))
        ]))
    def forward(self, obs):
        # Return output from network scaled to action space limits.
        return self.act_limit * self.actor(obs)

My confusion lies at the part with the concatenation of the action and the observation vectors and how to actually integrate that in my CNN model. A typical critic model is a MLP that looks like this:

class critic(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit):
        super(critic, self).__init__()
        self.act_limit = act_limit

        self.fc1 = nn.Linear(obs_dim + act_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.act_limit], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.out(x)

        return q_value

Any ideas on how to define a critic with the same (convolutional) feature extractor?

Hi @andreasceid
Have a look at our torchrl implementation: usually you would use a CNN for embedding the observation and concatenate the action with the embedded observation.
Hope that helps!