How to make a dataloader from a frozen network?

I’d like to use a frozen network as a type of preprocessing for a later network which is being trained. Instead of freezing layers, I’d like to combine my original train_dataloader and the frozen network into a new wrapper-dataloader. How should I best approach that?

I’ve written

class NetLoader:
    def __init__(self, loader, net):
        self.loader = loader
        self.net = net

    def __iter__(self):
        with torch.no_grad():
            for X, Y in self.loader:
                yield self.net(X), Y

and I use this as an iterable “dataloader” later when I train my proper network. I see multiple issues:

  • it’s not a proper DataLoader class
  • I’m not sure if/where to use torch.no_grad(), net.requires_grad=False and/or self.net(X).detach()

What is the proper way to do that?

The main goal is to build a network layer-wise (instead of picking out parts from a full specification).

Any other suggestions?

Why don’t you just create a nn.Module that wraps both? You can paralellize that, you don’t have to worry about devices as a simple cuda call will affect all the submodules properly.
This seems a headache to me.

Anyway, torch nograd is enough for the purpose.

You mean something like

class Net(nn.Module):
    def __init__(self, fixed_net, train_net):
        super().__init__()
        self.fixed_net = fixed_net
        self.train_net = train_net
        
    def forward(self, X):
        with torch.no_grad():
            X = self.fixed_net(X)
            
        return self.train_net(X)
    
    def parameters(self):
        return self.train_net.parameters()

?
Am I missing something here?

So no_grad is enough and no future operation will somehow leak into my fixed net and perform unnecessary computation?

Indeed,
that context manager disables autograd engine. See: no_grad — PyTorch 1.12 documentation

You may also want to overwrite load_state_dict and state_dict if you are not interested in saving/loading both at the same time. This way, when you allocate net or run in parallel the network, everything will work smoothly.

Thanks. I see. I try that.

For no_grad, I understand that it does what I want during the forward pass. I was concerned that unneeded computation could happened during later optim.zero_grad() or out.backward() or optim.step().

Werll for optim.step it just computes the backward for those parameters you pass to the optimizer. in your case, as you are rewritting the method parameters, you will never pass those.
Besides, there will not be graph to backprop, it won’t be computer or stored. Lastly, you would need a leaf node before the first network (and there will be no).