Skip to content

Commit 8024505

Browse files
committed
test(bfloat16): add Phase 2 NaN propagation and underflow tests
- NaN propagation through all 4 arithmetic ops (add, sub, mul, div) with each ArithmeticMode (IEEE, fast, exact) — 24 sub-tests - Gradual underflow at subnormal boundary for mul, div, add, sub - FMA correctness: basic arithmetic, NaN in each position, precision
1 parent 5ddc6a1 commit 8024505

1 file changed

Lines changed: 177 additions & 31 deletions

File tree

bfloat16_arithmetic_test.go

Lines changed: 177 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -252,57 +252,203 @@ func TestBFloat16FMA(t *testing.T) {
252252
}
253253
}
254254

255-
func TestBFloat16NaNPropagation(t *testing.T) {
255+
func TestBFloat16NaNPropagationAllModes(t *testing.T) {
256256
nan := BFloat16QuietNaN
257257
one := BFloat16FromFloat32(1)
258258

259+
type opFunc func(a, b BFloat16, mode ArithmeticMode, rounding RoundingMode) (BFloat16, error)
260+
259261
ops := []struct {
260262
name string
261-
fn func() (BFloat16, error)
263+
fn opFunc
264+
a, b BFloat16
265+
}{
266+
{"add(NaN,1)", BFloat16AddWithMode, nan, one},
267+
{"add(1,NaN)", BFloat16AddWithMode, one, nan},
268+
{"sub(NaN,1)", BFloat16SubWithMode, nan, one},
269+
{"sub(1,NaN)", BFloat16SubWithMode, one, nan},
270+
{"mul(NaN,1)", BFloat16MulWithMode, nan, one},
271+
{"mul(1,NaN)", BFloat16MulWithMode, one, nan},
272+
{"div(NaN,1)", BFloat16DivWithMode, nan, one},
273+
{"div(1,NaN)", BFloat16DivWithMode, one, nan},
274+
}
275+
276+
modes := []struct {
277+
name string
278+
mode ArithmeticMode
279+
wantErr bool
262280
}{
263-
{"add(NaN,1)", func() (BFloat16, error) { return BFloat16AddWithMode(nan, one, ModeIEEEArithmetic, RoundNearestEven) }},
264-
{"add(1,NaN)", func() (BFloat16, error) { return BFloat16AddWithMode(one, nan, ModeIEEEArithmetic, RoundNearestEven) }},
265-
{"sub(NaN,1)", func() (BFloat16, error) { return BFloat16SubWithMode(nan, one, ModeIEEEArithmetic, RoundNearestEven) }},
266-
{"mul(NaN,1)", func() (BFloat16, error) { return BFloat16MulWithMode(nan, one, ModeIEEEArithmetic, RoundNearestEven) }},
267-
{"div(NaN,1)", func() (BFloat16, error) { return BFloat16DivWithMode(nan, one, ModeIEEEArithmetic, RoundNearestEven) }},
268-
{"div(1,NaN)", func() (BFloat16, error) { return BFloat16DivWithMode(one, nan, ModeIEEEArithmetic, RoundNearestEven) }},
281+
{"IEEE", ModeIEEEArithmetic, false},
282+
{"fast", ModeFastArithmetic, false},
283+
{"exact", ModeExactArithmetic, true},
269284
}
270285

271286
for _, op := range ops {
272-
t.Run(op.name, func(t *testing.T) {
273-
got, err := op.fn()
274-
if err != nil {
275-
t.Fatalf("unexpected error: %v", err)
276-
}
277-
if !got.IsNaN() {
278-
t.Errorf("expected NaN, got 0x%04X (%v)", got.Bits(), got)
279-
}
280-
})
287+
for _, m := range modes {
288+
t.Run(op.name+"/"+m.name, func(t *testing.T) {
289+
got, err := op.fn(op.a, op.b, m.mode, RoundNearestEven)
290+
if m.wantErr {
291+
if err == nil {
292+
t.Fatal("expected error for NaN in exact mode, got nil")
293+
}
294+
fe, ok := err.(*Float16Error)
295+
if !ok {
296+
t.Fatalf("expected *Float16Error, got %T", err)
297+
}
298+
if fe.Code != ErrNaN {
299+
t.Errorf("error code = %d, want %d (ErrNaN)", fe.Code, ErrNaN)
300+
}
301+
return
302+
}
303+
if err != nil {
304+
t.Fatalf("unexpected error: %v", err)
305+
}
306+
if !got.IsNaN() {
307+
t.Errorf("expected NaN, got 0x%04X (%v)", got.Bits(), got)
308+
}
309+
})
310+
}
281311
}
282312
}
283313

284314
func TestBFloat16GradualUnderflow(t *testing.T) {
285-
// Multiplying two very small normal numbers should produce a subnormal
286-
// rather than flushing to zero.
287315
smallest := BFloat16SmallestPos // smallest positive normal
288316
half := BFloat16FromFloat32(0.5)
289317

290-
got, err := BFloat16MulWithMode(smallest, half, ModeIEEEArithmetic, RoundNearestEven)
291-
if err != nil {
292-
t.Fatalf("unexpected error: %v", err)
293-
}
318+
t.Run("mul/smallest*0.5", func(t *testing.T) {
319+
got, err := BFloat16MulWithMode(smallest, half, ModeIEEEArithmetic, RoundNearestEven)
320+
if err != nil {
321+
t.Fatalf("unexpected error: %v", err)
322+
}
323+
if got.IsZero() {
324+
t.Fatal("expected subnormal result, got zero")
325+
}
326+
gotF := got.ToFloat32()
327+
wantF := smallest.ToFloat32() * 0.5
328+
if math.Abs(float64(gotF-wantF)) > float64(wantF)*0.1 {
329+
t.Errorf("got %e, want approximately %e", gotF, wantF)
330+
}
331+
})
332+
333+
t.Run("mul/neg_underflow", func(t *testing.T) {
334+
neg := BFloat16Neg(smallest)
335+
got, err := BFloat16MulWithMode(neg, half, ModeIEEEArithmetic, RoundNearestEven)
336+
if err != nil {
337+
t.Fatalf("unexpected error: %v", err)
338+
}
339+
if got.IsZero() {
340+
t.Fatal("expected negative subnormal, got zero")
341+
}
342+
if got.ToFloat32() >= 0 {
343+
t.Errorf("expected negative result, got %e", got.ToFloat32())
344+
}
345+
})
346+
347+
t.Run("add/near_subnormal_boundary", func(t *testing.T) {
348+
// Adding two values that sum to something below the smallest normal
349+
// should produce a subnormal, not zero.
350+
sub := BFloat16SmallestPosSubnormal
351+
got, err := BFloat16AddWithMode(sub, sub, ModeIEEEArithmetic, RoundNearestEven)
352+
if err != nil {
353+
t.Fatalf("unexpected error: %v", err)
354+
}
355+
if got.IsZero() {
356+
t.Fatal("expected non-zero subnormal sum, got zero")
357+
}
358+
wantF := BFloat16SmallestPosSubnormal.ToFloat32() * 2
359+
gotF := got.ToFloat32()
360+
if math.Abs(float64(gotF-wantF)) > float64(wantF)*0.1 {
361+
t.Errorf("got %e, want approximately %e", gotF, wantF)
362+
}
363+
})
294364

295-
// The result should be subnormal (half the smallest normal)
296-
if got.IsZero() {
297-
t.Error("expected subnormal result, got zero (gradual underflow not working)")
365+
t.Run("div/smallest/2", func(t *testing.T) {
366+
two := BFloat16FromFloat32(2)
367+
got, err := BFloat16DivWithMode(smallest, two, ModeIEEEArithmetic, RoundNearestEven)
368+
if err != nil {
369+
t.Fatalf("unexpected error: %v", err)
370+
}
371+
if got.IsZero() {
372+
t.Fatal("expected subnormal result, got zero")
373+
}
374+
gotF := got.ToFloat32()
375+
wantF := smallest.ToFloat32() / 2
376+
if math.Abs(float64(gotF-wantF)) > float64(wantF)*0.1 {
377+
t.Errorf("got %e, want approximately %e", gotF, wantF)
378+
}
379+
})
380+
381+
t.Run("sub/subnormal_boundary", func(t *testing.T) {
382+
// Subtracting values that are very close should yield a subnormal.
383+
a := BFloat16FromFloat32(smallest.ToFloat32() * 1.5)
384+
got, err := BFloat16SubWithMode(a, smallest, ModeIEEEArithmetic, RoundNearestEven)
385+
if err != nil {
386+
t.Fatalf("unexpected error: %v", err)
387+
}
388+
// Result should be approximately 0.5 * smallest normal = subnormal
389+
if got.IsZero() {
390+
t.Fatal("expected subnormal result, got zero")
391+
}
392+
})
393+
}
394+
395+
func TestBFloat16FMACorrectness(t *testing.T) {
396+
tests := []struct {
397+
name string
398+
a, b, c BFloat16
399+
wantNaN bool
400+
wantF32 float32
401+
}{
402+
{"2*3+1=7", BFloat16FromFloat32(2), BFloat16FromFloat32(3), BFloat16FromFloat32(1), false, 7},
403+
{"-2*3+10=4", BFloat16FromFloat32(-2), BFloat16FromFloat32(3), BFloat16FromFloat32(10), false, 4},
404+
{"0*5+3=3", BFloat16PositiveZero, BFloat16FromFloat32(5), BFloat16FromFloat32(3), false, 3},
405+
{"5*0+3=3", BFloat16FromFloat32(5), BFloat16PositiveZero, BFloat16FromFloat32(3), false, 3},
406+
{"1*1+0=1", BFloat16FromFloat32(1), BFloat16FromFloat32(1), BFloat16PositiveZero, false, 1},
407+
{"-1*-1+0=1", BFloat16FromFloat32(-1), BFloat16FromFloat32(-1), BFloat16PositiveZero, false, 1},
408+
{"4*0.5+-2=0", BFloat16FromFloat32(4), BFloat16FromFloat32(0.5), BFloat16FromFloat32(-2), false, 0},
409+
// NaN in each operand position
410+
{"NaN*1+0", BFloat16QuietNaN, BFloat16FromFloat32(1), BFloat16PositiveZero, true, 0},
411+
{"1*NaN+0", BFloat16FromFloat32(1), BFloat16QuietNaN, BFloat16PositiveZero, true, 0},
412+
{"1*1+NaN", BFloat16FromFloat32(1), BFloat16FromFloat32(1), BFloat16QuietNaN, true, 0},
413+
{"NaN*NaN+NaN", BFloat16QuietNaN, BFloat16QuietNaN, BFloat16QuietNaN, true, 0},
298414
}
299415

300-
// Verify the result is approximately half of the smallest normal
301-
gotF := got.ToFloat32()
302-
wantF := smallest.ToFloat32() * 0.5
303-
if math.Abs(float64(gotF-wantF)) > float64(wantF)*0.1 {
304-
t.Errorf("got %e, want approximately %e", gotF, wantF)
416+
for _, tt := range tests {
417+
t.Run(tt.name, func(t *testing.T) {
418+
got, err := BFloat16FMA(tt.a, tt.b, tt.c)
419+
if err != nil {
420+
t.Fatalf("unexpected error: %v", err)
421+
}
422+
if tt.wantNaN {
423+
if !got.IsNaN() {
424+
t.Errorf("expected NaN, got %v (0x%04X)", got, got.Bits())
425+
}
426+
return
427+
}
428+
gotF32 := got.ToFloat32()
429+
if gotF32 != tt.wantF32 {
430+
t.Errorf("got %v, want %v", gotF32, tt.wantF32)
431+
}
432+
})
305433
}
434+
435+
// FMA precision test: verify fused multiply-add avoids intermediate rounding.
436+
// For values where a*b overflows float16 range but a*b+c is representable,
437+
// FMA via float64 should give a more accurate result.
438+
t.Run("precision/no_intermediate_rounding", func(t *testing.T) {
439+
a := BFloat16FromFloat32(100)
440+
b := BFloat16FromFloat32(100)
441+
c := BFloat16FromFloat32(-9984) // 100*100 = 10000; 10000 - 9984 = 16
442+
got, err := BFloat16FMA(a, b, c)
443+
if err != nil {
444+
t.Fatalf("unexpected error: %v", err)
445+
}
446+
gotF := got.ToFloat32()
447+
wantF := float32(math.FMA(float64(a.ToFloat32()), float64(b.ToFloat32()), float64(c.ToFloat32())))
448+
if gotF != wantF {
449+
t.Errorf("got %v, want %v", gotF, wantF)
450+
}
451+
})
306452
}
307453

308454
func TestBFloat16ArithmeticWithMode(t *testing.T) {

0 commit comments

Comments
 (0)