Skip to content

Commit 4be811a

Browse files
[refactor] Unify task graph setup for Logits layers and centralize shared logic into AbstractLogitsLayer
1 parent a26f2a9 commit 4be811a

5 files changed

Lines changed: 114 additions & 231 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,10 @@ protected AbstractLogitsLayer(String name, State state, Weights weights, Configu
2929
this.schedulerType = schedulerType;
3030
TornadoWeights tornadoWeights = requireWeightsType(weights, TornadoWeights.class,
3131
getClass().getSimpleName(), "TornadoTensor");
32-
this.logitsTaskGraph = buildLogitsTaskGraph(tornadoWeights, config);
32+
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
3333
}
3434

35-
/**
36-
* Builds the logits task graph. Called once from the constructor.
37-
* Subclasses define the quantization-specific task sequence here.
38-
*/
39-
protected abstract TaskGraph buildLogitsTaskGraph(TornadoWeights weights, Configuration config);
35+
protected abstract TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config);
4036

4137
public final TaskGraph getTaskGraph() {
4238
return logitsTaskGraph;

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,29 @@
22

33
import org.beehive.gpullama3.inference.state.State;
44
import org.beehive.gpullama3.inference.weights.Weights;
5-
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
65
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
76
import org.beehive.gpullama3.model.Configuration;
87
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
98
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
9+
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
1010
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
1111
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
12-
import org.beehive.gpullama3.tornadovm.layers.AbstractLayer;
12+
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
1313
import uk.ac.manchester.tornado.api.GridScheduler;
14-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1514
import uk.ac.manchester.tornado.api.TaskGraph;
16-
import uk.ac.manchester.tornado.api.WorkerGrid;
1715
import uk.ac.manchester.tornado.api.WorkerGrid1D;
1816
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1917

20-
public class LogitsFP16Layer extends AbstractLayer {
21-
22-
private String lastTaskGraphID;
23-
private TaskGraph logitsTaskGraph;
24-
private ImmutableTaskGraph immutableLogitsGraph;
25-
private GridScheduler scheduler;
26-
private SchedulerType schedulerType;
18+
public class LogitsFP16Layer extends AbstractLogitsLayer {
2719

28-
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
29-
super(name, state, weights, config);
30-
this.lastTaskGraphID = lastTaskGraphID;
31-
this.schedulerType = schedulerType;
32-
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
33-
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
20+
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config,
21+
String lastTaskGraphID, SchedulerType schedulerType) {
22+
super(name, state, weights, config, lastTaskGraphID, schedulerType);
3423
}
3524

3625
// @formatter:off
37-
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
26+
@Override
27+
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
3828
var logits = new TaskGraph("logits");
3929
// === Data Setup ===
4030
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
@@ -96,7 +86,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con
9686

9787
@Override
9888
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
99-
WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
89+
var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), rmsLocalSize());
10090
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
10191
var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
10292
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
@@ -106,18 +96,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
10696
return tornadoForwardScheduler;
10797
}
10898

109-
@Override
110-
public GridScheduler getGridScheduler() {
111-
return scheduler;
112-
}
113-
114-
@Override
115-
public TaskGraph getTaskGraph() {
116-
return logitsTaskGraph;
117-
}
118-
119-
@Override
120-
public ImmutableTaskGraph getImmutableTaskGraph() {
121-
return immutableLogitsGraph;
99+
/** Local workgroup size for RMS norm. Qwen2 requires a smaller group (32 vs 256). */
100+
protected int rmsLocalSize() {
101+
return weights instanceof Qwen2TornadoWeights ? 32 : 256;
122102
}
123103
}

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java

Lines changed: 29 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,124 +2,84 @@
22

33
import org.beehive.gpullama3.inference.state.State;
44
import org.beehive.gpullama3.inference.weights.Weights;
5-
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
65
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
76
import org.beehive.gpullama3.model.Configuration;
87
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
98
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
109
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
1110
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
12-
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
1311
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
14-
import uk.ac.manchester.tornado.api.GridScheduler;
15-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1612
import uk.ac.manchester.tornado.api.TaskGraph;
17-
import uk.ac.manchester.tornado.api.WorkerGrid;
18-
import uk.ac.manchester.tornado.api.WorkerGrid1D;
1913
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
2014

