Hi i want to write code for
i have written the code , but for the weight file i’m getting keys as :
conv.conv.bias
conv.conv.weight_g
conv.conv.weight_v
conv.parametrizations.weight.bias
conv.parametrizations.weight.weight_g
conv.parametrizations.weight.weight_v
but i need only first 3 keys of the saved weight in code also.
what changes do i need to make
{
conv.conv.bias
conv.conv.weight_g
conv.conv.weight_v
}
import torch
import torch.nn as nnclass _WeightNorm(nn.Module):
def init(self):
super(_WeightNorm, self).init()def forward(self, x): return x
class ParametrizedConv1d(nn.Module):
def init(self, in_channels, out_channels, kernel_size, stride, padding, groups):
super(ParametrizedConv1d, self).init()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.parametrizations = nn.ModuleDict({‘weight’: nn.utils.weight_norm(self.conv, name=‘weight’)})def forward(self, x): return self.conv(x)
class ConvolutionalPositionalEmbedding(nn.Module):
def init(self, input_dim=768, output_dim=768, kernel_size=128, stride=1, padding=64, groups=16):
super(ConvolutionalPositionalEmbedding, self).init()
self.conv = ParametrizedConv1d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)def forward(self, x): return self.conv(x)