[fairseq] Possible bug: Mutlingual training with max_tokens blows up in utils.resolve_max_positions

See fairseq/utils.py:367.

# fairseq/utils.py
def resolve_max_positions(*args):
    """Resolve max position constraints from multiple sources."""


    max_positions = None
    for arg in args:
        if max_positions is None:
            max_positions = arg
        elif arg is not None:
            max_positions, arg = _match_types(max_positions, arg)
            if isinstance(arg, float) or isinstance(arg, int):
                max_positions = min(max_positions, arg)
            elif isinstance(arg, dict):
                max_positions = map_value_update(max_positions, arg)
                max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))

    return max_positions

And consider where the multilingual translation task uses the following for max positions:

# fairseq/tasks/multilingual_translation.py
# class MultilingualTranslationTask(FairseqTask):
    def max_positions(self):
        """Return the max sentence length allowed by the task."""
        if len(self.datasets.values()) == 0:
            return {'%s-%s' % (self.args.source_lang, self.args.target_lang):
                    (self.args.max_source_positions, self.args.max_target_positions)}
        return OrderedDict([
            (key, (self.args.max_source_positions, self.args.max_target_positions))
            for split in self.datasets.keys()
            for key in self.datasets[split].datasets.keys()

Which returns a dict like {lang_pairs...: (max1, max2)}.

Then the resolve function tries to match types between all the passed max_positions values. Which can be a number, a tuple, or a dict. Once it sees a dict, max_positions becomes a dict and any later values will error as there’s no meaningful upgrade from dict to dict.