Hi I have a simple model below that runs on cpu and returns the output as expected. The moment I set this model to CUDA I get the interger multiplication overflow issue. This model only embeds a set of integers so it’s not clear to me where this is coming from.
import torch
import numpy
device = (torch.device('cuda') if torch.cuda.is_available()
else torch.device('cpu'))
class SmallModel(torch.nn.Module):
def __init__(self, skill_num, emb_size, max_seq_length):
super().__init__()
self.emb_size=emb_size
self.skill_num=skill_num
self.max_seq_length=max_seq_length
self.skill_embeddings=torch.nn.Embedding(self.skill_num, self.emb_size)
self.inter_embeddings=torch.nn.Embedding(self.skill_num*2,self.emb_size)
self.embd_pos = torch.nn.Embedding(self.max_seq_length , self.emb_size)
def forward(self, x, y):
query = self.skill_embeddings(x) # shape bs X seq_len X emb_size
# mask_labels = y * (y > -1).long()
# #mask_labels=mask_labels.to(device)
# key = self.inter_embeddings(x+mask_labels*self.skill_num)
# values = self.inter_embeddings(x+mask_labels*self.skill_num)
# pos = self.embd_pos(torch.arange(x.shape[1]))
# key = key+pos
# query = query+pos
return query
# create some mock data, 5 students with 10 seq length
input=abs(torch.ceil(torch.randn(5,10)*100)).type(torch.int)
output=torch.zeros(5,10).type(torch.int)
output[1,3]=1
output[2,5]=1
output[4,9]=-1
# set model params
skill_num=int(torch.max(input).numpy())+1
emb_size=12
max_seq_length=10
# init model
test_mod=SmallModel(skill_num=skill_num, emb_size=emb_size, max_seq_length=max_seq_length).to(device)
# run
#query=test_mod(input, output) # with cpu
query=test_mod(input.to(device), output.to(device)) # with gpu
query
Error message
RuntimeError Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/IPython/core/formatters.py in __call__(self, obj)
700 type_pprinters=self.type_printers,
701 deferred_pprinters=self.deferred_printers)
--> 702 printer.pretty(obj)
703 printer.flush()
704 return stream.getvalue()
/opt/conda/lib/python3.7/site-packages/IPython/lib/pretty.py in pretty(self, obj)
392 if cls is not object \
393 and callable(cls.__dict__.get('__repr__')):
--> 394 return _repr_pprint(obj, self, cycle)
395
396 return _default_pprint(obj, self, cycle)
/opt/conda/lib/python3.7/site-packages/IPython/lib/pretty.py in _repr_pprint(obj, p, cycle)
698 """A pprint that just redirects to the normal repr function."""
699 # Find newlines and replace them with p.break_()
--> 700 output = repr(obj)
701 lines = output.splitlines()
702 with p.group():
/opt/conda/lib/python3.7/site-packages/torch/_tensor.py in __repr__(self, tensor_contents)
425 )
426 # All strings are unicode in Python 3.
--> 427 return torch._tensor_str._str(self, tensor_contents=tensor_contents)
428
429 def backward(
/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py in _str(self, tensor_contents)
635 with torch.no_grad():
636 guard = torch._C._DisableFuncTorch()
--> 637 return _str_intern(self, tensor_contents=tensor_contents)
/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py in _str_intern(inp, tensor_contents)
566 tensor_str = _tensor_str(self.to_dense(), indent)
567 else:
--> 568 tensor_str = _tensor_str(self, indent)
569
570 if self.layout != torch.strided:
/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py in _tensor_str(self, indent)
326 )
327 else:
--> 328 formatter = _Formatter(get_summarized_data(self) if summarize else self)
329 return _tensor_str_with_formatter(self, indent, summarize, formatter)
330
/opt/conda/lib/python3.7/site-packages/torch/_tensor_str.py in __init__(self, tensor)
114 else:
115 nonzero_finite_vals = torch.masked_select(
--> 116 tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
117 )
118
RuntimeError: numel: integer multiplication overflow