Per-batch loss for semi-supervised learning

I have a semi-supervised problem as follows:
I only know ground-truth for batches of examples, e.g. for batch 1 with examples b1=(e1,e2,…) there should be at least one high value from the outputs o1=(o1,o2,…) while for batch 2 there shouldnt be any high outputs. Is there a way to setup a per-batch loss such as

L=(max(o1,o2,...)-E(b))**2

or

L=E(b)*log(max(o1,o2,...)

or the like, where E is the known expectation for batch b?

At inference time I’d like to send instances and not batches, to get back individual outputs- so I blv each example must generate a prediction.

Hi,

The two loss you’ve written above are differentiable. So you can use these as the loss for your network. Do you have any issue with doing that?

The losses shown e.g. here don’t seem to have access to the entire batch but rather only single x,y values, if I understood correctly.
Another example here , and the forum won’t let me post a third link as I am too new.
Anyway while I am happy to try writing the loss (which iiuc would require implementing forward and backward methods as in the first link) I think it may not have access to the full batch, which I need for this case. I’ll try implementing a simple example, perhaps the x,y are batches of input,output and not individuals.

Also now that I think of it, differentiating max(f(x1),f(x2),…) is not obvious for me. I suppose I can just select the max of f and then evaluate the derivative for this point only.

At a math forum I did see an approach where max(f(x),g(x)) = ( f(x)+g(x) + abs(f(x)-g(x)) )/2 and then one need only consider when the abs() changes sign but I blv this is a different case, as my inputs x_i are different for each output.

Hi,

The first link your posted is about extending the autograd engine for things it does not support.
Everything you need here is supported, so no need for you to implement any backward :slight_smile:

Usually, given a batch of samples, your model will give you one output score (?) for each sample (the oi).
Then you can write a function that given these, will compute the scalar loss on that batch.
For the first one, that would be:

def my_loss(outs):
  return (outs.max() - out.max()) ** 2

Then you can do as usual:

outs = model(inp)
loss = my_loss(outs)
opt.zero_grad()
loss.backward()
opt.step()

Thanks, I will give a whirl.