My Setup:
GPU: Nvidia A100 (40GB Memory)
RAM: 500GB
Dataloader:
pin_memory = true
num_workers = Tried with 2, 4, 8, 12, 16
batch_size = 32
Data Shape per Data unit: I have 2 inputs and a target tensor
torch.Size([-1, 3024]), torch.Size([1, 768]), torch.Size([1, 3792])
I am trying to use the data loader to predict, but I am getting the following error. I trained the same model with batch_size 80, which gave no problems.
I am also using torch.no_grad
before the prediction loop.
Traceback (most recent call last): Traceback
File "auto_encoder_embedding.py", line 485, in <module>
predictions = model.predict(
File "auto_encoder_embedding.py", line 411, in predict
File "./.venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 563, in _next_data
File "./.venv/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 58, in pin_memory
return [pin_memory(sample) for sample in data]
File "./.venv/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 58, in <listcomp>
return [pin_memory(sample) for sample in data]
File "./.venv/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 58, in pin_memory
return [pin_memory(sample) for sample in data]
File "./.venv/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 58, in <listcomp>
return [pin_memory(sample) for sample in data]
File "./.venv/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 50, in pin_memory
return data.pin_memory()
RuntimeError: cuda runtime error (2) : out of memory at ../aten/src/THC/THCCachingHostAllocator.cpp:278
I see that my RAM is overflowing but only max 3.5GB of GPU memory is being used!
Following is my dataset code:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class MyDataSet(Dataset):
def __init__(self, dfe_x1, dfe_x2, dfe_y, sample=None):
self.dfe_x1 = dfe_x1
self.dfe_x2 = dfe_x2
self.dfe_y = dfe_y
if sample:
self.index = random.sample(dfe_y.index.tolist(), sample)
else:
self.index = dfe_y.index.tolist()
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
article_id = self.index[idx]
x1 = torch.from_numpy(np.atleast_2d(self.dfe_x1.loc[article_id].to_numpy(dtype=np.float32)))
x2 = torch.from_numpy(self.dfe_x2.loc[article_id].to_numpy(dtype=np.float32)).unsqueeze(0)
y = torch.from_numpy(self.dfe_y.loc[article_id].to_numpy(dtype=np.float32)).unsqueeze(0)
return x1, x2, y
@staticmethod
def collate_fn(batch):
x1, x2, y = tuple(list(each) for each in zip(*batch))
# x1 - each data point: Shape(-1, 3024)
# x2 - each data point: Shape(768)
# y - each data point: Shape(3792)
y = torch.stack(y).squeeze()
return x1, x2, y
@staticmethod
def get_data_loader(dfe_x1, dfe_x2, dfe_y, sample=None, batch_size=32, shuffle=True, pin_memory=True, **kwargs):
return DataLoader(
MyDataSet(dfe_x1, dfe_x2, dfe_y, sample=sample),
batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory,
collate_fn=MyDataSet.collate_fn, **kwargs
)
On an additional note, I am using pytorch-lightning to avoid writing a lot of the boiler plate code.