How to handle for loops with loop count being matrix Tensor


I know that the title might not make much sense, so let me explain. I’m writing a differentiable renderer in PyTorch, much like PyTorch3D. I want to model how OpenGL and GLSL renders shaders, which I did first using nested for loops for pixel positions, and then calling a shader function on each pixel. This was ridiculously slow in python, so I switched over to a matrix based approach, where all my pixel positions are pre-computed. This works very well and is about 1000 times faster. However, it now gives me a problem when I’m writing a cloud noise shader. I’m using the concept of fractal brownian motion which adds octaves of noise in a for loop. I want the user to be able to control how many octaves are added through a paramter called detail. The problem is that, since switching to a matrix based renderering, the function is supplied with parameters that are matrices and contains the parameters for each pixel. Let me show the code:

def shade_mat(self, scale: Tensor, detail: Tensor) -> Tensor:
    w, h = Shader.width, Shader.height
    uv = Shader.frag_pos[:, :, :2]  # Pre-calculated fragment coordinates (WxHx3)
    color = torch.tensor((0., 0., 0.)).repeat(w, h, 1)
    color = color + noise.fractalBrownianMotion(uv * scale, detail)

    return color  # This is the full rendered image
def fractalBrownianMotion(p: Tensor, detail: Tensor) -> Tensor:
    Adds octaves of smooth noise to create "fractal brownian motion" or "fractal noise".
    :param p: 2D Tensor of coordinates
    :param detail: Integer 2D tensor dictating the number of octaves to add
    :return: A pseudo-random 2D noise Tensor
    w, h = p.shape[0], p.shape[1]

    # Initial values
    value = torch.tensor((0.0)).repeat(w, h, 1)
    amplitude = torch.tensor((0.5)).repeat(w, h, 1)

    # Loop and add octaves of more and more detailed noise
    for i in range(4):  # detail is supposed to be used here but is a matrix...
        value = value + (amplitude * smoothNoise2D(p))
        p = p * 2.0
        amplitude = amplitude * 0.5

    return value

I need to use the detail matrix parameter in the for loop so that for pixel (0,0) i use detail[0,0,:]. I could do this if PyTorch had some sort of map function. Any idea how I could solve this?