How to add the submodule's parameters into model.parameters()?

I’m currently trying to implement deep neural decision forest, however, I met some problems.
It seem that the submodule’s parameters haven’t add to the model’s parametes. I wonder if it is because of the module list.
Here is my definition of the model:

`
class DeepNeuralDecisionForest(nn.Module):

   def __init__(self, p_keep_conv, p_keep_hidden, n_leaf, n_label, n_tree, n_depth):
    super(DeepNeuralDecisionForest, self).__init__()

    self.conv = nn.Sequential()
    self.conv.add_module('conv1', nn.Conv2d(1, 32, kernel_size=3, padding=1))
    self.conv.add_module('relu1', nn.ReLU())
    self.conv.add_module('pool1', nn.MaxPool2d(kernel_size=2))
    self.conv.add_module('drop1', nn.Dropout(1 - p_keep_conv))
    self.conv.add_module('conv2', nn.Conv2d(32, 64, kernel_size=3, padding=1))
    self.conv.add_module('relu2', nn.ReLU())
    self.conv.add_module('pool2', nn.MaxPool2d(kernel_size=2))
    self.conv.add_module('drop2', nn.Dropout(1 - p_keep_conv))
    self.conv.add_module('conv3', nn.Conv2d(64, 128, kernel_size=3, padding=1))
    self.conv.add_module('relu3', nn.ReLU())
    self.conv.add_module('pool3', nn.MaxPool2d(kernel_size=2))
    self.conv.add_module('drop3', nn.Dropout(1 - p_keep_conv))

    self._nleaf = n_leaf
    self._nlabel = n_label
    self._ntree = n_tree
    self._ndepth = n_depth
    self._batchsize = 100

    self.treelayers = []
    self.pi_e = []
    for i in xrange(self._ntree):
        treelayer = nn.Sequential()
        treelayer.add_module('sub_linear1', nn.Linear(1152, 625))
        treelayer.add_module('sub_relu', nn.ReLU())
        treelayer.add_module('sub_drop1', nn.Dropout(1 - p_keep_hidden))
        treelayer.add_module('sub_linear2', nn.Linear(625, self._nleaf))
        treelayer.add_module('sub_sigmoid', nn.Sigmoid())
        pi = Parameter(self.init_pi())
        self.treelayers.append(treelayer)
        self.pi_e.append(nn.Softmax()(pi))

def init_pi(self):
    return torch.ones(self._nleaf, self._nlabel)/float(self._nlabel)

def init_weights(self, shape):
    return torch.randn(shape) * 0.01

def init_prob_weights(self, shape, minval=-5, maxval=5):
    return torch.Tensor(shape[0], shape[1]).uniform_(minval, maxval)

def compute_mu(self, flat_decision_p_e):
    n_batch = self._batchsize
    batch_0_indices = torch.range(0, n_batch * self._nleaf - 1, self._nleaf).unsqueeze(1).repeat(1, self._nleaf).long()

    in_repeat = self._nleaf / 2
    out_repeat = n_batch

    batch_complement_indices = torch.LongTensor(
        np.array([[0] * in_repeat, [n_batch * self._nleaf] * in_repeat] * out_repeat).reshape(n_batch, self._nleaf))

    # First define the routing probabilistics d for root nodes
    mu_e = []
    indices_var = Variable((batch_0_indices + batch_complement_indices).view(-1))
    indices_var = indices_var.cuda()
    #indices_var = indices_var.typeas(flat_decision_p_e[0])
    # iterate over each tree
    for i, flat_decision_p in enumerate(flat_decision_p_e):
        mu = torch.gather(flat_decision_p, 0, indices_var).view(n_batch, self._nleaf)
        mu_e.append(mu)

    # from the scond layer to the last layer, we make the decison nodes
    for d in xrange(1, self._ndepth + 1):
        indices = torch.range(2 ** d, 2 ** (d + 1) - 1) - 1
        tile_indices = indices.unsqueeze(1).repeat(1, 2 ** (self._ndepth - d + 1)).view(1, -1)
        batch_indices = batch_0_indices + tile_indices.repeat(n_batch, 1).long()

        in_repeat = in_repeat / 2
        out_repeat = out_repeat * 2

        # Again define the indices that picks d and 1-d for the nodes
        batch_complement_indices = torch.LongTensor(
            np.array([[0] * in_repeat, [n_batch * self._nleaf] * in_repeat] * out_repeat).reshape(n_batch, self._nleaf))

        mu_e_update = []
        indices_var = Variable((batch_indices + batch_complement_indices).view(-1))
        indices_var = indices_var.cuda()
        for mu, flat_decision_p in zip(mu_e, flat_decision_p_e):
            mu = torch.mul(mu, torch.gather(flat_decision_p, 0, indices_var).view(
                n_batch, self._nleaf))
            mu_e_update.append(mu)
        mu_e = mu_e_update
    return mu_e

def compute_py_x(self, mu_e):
    py_x_e = []
    n_batch = self._batchsize

    for mu, leaf_p in zip(mu_e, self.pi_e):
        py_x_tree = mu.unsqueeze(2).repeat(1, 1, self._nlabel).mul(leaf_p.unsqueeze(0).repeat(n_batch, 1, 1)).mean(1)
        py_x_e.append(py_x_tree)

    py_x_e = torch.cat(py_x_e, 1)
    py_x = py_x_e.mean(1).squeeze()
    return py_x

def forward(self, x):
    feat = self.conv.forward(x)
    feat = feat.view(-1, 1152)
    self._batchsize = x.size(0)
    #py_x = self.fc.forward(feat)
    flat_decision_p_e = []
    for i in xrange(len(self.treelayers)):
        decision_p = self.treelayers[i].forward(feat)
        decision_p_comp = 1 - decision_p
        decision_p_pack = torch.cat((decision_p, decision_p_comp), 1)
        flat_decision_p = decision_p_pack.view(-1)
        flat_decision_p_e.append(flat_decision_p)
    
    mu_e = self.compute_mu(flat_decision_p_e)
    
    py_x = self.compute_py_x(mu_e)`
    return py_x

