Skip to content

Commit 080fea4

Browse files
[refactor] Simplify and unify layer planners by centralizing inference plan creation logic in layerplanner package
1 parent dc76fde commit 080fea4

15 files changed

Lines changed: 113 additions & 246 deletions

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,28 @@
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.Model;
7-
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
88
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
99
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
10+
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
11+
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
12+
import org.beehive.gpullama3.tornadovm.layers.Activation;
13+
import uk.ac.manchester.tornado.api.GridScheduler;
14+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1015
import uk.ac.manchester.tornado.api.KernelContext;
1116

17+
import java.util.ArrayList;
18+
import java.util.List;
19+
1220
/**
1321
* Abstract base for all quantization-specific planners.
1422
*
15-
* Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc.
23+
* Extracts common state from the model, detects the hardware scheduler type,
24+
* and assembles the full execution plan via createTornadoInferencePlan().
25+
* Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation.
1626
*/
17-
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights> implements GenericLayerPlanner {
18-
19-
// Common state for all quantizations
20-
protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;
21-
protected static final int THREAD_SCALE_FOR_LOGITS = 8;
27+
public abstract class QuantizedLayerPlanner<S extends State, C extends Configuration, W extends Weights>
28+
implements GenericLayerPlanner {
2229

2330
protected final S state;
2431
protected final C config;
@@ -27,9 +34,14 @@ public abstract class QuantizedLayerPlanner<S extends State, C extends Configura
2734
protected final Model model;
2835
protected final SchedulerType schedulerType;
2936

30-
/**
31-
* Constructor: validate quantization type, extract model components
32-
*/
37+
protected Activation activationLayer;
38+
protected AbstractFFNLayers<W, C> ffnLayers;
39+
protected AbstractLogitsLayer logitsLayer;
40+
41+
private List<ImmutableTaskGraph> immutableTaskGraphs;
42+
private GridScheduler gridScheduler;
43+
44+
@SuppressWarnings("unchecked")
3345
protected QuantizedLayerPlanner(S state, Model model) {
3446
this.state = state;
3547
this.model = model;
@@ -40,26 +52,53 @@ protected QuantizedLayerPlanner(S state, Model model) {
4052
validateQuantizationType();
4153
}
4254

43-
/**
44-
* Override in subclasses to validate correct quantization format. E.g., FP16LayerPlanner checks: weights instanceof FP16Weights
45-
*/
55+
/** Validates that the model weights match the expected quantization type. */
4656
protected abstract void validateQuantizationType();
4757

4858
/**
49-
* Override in subclasses for model-specific initialization
59+
* Creates the TornadoVM inference execution pipeline.
60+
* It represents the entire Feed-Forward Network (FFN) and consists of:
61+
* <ul>
62+
* <li>Activation layer</li>
63+
* <li>FFN layers (N transformer layers, model-specific)</li>
64+
* <li>Logits layer</li>
65+
* </ul>
66+
* <p>
67+
* Each component is represented as an {@link ImmutableTaskGraph}, along with a
68+
* corresponding {@link GridScheduler} configuration that defines how tasks are
69+
* mapped on the GPU.
70+
* </p>
71+
* This method assembles all components into a unified execution pipeline and
72+
* caches the resulting task graphs and scheduler for reuse across inference runs.
5073
*/
51-
protected abstract void initializeLayerComponents();
74+
protected final void createTornadoInferencePlan() {
75+
List<ImmutableTaskGraph> allTaskGraphs = new ArrayList<>();
76+
GridScheduler masterScheduler = new GridScheduler();
77+
78+
// 1. Activation layer (common to all models)
79+
allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
80+
activationLayer.updateGridScheduler(masterScheduler);
81+
82+
// 2. FFN layers (N transformer layers - model-specific)
83+
allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
84+
ffnLayers.updateGridScheduler(masterScheduler);
85+
86+
// 3. Logits layer (common to all models)
87+
allTaskGraphs.add(logitsLayer.getImmutableTaskGraph());
88+
logitsLayer.updateGridScheduler(masterScheduler);
5289

53-
// Common helper methods for all quantizations
54-
protected C getConfig() {
55-
return config;
90+
// Cache for future retrievals
91+
this.immutableTaskGraphs = allTaskGraphs;
92+
this.gridScheduler = masterScheduler;
5693
}
5794

58-
protected W getWeights() {
59-
return weights;
95+
@Override
96+
public final List<ImmutableTaskGraph> getImmutableTaskGraphs() {
97+
return this.immutableTaskGraphs;
6098
}
6199

62-
protected S getState() {
63-
return state;
100+
@Override
101+
public final GridScheduler getGridScheduler() {
102+
return this.gridScheduler;
64103
}
65-
}
104+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@
1010
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
1111

1212
public class GraniteFP16LayerPlanner extends FP16LayerPlanner<GraniteState, GraniteConfiguration, GraniteTornadoWeights> {
13+
1314
public GraniteFP16LayerPlanner(GraniteState state, Model model) {
1415
super(state, model);
15-
validateQuantizationType();
16-
setupTornadoForwardPlan();
17-
}
18-
19-
@Override
20-
protected void initializeLayerComponents() {
21-
this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
22-
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
23-
this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
16+
this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
17+
this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
createTornadoInferencePlan();
2420
}
25-
2621
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@ public class LlamaFP16LayerPlanner extends FP16LayerPlanner<LlamaState, LlamaCon
1313

1414
public LlamaFP16LayerPlanner(LlamaState state, Model model) {
1515
super(state, model);
16-
validateQuantizationType();
17-
setupTornadoForwardPlan();
16+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
17+
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
createTornadoInferencePlan();
1820
}
19-
20-
@Override
21-
protected void initializeLayerComponents() {
22-
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
23-
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
24-
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
25-
}
26-
27-
}
21+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ public MistralFP16LayerPlanner(LlamaState state, Model model) {
1616
this.activationLayer = new Activation("activationUpdate", state, weights, config);
1717
this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
1818
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19-
buildForwardPlan();
19+
createTornadoInferencePlan();
2020
}
2121
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,9 @@ public class Phi3FP16LayerPlanner extends FP16LayerPlanner<Phi3State, Phi3Config
2020

2121
public Phi3FP16LayerPlanner(Phi3State state, Model model) {
2222
super(state, model);
23-
validateQuantizationType();
24-
setupTornadoForwardPlan();
23+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
24+
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType);
25+
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
26+
createTornadoInferencePlan();
2527
}
26-
27-
@Override
28-
protected void initializeLayerComponents() {
29-
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
31-
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(),this.schedulerType);
32-
}
33-
3428
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,9 @@ public class Qwen2FP16LayerPlanner extends FP16LayerPlanner<Qwen2State, Qwen2Con
2020

2121
public Qwen2FP16LayerPlanner(Qwen2State state, Model model) {
2222
super(state, model);
23-
validateQuantizationType();
24-
setupTornadoForwardPlan();
25-
}
26-
27-
@Override
28-
protected void initializeLayerComponents() {
29-
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
31-
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
23+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
24+
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType);
25+
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
26+
createTornadoInferencePlan();
3227
}
3328
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,9 @@ public class Qwen3FP16LayerPlanner extends FP16LayerPlanner<Qwen3State, Qwen3Con
2020

2121
public Qwen3FP16LayerPlanner(Qwen3State state, Model model) {
2222
super(state, model);
23-
validateQuantizationType();
24-
setupTornadoForwardPlan();
23+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
24+
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType);
25+
this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
26+
createTornadoInferencePlan();
2527
}
26-
27-
@Override
28-
protected void initializeLayerComponents() {
29-
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
31-
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
32-
}
33-
34-
}
28+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@ public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner<GraniteState, Gran
1313

