Hi there,

I am attempting to implement a custom loss function that utilizes a novel measure of information. The loss function is here

```
def MI_loss(self, inputs, E0 = 188, E = 1000): # TODO: MAKE THIS FASTER in pc_info
'''
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
MI_loss: Avg. MI information loss for this batch
'''
preds = self.predict(inputs)
Nx, batch_size, Np = preds.shape
dist = 1 / Nx * torch.ones(Nx) ## Option to change this for more interesting distributions
M = pc_info.MI_spike(preds, dist)
M = M.to(self.device)
loss = torch.norm(M + eps, p = 'fro')
# ENERGY CONSTRAINT
total_ave_frs = dist @ preds.mean(dim = 1)
energy_constraint_loss = relu((E0 * total_ave_frs - E))
energy_loss = torch.mean(energy_constraint_loss)
return loss + energy_loss
```

Where pc_info.MI_spike is defined by the following few snippets of code:

```
def I_spike(pc, dist):
pc = torch.relu(pc)
pc_nonzero_indices = pc.nonzero(as_tuple = False)
info_matrix = torch.zeros(pc.shape)
norm_info = torch.zeros(pc.shape)
pc_nonzero = pc[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]
dist_nonzero = dist[pc_nonzero_indices[:, 0]]
spike_rate_per_cell = (pc * dist.view(-1, 1, 1)).sum(dim = 0)
spike_rate_matrix = spike_rate_per_cell.unsqueeze(0).expand(pc.shape[0], -1, -1)
l = spike_rate_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]
info = pc_nonzero * torch.log2(pc_nonzero / l + eps) * dist_nonzero
info_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]] = info
info_per_cell = info_matrix.sum(dim = 0)
norm_info = info_per_cell * 1/spike_rate_per_cell
norm_info[spike_rate_per_cell == 0] = 0
# Average the Skaggs information across batches
batch_ave_norm_info = norm_info.mean(dim = 0)
return batch_ave_norm_info
def I_spike_joint(pc, dist):
Nx = pc.shape[0]
batch_size = pc.shape[1]
Np = pc.shape[2]
all_Js = torch.zeros((Np, Np))
for obs in range(batch_size):
J = torch.zeros((Np, Np))
for i in range(Np):
for j in range(i+1, Np):
pc1 = torch.relu(pc[:, obs, i])
pc2 = torch.relu(pc[:, obs, j])
if torch.std(pc1) ==0 :
J[i, j] = I_spike(pc2.reshape((Nx, 1, 1)), dist)
elif torch.std(pc2) == 0:
J[i, j] = I_spike(pc2.reshape((Nx, 1, 1)), dist)
else:
r = pc_corrcoef(pc1, pc2, dist)
lab = pc1 * pc2
lab_tilde = (dist * torch.sqrt(pc1*pc2)).sum()
la = (dist * pc1).sum()
lb = (dist * pc2).sum()
info_1 = dist * (r * (torch.sqrt(lab + eps)/lab_tilde)*torch.log2(torch.sqrt(lab + eps) / lab_tilde + eps))
info_2 = dist * ((pc1 - r*torch.sqrt(lab + eps))/(la - r*lab_tilde)*torch.log2((pc1 - r*torch.sqrt(lab + eps)) / (la - r*lab_tilde) + eps))
info_3 = dist*((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde)*torch.log2((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde) + eps))
# TODO: Question validity of below
info_1[(lab == 0) | (torch.sqrt(lab)/lab_tilde < 0)] = 0
info_2[((pc1 - r*torch.sqrt(lab)) == 0) | ((pc1 - r*torch.sqrt(lab)) / (la - r*lab_tilde) < 0)] = 0
info_3[((pc2 - r*torch.sqrt(lab)) == 0)| ((pc2 - r*torch.sqrt(lab)) / (lb - r*lab_tilde) < 0)] = 0
info = info_1 + info_2 + info_3
J[i,j] = info.sum()
all_Js += J
return 1/batch_size * all_Js
def MI_spike(pc, dist):
Np = pc.shape[2]
batch_size = pc.shape[1]
Nx = pc.shape[0]
all_Ms = torch.zeros((Np, Np))
for obs in range(batch_size):
M = torch.zeros((Np, Np))
J = I_spike_joint(pc[:,obs,:].reshape((Nx, 1, Np)), dist)
for i in range(Np):
for j in range(i+1, Np):
pc1 = pc[:, obs, i].reshape((Nx, 1, 1))
pc2 = pc[:, obs, j].reshape((Nx, 1, 1))
M[i,j] = I_spike(pc1, dist) + I_spike(pc2,dist) - J[i,j]
all_Ms += M
return 1/batch_size * all_Ms
```

I am attempting to use an RNN to produce outputs of size [Nx, batch_size, Np], which are then used to calculate the information measure as described above. The architecure of my RNN is below:

```
class RNN(torch.nn.Module):
def __init__(self, options):
super(RNN, self).__init__()
self.Ng = options.Ng
self.Np = options.Np
self.sequence_length = options.sequence_length
self.weight_decay = options.weight_decay
self.device = options.device
# Input weights
self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
self.RNN = torch.nn.RNN(input_size=2,
hidden_size=self.Ng,
nonlinearity=options.activation,
batch_first=False,
bias=False)
# Linear read-out weights
self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
def g(self, inputs):
'''
Compute grid cell activations.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
'''
v, p0 = inputs #TODO: Maybe change this, whwere does p0 come from?
init_state = self.encoder(p0)[None]
g,_ = self.RNN(v, init_state)
return g
def predict(self, inputs):
'''
Predict place cell code.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
Returns:
place_preds: Predicted place cell activations with shape
[batch_size, sequence_length, Np].
'''
place_preds = self.decoder(self.g(inputs))
return place_preds
```

Training results in nan after just one iteration (the first loss value is a reasonable value, but then everything immediately goes to nan. Running anomaly detection resulted in the following error:

`RuntimeError: Function 'LinalgVectorNormBackward0' returned nan values in its 0th output.`

Please help if there is any issue with my loss function or architecture that is causing this. Gradient clipping hasnâ€™t helped.

Thanks