Autograd with 2 branches

I am an architechture like this:
ChainEncoder will encode a input to a vector and Predictor takes in 2 encoded vectors (encoded by the same encoder) and generate softmax output, and the output is sent to calculate loss and grads.

input_1 ->          -> encoded_vec_1  ->
           encoder                        predictor(vec_1, vec_2)  -> loss(output, y)
input_2 ->          -> encoded_vec_2   -> 

Training code:

optimizer = optim.Adam(list(encoder.parameters())+list(predictor.parameters()))

for train_iter in range(num_iter):
    chains_A, chains_B, y = dataset.get_train_pairs(N)
    output_A = encoder(chains_A)
    output_B = encoder(chains_B)
    logSoftmax_output = predictor(output_A, output_B)

    optimizer.zero_grad()
    loss_val = loss(logSoftmax_output, y)
    loss_val.backward()
    optimizer.step()

And I got the error from pytorch like this:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1000, 20]], which is output 0 of ReluBackward0, is at version 6; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

And the error happens when start doing backprop on input_2, so I suspect there might be some confliction there…

Thanks for your help!

Using retain_graph=True is usually wrong and is often added as a workaround for another issue.
Based on your code snippet, I guess you might be running into the issue when backward is called the second time on the computation graph from the first iteration (which was stale forward activations by now as the parameters were updated) and the current second iteration.
Could you explain, why you’ve used this argument and if it’s really needed?

The retain_graph=True is just a workaround I tried. The error is the same without it.

FYI, the python2 and pytorch 0.4.1 version of the model code is here (model.py and learn.py):

I largely followed the structure of the original code.
(I am trying to refactor this to be compatible with python3 and latest pytorch for my project)

The code works fine using:

encoder = ChainEncoder([1], [1], 1, 1)
predictor = Predictor(1)
optimizer = torch.optim.Adam(list(encoder.parameters())+list(predictor.parameters()))
criterion = nn.NLLLoss()

for _ in range(10):
    chains_A, chains_B, y = (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), (torch.randn(1, 1, 1, 1),  torch.randn(1, 1, 1, 1)), torch.zeros(1).long()
    output_A = encoder(chains_A)
    output_B = encoder(chains_B)
    logSoftmax_output = predictor(output_A, output_B)
    
    optimizer.zero_grad()
    loss_val = criterion(logSoftmax_output, y)
    loss_val.backward()
    optimizer.step()

so please post a minimal, executable code snippet to reproduce the issue.

Hi! thanks for your reply.
I traced the issue to be at the forward method of class ChainEncoder.
After googling around I found a wordaround by using .detach().clone() but I am not sure if this will affect my code’s correctness (I am not familiar with this as I am new to pytorch)
Could you please help me to see if I am using the .detach().clone() in the right way:
the target workaround is located at forward method of class ChainEncoder:
combined_encs = torch.stack(combined_encs, dim=0).detach().clone()

