I’ve recently coded up a version of the mixed scale dense network from, http://www.pnas.org/content/pnas/early/2017/12/21/1715832114.full.pdf ,
While creating the architecture was not difficult there’s an issue with GPU usage exceeding the card memory. Using the same model size as in the paper(200 layers, 1 channel wide) and same input image size 512x512, running a single example for training takes up more that the 16GB on my card. The model is very useful for my application, medical imaging, because one of it’s features is a very low number of parameters. This model size only has 187k params. Essentially it takes the output of all the layers before layer i, concatenates them along the channel dimension and uses them as input for layer i. There’s a lot of concatenating and relection padding happening and I suspect I’m doing something very memory inefficient in torch.cat() or other definitions of my modules/graph.
Does anyone have an idea as to the zeroth order thing I should look at for where this huge memory usage is coming from?