RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn (1D CNN error)

Hello,

I am running a CNN and trying to train it later on, however i get the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
The error appears when I cal ‘loss.backward()’ in the validation step.
Here is my code, help would be appreciated :smiley:

> import torch.nn as nn
> import torch.nn.functional as F
> 
> class MyModel(nn.Module):
>     def __init__(self):
>         super(MyModel, self).__init__()
>         self.conv1 = nn.Conv1d(1, 3, kernel_size=3) 
>         self.conv2 = nn.Conv1d(3, 6, kernel_size=6) 
>         self.conv3 = nn.Conv1d(6, 12, kernel_size=12) 
>         self.conv4 = nn.Conv1d(12,24, kernel_size=24) 
>         self.conv5 = nn.Conv1d(24,48, kernel_size = 48)
>         self.conv_drop = nn.Dropout2d()
>         self.conv6_transpose = nn.ConvTranspose1d(48,24, kernel_size = 48)
>         self.conv7_transpose = nn.ConvTranspose1d(24,12, kernel_size = 24)
>         self.conv8_transpose = nn.ConvTranspose1d(12,6, kernel_size = 12)
>         self.conv9_transpose=nn.ConvTranspose1d(6,3, kernel_size = 6)
>         self.conv10_transpose = nn.ConvTranspose1d(3,1, kernel_size = 3)
> 
> 
>         self.bn1 = nn.BatchNorm1d(3)
>         self.bn2 = nn.BatchNorm1d(6)
>         self.bn3 = nn.BatchNorm1d(12)
>         self.bn4 = nn.BatchNorm1d(24)
>         self.bn5 = nn.BatchNorm1d(48)
>         self.bn6 = nn.BatchNorm1d(24)
>         self.bn7 = nn.BatchNorm1d(12)
>         self.bn8 = nn.BatchNorm1d(6)
>         self.bn9 = nn.BatchNorm1d(3)
>         self.bn10 = nn.BatchNorm1d(1)
> 
> 
>     def forward(self, x):
>         x = x.unsqueeze(0).permute(1, 0, 2)
> #         print (x.shape)
>         out = self.conv1(x)
>         out = F.relu(out)
>         out = self.bn1(out)
> #         print (out.shape)
>         #        out = out.view(out.shape[0], -1)
> 
>         out = self.conv2(out)
>         out = F.relu(out)
>         out = self.bn2(out)
> #         print (out.shape)
> 
>         out = self.conv3(out)
>         out = F.relu(out)
>         out = self.bn3(out)
> #         print (out.shape)
> 
>         out = self.conv4(out)
>         out = F.relu(out)
>         out = self.bn4(out)
> #         print (out.shape)
> 
>         #bottleneck
>         out = self.conv5(out)
>         out = F.relu(out)
>         out = self.bn5(out)
>         out = self.conv_drop(out)
> #         print (out.shape)
> 
> 
>         #upsample
>         out = self.conv6_transpose(out)
>         out = self.conv_drop(out)
>         out = F.relu(out)
> #         print (out.shape)
> #         out = out.unsqueeze(0)
>         out = self.bn6(out)
> 
>         out = self.conv7_transpose(out)
>         out = self.conv_drop(out)
>         out = F.relu(out)
> #         out = out.unsqueeze(0)
>         out = self.bn7(out)
> 
>         out = self.conv8_transpose(out)
>         out = self.conv_drop(out)
>         out = F.relu(out)
> #         out = out.unsqueeze(0)
>         out = self.bn8(out)
> 
>         out = self.conv9_transpose(out)
>         out = self.conv_drop(out)
>         out = F.relu(out)
> #         out = out.unsqueeze(0)
>         out = self.bn9(out)
> 
>         out = self.conv10_transpose(out)
>         out = self.conv_drop(out)
>         out = F.relu(out)
> #         out = out.unsqueeze(1)
>         out = self.bn10(out)
>         return out        
>         
> model = MyModel()
> 
> print(model)
> 
> #%% Here we train and evaluate the model
>         
> 
> import torch
> from torch.utils import data
> import numpy as np
> import torch.optim as optim
> 
> 
> # Check if CUDA is available
> use_cuda = torch.cuda.is_available()
> device = torch.device("cuda:0" if use_cuda else "cpu")
> # cudnn.benchmark = True
> 
> 
> # Set training parameters
> params = {'batch_size': 64,
>           'shuffle': True,
>           'num_workers': 6}
> max_epochs = 100
> 
> # Load all the data from the txt file
> file_IDs = open('ID_list.txt','r').read().split('\n')
> file_IDs = file_IDs[:-1] # remove last line
> complete_dataset = Dataset(file_IDs)
> 
> 
> #%% Here we define a loss function and the optimizer
> # create your loss function
> def rmse(y, y_hat):
>     """Compute root mean squared error"""
>     return torch.sqrt(torch.mean((y - y_hat).pow(2)))
> 
> # create your optimizer
> optimizer = optim.SGD(model.parameters(), lr=0.0003, momentum = 0.1)
> 
> #%% Here we train the network
> 
> # Divide the dataset into the training and validation set
> lengths = [int(np.ceil(len(complete_dataset)*0.8)), int(np.floor(len(complete_dataset)*0.1)), int(np.floor(len(complete_dataset)*0.1))]
> training_set, validation_set, evaluation_set = torch.utils.data.random_split(complete_dataset, lengths)
> training_generator = data.DataLoader(training_set, **params)
> validation_generator = data.DataLoader(validation_set, **params)
> evaluation_generator = data.DataLoader(evaluation_set, **params)
> 
> 
> # instantiate the model to make it a double tensor
> forward_model = model.double()
> 
> # check if model works for random instance of the data 
> t20, t100 ,t250 = next(iter(training_generator))
> one_prediction = forward_model(t20)
> 
> 
> # Loop over epochs
> for epoch in range(max_epochs):
>  for param in forward_model.parameters():
>  	param.requires_grad = True
>     
>     
>     # Training
>  with torch.set_grad_enabled(True):
>   for t20MHz, t100MHz, t250MHz in training_generator:
>         # Transfer to GPU if available
>             t20MHz, t100MHz, t250MHz = t20MHz.to(device), t100MHz.to(device), t250MHz.to(device)
> 
>         #zero the parameter gradients
>             optimizer.zero_grad()
> 
> 
>         # forward, backward and optimize
>             prediction_training = forward_model(t20MHz)
>             loss = rmse(prediction_training, t250MHz) #Here compute tha value that estimates how far the outut is from the target
>             loss.backward() #The whole graph is differentiatted wrt the loss and all ensors in the graph that have 'required_gard = true' will have their .grad tensot accumulaed ith the gradient. 
>             optimizer.step()
>         
>         
>         
>         
>         # Validation (needs to be adapted)
>  with torch.set_grad_enabled(False):
>   for t20MHz, t100MHz, t250MHz in validation_generator:
>             # Transfer to GPU if available
>                 t20MHz, t100MHz, t250MHz = t20MHz.to(device), t100MHz.to(device), t250MHz.to(device)  
> 
>             # Model computations
>                 prediction_validation = forward_model(t20MHz)
>                 loss = rmse(prediction_validation, t250MHz)
>                 loss.backward()
>             
>                 #torch.save(forward_model.state_dict(), data_path + 'conv_net_model.pth')
>         
> 
>         
>               
> 
>         
>       
> 
> #%% Here we are going to evaluate the network
> 
> forward_model.eval()
> with torch.set_grad_enabled(False):
> # Loop over epochs
>     for epoch in range(max_epochs):
>     
>     
>     # Training
>         for t20MHz, t100MHz, t250MHz in evaluation_generator:
>         # Transfer to GPU if available
>             t20MHz, t100MHz, t250MHz = t20MHz.to(device), t100MHz.to(device), t250MHz.to(device)
> 
>         #zero the parameter gradients
>             optimizer.zero_grad()
> 
> 
>         # forward, backward and optimize
>             prediction_evaluation = forward_model(t20MHz)
> #            loss = rmse(prediction, t250MHz) #Here compute tha value that estimates how far the outut is from the target
> #            loss.backward() #The whole graph is differentiatted wrt the loss and all ensors in the graph that have 'required_gard = true' will have their .grad tensot accumulaed ith the gradient. 
> #            optimizer.step()                    
>         
> 
>             torch.save(forward_model.state_dict(), data_path + 'conv_net_model.pt')

You won’t be able to call loss.backward() inside a block, where gradient calculation is disabled via with torch.set_grad_enabled(False).
You could of course enable it, but calculating gradients in the validation set (and updating the model) is a bad idea, as you are leaking the validation dataset into the training and thus make the validation set unusable to validate the training.

1 Like