Fixing Spark’s

I’ve been using Apache Spark recently and like it quite a lot, although it still has several rough edges.  One that I ran into is a quirk in

I had two RDDs of equal length, but when I zipped them together, the zipped RDD had fewer elements than its parents.  Looking at the documentation for, it says:

Assumes that the two RDDs have the *same number of partitions* and the *same number of elements in each partition* (e.g. one was made through a map on the other).

So having equal-length RDDs is not sufficient.  For example, an RDD with the elements (a, b, c) partitioned as ([a, b], [c]) cannot be correctly zipped with (A, B, C) partitioned as ([A], [B, C]). I’d like to point out that this is a horrible violation of the principle of least astonishment.  Unfortunately, Spark does not provide a way to zip equal-length but unequally-partitioned RDDs. (If any Spark developers are reading this, it would be a much-appreciated feature.)

While incredibly inefficient, I wrote an implementation of a general zip function based on what Spark provides. It’s quite a mess in Java; I imagine it’s a bit nicer in Scala. I tested this by zipping two RDDs with 1000 elements, each element being a 1000-byte array. It takes about 400ms on my older, dual-core laptop. I haven’t yet had the chance to test this in a distributed setting.

The index function may be useful for other things, too, such as creating an RDD with the first k elements, rather than calling RDD.take(k), which returns the entire dataset to the client.

Again, the performance is terrible, but this may still be useful if other work dominates the computation.


public static <V1, V2> JavaPairRDD<V1, V2> zip(JavaRDD<V1> rdd1, JavaRDD<V2> rdd2) {
    return JavaPairRDD.fromJavaRDD(index(rdd1).join(index(rdd2)).values());

 * Keys an RDD by the index of each element
public static <T> JavaPairRDD<Integer, T> index(JavaRDD<T> rdd) {
    // The RDD consists of a single entry, which is a list of the sizes of each partition
    JavaRDD<List<Integer>> partitionSizesRDD = rdd.mapPartitions(new FlatMapFunction<Iterator<T>, Integer>() {
        public Iterable<Integer> call(Iterator<T> arg0) throws Exception {
            int count = 0;
            while(arg0.hasNext()) {
            return Collections.singleton(count);
    // The RDD consists of a single entry, which is a list of the start indices of each partition
    JavaRDD<List<Integer>> partitionOffsetsRDD = partitionSizesRDD
            .map(new Function<List<Integer>, List<Integer>>() {
                public List<Integer> call(List<Integer> arg0) throws Exception {
                    int offset = 0;
                    List<Integer> out = new ArrayList<>();
                    for(Integer i : arg0) {
                        offset += i;
                    return out;
    ClassManifest<Tuple2<Integer, T>> cm = (ClassManifest) scala.reflect.ClassManifest$.MODULE$
    // Since partitionOffsetsRDD has a single element, this has the effect of pairing each partition with the list of all partition offsets
    JavaPairRDD<List<T>, List<Integer>> cartesian = rdd.glom().cartesian(partitionOffsetsRDD);
    // Since the RDD was already glommed, each partition has a single element.
    // We use the partition index to find the correct partition offset in the list of offsets.
    // Each individual element is then output with its index.
    JavaRDD<Tuple2<Integer, T>> result = cartesian.mapPartitionsWithIndex(
            new Function2<Object, Iterator<Tuple2<List<T>, List<Integer>>>, Iterator<Tuple2<Integer, T>>>() {
                public Iterator<Tuple2<Integer, T>> call(Object arg0, Iterator<Tuple2<List<T>, List<Integer>>> arg1)
                        throws Exception {
                    int partition = (Integer) arg0;
                    Tuple2<List<T>, List<Integer>> t =;
                    assert !arg1.hasNext();
                    final int initialOffset = t._2().get(partition);
                    final Iterator<T> itr = t._1().iterator();
                    return new Iterator<Tuple2<Integer, T>>() {
                        int offset = initialOffset;

                        public boolean hasNext() {
                            return itr.hasNext();

                        public Tuple2<Integer, T> next() {
                            return new Tuple2<Integer, T>(offset++,;

                        public void remove() {
                            throw new UnsupportedOperationException();
            }, true, cm);
    return JavaPairRDD.fromJavaRDD(result);


  1. Posted August 8, 2017 at 7:31 pm | Permalink

    I’d like to find out more? I’d love to find out more details.

  2. Posted May 13, 2021 at 11:54 pm | Permalink

    I enjoy reading a post that will make people think.
    Also, thanks for enabling me to comment!

Post a Comment

Your email is never shared. Required fields are marked *