Skip to content

Commit 5bad345

Browse files
committed
docs(cookbooks): add cookbooks 07-12
1 parent 66292e4 commit 5bad345

6 files changed

Lines changed: 942 additions & 0 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
---
2+
title: Batch Inference
3+
weight: 8
4+
bookToc: true
5+
---
6+
7+
# Batch Inference
8+
9+
Run inference over many prompts concurrently using goroutines. This pattern is useful for processing datasets, evaluations, or any batch workload where you need to generate completions for a list of inputs.
10+
11+
The program loads a single model and fans out generation across a configurable number of worker goroutines, collecting results in order.
12+
13+
## Full Example
14+
15+
```go
16+
// Recipe 08: Batch Inference
17+
//
18+
// Run inference over many prompts concurrently using goroutines. This pattern
19+
// is useful for processing datasets, evaluations, or any batch workload.
20+
//
21+
// The program loads a single model and fans out generation across a configurable
22+
// number of worker goroutines, collecting results in order.
23+
//
24+
// Usage:
25+
//
26+
// go run ./docs/cookbook/08-batch-inference/ --model path/to/model.gguf
27+
package main
28+
29+
import (
30+
"context"
31+
"flag"
32+
"fmt"
33+
"os"
34+
"sync"
35+
36+
"github.com/zerfoo/zerfoo"
37+
)
38+
39+
// prompts is a batch of inputs to process.
40+
var prompts = []string{
41+
"Summarize the Go memory model in one sentence.",
42+
"What is a goroutine?",
43+
"Explain channels in Go.",
44+
"What is the purpose of the context package?",
45+
"Describe Go's approach to error handling.",
46+
"What is a defer statement?",
47+
"Explain interfaces in Go.",
48+
"What are Go modules?",
49+
}
50+
51+
func main() {
52+
modelPath := flag.String("model", "", "path to GGUF model file or HuggingFace model ID")
53+
workers := flag.Int("workers", 4, "number of concurrent workers")
54+
flag.Parse()
55+
56+
if *modelPath == "" {
57+
fmt.Fprintln(os.Stderr, "usage: batch-inference --model <path> [--workers 4]")
58+
os.Exit(1)
59+
}
60+
61+
m, err := zerfoo.Load(*modelPath)
62+
if err != nil {
63+
fmt.Fprintf(os.Stderr, "load: %v\n", err)
64+
os.Exit(1)
65+
}
66+
defer m.Close()
67+
68+
// Results are stored in order.
69+
results := make([]string, len(prompts))
70+
errs := make([]error, len(prompts))
71+
72+
// Fan out work across goroutines.
73+
var wg sync.WaitGroup
74+
sem := make(chan struct{}, *workers)
75+
76+
for i, prompt := range prompts {
77+
wg.Add(1)
78+
go func(idx int, p string) {
79+
defer wg.Done()
80+
sem <- struct{}{} // Acquire worker slot.
81+
defer func() { <-sem }() // Release.
82+
83+
result, err := m.Generate(context.Background(), p,
84+
zerfoo.WithGenMaxTokens(128),
85+
zerfoo.WithGenTemperature(0.3),
86+
)
87+
if err != nil {
88+
errs[idx] = err
89+
return
90+
}
91+
results[idx] = result.Text
92+
}(i, prompt)
93+
}
94+
95+
wg.Wait()
96+
97+
// Print results.
98+
for i, prompt := range prompts {
99+
fmt.Printf("--- Prompt %d: %s\n", i+1, prompt)
100+
if errs[i] != nil {
101+
fmt.Printf(" Error: %v\n", errs[i])
102+
} else {
103+
fmt.Printf(" %s\n\n", results[i])
104+
}
105+
}
106+
}
107+
```
108+
109+
## How It Works
110+
111+
1. **Load the model once** -- `zerfoo.Load` loads the GGUF model into memory. The model handle is safe for concurrent use from multiple goroutines.
112+
113+
2. **Fan out with a semaphore** -- A buffered channel acts as a semaphore to limit the number of concurrent inference calls to `--workers` (default 4). Each goroutine acquires a slot before calling `Generate` and releases it when done.
114+
115+
3. **Collect results in order** -- Results and errors are written to pre-allocated slices indexed by prompt position, so the output order matches the input order regardless of which goroutine finishes first.
116+
117+
4. **Configurable generation** -- Each call uses `WithGenMaxTokens` and `WithGenTemperature` to control output length and randomness. You can pass any generation option supported by the API.
118+
119+
## Tuning the Worker Count
120+
121+
- **CPU inference**: Start with `workers = runtime.NumCPU()` and adjust based on memory. Each concurrent generation allocates its own KV cache.
122+
- **GPU inference**: GPU memory is the bottleneck. Start with 1-2 workers and increase until VRAM is saturated.
123+
- **Latency vs. throughput**: More workers increases total throughput but may increase per-request latency due to resource contention.
124+
125+
## Related API Reference
126+
127+
- [Generate API](/docs/api/generate/) -- generation options and result types
128+
- [Inference API](/docs/api/inference/) -- model loading and device selection
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
---
2+
title: LoRA Fine-Tuning
3+
weight: 7
4+
bookToc: true
5+
---
6+
7+
# LoRA Fine-Tuning
8+
9+
Inject LoRA (Low-Rank Adaptation) adapters into a model and train on custom data. Only the small LoRA A and B matrices are updated -- the base model weights stay frozen, making fine-tuning fast and memory-efficient.
10+
11+
This recipe demonstrates the LoRA + training API at the graph/layer level:
12+
13+
- Building a model with Linear layers
14+
- Injecting LoRA adapters into target layers
15+
- Running a forward pass through LoRA-wrapped layers
16+
- Saving and loading the LoRA checkpoint
17+
18+
## Full Example
19+
20+
```go
21+
// Recipe 07: Fine-Tuning with LoRA
22+
//
23+
// Inject LoRA (Low-Rank Adaptation) adapters into a model and train on custom
24+
// data. Only the small LoRA matrices are updated -- the base model weights stay
25+
// frozen, making fine-tuning fast and memory-efficient.
26+
//
27+
// This recipe demonstrates the LoRA + training API at the graph/layer level:
28+
// - Building a model with Linear layers
29+
// - Injecting LoRA adapters into target layers
30+
// - Running a forward pass through LoRA-wrapped layers
31+
// - Saving and loading the LoRA checkpoint
32+
//
33+
// Usage:
34+
//
35+
// go run ./docs/cookbook/07-lora-fine-tuning/
36+
package main
37+
38+
import (
39+
"context"
40+
"fmt"
41+
"math/rand/v2"
42+
"os"
43+
44+
"github.com/zerfoo/ztensor/compute"
45+
"github.com/zerfoo/ztensor/graph"
46+
"github.com/zerfoo/ztensor/numeric"
47+
"github.com/zerfoo/ztensor/tensor"
48+
"github.com/zerfoo/ztensor/types"
49+
"github.com/zerfoo/zerfoo/training/lora"
50+
)
51+
52+
// simpleLinear is a minimal Linear layer that satisfies lora.Layer[T]
53+
// (which requires graph.Node[T] + Named).
54+
type simpleLinear struct {
55+
name string
56+
weights *graph.Parameter[float32]
57+
engine compute.Engine[float32]
58+
dIn int
59+
dOut int
60+
}
61+
62+
func newSimpleLinear(name string, engine compute.Engine[float32], dIn, dOut int) (*simpleLinear, error) {
63+
data := make([]float32, dIn*dOut)
64+
for i := range data {
65+
data[i] = rand.Float32()*0.02 - 0.01
66+
}
67+
w, err := tensor.New[float32]([]int{dIn, dOut}, data)
68+
if err != nil {
69+
return nil, err
70+
}
71+
param, err := graph.NewParameter[float32](name+"_weights", w, tensor.New[float32])
72+
if err != nil {
73+
return nil, err
74+
}
75+
return &simpleLinear{name: name, weights: param, engine: engine, dIn: dIn, dOut: dOut}, nil
76+
}
77+
78+
func (l *simpleLinear) Name() string { return l.name }
79+
func (l *simpleLinear) OpType() string { return "Linear" }
80+
func (l *simpleLinear) Attributes() map[string]any { return nil }
81+
func (l *simpleLinear) OutputShape() []int { return []int{-1, l.dOut} }
82+
func (l *simpleLinear) Parameters() []*graph.Parameter[float32] { return []*graph.Parameter[float32]{l.weights} }
83+
func (l *simpleLinear) InputFeatures() int { return l.dIn }
84+
func (l *simpleLinear) OutputFeatures() int { return l.dOut }
85+
86+
func (l *simpleLinear) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error) {
87+
return l.engine.MatMul(ctx, inputs[0], l.weights.Value)
88+
}
89+
90+
func (l *simpleLinear) Backward(_ context.Context, _ types.BackwardMode, outputGradient *tensor.TensorNumeric[float32], _ ...*tensor.TensorNumeric[float32]) ([]*tensor.TensorNumeric[float32], error) {
91+
return []*tensor.TensorNumeric[float32]{outputGradient}, nil
92+
}
93+
94+
// simpleModel wraps a set of Linear layers for LoRA injection.
95+
type simpleModel struct {
96+
layers map[string]lora.Layer[float32]
97+
order []string
98+
}
99+
100+
func (m *simpleModel) Layers() []lora.Layer[float32] {
101+
var out []lora.Layer[float32]
102+
for _, name := range m.order {
103+
out = append(out, m.layers[name])
104+
}
105+
return out
106+
}
107+
108+
func (m *simpleModel) ReplaceLayer(name string, replacement lora.Layer[float32]) error {
109+
if _, ok := m.layers[name]; !ok {
110+
return fmt.Errorf("layer %q not found", name)
111+
}
112+
m.layers[name] = replacement
113+
return nil
114+
}
115+
116+
func main() {
117+
ops := numeric.Float32Ops{}
118+
engine := compute.NewCPUEngine[float32](ops)
119+
120+
// Build a small model with attention-like projections.
121+
qProj, _ := newSimpleLinear("q_proj", engine, 64, 64)
122+
kProj, _ := newSimpleLinear("k_proj", engine, 64, 64)
123+
vProj, _ := newSimpleLinear("v_proj", engine, 64, 64)
124+
oProj, _ := newSimpleLinear("o_proj", engine, 64, 64)
125+
126+
model := &simpleModel{
127+
layers: map[string]lora.Layer[float32]{
128+
"q_proj": qProj, "k_proj": kProj, "v_proj": vProj, "o_proj": oProj,
129+
},
130+
order: []string{"q_proj", "k_proj", "v_proj", "o_proj"},
131+
}
132+
133+
// Inject LoRA adapters into Q and V projections (rank=8, alpha=16).
134+
// Only the LoRA A and B matrices are trainable; base weights are frozen.
135+
err := lora.InjectLoRA[float32](
136+
model,
137+
8, // rank
138+
16.0, // alpha
139+
[]string{"q_proj", "v_proj"},
140+
engine,
141+
)
142+
if err != nil {
143+
fmt.Fprintf(os.Stderr, "inject lora: %v\n", err)
144+
os.Exit(1)
145+
}
146+
fmt.Println("Injected LoRA into q_proj and v_proj (rank=8, alpha=16)")
147+
148+
// Count total parameters across all layers.
149+
var totalParams int
150+
for _, layer := range model.Layers() {
151+
for _, p := range layer.Parameters() {
152+
n := 1
153+
for _, d := range p.Value.Shape() {
154+
n *= d
155+
}
156+
totalParams += n
157+
}
158+
}
159+
fmt.Printf("Total parameters: %d\n", totalParams)
160+
161+
// Forward pass with synthetic data through LoRA-wrapped layers.
162+
ctx := context.Background()
163+
inputData := make([]float32, 4*64)
164+
for i := range inputData {
165+
inputData[i] = rand.Float32()
166+
}
167+
input, _ := tensor.New[float32]([]int{4, 64}, inputData)
168+
169+
out := input
170+
for _, name := range model.order {
171+
layer := model.layers[name]
172+
out, err = layer.Forward(ctx, out)
173+
if err != nil {
174+
fmt.Fprintf(os.Stderr, "forward %s: %v\n", name, err)
175+
os.Exit(1)
176+
}
177+
}
178+
fmt.Printf("Output shape: %v\n", out.Shape())
179+
180+
// Save the LoRA adapter checkpoint.
181+
checkpointPath := "lora-adapter.bin"
182+
if err := lora.SaveAdapter[float32](checkpointPath, model); err != nil {
183+
fmt.Fprintf(os.Stderr, "save: %v\n", err)
184+
os.Exit(1)
185+
}
186+
fmt.Printf("Saved LoRA adapter to %s\n", checkpointPath)
187+
188+
// Clean up the checkpoint file created by this demo.
189+
os.Remove(checkpointPath)
190+
fmt.Println("Done.")
191+
}
192+
```
193+
194+
## How It Works
195+
196+
1. **Build a model with Linear layers** -- The example creates four linear projection layers (Q, K, V, O) that mimic the attention projections in a transformer. Each layer implements the `lora.Layer[T]` interface, which requires `Forward`, `Parameters`, `InputFeatures`, `OutputFeatures`, and `Name` methods.
197+
198+
2. **Inject LoRA adapters** -- `lora.InjectLoRA` wraps the specified layers with LoRA adapters. Each adapter adds two small matrices (A and B) of the given rank. The original weight matrix is frozen, and only the LoRA matrices are trainable. The scaling factor `alpha/rank` controls the magnitude of the adapter's contribution.
199+
200+
3. **Forward pass** -- Input flows through each layer sequentially. For LoRA-wrapped layers, the output is `base_output + (alpha/rank) * (x @ A @ B)`, where A and B are the low-rank adapter matrices.
201+
202+
4. **Save the checkpoint** -- `lora.SaveAdapter` serializes only the LoRA parameters (not the base model weights), producing a small checkpoint file that can be loaded later with `lora.LoadAdapter`.
203+
204+
## Key Concepts
205+
206+
- **Rank** controls the capacity of the adapter. Typical values are 4-64. Lower rank = fewer parameters = faster training, but less expressive.
207+
- **Alpha** is a scaling hyperparameter. A common default is `alpha = 2 * rank`.
208+
- **Target layers** -- LoRA is most effective when applied to the Q and V projections in attention layers, though you can target any linear layer.
209+
210+
## Related API Reference
211+
212+
- [Inference API](/docs/api/inference/) -- `inference.LoadFile` and model loading options
213+
- [Generate API](/docs/api/generate/) -- text generation with loaded models

0 commit comments

Comments
 (0)