Custom Backward Pass for Trilinear Interpolation

I am trying to implement an autograd module for trilinear interpolation from features on a hash grid.
Basically, I have several grids at different resolutions, and I interpolate the 8 corners around each point at each level of the grid. The feature grid is indexed through a hash function.

Specifically, I am interested in writing out the backward pass for computing the derivative w.r.t the input coordinates.

The forward pass is


def forward(ctx, coords, resolutions, codebook_bitwidth, lod_idx, codebook):

    _, feature_dim = codebook[0].shape
    batch, num_samples, _ = coords.shape
    codebook_size = 2**codebook_bitwidth
    feats = []
    for i, res in enumerate(resolutions[:lod_idx+1]):
        tf_coords = torch.clip(((coords + 1.0) / 2.0) * res, 0, res-1-1e-5).reshape(-1, 3)
        cc000 = torch.floor(tf_coords).short()  # this kills the gradient, right?
        cc = spc_ops.points_to_corners(cc000).long()
        num_pts = res**3
        if num_pts > codebook_size:
            cidx = ((cc[...,0] * PRIMES[(i*3+0)%len(PRIMES)]) ^ \
                        (cc[...,1] * PRIMES[(i*3+1)%len(PRIMES)]) ^ \
                        (cc[...,2] * PRIMES[(i*3+2)%len(PRIMES)])) % codebook_size
        else:
              cidx = cc[...,0] + cc[...,1] * res + cc[...,2] * res * res
        fs = codebook[i][cidx]

        num = coords.size(0) * coords.size(1)
        coeffs = torch.zeros(num, 8, device=coords.device, dtype=coords.dtype)
        x = tf_coords - cc000
        _x = 1.0 - x
        coeffs[...,0] = _x[...,0] * _x[...,1] * _x[...,2]
        coeffs[...,1] = _x[...,0] * _x[...,1] * x[...,2]
        coeffs[...,2] = _x[...,0] * x[...,1] * _x[...,2]
        coeffs[...,3] = _x[...,0] * x[...,1] * x[...,2]
        coeffs[...,4] = x[...,0] * _x[...,1] * _x[...,2]
        coeffs[...,5] = x[...,0] * _x[...,1] * x[...,2]
        coeffs[...,6] = x[...,0] * x[...,1] * _x[...,2]
        coeffs[...,7] = x[...,0] * x[...,1] * x[...,2]
        coeffs = coeffs.reshape(-1, 8, 1)

        fs_coeffs = (fs * coeffs).sum(1)  # what is the derivative of this?

        feats.append(fs_coeffs)

    feats = torch.cat(feats, -1)  # what is the derivative of this?

    ctx.save_for_backward(coords)
    ctx.codebook = codebook
    ctx.resolutions = resolutions
    ctx.num_lods = len(resolutions)
    ctx.codebook_shapes = [_c.shape for _c in codebook]
    ctx.codebook_size = 2**codebook_bitwidth
    ctx.codebook_bitwidth = codebook_bitwidth
    ctx.feature_dim = codebook[0].shape[-1]

    return feats.reshape(batch, num_samples, -1)

And my current implementation of the backward pass is

