Brenier maps in Pytorch

I’m trying to train a Brenier map using Pytorch. A Brenier map is the gradient of a convex function, so I’m using Input Convex Neural Network (ICNN) to implement the convex function. Here is my code.

class ICNN(nn.Module):
    def __init__(self, dim, layers):
        super(ICNN, self).__init__()
        self.dim = dim
        self.layers = layers
        self.Wbx = []
        for _ in range(0,layers-1):
            self.Wbx.append(nn.Linear(dim, dim, bias=True))
        self.Wbx.append(nn.Linear(dim, 1, bias=True))

        self.Wz = []
        for _ in range(1,layers-1):
            self.Wz.append(nn.Linear(dim, dim, bias=False))
        self.Wz.append(nn.Linear(dim, 1, bias=False))
        self.clip()

    def clip(self):
        for W in self.Wz:
            with torch.no_grad():
                W.weight[W.weight<0.0] = 0.0

    def forward(self, x):
        z = torch.pow(nn.functional.leaky_relu(self.Wbx[0](x), 0.2), 2.0)
        for i in range(1,self.layers):
            z = nn.functional.leaky_relu(self.Wz[i-1](z) + self.Wbx[i](x), 0.2)
        return z


class BRENIER(nn.Module):
    def __init__(self, dim, layers):
        super(RENIER, self).__init__()
        self.dim = dim
        self.layers = layers
        self.icnn = ICNN(dim, layers)

    def clip(self):
        self.icnn.clip()
    
    def forward(self, x):
        z = self.icnn(x)
        dzdx = torch.autograd.grad(outputs=z, inputs=x, create_graph=True)[0]
        return dzdx

When processing a batch of data, I get the error in the “torch.autograd.grad” call in “forward”.

“RuntimeError: grad can be implicitly created only for scalar outputs”

I understand that autograd.grad wants to be called on a scalar function, so how should I set this up so that autodifferentiation will work on this gradient?

Either reduce the output tensor to a scalar or provide the gradient explicitly in the same shape as the output tensor.

Besides that your code shows other issues since you are not properly registering the modules by using plain lists:

class ICNN(nn.Module):
    def __init__(self, dim, layers):
        super(ICNN, self).__init__()
        self.dim = dim
        self.layers = layers
        self.Wbx = []
        for _ in range(0,layers-1):
            self.Wbx.append(nn.Linear(dim, dim, bias=True))
        self.Wbx.append(nn.Linear(dim, 1, bias=True))

        self.Wz = []
        for _ in range(1,layers-1):
            self.Wz.append(nn.Linear(dim, dim, bias=False))
        self.Wz.append(nn.Linear(dim, 1, bias=False))

Use nn.ModuleList instead to make sure the linear layers are registered.