Hi there,
I am attempting to implement a custom loss function that utilizes a novel measure of information. The loss function is here
def MI_loss(self, inputs, E0 = 188, E = 1000): # TODO: MAKE THIS FASTER in pc_info
'''
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
MI_loss: Avg. MI information loss for this batch
'''
preds = self.predict(inputs)
Nx, batch_size, Np = preds.shape
dist = 1 / Nx * torch.ones(Nx) ## Option to change this for more interesting distributions
M = pc_info.MI_spike(preds, dist)
M = M.to(self.device)
loss = torch.norm(M + eps, p = 'fro')
# ENERGY CONSTRAINT
total_ave_frs = dist @ preds.mean(dim = 1)
energy_constraint_loss = relu((E0 * total_ave_frs - E))
energy_loss = torch.mean(energy_constraint_loss)
return loss + energy_loss
Where pc_info.MI_spike is defined by the following few snippets of code:
def I_spike(pc, dist):
pc = torch.relu(pc)
pc_nonzero_indices = pc.nonzero(as_tuple = False)
info_matrix = torch.zeros(pc.shape)
norm_info = torch.zeros(pc.shape)
pc_nonzero = pc[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]
dist_nonzero = dist[pc_nonzero_indices[:, 0]]
spike_rate_per_cell = (pc * dist.view(-1, 1, 1)).sum(dim = 0)
spike_rate_matrix = spike_rate_per_cell.unsqueeze(0).expand(pc.shape[0], -1, -1)
l = spike_rate_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]
info = pc_nonzero * torch.log2(pc_nonzero / l + eps) * dist_nonzero
info_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]] = info
info_per_cell = info_matrix.sum(dim = 0)
norm_info = info_per_cell * 1/spike_rate_per_cell
norm_info[spike_rate_per_cell == 0] = 0
# Average the Skaggs information across batches
batch_ave_norm_info = norm_info.mean(dim = 0)
return batch_ave_norm_info
def I_spike_joint(pc, dist):
Nx = pc.shape[0]
batch_size = pc.shape[1]
Np = pc.shape[2]
all_Js = torch.zeros((Np, Np))
for obs in range(batch_size):
J = torch.zeros((Np, Np))
for i in range(Np):
for j in range(i+1, Np):
pc1 = torch.relu(pc[:, obs, i])
pc2 = torch.relu(pc[:, obs, j])
if torch.std(pc1) ==0 :
J[i, j] = I_spike(pc2.reshape((Nx, 1, 1)), dist)
elif torch.std(pc2) == 0:
J[i, j] = I_spike(pc2.reshape((Nx, 1, 1)), dist)
else:
r = pc_corrcoef(pc1, pc2, dist)
lab = pc1 * pc2
lab_tilde = (dist * torch.sqrt(pc1*pc2)).sum()
la = (dist * pc1).sum()
lb = (dist * pc2).sum()
info_1 = dist * (r * (torch.sqrt(lab + eps)/lab_tilde)*torch.log2(torch.sqrt(lab + eps) / lab_tilde + eps))
info_2 = dist * ((pc1 - r*torch.sqrt(lab + eps))/(la - r*lab_tilde)*torch.log2((pc1 - r*torch.sqrt(lab + eps)) / (la - r*lab_tilde) + eps))
info_3 = dist*((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde)*torch.log2((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde) + eps))
# TODO: Question validity of below
info_1[(lab == 0) | (torch.sqrt(lab)/lab_tilde < 0)] = 0
info_2[((pc1 - r*torch.sqrt(lab)) == 0) | ((pc1 - r*torch.sqrt(lab)) / (la - r*lab_tilde) < 0)] = 0
info_3[((pc2 - r*torch.sqrt(lab)) == 0)| ((pc2 - r*torch.sqrt(lab)) / (lb - r*lab_tilde) < 0)] = 0
info = info_1 + info_2 + info_3
J[i,j] = info.sum()
all_Js += J
return 1/batch_size * all_Js
def MI_spike(pc, dist):
Np = pc.shape[2]
batch_size = pc.shape[1]
Nx = pc.shape[0]
all_Ms = torch.zeros((Np, Np))
for obs in range(batch_size):
M = torch.zeros((Np, Np))
J = I_spike_joint(pc[:,obs,:].reshape((Nx, 1, Np)), dist)
for i in range(Np):
for j in range(i+1, Np):
pc1 = pc[:, obs, i].reshape((Nx, 1, 1))
pc2 = pc[:, obs, j].reshape((Nx, 1, 1))
M[i,j] = I_spike(pc1, dist) + I_spike(pc2,dist) - J[i,j]
all_Ms += M
return 1/batch_size * all_Ms
I am attempting to use an RNN to produce outputs of size [Nx, batch_size, Np], which are then used to calculate the information measure as described above. The architecure of my RNN is below:
class RNN(torch.nn.Module):
def __init__(self, options):
super(RNN, self).__init__()
self.Ng = options.Ng
self.Np = options.Np
self.sequence_length = options.sequence_length
self.weight_decay = options.weight_decay
self.device = options.device
# Input weights
self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
self.RNN = torch.nn.RNN(input_size=2,
hidden_size=self.Ng,
nonlinearity=options.activation,
batch_first=False,
bias=False)
# Linear read-out weights
self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
def g(self, inputs):
'''
Compute grid cell activations.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
'''
v, p0 = inputs #TODO: Maybe change this, whwere does p0 come from?
init_state = self.encoder(p0)[None]
g,_ = self.RNN(v, init_state)
return g
def predict(self, inputs):
'''
Predict place cell code.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
place_preds: Predicted place cell activations with shape
[batch_size, sequence_length, Np].
'''
place_preds = self.decoder(self.g(inputs))
return place_preds
Training results in nan after just one iteration (the first loss value is a reasonable value, but then everything immediately goes to nan. Running anomaly detection resulted in the following error:
RuntimeError: Function 'LinalgVectorNormBackward0' returned nan values in its 0th output.
Please help if there is any issue with my loss function or architecture that is causing this. Gradient clipping hasn’t helped.
Thanks