You need to use nn.ModuleList.

Thanks for your reminding.

I met a similar but not same problem. I defined a new module as follow:

class RecursiveNN(nn.Module):
    def __init__(self, word_embedding, hidden_dim):
        super(RecursiveNN, self).__init__()
        self.word_dim = word_embedding.embeddings.size(1)
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(word_embedding.embeddings.size(0),
                                      self.word_dim)
        self.embedding.weight = nn.Parameter(word_embedding.embeddings)
        self.word2hidden = nn.Linear(self.word_dim, self.hidden_dim)
        self.hidden2hidden = nn.Linear(2 * self.hidden_dim, self.hidden_dim)

    def forward(self, node):
        if not node.val is None:
            node.calculate_result = self.word2hidden(self.embedding(Variable(torch.LongTensor([node.word_id]))))
            return node.calculate_result
        else:
            assert len(node.children) == 2
            node.calculate_result = self.hidden2hidden(torch.cat((node.children[0].calculate_result,
                                                          node.children[1].calculate_result), 1))
            return node.calculate_result

And, this module is used by another module whose definition is shown as below:

class RootAlign(nn.Module):
    def __init__(self, word_embedding, config):
        super(RootAlign, self).__init__()
        self.rnn = RecursiveNN(word_embedding, config['hidden_dim'])
        self.linear = nn.Linear(config['hidden_dim'] * 2, config['relation_num'])

    def forward(self, p_tree, h_tree):
        p_tree.postorder_traverse(self.rnn)
        h_tree.postorder_traverse(self.rnn)

        out = F.softmax(self.linear(F.sigmoid(torch.cat((p_tree.calculate_result, h_tree.calculate_result), 1))))
        return out

What I wonder is how to add the parameters of RecursiveNN into RootAlign so that their parameters can be trained together.

I would be very grateful if you could help me

The code seems correct, can’t you see parameters from RecursiveNN in .parameters()?

Also, I’d recommend agains caching node.calculate_result if you don’t need it. It will prevent PyTorch from freeing the graph that created node.calculate_result until it is overwriten or manually deleted.

Thank you very much for your reply. After check RootAlign.paramters(), I’d say that RootAlign does have the parameters of RecursiveNN. Yet I don’t understand why the model can’t be well trained. The train function shows as follow:

for _data in snli.train:
            p_tree = _data['p_tree']
            h_tree = _data['h_tree']
            target = Variable(torch.LongTensor([_data['label']]))
            optimizer.zero_grad()
            output = root_align(p_tree, h_tree)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss

And, the whole project can be obtained at [here]. The train function can be found in “train.py

By the way, I don’t know how to cache node.calculate_result. Can you please show me an example?

Thank you again for your help.

There’s one problem with your training loop, but it shouldn’t affect correctness. Don’t do train_loss += loss, because you’ll be keeping graphs for each iteration around. Do train_loss += loss.data[0], so that you only accumulate the value, not the Variable that records each iteration.

The project is quite large so I’m afraid I won’t be able to help you. There’s probably a bug somewhere. Maybe this example could help you somehow.

Now, I know where the problem is. In RootAlign.forward(), the F.softmax should be replaced by F.log_softmax as I use F.nll_loss to calculate losses of each sample.

Anyway, thanks for your help.