TypeError: forward() takes 2 positional arguments but 3 were given in JIT trace (but works fine otherwise)

I have this model:

class PriorBox4JIT(torch.nn.Module):
    def __init__(self, cfg):
        super(PriorBox4JIT, self).__init__()
        self.cfg=cfg
        self.idf = torch.nn.Identity()

    def forward(self,image_size):
        self_min_sizes = self.cfg['min_sizes']
        self_steps = self.cfg['steps']
        self_clip = self.cfg['clip']
        self_image_size = image_size
        self_feature_maps = self_image_size[1].to(dtype=torch.int64).tolist()
        self_name = "s"
        anchors = []
        for k, f in enumerate(self_feature_maps):
            min_sizes = self_min_sizes[k]
            for i, j in product(range(f[0]), range(f[1])):
                for min_size in min_sizes:
                    s_kx = min_size / self_image_size[0][1].item()
                    s_ky = min_size / self_image_size[0][0].item()
                    dense_cx = [x * self_steps[k] / self_image_size[0][1].item() for x in [j + 0.5]]
                    dense_cy = [y * self_steps[k] / self_image_size[0][0].item() for y in [i + 0.5]]
                    for cy, cx in product(dense_cy, dense_cx):
                        anchors += [cx, cy, s_kx, s_ky]

        # back to torch land
        output = torch.Tensor(anchors).view(-1, 4)
        if self_clip:
            output.clamp_(max=1, min=0)
        return output

I pass inputs in vanilla mode (i.e no tracing) like so:

        priorbox_jit = PriorBox4JIT(cfg)
        inp0=torch.Tensor([im_height, im_width])
        inp1=torch.Tensor(   [  [ceil(im_height/step), ceil(im_width/step)] for step in cfg['steps']    ]  )
        inputs=(inp0, inp1)

        priors_jit = priorbox_jit(  inputs )
        prior_data_jit = priors_jit.data
        print("prior_data_jit.shape", prior_data_jit.shape)

and this returns me what I expect: prior_data_jit.shape torch.Size([8142, 4])

However, when I try to trace the same model like so:

traced_script_module2 = torch.jit.trace(priorbox_jit.eval(), inputs )

I get thrown an error:

TypeError: forward() takes 2 positional arguments but 3 were given

I am confused why this is - it seems to do the forward() correctly when I do not traceā€¦

Any pointers would be great :frowning: