Skip to content

Commit 4ff7919

Browse files
[refactor] Rename setupFFNLayered() to setupFFNLayerTaskGraphs() and abstractify it for visibility and consistency across all FFN layers
1 parent 877550a commit 4ff7919

11 files changed

Lines changed: 35 additions & 48 deletions

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ protected AbstractFFNLayers(String taskGraphName, State state, Weights weights,
4141
this.schedulerType = schedulerType;
4242
}
4343

44+
/**
45+
* Creates the TornadoVM {@link uk.ac.manchester.tornado.api.TaskGraph} for the FFN layers.
46+
* It creates one TaskGraph per layer and snapshots it to produce an {@link ImmutableTaskGraph} per layer.
47+
* It also stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}.
48+
*/
49+
protected abstract List<ImmutableTaskGraph> setupFFNLayerTaskGraphs();
50+
4451
/**
4552
* Returns all task graphs for the FFN layers.
4653
*

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class GraniteFP16FFNLayers extends AbstractFFNLayers {
2727

2828
public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) {
2929
super(taskGraph, state, weights, config, schedulerType);
30-
this.ffnLayerTaskGraphs = setupFFNLayered();
30+
this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
3131
}
3232

3333
@Override
@@ -83,7 +83,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
8383
return ffnLayerTaskGraphs;
8484
}
8585

86-
List<ImmutableTaskGraph> setupFFNLayered() {
86+
@Override
87+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
8788
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
8889
var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i);
8990
if (i == config.numberOfLayers() - 1) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers {
2626

2727
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
2828
super(taskGraph, state, weights, config, schedulerType);
29-
this.ffnLayerTaskGraphs = setupFFNLayered();
29+
this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
3030
}
3131

3232
@Override
@@ -82,7 +82,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
8282
return ffnLayerTaskGraphs;
8383
}
8484

85-
List<ImmutableTaskGraph> setupFFNLayered() {
85+
@Override
86+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
8687
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
8788
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
8889
if (i == config.numberOfLayers() - 1) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh
4242
this.phi3State = state;
4343
this.phi3Config = config;
4444
this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize());
45-
ffnLayerTaskGraphs = setupFFNLayered();
45+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
4646
}
4747

4848
@Override
@@ -110,7 +110,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
110110
/**
111111
* Setup all FFN layers for all transformer layers
112112
*/
113-
List<ImmutableTaskGraph> setupFFNLayered() {
113+
@Override
114+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
114115
List<ImmutableTaskGraph> ffnGraphs = new ArrayList<>();
115116
for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) {
116117
TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex);

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
4242
super(taskGraphName, state, weights, config, schedulerType);
4343
this.qwen2State = state;
4444
this.qwen2Config = config;
45-
ffnLayerTaskGraphs = setupFFNLayered();
45+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
4646
}
4747

4848
@Override
@@ -135,7 +135,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
135135
return ffnLayerTaskGraphs;
136136
}
137137

138-
List<ImmutableTaskGraph> setupFFNLayered() {
138+
@Override
139+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
139140
List<ImmutableTaskGraph> ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers());
140141
for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) {
141142
TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex);

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe
5555
this.nEmbdHead = nEmbdHeadV;
5656
this.nEmbdGqa = nEmbdVGqa;
5757
this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads();
58-
ffnLayerTaskGraphs = setupFFNLayered();
58+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
5959
}
6060

6161
@Override
@@ -124,7 +124,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
124124
/**
125125
* Setup all FFN layers for all transformer layers
126126
*/
127-
List<ImmutableTaskGraph> setupFFNLayered() {
127+
@Override
128+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
128129
return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> {
129130
var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i);
130131
if (i == qwen3Config.numberOfLayers() - 1) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,11 @@ public class GraniteQ8_0FFNLayers extends AbstractFFNLayers {
2828

2929
public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
3030
super(taskGraphName, state, weights, config, schedulerType);
31-
ffnLayerTaskGraphs = setupFFNLayered();
31+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
3232
}
3333

3434
@Override
35-
public GridScheduler getGridScheduler() {
36-
return scheduler;
37-
}
38-
39-
@Override
40-
public TaskGraph getTaskGraph() {
41-
return null;
42-
}
43-
44-
@Override
45-
public ImmutableTaskGraph getImmutableTaskGraph() {
46-
return null;
47-
}
48-
49-
List<ImmutableTaskGraph> setupFFNLayered() {
35+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
5036
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
5137
var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i);
5238
if (i == config.numberOfLayers() - 1) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,11 @@ public class LlamaQ8_0FFNLayers extends AbstractFFNLayers {
2323

2424
public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) {
2525
super(taskGraphName, state, weights, config, schedulerType);
26-
ffnLayerTaskGraphs = setupFFNLayered();
26+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
2727
}
2828

2929
@Override
30-
public GridScheduler getGridScheduler() {
31-
return scheduler;
32-
}
33-
34-
@Override
35-
public TaskGraph getTaskGraph() {
36-
return null;
37-
}
38-
39-
@Override
40-
public ImmutableTaskGraph getImmutableTaskGraph() {
41-
return null;
42-
}
43-
44-
List<ImmutableTaskGraph> setupFFNLayered() {
30+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
4531
return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
4632
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
4733
if (i == config.numberOfLayers() - 1) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh
4141
this.phi3State = state;
4242
this.phi3Config = config;
4343
this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize());
44-
ffnLayerTaskGraphs = setupFFNLayered();
44+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
4545
}
4646

4747
@Override
@@ -98,7 +98,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
9898
/**
9999
* Setup all FFN layers for all transformer layers
100100
*/
101-
List<ImmutableTaskGraph> setupFFNLayered() {
101+
@Override
102+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
102103
List<ImmutableTaskGraph> ffnGraphs = new ArrayList<>();
103104
for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) {
104105
TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex);

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe
4646
super(taskGraphName, state, weights, config, schedulerType);
4747
this.qwen2State = state;
4848
this.qwen2Config = config;
49-
ffnLayerTaskGraphs = setupFFNLayered();
49+
ffnLayerTaskGraphs = setupFFNLayerTaskGraphs();
5050
}
5151

5252
@Override
@@ -137,7 +137,8 @@ public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
137137
/**
138138
* Setup all FFN layers for all transformer layers
139139
*/
140-
List<ImmutableTaskGraph> setupFFNLayered() {
140+
@Override
141+
protected List<ImmutableTaskGraph> setupFFNLayerTaskGraphs() {
141142
List<ImmutableTaskGraph> ffnGraphs = new ArrayList<>();
142143
qwen2State.temp.init(0.0f);
143144
qwen2State.tempFFN.init(0.0f);

0 commit comments

Comments
 (0)