How to register buffer without polluting state_dict?

I have a fixed tensor that needs to be moved to GPU according to DataParallel’s whims, but I don’t want to bloat the state_dict with this tensor (since it never changes). I can’t seem to figure out a good solution. [for a little more context, this tensor is the weight in F.conv_transpose2d] The things I’ve considered:

  1. override _apply - for single GPU, this seems to work fine, but for multi GPU, DataParallel seems to rely on replicate… which doesn’t use _apply and instead iterates over all buffers and parameters [the same things that goes into the state_dict]
  2. register_buffer (currently using) - code runs, but state_dict is bloated with unnecesary crap; also if this tensor is ever modified (accidentally or otherwise), entire network is silently corrupted
  3. custom code for saving checkpoints by manually removing this tensor from state_dict every single time - ugly
  4. move tensor to appropriate GPU during every forward pass - probably slow
  5. construct tensor during every forward pass - probably even slower
  6. any other ideas?

Just create a flag to check whether the tensor has been allocated or not. Allocate it once.
Another option is just to delete it when you stop training.

Anyway I think it’s not bad it stays as a buffer. It’s somehow a necessary parameter to run inference. What would you do for exporting the model, keep telling everyone those values? Seems a bit annoying to me.

The whole point of allocating every time is that it would allocate on the correct GPU. If it’s only allocated once, then that would be option 4). From my tests, it seems option 4) is about as fast as 2), so maybe I can just do that (assuming it works in multi GPU context…)

A checkpoint is saved every epoch, so deleting it before saving is a variation of 3). I guess that’s an option.

Well, ideally this buffer is instantiated with the rest of the net in *.py, so “keep telling everyone those values” isn’t much of an issue.

With recent pytorch versions, you can use register_buffer(persistent=False) to prevent the buffer from appearing in the state_dct. This is present in 1.7 and not present in 1.4, not 100% sure when it appeared.

1 Like