RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument index in method wrapper_index_select)

I have a wrapper for a huggingface model. In this wrapper I have some encoders, which are mainly a series of embeddings. In forward of the wrapped model, I want to call forward of each of encoders in a loop, but I get the error:

Traceback (most recent call last):
  File "/home/pouramini/mt5-comet/comet/train/train.py", line 1275, in <module>
    run()
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 716, in __call__
    return self.main(*args, **kwargs)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 696, in main
    rv = self.invoke(ctx)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 1060, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 889, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 534, in invoke
    return callback(*args, **kwargs)
  File "/home/pouramini/mt5-comet/comet/train/train.py", line 1069, in train
    result = wrapped_model(**batch)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 135, in forward
    prompt_embeds = encoder(prompt_input_ids,\
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 238, in forward
    return self.embedding(prompt_token_ids)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2043, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument index in method wrapper_index_select)

This the code that results the error:

def forward(self,input_ids, labels, decoder_input_ids=None,prompt_ids=None,**kwargs):
        # find masks based on the range of prompt ids (offset_id < X < offset_id + prompt_length)
        #Because this wrapper only deals with a single prompt, the length should be the same, you can use masked_select to reshape 
        prompt_masks = self.prompt_token_fn(input_ids)
        if prompt_masks.any():
            input_ids_ = input_ids.clone()
            wlog.info("inpu ids :{}".format(input_ids))
            if self.replacing_token_id is not None:
                # replace prompt ids in input_ids with replacing token
                input_ids_[prompt_masks]=self.replacing_token_id
            # find the model embeddings of input ids except for prompt tokens
            inputs_embeds = self.model_embeddings(input_ids_)

           for encoder in self.prompt_encoders:
                #encoder = self.prompt_encoders[0]
                wlog.info("********** offset: %s, length: %s", encoder.id_offset, encoder.length)
                prompt_token_fn = encoder.get_prompt_token_fn()
                encoder_masks = prompt_token_fn(input_ids)
                wlog.info("Encoder masks: %s", encoder_masks)
                if encoder_masks.any():
                    #find input ids for prompt tokens
                    prompt_input_ids = input_ids[encoder_masks]
                    wlog.info("Prompt Input ids: %s", prompt_input_ids)
                    # call forwards on prompt encoder whose outputs are prompt embeddings
                    prompt_embeds = encoder(prompt_input_ids,\
                        prompt_ids).to(device=inputs_embeds.device)

The code however runs if I just use cpu as device. Also if I have one encoder, the code is run with cuda, but when there are multiple encoders, it seems it expects that all of them are transfered to device, which I don’t know how to do that.

Assuming the error is raised in one of the prompt_encoders, check of their internal modules are already pushed to the device and if not either call model.to('cuda') on the entire model or on each of the prompt_encoders.

@ptrblck Thank you very much I added:

      wrapped_model.to(device=device)
      for encoder in wrapped_model.prompt_encoders:
            encoder.to(device=device)

Interestingly, when there was a single encoder or a list of encoders including one encoder, I didn’t need to explicitly put it on device, but for the list of encoders i must! I wonder why?! Probably because in the case of single encoder, I put it on device in the first iteration of the loop.

I don’t know why this was needed as I’m not familiar enough with the model implementation, but it sounds indeed unexpected.