Hello everybody,
i am currently trying to implement a loop for testing and training over many experiments and stumbled across a peculiar problem!
I am working with a Triplett Training Model ( Embedding Layer + Triplett Wrapping) and when I build the model step by step in Jupyter Notebook the training loop (triplett_train_and_test) works just fine!
#Batch_size, Epochs, etc.
TRAIN_BATCH_SIZE = BATCH_SIZE
TEST_BATCH_SIZE = int(BATCH_SIZE*VAL_SPLIT)
EPOCHS = 100
NUM_HARD = NUM_HARD
EMB_SIZE = 10
DROPOUT_PROB = 0.1
EMBEDDING_DIM = 14
#Instantiate the Embedding Model
EMBEDDING_MODEL = Simple_Embedding_V1(embedding_dim = EMBEDDING_DIM,
num_embeddings=EMB_SIZE,dropout_prob=DROPOUT_PROB )
#Instantiate the Triplett Wrapper with the underlying Embedding Model
MODEL = Triplett_V0(EMBEDDING_MODEL)
#Learning Rate, Optimizer, scheduler, loss_fn for training
lr = 0.001
optimizer = torch.optim.Adam(params=MODEL.parameters(), lr =lr)
scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
loss_fn = nn.TripletMarginLoss(margin =1.0,eps = 1e-7)
However when I try to use the following function to create the model
def create_Triplett_V0(embedding_model, input_size, output_size, dropout_prob=0):
# 1. Get the base model with random weights and send to device
model=embedding_model(embedding_dim = input_size,
num_embeddings=output_size,dropout_prob=dropout_prob)
model.apply(init_weights)
MODEL =Triplett_V0(model)
#2. Make sure that the parameters are trainable
for param in MODEL.parameters():
param.requires_grad= True
#3. Set the seeds (just sets the random seed to 42 to have more comparable results)
set_seeds()
#4. Give the model a name
model_name = f"Triplett_V0_{model.name}"
print(f"[INFO] Created new {model_name} model.")
return MODEL
the test and train_loss stay stagnant.
For what it’s worth I have already checked and compared requires_grad, grad_fn, is_leaf of the underlying embedding layer the Triplett Model between those two approaches and could not find anything, that would be different between them.
If anybody could help me out, I would be very grateful! If you guys need any further information let me know!
Happy Coding to everyone!