I’m using multiprocessing to train a model. I have two processes taking batches from queues (I’m actually using mp.Arrays that are more efficient, but a little trickier) and sharing my model weights.
Every ten batches seen, they make a validation iteration. In a trainBatch() method and in a valBatch() method, I placed calls to model.train() or model.eval(). They go something like this :
def trainBatch(model, input, criterion, target, optimizer): optimizer.zero_grad() model.train() output = model(input) loss = criterion(output, target) loss.backward() optimizer.step() return loss def valBatch(model, input, criterion, target): model.eval() with torch.no_grad(): output = model(input) loss = criterion(output, target) return loss
I’ve observed that without putting those calls, and therefore letting every iteration with train(True) (be it train or validation), the metrics are “rational”. But When I enable those calls, they become really weird, as can be seen here : (out_var and out_var_val are the variance of the logits, which I like to plot for diagnosis)
(orange : the curves while calling model.eval() and model.train(), and in blue without those calls)
I wonder if the problem doesn’t stem from the fact that, in such a framework, there can be a conflict on the layers with two opposite calls. Almost always : mode.eval() will be called while the model is doing a training iteration.
I wonder if this doesn’t screw the training of BatchNorm layers entirely, since the metrics in validation seem to be so poor. Or maybe, calling model.eval() while a training is happening deactivates the BatchNorm entirely without altering the learning ?
What do you think is happening ? Is validation during training prohibited in multiprocessing ?