15+
/**
16+
* Granite-specific FP16 logits layer.
17+
* Identical to LogitsFP16Layer except vocab_proj uses a scaled kernel (logitScale).
18+
*/
2119
public class LogitsGraniteFP16Layer extends LogitsFP16Layer {
22-
private String lastTaskGraphID;
23-
private TaskGraph logitsTaskGraph;
24-
private ImmutableTaskGraph immutableLogitsGraph;
25-
private GridScheduler scheduler;
26-
private SchedulerType schedulerType;
2720

28-
public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
21+
public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config,
22+
String lastTaskGraphID, SchedulerType schedulerType) {
2923
super(name, state, weights, config, lastTaskGraphID, schedulerType);
30-
this.lastTaskGraphID = lastTaskGraphID;
31-
this.schedulerType = schedulerType;
32-
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
33-
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config);
3424
}
3525

3626
// @formatter:off
37-
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) {
27+
@Override
28+
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
29+
GraniteConfiguration graniteCfg = (GraniteConfiguration) config;
3830
var logits = new TaskGraph("logits");
31+
3932
// === Data Setup ===
4033
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
4134
logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
4235
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
43-
// Kernel context
4436
context,
45-
// Output buffer
4637
state.wrapLogits,
47-
// Intermediate FP16 buffer
4838
state.wrapXbFP16,
49-
// Weights
5039
weights.wclsByteArray.asHalfFloatArray(),
5140
weights.rms_final_weight_as_floatArray.asFloatArray());
5241

5342
// === Final RMS Normalization ===
5443
logits.task("rms_reduce",
5544
TransformerComputeKernels::reductionOneBlockWithLayer,
5645
context,
57-
state.tempLogits, // output: partial sums + final scale factor
58-
state.wrapX, // input: hidden state
59-
config.dim(), // dimension
60-
config.rmsNormEps(), // epsilon for numerical stability
61-
state.localSize); // local workgroup size
46+
state.tempLogits,
47+
state.wrapX,
48+
config.dim(),
49+
config.rmsNormEps(),
50+
state.localSize);
6251

6352
if (schedulerType == SchedulerType.NON_NVIDIA) {
6453
logits.task("rms_finalize",
6554
TransformerComputeKernelsLayered::reductionFinalNormalization,
6655
context,
67-
state.tempLogits, // in/out: combines partial sums
68-
config.dim(), // dimension
69-
config.rmsNormEps()); // epsilon
56+
state.tempLogits,
57+
config.dim(),
58+
config.rmsNormEps());
7059
}
7160

7261
logits.task("rms_apply_fp16",
7362
TransformerComputeKernels::mapContextWithQuantizeLogits,
7463
context,
75-
state.wrapXbFP16, // output: normalized (FP16)
76-
state.wrapX, // input: hidden state
77-
weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights
78-
state.tempLogits); // scale factor from reduction
64+
state.wrapXbFP16,
65+
state.wrapX,
66+
weights.rms_final_weight_as_floatArray.asFloatArray(),
67+
state.tempLogits);
7968

80-
// === Vocabulary Projection ===
69+
// === Vocabulary Projection (Granite: scaled by logitScale) ===
8170
logits.task("vocab_proj",
8271
GraniteKernels::matrixVectorGenericWithGraniteScale,
8372
context,
84-
state.wrapXbFP16, // input (FP16)
85-
state.wrapLogits, // output
86-
weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights
87-
config.dim(), // input dimension
88-
config.vocabularySize(), // output dimension
73+
state.wrapXbFP16,
74+
state.wrapLogits,
75+
weights.wclsByteArray.asHalfFloatArray(),
76+
config.dim(),
77+
config.vocabularySize(),
8978
LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS,
90-
config.logitScale()); // granite logit scaling
79+
graniteCfg.logitScale());
9180

92-
// === Transfer Results to Host ===
9381
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
9482
return logits;
9583
}
9684
// @formatter:on
97-
98-
@Override
99-
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
100-
WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
101-
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
102-
var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
103-
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
104-
tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
105-
tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS);
106-
tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
107-
return tornadoForwardScheduler;
108-
}
109-
110-
@Override
111-
public GridScheduler getGridScheduler() {
112-
return scheduler;
113-
}
114-
115-
@Override
116-
public TaskGraph getTaskGraph() {
117-
return logitsTaskGraph;
118-
}
119-
120-
@Override
121-
public ImmutableTaskGraph getImmutableTaskGraph() {
122-
return immutableLogitsGraph;
123-
}
12485
}
125-

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java

Lines changed: 25 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,51 @@
22