1414
public GraniteQ8_0LayerPlanner(GraniteState state, Model model) {
1515
super(state, model);
16-
validateQuantizationType();
17-
setupTornadoForwardPlan();
18-
}
19-
20-
@Override
21-
protected void initializeLayerComponents() {
22-
this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
23-
this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
24-
this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
16+
this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
17+
this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsGraniteQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
createTornadoInferencePlan();
2520
}
2621
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@ public class LlamaQ8_0LayerPlanner extends Q8_0LayerPlanner<LlamaState, LlamaCon
1313

1414
public LlamaQ8_0LayerPlanner(LlamaState state, Model model) {
1515
super(state, model);
16-
validateQuantizationType();
17-
setupTornadoForwardPlan();
16+
this.activationLayer = new Activation("activationUpdate", state, weights, config);
17+
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType);
18+
this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19+
createTornadoInferencePlan();
1820
}
19-
20-
@Override
21-
protected void initializeLayerComponents() {
22-
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
23-
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
24-
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType);
25-
}
26-
27-
}
21+
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ public MistralQ8_0LayerPlanner(LlamaState state, Model model) {
1616
this.activationLayer = new Activation("activationUpdate", state, weights, config);
1717
this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType);
1818
this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
19-
buildForwardPlan();
19+
createTornadoInferencePlan();
2020
}
2121
}

0 commit comments

Comments
 (0)