Gradients are None for some input variables

Hi!

I’m trying to use PyTorch in a somewhat unconventional way. My goal is to design a differentiable procedural texture rendering framework. The main rendering is being done in OpenGL using GLSL shaders, and this GLSL code (and the 2D texture rendering process) is then being translated by me into PyTorch code. A user can then load a texture image, and the system is supposed to find the parameters for the selected shader that matches the input image, using a loss and gradient descent. Currently, I’m simply using an MSE loss and I’ve gotten it to work fairly well using simple shaders.

My problem is that, I’ve set up a “BrickShader” that is a bit more complex, and I’m not getting any gradients for my floating parameters, only for my color variables. It’s quite a bit of code I’ll have to show so bear with me.

This is the essential part of my gradient descent iterations:

new_loss = self.loss_torch2(*params)
grads = torch.autograd.grad(outputs=new_loss, inputs=params, create_graph=True, retain_graph=True, allow_unused=False, only_inputs=True)

with torch.no_grad():
    for p, g in zip(params, grads):
        if g is not None:
            p -= lr * g

This is where it fails, as I get the error message RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

My shader parameters are stored in params and I need gradients for all of them. I’m assuming I’m doing something to break the backpropagation for my floating variables…

This is the main shading function:

def shade_torch(self, vert_pos: Tensor, mortar_scale: Tensor, brick_scale: Tensor, brick_elongate: Tensor, brick_shift: Tensor,
                color_brick: Tensor, color_mortar: Tensor) -> Tensor:

    scale = torch.stack([torch.div(brick_scale, brick_elongate + TINY_FLOAT), brick_scale, brick_scale])
    uv3 = self._brickTileTorch(vert_pos, scale, brick_shift)
    b = box(uv3[0:2], tensor((mortar_scale, mortar_scale)), tensor(0.0))
    frag_color = mix(color_mortar, color_brick, b)
    return frag_color

vert_pos is a pixel position (like an interpolated fragment position in GLSL) and does not require a gradient. The only gradients I get are for color_brick and color_mortar. I’m thinking that something may be wrong with _brickTileTorch as I’m suspecting it’s not a great idea to pick out individual elements of a tensor (but I don’t know how else to do it):

def _brickTileTorch(self, tile: Tensor, scale: Tensor, shift: Tensor):
    tx = tile[0] * scale[0]
    ty = tile[1] * scale[1]
    tz = tile[2] * scale[2]

    st: Tensor = step(1.0, torch.fmod(ty, 2.0))
    tx_shifted = tx + shift * st

    return fract(torch.stack([tx_shifted, ty, tz]))

I’ll post the code below to the rest of the functions used as well as the rendering function I use. Can anyone spot anything that’s not right? Much appreciated!

Loss and render function

def loss_torch2(self, *args):
    render = render_funcs.render_torch2(self.width, self.height, self.f, *args)
    return F.mse_loss(render, self.truth)  # self.truth is Tensor containing the user's loaded image
def render_torch2(width: int, height: int, f: typing.Callable, *args):
    img = torch.zeros((4, width, height), device=device)
    x_pos, y_pos = _setup_coordinates(height, width)

    for row in range(height):
        for col in range(width):
            vert_pos = torch.tensor((x_pos[col], y_pos[row], 0.), dtype=torch.float32)
            val = f(vert_pos, *args)
            img[:, height - 1 - row, col] = val

    return img

Misc functions used in the shader

def box(coord: Tensor, size: Tensor, edge_smooth: Tensor) -> Tensor:
    """Returns 1.0 if this coordinate belongs to the box, 0.0 otherwise (with some interpolation based on edge_smooth)"""

    edge_smooth = torch.clamp(edge_smooth, 0.00001, 1.0)
    size = Tensor((0.5, 0.5)) - size * 0.5
    bx = smoothstep(size, size + edge_smooth, Tensor((1.0, 1.0)) - coord)
    bx *= smoothstep(size, size+edge_smooth, coord)
    return bx[0] * bx[1]

def step(edge: Tensor, x: Tensor) -> Tensor:
    """
    For element i of the return value, 0.0 is returned if x[i] < edge[i], and 1.0 is returned otherwise.
    """
    return (x >= edge).float()
def fract(x: Tensor) -> Tensor:
    return x - torch.floor(x)
def mix(x: Tensor, y: Tensor, a: Tensor) -> Tensor:
    """Performs a linear interpolation between x and y using a to weight between them.
    A weight of 1.0 returns y and a weight of 0.0 returns x.
    """
    return x * (1.0 - a) + y * a
def smoothstep(edge0: Tensor, edge1: Tensor, x: Tensor) -> Tensor:
    edgediff = (edge1 - edge0) + TINY_FLOAT
    t = torch.clamp((x - edge0) / edgediff, 0.0, 1.0)
    return t * t * (3.0 - 2.0 * t)

If you are worried about breaking the computation graph, you could add debug print statements in the code and check all tensors for a valid .grad_fn.
If you’ve accidentally detached a tensor, its grad_fn will be empty, otherwise it will show the backward version of the last operation.