33
import org.beehive.gpullama3.inference.state.State;
44
import org.beehive.gpullama3.inference.weights.Weights;
5-
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
65
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
76
import org.beehive.gpullama3.model.Configuration;
87
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
98
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
109
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
1110
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
12-
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
1311
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
14-
import uk.ac.manchester.tornado.api.GridScheduler;
15-
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
1612
import uk.ac.manchester.tornado.api.TaskGraph;
17-
import uk.ac.manchester.tornado.api.WorkerGrid1D;
1813
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1914

20-
public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{
21-
private String lastTaskGraphID;
22-
private TaskGraph logitsTaskGraph;
23-
private ImmutableTaskGraph immutableLogitsGraph;
24-
private GridScheduler scheduler;
25-
private SchedulerType schedulerType;
15+
/**
16+
* Granite-specific Q8_0 logits layer.
17+
* Identical to LogitsQ8_0Layer except vocab_proj uses a scaled kernel (logitScale).
18+
*/
19+
public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer {
2620

27-
public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
28-
super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType);
29-
this.lastTaskGraphID = lastTaskGraphID;
30-
var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor");
31-
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config);
32-
this.schedulerType = schedulerType;
33-
}
34-
35-
@Override
36-
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
37-
var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
38-
var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
39-
var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
40-
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
41-
tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
42-
tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
43-
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
44-
return tornadoForwardScheduler;
21+
public LogitsGraniteQ8_0Layer(String name, State state, Weights weights, Configuration config,
22+
String lastTaskGraphID, SchedulerType schedulerType) {
23+
super(name, state, weights, config, lastTaskGraphID, schedulerType);
4524
}
4625

4726
// @formatter:off
48-
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) {
27+
@Override
28+
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
29+
GraniteConfiguration graniteCfg = (GraniteConfiguration) config;
4930
var logits = new TaskGraph("logits");
31+
5032
// === Data Setup ===
5133
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
5234
logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
5335
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
54-
context, //
55-
state.wrapLogits, //
56-
weights.wclsByteArray.asByteArray(), //
36+
context,
37+
state.wrapLogits,
38+
weights.wclsByteArray.asByteArray(),
5739
weights.rms_final_weight_as_floatArray);
5840

5941
// === Final RMS Normalization ===
6042
logits.task("rms_reduce",
6143
TransformerComputeKernels::reductionOneBlockWithLayer,
6244
context,
63-
state.tempLogits, // output: partial sums + final scale factor
64-
state.wrapX, // input: hidden state
65-
config.dim(), // dimension
66-
config.rmsNormEps(), // epsilon for numerical stability
67-
state.localSize); // local workgroup size
45+
state.tempLogits,
46+
state.wrapX,
47+
config.dim(),
48+
config.rmsNormEps(),
49+
state.localSize);
6850

6951
if (schedulerType == SchedulerType.NON_NVIDIA) {
7052
logits.task("rms_finalize",
@@ -74,45 +56,28 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat
7456
config.dim(),
7557
config.rmsNormEps());
7658
}
59+
7760
logits.task("mapContextLogits",
7861
TransformerComputeKernels::reductionOneBlock2WithLogits,
7962
context,
8063
state.wrapX,
8164
weights.rms_final_weight_as_floatArray.asFloatArray(),
8265
state.tempLogits);
8366

84-
// === Vocabulary vocab_proj ===
85-
logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, //
67+
// === Vocabulary Projection (Granite: scaled by logitScale) ===
68+
logits.task("vocab_proj",
69+
GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale,
8670
context,
8771
state.wrapX,
8872
state.wrapLogits,
8973
weights.wclsByteArray.asByteArray(),
9074
config.dim(),
9175
config.vocabularySize(),
9276
LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS,
93-
config.logitScale()
77+
graniteCfg.logitScale());
9478

95-
);
96-
97-
// === Transfer Results to Host ===
9879
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
9980
return logits;
10081
}
10182
// @formatter:on
102-
103-
@Override
104-
public GridScheduler getGridScheduler() {
105-
return scheduler;
106-
}
107-
108-
@Override
109-
public TaskGraph getTaskGraph() {
110-
return logitsTaskGraph;
111-
}
112-
113-
@Override
114-
public ImmutableTaskGraph getImmutableTaskGraph() {
115-
return immutableLogitsGraph;
116-
}
117-
11883
}

0 commit comments

Comments
 (0)