|
| 1 | +--- |
| 2 | +title: LoRA Fine-Tuning |
| 3 | +weight: 7 |
| 4 | +bookToc: true |
| 5 | +--- |
| 6 | + |
| 7 | +# LoRA Fine-Tuning |
| 8 | + |
| 9 | +Inject LoRA (Low-Rank Adaptation) adapters into a model and train on custom data. Only the small LoRA A and B matrices are updated -- the base model weights stay frozen, making fine-tuning fast and memory-efficient. |
| 10 | + |
| 11 | +This recipe demonstrates the LoRA + training API at the graph/layer level: |
| 12 | + |
| 13 | +- Building a model with Linear layers |
| 14 | +- Injecting LoRA adapters into target layers |
| 15 | +- Running a forward pass through LoRA-wrapped layers |
| 16 | +- Saving and loading the LoRA checkpoint |
| 17 | + |
| 18 | +## Full Example |
| 19 | + |
| 20 | +```go |
| 21 | +// Recipe 07: Fine-Tuning with LoRA |
| 22 | +// |
| 23 | +// Inject LoRA (Low-Rank Adaptation) adapters into a model and train on custom |
| 24 | +// data. Only the small LoRA matrices are updated -- the base model weights stay |
| 25 | +// frozen, making fine-tuning fast and memory-efficient. |
| 26 | +// |
| 27 | +// This recipe demonstrates the LoRA + training API at the graph/layer level: |
| 28 | +// - Building a model with Linear layers |
| 29 | +// - Injecting LoRA adapters into target layers |
| 30 | +// - Running a forward pass through LoRA-wrapped layers |
| 31 | +// - Saving and loading the LoRA checkpoint |
| 32 | +// |
| 33 | +// Usage: |
| 34 | +// |
| 35 | +// go run ./docs/cookbook/07-lora-fine-tuning/ |
| 36 | +package main |
| 37 | + |
| 38 | +import ( |
| 39 | + "context" |
| 40 | + "fmt" |
| 41 | + "math/rand/v2" |
| 42 | + "os" |
| 43 | + |
| 44 | + "github.com/zerfoo/ztensor/compute" |
| 45 | + "github.com/zerfoo/ztensor/graph" |
| 46 | + "github.com/zerfoo/ztensor/numeric" |
| 47 | + "github.com/zerfoo/ztensor/tensor" |
| 48 | + "github.com/zerfoo/ztensor/types" |
| 49 | + "github.com/zerfoo/zerfoo/training/lora" |
| 50 | +) |
| 51 | + |
| 52 | +// simpleLinear is a minimal Linear layer that satisfies lora.Layer[T] |
| 53 | +// (which requires graph.Node[T] + Named). |
| 54 | +type simpleLinear struct { |
| 55 | + name string |
| 56 | + weights *graph.Parameter[float32] |
| 57 | + engine compute.Engine[float32] |
| 58 | + dIn int |
| 59 | + dOut int |
| 60 | +} |
| 61 | + |
| 62 | +func newSimpleLinear(name string, engine compute.Engine[float32], dIn, dOut int) (*simpleLinear, error) { |
| 63 | + data := make([]float32, dIn*dOut) |
| 64 | + for i := range data { |
| 65 | + data[i] = rand.Float32()*0.02 - 0.01 |
| 66 | + } |
| 67 | + w, err := tensor.New[float32]([]int{dIn, dOut}, data) |
| 68 | + if err != nil { |
| 69 | + return nil, err |
| 70 | + } |
| 71 | + param, err := graph.NewParameter[float32](name+"_weights", w, tensor.New[float32]) |
| 72 | + if err != nil { |
| 73 | + return nil, err |
| 74 | + } |
| 75 | + return &simpleLinear{name: name, weights: param, engine: engine, dIn: dIn, dOut: dOut}, nil |
| 76 | +} |
| 77 | + |
| 78 | +func (l *simpleLinear) Name() string { return l.name } |
| 79 | +func (l *simpleLinear) OpType() string { return "Linear" } |
| 80 | +func (l *simpleLinear) Attributes() map[string]any { return nil } |
| 81 | +func (l *simpleLinear) OutputShape() []int { return []int{-1, l.dOut} } |
| 82 | +func (l *simpleLinear) Parameters() []*graph.Parameter[float32] { return []*graph.Parameter[float32]{l.weights} } |
| 83 | +func (l *simpleLinear) InputFeatures() int { return l.dIn } |
| 84 | +func (l *simpleLinear) OutputFeatures() int { return l.dOut } |
| 85 | + |
| 86 | +func (l *simpleLinear) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error) { |
| 87 | + return l.engine.MatMul(ctx, inputs[0], l.weights.Value) |
| 88 | +} |
| 89 | + |
| 90 | +func (l *simpleLinear) Backward(_ context.Context, _ types.BackwardMode, outputGradient *tensor.TensorNumeric[float32], _ ...*tensor.TensorNumeric[float32]) ([]*tensor.TensorNumeric[float32], error) { |
| 91 | + return []*tensor.TensorNumeric[float32]{outputGradient}, nil |
| 92 | +} |
| 93 | + |
| 94 | +// simpleModel wraps a set of Linear layers for LoRA injection. |
| 95 | +type simpleModel struct { |
| 96 | + layers map[string]lora.Layer[float32] |
| 97 | + order []string |
| 98 | +} |
| 99 | + |
| 100 | +func (m *simpleModel) Layers() []lora.Layer[float32] { |
| 101 | + var out []lora.Layer[float32] |
| 102 | + for _, name := range m.order { |
| 103 | + out = append(out, m.layers[name]) |
| 104 | + } |
| 105 | + return out |
| 106 | +} |
| 107 | + |
| 108 | +func (m *simpleModel) ReplaceLayer(name string, replacement lora.Layer[float32]) error { |
| 109 | + if _, ok := m.layers[name]; !ok { |
| 110 | + return fmt.Errorf("layer %q not found", name) |
| 111 | + } |
| 112 | + m.layers[name] = replacement |
| 113 | + return nil |
| 114 | +} |
| 115 | + |
| 116 | +func main() { |
| 117 | + ops := numeric.Float32Ops{} |
| 118 | + engine := compute.NewCPUEngine[float32](ops) |
| 119 | + |
| 120 | + // Build a small model with attention-like projections. |
| 121 | + qProj, _ := newSimpleLinear("q_proj", engine, 64, 64) |
| 122 | + kProj, _ := newSimpleLinear("k_proj", engine, 64, 64) |
| 123 | + vProj, _ := newSimpleLinear("v_proj", engine, 64, 64) |
| 124 | + oProj, _ := newSimpleLinear("o_proj", engine, 64, 64) |
| 125 | + |
| 126 | + model := &simpleModel{ |
| 127 | + layers: map[string]lora.Layer[float32]{ |
| 128 | + "q_proj": qProj, "k_proj": kProj, "v_proj": vProj, "o_proj": oProj, |
| 129 | + }, |
| 130 | + order: []string{"q_proj", "k_proj", "v_proj", "o_proj"}, |
| 131 | + } |
| 132 | + |
| 133 | + // Inject LoRA adapters into Q and V projections (rank=8, alpha=16). |
| 134 | + // Only the LoRA A and B matrices are trainable; base weights are frozen. |
| 135 | + err := lora.InjectLoRA[float32]( |
| 136 | + model, |
| 137 | + 8, // rank |
| 138 | + 16.0, // alpha |
| 139 | + []string{"q_proj", "v_proj"}, |
| 140 | + engine, |
| 141 | + ) |
| 142 | + if err != nil { |
| 143 | + fmt.Fprintf(os.Stderr, "inject lora: %v\n", err) |
| 144 | + os.Exit(1) |
| 145 | + } |
| 146 | + fmt.Println("Injected LoRA into q_proj and v_proj (rank=8, alpha=16)") |
| 147 | + |
| 148 | + // Count total parameters across all layers. |
| 149 | + var totalParams int |
| 150 | + for _, layer := range model.Layers() { |
| 151 | + for _, p := range layer.Parameters() { |
| 152 | + n := 1 |
| 153 | + for _, d := range p.Value.Shape() { |
| 154 | + n *= d |
| 155 | + } |
| 156 | + totalParams += n |
| 157 | + } |
| 158 | + } |
| 159 | + fmt.Printf("Total parameters: %d\n", totalParams) |
| 160 | + |
| 161 | + // Forward pass with synthetic data through LoRA-wrapped layers. |
| 162 | + ctx := context.Background() |
| 163 | + inputData := make([]float32, 4*64) |
| 164 | + for i := range inputData { |
| 165 | + inputData[i] = rand.Float32() |
| 166 | + } |
| 167 | + input, _ := tensor.New[float32]([]int{4, 64}, inputData) |
| 168 | + |
| 169 | + out := input |
| 170 | + for _, name := range model.order { |
| 171 | + layer := model.layers[name] |
| 172 | + out, err = layer.Forward(ctx, out) |
| 173 | + if err != nil { |
| 174 | + fmt.Fprintf(os.Stderr, "forward %s: %v\n", name, err) |
| 175 | + os.Exit(1) |
| 176 | + } |
| 177 | + } |
| 178 | + fmt.Printf("Output shape: %v\n", out.Shape()) |
| 179 | + |
| 180 | + // Save the LoRA adapter checkpoint. |
| 181 | + checkpointPath := "lora-adapter.bin" |
| 182 | + if err := lora.SaveAdapter[float32](checkpointPath, model); err != nil { |
| 183 | + fmt.Fprintf(os.Stderr, "save: %v\n", err) |
| 184 | + os.Exit(1) |
| 185 | + } |
| 186 | + fmt.Printf("Saved LoRA adapter to %s\n", checkpointPath) |
| 187 | + |
| 188 | + // Clean up the checkpoint file created by this demo. |
| 189 | + os.Remove(checkpointPath) |
| 190 | + fmt.Println("Done.") |
| 191 | +} |
| 192 | +``` |
| 193 | + |
| 194 | +## How It Works |
| 195 | + |
| 196 | +1. **Build a model with Linear layers** -- The example creates four linear projection layers (Q, K, V, O) that mimic the attention projections in a transformer. Each layer implements the `lora.Layer[T]` interface, which requires `Forward`, `Parameters`, `InputFeatures`, `OutputFeatures`, and `Name` methods. |
| 197 | + |
| 198 | +2. **Inject LoRA adapters** -- `lora.InjectLoRA` wraps the specified layers with LoRA adapters. Each adapter adds two small matrices (A and B) of the given rank. The original weight matrix is frozen, and only the LoRA matrices are trainable. The scaling factor `alpha/rank` controls the magnitude of the adapter's contribution. |
| 199 | + |
| 200 | +3. **Forward pass** -- Input flows through each layer sequentially. For LoRA-wrapped layers, the output is `base_output + (alpha/rank) * (x @ A @ B)`, where A and B are the low-rank adapter matrices. |
| 201 | + |
| 202 | +4. **Save the checkpoint** -- `lora.SaveAdapter` serializes only the LoRA parameters (not the base model weights), producing a small checkpoint file that can be loaded later with `lora.LoadAdapter`. |
| 203 | + |
| 204 | +## Key Concepts |
| 205 | + |
| 206 | +- **Rank** controls the capacity of the adapter. Typical values are 4-64. Lower rank = fewer parameters = faster training, but less expressive. |
| 207 | +- **Alpha** is a scaling hyperparameter. A common default is `alpha = 2 * rank`. |
| 208 | +- **Target layers** -- LoRA is most effective when applied to the Q and V projections in attention layers, though you can target any linear layer. |
| 209 | + |
| 210 | +## Related API Reference |
| 211 | + |
| 212 | +- [Inference API](/docs/api/inference/) -- `inference.LoadFile` and model loading options |
| 213 | +- [Generate API](/docs/api/generate/) -- text generation with loaded models |
0 commit comments