Help with derivatives inside loss function

I’m trying to include a physics-informed term in a DeepONet model in pytorch. Basically, the model includes two DNNs, one with a single input, and the other with 3 inputs. The model then computes the dot-product between the outputs of these DNNs. The scalar output of the model is supposed to predict the value of a physical quantity in a certain point in space at a certain instant. This implements
both DNNs (branch and trunk) and the model:

class branchNet(nn.Module):
    """Branch network definition"""
    def __init__(self, inDim: int, nnDepth: int, nnWidth: int):
        super().__init__()
        # Input layer. Resizes input to desired network width
        self.inputLayer = nn.Linear(inDim, nnWidth)
        # intermediate dense layers. constant dimension
        self.MLPstack = nn.ModuleList([nn.Linear(nnWidth, nnWidth) for _ in range(nnDepth - 2)])
        # output layer. resizes network intermediate representation to networkOutputDim
        self.outputLayer = nn.Linear(nnWidth, networkOutputDim)

    def forward(self, x): # forward pass
        x = F.relu(self.inputLayer(x))
        for l in self.MLPstack:
            x = F.relu(l(x))
        return self.outputLayer(x)
    
class trunkNet(nn.Module):
    """Trunk network definition"""
    def __init__(self, nnDepth: int, nnWidth: int):
        super().__init__()
        # Input layer. Resizes input to desired network width.
        # Consider trunk network as receiving individual inputs
        # for each dimension (x, y, t)
        self.xCoord = nn.Linear(1, nnWidth)
        self.yCoord = nn.Linear(1, nnWidth)
        self.tCoord = nn.Linear(1, nnWidth)
        # intermediate dense layers. constant dimension
        self.MLPstack = nn.ModuleList([nn.Linear(nnWidth, nnWidth) for _ in range(nnDepth - 2)])
        # output layer. resizes network intermediate representation to networkOutputDim
        self.outputLayer = nn.Linear(nnWidth, networkOutputDim)

    def forward(self, x, y, t): # forward pass
        x = F.relu(self.xCoord(x))
        y = F.relu(self.yCoord(y))
        t = F.relu(self.tCoord(t))
        o = x + y + t
        for l in self.MLPstack:
            o = F.relu(l(o))
        return self.outputLayer(o)

class PI_deepONet(nn.Module):
    """Class for physics-informed DeepONet"""
    def __init__(self, branch: nn.Module, trunk: nn.Module):
        super().__init__()
        self.branch = branch
        self.trunk = trunk

    def forward(self, case: torch.tensor, x: torch.tensor, y: torch.tensor, t: torch.tensor):
        # transpose trunk output and multiply both for
        # sample-by-sample dot product in diagonal of result
        return F.relu(torch.diagonal(
            torch.matmul(
                self.branch(case),
                self.trunk(x, y, t).T
            )
        ).unsqueeze(1))

The physics-informed term I’m want to include is the residual of the PDE, which is calculated by differentiating the model’s output in certain ways. The PDE in question is: p_xx + p_yy - p_tt / (c ^ 2) (‘_xx’ is second derivative in x, p is the output of the model, and c is a constant).

I already tried a bunch of things. Currently, I’ve been testing the gradients with the code below:

branchNetwork = branchNet(nSensor, branchDepth, branchWidth)
trunkNetwork = trunkNet(trunkDepth, trunkWidth)
model = PI_deepONet(branchNetwork, trunkNetwork)
outM = model(branchBatch, xBatch, yBatch, tBatch)
p_x = torch.autograd.grad(
    outM, xBatch, 
    grad_outputs = torch.ones_like(outM),
    retain_graph = True,
    create_graph = True
)[0]
p_xx = torch.autograd.grad(
    p_x, xBatch, 
    grad_outputs = torch.ones_like(p_x),
    retain_graph = True,
    create_graph = True
)[0]

But both gradients are entirely zero. With a batch of size 256, xBatch, yBatch and tBatch have shape [256, 1]; the branch network input is [256, 1000]; the model output is [256, 1]; and the gradients have shape [256, 1].

Why are my gradients zero? How do I fix this?

Check the outputs of your model and make sure they were not set to zero by the F.relu.
Here is a small example showing how the init could create the expected zero gradients:

# negative init
x = torch.randn(4, 4) - 10.
x.requires_grad_()

y = torch.diagonal(x)
z = F.relu(y)
z.mean().backward()
x.grad
# tensor([[0., 0., 0., 0.],
#         [0., 0., 0., 0.],
#         [0., 0., 0., 0.],
#         [0., 0., 0., 0.]])

# positive init
x = torch.randn(4, 4) + 10.
x.requires_grad_()

y = torch.diagonal(x)
z = F.relu(y)
z.mean().backward()
x.grad
# tensor([[0.2500, 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.2500, 0.0000, 0.0000],
#         [0.0000, 0.0000, 0.2500, 0.0000],
#         [0.0000, 0.0000, 0.0000, 0.2500]])

It’s not using your full model, but just the last operations instead.

1 Like

Thanks for the tip! Removing the ReLUs solved the null first gradient. But now the second gradient gives an error: RuntimeError: One of the differentiated Tensors appears to not have been used in the graph.

Do I need to somehow include the backward() method in these calculations?

Here I found an example of second derivatives using .backward(). I tried this:

branchBatch, xBatch, yBatch, tBatch, targetBatch = next(iter(dataloader))
xBatch.requires_grad_(True)
yBatch.requires_grad_(True)
tBatch.requires_grad_(True)
xBatch.retain_grad()
yBatch.retain_grad()
tBatch.retain_grad()
model = PI_deepONet(branchNetwork, trunkNetwork)
outM = model(branchBatch, xBatch, yBatch, tBatch)
outM.sum().backward(create_graph = True)

p_x = xBatch.grad
p_y = yBatch.grad
p_t = tBatch.grad
xBatch.grad.data.zero_()
yBatch.grad.data.zero_()
tBatch.grad.data.zero_()
p_x.sum().backward(retain_graph = True)
p_y.sum().backward(retain_graph = True)
p_t.sum().backward(retain_graph = True)
print(xBatch.grad.abs().sum())
print(yBatch.grad.abs().sum())
print(tBatch.grad.abs().sum())

The are no problems with the first derivatives (matches the value returned by the previous implementation). But now the second gradients are zero. Any suggestions?

@ptrblck Sorry for the trouble, but do you have any more suggestions? I think I’m pretty close to the solution. Or do you know anyone else that could help?

I’m unsure what exactly you are trying to achieve as it seems you are setting the gradients to zero and are then calling .sum().backward() on these. I would thus expect to see zero gradients.