@@ -252,57 +252,203 @@ func TestBFloat16FMA(t *testing.T) {
252252 }
253253}
254254
255- func TestBFloat16NaNPropagation (t * testing.T ) {
255+ func TestBFloat16NaNPropagationAllModes (t * testing.T ) {
256256 nan := BFloat16QuietNaN
257257 one := BFloat16FromFloat32 (1 )
258258
259+ type opFunc func (a , b BFloat16 , mode ArithmeticMode , rounding RoundingMode ) (BFloat16 , error )
260+
259261 ops := []struct {
260262 name string
261- fn func () (BFloat16 , error )
263+ fn opFunc
264+ a , b BFloat16
265+ }{
266+ {"add(NaN,1)" , BFloat16AddWithMode , nan , one },
267+ {"add(1,NaN)" , BFloat16AddWithMode , one , nan },
268+ {"sub(NaN,1)" , BFloat16SubWithMode , nan , one },
269+ {"sub(1,NaN)" , BFloat16SubWithMode , one , nan },
270+ {"mul(NaN,1)" , BFloat16MulWithMode , nan , one },
271+ {"mul(1,NaN)" , BFloat16MulWithMode , one , nan },
272+ {"div(NaN,1)" , BFloat16DivWithMode , nan , one },
273+ {"div(1,NaN)" , BFloat16DivWithMode , one , nan },
274+ }
275+
276+ modes := []struct {
277+ name string
278+ mode ArithmeticMode
279+ wantErr bool
262280 }{
263- {"add(NaN,1)" , func () (BFloat16 , error ) { return BFloat16AddWithMode (nan , one , ModeIEEEArithmetic , RoundNearestEven ) }},
264- {"add(1,NaN)" , func () (BFloat16 , error ) { return BFloat16AddWithMode (one , nan , ModeIEEEArithmetic , RoundNearestEven ) }},
265- {"sub(NaN,1)" , func () (BFloat16 , error ) { return BFloat16SubWithMode (nan , one , ModeIEEEArithmetic , RoundNearestEven ) }},
266- {"mul(NaN,1)" , func () (BFloat16 , error ) { return BFloat16MulWithMode (nan , one , ModeIEEEArithmetic , RoundNearestEven ) }},
267- {"div(NaN,1)" , func () (BFloat16 , error ) { return BFloat16DivWithMode (nan , one , ModeIEEEArithmetic , RoundNearestEven ) }},
268- {"div(1,NaN)" , func () (BFloat16 , error ) { return BFloat16DivWithMode (one , nan , ModeIEEEArithmetic , RoundNearestEven ) }},
281+ {"IEEE" , ModeIEEEArithmetic , false },
282+ {"fast" , ModeFastArithmetic , false },
283+ {"exact" , ModeExactArithmetic , true },
269284 }
270285
271286 for _ , op := range ops {
272- t .Run (op .name , func (t * testing.T ) {
273- got , err := op .fn ()
274- if err != nil {
275- t .Fatalf ("unexpected error: %v" , err )
276- }
277- if ! got .IsNaN () {
278- t .Errorf ("expected NaN, got 0x%04X (%v)" , got .Bits (), got )
279- }
280- })
287+ for _ , m := range modes {
288+ t .Run (op .name + "/" + m .name , func (t * testing.T ) {
289+ got , err := op .fn (op .a , op .b , m .mode , RoundNearestEven )
290+ if m .wantErr {
291+ if err == nil {
292+ t .Fatal ("expected error for NaN in exact mode, got nil" )
293+ }
294+ fe , ok := err .(* Float16Error )
295+ if ! ok {
296+ t .Fatalf ("expected *Float16Error, got %T" , err )
297+ }
298+ if fe .Code != ErrNaN {
299+ t .Errorf ("error code = %d, want %d (ErrNaN)" , fe .Code , ErrNaN )
300+ }
301+ return
302+ }
303+ if err != nil {
304+ t .Fatalf ("unexpected error: %v" , err )
305+ }
306+ if ! got .IsNaN () {
307+ t .Errorf ("expected NaN, got 0x%04X (%v)" , got .Bits (), got )
308+ }
309+ })
310+ }
281311 }
282312}
283313
284314func TestBFloat16GradualUnderflow (t * testing.T ) {
285- // Multiplying two very small normal numbers should produce a subnormal
286- // rather than flushing to zero.
287315 smallest := BFloat16SmallestPos // smallest positive normal
288316 half := BFloat16FromFloat32 (0.5 )
289317
290- got , err := BFloat16MulWithMode (smallest , half , ModeIEEEArithmetic , RoundNearestEven )
291- if err != nil {
292- t .Fatalf ("unexpected error: %v" , err )
293- }
318+ t .Run ("mul/smallest*0.5" , func (t * testing.T ) {
319+ got , err := BFloat16MulWithMode (smallest , half , ModeIEEEArithmetic , RoundNearestEven )
320+ if err != nil {
321+ t .Fatalf ("unexpected error: %v" , err )
322+ }
323+ if got .IsZero () {
324+ t .Fatal ("expected subnormal result, got zero" )
325+ }
326+ gotF := got .ToFloat32 ()
327+ wantF := smallest .ToFloat32 () * 0.5
328+ if math .Abs (float64 (gotF - wantF )) > float64 (wantF )* 0.1 {
329+ t .Errorf ("got %e, want approximately %e" , gotF , wantF )
330+ }
331+ })
332+
333+ t .Run ("mul/neg_underflow" , func (t * testing.T ) {
334+ neg := BFloat16Neg (smallest )
335+ got , err := BFloat16MulWithMode (neg , half , ModeIEEEArithmetic , RoundNearestEven )
336+ if err != nil {
337+ t .Fatalf ("unexpected error: %v" , err )
338+ }
339+ if got .IsZero () {
340+ t .Fatal ("expected negative subnormal, got zero" )
341+ }
342+ if got .ToFloat32 () >= 0 {
343+ t .Errorf ("expected negative result, got %e" , got .ToFloat32 ())
344+ }
345+ })
346+
347+ t .Run ("add/near_subnormal_boundary" , func (t * testing.T ) {
348+ // Adding two values that sum to something below the smallest normal
349+ // should produce a subnormal, not zero.
350+ sub := BFloat16SmallestPosSubnormal
351+ got , err := BFloat16AddWithMode (sub , sub , ModeIEEEArithmetic , RoundNearestEven )
352+ if err != nil {
353+ t .Fatalf ("unexpected error: %v" , err )
354+ }
355+ if got .IsZero () {
356+ t .Fatal ("expected non-zero subnormal sum, got zero" )
357+ }
358+ wantF := BFloat16SmallestPosSubnormal .ToFloat32 () * 2
359+ gotF := got .ToFloat32 ()
360+ if math .Abs (float64 (gotF - wantF )) > float64 (wantF )* 0.1 {
361+ t .Errorf ("got %e, want approximately %e" , gotF , wantF )
362+ }
363+ })
294364
295- // The result should be subnormal (half the smallest normal)
296- if got .IsZero () {
297- t .Error ("expected subnormal result, got zero (gradual underflow not working)" )
365+ t .Run ("div/smallest/2" , func (t * testing.T ) {
366+ two := BFloat16FromFloat32 (2 )
367+ got , err := BFloat16DivWithMode (smallest , two , ModeIEEEArithmetic , RoundNearestEven )
368+ if err != nil {
369+ t .Fatalf ("unexpected error: %v" , err )
370+ }
371+ if got .IsZero () {
372+ t .Fatal ("expected subnormal result, got zero" )
373+ }
374+ gotF := got .ToFloat32 ()
375+ wantF := smallest .ToFloat32 () / 2
376+ if math .Abs (float64 (gotF - wantF )) > float64 (wantF )* 0.1 {
377+ t .Errorf ("got %e, want approximately %e" , gotF , wantF )
378+ }
379+ })
380+
381+ t .Run ("sub/subnormal_boundary" , func (t * testing.T ) {
382+ // Subtracting values that are very close should yield a subnormal.
383+ a := BFloat16FromFloat32 (smallest .ToFloat32 () * 1.5 )
384+ got , err := BFloat16SubWithMode (a , smallest , ModeIEEEArithmetic , RoundNearestEven )
385+ if err != nil {
386+ t .Fatalf ("unexpected error: %v" , err )
387+ }
388+ // Result should be approximately 0.5 * smallest normal = subnormal
389+ if got .IsZero () {
390+ t .Fatal ("expected subnormal result, got zero" )
391+ }
392+ })
393+ }
394+
395+ func TestBFloat16FMACorrectness (t * testing.T ) {
396+ tests := []struct {
397+ name string
398+ a , b , c BFloat16
399+ wantNaN bool
400+ wantF32 float32
401+ }{
402+ {"2*3+1=7" , BFloat16FromFloat32 (2 ), BFloat16FromFloat32 (3 ), BFloat16FromFloat32 (1 ), false , 7 },
403+ {"-2*3+10=4" , BFloat16FromFloat32 (- 2 ), BFloat16FromFloat32 (3 ), BFloat16FromFloat32 (10 ), false , 4 },
404+ {"0*5+3=3" , BFloat16PositiveZero , BFloat16FromFloat32 (5 ), BFloat16FromFloat32 (3 ), false , 3 },
405+ {"5*0+3=3" , BFloat16FromFloat32 (5 ), BFloat16PositiveZero , BFloat16FromFloat32 (3 ), false , 3 },
406+ {"1*1+0=1" , BFloat16FromFloat32 (1 ), BFloat16FromFloat32 (1 ), BFloat16PositiveZero , false , 1 },
407+ {"-1*-1+0=1" , BFloat16FromFloat32 (- 1 ), BFloat16FromFloat32 (- 1 ), BFloat16PositiveZero , false , 1 },
408+ {"4*0.5+-2=0" , BFloat16FromFloat32 (4 ), BFloat16FromFloat32 (0.5 ), BFloat16FromFloat32 (- 2 ), false , 0 },
409+ // NaN in each operand position
410+ {"NaN*1+0" , BFloat16QuietNaN , BFloat16FromFloat32 (1 ), BFloat16PositiveZero , true , 0 },
411+ {"1*NaN+0" , BFloat16FromFloat32 (1 ), BFloat16QuietNaN , BFloat16PositiveZero , true , 0 },
412+ {"1*1+NaN" , BFloat16FromFloat32 (1 ), BFloat16FromFloat32 (1 ), BFloat16QuietNaN , true , 0 },
413+ {"NaN*NaN+NaN" , BFloat16QuietNaN , BFloat16QuietNaN , BFloat16QuietNaN , true , 0 },
298414 }
299415
300- // Verify the result is approximately half of the smallest normal
301- gotF := got .ToFloat32 ()
302- wantF := smallest .ToFloat32 () * 0.5
303- if math .Abs (float64 (gotF - wantF )) > float64 (wantF )* 0.1 {
304- t .Errorf ("got %e, want approximately %e" , gotF , wantF )
416+ for _ , tt := range tests {
417+ t .Run (tt .name , func (t * testing.T ) {
418+ got , err := BFloat16FMA (tt .a , tt .b , tt .c )
419+ if err != nil {
420+ t .Fatalf ("unexpected error: %v" , err )
421+ }
422+ if tt .wantNaN {
423+ if ! got .IsNaN () {
424+ t .Errorf ("expected NaN, got %v (0x%04X)" , got , got .Bits ())
425+ }
426+ return
427+ }
428+ gotF32 := got .ToFloat32 ()
429+ if gotF32 != tt .wantF32 {
430+ t .Errorf ("got %v, want %v" , gotF32 , tt .wantF32 )
431+ }
432+ })
305433 }
434+
435+ // FMA precision test: verify fused multiply-add avoids intermediate rounding.
436+ // For values where a*b overflows float16 range but a*b+c is representable,
437+ // FMA via float64 should give a more accurate result.
438+ t .Run ("precision/no_intermediate_rounding" , func (t * testing.T ) {
439+ a := BFloat16FromFloat32 (100 )
440+ b := BFloat16FromFloat32 (100 )
441+ c := BFloat16FromFloat32 (- 9984 ) // 100*100 = 10000; 10000 - 9984 = 16
442+ got , err := BFloat16FMA (a , b , c )
443+ if err != nil {
444+ t .Fatalf ("unexpected error: %v" , err )
445+ }
446+ gotF := got .ToFloat32 ()
447+ wantF := float32 (math .FMA (float64 (a .ToFloat32 ()), float64 (b .ToFloat32 ()), float64 (c .ToFloat32 ())))
448+ if gotF != wantF {
449+ t .Errorf ("got %v, want %v" , gotF , wantF )
450+ }
451+ })
306452}
307453
308454func TestBFloat16ArithmeticWithMode (t * testing.T ) {
0 commit comments