Is it possible to shortcut input in forward method?

I want to create my custom loss function, unfortunately I need the input from multiple previous layers.

Input_x -> conv1 -> relu1 -> conv2 -> relu2 -> conv3 -> relu3 -> conv4 -> relu4 -> output

I attached my custom loss function between relu3 and conv4, but I need to evaluate the input from relu1. Normally forward method accept 2 arguments which are self, and x (output from previous layers). If my understanding correct, this is not supported???

Thank you

Whis is supported but you have to write your own module for it. A simple example is this:

class CustomModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1,3,3)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(3,1,3)

    def forward(self, input_batch):
        conv1 = self.conv1(input_batch)
        relu = self.relu(conv1)
        conv2 = self.conv2(relu) 
        return conv1, relu, conv2

If you do the whole loss calculation afterwards (which is what I would always recommend) you have access to all intermediate values. Alternatively you could use forward hooks to obtain the intermediate values you need.

1 Like

Thank you for the reply

I had came up with somewhat adhoc idea

class OutputProxy(nn.Module):
    states = []
    def __init__(self):
        super(OutputProxy, self).__init__()
        
    def forward(self, x):
        OutputProxy.push(x.detach())
        return x
        
    @staticmethod
    def push(item):
        OutputProxy.states.append(item)
    
    @staticmethod
    def pop():
        return OutputProxy.states.pop()

If I add OutputProxy layers after the desire output, and then call OutputProxy.pop() in the custom loss function. If I do this, does it have any caveats like this can be done only on cpu or something like that?
And if I could, would you recommend it?

You should not detach the values if you want to use them for loss calculation later. Other than that, I don’t see any problems einen though I prefer accessing by name rather then accessing by index. But this is just a personal favorite.

1 Like