Hi i am using a this dataset
with this dataset class and dataloader
class CamelyonDataset(torch.utils.data.Dataset):
def __init__(self, subset='train', dataset_path=dataset_path, transform=None):
self.subset = subset
self.transform = transform
x_filename = subset + "_img_data.npy"
y_filename = subset + "_mask_data.npy"
# load them from the downloaded folder
self._x = np.load(os.path.join(dataset_path, x_filename), allow_pickle=True)
self._y = np.load(os.path.join(dataset_path, y_filename), allow_pickle=True)
# convert to torch tensors
self._x = torch.from_numpy(self._x).float()
self._y = torch.from_numpy(self._y).float()
# change channel order to match pytorch
self._x = self._x.permute(0, 3, 1, 2)
# add channel dimension to y
self._y = self._y.unsqueeze(1)
def __len__(self):
return len(self._x)
def __getitem__(self, idx):
x = self._x[idx]
y = self._y[idx]
if self.transform:
x = self.transform(x)
return x, y
train_loader = torch.utils.data.DataLoader(CamelyonDataset(subset='train'), batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(CamelyonDataset(subset='val'), batch_size=16, shuffle=True)
And here is the training loop:
model = UNet(n_channels=3, n_classes=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
def train(model, train_loader, val_loader, optimizer, criterion, epochs=10):
for epoch in range(epochs):
model.train()
for i, (x, y) in enumerate(train_loader):
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} | Batch {i} | Loss {loss.item()}")
model.eval()
with torch.no_grad():
for i, (x, y) in enumerate(val_loader):
y_pred = model(x)
loss = criterion(y_pred, y)
print(f"Epoch {epoch} | Batch {i} | Loss {loss.item()}")
train(model, train_loader, val_loader, optimizer, criterion, epochs=1)
Even though the loaded .npy
files is less than 2 gb, the training loop uses way too much RAM and I couldn’t make this kaggle notebook work in GPU (kaggle CPU sessions have 32 gb of RAM).
Is there a way for me to partially load or better manage the memory of this dataloader/dataset?