Training an isolated part of a module


(Haotian Wang) #1

Suppose I have a module that looks like

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub1 = SubModule1()
        self.sub2 = SubModule2()

In the main script, I need to train only self.sub1 and declare in the main script sub_interested = my_module.sub1. Then I do the usual optimizer, sub_interested.call, loss.backward routine on sub_interested. Will changes on the parameters of sub_interested be reflected in my_module.sub1? Do I have to worry about hooks?

Freezing other submodules and call my_module.__call__ is not an option because of input\output formats etc.


(Seonil Son) #2

So you what you need is leaving all the other submodules but a specific module at a time to be tracked by autograd? Couldn’t it be achieved by freezing/undoing parameters taking turns or use detach() properly inside your forward loops?

Or if possible, autograd provides local disabling of gradient computation could be of your use. (I haven’t used this functionality thus not so sure about tho)


(Haotian Wang) #3

The problem is that the actual model is much more complicated than this. sub1's inputs and outputs are in very different formats from my_module. I have to train sub1 with its direct outputs.

The process looks like,

  • train my_module in a separate script with a different loss function
  • freeze sub1, get output from sub1, out1
  • sub2 still requires_grad, feed out1 into sub2 to get out2
  • sub1 still freezing, feed out2 into sub1, but not using sub1’s __call__ because of input format issue, to get out3.
  • out3 and out1 are compared to get a loss.
    In this process, not requiring grad is probably just a partial fix because I have to let go of __call__, which may mess up the autograd mechanism.

(Seonil Son) #4

Sounds equivocal. But talking about input format… is it about ‘in-place operation’-thing? Or using sparse tensors in sub1 but not in sub2 (both sounds not enough for your restriction not to use call)?

Why are you forbidden from using forward() call?


(Haotian Wang) #5

One is using scalar input for Embedding layer, [batch_size], while the other is already in the tensor form [batch_size, vocab_size]. So I need to do something like torch.matmul(output, embedding.weight) instead of embedding(output).

As for the issue with forward, I was referring to this post Any different between model(input) and model.forward(input).


(Seonil Son) #6

If I’m thinking correctly… wouldn’t detach() do a good for you? (e.g.) embedding.weight.detach().

But I don’t know. Your situation might need some unusual or non-trivial solution that is more than I can suggest =\


(Haotian Wang) #7

Thanks for the help!