Mysterious `trying to backward through the graph a second time' issue


I am trying to run a simple enc/dec architecture, consisting of

  1. a frozen encoder
  2. a simple feed-forward on the outputs of (1)
  3. a trainable decoder on the outputs of (2)
  4. a simple feed-forward on the outputs of (3)

The fw/bw pipeline works fine on the first batch, yet breaks on the next backwards() call.
My train_batch() function returns nothing, and no intermediate results are stored.
Help would be much appreciated.

    def __init__(self, num_embeddings: int, dim: int = 768, device: str = 'cpu'):
        super(Parser, self).__init__()
        self.dim = dim
        self.num_embeddings = num_embeddings
        self.device = device
        self.encoder = RobertaModel.from_pretrained('pdelobelle/robBERT-base')
        self.encoder_projection = FFN(d_model=dim, d_ff=2 * dim).to(device)
        self.decoder = make_decoder(num_layers=2, num_heads=6, d_model=self.dim, d_k=self.dim, d_v=self.dim,
                                    d_intermediate=self.dim, dropout=0.1).to(device)
        self.embedding_matrix = Parameter(torch.rand(num_embeddings + 1, dim, device=device) * 0.02,
                                          requires_grad=True) * 0.02
        self.atom_embedder = lambda x: functional.embedding(x, self.embedding_matrix, padding_idx=0,
        self.predictor = lambda x: x @ (self.embedding_matrix.transpose(1, 0) + 1e-10)
        self.pos_encoder = PositionalEncoder(0.1)

    def forward(self, lexical_token_ids: LongTensor, symbol_ids: LongTensor, pos_idxes: List[List[LongTensor]],
                neg_idxes: List[List[LongTensor]]) -> Any:

    def encode(self, lexical_token_ids: LongTensor, encoder_mask: LongTensor) -> Tensor:
        with torch.no_grad():
            encoder_output, _ = self.encoder(lexical_token_ids, attention_mask=encoder_mask)
            encoder_output =
        return self.encoder_projection(encoder_output)

    def make_output_repr(self, lexical_token_ids: LongTensor, symbol_ids: LongTensor) -> Tensor:
        b, s_in = lexical_token_ids.shape
        s_out = symbol_ids.shape[1]

        encoder_mask = self.make_encoder_mask(lexical_token_ids)

        encoder_output = self.encode(lexical_token_ids, encoder_mask)

        decoder_mask = self.make_decoder_mask(b=b, n=symbol_ids.shape[1])
        atom_embeddings = self.atom_embedder(
        pos_encodings = self.pos_encoder(b, s_out, self.dim, 1024).to(self.device)
        atom_embeddings = atom_embeddings + pos_encodings

        extended_encoder_mask = encoder_mask.view(b, 1, s_in).repeat(1, s_out, 1).to(self.device)

        return self.decoder((encoder_output, extended_encoder_mask, atom_embeddings, decoder_mask))[2]

    def predict(self, symbol_reprs: Tensor) -> Tensor:
        return self.predictor(symbol_reprs)

    def make_mask(inps: LongTensor, padding_id: int) -> LongTensor:
        mask = torch.ones_like(inps)
        mask[inps == padding_id] = 0
        return mask

    def make_encoder_mask(self, lexical_ids: LongTensor) -> LongTensor:
        return self.make_mask(lexical_ids, 1)

    def make_decoder_mask(self, b: int, n: int) -> LongTensor:
        return torch.triu(torch.ones(b, n, n), diagonal=1).to(self.device)

    def train_batch(self, samples: List[Sample], atom_map: Mapping[str, int], tokenizer: Tokenizer,
                    optimizer: Optimizer, max_difficulty: int = 20):

        words, types, pos_idxes, neg_idxes = samples_to_batch(samples, atom_map, tokenizer)

        output_reprs = self.make_output_repr(words, types)

        # supertagging
        type_predictions = self.predict(output_reprs)

        type_predictions = type_predictions[:, :-1].permute(0, 2, 1).reshape(-1, self.num_embeddings + 1)
        types = types[:, 1:].to(self.device).reshape(-1)

        supertagging_loss = functional.cross_entropy(type_predictions, types)


    def train_epoch(self, dataloader: DataLoader, atom_map: Mapping[str, int], tokenizer: Tokenizer,
                    optimizer: Optimizer):

        for i, samples in enumerate(dataloader):
            self.train_batch(samples, atom_map, tokenizer, optimizer)```


It is a bit hard to say given that we cannot really run the code.
Most likely something that you do when creating the net is a differentiable op and gets reused every time you call train_batch().

A simple way to check this is to use torchviz.
You can print the graph of the supertagging_loss for the first batch. Then print the graph for the second batch. If the second one contains the first, then you have something that links one iteration to the next.

Otherwise, you want to save in a global variable supertagging_loss from the first batch and then at the second one do make_dot([first_batch_supertagging_loss, second_batch_supertagging_loss]) and make sure you get two disjoint graphs except for the parameters (the blue ovals) that are used in both. If you have other nodes that are shared, that means that some computations are re-used from one iteration to the next.

Just replying in case this affects anyone else.

Calling self.register_parameter() on the embedding matrix resolves the issue, which seems to point at some other underlying bug.

1 Like

If you do = nn.Parameter(some_tensor), it should be the same as self.register_parameter("bar", nn.Parameter(some_tensor).