class ChainEncoder(nn.Module):
    '''
    encodes N chains at the same time
    assumes that each of the chains are of the same length
    '''

    def __init__(self, v_feature_lengths, e_feature_lengths, out_length, pooling):
        super(ChainEncoder, self).__init__()
        self.out_length = feature_enc_length = out_length
        num_layers = 1
        self.rnn_type = 'LSTM'
        self.pooling = pooling
        self.v_feature_lengths = v_feature_lengths
        self.e_feature_lengths = e_feature_lengths

        self.v_feature_encoders = nn.ModuleList()
        self.e_feature_encoders = nn.ModuleList()
        for d_in in self.v_feature_lengths:
            self.v_feature_encoders.append(
                FeatureTransformer(d_in, feature_enc_length))
        for d_in in self.e_feature_lengths:
            self.e_feature_encoders.append(
                FeatureTransformer(d_in, feature_enc_length))

        # RNN famlity layer: input (seq_len, batch_size, d_in), output (seq_len, batch_size, d_out * D) where D=2 for bidirectional, D=1 otherwise
        if self.rnn_type == 'RNN':
            self.rnn = nn.RNN(input_size=feature_enc_length,
                              hidden_size=out_length, num_layers=num_layers)
        elif self.rnn_type == 'LSTM':
            self.lstm = nn.LSTM(input_size=feature_enc_length,
                                hidden_size=out_length, num_layers=num_layers)

    def forward(self, input):
        '''
        input is a list of v_features, and e_features
        v_features is a list of num_vertices tuples
        each tuple is an N x d_in Variable, in which N is the batch size, and d_in is the feature length
        e_features is structured similarly
        '''
        v_features, e_features = input
        # v_features.shape = (num_vertices, batch_size, variable feature_len)
        # e_features.shape = (num_edges, batch_size, variable feature_len)

        v_encodes = []
        for i in range(len(v_features)):  # 4 vertices 
            v_enc = None
            for j in range(len(v_features[i])):  # feature in each vertex
                curr_encoder = self.v_feature_encoders[j]
                if v_enc is None:
                    v_enc = curr_encoder(v_features[i][j])
                else:
                    v_enc += curr_encoder(v_features[i][j])
            v_enc = v_enc / len(v_features[i])  # each feature encode is of shape (batch_size, out_length)
            v_encodes.append(v_enc)

        e_encodes = []
        for i in range(len(e_features)):  # 3 edges
            e_enc = None
            for j in range(len(e_features[i])):
                curr_encoder = self.e_feature_encoders[j]
                if e_enc is None:
                    e_enc = curr_encoder(e_features[i][j])
                else:
                    e_enc += curr_encoder(e_features[i][j])
            e_enc = e_enc / len(e_features[i])
            e_encodes.append(e_enc)

        combined_encs = [0] * (len(v_encodes)+len(e_encodes))
        # interleave vertices and edges
        combined_encs[::2] = v_encodes
        combined_encs[1::2] = e_encodes
        combined_encs = torch.stack(combined_encs, dim=0).detach().clone()
        # combined_encs has shape (#V+#E) x batch_size x out_length

        if self.rnn_type == 'RNN':
            output, hidden = self.rnn(combined_encs)
        elif self.rnn_type == 'LSTM':
            output, (hidden, cell) = self.lstm(combined_encs)
        if self.pooling == 'last':
            return output[-1]
        else:
            return torch.mean(output, dim=0)


class Predictor(nn.Module):
    '''
    takes two feature vectors and produces a prediction
    '''

    def __init__(self, feature_len):
        super(Predictor, self).__init__()
        self.linear = nn.Linear(feature_len, 1)
        self.logsoftmax = nn.LogSoftmax(dim=1)  # will use NLLLoss in learn.py

    def forward(self, vec1, vec2):
        # combined = torch.cat((vec1, vec2), dim=1)
        a = self.linear(vec1)
        b = self.linear(vec2)
        combined = torch.cat((a, b), dim=1)
        return self.logsoftmax(combined)

Thanks for your help!

Calling .detach().clone() will detach the tensor from the computation graph and thus won’t train the previously used modules.
I would guess this it not what you want.
As previously described, I cannot reproduce the error using your code snippet, so could you post a minimal, executable code snippet (without the detach()) to reproduce the issue?

I think .detach().clone() is not what I want as I want to train the previous modules indeed.
This is the configuration I used for training:

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

features = ['v_enc_dim300', 'v_freq_freq', 'v_deg', 'v_sense', 'e_vertexsim',
            'e_dir', 'e_rel', 'e_weightsource', 'e_srank_rel', 'e_trank_rel', 'e_sense']
encode_len = 20
split_frac = 0.8
N=1024  # batch size
dataset = Dataset(features, split_frac, device)

encoder = ChainEncoder(dataset.get_v_fea_len(),
                           dataset.get_e_fea_len(), encode_len, 'last')
predictor = Predictor(encode_len)
loss = nn.NLLLoss()

encoder.to(device=device)
predictor.to(device=device)
loss.to(device=device)

optimizer = optim.Adam(list(encoder.parameters()) +
                       list(predictor.parameters()))

print('Start training')
for train_iter in range(10):
    chains_A, chains_B, y = dataset.get_train_pairs(N)
    output_A = encoder(chains_A)
    output_B = encoder(chains_B)
    logSoftmax_output = predictor(output_A, output_B)

    optimizer.zero_grad()
    loss_val = loss(logSoftmax_output, y)
    loss_val.backward()
    optimizer.step()

