Thanks, Patrick. I edited your code to make device cpu or mps and post below the output from each on my MBP M1 Max. You can see that it reverses the order in one case: using mps with batch_first=True
. If you use mps with batch_first=False
it is the correct order.
device = torch.device("cpu")
print(device)
batch_size = 2
seq_len = 7
input_size = 10
# default setup with batch_first=False
model = nn.LSTM(input_size=input_size, hidden_size=5, num_layers=1, batch_first=False).to(device)
x = torch.randn(seq_len, batch_size, input_size).to(device)
print(x.shape)
out, (h, c) = model(x)
print(out.shape)
# batch_first=True
model = nn.LSTM(input_size=input_size, hidden_size=5, num_layers=1, batch_first=True).to(device)
x = torch.randn(batch_size, seq_len, input_size).to(device)
print(x.shape)
out, (h, c) = model(x)
print(out.shape)
cpu
torch.Size([7, 2, 10])
torch.Size([7, 2, 5])
torch.Size([2, 7, 10])
torch.Size([2, 7, 5])
--------------------------------------------------------
device = torch.device("mps")
print(device)
batch_size = 2
seq_len = 7
input_size = 10
# default setup with batch_first=False
model = nn.LSTM(input_size=input_size, hidden_size=5, num_layers=1, batch_first=False).to(device)
x = torch.randn(seq_len, batch_size, input_size).to(device)
print(x.shape)
out, (h, c) = model(x)
print(out.shape)
# batch_first=True
model = nn.LSTM(input_size=input_size, hidden_size=5, num_layers=1, batch_first=True).to(device)
x = torch.randn(batch_size, seq_len, input_size).to(device)
print(x.shape)
out, (h, c) = model(x)
print(out.shape)
mps
torch.Size([7, 2, 10])
torch.Size([7, 2, 5])
torch.Size([2, 7, 10])
torch.Size([7, 2, 5])