Replace Stub Training with Full Epochs, STE Backprop, Optimizer & Perplexity Validation
Core Repository – Domain-Agnostic
Version: 1.0
Date: March 20, 2026
Status: Ready-to-execute blueprint
Dependency note: WikiText-2 validation download and tokenization are being added in PR #27. This plan assumes that dependency merges first and then reuses those repository-local artifacts.
- Executive Summary & Success Criteria
- Prerequisites & Current State
- Overall Training Architecture
- Phase 1: WikiText-2 Data Loader & Tokenization (2–3 days)
- Phase 2: Real Train Method with Epochs, Batches & STE (5–7 days)
- Phase 3: AdamW Optimizer & Gradient Updates (3–4 days)
- Phase 4: Perplexity Evaluation on WikiText-2 (2–3 days)
- Phase 5: BenchmarkDotNet Integration & Reporting (3–4 days)
- Phase 6: Final Validation & CI Integration (2 days)
- Full UML Catalog
- Risk Register & Mitigation
- Timeline & Effort Estimates
Goal: Replace the current stub training with a real, measurable training loop that performs multiple epochs, computes loss, applies STE backprop, updates weights via AdamW, and reports perplexity on WikiText-2.
- Training runs multiple epochs and visibly reduces loss
- Perplexity on WikiText-2 validation is computed and reported (BitNet vs FP16 baseline)
- BenchmarkDotNet measures training time, tokens/sec, memory, and perplexity delta
- Report includes side-by-side TinyLlama-1.1B comparison
- Training no longer finishes in seconds — realistic duration on CPU/GPU
- Existing
BitNetModelandBitLinearwith STE forward pass already implemented - WikiText-2 raw validation set downloaded and tokenized by PR #27 (one-time dependency)
- BenchmarkDotNet already added to the test project (from prior benchmark patches)
flowchart TD
A[WikiText-2 Validation Tokens] --> B[DataLoader (Batching)]
B --> C[BitNetModel.Train(epochs)]
C --> D[For each epoch]
D --> E[Forward Pass (quantized)]
E --> F[Cross-Entropy Loss]
F --> G[STE Backward]
G --> H[AdamW Optimizer Step]
H --> I[Periodic Re-quantization]
I --> J[Perplexity Calculation]
J --> K[Benchmark Report]
- Consume the repository-local WikiText-2 artifacts added by PR #27.
- Add a tokenizer helper to convert raw text to token IDs by reusing the existing tokenizer where needed.
- Create a
WikiTextDataLoaderclass that yields batches of shape(batchSize, seqLen). - Cache or reuse the tokenized validation set in the test project for fast loading.
Update BitNetModel with a training API shaped like this:
public TrainingReport Train(int epochs, IDataLoader dataLoader)
{
var optimizer = new AdamWOptimizer(lr: 3e-4f, weightDecay: 0.1f);
var report = new TrainingReport();
for (int epoch = 0; epoch < epochs; epoch++)
{
double totalLoss = 0;
int tokenCount = 0;
foreach (var batch in dataLoader.GetBatches())
{
var logits = Forward(batch.Input); // quantized forward
var loss = CrossEntropyLoss(logits, batch.Target);
totalLoss += loss.Value * batch.Size;
tokenCount += batch.Size;
loss.BackwardWithSTE(); // straight-through estimator
optimizer.Step(Parameters);
optimizer.ZeroGrad();
}
report.AddEpoch(epoch, totalLoss / tokenCount);
ReQuantizeAllLayers(); // periodic re-quantization
}
return report;
}Implement a simple AdamWOptimizer class, or reuse an existing one if present, with:
- Momentum
- Variance
- Weight decay
- Support for ternary weight scaling (
γ) - In-place updates compatible with
BitLinear
Add a validation method to BitNetModel:
public double CalculatePerplexity(IDataLoader validationLoader)
{
double totalNLL = 0;
int tokenCount = 0;
foreach (var batch in validationLoader.GetBatches())
{
var logits = Forward(batch.Input);
var loss = CrossEntropyLoss(logits, batch.Target);
totalNLL += loss.Value * batch.Size;
tokenCount += batch.Size;
}
return Math.Exp(totalNLL / tokenCount);
}Update TinyLlamaBenchmark.cs, or create it if it is missing, with:
[Benchmark]
public double PerplexityBitNet() => _bitnetModel.CalculatePerplexity(wikiLoader);
[Benchmark]
public void TrainingEpoch() => _bitnetModel.Train(1, trainingLoader);Enhance the report generator to include:
- Training time per epoch
- Perplexity before and after training
- BitNet vs FP16 baseline comparison
- Add an integration test that runs 3 epochs and verifies loss decreases
- Update CI to run the full benchmark suite on a nightly schedule
- Generate HTML and JSON reports with tables and charts
flowchart TD
A[WikiText-2 Loader] --> B[Epoch Loop]
B --> C[Batch Forward (BitLinear)]
C --> D[Cross-Entropy Loss]
D --> E[STE Backward]
E --> F[AdamW Step]
F --> G[Re-quantize]
G --> H[Perplexity Calc]
| Risk | Likelihood | Impact | Mitigation |
|---|---|---|---|
| Training still too fast | High | High | Enforce a minimum of 3 epochs and a real WikiText loader |
| STE gradient issues | Medium | High | Add a unit test that verifies gradient flow on a small batch |
| Memory explosion | Low | Medium | Use a small batch size (8–32) plus gradient clipping |
| Phase | Estimate |
|---|---|
| Phase 1: WikiText-2 Data Loader & Tokenization | 2–3 days |
| Phase 2: Real Train Method with Epochs, Batches & STE | 5–7 days |
| Phase 3: AdamW Optimizer & Gradient Updates | 3–4 days |
| Phase 4: Perplexity Evaluation on WikiText-2 | 2–3 days |
| Phase 5: BenchmarkDotNet Integration & Reporting | 3–4 days |
| Phase 6: Final Validation & CI Integration | 2 days |
| Total | 17–23 days |
This plan is intentionally scoped to the core repository and remains domain-agnostic. It focuses on replacing stubbed training behavior with a measurable, benchmarked, paper-aligned training path.