In case it is not compatible with the python2 and pytorch0.4.1 version, I also attach the refactored version of my code here (with .detach().clone() removed in model.py). Thanks for your help!

dataset.py

from __future__ import division

import pickle
import random
import numpy as np
from itertools import cycle

import torch
from torch.autograd import Variable

all_feature_lengths = {'v_enc_onehot': 100,
                       'v_enc_embedding': 300,
                       'v_enc_dim300': 300,
                       'v_enc_dim2': 2,
                       'v_enc_dim10': 10,
                       'v_enc_dim50': 50,
                       'v_enc_dim100': 100,
                       'v_freq_freq': 1,
                       'v_freq_rank': 1,
                       'v_deg': 1,
                       'v_sense': 1,
                       'e_vertexsim': 1,
                       'e_dir': 3,
                       'e_rel': 46,
                       'e_weight': 1,
                       'e_source': 6,
                       'e_weightsource': 6,
                       'e_srank_abs': 1,
                       'e_srank_rel': 1,
                       'e_trank_abs': 1,
                       'e_trank_rel': 1,
                       'e_sense': 1}


class Dataset:
    def __init__(self, feature_names, train_test_split_fraction, device):
        self.feature_names = feature_names
        self.cached_features = dict()
        self.device = device

        for feature in feature_names:
            print('loading ' + feature)
            self.cached_features[feature] = pickle.load(
                open(f'features/{feature}.pkl', 'rb'), encoding='latin1')

        sampled_problems = pickle.load(open(
            '../../data/science/paths.pkl', 'rb'), encoding='latin1')

        self.texts = dict()
        print('loading problem plain texts')
        for id_num in sampled_problems:
            f_short = sampled_problems[id_num]['forward']['short']
            r_short = sampled_problems[id_num]['reverse']['short']
            self.texts[id_num+'f'] = f_short
            self.texts[id_num+'r'] = r_short

        print('loading labeled pairs')
        self.all_pairs = []  # list of id tuples (good, bad)
        for line in open('../../data/science/answers.txt'):
            first, second, good = line.strip().split('_')
            if first == good:
                bad = second
            elif second == good:
                bad = first
            g_len = (len(self.texts[good].strip().split(' '))+1)/2
            b_len = (len(self.texts[bad].strip().split(' '))+1)/2
            if g_len != 4 or b_len != 4:
                continue
            self.all_pairs.append((good, bad))
        random.shuffle(self.all_pairs)

        split = int(train_test_split_fraction * len(self.all_pairs))
        self.train_pairs = self.all_pairs[:split]
        self.test_pairs = self.all_pairs[split:]

        self.train_pairs = self.train_pairs[:len(self.train_pairs)]
        self.cycled_train_pairs = cycle(self.train_pairs)

    def get_fea_len(self):
        return [all_feature_lengths[f] for f in self.feature_names]

    def get_v_fea_len(self):
        return [all_feature_lengths[f] for f in self.feature_names if f.startswith('v')]

    def get_e_fea_len(self):
        return [all_feature_lengths[f] for f in self.feature_names if f.startswith('e')]

    def get_chain_len(self, id):
        return len(self.get_features(id)[0])

    def get_features(self, id):
        v_features = []
        e_features = []
        for f in self.feature_names:
            if f.startswith('v'):
                v_features.append(self.cached_features[f][id])
            else:
                e_features.append(self.cached_features[f][id])
        v_features = zip(*v_features)
        e_features = zip(*e_features)
        return list(v_features), list(e_features)

    def prepare_feature_placeholder(self, N):
        # 4 vertices and 3 edges for each path
        v_features = [[], [], [], []]
        e_features = [[], [], []]

        for feature in v_features:
            for f in self.feature_names:
                if f.startswith('v'):
                    feature.append(
                        np.zeros((N, all_feature_lengths[f]), dtype='float32')
                    )

        for feature in e_features:
            for f in self.feature_names:
                if f.startswith('e'):
                    feature.append(
                        np.zeros((N, all_feature_lengths[f]), dtype='float32')
                    )

        return v_features, e_features

    def get_train_pairs(self, N, randomize_dir=True):
        '''
        return a list of two lists, X_A and X_B, as well as a list y
        each list consists of two lists, which are vertex and edge representations
        each list consists of #V or #E lists, which are individual vertices/edges
        each list consists of several N x feature_len torch Variables, which are individual features
        currently only keeping chains of length 4
        if for i-th problem, the good chain is in X_A, then y[i]==1, else y[i]==0
        '''
        v_features_A, e_features_A = self.prepare_feature_placeholder(N)
        v_features_B, e_features_B = self.prepare_feature_placeholder(N)
        y = np.zeros(N, dtype='int64')

        for instance_idx in range(N):
            good, bad = next(self.cycled_train_pairs)
            if randomize_dir:
                good = good[:-1]+random.choice(['f', 'r'])
                bad = bad[:-1]+random.choice(['f', 'r'])
            v_good, e_good = self.get_features(good)
            v_bad, e_bad = self.get_features(bad)

            label = random.random() > 0.5
            y[instance_idx] = label
            for v_idx in range(4):
                for v_fea_idx in range(len(v_good[v_idx])):
                    if label:
                        v_features_A[v_idx][v_fea_idx][instance_idx] = v_good[v_idx][v_fea_idx]
                        v_features_B[v_idx][v_fea_idx][instance_idx] = v_bad[v_idx][v_fea_idx]
                    else:
                        v_features_B[v_idx][v_fea_idx][instance_idx] = v_good[v_idx][v_fea_idx]
                        v_features_A[v_idx][v_fea_idx][instance_idx] = v_bad[v_idx][v_fea_idx]

            for e_idx in range(3):
                for e_fea_idx in range(len(e_good[e_idx])):
                    if label:
                        e_features_A[e_idx][e_fea_idx][instance_idx] = e_good[e_idx][e_fea_idx]
                        e_features_B[e_idx][e_fea_idx][instance_idx] = e_bad[e_idx][e_fea_idx]
                    else:
                        e_features_B[e_idx][e_fea_idx][instance_idx] = e_good[e_idx][e_fea_idx]
                        e_features_A[e_idx][e_fea_idx][instance_idx] = e_bad[e_idx][e_fea_idx]

        for features in [v_features_A, e_features_A, v_features_B, e_features_B]:
            for feature in features:
                for i in range(len(feature)):
                    feature[i] = Variable(torch.from_numpy(feature[i]))
                    feature[i] = feature[i].to(device=self.device)

        y = Variable(torch.from_numpy(y))
        y = y.to(device=self.device)

        return ((v_features_A, e_features_A), (v_features_B, e_features_B), y)

    def get_test_pairs(self, randomize_dir=True, return_id=False):
        '''
        return a list of two lists, X_A and X_B, as well as a list y
        each list consists of two lists, which are vertex and edge representations
        each list consists of #V or #E lists, which are individual vertices/edges
        each list consists of several N x feature_len torch Variables, which are individual features
        currently only keeping chains of length 4
        if for i-th problem, the good chain is in X_A, then y[i]==1, else y[i]==0
        '''
        N = len(self.test_pairs)
        v_features_A, e_features_A = self.prepare_feature_placeholder(N)
        v_features_B, e_features_B = self.prepare_feature_placeholder(N)
        y = np.zeros(N, dtype='int64')
        if return_id:
            ids = [[], []]

        for instance_idx in range(N):
            good, bad = self.test_pairs[instance_idx]
            if randomize_dir:
                good = good[:-1]+random.choice(['f', 'r'])
                bad = bad[:-1]+random.choice(['f', 'r'])
            v_good, e_good = self.get_features(good)
            v_bad, e_bad = self.get_features(bad)

            label = random.random() > 0.5
            y[instance_idx] = label
            if return_id:
                if label:
                    ids[0].append(good)
                    ids[1].append(bad)
                else:
                    ids[0].append(bad)
                    ids[1].append(good)
            for v_idx in range(4):
                for v_fea_idx in range(len(v_good[v_idx])):
                    if label:
                        v_features_A[v_idx][v_fea_idx][instance_idx] = v_good[v_idx][v_fea_idx]
                        v_features_B[v_idx][v_fea_idx][instance_idx] = v_bad[v_idx][v_fea_idx]
                    else:
                        v_features_B[v_idx][v_fea_idx][instance_idx] = v_good[v_idx][v_fea_idx]
                        v_features_A[v_idx][v_fea_idx][instance_idx] = v_bad[v_idx][v_fea_idx]

            for e_idx in range(3):
                for e_fea_idx in range(len(e_good[e_idx])):
                    if label:
                        e_features_A[e_idx][e_fea_idx][instance_idx] = e_good[e_idx][e_fea_idx]
                        e_features_B[e_idx][e_fea_idx][instance_idx] = e_bad[e_idx][e_fea_idx]
                    else:
                        e_features_B[e_idx][e_fea_idx][instance_idx] = e_good[e_idx][e_fea_idx]
                        e_features_A[e_idx][e_fea_idx][instance_idx] = e_bad[e_idx][e_fea_idx]

        for features in [v_features_A, e_features_A, v_features_B, e_features_B]:
            for feature in features:
                for i in range(len(feature)):
                    feature[i] = Variable(torch.from_numpy(feature[i]))
                    feature[i] = feature[i].to(device=self.device)

        y = Variable(torch.from_numpy(y))
        y = y.to(device=self.device)

        if not return_id:
            return (v_features_A, e_features_A), (v_features_B, e_features_B), y
        else:
            return (v_features_A, e_features_A), (v_features_B, e_features_B), y, ids

    def get_pairs_for_ids(self, ids):
        '''
        ids are list of (first_chain, second_chain) tuples
        return a list of two lists, X_A and X_B
        each list consists of two lists, which are vertex and edge representations
        each list consists of #V or #E lists, which are individual vertices/edges
        each list consists of several N x feature_len torch Variables, which are individual features
        currently only keeping chains of length 4
        '''
        N = len(ids)
        v_features_A, e_features_A = self.prepare_feature_placeholder(N)
        v_features_B, e_features_B = self.prepare_feature_placeholder(N)

        for instance_idx, (first, second) in enumerate(ids):
            v_first, e_first = self.get_features(first)
            v_second, e_second = self.get_features(second)

            for v_idx in range(4):
                for v_fea_idx in range(len(v_first[v_idx])):
                    v_features_A[v_idx][v_fea_idx][instance_idx] = v_first[v_idx][v_fea_idx]
                    v_features_B[v_idx][v_fea_idx][instance_idx] = v_second[v_idx][v_fea_idx]

            for e_idx in range(3):
                for e_fea_idx in range(len(e_first[e_idx])):
                    e_features_A[e_idx][e_fea_idx][instance_idx] = e_first[e_idx][e_fea_idx]
                    e_features_B[e_idx][e_fea_idx][instance_idx] = e_second[e_idx][e_fea_idx]

        for features in [v_features_A, e_features_A, v_features_B, e_features_B]:
            for feature in features:
                for i in range(len(feature)):
                    feature[i] = Variable(torch.from_numpy(feature[i]))
                    feature[i] = feature[i].to(device=self.device)
        return ((v_features_A, e_features_A), (v_features_B, e_features_B))


