001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.math;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.math.MathPreconditions.checkNoOverflow;
022import static com.google.common.math.MathPreconditions.checkNonNegative;
023import static com.google.common.math.MathPreconditions.checkPositive;
024import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
025import static java.lang.Math.abs;
026import static java.lang.Math.min;
027import static java.math.RoundingMode.HALF_EVEN;
028import static java.math.RoundingMode.HALF_UP;
029
030import com.google.common.annotations.GwtCompatible;
031import com.google.common.annotations.GwtIncompatible;
032import com.google.common.annotations.VisibleForTesting;
033import com.google.common.primitives.UnsignedLongs;
034
035import java.math.BigInteger;
036import java.math.RoundingMode;
037
038/**
039 * A class for arithmetic on values of type {@code long}. Where possible, methods are defined and
040 * named analogously to their {@code BigInteger} counterparts.
041 *
042 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
043 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
044 *
045 * <p>Similar functionality for {@code int} and for {@link BigInteger} can be found in
046 * {@link IntMath} and {@link BigIntegerMath} respectively.  For other common operations on
047 * {@code long} values, see {@link com.google.common.primitives.Longs}.
048 *
049 * @author Louis Wasserman
050 * @since 11.0
051 */
052@GwtCompatible(emulated = true)
053public final class LongMath {
054  // NOTE: Whenever both tests are cheap and functional, it's faster to use &, | instead of &&, ||
055
056  /**
057   * Returns {@code true} if {@code x} represents a power of two.
058   *
059   * <p>This differs from {@code Long.bitCount(x) == 1}, because
060   * {@code Long.bitCount(Long.MIN_VALUE) == 1}, but {@link Long#MIN_VALUE} is not a power of two.
061   */
062  public static boolean isPowerOfTwo(long x) {
063    return x > 0 & (x & (x - 1)) == 0;
064  }
065
066  /**
067   * Returns 1 if {@code x < y} as unsigned longs, and 0 otherwise.  Assumes that x - y fits into a
068   * signed long.  The implementation is branch-free, and benchmarks suggest it is measurably
069   * faster than the straightforward ternary expression.
070   */
071  @VisibleForTesting
072  static int lessThanBranchFree(long x, long y) {
073    // Returns the sign bit of x - y.
074    return (int) (~~(x - y) >>> (Long.SIZE - 1));
075  }
076
077  /**
078   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
079   *
080   * @throws IllegalArgumentException if {@code x <= 0}
081   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
082   *         is not a power of two
083   */
084  @SuppressWarnings("fallthrough")
085  // TODO(kevinb): remove after this warning is disabled globally
086  public static int log2(long x, RoundingMode mode) {
087    checkPositive("x", x);
088    switch (mode) {
089      case UNNECESSARY:
090        checkRoundingUnnecessary(isPowerOfTwo(x));
091        // fall through
092      case DOWN:
093      case FLOOR:
094        return (Long.SIZE - 1) - Long.numberOfLeadingZeros(x);
095
096      case UP:
097      case CEILING:
098        return Long.SIZE - Long.numberOfLeadingZeros(x - 1);
099
100      case HALF_DOWN:
101      case HALF_UP:
102      case HALF_EVEN:
103        // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
104        int leadingZeros = Long.numberOfLeadingZeros(x);
105        long cmp = MAX_POWER_OF_SQRT2_UNSIGNED >>> leadingZeros;
106        // floor(2^(logFloor + 0.5))
107        int logFloor = (Long.SIZE - 1) - leadingZeros;
108        return logFloor + lessThanBranchFree(cmp, x);
109
110      default:
111        throw new AssertionError("impossible");
112    }
113  }
114
115  /** The biggest half power of two that fits into an unsigned long */
116  @VisibleForTesting static final long MAX_POWER_OF_SQRT2_UNSIGNED = 0xB504F333F9DE6484L;
117
118  /**
119   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
120   *
121   * @throws IllegalArgumentException if {@code x <= 0}
122   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
123   *         is not a power of ten
124   */
125  @GwtIncompatible("TODO")
126  @SuppressWarnings("fallthrough")
127  // TODO(kevinb): remove after this warning is disabled globally
128  public static int log10(long x, RoundingMode mode) {
129    checkPositive("x", x);
130    int logFloor = log10Floor(x);
131    long floorPow = powersOf10[logFloor];
132    switch (mode) {
133      case UNNECESSARY:
134        checkRoundingUnnecessary(x == floorPow);
135        // fall through
136      case FLOOR:
137      case DOWN:
138        return logFloor;
139      case CEILING:
140      case UP:
141        return logFloor + lessThanBranchFree(floorPow, x);
142      case HALF_DOWN:
143      case HALF_UP:
144      case HALF_EVEN:
145        // sqrt(10) is irrational, so log10(x)-logFloor is never exactly 0.5
146        return logFloor + lessThanBranchFree(halfPowersOf10[logFloor], x);
147      default:
148        throw new AssertionError();
149    }
150  }
151
152  @GwtIncompatible("TODO")
153  static int log10Floor(long x) {
154    /*
155     * Based on Hacker's Delight Fig. 11-5, the two-table-lookup, branch-free implementation.
156     *
157     * The key idea is that based on the number of leading zeros (equivalently, floor(log2(x))),
158     * we can narrow the possible floor(log10(x)) values to two.  For example, if floor(log2(x))
159     * is 6, then 64 <= x < 128, so floor(log10(x)) is either 1 or 2.
160     */
161    int y = maxLog10ForLeadingZeros[Long.numberOfLeadingZeros(x)];
162    /*
163     * y is the higher of the two possible values of floor(log10(x)). If x < 10^y, then we want the
164     * lower of the two possible values, or y - 1, otherwise, we want y.
165     */
166    return y - lessThanBranchFree(x, powersOf10[y]);
167  }
168
169  // maxLog10ForLeadingZeros[i] == floor(log10(2^(Long.SIZE - i)))
170  @VisibleForTesting static final byte[] maxLog10ForLeadingZeros = {
171      19, 18, 18, 18, 18, 17, 17, 17, 16, 16, 16, 15, 15, 15, 15, 14, 14, 14, 13, 13, 13, 12, 12,
172      12, 12, 11, 11, 11, 10, 10, 10, 9, 9, 9, 9, 8, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 4, 4, 4,
173      3, 3, 3, 3, 2, 2, 2, 1, 1, 1, 0, 0, 0 };
174
175  @GwtIncompatible("TODO")
176  @VisibleForTesting
177  static final long[] powersOf10 = {
178    1L,
179    10L,
180    100L,
181    1000L,
182    10000L,
183    100000L,
184    1000000L,
185    10000000L,
186    100000000L,
187    1000000000L,
188    10000000000L,
189    100000000000L,
190    1000000000000L,
191    10000000000000L,
192    100000000000000L,
193    1000000000000000L,
194    10000000000000000L,
195    100000000000000000L,
196    1000000000000000000L
197  };
198
199  // halfPowersOf10[i] = largest long less than 10^(i + 0.5)
200  @GwtIncompatible("TODO")
201  @VisibleForTesting
202  static final long[] halfPowersOf10 = {
203    3L,
204    31L,
205    316L,
206    3162L,
207    31622L,
208    316227L,
209    3162277L,
210    31622776L,
211    316227766L,
212    3162277660L,
213    31622776601L,
214    316227766016L,
215    3162277660168L,
216    31622776601683L,
217    316227766016837L,
218    3162277660168379L,
219    31622776601683793L,
220    316227766016837933L,
221    3162277660168379331L
222  };
223
224  /**
225   * Returns {@code b} to the {@code k}th power. Even if the result overflows, it will be equal to
226   * {@code BigInteger.valueOf(b).pow(k).longValue()}. This implementation runs in {@code O(log k)}
227   * time.
228   *
229   * @throws IllegalArgumentException if {@code k < 0}
230   */
231  @GwtIncompatible("TODO")
232  public static long pow(long b, int k) {
233    checkNonNegative("exponent", k);
234    if (-2 <= b && b <= 2) {
235      switch ((int) b) {
236        case 0:
237          return (k == 0) ? 1 : 0;
238        case 1:
239          return 1;
240        case (-1):
241          return ((k & 1) == 0) ? 1 : -1;
242        case 2:
243          return (k < Long.SIZE) ? 1L << k : 0;
244        case (-2):
245          if (k < Long.SIZE) {
246            return ((k & 1) == 0) ? 1L << k : -(1L << k);
247          } else {
248            return 0;
249          }
250        default:
251          throw new AssertionError();
252      }
253    }
254    for (long accum = 1;; k >>= 1) {
255      switch (k) {
256        case 0:
257          return accum;
258        case 1:
259          return accum * b;
260        default:
261          accum *= ((k & 1) == 0) ? 1 : b;
262          b *= b;
263      }
264    }
265  }
266
267  /**
268   * Returns the square root of {@code x}, rounded with the specified rounding mode.
269   *
270   * @throws IllegalArgumentException if {@code x < 0}
271   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
272   *         {@code sqrt(x)} is not an integer
273   */
274  @GwtIncompatible("TODO")
275  @SuppressWarnings("fallthrough")
276  public static long sqrt(long x, RoundingMode mode) {
277    checkNonNegative("x", x);
278    if (fitsInInt(x)) {
279      return IntMath.sqrt((int) x, mode);
280    }
281    /*
282     * Let k be the true value of floor(sqrt(x)), so that
283     *
284     *            k * k <= x          <  (k + 1) * (k + 1)
285     * (double) (k * k) <= (double) x <= (double) ((k + 1) * (k + 1))
286     *          since casting to double is nondecreasing.
287     *          Note that the right-hand inequality is no longer strict.
288     * Math.sqrt(k * k) <= Math.sqrt(x) <= Math.sqrt((k + 1) * (k + 1))
289     *          since Math.sqrt is monotonic.
290     * (long) Math.sqrt(k * k) <= (long) Math.sqrt(x) <= (long) Math.sqrt((k + 1) * (k + 1))
291     *          since casting to long is monotonic
292     * k <= (long) Math.sqrt(x) <= k + 1
293     *          since (long) Math.sqrt(k * k) == k, as checked exhaustively in
294     *          {@link LongMathTest#testSqrtOfPerfectSquareAsDoubleIsPerfect}
295     */
296    long guess = (long) Math.sqrt(x);
297    // Note: guess is always <= FLOOR_SQRT_MAX_LONG.
298    long guessSquared = guess * guess;
299    // Note (2013-2-26): benchmarks indicate that, inscrutably enough, using if statements is
300    // faster here than using lessThanBranchFree.
301    switch (mode) {
302      case UNNECESSARY:
303        checkRoundingUnnecessary(guessSquared == x);
304        return guess;
305      case FLOOR:
306      case DOWN:
307        if (x < guessSquared) {
308          return guess - 1;
309        }
310        return guess;
311      case CEILING:
312      case UP:
313        if (x > guessSquared) {
314          return guess + 1;
315        }
316        return guess;
317      case HALF_DOWN:
318      case HALF_UP:
319      case HALF_EVEN:
320        long sqrtFloor = guess - ((x < guessSquared) ? 1 : 0);
321        long halfSquare = sqrtFloor * sqrtFloor + sqrtFloor;
322        /*
323         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
324         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
325         * halfSquare. (We have to deal with overflow, though.)
326         *
327         * If we treat halfSquare as an unsigned long, we know that
328         *            sqrtFloor^2 <= x < (sqrtFloor + 1)^2
329         * halfSquare - sqrtFloor <= x < halfSquare + sqrtFloor + 1
330         * so |x - halfSquare| <= sqrtFloor.  Therefore, it's safe to treat x - halfSquare as a
331         * signed long, so lessThanBranchFree is safe for use.
332         */
333        return sqrtFloor + lessThanBranchFree(halfSquare, x);
334      default:
335        throw new AssertionError();
336    }
337  }
338
339  /**
340   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
341   * {@code RoundingMode}.
342   *
343   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
344   *         is not an integer multiple of {@code b}
345   */
346  @GwtIncompatible("TODO")
347  @SuppressWarnings("fallthrough")
348  public static long divide(long p, long q, RoundingMode mode) {
349    checkNotNull(mode);
350    long div = p / q; // throws if q == 0
351    long rem = p - q * div; // equals p % q
352
353    if (rem == 0) {
354      return div;
355    }
356
357    /*
358     * Normal Java division rounds towards 0, consistently with RoundingMode.DOWN. We just have to
359     * deal with the cases where rounding towards 0 is wrong, which typically depends on the sign of
360     * p / q.
361     *
362     * signum is 1 if p and q are both nonnegative or both negative, and -1 otherwise.
363     */
364    int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
365    boolean increment;
366    switch (mode) {
367      case UNNECESSARY:
368        checkRoundingUnnecessary(rem == 0);
369        // fall through
370      case DOWN:
371        increment = false;
372        break;
373      case UP:
374        increment = true;
375        break;
376      case CEILING:
377        increment = signum > 0;
378        break;
379      case FLOOR:
380        increment = signum < 0;
381        break;
382      case HALF_EVEN:
383      case HALF_DOWN:
384      case HALF_UP:
385        long absRem = abs(rem);
386        long cmpRemToHalfDivisor = absRem - (abs(q) - absRem);
387        // subtracting two nonnegative longs can't overflow
388        // cmpRemToHalfDivisor has the same sign as compare(abs(rem), abs(q) / 2).
389        if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
390          increment = (mode == HALF_UP | (mode == HALF_EVEN & (div & 1) != 0));
391        } else {
392          increment = cmpRemToHalfDivisor > 0; // closer to the UP value
393        }
394        break;
395      default:
396        throw new AssertionError();
397    }
398    return increment ? div + signum : div;
399  }
400
401  /**
402   * Returns {@code x mod m}, a non-negative value less than {@code m}.
403   * This differs from {@code x % m}, which might be negative.
404   *
405   * <p>For example:
406   *
407   * <pre> {@code
408   *
409   * mod(7, 4) == 3
410   * mod(-7, 4) == 1
411   * mod(-1, 4) == 3
412   * mod(-8, 4) == 0
413   * mod(8, 4) == 0}</pre>
414   *
415   * @throws ArithmeticException if {@code m <= 0}
416   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
417   *      Remainder Operator</a>
418   */
419  @GwtIncompatible("TODO")
420  public static int mod(long x, int m) {
421    // Cast is safe because the result is guaranteed in the range [0, m)
422    return (int) mod(x, (long) m);
423  }
424
425  /**
426   * Returns {@code x mod m}, a non-negative value less than {@code m}.
427   * This differs from {@code x % m}, which might be negative.
428   *
429   * <p>For example:
430   *
431   * <pre> {@code
432   *
433   * mod(7, 4) == 3
434   * mod(-7, 4) == 1
435   * mod(-1, 4) == 3
436   * mod(-8, 4) == 0
437   * mod(8, 4) == 0}</pre>
438   *
439   * @throws ArithmeticException if {@code m <= 0}
440   * @see <a href="http://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.17.3">
441   *      Remainder Operator</a>
442   */
443  @GwtIncompatible("TODO")
444  public static long mod(long x, long m) {
445    if (m <= 0) {
446      throw new ArithmeticException("Modulus must be positive");
447    }
448    long result = x % m;
449    return (result >= 0) ? result : result + m;
450  }
451
452  /**
453   * Returns the greatest common divisor of {@code a, b}. Returns {@code 0} if
454   * {@code a == 0 && b == 0}.
455   *
456   * @throws IllegalArgumentException if {@code a < 0} or {@code b < 0}
457   */
458  public static long gcd(long a, long b) {
459    /*
460     * The reason we require both arguments to be >= 0 is because otherwise, what do you return on
461     * gcd(0, Long.MIN_VALUE)? BigInteger.gcd would return positive 2^63, but positive 2^63 isn't
462     * an int.
463     */
464    checkNonNegative("a", a);
465    checkNonNegative("b", b);
466    if (a == 0) {
467      // 0 % b == 0, so b divides a, but the converse doesn't hold.
468      // BigInteger.gcd is consistent with this decision.
469      return b;
470    } else if (b == 0) {
471      return a; // similar logic
472    }
473    /*
474     * Uses the binary GCD algorithm; see http://en.wikipedia.org/wiki/Binary_GCD_algorithm.
475     * This is >60% faster than the Euclidean algorithm in benchmarks.
476     */
477    int aTwos = Long.numberOfTrailingZeros(a);
478    a >>= aTwos; // divide out all 2s
479    int bTwos = Long.numberOfTrailingZeros(b);
480    b >>= bTwos; // divide out all 2s
481    while (a != b) { // both a, b are odd
482      // The key to the binary GCD algorithm is as follows:
483      // Both a and b are odd.  Assume a > b; then gcd(a - b, b) = gcd(a, b).
484      // But in gcd(a - b, b), a - b is even and b is odd, so we can divide out powers of two.
485
486      // We bend over backwards to avoid branching, adapting a technique from
487      // http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
488
489      long delta = a - b; // can't overflow, since a and b are nonnegative
490
491      long minDeltaOrZero = delta & (delta >> (Long.SIZE - 1));
492      // equivalent to Math.min(delta, 0)
493
494      a = delta - minDeltaOrZero - minDeltaOrZero; // sets a to Math.abs(a - b)
495      // a is now nonnegative and even
496
497      b += minDeltaOrZero; // sets b to min(old a, b)
498      a >>= Long.numberOfTrailingZeros(a); // divide out all 2s, since 2 doesn't divide b
499    }
500    return a << min(aTwos, bTwos);
501  }
502
503  /**
504   * Returns the sum of {@code a} and {@code b}, provided it does not overflow.
505   *
506   * @throws ArithmeticException if {@code a + b} overflows in signed {@code long} arithmetic
507   */
508  @GwtIncompatible("TODO")
509  public static long checkedAdd(long a, long b) {
510    long result = a + b;
511    checkNoOverflow((a ^ b) < 0 | (a ^ result) >= 0);
512    return result;
513  }
514
515  /**
516   * Returns the difference of {@code a} and {@code b}, provided it does not overflow.
517   *
518   * @throws ArithmeticException if {@code a - b} overflows in signed {@code long} arithmetic
519   */
520  @GwtIncompatible("TODO")
521  public static long checkedSubtract(long a, long b) {
522    long result = a - b;
523    checkNoOverflow((a ^ b) >= 0 | (a ^ result) >= 0);
524    return result;
525  }
526
527  /**
528   * Returns the product of {@code a} and {@code b}, provided it does not overflow.
529   *
530   * @throws ArithmeticException if {@code a * b} overflows in signed {@code long} arithmetic
531   */
532  @GwtIncompatible("TODO")
533  public static long checkedMultiply(long a, long b) {
534    // Hacker's Delight, Section 2-12
535    int leadingZeros = Long.numberOfLeadingZeros(a) + Long.numberOfLeadingZeros(~a)
536        + Long.numberOfLeadingZeros(b) + Long.numberOfLeadingZeros(~b);
537    /*
538     * If leadingZeros > Long.SIZE + 1 it's definitely fine, if it's < Long.SIZE it's definitely
539     * bad. We do the leadingZeros check to avoid the division below if at all possible.
540     *
541     * Otherwise, if b == Long.MIN_VALUE, then the only allowed values of a are 0 and 1. We take
542     * care of all a < 0 with their own check, because in particular, the case a == -1 will
543     * incorrectly pass the division check below.
544     *
545     * In all other cases, we check that either a is 0 or the result is consistent with division.
546     */
547    if (leadingZeros > Long.SIZE + 1) {
548      return a * b;
549    }
550    checkNoOverflow(leadingZeros >= Long.SIZE);
551    checkNoOverflow(a >= 0 | b != Long.MIN_VALUE);
552    long result = a * b;
553    checkNoOverflow(a == 0 || result / a == b);
554    return result;
555  }
556
557  /**
558   * Returns the {@code b} to the {@code k}th power, provided it does not overflow.
559   *
560   * @throws ArithmeticException if {@code b} to the {@code k}th power overflows in signed
561   *         {@code long} arithmetic
562   */
563  @GwtIncompatible("TODO")
564  public static long checkedPow(long b, int k) {
565    checkNonNegative("exponent", k);
566    if (b >= -2 & b <= 2) {
567      switch ((int) b) {
568        case 0:
569          return (k == 0) ? 1 : 0;
570        case 1:
571          return 1;
572        case (-1):
573          return ((k & 1) == 0) ? 1 : -1;
574        case 2:
575          checkNoOverflow(k < Long.SIZE - 1);
576          return 1L << k;
577        case (-2):
578          checkNoOverflow(k < Long.SIZE);
579          return ((k & 1) == 0) ? (1L << k) : (-1L << k);
580        default:
581          throw new AssertionError();
582      }
583    }
584    long accum = 1;
585    while (true) {
586      switch (k) {
587        case 0:
588          return accum;
589        case 1:
590          return checkedMultiply(accum, b);
591        default:
592          if ((k & 1) != 0) {
593            accum = checkedMultiply(accum, b);
594          }
595          k >>= 1;
596          if (k > 0) {
597            checkNoOverflow(-FLOOR_SQRT_MAX_LONG <= b && b <= FLOOR_SQRT_MAX_LONG);
598            b *= b;
599          }
600      }
601    }
602  }
603
604  @VisibleForTesting static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
605
606  /**
607   * Returns {@code n!}, that is, the product of the first {@code n} positive
608   * integers, {@code 1} if {@code n == 0}, or {@link Long#MAX_VALUE} if the
609   * result does not fit in a {@code long}.
610   *
611   * @throws IllegalArgumentException if {@code n < 0}
612   */
613  @GwtIncompatible("TODO")
614  public static long factorial(int n) {
615    checkNonNegative("n", n);
616    return (n < factorials.length) ? factorials[n] : Long.MAX_VALUE;
617  }
618
619  static final long[] factorials = {
620      1L,
621      1L,
622      1L * 2,
623      1L * 2 * 3,
624      1L * 2 * 3 * 4,
625      1L * 2 * 3 * 4 * 5,
626      1L * 2 * 3 * 4 * 5 * 6,
627      1L * 2 * 3 * 4 * 5 * 6 * 7,
628      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8,
629      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9,
630      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10,
631      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11,
632      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12,
633      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13,
634      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14,
635      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15,
636      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16,
637      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17,
638      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18,
639      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19,
640      1L * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10 * 11 * 12 * 13 * 14 * 15 * 16 * 17 * 18 * 19 * 20
641  };
642
643  /**
644   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
645   * {@code k}, or {@link Long#MAX_VALUE} if the result does not fit in a {@code long}.
646   *
647   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
648   */
649  public static long binomial(int n, int k) {
650    checkNonNegative("n", n);
651    checkNonNegative("k", k);
652    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
653    if (k > (n >> 1)) {
654      k = n - k;
655    }
656    switch (k) {
657      case 0:
658        return 1;
659      case 1:
660        return n;
661      default:
662        if (n < factorials.length) {
663          return factorials[n] / (factorials[k] * factorials[n - k]);
664        } else if (k >= biggestBinomials.length || n > biggestBinomials[k]) {
665          return Long.MAX_VALUE;
666        } else if (k < biggestSimpleBinomials.length && n <= biggestSimpleBinomials[k]) {
667          // guaranteed not to overflow
668          long result = n--;
669          for (int i = 2; i <= k; n--, i++) {
670            result *= n;
671            result /= i;
672          }
673          return result;
674        } else {
675          int nBits = LongMath.log2(n, RoundingMode.CEILING);
676
677          long result = 1;
678          long numerator = n--;
679          long denominator = 1;
680
681          int numeratorBits = nBits;
682          // This is an upper bound on log2(numerator, ceiling).
683
684          /*
685           * We want to do this in long math for speed, but want to avoid overflow. We adapt the
686           * technique previously used by BigIntegerMath: maintain separate numerator and
687           * denominator accumulators, multiplying the fraction into result when near overflow.
688           */
689          for (int i = 2; i <= k; i++, n--) {
690            if (numeratorBits + nBits < Long.SIZE - 1) {
691              // It's definitely safe to multiply into numerator and denominator.
692              numerator *= n;
693              denominator *= i;
694              numeratorBits += nBits;
695            } else {
696              // It might not be safe to multiply into numerator and denominator,
697              // so multiply (numerator / denominator) into result.
698              result = multiplyFraction(result, numerator, denominator);
699              numerator = n;
700              denominator = i;
701              numeratorBits = nBits;
702            }
703          }
704          return multiplyFraction(result, numerator, denominator);
705        }
706    }
707  }
708
709  /**
710   * Returns (x * numerator / denominator), which is assumed to come out to an integral value.
711   */
712  static long multiplyFraction(long x, long numerator, long denominator) {
713    if (x == 1) {
714      return numerator / denominator;
715    }
716    long commonDivisor = gcd(x, denominator);
717    x /= commonDivisor;
718    denominator /= commonDivisor;
719    // We know gcd(x, denominator) = 1, and x * numerator / denominator is exact,
720    // so denominator must be a divisor of numerator.
721    return x * (numerator / denominator);
722  }
723
724  /*
725   * binomial(biggestBinomials[k], k) fits in a long, but not
726   * binomial(biggestBinomials[k] + 1, k).
727   */
728  static final int[] biggestBinomials =
729      {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 3810779, 121977, 16175, 4337, 1733,
730          887, 534, 361, 265, 206, 169, 143, 125, 111, 101, 94, 88, 83, 79, 76, 74, 72, 70, 69, 68,
731          67, 67, 66, 66, 66, 66};
732
733  /*
734   * binomial(biggestSimpleBinomials[k], k) doesn't need to use the slower GCD-based impl,
735   * but binomial(biggestSimpleBinomials[k] + 1, k) does.
736   */
737  @VisibleForTesting static final int[] biggestSimpleBinomials =
738      {Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, 2642246, 86251, 11724, 3218, 1313,
739          684, 419, 287, 214, 169, 139, 119, 105, 95, 87, 81, 76, 73, 70, 68, 66, 64, 63, 62, 62,
740          61, 61, 61};
741  // These values were generated by using checkedMultiply to see when the simple multiply/divide
742  // algorithm would lead to an overflow.
743
744  static boolean fitsInInt(long x) {
745    return (int) x == x;
746  }
747
748  /**
749   * Returns the arithmetic mean of {@code x} and {@code y}, rounded toward
750   * negative infinity. This method is resilient to overflow.
751   *
752   * @since 14.0
753   */
754  public static long mean(long x, long y) {
755    // Efficient method for computing the arithmetic mean.
756    // The alternative (x + y) / 2 fails for large values.
757    // The alternative (x + y) >>> 1 fails for negative values.
758    return (x & y) + ((x ^ y) >> 1);
759  }
760
761  /*
762   * If n <= millerRabinBases[i][0], then testing n against bases millerRabinBases[i][1..]
763   * suffices to prove its primality.  Values from miller-rabin.appspot.com.
764   *
765   * NOTE: We could get slightly better bases that would be treated as unsigned, but benchmarks
766   * showed negligible performance improvements.
767   */
768  private static final long[][] millerRabinBaseSets = {
769    {291830, 126401071349994536L},
770    {885594168, 725270293939359937L, 3569819667048198375L},
771    {273919523040L, 15, 7363882082L, 992620450144556L},
772    {47636622961200L, 2, 2570940, 211991001, 3749873356L},
773    {7999252175582850L,
774      2, 4130806001517L, 149795463772692060L, 186635894390467037L, 3967304179347715805L},
775    {585226005592931976L,
776      2, 123635709730000L, 9233062284813009L, 43835965440333360L, 761179012939631437L,
777      1263739024124850375L},
778    {Long.MAX_VALUE,
779        2, 325, 9375, 28178, 450775, 9780504, 1795265022}
780  };
781
782  private enum MillerRabinTester {
783    /**
784     * Works for inputs <= FLOOR_SQRT_MAX_LONG.
785     */
786    SMALL {
787      @Override
788      long mulMod(long a, long b, long m) {
789        /*
790         * NOTE(lowasser, 2015-Feb-12): Benchmarks suggest that changing this to
791         * UnsignedLongs.remainder and increasing the threshold to 2^32 doesn't pay for itself,
792         * and adding another enum constant hurts performance further -- I suspect because
793         * bimorphic implementation is a sweet spot for the JVM.
794         */
795        return (a * b) % m;
796      }
797
798      @Override
799      long squareMod(long a, long m) {
800        return (a * a) % m;
801      }
802    },
803    /**
804     * Works for all nonnegative signed longs.
805     */
806    LARGE {
807      /**
808       * Returns (a + b) mod m.  Precondition: 0 <= a, b < m < 2^63.
809       */
810      private long plusMod(long a, long b, long m) {
811        return (a >= m - b) ? (a + b - m) : (a + b);
812      }
813
814      /**
815       * Returns (a * 2^32) mod m.  a may be any unsigned long.
816       */
817      private long times2ToThe32Mod(long a, long m) {
818        int remainingPowersOf2 = 32;
819        do {
820          int shift = Math.min(remainingPowersOf2, Long.numberOfLeadingZeros(a));
821          // shift is either the number of powers of 2 left to multiply a by, or the biggest shift
822          // possible while keeping a in an unsigned long.
823          a = UnsignedLongs.remainder(a << shift, m);
824          remainingPowersOf2 -= shift;
825        } while (remainingPowersOf2 > 0);
826        return a;
827      }
828
829      @Override
830      long mulMod(long a, long b, long m) {
831        long aHi = a >>> 32; // < 2^31
832        long bHi = b >>> 32; // < 2^31
833        long aLo = a & 0xFFFFFFFFL; // < 2^32
834        long bLo = b & 0xFFFFFFFFL; // < 2^32
835
836        /*
837         * a * b == aHi * bHi * 2^64 + (aHi * bLo + aLo * bHi) * 2^63 + aLo * bLo.
838         *       == (aHi * bHi * 2^32 + aHi * bLo + aLo * bHi) * 2^32 + aLo * bLo
839         *
840         * We carry out this computation in modular arithmetic.  Since times2ToThe32Mod accepts
841         * any unsigned long, we don't have to do a mod on every operation, only when intermediate
842         * results can exceed 2^63.
843         */
844        long result = times2ToThe32Mod(aHi * bHi /* < 2^62 */, m); // < m < 2^63
845        result += aHi * bLo; // aHi * bLo < 2^63, result < 2^64
846        if (result < 0) {
847          result = UnsignedLongs.remainder(result, m);
848        }
849        // result < 2^63 again
850        result += aLo * bHi; // aLo * bHi < 2^63, result < 2^64
851        result = times2ToThe32Mod(result, m); // result < m < 2^63
852        return plusMod(
853            result,
854            UnsignedLongs.remainder(aLo * bLo /* < 2^64 */, m),
855            m);
856      }
857
858      @Override
859      long squareMod(long a, long m) {
860        long aHi = a >>> 32; // < 2^31
861        long aLo = a & 0xFFFFFFFFL; // < 2^32
862
863        /*
864         * a^2 == aHi^2 * 2^64 + aHi * aLo * 2^33 + aLo^2
865         *     == (aHi^2 * 2^32 + aHi * aLo * 2) * 2^32 + aLo^2
866         * We carry out this computation in modular arithmetic.  Since times2ToThe32Mod accepts
867         * any unsigned long, we don't have to do a mod on every operation, only when intermediate
868         * results can exceed 2^63.
869         */
870        long result = times2ToThe32Mod(aHi * aHi /* < 2^62 */, m); // < m < 2^63
871        long hiLo = aHi * aLo * 2;
872        if (hiLo < 0) {
873          hiLo = UnsignedLongs.remainder(hiLo, m);
874        }
875        // hiLo < 2^63
876        result += hiLo; // result < 2^64
877        result = times2ToThe32Mod(result, m); // result < m < 2^63
878        return plusMod(
879            result,
880            UnsignedLongs.remainder(aLo * aLo /* < 2^64 */, m),
881            m);
882      }
883    };
884
885    static boolean test(long base, long n) {
886      // Since base will be considered % n, it's okay if base > FLOOR_SQRT_MAX_LONG,
887      // so long as n <= FLOOR_SQRT_MAX_LONG.
888      return ((n <= FLOOR_SQRT_MAX_LONG) ? SMALL : LARGE).testWitness(base, n);
889    }
890
891    /**
892     * Returns a * b mod m.
893     */
894    abstract long mulMod(long a, long b, long m);
895
896    /**
897     * Returns a^2 mod m.
898     */
899    abstract long squareMod(long a, long m);
900
901    /**
902     * Returns a^p mod m.
903     */
904    private long powMod(long a, long p, long m) {
905      long res = 1;
906      for (; p != 0; p >>= 1) {
907        if ((p & 1) != 0) {
908          res = mulMod(res, a, m);
909        }
910        a = squareMod(a, m);
911      }
912      return res;
913    }
914
915    /**
916     * Returns true if n is a strong probable prime relative to the specified base.
917     */
918    private boolean testWitness(long base, long n) {
919      int r = Long.numberOfTrailingZeros(n - 1);
920      long d = (n - 1) >> r;
921      base %= n;
922      if (base == 0) {
923        return true;
924      }
925      // Calculate a := base^d mod n.
926      long a = powMod(base, d, n);
927      // n passes this test if
928      //    base^d = 1 (mod n)
929      // or base^(2^j * d) = -1 (mod n) for some 0 <= j < r.
930      if (a == 1) {
931        return true;
932      }
933      int j = 0;
934      while (a != n - 1) {
935        if (++j == r) {
936          return false;
937        }
938        a = squareMod(a, n);
939      }
940      return true;
941    }
942  }
943
944  private LongMath() {}
945}