Skip to content

Commit 49eb298

Browse files
committed
Fix logits scaling in GraniteKernels: correct scaling order for hb and output writes.
1 parent c2b91c4 commit 49eb298

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ public static FloatTensor forwardGranite(Model model, State state, int token, in
663663
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
664664

665665
// Apply Granite logit scaling (divide by the scaling factor)
666-
state.logits.mapInPlace(v -> v / logitScale);
666+
state.logits.mapInPlace(v -> v * logitScale);
667667

668668
return state.logits;
669669
}

src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public static void matrixVectorGenericWithGraniteScale(
7777

7878
// Thread 0 in each workgroup writes the final result
7979
if (localId == 0) {
80-
hb.set(rowId, sum);
80+
hb.set(rowId, sum * logitsScale);
8181
}
8282
}
8383

@@ -156,7 +156,6 @@ public static void processHeadsFlashAttentionWithGraniteScale(KernelContext cont
156156
score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
157157
}
158158
score *= attentionScale;
159-
// score /= TornadoMath.sqrt(headSize);
160159
s_tile[score_idx_in_tile] = score;
161160
}
162161

@@ -339,7 +338,7 @@ public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext conte
339338

340339
// Thread 0 writes the result
341340
if (localId == 0) {
342-
output.set(rowId, logitsScale * sum);
341+
output.set(rowId, sum * logitsScale);
343342
}
344343
}
345344

0 commit comments

Comments
 (0)