Storing values in metric or returning them in engine's dict for plotting purposes

Hi! I have a question about how to best plot things. Suppose I have a model and that during inference I am getting the error per example in a batch. I then want to correlate the error of the predicted y with a property of the input x. So ultimately I’m plotting a huge amount of property(x) and error(y,y_pred) on a scatterplot, and then calculating their correlation. I assume that the best way to do this would be to log every batch x properties and y errors using a custom metric, and then outputing the correlation between the property(x) and error(y,y_pred) for all x,y pairs at the end through the compute() function.

However, what if I also wanted a scatterplot of all the property(x) and error(y,y_pred) for all pairs x,y? I understand that metric stores the history, but would it be too convoluted to also plot all the property(x) and error(y,y_pred) from within Metric’s compute() function? It seems complicated because I create my visdom Visualizer in a file called “train.py” (where I also run my evaluator) and the custom Metric is in its own file called metric.py.

The only other option I can think of is to store all the property(x) and error(y,y_pred) outside of the Metric.py file by returning the property(x) and error(y,y_pred) in the dict returned by the evaluator engine in train.py, and then creating an event handler that will save those values every iteration until the end of the validation set. After this, all the property(x) and error(y,y_pred) are saved in a file, and I just repull them all to do a scatter plot.

I was wondering what is the best/more ignite/most elegant way to do this.

Thanks!

Hi @pytorchnewbie,

IMO, it would be better to separate the logic of metrics computation and logging for visualization.

def inference_step(e, batch):
    # ...
    x = batch["x"]
    y_pred = model(x)
    return {
        "y_pred": y_pred,
    }

infer_engine = Engine(inference_step)
infer_engine.state.viz_data = {}  # data storage for visualization

@infer_engine.on(Events.ITERATION_COMPLETED)
def compute_iterationwise_error(engine):
    batch = engine.state.batch
    x = batch["x"]
    y = batch["y"]
    y_pred = engine.state.output["y_pred"]

    err = compute_error(y,y_pred)
    engine.state.viz_data["err"] = err
    engine.state.viz_data["x_properties"] = get_properties(x)


plot_every = 1 # 2 or 10

@infer_engine.on(Events.ITERATION_COMPLETED(every=plot_every))
def plot_infer_data(engine):
    viz_data = engine.state.viz_data
    visdom_plot_data(viz_data["x_properties"], viz_data["err"])

For me, it is not clear from the description, how often you would like to compute a metric:

  • epoch-wise (1 per the dataset)
  • batch-wise (1 per batch)

Thanks! To answer your question: I want to update per epoch, since this is the validation set (so theres only one epoch so i just want the relationship of the validation set)

It seems strange to use a function compute_iterationwise_error within which you get the error and the properties and storing them in a dict every iteration since this is extremely similar to using a Metric. Is the only reason you’re doing this this way because Metric doesn’t allow you to return dicts? The compute() needs to give a scalar?

Additionally, at the end of accumulating all these errors and properties, I want to compute the correlation between them. This is a final scalar based on all the collected data - this seems like Metric could be especially useful.

Finally, if you plot the values in real time, so one property and one error per iteration, then thats great but what if i wanted to then udpate the title of the plot at the very end to include the correlation between the properties and the errors? I dont think visdom has a function for just updating the title of a plot on its own without adding an additional data point…

It seems strange to use a function compute_iterationwise_error within which you get the error and the properties and storing them in a dict every iteration since this is extremely similar to using a Metric. Is the only reason you’re doing this this way because Metric doesn’t allow you to return dicts? The compute() needs to give a scalar?

Yes, true. But as I didn’t understand what you wanted to plot, I thought about batchwise logging…

Additionally, at the end of accumulating all these errors and properties, I want to compute the correlation between them. This is a final scalar based on all the collected data - this seems like Metric could be especially useful.

Maybe, you can use EpochMetric to accumulate errors and properties and setup compute_fn to compute the correlation. See docs.
Then computed value you can separately log with visdom…

Finally, if you plot the values in real time, so one property and one error per iteration, then thats great but what if i wanted to then udpate the title of the plot at the very end to include the correlation between the properties and the errors? I dont think visdom has a function for just updating the title of a plot on its own without adding an additional data point…

You can do it like that:

w = vis.line(X=np.array([1, 2, 3]), Y=np.array([1, 2, 3]), opts=dict(xlabel="x", ylabel="y", title="title-to-change"))
vis.update_window_opts(w, opts={"title": "new_title"})

"Yes, true. But as I didn’t understand what you wanted to plot, I thought about batchwise logging… "

