Skip to content

Commit 5baba0e

Browse files
[refactor] Generalize AbstractFFNLayers and unify task graph setup logic across all subclasses
1 parent 16659c7 commit 5baba0e

18 files changed

Lines changed: 633 additions & 414 deletions

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
1212
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
1313
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
14+
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner;
1415
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
1516
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
1617
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
1718
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner;
1819
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
20+
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner;
1921
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
2022
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
2123
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
@@ -54,7 +56,8 @@ public static GenericLayerPlanner create(GGMLType quantization, State state, Mod
5456
// ============ FP16 Planners ============
5557
private static GenericLayerPlanner createFP16Planner(State state, Model model) {
5658
return switch (model.getModelType()) {
57-
case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model);
59+
case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model);
60+
case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model);
5861
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
5962
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
6063
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
@@ -67,7 +70,8 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) {
6770
// ============ Q8_0 Planners ============
6871
private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
6972
return switch (model.getModelType()) {
70-
case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
73+
case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
74+
case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model);
7175
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
7276
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
7377
case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
2+
3+
import org.beehive.gpullama3.inference.state.LlamaState;
4+
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
5+
import org.beehive.gpullama3.model.Model;
6+
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
8+
import org.beehive.gpullama3.tornadovm.layers.Activation;
9+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers;
10+
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
11+
12+
public class MistralFP16LayerPlanner extends FP16LayerPlanner<LlamaState, MistralConfiguration, LlamaTornadoWeights> {
13+
14+
public MistralFP16LayerPlanner(LlamaState state, Model model) {
15+
super(state, model);
16+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
17+
this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
buildForwardPlan();
20+
}
21+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
2+
3+
import org.beehive.gpullama3.inference.state.LlamaState;
4+
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
5+
import org.beehive.gpullama3.model.Model;
6+
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner;
8+
import org.beehive.gpullama3.tornadovm.layers.Activation;
9+
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers;
10+
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
11+
12+
public class MistralQ8_0LayerPlanner extends Q8_0LayerPlanner<LlamaState, MistralConfiguration, LlamaTornadoWeights> {
13+
14+
public MistralQ8_0LayerPlanner(LlamaState state, Model model) {
15+
super(state, model);
16+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
17+
this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
buildForwardPlan();
20+
}
21+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
public abstract class FP16LayerPlanner<S extends State, C extends Configuration, W extends TornadoWeights> extends QuantizedLayerPlanner<S, C, W> {
2626

2727
protected Activation activationLayer;
28-
protected AbstractFFNLayers ffnLayers;
28+
protected AbstractFFNLayers<?,?> ffnLayers;
2929
protected LogitsFP16Layer logitsLayer;
3030

3131
protected List<ImmutableTaskGraph> immutableTaskGraphs;
@@ -56,7 +56,7 @@ protected final void setupTornadoForwardPlan() {
5656
activationLayer.updateGridScheduler(masterScheduler);
5757

5858
// 2. FFN layers (N transformer layers - model-specific)
59-
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
59+
allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
6060
ffnLayers.updateGridScheduler(masterScheduler);
6161

6262
// 3. Logits layer (common to all models)

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
public abstract class Q8_0LayerPlanner<S extends State, C extends Configuration, W extends TornadoWeights> extends QuantizedLayerPlanner<S, C, W> {
2727

2828
protected Activation activationLayer;
29-
protected AbstractFFNLayers ffnLayers;
29+
protected AbstractFFNLayers<?,?> ffnLayers;
3030
protected LogitsQ8_0Layer logitsLayer;
3131

3232
// Cache for task graphs and scheduler (set once, reused)
@@ -59,7 +59,7 @@ protected final void setupTornadoForwardPlan() {
5959
activationLayer.updateGridScheduler(masterScheduler);
6060

6161
// 2. FFN layers (N transformer layers - model-specific)
62-
allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs());
62+
allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
6363
ffnLayers.updateGridScheduler(masterScheduler);
6464

6565
// 3. Logits layer (common to all models)

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,55 +5,70 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
77
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
8+
import uk.ac.manchester.tornado.api.TaskGraph;
89

910
import java.util.List;
11+
import java.util.stream.IntStream;
1012

1113
/**
1214
* Abstract base class for all FFN (Feed-Forward Network) layer implementations.
13-
*
14-
* Each subclass builds N ImmutableTaskGraphs (one per FFN layer) via
15-
* {@link #setupFFNLayerTaskGraphs}, covering RMSNorm, Attention, and FFN computations.
16-
*
17-
* Model-specific subclasses: Llama, Qwen2, Qwen3, Phi3, Granite — each in FP16 and Q8_0 variants.
15+
* Extended by model and quantization-specific subclasses that provide specific implementations.
1816
*/
19-
public abstract class AbstractFFNLayers extends AbstractLayer {
17+
public abstract class AbstractFFNLayers<W extends Weights, C extends Configuration> extends AbstractLayer {
18+
19+
/**
20+
* List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer.
21+
* Build by {@link #setupFFNLayers()}.
22+
*/
23+
private List<ImmutableTaskGraph> ffnLayerITGs;
24+
protected final W weights;
25+
protected final C config;
2026

2127
protected String lastFFNLayerTaskGraphID;
2228
protected final SchedulerType schedulerType;
2329

30+
protected AbstractFFNLayers(String taskGraphName, State state, W weights, C config, SchedulerType schedulerType) {
31+
super(taskGraphName, state, weights, config);
32+
this.weights = weights;
33+
this.config = config;
34+
this.schedulerType = schedulerType;
35+
// the ffnLayerITGs is initialized on subclasses
36+
// due to some model-specific values (i.e. in Qwen3)
37+
}
2438

2539
/**
26-
* Constructor for FFN layers.
27-
*
28-
* @param taskGraphName
29-
* Name for the task graph
30-
* @param state
31-
* Runtime state (LlamaState, Qwen2State, etc.)
32-
* @param weights
33-
* Model weights (FP16Weights, Q8_0Weights, etc.)
34-
* @param config
35-
* Model configuration
40+
* Creates the {@link ImmutableTaskGraph} list for each FFN layer.
3641
*/
37-
protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
38-
super(taskGraphName, state, weights, config);
39-
this.schedulerType = schedulerType;
42+
protected void setupFFNLayers() {
43+
int numLayers = config.numberOfLayers();
44+
45+
this.ffnLayerITGs = IntStream.range(0, numLayers)
46+
.mapToObj(this::setupFFNLayer)
47+
.toList();
4048
}
4149

4250
/**
43-
* Creates the TornadoVM {@link uk.ac.manchester.tornado.api.TaskGraph} for the FFN layers.
44-
* It creates one TaskGraph per layer and snapshots it to produce an {@link ImmutableTaskGraph} per layer.
45-
* It also stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}.
51+
* Creates the TaskGraph for a specific FFN layer and produces the {@link ImmutableTaskGraph}.
52+
* In addition, it stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}.
4653
*/
47-
protected abstract List<ImmutableTaskGraph> setupFFNLayerTaskGraphs();
54+
private ImmutableTaskGraph setupFFNLayer(int layerIndex) {
55+
TaskGraph tg = createFFNLayerTaskGraph(layerIndex);
56+
57+
if (layerIndex == config.numberOfLayers() - 1) {
58+
lastFFNLayerTaskGraphID = tg.getTaskGraphName();
59+
}
60+
61+
return tg.snapshot();
62+
}
4863

4964
/**
50-
* Returns all task graphs for the FFN layers.
51-
*
52-
* For a model with N transformer layers, this returns N ImmutableTaskGraphs, one for each layer (containing RMSNorm, Attention, FFN computations).
53-
*
54-
* @return List of immutable task graphs (one per transformer layer)
65+
* Model and quantization-specific implementation of the FFN layer task graph.
5566
*/
56-
public abstract List<ImmutableTaskGraph> getFFNLayerTaskGraphs();
67+
protected abstract TaskGraph createFFNLayerTaskGraph(int layerIndex);
68+
69+
public List<ImmutableTaskGraph> getFFNLayerImmutableTaskGraphs() {
70+
return ffnLayerITGs;
71+
}
5772

5873
/**
5974
* Returns the TaskGraph ID of the last FFN layer.

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,15 @@
1111
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
1212
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
1313
import uk.ac.manchester.tornado.api.GridScheduler;
14-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1514
import uk.ac.manchester.tornado.api.TaskGraph;
1615
import uk.ac.manchester.tornado.api.WorkerGrid;
1716
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1817

19-
import java.util.List;
20-
import java.util.stream.IntStream;
18+
public class GraniteFP16FFNLayers extends AbstractFFNLayers<GraniteTornadoWeights, GraniteConfiguration> {
2119

22-
public class GraniteFP16FFNLayers extends AbstractFFNLayers {
23-
24-
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
25-
26-
public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) {
20+
public GraniteFP16FFNLayers(String taskGraph, State state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
2721
super(taskGraph, state, weights, config, schedulerType);
28-
this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
22+
setupFFNLayers();
2923
}
3024

3125
@Override
@@ -62,21 +56,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6256
return tornadoForwardScheduler;
6357
}
6458

65-
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
66-
return ffnLayerTaskGraphs;
67-
}
68-
69-
@Override
70-
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
71-
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
72-
var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i);
73-
if (i == config.numberOfLayers() - 1) {
74-
this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName();
75-
}
76-
return ffnLayer.snapshot();
77-
}).toList();
78-
}
79-
8059
// @formatter:off
8160
/**
8261
* Transformer Layer Task Flow (LlamaFP16FFNLayers)
@@ -163,7 +142,8 @@ protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
163142
* • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
164143
*
165144
*/
166-
TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) {
145+
@Override
146+
protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
167147
var layerTaskGraphName = "layer_" + layerIndex;
168148
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
169149

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,23 @@
11
package org.beehive.gpullama3.tornadovm.layers.type.fp16;
22

33
import org.beehive.gpullama3.inference.state.State;
4-
import org.beehive.gpullama3.inference.weights.Weights;
54
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
6-
import org.beehive.gpullama3.model.Configuration;
5+
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
76
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
87
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
98
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
109
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
1110
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
1211
import uk.ac.manchester.tornado.api.GridScheduler;
13-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1412
import uk.ac.manchester.tornado.api.TaskGraph;
1513
import uk.ac.manchester.tornado.api.WorkerGrid;
1614
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1715

18-
import java.util.List;
19-
import java.util.stream.IntStream;
16+
public class LlamaFP16FFNLayers extends AbstractFFNLayers<LlamaTornadoWeights, LlamaConfiguration> {
2017

21-
public class LlamaFP16FFNLayers extends AbstractFFNLayers {
22-
23-
private List<ImmutableTaskGraph> ffnLayerTaskGraphs;
24-
25-
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
18+
public LlamaFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) {
2619
super(taskGraph, state, weights, config, schedulerType);
27-
this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
20+
setupFFNLayers();
2821
}
2922

3023
@Override
@@ -61,22 +54,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6154
return tornadoForwardScheduler;
6255
}
6356

64-
@Override
65-
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
66-
return ffnLayerTaskGraphs;
67-
}
68-
69-
@Override
70-
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
71-
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
72-
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
73-
if (i == config.numberOfLayers() - 1) {
74-
this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName();
75-
}
76-
return ffnLayer.snapshot();
77-
}).toList();
78-
}
79-
8057
// @formatter:off
8158
/**
8259
* Transformer Layer Task Flow (LlamaFP16FFNLayers)
@@ -163,7 +140,8 @@ protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
163140
* • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
164141
*
165142
*/
166-
TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) {
143+
@Override
144+
protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
167145
var layerTaskGraphName = "layer_" + layerIndex;
168146
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
169147

0 commit comments

Comments
 (0)