SuperModule for Keras-like training with Callbacks, Constraints, and Progress Bar

Hi all,

I create a new SuperModule class which allows for a lot of great high-level functionality without sacrificing ANY model flexibility. You define your models exactly as you would with nn.Module, except now you have access to fit(), evaluate(), and predict() functions, can use a ton of nice Callbacks, Constraints, and Regularizers - and there’s a sweet tqdm Progress Bar.

It inherits directly from nn.Module, so you can still do manual training if necessary and access all of its members. Also, there is a fit_loader() function to fit directly on DataLoader objects.

The code is available at the 3rd party torchsample repository. My motivation is that people can take this code and tailor it to their liking or expand on it.

Here’s a small example of the main functionality (full example in the torchsample README):

from torchsample.modules import SuperModule
class Network(SuperModule):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 1600)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = self.fc2(x)
        return F.log_softmax(x)

# constraints
# -> Nonneg on Conv layers applied at end of every epoch
# -> UnitNorm on FC layers applied every 4 batches
from torchsample.constraints import NonNeg, UnitNorm
constraints = [NonNeg(frequency=1, unit='batch', module_filter='*conv*'),
               UnitNorm(frequency=4, unit='batch', module_filter='*fc*')]

# regularizers 
# -> L1 on Conv layers
# -> L2 on FC layers
from torchsample.regularizers import L1Regularizer, L2Regularizer
regularizers = [L1Regularizer(scale=1e-6, module_filter='*conv*'),
                L2Regularizer(scale=1e-6, module_filter='*fc*')]

# callbacks
# lambda callback
from torchsample.callbacks import LambdaCallback
callbacks = [LambdaCallback(on_train_end=lambda logs: print('TRAINING FINISHED'))]

model = Network()
model.set_optimizer(optim.Adadelta, lr=1.0)

# fit model, y_train, 
          validation_data=(x_test, y_test),

# evaluate on test data
val_loss = model.evaluate(x_test, y_test)

# predict on input data
y_pred = model.predict(x_test)

Example of the progress bar:


just as an update, I implemented the following callbacks which I know people have been asking for:

  • ModelCheckpoint - saves model weights during training here
  • EarlyStopping - terminates training if loss doesnt improve here
  • LearningRateScheduler - schedule LR according to current epoch/LR/loss here
  • ReduceLROnPlateau - reduces LR if loss doesnt improve here
  • CSVLogger - logs train/val loss and other metrics to csv file during training here

To note, these can all be used in manual training by instantiating the class and calling the appropriate function - e.g. callback.on_batch_begin() or callback.on_epoch_begin()

Happy to answer any questions or take requests.

Examples of all the callbacks:

from torchsample.callbacks import ModelCheckpoint
callbacks = [ModelCheckpoint(file='/users/ncullen/desktop/test/model_{epoch}_{loss}.pt',

from torchsample.callbacks import CSVLogger
callbacks = [CSVLogger(file='/users/ncullen/desktop/test/logger.csv',append=True)]

from torchsample.callbacks import EarlyStopping
callbacks = [EarlyStopping(monitor='val_loss',

from torchsample.callbacks import LearningRateScheduler
save_lrs = []
def lr_schedule(epoch, lr, **kwargs):
    """exponential decay"""
    new_lr = lr[0] * 1e-5**(epoch / 200)
    return new_lr
callbacks = [LearningRateScheduler(lr_schedule)]

from torchsample.callbacks import ReduceLROnPlateau
callbacks = [ReduceLROnPlateau(monitor='val_loss', 

Hi Nick,

Thanks for your contribution! The code is pretty neat!

  1. Have you chosen a License? Would love to contribute to it.

If I may suggest some features I’d like to have re: model trainer:

  1. GPU support
  2. Ability to use PyTorch pretrained models (SuperModel can’t be used as a subclass for official VGG, Resnet, etc models)
  3. Accuracy logging

I’ve already implemented 1 and 2 in my fork. I’d be happy to send a PR if you’re OK with those features.



Thanks, yes haha i’ll add a license right now.

For features:

  • GPU support is easy I’ll implement it now, but i’ll look at yours if you already did it.
  • pretrained is not my area, but happy to accept your code
  • yes, accuracy logging and other metrics are definitely on to-do… will prob implement accuracy in the next few days

Note that a lot of the code is changing rapidly (for instance, just removed - th_gather2d/th_gather3d

  • in favor of - th_gather_nd - … and made - th_meshgrid - work for any number of dimensions (e.g. th_meshgrid(2,3,4)) , and have been adding a lot to the transforms. Hopefully the SuperModule will remain quite stable though so definitely happy to have contributions there.
1 Like

For step 2. I meant to be able to use models written by someone else, including Pytorch’s pretrained models. In this case, I don’t think it is possible to use inheritance as that would involve changing Pytorch’s code. I’ve replaced the inheritance strategy with object composition by adding a class field _model to SuperModel/Trainer which can be used for training and inference.

I’ll push my code to Github so you can have a look at my changes. Let me know what you think.
Happy to see the code evolve so quickly :slight_smile:

Thanks again!

1 Like

This is what I meant by object composition:

PS: I’ve also replaced the imports DataLoader and TensorDataset to the Pytorch’s implementation (I manually create DataLoaders that pin_memory = true). Not sure this affects something else in your repo though. And I’m not sure either, what the differences are between Pytorch’s DataLoader and yours.

Edit2: I’ve implemented Accuracy logging here I haven’t entirely finished testing it but I’d love to know if this is what you had in mind?

Happy to use a different approach that doesn’t break backwards compatibility.


If you look, I’ve changed a lot of the code - now includes multiple inputs and targets support, and optional target support - and cuda support. I also made it so fit doesnt default to fit_loader.

I LOVE the metrics class… I’m thinking how to integrate it and how to deal with the History callback.

Also, breaking compatibility is not an issue haha… this is mostly code for my own shit that i hacked together in a few days about two weeks ago… compatibility SHOULD be broken.

For ModuleTrainer I don’t like the way you have to pass in a model to the trainer… It just adds another layer of composition (like DataSet and DataLoader) which I think is cumbersome, I’d be willing to add the ModuleTrainer as an ADDITIONAL class, so people could either choose SuperModule as a drop-in for nn.Module, or use ModuleTrainer as an extra layer on their actual nn.Module class. Do you think that would add additioinal benefit? Happy to merge ModuleTrainer and let you develop it separately+concurrently for now.


I rather see the code changing fast than not at all :stuck_out_tongue:

Thanks for adding CUDA support. I’ll try it out as soon as a model I’m working on finishes training.

Happy you like the Metrics class! I saw that you’ve implemented it, so I guess no need for a PR? It also looks like you’re handling the interaction with the History callback?

Everything moves so fast in ML that I bet all I do now will be outdated a year from now, so that kinda validates your point about backwards compatibilty, haha.

I don’t think object composition is wrong per se (I’m actually a big fan haha). But I see your point about it being cumbersome to pass a model to trainer. The issue I have though, is that I can’t use SuperModule as I’m using pretrained models from PyTorch vision repo that do not extend SuperModule. Thus, why passing a model to ModuleTrainer is my only option? I’ve also thought of merging SuperModule and a Module methods during runtime, but this seems like a hacky idea and probably not worth it.

Adding an extra class sounds like a good compromise. However, if we go that way, I’d love to have all training code in a single class e.g. TrainerModule implements all fit, evaluate, etc methods and SuperModule simply forwards those calls to a TrainerModule (maybe stored as class attribute of SuperModule). Less room for errors. Is this what you had in mind?

I’d also like to add some tests to the metrics classes. Do you have any thoughts on how you’d structure them?

1 Like

yeah that makes sense that the engine should only be implemented once and I get that use case. The class you’re proposing might also make training GANs a lot easier as well.

Tests are obviously much needed… The metrics are sort of their own thing and only interact with the TQDM class - not History . Here’s how to use them:

from torchsample.metrics import CategoricalAccuracy

# get a prediction -> shape = (samples, classes)
y_pred = model.predict(x_train)
# y_train is ground truth  -> shape = (samples,)

acc = CategoricalAccuracy(top_k=2)
score = acc(y_pred, y_train)

So, they act ust like a loss function… testing them should be straight-forward then.

Right now, I’m focusing on finishing the predict and evaluate functions, as well as what to do with the *_loader() functions because it gets a lot more complex to handle multiple inputs/targets and especially to handle an optional target.

Another reason I like your idea is because it might make hyper-param optimization (something like a MetaModelTrainer) much easier … some way to train a model over multiple datasets/splits, or train multiple models over a single dataset, and keep track of those experiments through a CSV-type logger would be great.

1 Like

I’ll review how other python projects structure their test code and infrastructure, there may be a few ideas there that could be useful to this project.

Maybe non *_loader() functions are not needed? To simplify usage there could be some method helpers like:

def loader_from_tensor(tensor, batch_size=64, shuffle=False, pin_memory=False):
    y_unused = torch.Tensor(tensor.size(0))
    loader =, y_unused), batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

loader = loader_from_tensor(tensor)
loss = trainer.evaluate_loader(loader)
y = trainer.predict_loader(loader)

I like your MetaModelTrainer idea. Combined with KFold, sounds like a sweet thing to have! The CSV-like logger sounds pretty cool too.

Unrelated question: do you have any plans to merge your image transforms into pytorch/vision? I think this would avoid the need to have 2 implementations of ImageFolder and TensorDataset? And would make you code more accessible to PyTorch users.

1 Like

I’ll make an issue on github to discuss this stuff further.

Re: transforms and the datasets, no plans right now unfortunately. I would but I don’t know what the general plans are for and torchvision.transforms