—> ok but so if I use metric, then I won’t have access to the list to plot. So I’m still stuck because I won’t be able to plot the error, property data points (i.e. “separately log with visdom” would require me to use your solution above anyways). It woulndn’t make sense to use both a metric using EpochLoss and all the code you wrote above with storing infer_engine.state.viz_data since now both a metric and this dictionary are storing the data points…

However, the solution you provided works, because I can just add a function that computes the correlation on infer_engine.state.viz_data.

I was just hoping for you to confirm that the only reason we are not using metric is specifically because we want to plot the error, property data points - there is no way to do this if we use Metric right?

Well, you can manually pick collected errors/target from EpochMetric like EpochMetric._predictions and EpochMetric._targets. But this may look hacky…

Probably, do it manually, as there are logics mixed in what you would like to do.

What do you mean by manually in your third sentence? Is it different from manually in your first? Why is it hacky?

Why is it hacky?

Because we access private members (EpochMetric._*)…

What do you mean by manually ?

Idea is to setup EpochMetric like that :

def inference_step(e, batch):
    # ...
    x = batch["x"]
    y_pred = model(x)

    # compute error and properties
    err = compute_error(y,y_pred)
    x_properties = get_properties(x)

    return {
        "y_pred": y_pred,
        "err": err,
        "x_properties": x_properties,
    }

infer_engine = Engine(inference_step)

em = EpochMetric(
    compute_fn=compute_corr, 
    output_transform=lambda out: out["err"], out["x_properties"]
)
# we set em._predictions as x_properties
# and em._targets as err

em.attach(infer_engine, "overall_correlation")

@infer_engine.on(Events.COMPLETED)
def plot_infer_data(engine):
    visdom_plot_data(em._targets, em._predictions)

I didn’t check if this works, but here is the idea…

What do you mean by manually in your third sentence? Is it different from manually in your first?

There are two options to do almost the same thing. First option is to collect data manually as in the first answer. Second option is to use EpochMetric and plot data after the computation.
I would prefer the 1st option as I can control everything…

Yeah I see, but so where would the plotting go in here? do you recommend doing em._predictions and em._targets?

Also, I’m concerned about the warning on EpochMetrics: “Current implementation does not work with distributed computations. Results are not gather across all devices and computed results are valid for a single device only.”

  • Does device here mean machine or gpu? Can I stay on one machine and use multiple gpus?

-Does this warning apply if I create my own custom metric like such?

class Mape(Metric):
    def __init__(self, output_transform=lambda x: x):

        self._num_examples = None
        self._sum_percentages = None
        super(Mape, self).__init__(output_transform=output_transform)

    @reinit__is_reduced
    def reset(self):
        self._num_examples = 0
        self._sum_percentages = 0
        super(Mape, self).reset()

    @reinit__is_reduced
    def update(self, output):
        y_pred, y = output

        errors = torch.abs(y_pred - y.view_as(y_pred))
        errors_divided = errors / y.view_as(y_pred) * 100

        # pdb.set_trace()

        self._num_examples += y.shape[0] * y.shape[1]
        self._sum_percentages += torch.sum(errors_divided).item()

    @sync_all_reduce("_num_examples", "_sum_percentages")
    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError(
                "CustomAccuracy must have at least one example before it can be computed."
            )
        return self._sum_percentages / self._num_examples

The problem of EpochMetric within distributed data parallelism is that each process collects its own list of _predictions and _targets. There is no syncronization of the whole seen histories across participating processes. We can also imagine that this can potentially raise OOM when we do all_gather operation (which is missing currently, but we still plan to add it).

In your Mape implementation, _num_examples and _sum_percentages looks like to be scalars, so sync_all_reduce will collect the data across the processes and in compute method all processes will manipulate “total” self._sum_percentages, self._num_examples. So, it is OK.

Does device here mean machine or gpu?

Device means single gpu here

Got it - so is there a difference between EpochMetric and custom metrics with respect to distributed training? And what case scenarios (i.e. not scalars) can I not use ignite if I’m using distributed training? For example, I’m assuming the following won’t work?

