is the code attached below bug free?
def train_model():
global train_dataset, valid_dataset
torch.manual_seed(42)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
sampler=train_sampler,
num_workers=0,
drop_last=True) # print(len(train_loader))
'''valid_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
)'''
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
#sampler=valid_sampler,
shuffle=False,
num_workers=0,
drop_last=True)
#xm.master_print(f"Train for {len(train_loader)} steps per epoch")
LOGGER.debug(f"Train for {len(train_loader)} steps per epoch")
# Scale learning rate to num cores
lr = 0.0001 * xm.xrt_world_size()
# Get loss function, optimizer, and model
device = xm.xla_device()
#model = model()
'''
for param in model.base_model.parameters(): # freeze some layers
param.requires_grad = False'''
global model
model = model.to(device)
criterion = torch.nn.BCEWithLogitsLoss() # MSELoss
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
scheduler = OneCycleLR(optimizer,
lr,
div_factor=10.0,
final_div_factor=50.0,
epochs=NUM_EPOCH,
steps_per_epoch=len(train_loader))
def train_loop_fn(loader):
tracker = xm.RateTracker()
model.train()
#xm.master_print('Epoch {}/{}'.format(epoch, num_epochs - 1))
LOGGER.debug('Epoch {}/{}'.format(epoch, num_epochs - 1))
#xm.master_print('-' * 10)
LOGGER.debug('-' * 10)
scheduler.step()
running_loss = 0.0
tk0 = tqdm(loader, total=int(len(train_loader)))
counter = 0
for bi, d in enumerate(tk0):
inputs = d["image"]
labels = d["label"].view(-1, 1)
inputs = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
optimizer.zero_grad()
#with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
#loss = criterion(outputs, torch.max(labels, 1)[1])
loss.backward()
xm.optimizer_step(optimizer)
running_loss += loss.item() * inputs.size(0)
#print(running_loss)
counter += 1
tk0.set_postfix(loss=(running_loss / (counter * BATCH_SIZE)))
epoch_loss = running_loss / len(train_loader)
#xm.master_print('Training Loss: {:.8f}'.format(epoch_loss))
LOGGER.debug('Training Loss: {:.8f}'.format(epoch_loss))
def test_loop_fn(loader):
tk0 = tqdm(loader, total=int(len(valid_loader)))
counter = 0
total_samples, correct = 0, 0
for bi, d in enumerate(tk0):
inputs = d["image"]
labels = d["label"].view(-1, 1)
inputs = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
optimizer.zero_grad()
with torch.no_grad():
output = model(inputs)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(labels.view_as(pred)).sum().item()
total_samples += inputs.size()[0]
accuracy = 100.0 * correct / total_samples
#print('[xla:{}] Accuracy={:.4f}%'.format(xm.get_ordinal(), accuracy), flush=True)
model.train()
return accuracy
# Train - valid loop
accuracy = []
for epoch in range(1, num_epochs + 1):
start = time.time()
para_loader = pl.ParallelLoader(train_loader, [device])
train_loop_fn(para_loader.per_device_loader(device))
para_loader = pl.ParallelLoader(valid_loader, [device])
accuracy.append(test_loop_fn(para_loader.per_device_loader(device)))
#xm.master_print("Finished training epoch {} Val-Acc {:.4f} in {:.4f} sec".format(epoch, accuracy[-1], time.time() - start))
LOGGER.debug("Finished training epoch {} Val-Acc {:.4f} in {:.4f} sec".format(epoch, accuracy[-1], time.time() - start))
valauc = accuracy[-1]
if(epoch>4):
xm.save(model.state_dict(), f"./epoch{epoch}valauc{valauc}.bin")
return accuracy
def _mp_fn(rank, flags):
global acc_list
torch.set_default_tensor_type('torch.FloatTensor')
res = train_model()
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
1st epochs train log looks like this :
2020-05-09 12:21:29,371 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,710 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,721 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,911 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:30,561 DEBUG Epoch 1/6
2020-05-09 12:21:30,564 DEBUG ----------
2020-05-09 12:21:31,065 DEBUG Epoch 1/6
2020-05-09 12:21:31,076 DEBUG ----------
2020-05-09 12:21:31,120 DEBUG Epoch 1/6
2020-05-09 12:21:31,130 DEBUG ----------
2020-05-09 12:21:31,390 DEBUG Epoch 1/6
2020-05-09 12:21:31,426 DEBUG ----------
2020-05-09 12:21:32,629 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,573 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,748 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,883 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:34,889 DEBUG Epoch 1/6
2020-05-09 12:21:34,914 DEBUG ----------
2020-05-09 12:21:35,573 DEBUG Epoch 1/6
2020-05-09 12:21:35,613 DEBUG ----------
2020-05-09 12:21:35,823 DEBUG Epoch 1/6
2020-05-09 12:21:35,845 DEBUG ----------
2020-05-09 12:21:36,128 DEBUG Epoch 1/6
2020-05-09 12:21:36,171 DEBUG ----------
2020-05-09 12:35:08,162 DEBUG Training Loss: 11.22450873
2020-05-09 12:35:08,172 DEBUG Training Loss: 11.19612112
2020-05-09 12:35:08,309 DEBUG Training Loss: 11.18398799
2020-05-09 12:35:08,352 DEBUG Training Loss: 11.16665337
2020-05-09 12:35:08,362 DEBUG Training Loss: 11.20103131
2020-05-09 12:35:08,357 DEBUG Training Loss: 11.19919075
2020-05-09 12:35:08,368 DEBUG Training Loss: 11.19310062
2020-05-09 12:35:08,386 DEBUG Training Loss: 11.21970569
2020-05-09 12:39:31,562 DEBUG Finished training epoch 1 Val-Acc 50.5348 in 1080.4523 sec
the validation accuracy calculation is slow,somehow validation accuracy is using 1 core for calculation where in train phase it is using 8 cores,how do i solve this issue? i need to make the validation calculation fast, also in all epoch i see same validation accuracy, maybe i have bug in my code? another thing is,if i train this model for 8-10 epoch then kaggle kernel doesn’t finish commit,it gives error that’s not visible,so maybe somewhere in my code i am requesting more memory and getting OOM for that? also in my code if i try sampler=valid_sampler for valid_loader then i get error. please help me find bugs in my code,thank you a lot in advance