Dealing with multiple sequences in T5ForConditionalGeneration

Here I remark that the output of individual sequences are different from batched sequences using T5ForConditionalGeneration

Here is an example to reproduce the result: batch_size=1 vs 2

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

import pandas as pd

from torchtext.legacy.data import Field, BucketIterator, TabularDataset

# prepare data
data = {"text": ["summarize: i am very happy. i am very happy",
        "summarize: i am very safe. i am very safe"],
        "summary": ["i am very happy", "i am very safe"]}
df = pd.DataFrame(data)
df.to_csv("debug.csv", index=False)

# set tokenizer of T5-small

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-small")

pad_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

unk_index = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)

eos_index = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

model.resize_token_embeddings(len(tokenizer))

model.to("cuda")

from transformers import T5Tokenizer, T5ForConditionalGeneration

SRC = Field(tokenize = tokenizer.encode, 
            use_vocab=False,
            lower = False,
            init_token = None, 
            eos_token = eos_index, 
            pad_token=pad_index,
            unk_token=unk_index,
            include_lengths = True)

TRG = Field(tokenize = tokenizer.encode, 
            use_vocab=False,
            init_token = None, 
            eos_token = eos_index, 
            pad_token=pad_index,
            unk_token=unk_index,
            lower = False)

fields = {"text": ("src", SRC), "summary": ("trg", TRG)}
train_data, valid_data, test_data = TabularDataset.splits(
    path="./",
    train="debug.csv",
    validation="debug.csv",
    test="debug.csv",
    format='csv',
    fields=fields)


BATCH_SIZE = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     sort_within_batch = True,
     sort_key = lambda x : len(x.src),
     device = device)

for i, batch in enumerate(train_iterator):
    src, _ = batch.src
    trg = batch.trg
    logits = model(input_ids=src.view(src.shape[1], src.shape[0]),
                   labels=trg.view(trg.shape[1], trg.shape[0])).logits
    X = logits.view(logits.size(1), logits.size(0), logits.size(-1))
    X = F.softmax(X, dim=-1)
    ids = X.argmax(dim=-1)
    y = tokenizer.batch_decode(sequences=ids, skip_special_tokens=False)
    z = tokenizer.batch_decode(sequences=trg, skip_special_tokens=False)
    print(" ".join(y))
    print("*********")
    print(" ".join(z))
    print("*********")