Implement a custom function inside the model

Hello everyone!

I am trying to implement the function inside the model, however, no progress so far… :confused:
I have a function which takes three arguments, one of them is an input (image transformed to PyTorch tensor) and two more numbers and returns processed image transformed to PyTorch tensor:

def process_image(image_in, variable_1, variable_2):

  img = image_in.cpu().detach().numpy()

  #do whatever to get img_out

  img_out = transforms.ToTensor()(img_out)

  return img_out

then I am trying to add the function in the simple NN:

class MyNet(nn.Module):
    def __init__(self):
      self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1)
      self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1)
      self.conv3 = nn.Conv2d(32, 64, 3, 1, padding=1)
      self.fc1 = nn.Linear(1024, 300)
      self.dropout1 = nn.Dropout(0.5)
      self.fc2 = nn.Linear(300, 10)
    def forward(self, x):
      x = process_image(x, variable_1, variable_2)
      x = F.relu(self.conv1(x))
      x = F.max_pool2d(x, 2, 2)
      x = F.relu(self.conv2(x))
      x = F.max_pool2d(x, 2, 2)
      x = F.relu(self.conv3(x))
      x = F.max_pool2d(x, 2, 2)
      x = x.view(-1, 1024)
      x = F.relu(self.fc1(x))
      x = self.dropout1(x)
      x = self.fc2(x)
      return x

I guess I am doing it wrong, because, so far, Colab just crashes without any particular explanation… :confounded:
Also, I’d like to make variable_1 and variable_2 trainable. How can I do that?

Thanks in advance for helping a newbie!

1 Like

Could you run the code locally on your machine to get a proper error message?
Note that all operations applied on a numpy array won’t be tracked by Autograd, if you don’t implement the backwards method manually.
However, if you use variable_1 and variable_2 on your tensor, it should work.


Thanks a lot!
In this case, shall I implement backwards method, how described here? Defining backward() function in nn.module?

1 Like

If you cannot use PyTorch methods and would need to use numpy functions, then you would need to implement the backward manually.
However, if all needed operations can be called in PyTorch, the backward pass will just work.