Weight update for New layer

(Jaideep Valani) #1

Hi All
I have defined a NN module as below. Since i want to normalize the weights before its multiplied to normalized input features , i am every time reassigning the last layer weights with its normalized version . Assume that fc1.weight get initialized to some desired values while this module gets appended to Main Module which is Dense network.

My query is would this fc1.weight gets updated at the end of batch so every time updated weights gets assigned after normalization?

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
    def __init__(self, in_features, out_features=5004):
        super(ArcMarginProduct, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        # nn.init.xavier_uniform_(self.weight)
        body = create_body(m.densenet121, True, -1)
#body = create_head(ArcMarginProduct, pretrained, 0)
        nf = num_features_model(nn.Sequential(*body.children())) * 2
        #head = 
        self.head=create_head(nf, 5004, [1024], ps=0.5, bn_final=False)[:-1]
        self.fc1=nn.Linear(1024,5004,bias=False) # this gets input of 1024 from last layer

    def forward(self, features):
        cosine = self.fc1(F.normalize(x))
        #F.linear(F.normalize(x), F.normalize(self.weight.cuda()))
        return cosine

When this gets added its sumary would as below which is head of Dense layer.

(1): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=(1, 1))
    (mp): AdaptiveMaxPool2d(output_size=(1, 1))
  (2): Flatten()
  (3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): Dropout(p=0.5)
  (5): Linear(in_features=2048, out_features=1024, bias=True)
  (6): ReLU()
  (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): Dropout(p=0.5)
  (9): Linear(in_features=1024, out_features=5004, bias=True)