Model is not training @ PyTorch

Hi,
I am using Pytorch for train my model however it is not learning. I have written my own loss function. Can anyone please review the code and the steps I am taking for training? A plot showing loss and accuracy is attached in the end. Thank you.

def train(EPOCHS):
	    BATCH_SIZE = 5
	    best_acc = 0.0
	    train_dataset, train_data_loader, test_dataset, test_data_loader = load_data_set(BATCH_SIZE)  # receiving only second parameter
	    print("Number of train samples: ", len(train_dataset))
	    print("Number of test samples: ", len(test_dataset))
	    print("Epoch       Train Accuracy         TrainLoss          Test Accuracy          Test Loss")
	    for epoch in range(EPOCHS):

		#print("epoch: ", epoch)
		model.train()

		train_acc = 0.0
		train_loss = 0.0
		# 25 random points (x,y) on image for evaluation
		points = torch.randint(0, 32, ([100, 2]), dtype=torch.int)

		batch_count = 0
		for i, (images, gt) in enumerate(train_data_loader):
		    batch_count += 1

		    # Move images and labels to gpu if available
		    if cuda_avail:
		        images_tensor = Variable(images[0].cuda())

		    gt_tensor = Variable(gt[0])


		    # Clear all accumulated gradients
		    optimizer.zero_grad()
		    # Predict classes using images from the train set
		    output_lines = model(images_tensor)
		    # pdb.set_trace()  #TODO need to remove
		    h_matrix = output_lines.view(len(images_tensor), 25, 3)
		    gt_no = 0
		    batch_loss = 0  # calculate loss of every batch
		    batch_acc = 0   # calculate accuracy of every batch

		    for hyper_lines in h_matrix:    # pick all lines of single image

		        gt_img = np.array(gt_tensor[gt_no][0]) # select GT w.r.t. selected image
		        gt_no += 1   # for next GT
		        image_loss = 0
		        image_acc = 0
		        for point in range(points.size()[0]):   # pick one test point at a time
		            x = points[point]  # considering 1 point
		            estimated_value = indicator_functions.estimatedPosition(hyper_lines, x.double())
		            calculated_value = indicator_functions.calculated_Position(gt_img, x)

		            loss = (calculated_value - estimated_value) ** 2

		            # Calculate training loss and training accuracy
		            image_loss += loss.cpu().item()
		            if loss.cpu().item() < 0.5:
		                image_acc += 1.0
		            # pdb.set_trace() #TODO need to remove

		        # Compute the average acc and loss over all 25 lines of a single convex image and add in batch loss/acc
		        batch_loss += image_loss / points.size()[0]
		        batch_acc  += image_acc / points.size()[0]
		        # pdb.set_trace() #TODO need to remove

		    train_loss += batch_loss / h_matrix.size()[0]
		    train_acc += batch_acc / h_matrix.size()[0]

		    ### Making the loss tensor
		    train_loss_tensor = Variable(torch.tensor(train_loss), requires_grad=True)

		    # Backpropagate the loss
		    train_loss_tensor.backward()		 

		    # Adjust parameters according to the computed gradients
		    optimizer.step()
		    # pdb.set_trace()  # TODO need to remove

		# Compute the average acc and loss over all 50000 training images
		train_loss = train_loss / batch_count
		train_acc = train_acc / batch_count
		# pdb.set_trace()  # TODO need to remove

		### Evaluate on the test set
		test_acc, test_loss = test(test_data_loader, test_dataset)

		### Save the model if the test acc is greater than our current best
		if test_acc > best_acc:
		    save_models(epoch)
		    best_acc = test_acc

		# Print the metrics
		print("{}   {}   {}    {}     {}".format(epoch, train_acc, train_loss, test_acc, test_loss))

Here is the problem:

# Backpropagate the loss
train_loss_tensor = Variable(torch.tensor(train_loss), requires_grad=True)

When you are making a Variable, you are removing all gradient information from the tensor, hence you see no training improvement, since it doesn’t know its origins.

Try this code:

a = torch.randn(2, 3, requires_grad=True)
a = a*2
print(a.grad_fn)
# <MulBackward0 object at 0x7f1819cf3518>
a = Variable(a, requires_grad=True)
print(a.grad_fn)
# None

Thanks for reply Krish. Previously I was calculating loss in float variable. I have updated the a part of code such that now loss is a tensor. Here are the details of the loss updates;

  1. Initially I am calculating loss which is a tensor;
