Skip to content

Commit 2e8ed31

Browse files
authored
Merge pull request #107 from AdamBien/main
Devstral 2 support (Mistral 3 architecture, Tekken tokenizer, YaRN …
2 parents 5324c09 + 36043f5 commit 2e8ed31

19 files changed

Lines changed: 1583 additions & 0 deletions

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.beehive.gpullama3.model.Configuration;
1313
import org.beehive.gpullama3.model.Model;
1414
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
15+
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
1516
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
1617
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1718
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
@@ -179,6 +180,95 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
179180
return state.logits;
180181
}
181182

183+
/**
184+
* Forward pass for Devstral 2 models where head_dim != dim/num_heads.
185+
* Q projection outputs qDim (num_heads * head_dim) instead of dim.
186+
*/
187+
public static FloatTensor forwardJavaDevstral(Model model, State state, int token, int position) {
188+
final DevstralConfiguration config = (DevstralConfiguration) model.configuration();
189+
final StandardWeights weights = (StandardWeights) model.weights();
190+
int dim = config.dim();
191+
int headSize = config.headSize(); // 128 (independent head_dim)
192+
int qDim = config.qDim(); // 4096 = 32 * 128
193+
int kvDim = config.kvDim(); // 1024 = 8 * 128
194+
int kvMul = config.kvMul();
195+
float sqrtHeadSize = (float) Math.sqrt(headSize);
196+
197+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
198+
199+
for (int l = 0; l < config.numberOfLayers(); l++) {
200+
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
201+
202+
weights.wq[l].matmul(state.xb, state.q, qDim, dim);
203+
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
204+
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
205+
206+
// RoPE over qDim (not dim)
207+
for (int i = 0; i < qDim; i += 2) {
208+
int head_dim = i % headSize;
209+
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
210+
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
211+
int rotn = i < kvDim ? 2 : 1;
212+
for (int v = 0; v < rotn; v++) {
213+
FloatTensor vec = v == 0 ? state.q : state.k;
214+
float v0 = vec.getFloat(i);
215+
float v1 = vec.getFloat(i + 1);
216+
vec.setFloat(i, v0 * fcr - v1 * fci);
217+
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
218+
}
219+
}
220+
221+
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
222+
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
223+
224+
int curLayer = l;
225+
226+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
227+
int qOffset = h * headSize;
228+
int attOffset = h * config.contextLength();
229+
230+
for (int t = 0; t <= position; t++) {
231+
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
232+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
233+
score /= sqrtHeadSize;
234+
state.att.setFloat(attOffset + t, score);
235+
}
236+
237+
state.att.softmaxInPlace(attOffset, position + 1);
238+
239+
int xbOffset = h * headSize;
240+
state.xb.fillInPlace(xbOffset, headSize, 0f);
241+
242+
for (int t = 0; t <= position; t++) {
243+
int vOffset = t * kvDim + (h / kvMul) * headSize;
244+
float a = state.att.getFloat(attOffset + t);
245+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
246+
}
247+
});
248+
249+
// O projection: input qDim, output dim
250+
weights.wo[l].matmul(state.xb, state.xb2, dim, qDim);
251+
252+
state.x.addInPlace(state.xb2);
253+
254+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
255+
256+
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
257+
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
258+
259+
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
260+
state.hb.multiplyInPlace(state.hb2);
261+
262+
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
263+
state.x.addInPlace(state.xb);
264+
}
265+
266+
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
267+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
268+
269+
return state.logits;
270+
}
271+
182272
public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
183273
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
184274
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();

