Cannot convert self-defined conv2d to onnx

Hi,

I am using torch1.6 from conda on linux platform.
My code is like this:

import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv2dWS(nn.Conv2d):

    def __init__(self,
                 in_chan,
                 out_chan,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 eps=1e-5
                ):
        super(Conv2dWS, self).__init__(
            in_chan,
            out_chan,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.std_eps = eps

    def forward(self, x):
        weight = self.get_weight() if self.training else self.weight.detach()
        return F.conv2d(x,
                        weight,
                        self.bias,
                        self.stride,
                        self.padding,
                        self.dilation,
                        self.groups)

    def get_weight(self):
        N, _, _, _ = self.weight.size()
        weight = self.weight
        mean = weight.mean(dim=(1, 2, 3), keepdim=True)
        weight = weight - mean
        std = weight.std(dim=(1, 2, 3), keepdim=True) + self.std_eps
        weight = torch.div(weight, std)
        return weight

    @torch.no_grad()
    def _save_to_state_dict(self, destination, prefix, keep_vars):
        for name, param in self._parameters.items():
            if param is not None:
                if name == 'weight':
                    param = self.get_weight().detach()
                destination[prefix + name] = param if keep_vars else param.data
        for name, buf in self._buffers.items():
            if buf is not None:
                destination[prefix + name] = buf if keep_vars else buf.data


net = Conv2dWS(32, 64, 3, 1, 1)
net.eval()

dummy = torch.randn(1, 32, 768, 768)
torch.onnx.export(net, dummy, 'model.on

The error message is:

Traceback (most recent call last):
  File "play.py", line 69, in <module>
    verbose=False, opset_version=11)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 426, in forward
    self._force_outplace,
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/jit/__init__.py", line 412, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "play.py", line 34, in forward
    weight = self.get_weight() if self.training else self.weight.detach()
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

How could I make it work please ?