Code for ParametrizedConv1d in transformer code for wav2vec2 for ctc

Hi i want to write code for


i have written the code , but for the weight file i’m getting keys as :

conv.conv.bias
conv.conv.weight_g
conv.conv.weight_v
conv.parametrizations.weight.bias
conv.parametrizations.weight.weight_g
conv.parametrizations.weight.weight_v

but i need only first 3 keys of the saved weight in code also.
what changes do i need to make
{
conv.conv.bias
conv.conv.weight_g
conv.conv.weight_v
}

import torch
import torch.nn as nn

class _WeightNorm(nn.Module):
def init(self):
super(_WeightNorm, self).init()

def forward(self, x):
    return x

class ParametrizedConv1d(nn.Module):
def init(self, in_channels, out_channels, kernel_size, stride, padding, groups):
super(ParametrizedConv1d, self).init()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.parametrizations = nn.ModuleDict({‘weight’: nn.utils.weight_norm(self.conv, name=‘weight’)})

def forward(self, x):
    return self.conv(x)

class ConvolutionalPositionalEmbedding(nn.Module):
def init(self, input_dim=768, output_dim=768, kernel_size=128, stride=1, padding=64, groups=16):
super(ConvolutionalPositionalEmbedding, self).init()
self.conv = ParametrizedConv1d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)

def forward(self, x):
    return self.conv(x)