Hello to all,
Suppose I have defined my model as a standard class that inherits torch.nn.Module.
I was wondering if it is possible to encapsulate the training routine of my model as a function of this class and to
perform training by creating a model object and then calling the relevant function
In other words, is the following possible (or good practice)?
""" Perform training for a number of epochs """
my_model = MyModel()
Thanks a lot
Yes it is possible. In general: nearly everything that can be done with classes can be done by inheriting
torch.nn.Module and defining the missing things yourself.
Whether it is good practice is hard to tell. If it helps, I can ensure you, that I have done something similar quite often as it is one of the easiest ways to define a API. Another approach would be to only define a closure for each network. I’ve done this to write a generic trainer class.
Hello, I am new to the community. Excited to be here.
As i am a new user of pytorch facing a similar problem, could someone provide an example on how the train method would look like inside the class?
For example in my code below i call self.forward() in the fit method, but is that correct? Calling self.model throws the following error: TypeError: forward() takes 1 positional argument but 2 were given
def __init__(self,h=[500,300,100,300,500], x_dim=None, dropout=0.1, lr=0.1):
self.x_dim = x_dim
self.h = h
self.n_layers = len(h)
self.dropout = dropout
self.lr = lr
self.model = nn.ModuleDict()
for i in np.arange(self.n_layers):
n_nodes = self.h[i]
if i == int(self.n_layers/2):
self.model['latent_layer'] = nn.Linear(x, n_nodes, bias=True)
self.model['hidden%d'%i] = nn.Linear(x, n_nodes, bias=True)
x = n_nodes
self.output = nn.Linear(x, self.x_dim,bias=True)
for name,layer in self.model.items():
x = F.relu(layer(x))
output = self.output(x) # linear output
def fit(self,loader, epochs=10, batch_size=128):
loss = RMSELoss()
opt = self.optimizer()
n_minibatches = len(loader)
print('training on %d samples'%(n_minibatches*batch_size))
for epoch in range(epochs):
epoch_loss = 0
for batch_x, batch_y in loader:
prediction = self.forward(batch_x.float())
batch_loss = loss(prediction,batch_y.float())
opt.zero_grad() # clear gradients for next train
batch_loss.backward() # backpropagation, compute gradients
opt.step() # apply gradients
epoch_loss += batch_loss.item()
print('[Epoch %d] loss: %.3f'%(epoch+1, epoch_loss/n_minibatches))
Any help would be appreciated!
Thank you for your time
You could call
self(batch_x.float()) directly, which would be similar to the
model(data) call, as it will call into
I’m quite new to pytorch and I have a question, regarding this topic.
If, as I read from @justusschock and @ptrblck, one could implement a training routine inside model definition class, is there a reason why generally you see only the
forward function defined and not the
evalutate etc? Is it related somehow to performance?
I think it depends on your coding style and which abstractions you would like to use.
E.g. I personally prefer to isolate the model, optimizer, loss function, and write a training routine using these objects.
You could pass the optimizer and loss function to the model and use a class function for training, if you think making these objects part of the model is clearer.
@ptrblck I am trying to do this now, do you have any examples? I’m having some serious trouble calling an instance of my model in the training script and giving it the attributes I’d like
I found this guy: DiscoGAN/image_translation.py at master · ptrblck/DiscoGAN · GitHub
Instead could you write this as a class
Train(model) where in the constructor
self.args = args just gets the arguments from
argparse and you can then write a method that uses a similar training loop to the one you have?
model can be any of the objects in
I am trying to do this, and keep getting an attributerror after calling an instance of my model using
super().__init__(args) within the constructor of my
I don’t have a code snippet ready, as I’m not using the approach of putting everything in a “global” class.
What kind of errors are you seeing using your approach and could you post a code snippet, which would reproduce this error?