# if __name__ == '__main__':
#     feature_names = ['v_freq_freq', 'v_sense', 'e_source', 'e_dir', 'e_sense']
#     split_frac = 0.9
#     if_gpu = True
#     d = Dataset(feature_names, split_frac, if_gpu)

#     batch_size = 1000
#     good, bad, y = d.get_train_pairs(batch_size)
#     v_good, e_good = good
#     v_bad, e_bad = bad

#     print(len(v_good))
#     for feature in v_good[0]:  # features in first vertex
#         print(feature.shape)

#     print(len(e_good))
#     print(len(v_bad))
#     print(len(e_bad))

model.py

from __future__ import division

import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.functional import relu

class FeatureTransformer(nn.Module):
    '''
    take an n x d_in matrix and transform it into a n x d_out matrix
    where the n x d_in matrix is the n examples each with d_in dimensions
    '''

    def __init__(self, d_in, d_out):
        super(FeatureTransformer, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        # takes in (*, d_in) tensors and outputs (*, d_out) tensors
        self.linear = nn.Linear(d_in, d_out)
        # self.relu = nn.ReLU()

    def forward(self, input):
        return relu(self.linear(input))


class ChainEncoder(nn.Module):
    '''
    encodes N chains at the same time
    assumes that each of the chains are of the same length
    '''

    def __init__(self, v_feature_lengths, e_feature_lengths, out_length, pooling):
        super(ChainEncoder, self).__init__()
        self.out_length = feature_enc_length = out_length
        num_layers = 1
        self.rnn_type = 'LSTM'
        self.pooling = pooling
        self.v_feature_lengths = v_feature_lengths
        self.e_feature_lengths = e_feature_lengths

        self.v_feature_encoders = nn.ModuleList()
        self.e_feature_encoders = nn.ModuleList()
        for d_in in self.v_feature_lengths:
            self.v_feature_encoders.append(
                FeatureTransformer(d_in, feature_enc_length))
        for d_in in self.e_feature_lengths:
            self.e_feature_encoders.append(
                FeatureTransformer(d_in, feature_enc_length))

        # RNN famlity layer: input (seq_len, batch_size, d_in), output (seq_len, batch_size, d_out * D) where D=2 for bidirectional, D=1 otherwise
        if self.rnn_type == 'RNN':
            self.rnn = nn.RNN(input_size=feature_enc_length,
                              hidden_size=out_length, num_layers=num_layers)
        elif self.rnn_type == 'LSTM':
            self.lstm = nn.LSTM(input_size=feature_enc_length,
                                hidden_size=out_length, num_layers=num_layers)

    def forward(self, input):
        '''
        input is a list of v_features, and e_features
        v_features is a list of num_vertices tuples
        each tuple is an N x d_in Variable, in which N is the batch size, and d_in is the feature length
        e_features is structured similarly
        '''
        v_features, e_features = input
        # v_features.shape == (num_vertices, batch_size, variable feature_len)
        # e_features.shape == (num_edges, batch_size, variable feature_len)

        v_encodes = []
        for i in range(len(v_features)):  # 4 vertices 
            v_enc = None
            for j in range(len(v_features[i])):  # feature in each vertex
                curr_encoder = self.v_feature_encoders[j]
                if v_enc is None:
                    v_enc = curr_encoder(v_features[i][j])
                else:
                    v_enc += curr_encoder(v_features[i][j])
            v_enc = v_enc / len(v_features[i])  # each feature encode is of shape (batch_size, out_length)
            v_encodes.append(v_enc)

        e_encodes = []
        for i in range(len(e_features)):  # 3 edges
            e_enc = None
            for j in range(len(e_features[i])):
                curr_encoder = self.e_feature_encoders[j]
                if e_enc is None:
                    e_enc = curr_encoder(e_features[i][j])
                else:
                    e_enc += curr_encoder(e_features[i][j])
            e_enc = e_enc / len(e_features[i])
            e_encodes.append(e_enc)

        combined_encs = [0] * (len(v_encodes)+len(e_encodes))
        # interleave vertices and edges
        combined_encs[::2] = v_encodes
        combined_encs[1::2] = e_encodes
        # combined_encs = torch.stack(combined_encs, dim=0).detach().clone()
        combined_encs = torch.stack(combined_encs, dim=0)
        # combined_encs has shape (#V+#E) x batch_size x out_length

        if self.rnn_type == 'RNN':
            output, hidden = self.rnn(combined_encs)
        elif self.rnn_type == 'LSTM':
            output, (hidden, cell) = self.lstm(combined_encs)
        if self.pooling == 'last':
            return output[-1]
        else:
            return torch.mean(output, dim=0)


class Predictor(nn.Module):
    '''
    takes two feature vectors and produces a prediction
    '''

    def __init__(self, feature_len):
        super(Predictor, self).__init__()
        self.linear = nn.Linear(feature_len, 1)
        self.logsoftmax = nn.LogSoftmax(dim=1)  # will use NLLLoss in learn.py

    def forward(self, vec1, vec2):
        # combined = torch.cat((vec1, vec2), dim=1)
        a = self.linear(vec1)
        b = self.linear(vec2)
        combined = torch.cat((a, b), dim=1)
        return self.logsoftmax(combined)

class JointModel(nn.Module):
    '''
    Combine ChainEncoder and Predictor together
    '''

    def __init__(self,  v_feature_lengths, e_feature_lengths, out_length, pooling):
        super(JointModel, self).__init__()
        self.encoder = ChainEncoder(v_feature_lengths, e_feature_lengths, out_length, pooling)
        self.predictor = Predictor(out_length)

    def forward(self, input1, input2):
        vec1 = self.encoder(input1)
        vec2 = self.encoder(input2)
        return self.predictor(vec1, vec2)

learn.py

from __future__ import division

import numpy as np
import torch
import sys
import os
from torch import nn, optim
from torch.autograd import Variable
from model import ChainEncoder, Predictor, JointModel
from dataset import Dataset
from multiprocessing import Pool
import time
from datetime import datetime


def train(dataset, fea_len, num_iter=4000, N=1000, out_file='train.log'):
    if isinstance(out_file, str):
        out_file = open(out_file, 'w')

    print('defining architecture')
    encoder = ChainEncoder(dataset.get_v_fea_len(),
                           dataset.get_e_fea_len(), fea_len, 'last')
    predictor = Predictor(fea_len)
    # model = JointModel(dataset.get_v_fea_len(), dataset.get_e_fea_len(), fea_len, 'last')
    loss = nn.NLLLoss()

    encoder.to(device=device)
    predictor.to(device=device)
    # model.to(device=device)
    loss.to(device=device)

    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(predictor.parameters()))
    # optimizer = optim.Adam(model.parameters())

    print('Start training')
    start = time.time()
    for train_iter in range(num_iter):
        chains_A, chains_B, y = dataset.get_train_pairs(N)
        output_A = encoder(chains_A)
        output_B = encoder(chains_B)
        logSoftmax_output = predictor(output_A, output_B)
        optimizer.zero_grad()
        # logSoftmax_output = model(chains_A, chains_B)
        loss_val = loss(logSoftmax_output, y)
        loss_val.backward()
        optimizer.step()

        if train_iter % 100 == 0:
            print(
                f"Progress: {100*train_iter/num_iter:.2f}%, loss: {loss_val.item()}, time spent: {(time.time() - start)/60:.2f} minutes")

            out_file.write(f"{num_iter}, loss: {loss_val.item()}\n")
            torch.save(encoder.state_dict(),
                       f'ckpt/{train_iter}_encoder.model')
            torch.save(predictor.state_dict(),
                       f'ckpt/{train_iter}_predictor.model')

    print(f'Finish training, time spent: {(time.time()-start)/60:.2f} minutes')
    out_file.close()
    return encoder, predictor, loss