tensor(1.9175e-11, dtype=torch.float64, grad_fn=<PowBackward0>)
  1. I am averaging 100 losses in variable batch_loss,
tensor(3.7276, dtype=torch.float64, grad_fn=<AddBackward0>)
  1. Then I am averaging 25 batch_loss in variable train_loss;
tensor(0.7455, dtype=torch.float64, grad_fn=<AddBackward0>)

When I use train_loss.backward(), it shows error.

RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type TensorOptions(dtype=float, device=cuda:0, layout=Strided, requires_grad=false) but got TensorOptions(dtype=float, device=cpu, layout=Strided, requires_grad=false) (validate_outputs at /opt/conda/conda-bld/pytorch_1587428398394/work/torch/csrc/autograd/engine.cpp:484)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x4e (0x7f5692314b5e in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x2ae2834 (0x7f56bbfd8834 in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x548 (0x7f56bbfda368 in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7f56bbfdc2f2 in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int) + 0x39 (0x7f56bbfd4969 in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7f56bf31b558 in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xc819d (0x7f56c1d7e19d in /home/mz/anaconda3/envs/semSeg/lib/python3.7/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #7: <unknown function> + 0x76db (0x7f56dac2e6db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x3f (0x7f56da95788f in /lib/x86_64-linux-gnu/libc.so.6)
  1. Then I use the Variable() as given blow;
train_loss_tensor = Variable(train_loss, requires_grad=True)
#tensor(0.7810, dtype=torch.float64, requires_grad=True)

print(train_loss_tensor.grad_fn)
#None

train_loss_tensor.backward()

It is executing (no error), but still model is not training. Will you please elaborate, how could I make it working? Thank you.

The first line of the error suggests that there is a device mismatch. Are you moving the loss to the cpu midway?

And make sure the loss is a result of differentiable functions on the input, else the training won’t work. I don’t know if the indicator_functions are.

Yes, I was moving loss to cpu, but now I am remaining it to gpu.

Here are the indicator_functions:


# signed distance function
# H_h(x) = n_h . x + d_h
def sdf(hyperlines, x):
    dist_vector = torch.zeros([hyperlines.size()[0]], dtype=torch.double)
    i = 0
    for x_n, y_n, d_n in hyperlines:
        n_h = torch.tensor([x_n, y_n], dtype=torch.double)
        dist_vector[i] = torch.matmul(n_h, x) + d_n  # n_h[0]*x[0] + n_h[1]*x[1] + d_n
        i += 1
    return dist_vector

def logSumExp(dist_vector, delta):
    return torch.logsumexp(dist_vector, 0)*delta


def sigmoid(sigma, fai_x):
    return torch.sigmoid(sigma * fai_x)

# the Function uses softmax, sigmoid and returns probability of the point x inside object
def estimatedPosition(lines, x):
    dist_vector = sdf(lines, x)
    delta = torch.tensor(1) # 0.006
    fai_x = logSumExp(dist_vector, delta)  # softmax
    sigma = -1 #-75
    estimation = sigmoid(sigma, fai_x)
    return estimation

# function returns 1 if point (x) is inside GT object, returns 0 other wise
def calculated_Position(train_GT_image,x):
    if train_GT_image[x[0].item(), x[1].item()] == 0: return torch.tensor(1.0, dtype=torch.double)
    return torch.tensor(0.0, dtype=torch.double)

One more thing, after loss.backward() when I try print(model.weight.grad) it produces the error as shown below.

(Pdb) print(model.weight.grad)
*** AttributeError: 'SimpleNet' object has no attribute 'weight'

Sorry for the late reply. Hope you have already solved the problem.

The indicator functions are perfect and shouldn’t cause any problems in optimization.

Regarding model.weight.grad, your model doesn’t have a weight matrix attached, instead you should call the module object inside the model.

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 1)

...

model = SimpleModel()

You should call:

model.fc1.weight.grad

Hi Krish,
Thanks for your reply.
Yes, the problem has been slightly solved. Now the model is training but it is not achieving the desired results. This is not related to the model or learning.
As you suggested before, the error was in gradients as they were removing before back tracking. Now, after creating all the torch variable on cuda, the model has started training. Thanks for your help. :blush:

-Zohaib

Great. Hope you get your results soon.
Good luck. :+1:

P.S. You might want to mark the solution. :slight_smile:

1 Like