Fixing Spark’s RDD.zip()

I’ve been using Apache Spark recently and like it quite a lot, although it still has several rough edges.  One that I ran into is a quirk in RDD.zip().

I had two RDDs of equal length, but when I zipped them together, the zipped RDD had fewer elements than its parents.  Looking at the documentation for JavaRDD.zip(), it says:

Assumes that the two RDDs have the *same number of partitions* and the *same number of elements in each partition* (e.g. one was made through a map on the other).

So having equal-length RDDs is not sufficient.  For example, an RDD with the elements (a, b, c) partitioned as ([a, b], [c]) cannot be correctly zipped with (A, B, C) partitioned as ([A], [B, C]). I’d like to point out that this is a horrible violation of the principle of least astonishment.  Unfortunately, Spark does not provide a way to zip equal-length but unequally-partitioned RDDs. (If any Spark developers are reading this, it would be a much-appreciated feature.)

While incredibly inefficient, I wrote an implementation of a general zip function based on what Spark provides. It’s quite a mess in Java; I imagine it’s a bit nicer in Scala. I tested this by zipping two RDDs with 1000 elements, each element being a 1000-byte array. It takes about 400ms on my older, dual-core laptop. I haven’t yet had the chance to test this in a distributed setting.

The index function may be useful for other things, too, such as creating an RDD with the first k elements, rather than calling RDD.take(k), which returns the entire dataset to the client.

Again, the performance is terrible, but this may still be useful if other work dominates the computation.

 

public static <V1, V2> JavaPairRDD<V1, V2> zip(JavaRDD<V1> rdd1, JavaRDD<V2> rdd2) {
    return JavaPairRDD.fromJavaRDD(index(rdd1).join(index(rdd2)).values());
}

/**
 * Keys an RDD by the index of each element
 */
public static <T> JavaPairRDD<Integer, T> index(JavaRDD<T> rdd) {
    // The RDD consists of a single entry, which is a list of the sizes of each partition
    JavaRDD<List<Integer>> partitionSizesRDD = rdd.mapPartitions(new FlatMapFunction<Iterator<T>, Integer>() {
        @Override
        public Iterable<Integer> call(Iterator<T> arg0) throws Exception {
            int count = 0;
            while(arg0.hasNext()) {
                arg0.next();
                count++;
            }
            return Collections.singleton(count);
        }
    }).coalesce(1).glom();
    // The RDD consists of a single entry, which is a list of the start indices of each partition
    JavaRDD<List<Integer>> partitionOffsetsRDD = partitionSizesRDD
            .map(new Function<List<Integer>, List<Integer>>() {
                @Override
                public List<Integer> call(List<Integer> arg0) throws Exception {
                    int offset = 0;
                    List<Integer> out = new ArrayList<>();
                    for(Integer i : arg0) {
                        out.add(offset);
                        offset += i;
                    }
                    return out;
                }
            });
    ClassManifest<Tuple2<Integer, T>> cm = (ClassManifest) scala.reflect.ClassManifest$.MODULE$
            .fromClass(Tuple2.class);
    // Since partitionOffsetsRDD has a single element, this has the effect of pairing each partition with the list of all partition offsets
    JavaPairRDD<List<T>, List<Integer>> cartesian = rdd.glom().cartesian(partitionOffsetsRDD);
    // Since the RDD was already glommed, each partition has a single element.
    // We use the partition index to find the correct partition offset in the list of offsets.
    // Each individual element is then output with its index.
    JavaRDD<Tuple2<Integer, T>> result = cartesian.mapPartitionsWithIndex(
            new Function2<Object, Iterator<Tuple2<List<T>, List<Integer>>>, Iterator<Tuple2<Integer, T>>>() {
                @Override
                public Iterator<Tuple2<Integer, T>> call(Object arg0, Iterator<Tuple2<List<T>, List<Integer>>> arg1)
                        throws Exception {
                    int partition = (Integer) arg0;
                    Tuple2<List<T>, List<Integer>> t = arg1.next();
                    assert !arg1.hasNext();
                    final int initialOffset = t._2().get(partition);
                    final Iterator<T> itr = t._1().iterator();
                    return new Iterator<Tuple2<Integer, T>>() {
                        int offset = initialOffset;

                        @Override
                        public boolean hasNext() {
                            return itr.hasNext();
                        }

                        @Override
                        public Tuple2<Integer, T> next() {
                            return new Tuple2<Integer, T>(offset++, itr.next());
                        }

                        @Override
                        public void remove() {
                            throw new UnsupportedOperationException();
                        }
                    };
                }
            }, true, cm);
    return JavaPairRDD.fromJavaRDD(result);
}

One Comment

  1. Posted August 8, 2017 at 7:31 pm | Permalink

    I’d like to find out more? I’d love to find out more details.

Post a Comment

Your email is never shared. Required fields are marked *

*
*