Unable to export pytorch model with dynamically changing **kernel** shape to ONNX

I have a PyTorch model that performs correlation between the dynamically changing shapes of template and search images. For example:

The pytorch model code:

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, template, search):
        out = torch.nn.functional.conv2d(search, template)

        return out

The onnx export code:

dummy_inputs = (dummy_template, dummy_search)
input_names = ["template", "search"]
output_names = ["outputs"]

dynamic_axes = {
    "template": {
        2: "height",
        3: "width"
    },
    "search": {
        2: "height",
        3: "width"
    }
}

torch.onnx.export(model,
                  args=dummy_inputs,
                  f=onnx_path,
                  input_names=input_names,
                  output_names=output_names,
                  dynamic_axes=dynamic_axes,
                  opset_version=11,
                  export_params=True)

I got this error RuntimeError: Unsupported: ONNX export of convolution for kernel of unknown shape.

Cause I use torch.nn.functional.conv2d for the correlation operation. But the input shape of template is dynamically changing, so I got the error above.

Currently, I have tried to implement the correlation manually.

def corr(input: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
    channel = input.shape[1]
    in_h, in_w = input.shape[-2:]
    kh, kw = kernel.shape[-2:]
    output_width = (in_w - kw) + 1
    output_height = (in_h - kh) + 1

    output = torch.zeros(1, channel, output_height, output_width).to(input.device)

    for c in range(channel):
        for h in range(in_h - kh + 1):
            for w in range(in_w - kw + 1):
                input_window = input[0, c, h:h+kh, w:w+kw]
                kernel_window = kernel[0, c, :, :]
                corr = torch.sum(input_window * kernel_window)
                output[0, c, h, w] = corr

    return output

However, the results using this method is incorrect, possibly because ONNX doesn’t support the slice method (input[0, c, h:h+kh, w:w+kw])

I want to know any possible solution.