import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
torch.set_printoptions(precision=10)
class LargeMarginSoftmaxLinear(nn.Module):
def __init__(self, in_features, out_features, m, _lambda, use_cuda, device):
super(LargeMarginSoftmaxLinear, self).__init__()
self.w = Parameter(torch.zeros(out_features, in_features).to(device))
self.reset_parameters()
self.m = int(m)
self._lambda = float(_lambda)
self.m_choose_n_map = torch.zeros(self.m+1)
self.k_map = torch.zeros(self.m+1)
for i in range(self.m+1):
n = k = i
self.m_choose_n_map[i] = math.factorial(self.m) / math.factorial(n) / math.factorial(self.m-n)
self.k_map[i] = math.cos(k * math.pi / self.m)
# any better way than this for device agnostic code?
self.use_cuda = use_cuda
self.device = device
if self.use_cuda:
self.m_choose_n_map = self.m_choose_n_map.to(self.device)
self.k_map = self.k_map.to(self.device)
def reset_parameters(self):
stdv = 1. / math.sqrt(self.w.size(1))
self.w.data.uniform_(-stdv, stdv)
def determine_k(self, cos_theta):
k = torch.zeros_like(cos_theta)
for i in range(self.m-1):
k += (self.k_map[i+1] >= cos_theta).type(cos_theta.dtype)
return k
def evaluate_cos_m_theta(self, cos_theta):
sin_square_theta = 1 - cos_theta.pow(2)
n = torch.range(0, self.m // 2, dtype=cos_theta.dtype).view(-1, 1)
if self.use_cuda:
sin_square_theta = sin_square_theta.to(self.device)
n = n.to(self.device)
cos_m_theta = pow(-1, n) * self.m_choose_n_map[2*n.long()] * cos_theta.pow(self.m - 2*n) * sin_square_theta.pow(n)
cos_m_theta = torch.sum(cos_m_theta, 0)
return cos_m_theta
def forward(self, x, y=None):
x_dot_wT = x.mm(self.w.transpose(0, 1))
f_y_i = torch.tensor(x_dot_wT)
if y is not None:
batch_size = y.size(0)
w_norm = self.w.norm(p=2, dim=1)
x_norm = x.norm(p=2, dim=1)
y_i = x_dot_wT.gather(1, y.view(-1, 1)).squeeze(dim=1)
cos_theta = y_i / (x_norm * w_norm.index_select(0, y))
cos_m_theta = self.evaluate_cos_m_theta(cos_theta)
k = self.determine_k(cos_theta)
idxs = torch.arange(0, batch_size, dtype=torch.long)
if self.use_cuda:
idxs = idxs.to(self.device)
f_y_i[idxs, y] = ((self._lambda * y_i) + ((pow(-1, k) * cos_m_theta - 2*k) * x_norm * w_norm.index_select(0, y))) / (1 + self._lambda)
return f_y_i
I couldn’t figure out the problem.
I could train this longer before hitting NAN with double tensor, which suggests that it has to be numerical problem.
Before I hit NAN, the loss will decrease in the double case.