I am trying to implement an autograd module for trilinear interpolation from features on a hash grid.
Basically, I have several grids at different resolutions, and I interpolate the 8 corners around each point at each level of the grid. The feature grid is indexed through a hash function.
Specifically, I am interested in writing out the backward pass for computing the derivative w.r.t the input coordinates.
The forward pass is
def forward(ctx, coords, resolutions, codebook_bitwidth, lod_idx, codebook):
_, feature_dim = codebook[0].shape
batch, num_samples, _ = coords.shape
codebook_size = 2**codebook_bitwidth
feats = []
for i, res in enumerate(resolutions[:lod_idx+1]):
tf_coords = torch.clip(((coords + 1.0) / 2.0) * res, 0, res-1-1e-5).reshape(-1, 3)
cc000 = torch.floor(tf_coords).short() # this kills the gradient, right?
cc = spc_ops.points_to_corners(cc000).long()
num_pts = res**3
if num_pts > codebook_size:
cidx = ((cc[...,0] * PRIMES[(i*3+0)%len(PRIMES)]) ^ \
(cc[...,1] * PRIMES[(i*3+1)%len(PRIMES)]) ^ \
(cc[...,2] * PRIMES[(i*3+2)%len(PRIMES)])) % codebook_size
else:
cidx = cc[...,0] + cc[...,1] * res + cc[...,2] * res * res
fs = codebook[i][cidx]
num = coords.size(0) * coords.size(1)
coeffs = torch.zeros(num, 8, device=coords.device, dtype=coords.dtype)
x = tf_coords - cc000
_x = 1.0 - x
coeffs[...,0] = _x[...,0] * _x[...,1] * _x[...,2]
coeffs[...,1] = _x[...,0] * _x[...,1] * x[...,2]
coeffs[...,2] = _x[...,0] * x[...,1] * _x[...,2]
coeffs[...,3] = _x[...,0] * x[...,1] * x[...,2]
coeffs[...,4] = x[...,0] * _x[...,1] * _x[...,2]
coeffs[...,5] = x[...,0] * _x[...,1] * x[...,2]
coeffs[...,6] = x[...,0] * x[...,1] * _x[...,2]
coeffs[...,7] = x[...,0] * x[...,1] * x[...,2]
coeffs = coeffs.reshape(-1, 8, 1)
fs_coeffs = (fs * coeffs).sum(1) # what is the derivative of this?
feats.append(fs_coeffs)
feats = torch.cat(feats, -1) # what is the derivative of this?
ctx.save_for_backward(coords)
ctx.codebook = codebook
ctx.resolutions = resolutions
ctx.num_lods = len(resolutions)
ctx.codebook_shapes = [_c.shape for _c in codebook]
ctx.codebook_size = 2**codebook_bitwidth
ctx.codebook_bitwidth = codebook_bitwidth
ctx.feature_dim = codebook[0].shape[-1]
return feats.reshape(batch, num_samples, -1)
And my current implementation of the backward pass is
def backward(ctx, grad_output):
# grad_output: B x (num_features * feature_dim)
coords = ctx.saved_tensors[0]
codebook = ctx.codebook
resolutions = ctx.resolutions
num_lods = ctx.num_lods
codebook_size = ctx.codebook_size
feature_dim = ctx.feature_dim
codebook_shapes = ctx.codebook_shapes
codebook_bitwidth = ctx.codebook_bitwidth
def indicator(x, a, b):
return torch.gt(x, a - 1e-5).float() * torch.gt(-x, - (b + 1e-5)).float()
grad_coords = torch.zeros_like(coords) # B x 3
res_grads = []
for i, res in enumerate(resolutions[:num_lods]):
dt = indicator((coords + 1) / 2 * res, 0, res-1) * res / 2 # B x3
dt0 = torch.cat([dt[...,0:1], torch.zeros_like(dt[...,1:3])], -1) # Derivative of tf_coords w.r.t. to coordinate at position 0
dt1 = torch.cat([torch.zeros_like(dt[...,0:1]), dt[...,1:2], torch.zeros_like(dt[...,2:3])], -1) # w.r.t. to coordinate at position 0
dt2 = torch.cat([torch.zeros_like(dt[...,0:2]), dt[...,2:3]], -1) # w.r.t. to coordinate at position 0
# Compute some intermediate values as in the forward pass
tf_coords = torch.clip(((coords + 1.0) / 2.0) * res, 0, res-1-1e-5).reshape(-1, 3)
t0, t1, t2 = tf_coords[..., 0].unsqueeze(-1), tf_coords[..., 1].unsqueeze(-1), tf_coords[..., 2].unsqueeze(-1)
cc000 = torch.floor(tf_coords).short()
c0, c1, c2 = cc000[..., 0].unsqueeze(-1), cc000[..., 1].unsqueeze(-1), cc000[..., 2].unsqueeze(-1)
cc = spc_ops.points_to_corners(cc000).long()
num_pts = res**3
if num_pts > codebook_size:
cidx = ((cc[...,0] * PRIMES[(i*3+0)%len(PRIMES)]) ^ \
(cc[...,1] * PRIMES[(i*3+1)%len(PRIMES)]) ^ \
(cc[...,2] * PRIMES[(i*3+2)%len(PRIMES)])) % codebook_size
else:
cidx = cc[...,0] + cc[...,1] * res + cc[...,2] * res * res
fs = codebook[i][cidx] # B x feature_dim x 8
# this looks like I mess, but I obtained it by hand and double-checked it. Should be correct.
# dci: B x 3
dc0 = - dt2 - dt1 + dt1*t2+dt2*t1 -c2*dt1 -c1*dt2 - dt0 + dt0*t2+dt2*t0 -c2*dt0 + dt0*t1+dt1*t0 \
- (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c2*(dt0*t1 + dt1*t0) - c1*dt0 + c1*(dt0*t2 + dt2*t0) \
- c1*c2*dt0 - c0*dt2 - c0*dt1 + c0*(dt1*t2 + dt2*t1) - c0*c2*dt1 - c0*c1*dt2
dc1 = + dt2 - (dt1*t2 + dt2*t1) + c2*dt1 + c1*dt2 - (dt0*t2 + dt2*t0) + c2*dt0 + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
- c2*(dt0*t1 + dt1*t0) - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt2 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2
dc2 = + dt1 - (dt1*t2 + dt2*t1) + c2*dt1 + c1*dt2 - (dt0*t1 + dt1*t0) + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
- c2*(dt0*t1 + dt1*t0) + c1*dt0 - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt1 - c0*(dt1*t2 + dt2*t1) \
+ c0*c2*dt1 + c0*c1*dt2
dc3 = + (dt1*t2 + dt2*t1) - c2*dt1 - c1*dt2 - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
+ c2*(dt0*t1 + dt1*t0) + c1*(dt0*t2 + dt2*t0) - c1*c2*dt0 + c0*(dt1*t2 + dt2*t1) \
- c0*c2*dt1 - c0*c1*dt2
dc4 = + dt0 - (dt0*t2 + dt2*t0) + c2*dt0 - (dt0*t1 + dt1*t0) + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) \
- c2*(dt0*t1 + dt1*t0) + c1*dt0 - c1*(dt0*t2 + dt2*t0) + c1*c2*dt0 + c0*dt2 \
+ c0*dt1 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2
dc5 = + (dt0*t2 + dt2*t0) - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c1*(dt0*t2 + dt2*t0) \
- c2*dt0 + c2*(dt0*t1 + dt1*t0) - c1*c2*dt0 - c0*dt2 + c0*(dt1*t2 + dt2*t1) - c0*c1*dt2 - c0*c2*dt1
dc6 = + (dt0*t1 + dt1*t0) - (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) + c2*(dt0*t1 + dt1*t0) - c1*dt0 \
+ c1*(dt0*t2 + dt2*t0) - c1*c2*dt0 - c0*dt1 + c0*(dt1*t2 + dt2*t1) - c0*c2*dt1 - c0*c1*dt2
dc7 = + (dt0*t1*t2 + dt1*t0*t2 + dt2*t0*t1) - c2*(dt0*t1 + dt1*t0) - c1*(dt0*t2 + dt2*t0) \
+ c1*c2*dt0 - c0*(dt1*t2 + dt2*t1) + c0*c2*dt1 + c0*c1*dt2
dc = torch.stack([dc0, dc1, dc2, dc3, dc4, dc5, dc6, dc7], dim=-1) # B x 3 x 8
grad_output_res = grad_output[..., feature_dim*i:feature_dim*(i+1)].unsqueeze(-1)
fs = fs.reshape(-1, feature_dim, 8)
# in the forward, there is an elementwise product fs * coeffs. Is this how the gradient should be handled? fs does not depend on the coordinates (because of the floor operation), only coeffs does
# here I multiply the upstream gradient by the features and aggregate over the dimension that was broadcasted in the forward (feature_dim)
grad_output_res = (grad_output_res * fs).sum(-2).reshape(-1, 1, 8)
fs_dc = grad_output_res * dc
fs_dc = fs_dc.sum(-1) # aggregate over the dim of size 8 (the corners of the grid voxel)
res_grads.append(fs_dc)
res_grads = torch.stack(res_grads, dim=-2)
res_grads = res_grads.reshape(-1, num_lods, 3)
grad_coords = res_grads.sum(-2) # aggregate the gradients of the concatenated outputs
return (grad_coords, None, None, None, *grad_codebook)
This does not seem to be correct, as it produces gradients that are up to 8 orders of magnitude greater than the gradients automatically produced by PyTorch’s autograd engine.
For now you can safely assume that the derivatives dc0
, dc1
, …, dc7
of the coefficients coeffs
w.r.t. the coordinates coords
are correct.
I am mainly suspicious of the way I am handling broadcasting in the forward pass in the multiplication of fs
(feature_dim
x 8
) and coeffs
(1
x 8
) . In the backward I am multiplying each coeff with its corresponding features in fs
and then summing over the feature_dim
dimension.
I am not sure either about the way to handle concatenation. Should I just pass the gradient backwards through it? Besides, is it true that floor
kills off the gradients? I have added comments in the code, but I am happy to clarify any details.
Again, the problem is with PyTorch, but rather with my understanding of how derivatives should be computed. However I am at a loss as to why the gradients produced by PyTorch are so different from those I am getting.