How to save output of a layer during training

Hi, here again,

I have a question, say in my training code, I want to save the output of a module or layer for use in another layer (while training), ensuring the output of the said layer or module does not change during training- how do I do that? Do I define it in the forward pass? in the code below say I want to use the output at D, do I save the output or call self.block4 again? Apologies if this question is stupid I am just learning. Thank you!

class block1(nn.Module):
...........
class block2(nn.Module):
...........
class block3(nn.Module):
..................
class block4(nn.Module):

...............
class exped (nn.Module):




class RDCDH (nn.Module):
    def __init__(self):
        super (RDCDH, self).__init__ ()
        self.conv = conv
        self.block1 = block1
        self.block2 = block2
        self.block3 = block1
        self.block4 = block1
        self.trans = exped

    def forward(self, x):
        out = self.conv (x)
        out = self.block1(out, before_Trans=False) #A
        out = self.block2 (out, before_trans=False) #B
        out = self.block3 (out, before_Trans=False) #C
        out = self.block4(out, before_trans  = False) #D
        out = self.upsam_trans(out) #P
        out = self.block4(out, before_trans = True)#Q
        out  = self.trans(out)
        return out

Many thanks.

Sorry, can you please clarify what you mean? When you say ‘save the output of a module or layer for use in another layer’, do you mean save it and use it within the same model? If so, you can just save it in forward as you say. Just call it something else other than out and then use it later on. So in your example, I changed it to be out_D. And then wherever else you need to use that output, you can just pass that in. PyTorch will take care of the rest, during training and inference.

1 Like

Hi Karthik,

Thank you for your response.

This is exactly what I meant. To use it in the same model during training. So going by out_D, does it mean anywhere I call it in the forward pass (after defining it) it gives me exactly the same values, without changing?

Thank you so much for your help!

Hi, yes exactly. Anywhere you use out_D will have the same value set by self.block4.

But keep in mind, you are calling block4 twice in forward. Is that what you want to do? If you want block4 to have different parameters each time, you should instantiate a new block.

Got that.

Yes I called block4 twice in forward but the output is not thesame. The output depends on whether before_trans is True or False. Is this allowed?
Many thanks

class block4(nn.Module):
    def __init__(self,before_Trans=True):
        super (block4, self).__init__ ()
        self.before_Trans = before_Trans

    def forward(self, x):
        out = self.conv(x)

        if self.before_Trans:
            return out

        else:
            return self.block1(out, before_Trans = True)
    

Yes, it will work just fine since block4 just seems to be a wrapper around block1. But now it seems like you are calling block1 multiple times. What does block1 have inside?

The concern is if block1 has trainable layers (e.g., Conv2D, Linear etc), by calling it multiple times this way, you may be breaking PyTorch’s Autograd engine which automatically does backpropogation to train the model.

If each block is meant to have a different set of weights at the end, then it is probably going to be safer for you to create new blocks instead of re-using block1 multiple times.

Thank you Karthik, noted.

One final question.

For a nn.Module, is it a must for the forward pass to have x or any other argument(s)? What if I give it no argument. For example below, I made use of the sub module create addition in the main model addition, however, the forward pass of create addition has no argument because I didn’t have to. Is this allowed?

class create_addition (nn.Module):
    def __init__(self, depth, d1_before, d2_before, u1_before, u1_after, up_shape=(64, 64), before_Trans=True):
        super (create_addition, self).__init__ ()
        self.before_Trans = before_Trans
        self.d1_before = d1_before
        self.d2_before = d2_before
        self.u1_before = u1_before
        self.u1_after = u1_after
        self.in_planes = (2 * growth_rate)
        self.up_shape = up_shape
        self.block4 = block4

    def forward(self):
        u2_out = self.block4 (self.u1_before) #I

        if self.before_Trans:
            return u2_out #I
        else:
            after_trans = torch.cat ([self.d1_before, self.d2_before, self.d2_before, u2_out], 1) #J
            after_trans = nn.functional.interpolate (after_trans, self.up_shape, mode='bilinear') #K
            return after_trans

Main model

class addition(nn.Module):
    def __init__(self, depth, reduction=.5, bottleneck=True, dropRate=0.0,
                 before_Trans=True):
        super (addition, self).__init__ ()
        self.d1_before = None 
        self.d1_after = None 
        self.d2_before = None
        self.d2_after = None
        self.u1_before = None
        self.u1_after = None 
        self.before_Trans = before_Trans
        self.create_addition= create_addition1(up_shape=(64, 64), before_Trans=True)
        self.create_addtion2 = create_addition2(self.d1_before, self.d1_after)
        self.create_addition3 = create_addition3(self.d1_before, self.d1_after, self.d2_before, self.d2_after)
       



    def forward(self, x):

        out = self.conv (x)
        self.d1_before = self.create_addition(out.clone(), before_Trans=True)
        self.d1_after = self.create_addition(out, before_Trans = False)
        self.d2_before = self.create_addition2(self.d1_after.clone(), before_Trans=True)
        self.d2_after = self.create_addition2(self.d1_after, before_Trans = False)
        self.u1_before = self.create_addition2(self.d2_after.clone(), before_Trans = True)
        self.u1_after = self.create_addition2(self.d2_after, before_Trans = False)
        u2_before = self.create_addition3(self.u1_after.clone(), before_Trans = True)
        u2_after = self.create_addition3(self.u1_after, before_Trans = False)
        return u2_after
       ```
Many thanks.