This model does not train, but when I remove all the auxiliary code like the ones that keep track of loss and other metrics, it trains well. I am using the default torch initialization.
class Model(Base):
def __init__(self, config, name
.
.
):
super(Model, self).__init__(config, name)
self.args = Args()
kwargs = locals(); del kwargs['self']
self.args.__dict__.update(kwargs)
#This order matters
self.setup_training_params ()
self.setup_dataset ()
self.setup_model_params ()
self.build_model ()
self.setup_stats_records ()
self.setup_optimizer_functions ()
self.loss_function = nn.NLLLoss()
def build_model(self):
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.conv3 = nn.Conv2d(50, 50, 5, 1)
self.fc1 = nn.Linear(3200, 500)
self.fc2 = nn.Linear(500, self.output_size)
def setup_stats_records(self, *args, **kwargs):
# Model stat records
self.__build_stats__()
self.best_model_criteria = self.test_accuracy
self.best_model = 1e-6, self.cpu().state_dict()
def setup_optimizer_functions(self, *args, **kwargs):
#Loss functions and optimizers
self.optimizer = optim.SGD(self.parameters(),
lr=0.01, momentum=0.1)
def accuracy_function(self, output, labels):
return (output.max(dim=1)[1] == labels).float().mean()
def train_one_batch(self, input_):
ids, (x, ), (label, ) = input_
B, C, H, W = x.size()
x = self.__( x.transpose(1, 3), 'x')
x = self.__( F.relu(self.conv1(x)), 'conv1')
x = self.__( F.max_pool2d(x, 2, 2), 'maxpool1')
x = self.__( F.relu(self.conv2(x)), 'conv2')
x = self.__( F.max_pool2d(x, 2, 2), 'maxpool2')
x = self.__( F.relu(self.conv3(x)), 'conv3')
x = self.__( x.view(B, -1), 'after flattening')
x = self.__( F.relu(self.fc1(x)), 'fc1')
x = self.__( self.fc2(x), 'fc2')
output = F.log_softmax(x, dim=1)
return (output
, self.loss_function(output, label)
, self.accuracy_function(output, label)
)
def do_train(self):
for epoch in range(self.epochs):
self.log.critical('memory consumed : {}'.format(memory_consumed()))
self.epoch = epoch
if epoch % max(1, (self.checkpoint - 1)) == 0:
#self.do_predict()
if self.do_validate() == FLAGS.STOP_TRAINING:
self.log.info('loss trend suggests to stop training')
return
self.train()
losses = []
for j in tqdm(range(self.train_feed.num_batch), desc='Trainer.{}'.format(self.name())):
self.optimizer.zero_grad()
input_ = self.train_feed.next_batch()
_, loss, __ = self.train_one_batch(input_)
print('loss: ', loss.item())
losses.append(loss)
loss.backward()
self.optimizer.step()
epoch_loss = torch.stack(losses).mean()
self.train_loss.append(epoch_loss.item())
self.log.info('-- {} -- loss: {}\n'.format(epoch, epoch_loss))
for m in self.metrics:
m.write_to_file()
return True