class Correlation(Metric):
    def __init__(self, output_transform=lambda x: x):

        self._list_of_ETC = []
        self._list_of_errors = []
        super(Correlation, self).__init__(output_transform=output_transform)

    @reinit__is_reduced
    def reset(self):
        self._list_of_ETC = []
        self._list_of_errors = []
        super(Correlation, self).reset()

    @reinit__is_reduced
    def update(self, output):

        mini_list_of_ETC, mini_list_of_errors = output

        self._list_of_ETC.append(mini_list_of_ETC)
        self._list_of_errors.append(mini_list_of_errors)
        pdb.set_trace()

    @sync_all_reduce("_list_of_ETC", "_list_of_errors")
    def compute(self):

        if len(self._list_of_ETC) == 0:
            raise NotComputableError(
                "ETC must have at least 1 element."
            )

        if len(self._list_of_errors) == 0:
            raise NotComputableError(
                "ETC must have at least 1 element."
            )

        if len(self._list_of_ETC) != len(self._list_of_errors):
            raise NotComputableError(
                "Number of ETC must be equal to number of errors!"
            )

        corr, p_value = pearsonr(self._list_of_ETC, self._list_of_errors)

        return corr

Is there anyway to use Ignite and distributed training in this case? I’m 99% of the time going to be using more than 1 gpu in the near future.
Thanks!

And what case scenarios (i.e. not scalars) can I not use ignite if I’m using distributed training?

There are some abstractions that are not working yet in DDP: EpochMetric, LRFinder, all contrib regression metrics that use EpochMetric.

so is there a difference between EpochMetric and custom metrics with respect to distributed training?

Yes, in the way of how to collect data across processes. By default we collect internal metric’s data via all reduce operation (sum across processes). In case of non-scalar mertic’s data, it is not defined how to perform a sum of those objects and certainly it wont give a correct result (think about computing median for example).

For example, I’m assuming the following won’t work?

No, it wont work, as you need to collect _list_of_ETC and _list_of_errors accross processes, so for example total _list_of_errors should be a concat of _list_of_errors by rank, right ?

You have to use all gather op to do that

import torch.distributed as dist

class Correlation(Metric):

   def compute(self):
       # roughly something like that
       list_of_ETC = torch.tensor(_list_of_ETC).to(this_current_device)
       dist.all_gather(_list_of_ETC) 
       self._list_of_ETC = _list_of_ETC.tolist() 
       # etc

PS: btw if you would like to help us with this issue : Make EpochMetric work in DDP · Issue #978 · pytorch/ignite · GitHub
such that everyone could benefit of using EpochMetric in DDP…
If so, you can submit a draft PR and we could guide you on how to better implement and test it…

I see but the example that I gave you above with the Correlation class is not part of EpochMetric. Therefore, can I modify your sentence to say that currently, there is no way of getting Metric to work across multiple gpus in all case scenarios easily?

And is it safe to say that to fix this, every time the user needs to make a custom metric that does not use @sync_all_reduce and instead uses dist.all_gather?

Sure I’ll look into it, it’ll depend on time and if someone can help me because I’m still just learning :slight_smile:

Therefore, can I modify your sentence to say that currently, there is no way of getting Metric to work across multiple gpus in all case scenarios easily?

Well, majority of ignite’s metrics like Accuracy, Precision, Recall, ConfusionMatrix etc are working in DDP for their use-cases :slight_smile:

And is it safe to say that to fix this, every time the user needs to make a custom metric that does not use @sync_all_reduce and instead uses dist.all_gather?

It depends on the metric and how it should be computed: how to compute accumulators for an iteration. If in update function you do += on tensors, scalars etc. @sync_all_reduce will work without problems and the result will be correct. If you do something else like concat of lists etc, @sync_all_reduce wont work and at compute time we had to collect all the data.

Sure I’ll look into it, it’ll depend on time and if someone can help me because I’m still just learning

We can help with that. Just if anyone could initialize the work and put some code etc…

OK I see. I’m currently working 2 jobs so I’m not sure if I’ll have time but if you outline what needs to be done I can see if i can squeeze it in! Can you let me know concretely what needs to be done?

Thanks

I see. Anyway, i’ll update the issue related to that where I’ll describe what to do in details…

Thanks! One last thing:

I find it strange that there is this problem with distributed training because apparently when you put a model on multiple gpus in pytorch, all the distributed stuff dissapears once it returns its prediction according to: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html#create-model-and-dataparallel .

Therefore, once I return y_pred in the infer_dict of the training engine, nothing is distributed anymore?

Yes, there are at least 2 data parallelisms you can do with PyTorch provided classes :

What we discussed before on collecting data across participating processes etc is applied only for DDP. With DP there is a single process running, so there is no need to collect data etc.

However, PyTorch doc suggests to use DDP and it is faster than DP.

PS: See also how ignite simplifies usage of DDP : https://pytorch.org/ignite/distributed.html

Thanks! but before you said you can’t use current implementation of Ignite on multiple gpus. But you can, by using DataParallel!

Yes, you are right, you can use ignite and its API with DP on multiple GPUs.

PS: I had a shortcut between multiple devices (GPUs) and DDP…

1 Like