I am facing problem with updating a custom defined nn.Parameter(). I tried porting it into GPU/CPU using Adam/SGD but the the self.mu parameter does not update at all over batches/epochs. The code below is written to use a custom weights kernel in F.Conv1d. Below is my network class definition. Also, I found that print(model.mu.data.grad) gives None. I appreciate any suggestions to debug this. I am new to using NNs and pytorch.
class CNN(nn.Module):
def __init__(self, h_in, w_in, h_out):
super(CNN,self).__init__()
self.filter_dur = 8e-3
self.srate = 16000
self.f_max = 6500
self.C_out = 64
self.filter_length = int(self.filter_dur*self.srate)-1
self.K = (self.filter_length,1)
self.m = torch.arange(0,self.K[0],1)/self.srate
self.freq_scale = 100
self.mu = Parameter(torch.rand((self.C_out,1)))
self.center_freq = (self.freq_scale*self.mu.repeat(1,self.K[0]))
self.time_grid = self.m.repeat(self.C_out, 1)
self.dims = (self.C_out, h_in, w_in) # dimensions of pool layer output
self.fc_1 = nn.Linear(np.prod(self.dims), h_out) # flattened pool output
def forward(self, x):
torch.pi = torch.acos(torch.zeros(1)).item() * 2
gauss = torch.exp(-(self.time_grid - self.filter_dur/2)**2/(torch.sigmoid(self.center_freq/self.f_max -0.5)**2)*1000)
tones = torch.cos(2*(torch.pi)*torch.relu(self.center_freq) * self.time_grid)
filters = gauss * tones
filters = filters.reshape(self.C_out,self.K[1],self.K[0])
x = x.view(x.shape[0],x.shape[2],x.shape[1])
x = F.conv1d(x, filters, stride=1, padding=self.K[0]//2)
x = x.view(-1, np.prod(self.dims)) # flatten pooling layer outputs into a vector
x = torch.relu(self.fc_1(x)) # output of fully connected layer
return x
First you should never use .data.
When you do model.mu.data it returns a new Tensor and so its .grad field will always be None.
Also, you want to move self.center_freq = (self.freq_scale*self.mu.repeat(1,self.K[0])) to the forward funciton: every computation that is part of the forward should happen in the forward.
Otherwise, you will have part of these computations that are shared across iterations and it will lead to errors.
Thank you @albanD . Yes, your suggestion makes sense. I moved self.center_freq into the forward function but this too doesn’t help in any updates on self.mu. Only the fc_1 params are updating. To give a better picture of the code, I am putting below my training block.
# train the CNN for num_epochs
train_loss = []
val_loss = []
mu_updates = []
start_time = time.time()
retain_graph = True
for epoch in range(num_epochs):
loss_accum = []
for dataIns in dataloader:
# read one sample
inputRef, outputRef = dataIns
inputRef = inputRef.reshape(inputRef.shape[0],h_in,w_in)
outputRef = outputRef.reshape(outputRef.shape[0], h_out)
inputRef, outputRef = inputRef.to(device), outputRef.to(device)
# forward pass
output = model(inputRef.float())
# compute loss
loss = custom_loss(output.float(), outputRef.float())
loss_accum.append(loss)
# backward pass
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
# store mu
mu_updates.append(model.mu.data.detach().cpu().numpy()[:,0])
# compute val loss
output = model(valData.float())
vloss = custom_loss(output.float(), valOutput.float())
# ===================store loss==================
train_loss.append(sum(loss_accum).data.cpu().numpy()/len(loss_accum))
val_loss.append(vloss.cpu().data.numpy())
if epoch % 50 == 0:
print('epoch [{}/{}], loss:{:.4f}, validation_loss:{:.4f}, time_elapsed:{:.4f}'
.format(epoch + 1, num_epochs, train_loss[epoch], val_loss[epoch], time.time() - start_time))
if vloss.data <= 0.0001:
print(vlosses[-1])
break
Also you can print print(model.mu.grad) to see if you actually get gradients for it.
In particular, you want to make sure that your function structure does not make the gradients always 0 here.
Thanks again @albanD . I see the output of print(model.mu.grad) is in range tensor([[ 5.7729e-07], ... These values are too low to reflect in an update, right? Is this do do with dynamic range of my initialization or that of inputs?
It can have many cause so hard to say without any details.
But the original value can be the issue. Or being in a region that is too flat of your loss.
A good way to debug these is to add some_tensor.register_hook(print) on Tensors in the forward pass. It will print the gradient for that Tensor. That way you can see where in the backward pass the gradients are vanishing.