Skip to content

Commit 16659c7

Browse files
[refactor] Remove unused fields and methods across FFN layer implementations
1 parent 4ff7919 commit 16659c7

14 files changed

Lines changed: 18 additions & 211 deletions

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
/**
1212
* Abstract base class for all FFN (Feed-Forward Network) layer implementations.
1313
*
14-
* Extends AbstractLayer and adds FFN-specific methods: - getFfnLayerTaskGraphs(): Returns task graphs for all transformer layers - getLastTaskGraphID(): Tracks the ID of the last task graph
14+
* Each subclass builds N ImmutableTaskGraphs (one per FFN layer) via
15+
* {@link #setupFFNLayerTaskGraphs}, covering RMSNorm, Attention, and FFN computations.
1516
*
16-
* All model-specific FFN layers extend this: - LlamaFP16FFNLayers, Qwen2FP16FFNLayers, Qwen3FP16FFNLayers, Phi3FP16FFNLayers - LlamaQ8_0FFNLayers, Qwen2Q8_0FFNLayers, Qwen3Q8_0FFNLayers,
17-
* Phi3Q8_0FFNLayers
18-
*
19-
* Used by FP16LayerPlanner and Q8_0LayerPlanner template methods for type-safe polymorphic access to any FFN layer implementation.
17+
* Model-specific subclasses: Llama, Qwen2, Qwen3, Phi3, Granite — each in FP16 and Q8_0 variants.
2018
*/
2119
public abstract class AbstractFFNLayers extends AbstractLayer {
2220

@@ -55,7 +53,7 @@ protected AbstractFFNLayers(String taskGraphName, State state, Weights weights,
5553
*
5654
* @return List of immutable task graphs (one per transformer layer)
5755
*/
58-
public abstract List<ImmutableTaskGraph> getFfnLayerTaskGraphs();
56+
public abstract List<ImmutableTaskGraph> getFFNLayerTaskGraphs();
5957

6058
/**
6159
* Returns the TaskGraph ID of the last FFN layer.

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.model.Configuration;
66
import uk.ac.manchester.tornado.api.GridScheduler;
7-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
87
import uk.ac.manchester.tornado.api.KernelContext;
98
import uk.ac.manchester.tornado.api.TaskGraph;
109

11-
import java.util.ArrayList;
12-
import java.util.List;
13-
1410
/**
15-
* Minimal base with common fields/utilities so subclasses compile cleanly. Adjust or remove fields if they already exist in your project.
11+
* Abstract base class for Activations, FFN Layers, and Logits.
1612
*/
1713
public abstract class AbstractLayer {
1814

@@ -22,17 +18,10 @@ public abstract class AbstractLayer {
2218

2319
protected final Weights weights;
2420
protected final Configuration config;
25-
/** Often a small context/config buffer passed into kernels. Use your real type if available. */
21+
protected final State state;
2622
protected final KernelContext context = new KernelContext();
27-
/** Collected snapshots for scheduling / debugging. */
28-
protected final List<ImmutableTaskGraph> taskGraphs = new ArrayList<>();
29-
/** Optional: track the "main" task graph for the layer if one exists. */
30-
protected TaskGraph taskGraph;
31-
/** Shared runtime objects (exposed because kernels expect them). */
32-
protected State state;
3323

3424
protected AbstractLayer(String taskGraphName, State state, Weights weights, Configuration config) {
35-
this.taskGraph = null;
3625
this.state = state;
3726
this.weights = weights;
3827
this.config = config;
@@ -48,13 +37,7 @@ protected static <T> T requireWeightsType(Object weights, Class<T> expectedType,
4837

4938
public abstract GridScheduler updateGridScheduler(GridScheduler scheduler);
5039

51-
public abstract GridScheduler getGridScheduler();
52-
53-
public abstract TaskGraph getTaskGraph();
54-
55-
public abstract ImmutableTaskGraph getImmutableTaskGraph();
56-
57-
/** Allow subclasses to override if they need custom transfers. */
40+
/** Allow subclasses to override if they need custom data transfers. */
5841
protected TaskGraph configureLayerDataTransfers(TaskGraph tg, int layerIndex) {
5942
return tg;
6043
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,10 @@ public GridScheduler updateGridScheduler(GridScheduler scheduler) {
5050
return scheduler;
5151
}
5252

53-
@Override
54-
public GridScheduler getGridScheduler() {
55-
return null;
56-
}
57-
58-
@Override
5953
public TaskGraph getTaskGraph() {
6054
return activationUpdate;
6155
}
6256

63-
@Override
6457
public ImmutableTaskGraph getImmutableTaskGraph() {
6558
return activationUpdate.snapshot();
6659
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,3 @@ public ActivationGranite(String taskGraphHandle, State state, Weights weights, G
4141
}
4242
// @formatter:on
4343
}
44-
45-
@Override
46-
public GridScheduler updateGridScheduler(GridScheduler scheduler) {
47-
WorkerGrid worker = new WorkerGrid1D(config.dim());
48-
worker.setLocalWork(128, 1, 1);
49-
scheduler.addWorkerGrid("activationUpdate.updateX", worker);
50-
return scheduler;
51-
}
52-
53-
@Override
54-
public GridScheduler getGridScheduler() {
55-
return null;
56-
}
57-
58-
@Override
59-
public TaskGraph getTaskGraph() {
60-
return activationUpdate;
61-
}
62-
63-
@Override
64-
public ImmutableTaskGraph getImmutableTaskGraph() {
65-
return activationUpdate.snapshot();
66-
}
67-
68-
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
public class GraniteFP16FFNLayers extends AbstractFFNLayers {
2323

24-
TaskGraph ffnTaskGraphs;
25-
GridScheduler scheduler;
2624
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
2725

2826
public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) {
@@ -64,22 +62,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6462
return tornadoForwardScheduler;
6563
}
6664

67-
@Override
68-
public GridScheduler getGridScheduler() {
69-
return scheduler;
70-
}
71-
72-
@Override
73-
public TaskGraph getTaskGraph() {
74-
return ffnTaskGraphs;
75-
}
76-
77-
@Override
78-
public ImmutableTaskGraph getImmutableTaskGraph() {
79-
return null;
80-
}
81-
82-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
65+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
8366
return ffnLayerTaskGraphs;
8467
}
8568

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
public class LlamaFP16FFNLayers extends AbstractFFNLayers {
2222

23-
TaskGraph ffnTaskGraphs;
24-
GridScheduler scheduler;
25-
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
23+
private List<ImmutableTaskGraph> ffnLayerTaskGraphs;
2624

2725
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
2826
super(taskGraph, state, weights, config, schedulerType);
@@ -64,21 +62,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6462
}
6563

6664
@Override
67-
public GridScheduler getGridScheduler() {
68-
return scheduler;
69-
}
70-
71-
@Override
72-
public TaskGraph getTaskGraph() {
73-
return ffnTaskGraphs;
74-
}
75-
76-
@Override
77-
public ImmutableTaskGraph getImmutableTaskGraph() {
78-
return null;
79-
}
80-
81-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
65+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
8266
return ffnLayerTaskGraphs;
8367
}
8468

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ public class Phi3FP16FFNLayers extends AbstractFFNLayers {
3333
private final Phi3Configuration phi3Config;
3434
// Phi3-specific dimension for combined QKV buffer
3535
private final int opSize;
36-
TaskGraph ffnLayerTaskGraph;
37-
GridScheduler scheduler;
3836
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
3937

4038
public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) {
@@ -88,22 +86,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
8886
return gridScheduler;
8987
}
9088

91-
@Override
92-
public GridScheduler getGridScheduler() {
93-
return scheduler;
94-
}
95-
96-
@Override
97-
public TaskGraph getTaskGraph() {
98-
return ffnLayerTaskGraph;
99-
}
100-
101-
@Override
102-
public ImmutableTaskGraph getImmutableTaskGraph() {
103-
return null;
104-
}
105-
106-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
89+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
10790
return ffnLayerTaskGraphs;
10891
}
10992

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ public class Qwen2FP16FFNLayers extends AbstractFFNLayers {
3434
// Typed references to Qwen2-specific state and config
3535
private final Qwen2State qwen2State;
3636
private final Qwen2Configuration qwen2Config;
37-
TaskGraph ffnLayerTaskGraph;
38-
GridScheduler scheduler;
3937
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
4038

4139
public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) {
@@ -116,22 +114,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
116114
return tornadoForwardScheduler;
117115
}
118116

119-
@Override
120-
public GridScheduler getGridScheduler() {
121-
return scheduler;
122-
}
123-
124-
@Override
125-
public TaskGraph getTaskGraph() {
126-
return ffnLayerTaskGraph;
127-
}
128-
129-
@Override
130-
public ImmutableTaskGraph getImmutableTaskGraph() {
131-
return null;
132-
}
133-
134-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
117+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
135118
return ffnLayerTaskGraphs;
136119
}
137120

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers {
3838
private final int nEmbdHead;
3939
private final int nEmbdGqa;
4040
private final int gqa;
41-
TaskGraph ffnLayerTaskGraph;
42-
GridScheduler scheduler;
4341
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
4442

4543
public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) {
@@ -102,22 +100,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
102100
return gridScheduler;
103101
}
104102

105-
@Override
106-
public GridScheduler getGridScheduler() {
107-
return scheduler;
108-
}
109-
110-
@Override
111-
public TaskGraph getTaskGraph() {
112-
return ffnLayerTaskGraph;
113-
}
114-
115-
@Override
116-
public ImmutableTaskGraph getImmutableTaskGraph() {
117-
return null;
118-
}
119-
120-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
103+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
121104
return ffnLayerTaskGraphs;
122105
}
123106

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
package org.beehive.gpullama3.tornadovm.layers.type.q8_0;
22

33
import org.beehive.gpullama3.inference.state.GraniteState;
4-
import org.beehive.gpullama3.inference.state.LlamaState;
54
import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
6-
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
7-
import org.beehive.gpullama3.model.Configuration;
8-
import org.beehive.gpullama3.model.granite.Granite;
95
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
106
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
117
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
@@ -23,7 +19,6 @@
2319

2420
public class GraniteQ8_0FFNLayers extends AbstractFFNLayers {
2521

26-
GridScheduler scheduler;
2722
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
2823

2924
public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
@@ -314,7 +309,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
314309
return tornadoForwardScheduler;
315310
}
316311

317-
public List<ImmutableTaskGraph> getFfnLayerTaskGraphs() {
312+
public List<ImmutableTaskGraph> getFFNLayerTaskGraphs() {
318313
return ffnLayerTaskGraphs;
319314
}
320315

0 commit comments

Comments
 (0)