Skip to content

Commit 3aa399b

Browse files
[refactor] Simplify and unify Activation task graph setup logic
1 parent 4be811a commit 3aa399b

3 files changed

Lines changed: 48 additions & 54 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ protected AbstractLayer(String taskGraphName, State state, Weights weights, Conf
2727
this.config = config;
2828
}
2929

30+
/**
31+
* Ensures weights are of the expected type.
32+
*/
3033
@SuppressWarnings("unchecked")
3134
protected static <T> T requireWeightsType(Object weights, Class<T> expectedType, String layerName, String layout) {
3235
if (expectedType.isInstance(weights)) {

src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,49 @@
77
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
88
import uk.ac.manchester.tornado.api.GridScheduler;
99
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
10-
import uk.ac.manchester.tornado.api.KernelContext;
1110
import uk.ac.manchester.tornado.api.TaskGraph;
1211
import uk.ac.manchester.tornado.api.WorkerGrid;
13-
import uk.ac.manchester.tornado.api.WorkerGrid1D;
1412
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1513
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
1614
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
1715

1816
public class Activation extends AbstractLayer {
19-
private final TaskGraph activationUpdate;
20-
21-
public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) {
22-
super(taskGraphHandle, state, weights, config);
23-
24-
KernelContext kernelContext = new KernelContext();
25-
26-
// @formatter:off
27-
switch (config.quantization()) {
28-
case "FP16" -> {
29-
this.activationUpdate = new TaskGraph(taskGraphHandle)
30-
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
31-
.task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX)
32-
.persistOnDevice(state.wrapX);
33-
}
34-
case "Q8_0" -> {
35-
this.activationUpdate = new TaskGraph(taskGraphHandle)
36-
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
37-
.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX)
38-
.persistOnDevice(state.wrapX);
39-
}
17+
private final TaskGraph activationTaskGraph;
18+
19+
public Activation(String name, State state, Weights weights, Configuration config) {
20+
super(name, state, weights, config);
21+
this.activationTaskGraph = setupActivationTaskGraph(name);
22+
}
23+
24+
// @formatter:off
25+
protected TaskGraph setupActivationTaskGraph(String name) {
26+
return switch (config.quantization()) {
27+
case "FP16" -> new TaskGraph(name)
28+
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
29+
.task("updateX", TransformerComputeKernels::convertFP16toFP32, context, (HalfFloatArray) state.embeddingX, state.wrapX)
30+
.persistOnDevice(state.wrapX);
31+
case "Q8_0" -> new TaskGraph(name)
32+
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
33+
.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, context, (ByteArray) state.embeddingX, state.wrapX)
34+
.persistOnDevice(state.wrapX);
4035
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
41-
}
42-
// @formatter:on
36+
};
4337
}
38+
// @formatter:on
4439

4540
@Override
4641
public GridScheduler updateGridScheduler(GridScheduler scheduler) {
47-
WorkerGrid worker = new WorkerGrid1D(config.dim());
48-
worker.setLocalWork(128, 1, 1);
49-
scheduler.addWorkerGrid("activationUpdate.updateX", worker);
42+
WorkerGrid worker = WorkerGridFactory.genericWorker(config.dim(), 128);
43+
scheduler.addWorkerGrid(activationTaskGraph.getTaskGraphName() + ".updateX", worker);
5044
return scheduler;
5145
}
5246

5347
public TaskGraph getTaskGraph() {
54-
return activationUpdate;
48+
return activationTaskGraph;
5549
}
5650

5751
public ImmutableTaskGraph getImmutableTaskGraph() {
58-
return activationUpdate.snapshot();
52+
return activationTaskGraph.snapshot();
5953
}
6054

6155
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,37 @@
44
import org.beehive.gpullama3.inference.weights.Weights;
55
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
66
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
7-
import uk.ac.manchester.tornado.api.GridScheduler;
8-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
9-
import uk.ac.manchester.tornado.api.KernelContext;
107
import uk.ac.manchester.tornado.api.TaskGraph;
11-
import uk.ac.manchester.tornado.api.WorkerGrid;
12-
import uk.ac.manchester.tornado.api.WorkerGrid1D;
138
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
149
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
1510
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
1611

12+
/**
13+
* Granite-specific activation: applies an embedding scale factor during the FP32 conversion.
14+
* Overrides only the task graph builder; all other behaviour is inherited from Activation.
15+
*/
1716
public class ActivationGranite extends Activation {
18-
private final TaskGraph activationUpdate;
1917

2018
// Granite is a special case where activation X is scaled by embedding scale float value that inside model.
2119
public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) {
2220
super(taskGraphHandle, state, weights, config);
21+
}
2322

24-
KernelContext kernelContext = new KernelContext();
25-
26-
// @formatter:off
27-
switch (config.quantization()) {
28-
case "FP16" -> {
29-
this.activationUpdate = new TaskGraph(taskGraphHandle)
30-
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
31-
.task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale())
32-
.persistOnDevice(state.wrapX);
33-
}
34-
case "Q8_0" -> {
35-
this.activationUpdate = new TaskGraph(taskGraphHandle)
36-
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
37-
.task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale())
38-
.persistOnDevice(state.wrapX);
39-
}
23+
// @formatter:off
24+
@Override
25+
protected TaskGraph setupActivationTaskGraph(String handle) {
26+
GraniteConfiguration cfg = (GraniteConfiguration) config;
27+
return switch (config.quantization()) {
28+
case "FP16" -> new TaskGraph(handle)
29+
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
30+
.task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, context, (HalfFloatArray) state.embeddingX, state.wrapX, cfg.embeddingScale())
31+
.persistOnDevice(state.wrapX);
32+
case "Q8_0" -> new TaskGraph(handle)
33+
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
34+
.task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, context, (ByteArray) state.embeddingX, state.wrapX, cfg.embeddingScale())
35+
.persistOnDevice(state.wrapX);
4036
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
41-
}
42-
// @formatter:on
37+
};
4338
}
39+
// @formatter:on
40+
}

0 commit comments

Comments
 (0)