Skip to content

Commit eb88d1e

Browse files
committed
feat(bfloat16): add Phase 4 math functions
Implement BFloat16Sqrt, Exp, Log, Log2, Sin, Cos, Tanh, Sigmoid. Each function converts to float64 for computation and converts back. Add FastMode variants for Sigmoid and Tanh using polynomial approximation.
1 parent e48ff23 commit eb88d1e

2 files changed

Lines changed: 530 additions & 0 deletions

File tree

bfloat16_math.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package float16
2+
3+
import (
4+
"math"
5+
)
6+
7+
// Mathematical functions for BFloat16
8+
9+
// BFloat16Sqrt returns the square root of the BFloat16 value.
10+
func BFloat16Sqrt(b BFloat16) BFloat16 {
11+
if b.IsZero() {
12+
return b
13+
}
14+
if b.IsNaN() {
15+
return b
16+
}
17+
if b.IsInf(1) {
18+
return BFloat16PositiveInfinity
19+
}
20+
if b.Signbit() {
21+
return BFloat16QuietNaN
22+
}
23+
result := math.Sqrt(float64(b.ToFloat32()))
24+
return BFloat16FromFloat32(float32(result))
25+
}
26+
27+
// BFloat16Exp returns e^b.
28+
func BFloat16Exp(b BFloat16) BFloat16 {
29+
if b.IsZero() {
30+
return BFloat16One
31+
}
32+
if b.IsNaN() {
33+
return b
34+
}
35+
if b.IsInf(1) {
36+
return BFloat16PositiveInfinity
37+
}
38+
if b.IsInf(-1) {
39+
return BFloat16PositiveZero
40+
}
41+
result := math.Exp(float64(b.ToFloat32()))
42+
return BFloat16FromFloat32(float32(result))
43+
}
44+
45+
// BFloat16Log returns the natural logarithm of b.
46+
func BFloat16Log(b BFloat16) BFloat16 {
47+
if b.IsZero() {
48+
return BFloat16NegativeInfinity
49+
}
50+
if b.IsNaN() {
51+
return b
52+
}
53+
if b.IsInf(1) {
54+
return BFloat16PositiveInfinity
55+
}
56+
if b.Signbit() {
57+
return BFloat16QuietNaN
58+
}
59+
result := math.Log(float64(b.ToFloat32()))
60+
return BFloat16FromFloat32(float32(result))
61+
}
62+
63+
// BFloat16Log2 returns the base-2 logarithm of b.
64+
func BFloat16Log2(b BFloat16) BFloat16 {
65+
if b.IsZero() {
66+
return BFloat16NegativeInfinity
67+
}
68+
if b.IsNaN() {
69+
return b
70+
}
71+
if b.IsInf(1) {
72+
return BFloat16PositiveInfinity
73+
}
74+
if b.Signbit() {
75+
return BFloat16QuietNaN
76+
}
77+
result := math.Log2(float64(b.ToFloat32()))
78+
return BFloat16FromFloat32(float32(result))
79+
}
80+
81+
// BFloat16Sin returns the sine of b (in radians).
82+
func BFloat16Sin(b BFloat16) BFloat16 {
83+
if b.IsZero() {
84+
return b
85+
}
86+
if b.IsNaN() || b.IsInf(0) {
87+
return BFloat16QuietNaN
88+
}
89+
result := math.Sin(float64(b.ToFloat32()))
90+
return BFloat16FromFloat32(float32(result))
91+
}
92+
93+
// BFloat16Cos returns the cosine of b (in radians).
94+
func BFloat16Cos(b BFloat16) BFloat16 {
95+
if b.IsZero() {
96+
return BFloat16One
97+
}
98+
if b.IsNaN() || b.IsInf(0) {
99+
return BFloat16QuietNaN
100+
}
101+
result := math.Cos(float64(b.ToFloat32()))
102+
return BFloat16FromFloat32(float32(result))
103+
}
104+
105+
// BFloat16Tanh returns the hyperbolic tangent of b.
106+
func BFloat16Tanh(b BFloat16) BFloat16 {
107+
if b.IsZero() {
108+
return b
109+
}
110+
if b.IsNaN() {
111+
return b
112+
}
113+
if b.IsInf(1) {
114+
return BFloat16One
115+
}
116+
if b.IsInf(-1) {
117+
return BFloat16FromFloat32(-1)
118+
}
119+
result := math.Tanh(float64(b.ToFloat32()))
120+
return BFloat16FromFloat32(float32(result))
121+
}
122+
123+
// BFloat16Sigmoid returns 1 / (1 + exp(-b)).
124+
func BFloat16Sigmoid(b BFloat16) BFloat16 {
125+
if b.IsNaN() {
126+
return b
127+
}
128+
if b.IsInf(1) {
129+
return BFloat16One
130+
}
131+
if b.IsInf(-1) {
132+
return BFloat16PositiveZero
133+
}
134+
x := float64(b.ToFloat32())
135+
result := 1.0 / (1.0 + math.Exp(-x))
136+
return BFloat16FromFloat32(float32(result))
137+
}
138+
139+
// FastMode variants using polynomial approximations.
140+
// These trade accuracy for speed, suitable for ML inference workloads
141+
// where BFloat16 precision is already limited.
142+
143+
// BFloat16FastSigmoid computes an approximate sigmoid using a rational polynomial.
144+
// Uses the approximation: sigmoid(x) ≈ 0.5 + 0.5 * x / (1 + |x|)
145+
// which avoids exp() entirely.
146+
func BFloat16FastSigmoid(b BFloat16) BFloat16 {
147+
if b.IsNaN() {
148+
return b
149+
}
150+
if b.IsInf(1) {
151+
return BFloat16One
152+
}
153+
if b.IsInf(-1) {
154+
return BFloat16PositiveZero
155+
}
156+
x := float64(b.ToFloat32())
157+
abs := x
158+
if abs < 0 {
159+
abs = -abs
160+
}
161+
result := 0.5 + 0.5*x/(1.0+abs)
162+
return BFloat16FromFloat32(float32(result))
163+
}
164+
165+
// BFloat16FastTanh computes an approximate tanh using a rational polynomial.
166+
// Uses the approximation: tanh(x) ≈ x*(27 + x*x) / (27 + 9*x*x)
167+
// which is a Padé approximant accurate to within ~0.004 for |x| < 3.
168+
func BFloat16FastTanh(b BFloat16) BFloat16 {
169+
if b.IsZero() {
170+
return b
171+
}
172+
if b.IsNaN() {
173+
return b
174+
}
175+
if b.IsInf(1) {
176+
return BFloat16One
177+
}
178+
if b.IsInf(-1) {
179+
return BFloat16FromFloat32(-1)
180+
}
181+
x := float64(b.ToFloat32())
182+
abs := x
183+
if abs < 0 {
184+
abs = -abs
185+
}
186+
// Clamp for large values where tanh saturates
187+
if abs > 4.0 {
188+
if x > 0 {
189+
return BFloat16One
190+
}
191+
return BFloat16FromFloat32(-1)
192+
}
193+
x2 := x * x
194+
result := x * (27.0 + x2) / (27.0 + 9.0*x2)
195+
return BFloat16FromFloat32(float32(result))
196+
}

0 commit comments

Comments
 (0)