Loss does not improve on training

Hi,

I implemented a model which fails to learn. Loss calculated for every epoch is exactly the same and so is the sequence of losses over batches for every epoch. I have been fumbling around with this for a couple of days and apparently stumbled over the reason for this behaviour now.

As it seems all Tensors I use in my C++ code are from namespace at even though I am explicitly creating them with torch::Tensor … . Can anyone please explain why this is so and how to fix this?

I’m including <torch/torch.h> and no other PyTorch headers an I’m never creating any at::Tensors.

I only noticed this because of a compilation error I introduced where the compiler among other things complained about some at::Tensors which I didn’t realize I was using until then.

Any help would be greatly appreciated.

Edit 1:
Calling requires_grad() on any Tensor in my code so far yields false. Guess it should not be that way…

Edit 2: Here’s some code from the training part, maybe it’s just obvious what I’m doing wrong…

torch::optim::SGD optimizer(classifier->parameters(), 0.001);
const unsigned int epochs = 10;
torch::Tensor loss, output, runningLoss;

for(unsigned int epoch=0; epoch<epochs; epoch++) {
	runningLoss = torch::zeros(1);
			
	for(vector<torch::data::Example<>>& batch : *dataLoader) {
		for(torch::data::Example<torch::Tensor, torch::Tensor>& sample : batch) {
			optimizer.zero_grad();
			output = classifier->forward(sample.data);
			loss = torch::mse_loss(output, sample.target);
			runningLoss += loss;
			loss.backward();					
			optimizer.step();
		}
	}			
}

Edit 3:
Maybe I’m just doing something wrong when passing the data to the functions…

I played arount with the learning rate of the SGD optimizer setting it to values from 10000000 to 0.0000001 without any effect… bump

I’ve build an executable example now, stripped of all unnecessary stuff. Basicly I’m trying to build an autoencoder but the model is not converging. Have quite the same model in Python working like a charm. I really do not see where I might be going wrong here… In the example prog I’m passing one Tensor to the autoencoder which I expect the network should be able to learn and reproduce after some number of iterations. Loss is constant though no matter how many iterations …

#include <torch/torch.h>

using namespace std;

struct Autoencoder : torch::nn::Module {
	
	unsigned int inputSize;
	const unsigned int encodingSize = 35;
	
	// Network layers
	torch::nn::Linear encode1;
	torch::nn::Linear encode2;
	torch::nn::Linear encode3;
	torch::nn::Linear decode1;
	torch::nn::Linear decode2;
	torch::nn::Linear decode3;
	
	Autoencoder() = delete;
	
	Autoencoder(unsigned int inputs)
		: inputSize(inputs),
		  encode1(register_module("encode1", torch::nn::Linear(inputSize, 77))),
		  encode2(register_module("encode2", torch::nn::Linear(77, 44))),
		  encode3(register_module("encode3", torch::nn::Linear(44, encodingSize))),
		  decode1(register_module("decode1", torch::nn::Linear(encodingSize, 44))),
		  decode2(register_module("decode2", torch::nn::Linear(44, 77))),
		  decode3(register_module("decode3", torch::nn::Linear(77, inputSize)))
	{}
	
	torch::Tensor forward(torch::Tensor& input) {
		torch::Tensor x;
		
		// Encode
		x = torch::relu(encode1->forward(input));
		x = torch::relu(encode2->forward(x));
		x = torch::relu(encode3->forward(x));
		
		// Decode
		x = torch::relu(decode1->forward(x));
		x = torch::relu(decode2->forward(x));
		x = decode3->forward(x);
		
		return x;
	}
};

int main(int argc, char* argv[]) {
	
	torch::Tensor data = torch::zeros(10, torch::requires_grad(true).dtype(torch::kFloat32));
	
	data[0] = 0;
	data[1] = 1;
	data[2] = 0;
	data[3] = 1;
	data[4] = 0;
	data[5] = 1;
	data[6] = 0;
	data[7] = 1;
	data[8] = 0;
	data[9] = 1;
	
	Autoencoder autoencoder(10);
	
	torch::Tensor loss = torch::zeros(1);
	torch::Tensor output = torch::zeros(10);
	torch::optim::SGD optimizer(autoencoder.parameters(), torch::optim::SGDOptions(10));
	
	for(int i=0; i<1000000; i++) {
		optimizer.zero_grad();
		
		output = autoencoder.forward(data);
		
		loss = torch::mse_loss(output.detach(), data.detach());
		
		cout << "\rLoss: " << loss.item().toFloat();
		
		loss.backward();
		optimizer.step();
	}
	
	cout << endl;
	
	return 0;
}

