I have successfully fine-tuned Llama-2-7b with torch tune.
I then test it with:
!tune run generate --config trained-model.yaml
prompt=“what is TBS?”
I get a response that reflects my training progress.
Today I was looking to follow that pattern and train Llama-3-8b.
The training went fine. It outputted the Adapter.pt and meta-model.pt files.
When I tried to test with the same:
!tune run generate --config trained-model.yaml
prompt=“what is TBS?”
I get the following error:
Exception: Error converting the state dict. Found unexpected key: “tok_embeddings.weight”. Please make sure you’re loading a checkpoint with the right format.
Any ideas would be great.
Below is my training config:
model:
component: torchtune.models.llama3.lora_llama3_8b
lora_attn_modules: [‘q_proj’, ‘v_proj’]
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
Tokenizer
tokenizer:
component: torchtune.models.llama3.llama3_tokenizer
path: /root/Pretrained_base_models/Llama-3-8B/original/tokenizer.model
checkpointer:
component: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /root/Pretrained_base_models/Llama-3-8B/original/
checkpoint_files: [
consolidated.00.pth
]
recipe_checkpoint: null
output_dir: /root/Trained_models/Llama-3-8b/output
model_type: LLAMA3
resume_from_checkpoint: False
Dataset and Sampler
dataset:
component: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2
Optimizer and Scheduler
optimizer:
component: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
component: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 0
loss:
component: torch.nn.CrossEntropyLoss
Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False
Logging
output_dir: /root/Trained_models/Llama-3-8b/finetune_output
metric_logger:
component: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: null
Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
Profiler (disabled)
profiler:
component: torchtune.utils.profiler
enabled: False