Thanks for your reply! You’re right, maybe if I explain my architecture, you might realize there is a better way to do it and that I’m probably doing something “not super intelligent” :
I have a sequence of time frames (nw, nh, nt) where nw and nh are the width/height of the images, and nt is the number of time samples for the sequence. What I want to do is:
(1) For a given example, I first extract some features independently for each of the nw * nh pixels by analyzing the time axis using a bunch of conv1d (+pool+etc.) layers. This is the most important step of my approach because most of the information I want to extract from the sequences of frames is carried by the time axis. So to pass one sequence of time frames of shape (nw, nh, nt), I reshape that into a 3d tensor of shape (nw * nh, 1, nt) (I need the dummy 1 on the second axis for conv1d), and I pass it through the first part of the model. So here, what i am really doing is just taking a bunch of 1d time signals, and independently extracting features from them. The output of my first model is therefore something like (nw * nh, n_channels_out), where n_channels_out is the number of features I got from the time axis for each pixel.
(2) So far, I have not introduced any “spatial” info into my network. So I take the output of part 1, I reshape it into (nw, nh, n_channels_out), and I pass it into a conventional network made of conv2d layers (+usual batch norm, etc.), and my output is a (nw,nh,1) tensor (and then I appy a MSE loss - I’m doing a regression). For the second part of my model, 1 example is a tensor of shape (nw, nh, n_channels_out).
The reason why I don’t just do a conv3d instead of decomposing the sequence into a conv1d-type network followed by a conv2d is because of the cost. I might be wrong, but like I said, most of the important info is obtained from the conv1d network, and the second part of the network is really to “smooth” things out a little bit.
Now my problem: when I tell PyTorch to use, say batches of 64, it indeed uses 64 tensor of shape (nw, nh, nt). For the second part of the model, that’s ok, no problem (it allocates what I expect for the memory). However, for the first part, since I have to reshape everything in (nw * nh,1,nt), if I have batches of “64”, in fact for part 1, it’s really a batch of 64 * nw * nh. And then PyTprch allocates a huge amount of memory for the conv1d layers. I can modify the forward method in part 1 of my model (basically introduce a for loop and decompose the 64 * nw * nh into smaller batches), so the forward does not allocate too much memory on the GPU. However, during the backprop, PyTorch does not know that I want to compute my gradient for part 1 in smaller batches, so it allocates a ridiculously huge amount of memory. It’s quite frustrating because basically the bottleneck is the conv1d part…
So I was thinking that if I could split my model into two (part 1 and part 2), then I could probably control the batch size for each model separately. I would need to connect the two models throughout the backprop in some way. I feel this probably exists, since GANs would probably need this type of backprop structure.
I was wondering if there was an easy fix to this (or maybe using graphs, etc. - but I’m still relatively new at this).
Thanks!