See fairseq/utils.py
:367
.
# fairseq/utils.py
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
[...]
max_positions = None
for arg in args:
if max_positions is None:
max_positions = arg
elif arg is not None:
max_positions, arg = _match_types(max_positions, arg)
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
elif isinstance(arg, dict):
max_positions = map_value_update(max_positions, arg)
else:
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))
return max_positions
And consider where the multilingual translation task uses the following for max positions:
# fairseq/tasks/multilingual_translation.py
# class MultilingualTranslationTask(FairseqTask):
[...]
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
return {'%s-%s' % (self.args.source_lang, self.args.target_lang):
(self.args.max_source_positions, self.args.max_target_positions)}
return OrderedDict([
(key, (self.args.max_source_positions, self.args.max_target_positions))
for split in self.datasets.keys()
for key in self.datasets[split].datasets.keys()
])
Which returns a dict like {lang_pairs...: (max1, max2)}
.
Then the resolve function tries to match types between all the passed max_positions values. Which can be a number, a tuple, or a dict. Once it sees a dict, max_positions
becomes a dict and any later values will error as there’s no meaningful upgrade from dict to dict.