src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,50 @@ public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int h
3535
assert contextLength * (headSize / 2) == n;
3636
return new Pair<>(cr, ci);
3737
}
38+
39+
public static Pair<float[], float[]> precomputeFreqsCisYaRN(int contextLength, int headSize, double theta,
40+
float factor, float betaFast, float betaSlow, float logMultiplier, int originalContextLength) {
41+
assert headSize % 2 == 0;
42+
float[] cr = new float[contextLength * (headSize / 2)];
43+
float[] ci = new float[contextLength * (headSize / 2)];
44+
45+
float freqScale = 1.0f / factor;
46+
47+
// Compute correlation dimensions for ramp interpolation
48+
float corrDim0 = yarnCorrDim(headSize, originalContextLength, betaFast, (float) theta);
49+
float corrDim1 = yarnCorrDim(headSize, originalContextLength, betaSlow, (float) theta);
50+
51+
// Compute mscale (attention scaling for extended context)
52+
// Formula: mscale = 0.1 * logMultiplier * log(factor) + 1.0
53+
float mscale = logMultiplier > 0
54+
? 1.0f + 0.1f * logMultiplier * (float) Math.log(1.0f / freqScale)
55+
: 1.0f;
56+
57+
int n = 0;
58+
for (int pos = 0; pos < contextLength; ++pos) {
59+
for (int i = 0; i < headSize; i += 2) {
60+
float freqExtrap = (float) (1.0 / Math.pow(theta, i / (double) headSize));
61+
float freqInterp = freqScale * freqExtrap;
62+
63+
float rampMix = yarnRamp(corrDim0, corrDim1, i / 2);
64+
float freq = freqInterp * (1.0f - rampMix) + freqExtrap * rampMix;
65+
66+
float val = pos * freq;
67+
cr[n] = (float) Math.cos(val) * mscale;
68+
ci[n] = (float) Math.sin(val) * mscale;
69+
n++;
70+
}
71+
}
72+
assert contextLength * (headSize / 2) == n;
73+
return new Pair<>(cr, ci);
74+
}
75+
76+
private static float yarnCorrDim(int nDims, int nCtxOrig, float nRot, float base) {
77+
return nDims * (float) Math.log(nCtxOrig / (nRot * 2.0f * (float) Math.PI)) / (2.0f * (float) Math.log(base));
78+
}
79+
80+
private static float yarnRamp(float low, float high, int i0) {
81+
float y = (i0 - low) / Math.max(0.001f, high - low);
82+
return 1.0f - Math.min(1.0f, Math.max(0.0f, y));
83+
}
3884
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package org.beehive.gpullama3.inference.state;
2+
3+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
4+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
7+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
9+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
10+
11+
import java.util.stream.Stream;
12+
13+
/**
14+
* State for Devstral 2 models where head_dim != dim/num_heads.
15+
* Allocates Q with qDim (num_heads * head_dim) and K/V with kvDim (num_kv_heads * head_dim).
16+
*/
17+
public final class DevstralState extends State {
18+
19+
public DevstralState(Configuration config, int batchsize) {
20+
super(config, batchsize);
21+
}
22+
23+
@Override
24+
protected StateFields createStateFields(Configuration config) {
25+
DevstralConfiguration dc = (DevstralConfiguration) config;
26+
StateFields fields = new StateFields();
27+
28+
int qDim = dc.qDim();
29+
int kvDim = dc.kvDim();
30+
31+
fields.x = ArrayFloatTensor.allocate(dc.dim());
32+
fields.xb = ArrayFloatTensor.allocate(dc.dim());
33+
fields.xb2 = ArrayFloatTensor.allocate(dc.dim());
34+
fields.hb = ArrayFloatTensor.allocate(dc.hiddenDim());
35+
fields.hb2 = ArrayFloatTensor.allocate(dc.hiddenDim());
36+
fields.q = ArrayFloatTensor.allocate(qDim);
37+
fields.k = ArrayFloatTensor.allocate(kvDim);
38+
fields.v = ArrayFloatTensor.allocate(kvDim);
39+
fields.att = ArrayFloatTensor.allocate(dc.numberOfHeads(), dc.contextLength());
40+
fields.logits = ArrayFloatTensor.allocate(dc.vocabularySize());
41+
42+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new);
43+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new);
44+
45+
// TornadoVM wrappers
46+
fields.wrapX = new FloatArray(dc.dim());
47+
fields.wrapXb = new FloatArray(dc.dim());
48+
fields.wrapXb2 = new FloatArray(dc.dim());
49+
fields.wrapHb = new FloatArray(dc.hiddenDim());
50+
fields.wrapHb2 = new FloatArray(dc.hiddenDim());
51+
52+
switch (dc.quantization()) {
53+
case "FP16" -> fields.createActivationFP16(dc.dim());
54+
case "Q8_0" -> fields.createActivationQ8_0(dc.dim());
55+
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + dc.quantization());
56+
}
57+
fields.wrapLogits = new FloatArray(dc.vocabularySize());
58+
fields.wrapQ = new FloatArray(qDim);
59+
fields.wrapK = new FloatArray(kvDim);
60+
fields.wrapV = new FloatArray(kvDim);
61+
62+
fields.wrapXFP16 = new HalfFloatArray(dc.dim());
63+
fields.wrapXbFP16 = new HalfFloatArray(dc.dim());
64+
fields.wrapKeyCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers());
65+
fields.wrapValueCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers());
66+
fields.wrapValueCache.init(0.f);
67+
fields.wrapKeyCache.init(0.f);
68+
fields.wrapAtt = new FloatArray(dc.numberOfHeads() * dc.contextLength());
69+
fields.positionHolder = new IntArray(1);
70+
71+
fields.temp = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));
72+
fields.tempFFN = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));
73+
fields.tempLogits = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));
74+
75+
return fields;
76+
}
77+
}

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.beehive.gpullama3.model;
22

