001/*
002 * Copyright (C) 2007 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkState;
021import static com.google.common.collect.CollectPreconditions.checkNonnegative;
022import static com.google.common.collect.CollectPreconditions.checkRemove;
023
024import com.google.common.annotations.GwtCompatible;
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.base.MoreObjects;
027import com.google.common.primitives.Ints;
028import com.google.errorprone.annotations.CanIgnoreReturnValue;
029import java.io.IOException;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.Serializable;
033import java.util.Comparator;
034import java.util.ConcurrentModificationException;
035import java.util.Iterator;
036import java.util.NoSuchElementException;
037import javax.annotation.Nullable;
038
039/**
040 * A multiset which maintains the ordering of its elements, according to either their natural order
041 * or an explicit {@link Comparator}. In all cases, this implementation uses
042 * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
043 * determine equivalence of instances.
044 *
045 * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
046 * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
047 * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
048 *
049 * <p>See the Guava User Guide article on <a href=
050 * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset">
051 * {@code Multiset}</a>.
052 *
053 * @author Louis Wasserman
054 * @author Jared Levy
055 * @since 2.0
056 */
057@GwtCompatible(emulated = true)
058public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
059
060  /**
061   * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
062   * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
063   * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
064   * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
065   * user attempts to add an element to the multiset that violates this constraint (for example,
066   * the user attempts to add a string element to a set whose elements are integers), the
067   * {@code add(Object)} call will throw a {@code ClassCastException}.
068   *
069   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
070   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
071   */
072  public static <E extends Comparable> TreeMultiset<E> create() {
073    return new TreeMultiset<E>(Ordering.natural());
074  }
075
076  /**
077   * Creates a new, empty multiset, sorted according to the specified comparator. All elements
078   * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
079   * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
080   * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
081   * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
082   * ClassCastException}.
083   *
084   * @param comparator the comparator that will be used to sort this multiset. A null value
085   *     indicates that the elements' <i>natural ordering</i> should be used.
086   */
087  @SuppressWarnings("unchecked")
088  public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
089    return (comparator == null)
090        ? new TreeMultiset<E>((Comparator) Ordering.natural())
091        : new TreeMultiset<E>(comparator);
092  }
093
094  /**
095   * Creates an empty multiset containing the given initial elements, sorted according to the
096   * elements' natural order.
097   *
098   * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
099   *
100   * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101   * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102   */
103  public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104    TreeMultiset<E> multiset = create();
105    Iterables.addAll(multiset, elements);
106    return multiset;
107  }
108
109  private final transient Reference<AvlNode<E>> rootReference;
110  private final transient GeneralRange<E> range;
111  private final transient AvlNode<E> header;
112
113  TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114    super(range.comparator());
115    this.rootReference = rootReference;
116    this.range = range;
117    this.header = endLink;
118  }
119
120  TreeMultiset(Comparator<? super E> comparator) {
121    super(comparator);
122    this.range = GeneralRange.all(comparator);
123    this.header = new AvlNode<E>(null, 1);
124    successor(header, header);
125    this.rootReference = new Reference<>();
126  }
127
128  /**
129   * A function which can be summed across a subtree.
130   */
131  private enum Aggregate {
132    SIZE {
133      @Override
134      int nodeAggregate(AvlNode<?> node) {
135        return node.elemCount;
136      }
137
138      @Override
139      long treeAggregate(@Nullable AvlNode<?> root) {
140        return (root == null) ? 0 : root.totalCount;
141      }
142    },
143    DISTINCT {
144      @Override
145      int nodeAggregate(AvlNode<?> node) {
146        return 1;
147      }
148
149      @Override
150      long treeAggregate(@Nullable AvlNode<?> root) {
151        return (root == null) ? 0 : root.distinctElements;
152      }
153    };
154
155    abstract int nodeAggregate(AvlNode<?> node);
156
157    abstract long treeAggregate(@Nullable AvlNode<?> root);
158  }
159
160  private long aggregateForEntries(Aggregate aggr) {
161    AvlNode<E> root = rootReference.get();
162    long total = aggr.treeAggregate(root);
163    if (range.hasLowerBound()) {
164      total -= aggregateBelowRange(aggr, root);
165    }
166    if (range.hasUpperBound()) {
167      total -= aggregateAboveRange(aggr, root);
168    }
169    return total;
170  }
171
172  private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
173    if (node == null) {
174      return 0;
175    }
176    int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
177    if (cmp < 0) {
178      return aggregateBelowRange(aggr, node.left);
179    } else if (cmp == 0) {
180      switch (range.getLowerBoundType()) {
181        case OPEN:
182          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
183        case CLOSED:
184          return aggr.treeAggregate(node.left);
185        default:
186          throw new AssertionError();
187      }
188    } else {
189      return aggr.treeAggregate(node.left)
190          + aggr.nodeAggregate(node)
191          + aggregateBelowRange(aggr, node.right);
192    }
193  }
194
195  private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
196    if (node == null) {
197      return 0;
198    }
199    int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
200    if (cmp > 0) {
201      return aggregateAboveRange(aggr, node.right);
202    } else if (cmp == 0) {
203      switch (range.getUpperBoundType()) {
204        case OPEN:
205          return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
206        case CLOSED:
207          return aggr.treeAggregate(node.right);
208        default:
209          throw new AssertionError();
210      }
211    } else {
212      return aggr.treeAggregate(node.right)
213          + aggr.nodeAggregate(node)
214          + aggregateAboveRange(aggr, node.left);
215    }
216  }
217
218  @Override
219  public int size() {
220    return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
221  }
222
223  @Override
224  int distinctElements() {
225    return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
226  }
227
228  @Override
229  public int count(@Nullable Object element) {
230    try {
231      @SuppressWarnings("unchecked")
232      E e = (E) element;
233      AvlNode<E> root = rootReference.get();
234      if (!range.contains(e) || root == null) {
235        return 0;
236      }
237      return root.count(comparator(), e);
238    } catch (ClassCastException | NullPointerException e) {
239      return 0;
240    }
241  }
242
243  @CanIgnoreReturnValue
244  @Override
245  public int add(@Nullable E element, int occurrences) {
246    checkNonnegative(occurrences, "occurrences");
247    if (occurrences == 0) {
248      return count(element);
249    }
250    checkArgument(range.contains(element));
251    AvlNode<E> root = rootReference.get();
252    if (root == null) {
253      comparator().compare(element, element);
254      AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
255      successor(header, newRoot, header);
256      rootReference.checkAndSet(root, newRoot);
257      return 0;
258    }
259    int[] result = new int[1]; // used as a mutable int reference to hold result
260    AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
261    rootReference.checkAndSet(root, newRoot);
262    return result[0];
263  }
264
265  @CanIgnoreReturnValue
266  @Override
267  public int remove(@Nullable Object element, int occurrences) {
268    checkNonnegative(occurrences, "occurrences");
269    if (occurrences == 0) {
270      return count(element);
271    }
272    AvlNode<E> root = rootReference.get();
273    int[] result = new int[1]; // used as a mutable int reference to hold result
274    AvlNode<E> newRoot;
275    try {
276      @SuppressWarnings("unchecked")
277      E e = (E) element;
278      if (!range.contains(e) || root == null) {
279        return 0;
280      }
281      newRoot = root.remove(comparator(), e, occurrences, result);
282    } catch (ClassCastException | NullPointerException e) {
283      return 0;
284    }
285    rootReference.checkAndSet(root, newRoot);
286    return result[0];
287  }
288
289  @CanIgnoreReturnValue
290  @Override
291  public int setCount(@Nullable E element, int count) {
292    checkNonnegative(count, "count");
293    if (!range.contains(element)) {
294      checkArgument(count == 0);
295      return 0;
296    }
297
298    AvlNode<E> root = rootReference.get();
299    if (root == null) {
300      if (count > 0) {
301        add(element, count);
302      }
303      return 0;
304    }
305    int[] result = new int[1]; // used as a mutable int reference to hold result
306    AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
307    rootReference.checkAndSet(root, newRoot);
308    return result[0];
309  }
310
311  @CanIgnoreReturnValue
312  @Override
313  public boolean setCount(@Nullable E element, int oldCount, int newCount) {
314    checkNonnegative(newCount, "newCount");
315    checkNonnegative(oldCount, "oldCount");
316    checkArgument(range.contains(element));
317
318    AvlNode<E> root = rootReference.get();
319    if (root == null) {
320      if (oldCount == 0) {
321        if (newCount > 0) {
322          add(element, newCount);
323        }
324        return true;
325      } else {
326        return false;
327      }
328    }
329    int[] result = new int[1]; // used as a mutable int reference to hold result
330    AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
331    rootReference.checkAndSet(root, newRoot);
332    return result[0] == oldCount;
333  }
334
335  private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
336    return new Multisets.AbstractEntry<E>() {
337      @Override
338      public E getElement() {
339        return baseEntry.getElement();
340      }
341
342      @Override
343      public int getCount() {
344        int result = baseEntry.getCount();
345        if (result == 0) {
346          return count(getElement());
347        } else {
348          return result;
349        }
350      }
351    };
352  }
353
354  /**
355   * Returns the first node in the tree that is in range.
356   */
357  @Nullable
358  private AvlNode<E> firstNode() {
359    AvlNode<E> root = rootReference.get();
360    if (root == null) {
361      return null;
362    }
363    AvlNode<E> node;
364    if (range.hasLowerBound()) {
365      E endpoint = range.getLowerEndpoint();
366      node = rootReference.get().ceiling(comparator(), endpoint);
367      if (node == null) {
368        return null;
369      }
370      if (range.getLowerBoundType() == BoundType.OPEN
371          && comparator().compare(endpoint, node.getElement()) == 0) {
372        node = node.succ;
373      }
374    } else {
375      node = header.succ;
376    }
377    return (node == header || !range.contains(node.getElement())) ? null : node;
378  }
379
380  @Nullable
381  private AvlNode<E> lastNode() {
382    AvlNode<E> root = rootReference.get();
383    if (root == null) {
384      return null;
385    }
386    AvlNode<E> node;
387    if (range.hasUpperBound()) {
388      E endpoint = range.getUpperEndpoint();
389      node = rootReference.get().floor(comparator(), endpoint);
390      if (node == null) {
391        return null;
392      }
393      if (range.getUpperBoundType() == BoundType.OPEN
394          && comparator().compare(endpoint, node.getElement()) == 0) {
395        node = node.pred;
396      }
397    } else {
398      node = header.pred;
399    }
400    return (node == header || !range.contains(node.getElement())) ? null : node;
401  }
402
403  @Override
404  Iterator<Entry<E>> entryIterator() {
405    return new Iterator<Entry<E>>() {
406      AvlNode<E> current = firstNode();
407      Entry<E> prevEntry;
408
409      @Override
410      public boolean hasNext() {
411        if (current == null) {
412          return false;
413        } else if (range.tooHigh(current.getElement())) {
414          current = null;
415          return false;
416        } else {
417          return true;
418        }
419      }
420
421      @Override
422      public Entry<E> next() {
423        if (!hasNext()) {
424          throw new NoSuchElementException();
425        }
426        Entry<E> result = wrapEntry(current);
427        prevEntry = result;
428        if (current.succ == header) {
429          current = null;
430        } else {
431          current = current.succ;
432        }
433        return result;
434      }
435
436      @Override
437      public void remove() {
438        checkRemove(prevEntry != null);
439        setCount(prevEntry.getElement(), 0);
440        prevEntry = null;
441      }
442    };
443  }
444
445  @Override
446  Iterator<Entry<E>> descendingEntryIterator() {
447    return new Iterator<Entry<E>>() {
448      AvlNode<E> current = lastNode();
449      Entry<E> prevEntry = null;
450
451      @Override
452      public boolean hasNext() {
453        if (current == null) {
454          return false;
455        } else if (range.tooLow(current.getElement())) {
456          current = null;
457          return false;
458        } else {
459          return true;
460        }
461      }
462
463      @Override
464      public Entry<E> next() {
465        if (!hasNext()) {
466          throw new NoSuchElementException();
467        }
468        Entry<E> result = wrapEntry(current);
469        prevEntry = result;
470        if (current.pred == header) {
471          current = null;
472        } else {
473          current = current.pred;
474        }
475        return result;
476      }
477
478      @Override
479      public void remove() {
480        checkRemove(prevEntry != null);
481        setCount(prevEntry.getElement(), 0);
482        prevEntry = null;
483      }
484    };
485  }
486
487  @Override
488  public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
489    return new TreeMultiset<E>(
490        rootReference,
491        range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
492        header);
493  }
494
495  @Override
496  public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
497    return new TreeMultiset<E>(
498        rootReference,
499        range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
500        header);
501  }
502
503  static int distinctElements(@Nullable AvlNode<?> node) {
504    return (node == null) ? 0 : node.distinctElements;
505  }
506
507  private static final class Reference<T> {
508    @Nullable private T value;
509
510    @Nullable
511    public T get() {
512      return value;
513    }
514
515    public void checkAndSet(@Nullable T expected, T newValue) {
516      if (value != expected) {
517        throw new ConcurrentModificationException();
518      }
519      value = newValue;
520    }
521  }
522
523  private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
524    @Nullable private final E elem;
525
526    // elemCount is 0 iff this node has been deleted.
527    private int elemCount;
528
529    private int distinctElements;
530    private long totalCount;
531    private int height;
532    private AvlNode<E> left;
533    private AvlNode<E> right;
534    private AvlNode<E> pred;
535    private AvlNode<E> succ;
536
537    AvlNode(@Nullable E elem, int elemCount) {
538      checkArgument(elemCount > 0);
539      this.elem = elem;
540      this.elemCount = elemCount;
541      this.totalCount = elemCount;
542      this.distinctElements = 1;
543      this.height = 1;
544      this.left = null;
545      this.right = null;
546    }
547
548    public int count(Comparator<? super E> comparator, E e) {
549      int cmp = comparator.compare(e, elem);
550      if (cmp < 0) {
551        return (left == null) ? 0 : left.count(comparator, e);
552      } else if (cmp > 0) {
553        return (right == null) ? 0 : right.count(comparator, e);
554      } else {
555        return elemCount;
556      }
557    }
558
559    private AvlNode<E> addRightChild(E e, int count) {
560      right = new AvlNode<E>(e, count);
561      successor(this, right, succ);
562      height = Math.max(2, height);
563      distinctElements++;
564      totalCount += count;
565      return this;
566    }
567
568    private AvlNode<E> addLeftChild(E e, int count) {
569      left = new AvlNode<E>(e, count);
570      successor(pred, left, this);
571      height = Math.max(2, height);
572      distinctElements++;
573      totalCount += count;
574      return this;
575    }
576
577    AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
578      /*
579       * It speeds things up considerably to unconditionally add count to totalCount here,
580       * but that destroys failure atomicity in the case of count overflow. =(
581       */
582      int cmp = comparator.compare(e, elem);
583      if (cmp < 0) {
584        AvlNode<E> initLeft = left;
585        if (initLeft == null) {
586          result[0] = 0;
587          return addLeftChild(e, count);
588        }
589        int initHeight = initLeft.height;
590
591        left = initLeft.add(comparator, e, count, result);
592        if (result[0] == 0) {
593          distinctElements++;
594        }
595        this.totalCount += count;
596        return (left.height == initHeight) ? this : rebalance();
597      } else if (cmp > 0) {
598        AvlNode<E> initRight = right;
599        if (initRight == null) {
600          result[0] = 0;
601          return addRightChild(e, count);
602        }
603        int initHeight = initRight.height;
604
605        right = initRight.add(comparator, e, count, result);
606        if (result[0] == 0) {
607          distinctElements++;
608        }
609        this.totalCount += count;
610        return (right.height == initHeight) ? this : rebalance();
611      }
612
613      // adding count to me!  No rebalance possible.
614      result[0] = elemCount;
615      long resultCount = (long) elemCount + count;
616      checkArgument(resultCount <= Integer.MAX_VALUE);
617      this.elemCount += count;
618      this.totalCount += count;
619      return this;
620    }
621
622    AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
623      int cmp = comparator.compare(e, elem);
624      if (cmp < 0) {
625        AvlNode<E> initLeft = left;
626        if (initLeft == null) {
627          result[0] = 0;
628          return this;
629        }
630
631        left = initLeft.remove(comparator, e, count, result);
632
633        if (result[0] > 0) {
634          if (count >= result[0]) {
635            this.distinctElements--;
636            this.totalCount -= result[0];
637          } else {
638            this.totalCount -= count;
639          }
640        }
641        return (result[0] == 0) ? this : rebalance();
642      } else if (cmp > 0) {
643        AvlNode<E> initRight = right;
644        if (initRight == null) {
645          result[0] = 0;
646          return this;
647        }
648
649        right = initRight.remove(comparator, e, count, result);
650
651        if (result[0] > 0) {
652          if (count >= result[0]) {
653            this.distinctElements--;
654            this.totalCount -= result[0];
655          } else {
656            this.totalCount -= count;
657          }
658        }
659        return rebalance();
660      }
661
662      // removing count from me!
663      result[0] = elemCount;
664      if (count >= elemCount) {
665        return deleteMe();
666      } else {
667        this.elemCount -= count;
668        this.totalCount -= count;
669        return this;
670      }
671    }
672
673    AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
674      int cmp = comparator.compare(e, elem);
675      if (cmp < 0) {
676        AvlNode<E> initLeft = left;
677        if (initLeft == null) {
678          result[0] = 0;
679          return (count > 0) ? addLeftChild(e, count) : this;
680        }
681
682        left = initLeft.setCount(comparator, e, count, result);
683
684        if (count == 0 && result[0] != 0) {
685          this.distinctElements--;
686        } else if (count > 0 && result[0] == 0) {
687          this.distinctElements++;
688        }
689
690        this.totalCount += count - result[0];
691        return rebalance();
692      } else if (cmp > 0) {
693        AvlNode<E> initRight = right;
694        if (initRight == null) {
695          result[0] = 0;
696          return (count > 0) ? addRightChild(e, count) : this;
697        }
698
699        right = initRight.setCount(comparator, e, count, result);
700
701        if (count == 0 && result[0] != 0) {
702          this.distinctElements--;
703        } else if (count > 0 && result[0] == 0) {
704          this.distinctElements++;
705        }
706
707        this.totalCount += count - result[0];
708        return rebalance();
709      }
710
711      // setting my count
712      result[0] = elemCount;
713      if (count == 0) {
714        return deleteMe();
715      }
716      this.totalCount += count - elemCount;
717      this.elemCount = count;
718      return this;
719    }
720
721    AvlNode<E> setCount(
722        Comparator<? super E> comparator,
723        @Nullable E e,
724        int expectedCount,
725        int newCount,
726        int[] result) {
727      int cmp = comparator.compare(e, elem);
728      if (cmp < 0) {
729        AvlNode<E> initLeft = left;
730        if (initLeft == null) {
731          result[0] = 0;
732          if (expectedCount == 0 && newCount > 0) {
733            return addLeftChild(e, newCount);
734          }
735          return this;
736        }
737
738        left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
739
740        if (result[0] == expectedCount) {
741          if (newCount == 0 && result[0] != 0) {
742            this.distinctElements--;
743          } else if (newCount > 0 && result[0] == 0) {
744            this.distinctElements++;
745          }
746          this.totalCount += newCount - result[0];
747        }
748        return rebalance();
749      } else if (cmp > 0) {
750        AvlNode<E> initRight = right;
751        if (initRight == null) {
752          result[0] = 0;
753          if (expectedCount == 0 && newCount > 0) {
754            return addRightChild(e, newCount);
755          }
756          return this;
757        }
758
759        right = initRight.setCount(comparator, e, expectedCount, newCount, result);
760
761        if (result[0] == expectedCount) {
762          if (newCount == 0 && result[0] != 0) {
763            this.distinctElements--;
764          } else if (newCount > 0 && result[0] == 0) {
765            this.distinctElements++;
766          }
767          this.totalCount += newCount - result[0];
768        }
769        return rebalance();
770      }
771
772      // setting my count
773      result[0] = elemCount;
774      if (expectedCount == elemCount) {
775        if (newCount == 0) {
776          return deleteMe();
777        }
778        this.totalCount += newCount - elemCount;
779        this.elemCount = newCount;
780      }
781      return this;
782    }
783
784    private AvlNode<E> deleteMe() {
785      int oldElemCount = this.elemCount;
786      this.elemCount = 0;
787      successor(pred, succ);
788      if (left == null) {
789        return right;
790      } else if (right == null) {
791        return left;
792      } else if (left.height >= right.height) {
793        AvlNode<E> newTop = pred;
794        // newTop is the maximum node in my left subtree
795        newTop.left = left.removeMax(newTop);
796        newTop.right = right;
797        newTop.distinctElements = distinctElements - 1;
798        newTop.totalCount = totalCount - oldElemCount;
799        return newTop.rebalance();
800      } else {
801        AvlNode<E> newTop = succ;
802        newTop.right = right.removeMin(newTop);
803        newTop.left = left;
804        newTop.distinctElements = distinctElements - 1;
805        newTop.totalCount = totalCount - oldElemCount;
806        return newTop.rebalance();
807      }
808    }
809
810    // Removes the minimum node from this subtree to be reused elsewhere
811    private AvlNode<E> removeMin(AvlNode<E> node) {
812      if (left == null) {
813        return right;
814      } else {
815        left = left.removeMin(node);
816        distinctElements--;
817        totalCount -= node.elemCount;
818        return rebalance();
819      }
820    }
821
822    // Removes the maximum node from this subtree to be reused elsewhere
823    private AvlNode<E> removeMax(AvlNode<E> node) {
824      if (right == null) {
825        return left;
826      } else {
827        right = right.removeMax(node);
828        distinctElements--;
829        totalCount -= node.elemCount;
830        return rebalance();
831      }
832    }
833
834    private void recomputeMultiset() {
835      this.distinctElements =
836          1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
837      this.totalCount = elemCount + totalCount(left) + totalCount(right);
838    }
839
840    private void recomputeHeight() {
841      this.height = 1 + Math.max(height(left), height(right));
842    }
843
844    private void recompute() {
845      recomputeMultiset();
846      recomputeHeight();
847    }
848
849    private AvlNode<E> rebalance() {
850      switch (balanceFactor()) {
851        case -2:
852          if (right.balanceFactor() > 0) {
853            right = right.rotateRight();
854          }
855          return rotateLeft();
856        case 2:
857          if (left.balanceFactor() < 0) {
858            left = left.rotateLeft();
859          }
860          return rotateRight();
861        default:
862          recomputeHeight();
863          return this;
864      }
865    }
866
867    private int balanceFactor() {
868      return height(left) - height(right);
869    }
870
871    private AvlNode<E> rotateLeft() {
872      checkState(right != null);
873      AvlNode<E> newTop = right;
874      this.right = newTop.left;
875      newTop.left = this;
876      newTop.totalCount = this.totalCount;
877      newTop.distinctElements = this.distinctElements;
878      this.recompute();
879      newTop.recomputeHeight();
880      return newTop;
881    }
882
883    private AvlNode<E> rotateRight() {
884      checkState(left != null);
885      AvlNode<E> newTop = left;
886      this.left = newTop.right;
887      newTop.right = this;
888      newTop.totalCount = this.totalCount;
889      newTop.distinctElements = this.distinctElements;
890      this.recompute();
891      newTop.recomputeHeight();
892      return newTop;
893    }
894
895    private static long totalCount(@Nullable AvlNode<?> node) {
896      return (node == null) ? 0 : node.totalCount;
897    }
898
899    private static int height(@Nullable AvlNode<?> node) {
900      return (node == null) ? 0 : node.height;
901    }
902
903    @Nullable
904    private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
905      int cmp = comparator.compare(e, elem);
906      if (cmp < 0) {
907        return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
908      } else if (cmp == 0) {
909        return this;
910      } else {
911        return (right == null) ? null : right.ceiling(comparator, e);
912      }
913    }
914
915    @Nullable
916    private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
917      int cmp = comparator.compare(e, elem);
918      if (cmp > 0) {
919        return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
920      } else if (cmp == 0) {
921        return this;
922      } else {
923        return (left == null) ? null : left.floor(comparator, e);
924      }
925    }
926
927    @Override
928    public E getElement() {
929      return elem;
930    }
931
932    @Override
933    public int getCount() {
934      return elemCount;
935    }
936
937    @Override
938    public String toString() {
939      return Multisets.immutableEntry(getElement(), getCount()).toString();
940    }
941  }
942
943  private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
944    a.succ = b;
945    b.pred = a;
946  }
947
948  private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
949    successor(a, b);
950    successor(b, c);
951  }
952
953  /*
954   * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
955   * calls the comparator to compare the two keys. If that change is made,
956   * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
957   */
958
959  /**
960   * @serialData the comparator, the number of distinct elements, the first element, its count, the
961   *             second element, its count, and so on
962   */
963  @GwtIncompatible // java.io.ObjectOutputStream
964  private void writeObject(ObjectOutputStream stream) throws IOException {
965    stream.defaultWriteObject();
966    stream.writeObject(elementSet().comparator());
967    Serialization.writeMultiset(this, stream);
968  }
969
970  @GwtIncompatible // java.io.ObjectInputStream
971  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
972    stream.defaultReadObject();
973    @SuppressWarnings("unchecked")
974    // reading data stored by writeObject
975    Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
976    Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
977    Serialization.getFieldSetter(TreeMultiset.class, "range")
978        .set(this, GeneralRange.all(comparator));
979    Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
980        .set(this, new Reference<AvlNode<E>>());
981    AvlNode<E> header = new AvlNode<E>(null, 1);
982    Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
983    successor(header, header);
984    Serialization.populateMultiset(this, stream);
985  }
986
987  @GwtIncompatible // not needed in emulated source
988  private static final long serialVersionUID = 1;
989}