def backward(ctx, grad_output):
    # grad_output: B x (num_features * feature_dim)
    coords = ctx.saved_tensors[0]
    codebook = ctx.codebook
    resolutions = ctx.resolutions
    num_lods = ctx.num_lods
    codebook_size = ctx.codebook_size
    feature_dim = ctx.feature_dim
    codebook_shapes = ctx.codebook_shapes
    codebook_bitwidth = ctx.codebook_bitwidth

    def indicator(x, a, b):
        return torch.gt(x, a - 1e-5).float() * torch.gt(-x, - (b + 1e-5)).float()

    grad_coords = torch.zeros_like(coords)   # B x 3
    res_grads = []
    for i, res in enumerate(resolutions[:num_lods]):
        dt = indicator((coords + 1) / 2 * res, 0, res-1) * res / 2  # B x3
        dt0 = torch.cat([dt[...,0:1], torch.zeros_like(dt[...,1:3])], -1)  # Derivative of tf_coords w.r.t. to coordinate at position 0
        dt1 = torch.cat([torch.zeros_like(dt[...,0:1]), dt[...,1:2], torch.zeros_like(dt[...,2:3])], -1)  # w.r.t. to coordinate at position 0
        dt2 = torch.cat([torch.zeros_like(dt[...,0:2]), dt[...,2:3]], -1)  # w.r.t. to coordinate at position 0

        # Compute some intermediate values as in the forward pass
        tf_coords = torch.clip(((coords + 1.0) / 2.0) * res, 0, res-1-1e-5).reshape(-1, 3)
        t0, t1, t2 = tf_coords[..., 0].unsqueeze(-1), tf_coords[..., 1].unsqueeze(-1), tf_coords[..., 2].unsqueeze(-1)
        cc000 = torch.floor(tf_coords).short()
        c0, c1, c2 = cc000[..., 0].unsqueeze(-1), cc000[..., 1].unsqueeze(-1), cc000[..., 2].unsqueeze(-1)
        cc = spc_ops.points_to_corners(cc000).long()

        num_pts = res**3
        if num_pts > codebook_size:
            cidx = ((cc[...,0] * PRIMES[(i*3+0)%len(PRIMES)]) ^ \
                        (cc[...,1] * PRIMES[(i*3+1)%len(PRIMES)]) ^ \
                        (cc[...,2] * PRIMES[(i*3+2)%len(PRIMES)])) % codebook_size
        else:
            cidx = cc[...,0] + cc[...,1] * res + cc[...,2] * res * res
        fs = codebook[i][cidx]  # B x feature_dim x 8

        # this looks like I mess, but I obtained it by hand and double-checked it. Should be correct.
        # dci: B x 3
        dc0 = - dt2 - dt1 + dt1*t2+dt2*t1 -c2*dt1 -c1*dt2 - dt0 + dt0*t2+dt2*t0 -c2*dt0 + dt0*t1+dt1*t0 \
                  - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c2*(dt0*t1 + dt1*t0) - c1*dt0 + c1*(dt0*t2 + dt2*t0) \
                  - c1*c2*dt0 - c0*dt2 - c0*dt1 + c0*(dt1*t2 + dt2*t1) - c0*c2*dt1 - c0*c1*dt2

        dc1 = + dt2 - (dt1*t2 + dt2*t1) + c2*dt1 + c1*dt2 - (dt0*t2 + dt2*t0) + c2*dt0 + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
                - c2*(dt0*t1 + dt1*t0) - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt2 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2

        dc2 = + dt1 - (dt1*t2 + dt2*t1) + c2*dt1 + c1*dt2 - (dt0*t1 + dt1*t0) + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
                  - c2*(dt0*t1 + dt1*t0) + c1*dt0 - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt1 - c0*(dt1*t2 + dt2*t1) \
                  + c0*c2*dt1 + c0*c1*dt2

        dc3 = + (dt1*t2 + dt2*t1) - c2*dt1 - c1*dt2 - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
                  + c2*(dt0*t1 + dt1*t0) + c1*(dt0*t2 + dt2*t0) - c1*c2*dt0 + c0*(dt1*t2 + dt2*t1) \
                  - c0*c2*dt1 - c0*c1*dt2

        dc4 = + dt0 - (dt0*t2 + dt2*t0) + c2*dt0 - (dt0*t1 + dt1*t0) + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
                  - c2*(dt0*t1 + dt1*t0) + c1*dt0 - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt2 \
                  + c0*dt1 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2

        dc5 = + (dt0*t2 + dt2*t0) - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c1*(dt0*t2 + dt2*t0) \
                  - c2*dt0 + c2*(dt0*t1 + dt1*t0) - c1*c2*dt0 - c0*dt2 + c0*(dt1*t2 + dt2*t1) - c0*c1*dt2 - c0*c2*dt1

        dc6 = + (dt0*t1 + dt1*t0) - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c2*(dt0*t1 + dt1*t0) - c1*dt0 \
                  + c1*(dt0*t2 + dt2*t0) - c1*c2*dt0 - c0*dt1 + c0*(dt1*t2 + dt2*t1) - c0*c2*dt1 - c0*c1*dt2

        dc7 = + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) - c2*(dt0*t1 + dt1*t0) - c1*(dt0*t2 + dt2*t0) \
                  + c1*c2*dt0 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2

        dc = torch.stack([dc0, dc1, dc2, dc3, dc4, dc5, dc6, dc7], dim=-1)  # B x 3 x 8

        grad_output_res = grad_output[..., feature_dim*i:feature_dim*(i+1)].unsqueeze(-1)

        fs = fs.reshape(-1, feature_dim, 8)
        # in the forward, there is an elementwise product fs * coeffs. Is this how the gradient should be handled? fs does not depend on the coordinates (because of the floor operation), only coeffs does
        # here I multiply the upstream gradient by the features and aggregate over the dimension that was broadcasted in the forward (feature_dim)
        grad_output_res = (grad_output_res * fs).sum(-2).reshape(-1, 1, 8)

        fs_dc = grad_output_res * dc
        fs_dc = fs_dc.sum(-1)  # aggregate over the dim of size 8 (the corners of the grid voxel)

        res_grads.append(fs_dc)

    res_grads = torch.stack(res_grads, dim=-2)
    res_grads = res_grads.reshape(-1, num_lods, 3)

    grad_coords = res_grads.sum(-2)  # aggregate the gradients of the concatenated outputs
    
    return (grad_coords, None, None, None, *grad_codebook)

This does not seem to be correct, as it produces gradients that are up to 8 orders of magnitude greater than the gradients automatically produced by PyTorch’s autograd engine.

For now you can safely assume that the derivatives dc0, dc1, …, dc7 of the coefficients coeffs w.r.t. the coordinates coords are correct.

I am mainly suspicious of the way I am handling broadcasting in the forward pass in the multiplication of fs (feature_dim x 8) and coeffs (1 x 8) . In the backward I am multiplying each coeff with its corresponding features in fs and then summing over the feature_dim dimension.
I am not sure either about the way to handle concatenation. Should I just pass the gradient backwards through it? Besides, is it true that floor kills off the gradients? I have added comments in the code, but I am happy to clarify any details.

Again, the problem is with PyTorch, but rather with my understanding of how derivatives should be computed. However I am at a loss as to why the gradients produced by PyTorch are so different from those I am getting.