Evaluate multiple models with one evaluator results weird metrics

Hello.

I have several similar models.(with little variation of some hyperparameters)
I am trying to evaluate these models for same dataset with one evaluator, to reduce eval time.

So I made custom evaluator like below:

metrics_group = {
    experiment_name: {'Top-1': Accuracy(output_transform=lambda output: (output[0][experiment_name], output[1][experiment_name])),
                      'Top-5': TopKCategoricalAccuracy(output_transform=lambda output: (output[0][experiment_name], output[1][experiment_name]))}
    for experiment_name in experiment_settings.keys()
}

def _inference(engine, batch):
    data = {experiment_name: _prepare_batch(batch,
                                            device=experiment_settings[experiment_name]['device_config']['host'],
                                            non_blocking=True)
            for experiment_name in experiment_components.keys()}
    y_preds = {}
    with torch.no_grad():
        for experiment_name, experiment_component in experiment_components.items():
            experiment_component['model'].eval()
            y_preds[experiment_name] = experiment_component['model'](data[experiment_name][0]) # Is this work simultaneously? IDK...
        return y_preds, {experiment_name: xy[1] for experiment_name, xy in data.items()}

evaluator = Engine(_inference)

for experiment_name, metrics in metrics_group.items():
    for metric_name, metric in metrics.items():
        metric.attach(evaluator, '{0}/{1}'.format(experiment_name, metric_name))

As you can see, I attached ‘Top-1’ and ‘Top-5’ to measure every model but the results says:

print(evaluator.state.metrics)
{'A/Top-1': 0.1328125,
 'A/Top-5': 0.6015625,
 'B/Top-1': 0.1328125,
 'B/Top-5': 0.6015625,
 'C/Top-1': 0.1328125,
 'C/Top-5': 0.6015625,
 'D/Top-1': 0.1328125,
 'D/Top-5': 0.6015625}

I confused with this result so I checked evaluator’s output, but:

print(evaluator.state.output[0]['A']
tensor([[-0.3352, -2.0896, -0.1886,  ...,  0.3233,  0.5214, -0.0945],
        [-0.3316, -2.0419, -0.1706,  ...,  0.2498,  0.5802, -0.0909],
        [-0.3056, -2.0395, -0.2438,  ...,  0.2266,  0.5328, -0.1037],
        ...,
        [-0.3001, -2.0332, -0.2573,  ...,  0.3248,  0.5248, -0.0653],
        [-0.3233, -2.0452, -0.1502,  ...,  0.2362,  0.5626, -0.0756],
        [-0.3304, -2.0787, -0.1769,  ...,  0.2427,  0.5589, -0.0379]],
       device='cuda:0')

print(evaluator.state.output[0]['B'])
tensor([[ 0.9059, -0.0701,  2.3905,  ...,  0.9909,  2.3744,  0.6785],
        [ 0.8492, -0.0840,  2.3493,  ...,  0.9192,  2.3537,  0.6840],
        [ 0.8844, -0.1049,  2.3237,  ...,  0.8993,  2.3218,  0.6488],
        ...,
        [ 0.8758, -0.1526,  2.3674,  ...,  1.0029,  2.3428,  0.6295],
        [ 0.8461, -0.0652,  2.3746,  ...,  0.9019,  2.2553,  0.6627],
        [ 0.9023, -0.0956,  2.3242,  ...,  0.9092,  2.2694,  0.6629]],
       device='cuda:0')

print(torch.eq(evaluator.state.output[0]['A'], evaluator.state.output[0]['B']))
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0')

I think there’s no problem with ‘predicting’ phase.
Is there a problem with attaching metrics? Any suggestion will be welcome.

@FruitVinegar I think the problem is related to output transform definition with lambdas. Lambdas can not store experiment_name internally, so all output_transform fetch the last experiment_name.
For example, take a look:

ot_list = []

for n in ["a", "b", "c", "d"]:
    ot_list.append(lambda _: print(n))

for o in ot_list:
    o(None)
> d
> d
> d
> d

In order to do what you would like I’d use functools.partial

from functools import partial

ot_list = []

def ot_func(output, exp_name):
    print(output, exp_name)


for n in ["a", "b", "c", "d"]:
    ot_list.append(partial(ot_func, exp_name=n))

for o in ot_list:
    o(1)

> 1 a
> 1 b
> 1 c
> 1 d

Oh my. The problem was totally not related to ignite.
Thanks for your help!