3+
import org.beehive.gpullama3.model.loader.DevstralModelLoader;
34
import org.beehive.gpullama3.model.loader.GraniteLoader;
45
import org.beehive.gpullama3.tensor.GGUF;
56
import org.beehive.gpullama3.model.loader.LlamaModelLoader;
@@ -37,6 +38,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
3738
}
3839
},
3940

41+
DEVSTRAL_2 {
42+
@Override
43+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
44+
return new DevstralModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
45+
}
46+
},
47+
4048
QWEN_2 {
4149
@Override
4250
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package org.beehive.gpullama3.model.devstral;
2+
3+
import org.beehive.gpullama3.inference.InferenceCore;
4+
import org.beehive.gpullama3.inference.InferenceEngine;
5+
import org.beehive.gpullama3.inference.sampler.Sampler;
6+
import org.beehive.gpullama3.inference.state.DevstralState;
7+
import org.beehive.gpullama3.inference.state.State;
8+
import org.beehive.gpullama3.inference.weights.Weights;
9+
import org.beehive.gpullama3.model.AbstractModel;
10+
import org.beehive.gpullama3.model.ModelType;
11+
import org.beehive.gpullama3.model.format.ChatFormat;
12+
import org.beehive.gpullama3.tokenizer.DevstralTokenizer;
13+
import org.beehive.gpullama3.tokenizer.Tokenizer;
14+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
15+
16+
import java.util.List;
17+
import java.util.Set;
18+
import java.util.function.IntConsumer;
19+
20+
public class Devstral extends AbstractModel {
21+
22+
DevstralConfiguration configuration;
23+
24+
public Devstral(DevstralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
25+
super(tokenizer, weights, chatFormat, null);
26+
this.configuration = configuration;
27+
}
28+
29+
@Override
30+
public DevstralConfiguration configuration() {
31+
return configuration;
32+
}
33+
34+
@Override
35+
public DevstralTokenizer tokenizer() {
36+
return (DevstralTokenizer) tokenizer;
37+
}
38+
39+
@Override
40+
public ModelType getModelType() {
41+
return ModelType.DEVSTRAL_2;
42+
}
43+
44+
public State createNewState() {
45+
State state = new DevstralState(configuration(), -1);
46+
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
47+
return state;
48+
}
49+
50+
public State createNewState(int batchsize) {
51+
State state = new DevstralState(configuration(), batchsize);
52+
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
53+
return state;
54+
}
55+
56+
@Override
57+
public void forward(State state, int token, int position) {
58+
InferenceCore.forwardJavaDevstral(this, state, token, position);
59+
}
60+
61+
@Override
62+
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
63+
IntConsumer onTokenGenerated) {
64+
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
65+
}
66+
67+
@Override
68+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
69+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
70+
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
71+
}
72+
73+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package org.beehive.gpullama3.model.devstral;
2+
3+
import org.beehive.gpullama3.model.Configuration;
4+
5+
/**
6+
* Configuration for Devstral 2 models (Mistral 3 architecture).
7+
* Unlike standard Mistral, Devstral 2 has an independent head dimension
8+
* (head_dim != dim / num_heads), requiring explicit key_length/value_length.
9+
*/
10+
// @formatter:off
11+
public record DevstralConfiguration(String quantization,
12+
int dim,
13+
int hiddenDim,
14+
int numberOfLayers,
15+
int numberOfHeads,
16+
int numberOfKeyValueHeads,
17+
int headDim,
18+
int vocabularySize,
19+
int contextLength,
20+
float rmsNormEps,
21+
float ropeTheta) implements Configuration {
22+
23+
@Override public String quantization() {
24+
return quantization;
25+
}
26+
27+
/**
28+
* Q projection output dimension = numberOfHeads * headDim.
29+
* This differs from dim when headDim != dim/numberOfHeads.
30+
*/
31+
public int qDim() {
32+
return numberOfHeads * headDim;
33+
}
34+
35+
public int kvDim() {
36+
return numberOfKeyValueHeads * headDim;
37+
}
38+
39+
public int kvMul() {
40+
return numberOfHeads / numberOfKeyValueHeads;
41+
}
42+
43+
@Override
44+
public int numberOfHeadsKey() {
45+
throw new UnsupportedOperationException("Not supported for Devstral.");
46+
}
47+
48+
@Override
49+
public int contextLengthModel() {
50+
throw new UnsupportedOperationException("Not supported for Devstral.");
51+
}
52+
53+
public int headSize() {
54+
return headDim;
55+
}
56+
}

0 commit comments

Comments
 (0)