I have a wrapper for a huggingface model. In this wrapper I have some encoders, which are mainly a series of embeddings. In forward of the wrapped model, I want to call forward of each of encoders in a loop, but I get the error:
Traceback (most recent call last):
File "/home/pouramini/mt5-comet/comet/train/train.py", line 1275, in <module>
run()
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 716, in __call__
return self.main(*args, **kwargs)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 696, in main
rv = self.invoke(ctx)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 1060, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 889, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/click/core.py", line 534, in invoke
return callback(*args, **kwargs)
File "/home/pouramini/mt5-comet/comet/train/train.py", line 1069, in train
result = wrapped_model(**batch)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 135, in forward
prompt_embeds = encoder(prompt_input_ids,\
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pouramini/mt5-comet/comet/transformers_ptuning/ptuning_wrapper.py", line 238, in forward
return self.embedding(prompt_token_ids)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
return F.embedding(
File "/home/pouramini/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2043, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument index in method wrapper_index_select)
This the code that results the error:
def forward(self,input_ids, labels, decoder_input_ids=None,prompt_ids=None,**kwargs):
# find masks based on the range of prompt ids (offset_id < X < offset_id + prompt_length)
#Because this wrapper only deals with a single prompt, the length should be the same, you can use masked_select to reshape
prompt_masks = self.prompt_token_fn(input_ids)
if prompt_masks.any():
input_ids_ = input_ids.clone()
wlog.info("inpu ids :{}".format(input_ids))
if self.replacing_token_id is not None:
# replace prompt ids in input_ids with replacing token
input_ids_[prompt_masks]=self.replacing_token_id
# find the model embeddings of input ids except for prompt tokens
inputs_embeds = self.model_embeddings(input_ids_)
for encoder in self.prompt_encoders:
#encoder = self.prompt_encoders[0]
wlog.info("********** offset: %s, length: %s", encoder.id_offset, encoder.length)
prompt_token_fn = encoder.get_prompt_token_fn()
encoder_masks = prompt_token_fn(input_ids)
wlog.info("Encoder masks: %s", encoder_masks)
if encoder_masks.any():
#find input ids for prompt tokens
prompt_input_ids = input_ids[encoder_masks]
wlog.info("Prompt Input ids: %s", prompt_input_ids)
# call forwards on prompt encoder whose outputs are prompt embeddings
prompt_embeds = encoder(prompt_input_ids,\
prompt_ids).to(device=inputs_embeds.device)
The code however runs if I just use cpu
as device. Also if I have one encoder, the code is run with cuda, but when there are multiple encoders, it seems it expects that all of them are transfered to device, which I don’t know how to do that.