Edit 1: Tried to explicitly initialize layer. Didn’t help :frowning:
Edit 2: Tried to use different optimizer (Adam). Didn’t help :frowning:
Edit 3: Moved detach() out of forward() to make it more visible…

Why do I have to detach those two tensors when passing to loss function? Is that correct?

No. That means no training. You actually want to detach data (if it requires gradient) but not output.

Best regards

Thomas

1 Like

Hi tom, thanks for your reply! This get’s my sample program working :slight_smile: Will check my initial program to see if it has the same issue.

I’m wondering if the following could be problematic…

for(auto& sample : batch) {
    . . .
    output = classifier->forward(sample.data);
    loss = torch::mse_loss(output, sample.data.detach());
    . . .				

input and target for the network are the same tensor. So if I detach sample.data do I also implicitly detach output? Or does this in some other way mess up gradients / learning?

Edit: Sorry, forget this. The example prog does just the same and is working…

In my initial program it makes no difference if I detach the output variable or not… :thinking: Loss remains constant either way.

The code is as follows:

torch::Tensor output = torch::zeros(127, torch::requires_grad(true));
for(unsigned int epoch=0; epoch<epochs; epoch++) {
	optimizer.zero_grad();			
	output = classifier->forward(data_x);
	loss = torch::mse_loss(output, data_x.detach_());
	loss.backward();
	optimizer.step();
}

Is it possible that the forward() method creates an output Tensor without gradient information?

… yep. Had the same detach() in my forward method as in the example above before I moved it… Now my loss turns out to be -nan …:weary:

The -nan values are introduced by the Linear layers. The first layer introduces a handfull of those nans and it is getting worse with data traversing through each additional layer.

Most of the input data consists of zeros with usually a hand full of values differing from zero. Input data seems fine otherwise. No nan values to be found there.

Any suggestions/ideas on this?

Edit 1: Turns out that some bias values in the first layer are nan … How comes?
Edit 2: Bias values are fine upon initialization. Nans seem to appear when updating the network…
Edit 3: After learning on a first batch of size 64 a couple of nans already appear in the layer’s bias and weight values among some impressively big negative numbers. I’m afraid bias values are overflowing … My input data has not been scaled/normalized yet. I’ll try that and see if it helps…

Did you move the learning rate back to some non-ridiculius value?

Thanks tom for your reply. The learning rate was indeed set to 1.0 which is a bit high I guess but maybe not yet rediculous. Switching the learning rate back to 0.01 alone is not enough as it seems because I’m still getting my nans then.
While I have still trouble getting a proper standard deviation to work, scaling by only subtracting the mean from input data seems to mitigate the problem. So far the network loss is decreasing / the network is converging in the couple of tests I did with the current settings.

So the solution to the specific problem in my case seemed to be

  1. Do not detach output
  2. Scale input data

As for the standard deviation problem I will open another thread.

Thanks tom for your help!

Edit 1: Guess I was cheering too early once again… In my scaling method I am subtracting the input data’s mean from input and dividing by the input data’s standard deviation which happens to be broken. As soon as I remove the division by the broken std-Tensor I keep getting nans again :sob::sob::sob:

Edit 2: Interestingly when printing the result of
torch::std(data, 0);
I get

[ Variable[CPUType]{} ]

Which I interpret as an empty Tensor with no values (?). Nevertheless division by that Tensor seems to scale values of the data Tensor which seems to fix the nan issue… weird. Is this correct expected behaviour?

No, that is a tensor with empty shape, ie a scalar.

Besteht regards

Thomas

Ah great, thanks again! So it seems like I have to provide another number as the second parameter to the std() function to get the standard deviation for every column seperately. I’ll find out :slight_smile:

Interestingly
torch::std(tensor, integer)
does not call the function one would think it does but (at least in my case) instead calls
torch::std(Tensor t, bool b)
because the provided integer value is implicitly converted to a bool. I think this could be considered a bug or a design flaw, since this behaviour is counter intuitive and I think there will be more people using this the way I did. To provide the dimension along which to calculate the std deviation one has to explicitly call torch::std(tensor, integer, false)