Hi all,
I create a new SuperModule
class which allows for a lot of great high-level functionality without sacrificing ANY model flexibility. You define your models exactly as you would with nn.Module
, except now you have access to fit()
, evaluate()
, and predict()
functions, can use a ton of nice Callbacks, Constraints, and Regularizers - and there’s a sweet tqdm Progress Bar.
It inherits directly from nn.Module
, so you can still do manual training if necessary and access all of its members. Also, there is a fit_loader()
function to fit directly on DataLoader
objects.
The code is available at the 3rd party torchsample repository. My motivation is that people can take this code and tailor it to their liking or expand on it.
Here’s a small example of the main functionality (full example in the torchsample README):
from torchsample.modules import SuperModule
class Network(SuperModule):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.fc1 = nn.Linear(1600, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 1600)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
# constraints
# -> Nonneg on Conv layers applied at end of every epoch
# -> UnitNorm on FC layers applied every 4 batches
from torchsample.constraints import NonNeg, UnitNorm
constraints = [NonNeg(frequency=1, unit='batch', module_filter='*conv*'),
UnitNorm(frequency=4, unit='batch', module_filter='*fc*')]
# regularizers
# -> L1 on Conv layers
# -> L2 on FC layers
from torchsample.regularizers import L1Regularizer, L2Regularizer
regularizers = [L1Regularizer(scale=1e-6, module_filter='*conv*'),
L2Regularizer(scale=1e-6, module_filter='*fc*')]
# callbacks
# lambda callback
from torchsample.callbacks import LambdaCallback
callbacks = [LambdaCallback(on_train_end=lambda logs: print('TRAINING FINISHED'))]
model = Network()
model.set_loss(F.nll_loss)
model.set_optimizer(optim.Adadelta, lr=1.0)
model.set_regularizers(regularizers)
model.set_constraints(constraints)
model.set_callbacks(callbacks)
# fit model
model.fit(x_train, y_train,
validation_data=(x_test, y_test),
nb_epoch=5,
batch_size=128,
verbose=1)
# evaluate on test data
val_loss = model.evaluate(x_test, y_test)
# predict on input data
y_pred = model.predict(x_test)
Example of the progress bar: