PyTorch Transformers: TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

I want to modify this model so it is made of xlm-roberta as encoder and gpt-neo as decoder.

At the beginning of training I get the following error:

Traceback (most recent call last):
  File "run.py", line 277, in <module>
    task()
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.7/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "run.py", line 166, in train
    train_model(start_epoch, eval_loss, (train_dl, valid_dl), optimizer, kwargs['base_path']+kwargs["checkpoint_path"], kwargs['base_path']+kwargs["best_model"])
  File "run.py", line 233, in train_model
    labels = labels.to(device))[0]
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 628, in forward
    **kwargs_decoder,
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'

I think the problem is somewhere here, but I actually have no idea since the last line of the traceback is referencing the forward() function - which isn’t explicitly part of my model.

loss = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[0]

How can I “access” the forward() function? (Where is it called?) I know the model works because I tried it with the original configuration (BERT & GPT-2).

T his is the model setup:

model = EncoderDecoderModel.from_encoder_decoder_pretrained('xlm-roberta-base', 'EleutherAI/gpt-neo-1.3B', tie_encoder_decoder=True)
model.decoder.config.use_cache = False
tokenizer = Tokenizer(max_token_len)
model.config.decoder_start_token_id = tokenizer.autotokenizer.bos_token_id
model.config.eos_token_id = tokenizer.autotokenizer.eos_token_id
model.config.max_length = max_token_len
model.config.no_repeat_ngram_size = 3

Training:

def train(**kwargs):
    print("Loading datasets...")
    train_dataset = WikiDataset(kwargs['base_path']+kwargs['src_train'], kwargs['base_path']+kwargs['tgt_train'])
    valid_dataset = WikiDataset(kwargs['base_path']+kwargs['src_valid'], kwargs['base_path']+kwargs['tgt_valid'], kwargs['base_path']+kwargs['ref_valid'], ref=True)
    print("Dataset loaded successfully")

    train_dl = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
    valid_dl = DataLoader(valid_dataset, batch_size=TRAIN_BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.0}
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=3e-5)
    
    if os.path.exists(kwargs['base_path']+kwargs["checkpoint_path"]):
        optimizer, eval_loss, start_epoch = load_checkpt(kwargs['base_path']+kwargs["checkpoint_path"], optimizer)
        print(f"Loading model from checkpoint with start epoch: {start_epoch} and loss: {eval_loss}")
        logging.info(f"Model loaded from saved checkpoint with start epoch: {start_epoch} and loss: {eval_loss}")
    
    train_model(start_epoch, eval_loss, (train_dl, valid_dl), optimizer, kwargs['base_path']+kwargs["checkpoint_path"], kwargs['base_path']+kwargs["best_model"])



def train_model(start_epoch, eval_loss, loaders, optimizer, check_pt_path, best_model_path):
    best_eval_loss = eval_loss
    print("Model training started...")
    for epoch in range(start_epoch, N_EPOCH):
        print(f"Epoch {epoch} running...")
        epoch_start_time = time.time()
        epoch_train_loss = 0
        epoch_eval_loss = 0
        model.train()
        for step, batch in enumerate(loaders[0]):
            src_tensors, src_attn_tensors, tgt_tensors, tgt_attn_tensors, labels = tokenizer.encode_batch(batch)
            optimizer.zero_grad()
            model.zero_grad()
            loss = model(input_ids = src_tensors.to(device), 
                            decoder_input_ids = tgt_tensors.to(device),
                            attention_mask = src_attn_tensors.to(device),
                            decoder_attention_mask = tgt_attn_tensors.to(device),
                            labels = labels.to(device))[0]
            if step == 0:
                epoch_train_loss = loss.item()
            else:
                epoch_train_loss = (1/2.0)*(epoch_train_loss + loss.item())
            
            loss.backward()
            optimizer.step()

            if (step+1) % LOG_EVERY == 0:
                print(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
                logging.info(f'Epoch: {epoch} | iter: {step+1} | avg. train loss: {epoch_train_loss} | time elapsed: {time.time() - epoch_start_time}')
        
        eval_start_time = time.time()
        epoch_eval_loss, bleu_score, sari_score = evaluate(loaders[1], epoch_eval_loss)
        epoch_eval_loss = epoch_eval_loss/TRAIN_BATCH_SIZE
        print(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score} | Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')
        logging.info(f'Completed Epoch: {epoch} | avg. eval loss: {epoch_eval_loss:.5f} | blue score: {bleu_score}| Sari score: {sari_score} | time elapsed: {time.time() - eval_start_time}')

        check_pt = {
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': epoch_eval_loss,
            'sari_score': sari_score,
            'bleu_score': bleu_score
        }
        check_pt_time = time.time()
        print("Saving Checkpoint.......")
        if epoch_eval_loss < best_eval_loss:
            print("New best model found")
            logging.info(f"New best model found")
            best_eval_loss = epoch_eval_loss
            save_model_checkpt(check_pt, True, check_pt_path, best_model_path)
        else:
            save_model_checkpt(check_pt, False, check_pt_path, best_model_path)  
        print(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")
        logging.info(f"Checkpoint saved successfully with time: {time.time() - check_pt_time}")

        gc.collect()
        torch.cuda.empty_cache() 

It is called when you pass inputs to you model instance, so:
output = model(input_ids, ... ) is equivalent to output = model.forward(input_ids, ... ).

Please refer to the documentation of the EncoderDecoderModel class to see what arguments are expected by its forward method as according to the error message you are passing an unexpected argument which is apparently ‘encoder_hidden_states’.

You might want to adjust the configuration to be able to modify the default behaviour. I strongly recommend posting on the Hugging Face forums (you are using their transformers API) in case you aren’t able to configure.

Thank you Srishti for your reply!

I already posted on the Hugging Face forum - unfortunately noone answered :smiling_face_with_tear:

Everything that I passed seems to be expected by the forward method. At no point in my code am I passing “encoder_hidden_states”. Is there another tip you can give me? I have no idea where I can remove encoder_hidden_states since it is not part of my code…

I see. In that case, I could try to help by first reproducing the error on my end and from there looking at some source code could help.

Could you please confirm if the code included in your first post is executable and reproduces the said error? If not, please try to post a minimum executable snippet with input dimensions etc.

Thank you so much for the offer! Unfortunately I don’t understand the code enough to provide a minimum example.
That would be too elaborate on your side and too much to ask. I will try to find a different solution. But thanks again for the kind offer :slight_smile: