Defining my model encounter a runtimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I am going to define my model. How ever, I encounter the RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation while running backward().

class Detail(nn.Module):

def __init__(self,N,D):
    super(Detail, self).__init__()
    self.N, self.D = N, D

def forward(self, input):
    
    for i in range(self.N):
        temp=input[0][i]
        #temp=temp.numpy()
        #temp=temp.tolist()
        for m in range(self.D):
            for n in range(self.D):
                if m< self.D-1:
                    temp[n][m]=temp[n][m]-temp[n][m+1]
                if m==self.D-1:
                    temp[n][m]=temp[n][m-1] 
    
    return input

class View(nn.Module):

def __init__(self, *args):
    super(View, self).__init__()
    if len(args) == 1 and isinstance(args[0], torch.Size):
        self.size = args[0]
    else:
        self.size = torch.Size(args)

def forward(self, input):
    return input.view(self.size)

class Net(nn.Module):
def init(self, nclass, backbone=‘resnet18’):
super(Net, self).init()
self.backbone = backbone
# copying modules from pretrained models
if backbone == ‘resnet18’:
self.pretrained = resnet.resnet50(pretrained=True)
self.detail = nn.Sequential(
Detail(512,7),
nn.AvgPool2d(7),
View(-1, 512),
nn.Linear(512, 64),
Normalize()

            )
    self.pool = nn.Sequential(
        nn.AvgPool2d(7),
        View(-1, 512),
        nn.Linear(512, 64),
        Normalize()
        
    )
    self.fc = nn.Sequential(
        Normalize(),
        nn.Linear(64*64, 128),
        Normalize(),
        nn.Linear(128, nclass)
        )
def forward(self, x):


    if self.backbone == 'resnet18' or self.backbone == 'resnet101' \
        or self.backbone == 'resnet152':
        # pre-trained ResNet feature
        x = self.pretrained.conv1(x)
        x = self.pretrained.bn1(x)
        x = self.pretrained.relu(x)
        x = self.pretrained.maxpool(x)
        x = self.pretrained.layer1(x)
        x = self.pretrained.layer2(x)
        x = self.pretrained.layer3(x)
        x = self.pretrained.layer4(x)
        
        x1 = self.detail(x)
        print(x1.size())
        x2 = self.pool(x)
        
        print(x2.size())
        x1 = x1.unsqueeze(1).expand(x1.size(0),x2.size(1),x1.size(-1))
        print(x1.size())
        x = x1*x2.unsqueeze(-1)
        print(x.size())
        x=x.view(-1,x1.size(-1)*x2.size(1))
        out = self.fc(x)

    return out

Hi,

Your Detail module modifies the temp tensor inplace.
Assuming that temp is a DxD-1 matrix, your for loops can be replaced by:

part1 = temp.narrow(1, 0, D-1) - temp.narrow(1, 1, D-1)
part2 = part1.narrow(1, -1, 1)
out = torch.cat([part1, part2], 1))

Also I am not sure if you get the output you want? Maybe you were expecting part2 = temp.narrow(1, -1, 1).

thank you very much!I follow your approach, which error no longer appears, I am very curious about the reason. At the same time, I have other mistakes to ask,
RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 ‘mat1’
If you don’t mind, pay attention to your contact information, I want to continue to ask you.

The think is that your original code was modifying the temp tensor inplace (when doing temps[n,m] = xxx). The thing is that the original value of this tensor was needed for gradient computation, so inplace modification of it is not allowed !

Your other error is just that a Tensor is not on the gpu. make sure to only perform computation between tensors of the same type and on the same device.

Thank you very much!
I still have doubts about which problem I have before, because in my own defined module, no gradient calculation is needed, so why it is wrong to modify the temp tensor inplace ?