|
2 | 2 |
|
3 | 3 | import org.beehive.gpullama3.inference.state.State; |
4 | 4 | import org.beehive.gpullama3.inference.weights.Weights; |
5 | | -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; |
6 | 5 | import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; |
7 | 6 | import org.beehive.gpullama3.model.Configuration; |
8 | 7 | import org.beehive.gpullama3.model.granite.GraniteConfiguration; |
9 | 8 | import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; |
10 | 9 | import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; |
11 | 10 | import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; |
12 | | -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; |
13 | 11 | import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; |
14 | | -import uk.ac.manchester.tornado.api.GridScheduler; |
15 | | -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; |
16 | 12 | import uk.ac.manchester.tornado.api.TaskGraph; |
17 | | -import uk.ac.manchester.tornado.api.WorkerGrid; |
18 | | -import uk.ac.manchester.tornado.api.WorkerGrid1D; |
19 | 13 | import uk.ac.manchester.tornado.api.enums.DataTransferMode; |
20 | 14 |
|
| 15 | +/** |
| 16 | + * Granite-specific FP16 logits layer. |
| 17 | + * Identical to LogitsFP16Layer except vocab_proj uses a scaled kernel (logitScale). |
| 18 | + */ |
21 | 19 | public class LogitsGraniteFP16Layer extends LogitsFP16Layer { |
22 | | - private String lastTaskGraphID; |
23 | | - private TaskGraph logitsTaskGraph; |
24 | | - private ImmutableTaskGraph immutableLogitsGraph; |
25 | | - private GridScheduler scheduler; |
26 | | - private SchedulerType schedulerType; |
27 | 20 |
|
28 | | - public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { |
| 21 | + public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, |
| 22 | + String lastTaskGraphID, SchedulerType schedulerType) { |
29 | 23 | super(name, state, weights, config, lastTaskGraphID, schedulerType); |
30 | | - this.lastTaskGraphID = lastTaskGraphID; |
31 | | - this.schedulerType = schedulerType; |
32 | | - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); |
33 | | - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); |
34 | 24 | } |
35 | 25 |
|
36 | 26 | // @formatter:off |
37 | | - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { |
| 27 | + @Override |
| 28 | + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { |
| 29 | + GraniteConfiguration graniteCfg = (GraniteConfiguration) config; |
38 | 30 | var logits = new TaskGraph("logits"); |
| 31 | + |
39 | 32 | // === Data Setup === |
40 | 33 | logits.consumeFromDevice(lastTaskGraphID, state.wrapX); |
41 | 34 | logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); |
42 | 35 | logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, |
43 | | - // Kernel context |
44 | 36 | context, |
45 | | - // Output buffer |
46 | 37 | state.wrapLogits, |
47 | | - // Intermediate FP16 buffer |
48 | 38 | state.wrapXbFP16, |
49 | | - // Weights |
50 | 39 | weights.wclsByteArray.asHalfFloatArray(), |
51 | 40 | weights.rms_final_weight_as_floatArray.asFloatArray()); |
52 | 41 |
|
53 | 42 | // === Final RMS Normalization === |
54 | 43 | logits.task("rms_reduce", |
55 | 44 | TransformerComputeKernels::reductionOneBlockWithLayer, |
56 | 45 | context, |
57 | | - state.tempLogits, // output: partial sums + final scale factor |
58 | | - state.wrapX, // input: hidden state |
59 | | - config.dim(), // dimension |
60 | | - config.rmsNormEps(), // epsilon for numerical stability |
61 | | - state.localSize); // local workgroup size |
| 46 | + state.tempLogits, |
| 47 | + state.wrapX, |
| 48 | + config.dim(), |
| 49 | + config.rmsNormEps(), |
| 50 | + state.localSize); |
62 | 51 |
|
63 | 52 | if (schedulerType == SchedulerType.NON_NVIDIA) { |
64 | 53 | logits.task("rms_finalize", |
65 | 54 | TransformerComputeKernelsLayered::reductionFinalNormalization, |
66 | 55 | context, |
67 | | - state.tempLogits, // in/out: combines partial sums |
68 | | - config.dim(), // dimension |
69 | | - config.rmsNormEps()); // epsilon |
| 56 | + state.tempLogits, |
| 57 | + config.dim(), |
| 58 | + config.rmsNormEps()); |
70 | 59 | } |
71 | 60 |
|
72 | 61 | logits.task("rms_apply_fp16", |
73 | 62 | TransformerComputeKernels::mapContextWithQuantizeLogits, |
74 | 63 | context, |
75 | | - state.wrapXbFP16, // output: normalized (FP16) |
76 | | - state.wrapX, // input: hidden state |
77 | | - weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights |
78 | | - state.tempLogits); // scale factor from reduction |
| 64 | + state.wrapXbFP16, |
| 65 | + state.wrapX, |
| 66 | + weights.rms_final_weight_as_floatArray.asFloatArray(), |
| 67 | + state.tempLogits); |
79 | 68 |
|
80 | | - // === Vocabulary Projection === |
| 69 | + // === Vocabulary Projection (Granite: scaled by logitScale) === |
81 | 70 | logits.task("vocab_proj", |
82 | 71 | GraniteKernels::matrixVectorGenericWithGraniteScale, |
83 | 72 | context, |
84 | | - state.wrapXbFP16, // input (FP16) |
85 | | - state.wrapLogits, // output |
86 | | - weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights |
87 | | - config.dim(), // input dimension |
88 | | - config.vocabularySize(), // output dimension |
| 73 | + state.wrapXbFP16, |
| 74 | + state.wrapLogits, |
| 75 | + weights.wclsByteArray.asHalfFloatArray(), |
| 76 | + config.dim(), |
| 77 | + config.vocabularySize(), |
89 | 78 | LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, |
90 | | - config.logitScale()); // granite logit scaling |
| 79 | + graniteCfg.logitScale()); |
91 | 80 |
|
92 | | - // === Transfer Results to Host === |
93 | 81 | logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); |
94 | 82 | return logits; |
95 | 83 | } |
96 | 84 | // @formatter:on |
97 | | - |
98 | | - @Override |
99 | | - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { |
100 | | - WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); |
101 | | - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; |
102 | | - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); |
103 | | - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); |
104 | | - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); |
105 | | - tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); |
106 | | - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); |
107 | | - return tornadoForwardScheduler; |
108 | | - } |
109 | | - |
110 | | - @Override |
111 | | - public GridScheduler getGridScheduler() { |
112 | | - return scheduler; |
113 | | - } |
114 | | - |
115 | | - @Override |
116 | | - public TaskGraph getTaskGraph() { |
117 | | - return logitsTaskGraph; |
118 | | - } |
119 | | - |
120 | | - @Override |
121 | | - public ImmutableTaskGraph getImmutableTaskGraph() { |
122 | | - return immutableLogitsGraph; |
123 | | - } |
124 | 85 | } |
125 | | - |
0 commit comments