Decomposing one model into two and training them with different batch size


I created one model composed of two “parts” (though it is wrapped as one model). I would like to decompose this model into two models because I want to train each model with different batch sizes. So for example, the first model/part would use a batch size of 64, and the second/part one would use a batch size of 128.
For the forward pass, there is no issue because I can run the “forward” method for the first model by decomposing it into batches. The issue is for the gradient. The autograd in Pytorch seems to compute (and allocate memory) according to the batch size of the second model, and I do not know how to “tell” Pytorch to treat the first part with a different batch size. Hence, it is allocating way too much memory than I would like it to.

Thanks for the help!

I don’t have a suggestion on how to fix this, sorry - but I am curious, what task are you using this sort of approach for? Are the inputs to the two models different (apart from the batch size)?

Thanks for your reply! You’re right, maybe if I explain my architecture, you might realize there is a better way to do it and that I’m probably doing something “not super intelligent” :sweat_smile::

I have a sequence of time frames (nw, nh, nt) where nw and nh are the width/height of the images, and nt is the number of time samples for the sequence. What I want to do is:

(1) For a given example, I first extract some features independently for each of the nw * nh pixels by analyzing the time axis using a bunch of conv1d (+pool+etc.) layers. This is the most important step of my approach because most of the information I want to extract from the sequences of frames is carried by the time axis. So to pass one sequence of time frames of shape (nw, nh, nt), I reshape that into a 3d tensor of shape (nw * nh, 1, nt) (I need the dummy 1 on the second axis for conv1d), and I pass it through the first part of the model. So here, what i am really doing is just taking a bunch of 1d time signals, and independently extracting features from them. The output of my first model is therefore something like (nw * nh, n_channels_out), where n_channels_out is the number of features I got from the time axis for each pixel.

(2) So far, I have not introduced any “spatial” info into my network. So I take the output of part 1, I reshape it into (nw, nh, n_channels_out), and I pass it into a conventional network made of conv2d layers (+usual batch norm, etc.), and my output is a (nw,nh,1) tensor (and then I appy a MSE loss - I’m doing a regression). For the second part of my model, 1 example is a tensor of shape (nw, nh, n_channels_out).

The reason why I don’t just do a conv3d instead of decomposing the sequence into a conv1d-type network followed by a conv2d is because of the cost. I might be wrong, but like I said, most of the important info is obtained from the conv1d network, and the second part of the network is really to “smooth” things out a little bit.

Now my problem: when I tell PyTorch to use, say batches of 64, it indeed uses 64 tensor of shape (nw, nh, nt). For the second part of the model, that’s ok, no problem (it allocates what I expect for the memory). However, for the first part, since I have to reshape everything in (nw * nh,1,nt), if I have batches of “64”, in fact for part 1, it’s really a batch of 64 * nw * nh. And then PyTprch allocates a huge amount of memory for the conv1d layers. I can modify the forward method in part 1 of my model (basically introduce a for loop and decompose the 64 * nw * nh into smaller batches), so the forward does not allocate too much memory on the GPU. However, during the backprop, PyTorch does not know that I want to compute my gradient for part 1 in smaller batches, so it allocates a ridiculously huge amount of memory. It’s quite frustrating because basically the bottleneck is the conv1d part…

So I was thinking that if I could split my model into two (part 1 and part 2), then I could probably control the batch size for each model separately. I would need to connect the two models throughout the backprop in some way. I feel this probably exists, since GANs would probably need this type of backprop structure.

I was wondering if there was an easy fix to this (or maybe using graphs, etc. - but I’m still relatively new at this).