input is high dimensional and I’m using a simple NN on a small partition of the data (f0) to reduce dimension and then concatenate with the rest of the data (f1) on which nothing has been applied. then I pass them to a frozen model M (pretrained Keras model) and then get the binary output and then I calculate loss. I just want to train the NN model and do backward on its parameters. Is this possible?

The workflow would be possible in PyTorch.
However, if seems you would like to mix PyTorch with a pretrained Keras model afterwards, which won’t work (at least I’m not aware of such a workflow and if someone actually has used it before).

Would it be possible to port the Keras parameters to a PyTorch model?

So there are some layers in my Keras model that are custom defined and the existing libraries don’t support custom layers when converting. But, how would I do that if M was a pretrained PyTorch model?

Also, is it because the output of the Keras model is a numpy array and doesn’t keep the grad_fn and requires_grad that this is not possible?

class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(4, 2)
self.base = models.resnet50()
def forward(self, x):
# split x to x0, x1
x0, x1 = x[:, :4], x[:, 4:]
x0 = self.fc(x0)
# Concatenate
x = torch.cat((x0, x1), 1)
# reshape to fit resnet50 input shape
x = x.view(x.size(0), 3, 224, 224)
x = self.base(x)
return x
model = MyModel()
x = torch.randn(1, 3*224*224 + 2)
output = model(x)

Note that I’ve used resnet50, which expects image tensors as the input, thus I had to reshape the concatenated tensor to [batch_size, 3, 224, 224].
If you are dealing with a pretrained model using only linear layers, you wouldn’t have to reshape it.

Yes, basically you would have to convert the concatenated PyTorch tensor to a numpy array, which will detach the computation graph, so that you won’t be able to calculate the gradients using the final loss.

It depends, what kind of custom layers there are, but as long as there are equivalent PyTorch methods, it should work.
There might be some pitfalls, e.g. flipped kernels etc., but users in this forum ported successfully models before.

Sounds like a good project. I’ll research a bit more (this forum and web) to see what I can do. Thanks for the help again. I’ll reach out if I needed help