Build A Trainable LLM
Create an LLM Trainer for training
The parameters of this LLM Trainer are basically the same as transformers.TrainingArguments
, but some additional parameters have been added for easier training setup.
from superduper_transformers import LLM, LLMTrainer
trainer = LLMTrainer(
identifier="llm-finetune-trainer",
output_dir="output/finetune",
overwrite_output_dir=True,
num_train_epochs=3,
save_total_limit=3,
logging_steps=10,
evaluation_strategy="steps",
save_steps=100,
eval_steps=100,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=2,
max_seq_length=512,
key=key,
select=select,
transform=transform,
training_kwargs=training_kwargs,
)
- Lora
- QLora
- Deepspeed
- Multi-GPUS
trainer.use_lora = True
trainer.use_lora = True
trainer.bits = 4
!pip install deepspeed
deepspeed = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
},
}
trainer.use_lora = True
trainer.bits = 4
trainer.deepspeed = deepspeed
trainer.use_lora = True
trainer.bits = 4
trainer.num_gpus = 2
Create a trainable LLM model and add it to the database, then the training task will run automatically.
llm = LLM(
identifier="llm",
model_name_or_path=model_name,
trainer=trainer,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
)
db.apply(llm)
Load the trained model
There are two methods to load a trained model:
- Load the model directly: This will load the model with the best metrics (if the transformers' best model save strategy is set) or the last version of the model.
- Use a specified checkpoint: This method downloads the specified checkpoint, then initializes the base model, and finally merges the checkpoint with the base model. This approach supports custom operations such as resetting flash_attentions, model quantization, etc., during initialization.
- Load Trained Model Directly
- Use a specified checkpoint
llm = db.load("model", "llm")
from superduper_transformers import LLM
experiment_id = db.show("checkpoint")[-1]
version = None # None means the last checkpoint
checkpoint = db.load("checkpoint", experiment_id, version=version)
llm = LLM(
identifier="llm",
model_name_or_path=model_name,
adapter_id=checkpoint,
model_kwargs=dict(load_in_4bit=True)
)