Multiple network with one loss function, non-differentiable, is it okay with backpropagation?

Hello I’m trying to make the augmentation function which chooses to augmentate or not based on the contents of image.

To make this, I’m thinking about using two different network.

One is to classify augmentating the image or not The ohter is to classify the image based on label (Main task)

Network A makes output with False or True (0,1) to decide input image should be augmentated or not. Based on result of Network A, the image is augmentated and fed in to Network B to classify it with comparing to ground truth label.

I’m wondering that is this model will correctly backpropagate and learns correctly to increase the perfomance of image classification.

Because Network A and B is seprated image augmentation functtion which is non-differentiable.

Sorry for my poor explanation, but I want to know is this method will theoretically work.

Hi Chan!

You are correct that because the information that Network A passes to
Network B is a discrete variable (zero or one), that connection will not
be usefully differentiable, and you will not be able to use backpropagation
and gradient descent to train Network A.

Here are two possible approaches:

Plan A: Have Network A pass a continuous value (say, any value between
zero and one, or perhaps from -inf to inf) to Network B, and have Network
B not augment the image if the value is zero, fully augment the image if the
value is one, and “partially” augment the image if the value is in between. Of
course such a scheme only makes sense if your use case has a sensible
notion of “partial augmentation.”

Plan B: Have Network B always classify both the unaugmented image and
the augmented image, compute the loss function for both images separately,
and combine the two loss-function values together using the continuous
value passed to Network B by Network A (say, by taking a weighted average
of the two loss values). Then backpropagate that combined loss value through
Network A (and use the resulting gradients to train Network A).

Of course, whether or not such schemes will work in practice will depend on
the details of your actual problem, but they illustrate plausible ways to make the
connection between Networks A and B continuous and usefully differentiable.

Best.

K. Frank