Task Description:
I’ve constructed a model which is trained in two stages. The first stage is just a regular training loop, and the second is extra training for bias correction. During the second training phase, I created a variable (self.rotated_weight in this case) within the forward() to correct bias.
Question:
Can the variable created in forward() be accessed in the following training/inference, or the variable will just be demolish after the call of forward() is completed?
Code:
class AddWeightProduct(nn.Module):
def __init__(self, in_features, out_features, s=30, m=0.5, easy_margin=True):
super(AddWeightProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
#self.rotation=nn.Parameter(torch.eye(in_features,in_features))
self.v1 = nn.ParameterList([nn.Parameter(torch.rand(1,in_features)) for i in range(out_features)])
self.v2 = nn.ParameterList([nn.Parameter(torch.rand(1,in_features)) for i in range(out_features)])
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
self.rotation=False
def torch_HT(self,VEC):
device=VEC.device
#print(device)
u=torch.t(VEC)
uH=torch.adjoint(u)
E=torch.eye(u.shape[0]).to(device)
#print(torch.mm(u,uH).shape)
H=torch.sub(E,2*torch.mul(u,uH))
return H
def forward(self, input, label=None):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.rotation:
templist=[]
for cls in range(self.out_features):
flp_dir1=F.normalize(self.v1[cls])
H1=self.torch_HT(flp_dir1)
flp_dir2=F.normalize(self.v2[cls])
H2=self.torch_HT(flp_dir2)
H=F.linear(H1,H2)
weight=torch.unsqueeze(self.weight[cls],dim=0)
w_prim=F.linear(weight,H)
w_rotate=torch.squeeze(w_prim)
templist.append(w_rotate)
self.rotated_weight=torch.stack(templist,dim=0)
cosine = F.linear(F.normalize(input), F.normalize(self.rotated_weight))
else:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m#三角公式cos(A+B)=cosAcosB-sinAsinB
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
output = phi*self.s
return output