def test(dataset, encoder, predictor, loss, out_file='test.log'):
    if isinstance(out_file, str):
        out_file = open(out_file, 'a')

    print("Start testing")
    chains_A, chains_B, y = dataset.get_test_pairs(randomize_dir=True, return_id=False)

    with torch.no_grad():
        output_test_A = encoder(chains_A)
        output_test_B = encoder(chains_B)
        logSoftmax_output = predictor(
            output_test_A, output_test_B).to(device='cpu').numpy()

        pred = logSoftmax_output.argmax(axis=1)
        y = y.to(device='cpu').numpy()

        cur_acc = (pred == y).sum() / len(y)

        print(f'test acc: {cur_acc}')
        out_file.write("\nTest time: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S")))
        out_file.write(f'{cur_acc}\n')

    out_file.close()


#torch.autograd.set_detect_anomaly(True)
use_gpu = True
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

features = ['v_enc_dim300', 'v_freq_freq', 'v_deg', 'v_sense', 'e_vertexsim',
            'e_dir', 'e_rel', 'e_weightsource', 'e_srank_rel', 'e_trank_rel', 'e_sense']
feature_len = 20
split_frac = 0.8
dataset = Dataset(features, split_frac, device)

num_iter = 6000
N = 1000
print(f'Batch size: {N}')

encoder, predictor, loss = train(dataset, feature_len, num_iter, N)
test(dataset, encoder, predictor, loss)

Your code is unfortunately not executable, as you depend on the locally stored dataset and I don’t easily see the difference between the new code and the previous one which worked fine in my setup.
Maybe try to update PyTorch to the latest release (or the nightly) and it would also work?

Hi, the problem is resolved without using detach().clone() finally.
I changed the assignment lines combined_encs[::2] = v_encodes, combined_encs[1::2] = e_encodes to append each tensor in v_encodes and e_encodes to combined_encs one by one, then the backprop magically worked well…
Really thanks for your help!