|
7 | 7 | import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; |
8 | 8 | import uk.ac.manchester.tornado.api.GridScheduler; |
9 | 9 | import uk.ac.manchester.tornado.api.ImmutableTaskGraph; |
10 | | -import uk.ac.manchester.tornado.api.KernelContext; |
11 | 10 | import uk.ac.manchester.tornado.api.TaskGraph; |
12 | 11 | import uk.ac.manchester.tornado.api.WorkerGrid; |
13 | | -import uk.ac.manchester.tornado.api.WorkerGrid1D; |
14 | 12 | import uk.ac.manchester.tornado.api.enums.DataTransferMode; |
15 | 13 | import uk.ac.manchester.tornado.api.types.arrays.ByteArray; |
16 | 14 | import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; |
17 | 15 |
|
18 | 16 | public class Activation extends AbstractLayer { |
19 | | - private final TaskGraph activationUpdate; |
20 | | - |
21 | | - public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { |
22 | | - super(taskGraphHandle, state, weights, config); |
23 | | - |
24 | | - KernelContext kernelContext = new KernelContext(); |
25 | | - |
26 | | - // @formatter:off |
27 | | - switch (config.quantization()) { |
28 | | - case "FP16" -> { |
29 | | - this.activationUpdate = new TaskGraph(taskGraphHandle) |
30 | | - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) |
31 | | - .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX) |
32 | | - .persistOnDevice(state.wrapX); |
33 | | - } |
34 | | - case "Q8_0" -> { |
35 | | - this.activationUpdate = new TaskGraph(taskGraphHandle) |
36 | | - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) |
37 | | - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX) |
38 | | - .persistOnDevice(state.wrapX); |
39 | | - } |
| 17 | + private final TaskGraph activationTaskGraph; |
| 18 | + |
| 19 | + public Activation(String name, State state, Weights weights, Configuration config) { |
| 20 | + super(name, state, weights, config); |
| 21 | + this.activationTaskGraph = setupActivationTaskGraph(name); |
| 22 | + } |
| 23 | + |
| 24 | + // @formatter:off |
| 25 | + protected TaskGraph setupActivationTaskGraph(String name) { |
| 26 | + return switch (config.quantization()) { |
| 27 | + case "FP16" -> new TaskGraph(name) |
| 28 | + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) |
| 29 | + .task("updateX", TransformerComputeKernels::convertFP16toFP32, context, (HalfFloatArray) state.embeddingX, state.wrapX) |
| 30 | + .persistOnDevice(state.wrapX); |
| 31 | + case "Q8_0" -> new TaskGraph(name) |
| 32 | + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) |
| 33 | + .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, context, (ByteArray) state.embeddingX, state.wrapX) |
| 34 | + .persistOnDevice(state.wrapX); |
40 | 35 | default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); |
41 | | - } |
42 | | - // @formatter:on |
| 36 | + }; |
43 | 37 | } |
| 38 | + // @formatter:on |
44 | 39 |
|
45 | 40 | @Override |
46 | 41 | public GridScheduler updateGridScheduler(GridScheduler scheduler) { |
47 | | - WorkerGrid worker = new WorkerGrid1D(config.dim()); |
48 | | - worker.setLocalWork(128, 1, 1); |
49 | | - scheduler.addWorkerGrid("activationUpdate.updateX", worker); |
| 42 | + WorkerGrid worker = WorkerGridFactory.genericWorker(config.dim(), 128); |
| 43 | + scheduler.addWorkerGrid(activationTaskGraph.getTaskGraphName() + ".updateX", worker); |
50 | 44 | return scheduler; |
51 | 45 | } |
52 | 46 |
|
53 | 47 | public TaskGraph getTaskGraph() { |
54 | | - return activationUpdate; |
| 48 | + return activationTaskGraph; |
55 | 49 | } |
56 | 50 |
|
57 | 51 | public ImmutableTaskGraph getImmutableTaskGraph() { |
58 | | - return activationUpdate.snapshot(); |
| 52 | + return activationTaskGraph.snapshot(); |
59 | 53 | } |
60 | 54 |
|
61 | 55 | } |
|
0 commit comments