I’m building a tool for training vision models, the main focus of the tool is to make everything “just works”. An easy to use CLI binary you can use to inference and train models, without needing to tweak many different model/training parameters.
Problem
And a problem I’m having is that I can’t seem to find a good way to load datasets. The default loader is great but it doesn’t “adapt” well to different model sizes and system specs, like the training device will go OOM depending on the model size and system specs.
Or maybe there’s a formula to calculate the optimal parameters for the loader but I can’t figure it out.
Another thing is that when loading the dataset, some extra processing of the tensors (resizing, masking) is required before they can be used in training. I think caching the tensors after they are processed will be a good idea, so I tried implementing my own dataset loader.
Attempt
The dataset loader I came up with have two caching method: memory, chunked with disk. If the device memory have enough space, it’ll load all the processed tensor into the device memory. Otherwise it’ll split them into chunks (the chunk size is also calculated with the available memory on the device) and save them onto the disk, and loading them chunk by chunk when training.
But for some reason the chunked with disk caching method will cause OOM when loading a chunk, even though the memory have enough space for the chunk (based on the size of the .pth chunk file). From my research .pth do some compression stuff when saving, but the chunk capacity is calculated using the width, height, channels and dtype of the tensors (along with the available memory on the device) so there shouldn’t be a calculation error there.
Conclusion
So I guess my questions are:
Can the default dataset loader achieve the level of adaptability that so model training can use as much resources as the system provides but not crashes itself.
Is this even an reasonable approach? (I’m aware that letting the model/training parameter be automatically configured might not lead to the best model accuracy)
The vanilla DataLoader class will load images “lazily” for each batch. Not all at once. If you’re getting OOM, it’s likely not due to the DataLoader. More likely due to some combination of batchsize, model size, image size, optimizer, dtype, etc.
Indeed, there for I’m trying to find a way to make the DataLoader use as much resources while not causing OOM. Using the model size, image size, optimizer, dtype, etc… to calculate the parameters pass to the DataLoader (if that’s even possible).
Sounds like you just want something to dynamically set the batch size.
It’s possible but just a bit challenging.
You’ll want to make a key:query for all of the possible optimizers and the multiple of parameters in overhead each needs, along with possible variants. For example, RMSProp without momentum or centered is something like 1x overhead while RMSProp with momentum and centered is 3x overhead.
You’ll need to determine the dtype.
You’ll need to iterate over each conv2d layer and calculate the following as a baseline footprint: dtype × [(o × c × z × z + o) + (b × c × x × y) + (b × o × x_out × y_out) + (b × c × z × z × x_out × y_out)] where o is out_channels, z is kernel size on one dim, c is in_channels, x is local input size on the x dim, y is local input size on the y dim, b for batch size, x_out is layer output size on the x dim, and y_out is the layer output size on the y dim. Dtype is how many bits are needed per number. I.e. float32 is 4. That doesn’t account for stride, padding, dilation, etc.
Of course, you’re looking to optimize for b(batch size), so you’ll need to perform some algebra to solve for it, and then find the max of b to where the memory footprint still fits in the max memory.
Because it’s quite a convoluted calculation, most people just find the max batch size through trail and error, while also considering not to make it exceed 256, or you run into other problems with effective learning. See “On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima" by Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang (2016).
The calculation do seems pretty convoluted. After doing some more research I found a package called torchinfo which can roughly calculate the forward/backward pass size. With the help of the package, I also figured out why my implementation of the loader is causing OOM. It’s because the forward/backward pass size exceeds the memory capacity of the device.
After some testing I think my loader implementation have more potential to adapt better and gaining better performance across different system.