Net forward success,but backward failed, somebody know why?

I build a net, I used cross entropy loss, the forward is success, but the backward failed! I don’t know why it doesn’t work.

RuntimeErrorTraceback (most recent call last)
<ipython-input-3-c3211f22ae0b> in <module>()
132             print "loss: {}, train_acc: {}".format(loss.data[0], train_acc)
133 
--> 134         loss.backward()
135         opt.step()
136 

/root//lib/python2.7/site-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_variables)
    144                     'or with gradient w.r.t. the variable')
    145             gradient = self.data.new().resize_as_(self.data).fill_(1)
--> 146         self._execution_engine.run_backward((self,), (gradient,), retain_variables)
    147 
    148     def register_hook(self, hook):

/root//lib/python2.7/site-packages/torch/autograd/_functions/tensor.pyc in backward(self, grad_output)
    307     def backward(self, grad_output):
    308         return tuple(grad_output.narrow(self.dim, end - size, size) for size, end
--> 309                      in zip(self.input_sizes, _accumulate(self.input_sizes)))
    310 
    311 

/root//lib/python2.7/site-packages/torch/autograd/_functions/tensor.pyc in <genexpr>((size, end))
    306 
    307     def backward(self, grad_output):
--> 308         return tuple(grad_output.narrow(self.dim, end - size, size) for size, end
    309                      in zip(self.input_sizes, _accumulate(self.input_sizes)))
    310 

RuntimeError: out of range at /data/users/soumith/miniconda2/conda-bld/pytorch-cuda80-0.1.10_1488756735684/work/torch/lib/TH/generic/THTensor.c:367
1 Like
opt = O.Adagrad(filter(lambda p: p.requires_grad,model.parameters()),lr=config["lr"],)

I remove one parameter in code . because I wanna use pretrained vector (glove) , is this reason?

I tested it, i’t not about the parameters

Hi,

Unfortunately without more information about what you run, it’s hard to help you.
The error is that one of the gradients passed back to your Concat operation does not contain the right number of dimensions.
You should make sure that you never change the content of a Variable by accessing its .data.

thank you for your time. I do the concat operation!but I never change the content of a Variable .

Can you share a small example that would allow us to reproduce this problem?
Or share the code corresponding to your concat operation and what you do with the concat output.

OK, I will do a small example, the original code is too much.wait a minute

this is my example! I finally reproduced it!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

class Fnn2D3(nn.Module):
    def __init__(self,input_dim, hidden_dim, dp_ratio):
        super(Fnn2D3,self).__init__()
        self.out = nn.Sequential(
            nn.Dropout(dp_ratio),
            nn.Linear(input_dim, hidden_dim,bias=False),
            nn.ReLU(),
            nn.Dropout(dp_ratio),
            nn.Linear(hidden_dim, hidden_dim,bias=False),
            nn.ReLU())
     
    def forward(self, inputs):
        a,b,c = inputs.size()
        inputs = inputs.view(-1,c)
        outputs = self.out(inputs)
        outputs = outputs.view(a,b,-1)
        return outputs
    
    
class Mlp2(nn.Module):
    def __init__(self,input_dim, hidden_dim, output_dim,dp_ratio):
        super(Mlp2,self).__init__()
        self.out = nn.Sequential(
            nn.Dropout(dp_ratio),
            nn.Linear(input_dim,hidden_dim,bias=False),
            nn.ReLU(),
            nn.Dropout(dp_ratio),
            nn.Linear(hidden_dim,output_dim,bias=False)
        )
        
    def forward(self, inputs):
        return self.out(inputs)
class Example(nn.Module):
    def __init__(self):
        super(Example,self).__init__()
        self.cmp1 = Fnn2D3(600,200,0.2)
        self.cmp2 = Fnn2D3(600,200,0.2)
        self.mlp = Mlp2(400,200,3,0.2)
    
    def forward(self,a,b,c,d):
        a = self.cmp1(torch.cat((a, c), 2))
        b = self.cmp2(torch.cat((b, d), 2))
        
        a = torch.mean(a,1)
        b = torch.mean(b,1)
        print a.size()
        print b.size()
        # hypo_mpool:(batch,1,cmp_dim),prem_mpool..
        out = self.mlp(torch.cat((a,b), -1).view(5,-1))
        
        return out
    
a = Variable(torch.randn(5,30,300))
b = Variable(torch.randn(5,23,300))
c = Variable(torch.randn(5,30,300))
d = Variable(torch.randn(5,23,300))
e = Variable(torch.from_numpy(np.array([1,0,1,2,1])))

opt = O.Adagrad(model.parameters(),lr=config["lr"],)

model = Example()
output = model(a,b,c,d)
print output.size()
criterion = nn.CrossEntropyLoss()


loss = criterion(output,e)
print loss
loss.backward()
opt.step()

It looks like there is a problem in the backward of the cat operation when the given dimension is negative.
You can replace torch.cat((a,b), -1) by torch.cat((a,b), a.dim()-1) as a temporary fix.

I will look into what is causing this bug when I have more time (cc @apaszke )

1 Like

so much thank you for u!! it works!