Issue with using ramdisk and map_location w/ torch.load()

I’m using a colab notebook to generate samples from OpenAI’s Jukebox model. The notebook has some code to reduce the amount of system RAM when loading the model. Using the default model provided by OpenAI, it generates audio just fine. I have a finetuned model, which when using the same notebook/code, generates “garbage” audio. However, if I remove this code it works fine.

So, given the same notebook and code, modulo the function below it will either work correctly or not.
It works if loaded via:

checkpoint = t.load(restore, map_location=t.device('cpu'))

and doesn’t with this code:

checkpoint = t.load(restore)
In the non-working the model is loaded via

checkpoint = t.load(restore, map_location=memory_map)

memory_map_idx = 0
def memory_map(storage, location):
    global memory_map_idx
    memory_map_idx += 1
    s = 'disk_tensors/' + str(memory_map_idx) + '.bint'
    f = open(s,"wb")*storage.element_size()-1)
    new_storage = storage.__class__.from_file(s, 1, storage.size())
    del storage
    return new_storage

I’ve tried a bunch of different things, with the same results, I’m hoping someone might have some ideas on what I could try, or what might be the issue.

Things I’ve tried:

  • Tons of different versions of pytorch
  • I’m not running in a parallel env, but I did try handling concurrency just in case, no difference.
  • Removed logic that deleted the on disk files after torch.load(), no difference. I would have thought this was the issue, but no.