Pyotrch equivalent of noise adaptation layer keras code


I have written a pytorch version of a noise adaptation layer (Paper) which is originally in keras.
The keras version is as follows:

from __future__ import print_function
from keras.layers import Dense
from keras import backend as K

class Channel(Dense):

    def __init__(self, units = None, **kwargs):
        kwargs['use_bias'] = False
        if 'activation' not in kwargs:
            kwargs['activation'] = 'softmax'
        super(Channel, self).__init__(units, **kwargs)

     def build(self, input_shape):
        if self.units is None:
           self.units = input_shape[-1]
        super(Channel, self).build(input_shape)

     def call(self, x, mask=None):
           channel_matrix = self.activation(self.kernel)
           return, channel_matrix)`

The pytorch equivalent which I have written is as follows:

` from torch.nn.parameter import Parameter

class Channel(nn.Module):
   def __init__(self, input_dim, bias = False, *argv):
     super(Channel, self).__init__()        
     self.input_dim = input_dim
     self.activation = nn.Softmax(dim=1)
     if len(argv)==0:
        # construct the proper layer as it is not initialized 
        # from some previously learned models
        self.weight = Parameter(torch.Tensor(input_dim, input_dim))            
        if bias:
            self.bias = Parameter(torch.Tensor(input_dim))
            self.register_parameter('bias', None)                            
        # use the pre-initialized weights             
        self.weight = Parameter(argv[0])
        if bias:
            self.bias = Parameter(torch.zeros(input_dim))  
            self.register_parameter('bias', None)   
  def forward(self, x):
    channel_matrix = self.activation(self.weight)        
    return torch.matmul(x, channel_matrix)
  def reset_parameters(self):
    n = self.input_dim
    stdv = 1. / math.sqrt(n), stdv)
    if self.bias is not None:, stdv)`

But it is not behaving as expected. Is there any difference between the two scripts? Any info would be helpful. Thanks.

Hi @Yashas_Annadani,

Did you fix the code?

@Longlong_Jing Unfortunately, I have not been able to fix it yet. Do you have any clue as to why this might be wrong?