How to create a 3d indexable embedding matrix?

import torch
import torch.nn as nn
import torch.nn.functional as F

def generate_and_add_binary_numbers(num_samples, bit_size, num2_decimal=None):
num1_bin = [‘’.join(str(bit) for bit in torch.randint(0, 2, (bit_size,)).tolist()) for _ in range(num_samples)]
if num2_decimal is None:
num2_bin = [‘’.join(str(bit) for bit in torch.randint(0, 2, (bit_size,)).tolist()) for _ in range(num_samples)]
else:
# Convert the decimal number to a binary string with padding
num2_bin_str = format(num2_decimal, ‘b’).zfill(bit_size)[-bit_size:]
num2_bin = [num2_bin_str] * num_samples # Repeat for all samples

# Convert binary strings to integers for addition
int1 = [int(n, 2) for n in num1_bin]
int2 = [int(n, 2) for n in num2_bin]
sum_int = [i1 + i2 for i1, i2 in zip(int1, int2)]

# Convert the sums back to binary strings with padding
sum_bin = [format(s, 'b').zfill(bit_size)[-bit_size:] for s in sum_int]

return num1_bin, num2_bin, sum_bin

def bin_list_to_tensor(*bin_lists):
tensor_lists = [ for _ in range(len(bin_lists))]
for i, bin_list in enumerate(bin_lists):
for binary_str in bin_list:
binary_list = [int(bit) for bit in binary_str]
binary_tensor = torch.tensor(binary_list, dtype=torch.float32)
tensor_lists[i].append(binary_tensor)
tensor_lists[i] = torch.stack(tensor_lists[i]) # Convert list of tensors to a tensor
return tensor_lists

num_samples = 1000
bit_size = 8
num_epochs = 5000
learning_rate = 0.001

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionModule(nn.Module):
def init(self, n_embd , bit_size = bit_size):
super(AttentionModule, self).init()

    self.key = nn.Linear(n_embd, n_embd, bias=False)
    self.query = nn.Linear(n_embd, n_embd, bias=False)
    self.value = nn.Linear(n_embd, n_embd, bias=False)
    
    self.ln1 = nn.LayerNorm(n_embd)

def forward(self, x):
    
    k = self.key(x)
    q = self.query(x)
    v = self.value(x)
    
    wei = q @ k.transpose(-2,-1) * n_embd**-0.5 
    wei = F.softmax(wei, dim=-1)
    out = wei @ v
    
    out = self.ln1(out + x)
    
    return out

class block(nn.Module):
def init(self, n_embd , bit_size = bit_size , num_layers = 3 ):
super(block, self).init()
self.embedding_layer = nn.Parameter(torch.randn(size=(bit_size, 2, n_embd)))
self.embedding_layer2 = nn.Parameter(torch.randn(size=(bit_size, 2, n_embd)))
self.att_blocks = nn.ModuleList([
nn.Sequential(
AttentionModule(n_embd),

nn.ReLU()

            nn.Sigmoid()
        )
        for _ in range(num_layers)
    ])
    
    self.fc = nn.Linear(n_embd * bit_size *2, bit_size, bias=False)

    
def forward(self, num1 , num2):
    tokenized_num1 = num1.long()
    tokenized_num2 = num2.long()
    embedded_num1 = self.embedding_layer[torch.arange(num1.shape[0]).unsqueeze(1), tokenized_num1]
    embedded_num2 = self.embedding_layer2[torch.arange(num2.shape[0]).unsqueeze(1), tokenized_num2]
    
    print(embedded_num1.shape)
    print()
    
    x = torch.cat((embedded_num1, embedded_num2), dim=1)
    
    for b in self.att_blocks:
        x = b(x)
        
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return torch.sigmoid(x)

n_embd = 4
model = block(n_embd).to(device)

num1, num2, sum_bin = generate_and_add_binary_numbers(num_samples, bit_size , num2_decimal=2)
num1, num2, sum_bin = bin_list_to_tensor(num1, num2, sum_bin)
num1, num2, sum_bin = num1.to(device), num2.to(device), sum_bin.to(device)

out = model(num1, num2)

print(“out:”, out)

those 2 lines give cuda runtime error
one runs normally but rerunning it or running the second breaks it
embedded_num1 = model.embedding_layer[torch.arange(num1.shape[0]).unsqueeze(1), num1.long()]
embedded_num2 = model.embedding_layer2[torch.arange(num2.shape[0]).unsqueeze(1), num2.long()]