I have a few approaches, but here are two. Let’s start with the one I like the most:
- Loss contribution masking: apply all the functions to the incoming input, and get the resulting vector
[func1(input), func2(input), ..., func10(input)]
and then perform the internal product (dot product) with the one hot encoded label vector. This will give you as loss just the value of the matching function. Example:class = 2 -> [0, 1, 0, ...]
, the loss will beloss = 0 * func1(input) + 1 * func2(input) + 0 * func3(input) + ...
. - Hash Label to select function: just compute the hash of the label tensor given its values; this will index a dictionary to get the function you want to compute.
For both you might need to play a bit with concatenation and map.
The first approach wastes some compute but it’s much more flexile (for instance you can use soft labels).
I hope this will help you, let me know if you need any further clarification.