Skip to content

Commit 0699e3f

Browse files
bors[bot]strake
andcommitted
Merge #21
21: elide pass-by-ref in root functions r=cuviper a=strake Co-authored-by: M Farkas-Dyck <strake888@gmail.com>
2 parents 0473275 + 956315c commit 0699e3f

1 file changed

Lines changed: 144 additions & 133 deletions

File tree

src/roots.rs

Lines changed: 144 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -202,170 +202,181 @@ fn log2<T: PrimInt>(x: T) -> u32 {
202202
macro_rules! unsigned_roots {
203203
($T:ident) => {
204204
impl Roots for $T {
205+
#[inline]
205206
fn nth_root(&self, n: u32) -> Self {
206-
// Specialize small roots
207-
match n {
208-
0 => panic!("can't find a root of degree 0!"),
209-
1 => return *self,
210-
2 => return self.sqrt(),
211-
3 => return self.cbrt(),
212-
_ => (),
213-
}
207+
fn go(a: $T, n: u32) -> $T {
208+
// Specialize small roots
209+
match n {
210+
0 => panic!("can't find a root of degree 0!"),
211+
1 => return a,
212+
2 => return a.sqrt(),
213+
3 => return a.cbrt(),
214+
_ => (),
215+
}
214216

215-
// The root of values less than 2ⁿ can only be 0 or 1.
216-
if bits::<$T>() <= n || *self < (1 << n) {
217-
return (*self > 0) as $T;
218-
}
217+
// The root of values less than 2ⁿ can only be 0 or 1.
218+
if bits::<$T>() <= n || a < (1 << n) {
219+
return (a > 0) as $T;
220+
}
219221

220-
if bits::<$T>() > 64 {
221-
// 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
222-
return if *self <= core::u64::MAX as $T {
223-
(*self as u64).nth_root(n) as $T
224-
} else {
225-
let lo = (self >> n).nth_root(n) << 1;
226-
let hi = lo + 1;
227-
// 128-bit `checked_mul` also involves division, but we can't always
228-
// compute `hiⁿ` without risking overflow. Try to avoid it though...
229-
if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
230-
match checked_pow(hi, n as usize) {
231-
Some(x) if x <= *self => hi,
232-
_ => lo,
233-
}
222+
if bits::<$T>() > 64 {
223+
// 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
224+
return if a <= core::u64::MAX as $T {
225+
(a as u64).nth_root(n) as $T
234226
} else {
235-
if hi.pow(n) <= *self {
236-
hi
227+
let lo = (a >> n).nth_root(n) << 1;
228+
let hi = lo + 1;
229+
// 128-bit `checked_mul` also involves division, but we can't always
230+
// compute `hiⁿ` without risking overflow. Try to avoid it though...
231+
if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
232+
match checked_pow(hi, n as usize) {
233+
Some(x) if x <= a => hi,
234+
_ => lo,
235+
}
237236
} else {
238-
lo
237+
if hi.pow(n) <= a {
238+
hi
239+
} else {
240+
lo
241+
}
239242
}
243+
};
244+
}
245+
246+
#[cfg(feature = "std")]
247+
#[inline]
248+
fn guess(x: $T, n: u32) -> $T {
249+
// for smaller inputs, `f64` doesn't justify its cost.
250+
if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
251+
1 << ((log2(x) + n - 1) / n)
252+
} else {
253+
((x as f64).ln() / f64::from(n)).exp() as $T
240254
}
241-
};
242-
}
255+
}
243256

244-
#[cfg(feature = "std")]
245-
#[inline]
246-
fn guess(x: $T, n: u32) -> $T {
247-
// for smaller inputs, `f64` doesn't justify its cost.
248-
if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
257+
#[cfg(not(feature = "std"))]
258+
#[inline]
259+
fn guess(x: $T, n: u32) -> $T {
249260
1 << ((log2(x) + n - 1) / n)
250-
} else {
251-
((x as f64).ln() / f64::from(n)).exp() as $T
252261
}
253-
}
254-
255-
#[cfg(not(feature = "std"))]
256-
#[inline]
257-
fn guess(x: $T, n: u32) -> $T {
258-
1 << ((log2(x) + n - 1) / n)
259-
}
260262

261-
// https://en.wikipedia.org/wiki/Nth_root_algorithm
262-
let n1 = n - 1;
263-
let next = |x: $T| {
264-
let y = match checked_pow(x, n1 as usize) {
265-
Some(ax) => self / ax,
266-
None => 0,
263+
// https://en.wikipedia.org/wiki/Nth_root_algorithm
264+
let n1 = n - 1;
265+
let next = |x: $T| {
266+
let y = match checked_pow(x, n1 as usize) {
267+
Some(ax) => a / ax,
268+
None => 0,
269+
};
270+
(y + x * n1 as $T) / n as $T
267271
};
268-
(y + x * n1 as $T) / n as $T
269-
};
270-
fixpoint(guess(*self, n), next)
272+
fixpoint(guess(a, n), next)
273+
}
274+
go(*self, n)
271275
}
272276

277+
#[inline]
273278
fn sqrt(&self) -> Self {
274-
if bits::<$T>() > 64 {
275-
// 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
276-
// https://en.wikipedia.org/wiki/Integer_square_root#Using_bitwise_operations
277-
return if *self <= core::u64::MAX as $T {
278-
(*self as u64).sqrt() as $T
279-
} else {
280-
let lo = (self >> 2u32).sqrt() << 1;
281-
let hi = lo + 1;
282-
if hi * hi <= *self {
283-
hi
279+
fn go(a: $T) -> $T {
280+
if bits::<$T>() > 64 {
281+
// 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
282+
return if a <= core::u64::MAX as $T {
283+
(a as u64).sqrt() as $T
284284
} else {
285-
lo
286-
}
287-
};
288-
}
285+
let lo = (a >> 2u32).sqrt() << 1;
286+
let hi = lo + 1;
287+
if hi * hi <= a {
288+
hi
289+
} else {
290+
lo
291+
}
292+
};
293+
}
289294

290-
if *self < 4 {
291-
return (*self > 0) as Self;
292-
}
295+
if a < 4 {
296+
return (a > 0) as $T;
297+
}
293298

294-
#[cfg(feature = "std")]
295-
#[inline]
296-
fn guess(x: $T) -> $T {
297-
(x as f64).sqrt() as $T
298-
}
299+
#[cfg(feature = "std")]
300+
#[inline]
301+
fn guess(x: $T) -> $T {
302+
(x as f64).sqrt() as $T
303+
}
299304

300-
#[cfg(not(feature = "std"))]
301-
#[inline]
302-
fn guess(x: $T) -> $T {
303-
1 << ((log2(x) + 1) / 2)
304-
}
305+
#[cfg(not(feature = "std"))]
306+
#[inline]
307+
fn guess(x: $T) -> $T {
308+
1 << ((log2(x) + 1) / 2)
309+
}
305310

306-
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
307-
let next = |x: $T| (self / x + x) >> 1;
308-
fixpoint(guess(*self), next)
311+
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
312+
let next = |x: $T| (a / x + x) >> 1;
313+
fixpoint(guess(a), next)
314+
}
315+
go(*self)
309316
}
310317

318+
#[inline]
311319
fn cbrt(&self) -> Self {
312-
if bits::<$T>() > 64 {
313-
// 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
314-
return if *self <= core::u64::MAX as $T {
315-
(*self as u64).cbrt() as $T
316-
} else {
317-
let lo = (self >> 3u32).cbrt() << 1;
318-
let hi = lo + 1;
319-
if hi * hi * hi <= *self {
320-
hi
320+
fn go(a: $T) -> $T {
321+
if bits::<$T>() > 64 {
322+
// 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
323+
return if a <= core::u64::MAX as $T {
324+
(a as u64).cbrt() as $T
321325
} else {
322-
lo
323-
}
324-
};
325-
}
326+
let lo = (a >> 3u32).cbrt() << 1;
327+
let hi = lo + 1;
328+
if hi * hi * hi <= a {
329+
hi
330+
} else {
331+
lo
332+
}
333+
};
334+
}
326335

327-
if bits::<$T>() <= 32 {
328-
// Implementation based on Hacker's Delight `icbrt2`
329-
let mut x = *self;
330-
let mut y2 = 0;
331-
let mut y = 0;
332-
let smax = bits::<$T>() / 3;
333-
for s in (0..smax + 1).rev() {
334-
let s = s * 3;
335-
y2 *= 4;
336-
y *= 2;
337-
let b = 3 * (y2 + y) + 1;
338-
if x >> s >= b {
339-
x -= b << s;
340-
y2 += 2 * y + 1;
341-
y += 1;
336+
if bits::<$T>() <= 32 {
337+
// Implementation based on Hacker's Delight `icbrt2`
338+
let mut x = a ;
339+
let mut y2 = 0;
340+
let mut y = 0;
341+
let smax = bits::<$T>() / 3;
342+
for s in (0..smax + 1).rev() {
343+
let s = s * 3;
344+
y2 *= 4;
345+
y *= 2;
346+
let b = 3 * (y2 + y) + 1;
347+
if x >> s >= b {
348+
x -= b << s;
349+
y2 += 2 * y + 1;
350+
y += 1;
351+
}
342352
}
353+
return y;
343354
}
344-
return y;
345-
}
346355

347-
if *self < 8 {
348-
return (*self > 0) as Self;
349-
}
350-
if *self <= core::u32::MAX as $T {
351-
return (*self as u32).cbrt() as $T;
352-
}
356+
if a < 8 {
357+
return (a > 0) as $T;
358+
}
359+
if a <= core::u32::MAX as $T {
360+
return (a as u32).cbrt() as $T;
361+
}
353362

354-
#[cfg(feature = "std")]
355-
#[inline]
356-
fn guess(x: $T) -> $T {
357-
(x as f64).cbrt() as $T
358-
}
363+
#[cfg(feature = "std")]
364+
#[inline]
365+
fn guess(x: $T) -> $T {
366+
(x as f64).cbrt() as $T
367+
}
359368

360-
#[cfg(not(feature = "std"))]
361-
#[inline]
362-
fn guess(x: $T) -> $T {
363-
1 << ((log2(x) + 2) / 3)
364-
}
369+
#[cfg(not(feature = "std"))]
370+
#[inline]
371+
fn guess(x: $T) -> $T {
372+
1 << ((log2(x) + 2) / 3)
373+
}
365374

366-
// https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
367-
let next = |x: $T| (self / (x * x) + x * 2) / 3;
368-
fixpoint(guess(*self), next)
375+
// https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
376+
let next = |x: $T| (a / (x * x) + x * 2) / 3;
377+
fixpoint(guess(a), next)
378+
}
379+
go(*self)
369380
}
370381
}
371382
};

0 commit comments

Comments
 (0)