Training¶
This page covers training a new midigpt model from scratch, or fine-tuning an existing checkpoint on your own dataset.
Requirements¶
Training requires PyTorch Lightning, HuggingFace datasets, and pyarrow. The C++ MIDI parser is not fork-safe, so num_workers must be 0.
Data format¶
midigpt trains on MIDI files stored as parquet shards. Each row in the parquet file represents one MIDI piece. The recommended source is GigaMIDI, available on HuggingFace Datasets.
Step 1 — Preprocess parquet shards¶
The preprocessing step builds a valid-index cache so dataset initialization is instant on every subsequent run. It runs a fast metadata filter (pure PyArrow, no MIDI parsing) followed by per-row validation via an isolated subprocess that bisects on crashes.
python -m midigpt.training.preprocess \
--parquet /data/train/*.parquet \
--checkpoint models/yellow.pt
Or supply a raw encoder config JSON instead of a checkpoint:
python -m midigpt.training.preprocess \
--parquet /data/train/*.parquet \
--encoder-config models/yellow_encoder.json \
--min-bars 4 --min-tracks 1
Index files are cached in ~/.midigpt/ (override with the MIDIGPT_CACHE environment variable).
Step 2 — Launch training¶
Command line:
python -m midigpt.training.trainer \
--config models/train_config.json \
--train-data /data/train/*.parquet \
--eval-data /data/valid/*.parquet \
--output-dir checkpoints/run_001
Python API:
from midigpt.training.trainer import TrainConfig, train
config = TrainConfig.from_file("models/train_config.json")
config.output_dir = "checkpoints/run_001"
train(
config,
train_path="/data/train/00000.parquet",
eval_path="/data/valid/00000.parquet",
)
train() uses PyTorch Lightning internally. At the end of training it writes a packed .pt bundle (model_final.pt) containing the weights, architecture config, and encoder config. Intermediate checkpoints are saved every save_steps steps.
TrainConfig reference¶
Architecture¶
| Field | Default | Description |
|---|---|---|
encoder_config_path |
"" |
Path to an encoder .json or a packed .pt bundle |
n_embd |
512 |
Embedding dimension |
n_layer |
6 |
Number of transformer layers |
n_head |
8 |
Number of attention heads |
n_positions |
2048 |
Maximum sequence length (positional embeddings) |
Data¶
| Field | Default | Description |
|---|---|---|
max_seq_len |
2048 |
Token sequence cap — truncated to this length |
num_bars_choices |
[4, 8] |
Window sizes sampled during training |
min_tracks |
1 |
Minimum tracks per sample |
max_tracks |
4 |
Maximum tracks per sample |
min_fill_ratio |
0.75 |
Minimum note density required to accept a window |
Training objective¶
| Field | Default | Description |
|---|---|---|
infill_probability |
0.75 |
Fraction of samples trained with FillIn tokens |
infill_bar_fraction |
0.5 |
Max per-cell infill density (drawn from Uniform(0, this)) |
mask_apply_probability |
0.5 |
Fraction of samples with MASK_BAR applied |
mask_mode |
2 |
MaskMode: 0 = RANDOM, 1 = STRUCTURED, 2 = MIXED |
Optimisation¶
| Field | Default | Description |
|---|---|---|
learning_rate |
1e-4 |
Peak learning rate |
batch_size |
16 |
Per-GPU batch size |
max_steps |
100000 |
Total training steps |
warmup_steps |
1000 |
Linear LR warmup steps |
precision |
"fp16" |
"fp16", "bf16", or "fp32" |
Infrastructure¶
| Field | Default | Description |
|---|---|---|
num_workers |
0 |
Must be 0 — the C++ MIDI parser is not fork-safe |
save_steps |
5000 |
Save a checkpoint every N steps |
eval_steps |
1000 |
Run validation every N steps |
logger |
"none" |
"tensorboard", "wandb", or "none" |
output_dir |
"checkpoints" |
Where to write checkpoints and the final bundle |
Checkpoint format¶
Packed .pt bundles written by training embed everything needed to run inference:
{
"format_version": 1,
"arch": "gpt2",
"config": { "vocab_size": ..., "n_positions": 2048, ... },
"encoder_config": { ... }, # full encoder JSON
"state_dict": { ... }, # HuggingFace GPT-2 key layout
}
Load with:
from midigpt.inference import InferenceEngine
engine = InferenceEngine.from_checkpoint("checkpoints/run_001/model_final.pt")
load_checkpoint(path) also accepts a legacy directory containing config.json + model.pt.