一个Spliterator能实现Stream中元素排序

21-04-11 banq

如果我们有一个List <Stream <T >>,每个流都具有排序元素,那么如何生成一个排序后的Stream <T>,一次从每个流中获取一个?javaspecialists的文章将展示如何使用Stream API并编写我们自己的MergingSortedSpliterator。

import java.util.*;
import java.util.function.*;
import java.util.stream.*;

public class MergingSortedSpliterator<T> implements Spliterator<T> {
  private final List<Spliterator<T>> spliterators;
  private final List<Iterator<T>> iterators;
  private final int characteristics;
  private final Object[] nextItem;
  private final static Object START_OF_STREAM = new Object();
  private final static Object END_OF_STREAM = new Object();
  private final Comparator<? super T> comparator;
  private final boolean distinct;

  public MergingSortedSpliterator(Collection<Stream<T>> streams) {
    this.spliterators = streams.stream()
        .map(Stream::spliterator)
        .collect(Collectors.toList());
    if (!spliterators.stream().allMatch(
        spliterator -> spliterator.hasCharacteristics(SORTED)))
      throw new IllegalArgumentException("Streams must be sorted");
    Comparator<? super T> comparator = spliterators.stream()
        .map(Spliterator::getComparator)
        .reduce(null, (a, b) -> {
          if (Objects.equals(a, b)) return a;
          else throw new IllegalArgumentException(
              "Mismatching comparators " + a + " and " + b);
        });
    this.comparator = Objects.requireNonNullElse(comparator,
        (Comparator<? super T>) Comparator.naturalOrder());
    this.characteristics = spliterators.stream()
        .mapToInt(Spliterator::characteristics)
        .reduce((ch1, ch2) -> ch1 & ch2)
        .orElse(0);
    this.distinct = hasCharacteristics(DISTINCT);

    // setting up iterators
    this.iterators = spliterators.stream()
        .map(Spliterators::iterator)
        .collect(Collectors.toList());
    nextItem = new Object[streams.size()];
    Arrays.fill(nextItem, START_OF_STREAM);
  }

  private Object fetchNext(Iterator<T> iterator) {
    return iterator.hasNext() ? iterator.next() : END_OF_STREAM;
  }

  public boolean tryAdvance(Consumer<? super T> action) {
    Objects.requireNonNull(action, "action==null");
    if (nextItem.length == 0) return false;
    T smallest = null;
    int smallestIndex = -1;
    for (int i = 0; i < nextItem.length; i++) {
      Object o = nextItem[i];
      if (o == START_OF_STREAM)
        nextItem[i] = o = fetchNext(iterators.get(i));
      if (o != END_OF_STREAM) {
        T t = (T) o;
        if (smallest == null ||
            comparator.compare(t, smallest) < 0) {
          smallest = t;
          smallestIndex = i;
        }
      }
    }

    // smallest might be null if the stream contains nulls
    if (smallestIndex == -1) return false;

    if (distinct) {
      for (int i = 0; i < nextItem.length; i++) {
        Iterator<T> iterator = iterators.get(i);
        while (nextItem[i] != END_OF_STREAM &&
            comparator.compare(smallest, (T) nextItem[i]) == 0) {
          nextItem[i] = fetchNext(iterator);
        }
      }
    } else {
      nextItem[smallestIndex] =
          fetchNext(iterators.get(smallestIndex));
    }

    action.accept(smallest);
    return true;
  }

  public Spliterator<T> trySplit() {
    // never split - parallel not supported
    return null;
  }

  public long estimateSize() {
    return spliterators.stream()
        .mapToLong(Spliterator::estimateSize)
        .reduce((ch1, ch2) -> {
          long result;
          if ((result = ch1 + ch2) < 0) result = Long.MAX_VALUE;
          return result;
        })
        .orElse(0);
  }

  public int characteristics() {
    return characteristics;
  }

  public Comparator<? super T> getComparator() {
    return comparator;
  }
}
  

下面是调用测试客户端代码:

import java.util.*;
import java.util.concurrent.*;
import java.util.stream.*;

public class SortedStreamOfSortedStreams {
  private static final int SIZE = 5;

  public static void main(String... args) {
    List<Stream<Integer>> streams = List.of(
        generateSortedRandom(SIZE),
        generateSortedRandom(SIZE),
        generateSortedRandom(SIZE),
        generateSortedRandom(SIZE)
    );

    Stream<Integer> numbers = StreamSupport.stream(
        new MergingSortedSpliterator<>(streams), false
    );
    numbers.forEach(System.out::println);
  }

  private static Stream<Integer> generateSortedRandom(int size) {
    return ThreadLocalRandom.current().ints(size, 0, size * 4)
        .parallel()
        .sorted()
        .boxed();
  }
}

输出结果:

0
0
2
4
4
5
6
6
7
10
10
11
12
15
16
17
18
18
19
19
  

它比并行排序的flatMap甚至更快。

猜你喜欢