Skip to content

Commit 3c09bca

Browse files
[refactor] Introduce AbstractLogitsLayer to centralize shared logic for logits layers
1 parent bf6823d commit 3c09bca

1 file changed

Lines changed: 48 additions & 0 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package org.beehive.gpullama3.tornadovm.layers;
2+
3+
import org.beehive.gpullama3.inference.state.State;
4+
import org.beehive.gpullama3.inference.weights.Weights;
5+
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
6+
import org.beehive.gpullama3.model.Configuration;
7+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
8+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
9+
import uk.ac.manchester.tornado.api.TaskGraph;
10+
11+
/**
12+
* Abstract base for all logits layers (final vocabulary projection step).
13+
*
14+
* Holds the shared fields and calls the protected buildLogitsTaskGraph() hook once
15+
* during construction. Subclasses implement buildLogitsTaskGraph() to define the
16+
* quantization-specific task sequence; Granite variants override it to swap in
17+
* their scaled kernel.
18+
*/
19+
public abstract class AbstractLogitsLayer extends AbstractLayer {
20+
21+
protected final String lastTaskGraphID;
22+
protected final SchedulerType schedulerType;
23+
private final TaskGraph logitsTaskGraph;
24+
25+
protected AbstractLogitsLayer(String name, State state, Weights weights, Configuration config,
26+
String lastTaskGraphID, SchedulerType schedulerType) {
27+
super(name, state, weights, config);
28+
this.lastTaskGraphID = lastTaskGraphID;
29+
this.schedulerType = schedulerType;
30+
TornadoWeights tornadoWeights = requireWeightsType(weights, TornadoWeights.class,
31+
getClass().getSimpleName(), "TornadoTensor");
32+
this.logitsTaskGraph = buildLogitsTaskGraph(tornadoWeights, config);
33+
}
34+
35+
/**
36+
* Builds the logits task graph. Called once from the constructor.
37+
* Subclasses define the quantization-specific task sequence here.
38+
*/
39+
protected abstract TaskGraph buildLogitsTaskGraph(TornadoWeights weights, Configuration config);
40+
41+
public final TaskGraph getTaskGraph() {
42+
return logitsTaskGraph;
43+
}
44+
45+
public final ImmutableTaskGraph getImmutableTaskGraph() {
46+
return logitsTaskGraph.snapshot();
47+
}
48+
}

0 commit comments

Comments
 (0)