Ignite - dataloader with more than input-target pair and custom evaluator

My dataloader returns image, target and a weight map for my loss function. For this, I’ve written an update function for the trainer engine (because I have to pass that weight map to the loss function). But, during evaluation, I don’t know how (or if) it’s possible to include this weight map and attach loss as a metric to the evaluator. prepare_batch expects an input-target pair (I cant’ return more than 2 items) and I don’t think I can pass the weight map through output_transform

Any ideas?

Hi @xen0f0n
So, output transform argument of Loss metric : https://pytorch.org/ignite/metrics.html#ignite.metrics.Loss
does not work in your case ?

If it’s the case, please, provide a minimal code snippet to understand the issue. Thanks!

@vfdev-5 I’m not sure I can use Loss from metrics, no.
So, my update_model function for the trainer engine is

def update_model(engine, batch):
        inputs, targets, weight_map = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        weight_map = weight_map.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets, weight_map)
        return loss.item()

I need to unpack the batch, that has 3 items, and follow the same scheme for evaluation. If I could unpack the batch inside loss_fn (in metrics.Loss), or using prepare_batch in create_supervised_evaluator that would do the trick.

If I understand your case correctly, the following would work

import torch
import torch.nn as nn

from ignite.engine import Engine, create_supervised_evaluator, Events
from ignite.metrics import Loss

data = [
    # inputs, targets, weight_map
    [torch.rand(1), torch.rand(1), torch.rand(2)],
    [torch.rand(1), torch.rand(1), torch.rand(2)],
    [torch.rand(1), torch.rand(1), torch.rand(2)],
    [torch.rand(1), torch.rand(1), torch.rand(2)]

net = nn.Linear(1, 1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

class WMapLoss(nn.Module):
    def forward(self, y_pred, y, weight_map):
        return (y_pred - y).sum() + weight_map.sum()

criterion = WMapLoss()

def update_model(engine, batch):
    inputs, targets, weight_map = batch
    outputs = net(inputs)
    loss = criterion(outputs, targets, weight_map)
    return loss.item()

def eval_fn(engine, batch):
    with torch.no_grad():
        inputs, targets, weight_map = batch
        outputs = net(inputs)
        return outputs, targets, {"weight_map": weight_map}

trainer = Engine(update_model)
evaluator = Engine(eval_fn)

metrics={"val loss": Loss(criterion)}

for name, metric in metrics.items():
    metric.attach(evaluator, name)

def run_validation():

trainer.run(data, max_epochs=2)
> {'val loss': 1.027790516614914}

Tell me if this works or does not work for you.

1 Like

It works! Thanks!
Now, could you explain why?! :stuck_out_tongue:
Why does the eval_fn return a dictionary? (kwargs for some function?)

According to the docs:
process_function (callable): A function receiving a handle to the engine and the current batch in each iteration, and returns data to be stored in the engine’s state.

How does Loss receive the weight_map?

And finally, for other metrics that only need outputs and target, do I have to run another evaluator again to log them?

There is no magic there :slight_smile:

To be able to pass additional attribute to Loss we need to set the output in format (prediction, target, kwargs) as according to the docs:

output_transform (callable): a callable that is used to transform the
:class:~ignite.engine.Engine's process_function's output into the
form expected by the metric.
This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
The output is expected to be a tuple (prediction, target) or
(prediction, target, kwargs) where kwargs is a dictionary of extra
keywords arguments. If extra keywords arguments are provided they are passed to loss_fn.

You can also see how it is interpreted in the update function: https://pytorch.org/ignite/_modules/ignite/metrics/loss.html#Loss

Therefore, eval_fn as being engine’s processing function returns the output exactly in the good format for the Loss function.

For other metrics, you need to filter out the last element (dict) with output_transform and attach to this evaluator (no need another one):

acc = Accuracy(output_transform=lambda out: out[0], out[1])
acc.attach(evaluator, "Accuracy")
1 Like