Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Add a strategy to fall back to Vanilla Spark shuffle manager #1047

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.vectorized;


import scala.collection.convert.Wrappers;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;

public class IteratorWrapper {

private Iterator<List<Long>> in;

public IteratorWrapper(Iterator<List<Long>> in) {
this.in = in;
}

public boolean hasNext() {
return in.hasNext();
}

public List<Long> next() {
return in.next();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ public native long nativeMake(
long memoryPoolId,
boolean writeSchema);

public long make(
NativePartitioning part,
long offheapPerTask,
int bufferSize) {
return initSplit(
part.getShortName(),
part.getNumPartitions(),
part.getSchema(),
part.getExprList(),
offheapPerTask,
bufferSize
);
}

public native long initSplit(
String shortName,
int numPartitions,
byte[] schema,
byte[] exprList,
long offheapPerTask,
int bufferSize);

/**
*
* Spill partition data to disk.
Expand Down Expand Up @@ -113,6 +135,11 @@ public native long split(
long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch)
throws IOException;

/**
* Collect the record batch after splitting.
*/
public native void collect(long splitterId, int numRows) throws IOException;

/**
* Update the compress type.
*/
Expand All @@ -127,6 +154,21 @@ public native long split(
*/
public native SplitResult stop(long splitterId) throws IOException;

public native byte[][] cacheBuffer(
long splitterId,
int numRows)
throws RuntimeException;


/**
* Clear the buffer. And stop processing splitting
*
* @param splitterId splitter instance id
* @return SplitResult
*/
public native SplitResult clear(long splitterId) throws IOException;


/**
* Release resources associated with designated splitter instance.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.vectorized;


import com.intel.oap.expression.ConverterUtils;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.ipc.message.ArrowBuffer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;

public class SplitIterator implements Iterator<ColumnarBatch>{

public static class IteratorOptions implements Serializable {
private static final long serialVersionUID = -1L;

private int partitionNum;

private String name;

private long offheapPerTask;

private int bufferSize;

private String expr;

public NativePartitioning getNativePartitioning() {
return nativePartitioning;
}

public void setNativePartitioning(NativePartitioning nativePartitioning) {
this.nativePartitioning = nativePartitioning;
}

NativePartitioning nativePartitioning;

public int getPartitionNum() {
return partitionNum;
}

public void setPartitionNum(int partitionNum) {
this.partitionNum = partitionNum;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public long getOffheapPerTask() {
return offheapPerTask;
}

public void setOffheapPerTask(long offheapPerTask) {
this.offheapPerTask = offheapPerTask;
}

public int getBufferSize() {
return bufferSize;
}

public void setBufferSize(int bufferSize) {
this.bufferSize = bufferSize;
}

public String getExpr() {
return expr;
}

public void setExpr(String expr) {
this.expr = expr;
}

}

ShuffleSplitterJniWrapper jniWrapper;

private long nativeSplitter = 0;
private final Iterator<ColumnarBatch> iterator;
private final IteratorOptions options;

public SplitIterator(ShuffleSplitterJniWrapper jniWrapper,
Iterator<ColumnarBatch> iterator, IteratorOptions options) {
this.jniWrapper = jniWrapper;
this.iterator = iterator;
this.options = options;
}

private void nativeCreateInstance() {
ColumnarBatch cb = iterator.next();
ArrowRecordBatch recordBatch = ConverterUtils.createArrowRecordBatch(cb);
try {
nativeSplitter = jniWrapper.make(
options.getNativePartitioning(),
options.getOffheapPerTask(),
options.getBufferSize());
int len = recordBatch.getBuffers().size();
long[] bufAddrs = new long[len];
long[] bufSizes = new long[len];
int i = 0, j = 0;
for (ArrowBuf buffer: recordBatch.getBuffers()) {
bufAddrs[i++] = buffer.memoryAddress();
}
for (ArrowBuffer buffer: recordBatch.getBuffersLayout()) {
bufSizes[j++] = buffer.getSize();
}
jniWrapper.split(nativeSplitter, cb.numRows(), bufAddrs, bufSizes, false);
jniWrapper.collect(nativeSplitter, cb.numRows());
} catch (IOException e) {
throw new RuntimeException(e);
}

}

private native boolean nativeHasNext(long instance);

/**
* First to check,
* @return
*/
@Override
public boolean hasNext() {

// 1. Init the native splitter
if (nativeSplitter == 0) {
if (!iterator.hasNext()) {
return false;
} else {
nativeCreateInstance();
}
}
// 2. Call native hasNext
if (nativeHasNext(nativeSplitter)) {
return true;
} else if (iterator.hasNext()) {
// 3. Split next rb
nativeCreateInstance();
}
return nativeHasNext(nativeSplitter);
}

private native byte[] nativeNext(long instance);

@Override
public ColumnarBatch next() {
byte[] serializedRecordBatch = nativeNext(nativeSplitter);
ColumnarBatch cb = ConverterUtils.createRecordBatch(serializedRecordBatch,
options.getNativePartitioning().getSchema());
return cb;
}

private native int nativeNextPartitionId(long nativeSplitter);

public int nextPartitionId() {
return nativeNextPartitionId(nativeSplitter);
}

@Override
protected void finalize() throws Throwable {
try {
jniWrapper.clear(nativeSplitter);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
val enableColumnarShuffledHashJoin: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffledhashjoin", "true").toBoolean && enableCpu

// enable or disable fallback shuffle manager
val enableFallbackShuffle: Boolean = conf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please also add a short note on how to use this feature? and also make this default to false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in the description dialog

.getConfString("spark.oap.sql.columnar.enableFallbackShuffle", "true")
.equals("true") && enableCpu

val enableArrowColumnarToRow: Boolean =
conf.getConfString("spark.oap.sql.columnar.columnartorow", "true").toBoolean && enableCpu

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException}
import java.nio.channels.Channels
import java.nio.ByteBuffer
import java.util.ArrayList

import com.intel.oap.vectorized.ArrowWritableColumnVector
import io.netty.buffer.{ByteBufAllocator, ByteBufOutputStream}
import org.apache.arrow.memory.ArrowBuf
Expand Down Expand Up @@ -50,28 +49,42 @@ import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import io.netty.buffer.{ByteBuf, ByteBufAllocator, ByteBufOutputStream}
import java.nio.channels.{Channels, WritableByteChannel}

import java.nio.channels.{Channels, WritableByteChannel}
import com.google.common.collect.Lists
import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer

import java.io.{InputStream, OutputStream}
import java.util
import java.util.concurrent.TimeUnit.SECONDS

import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision}
import org.apache.spark.sql.catalyst.util.{DateTimeConstants, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils, SparkVectorUtils}

object ConverterUtils extends Logging {
def calcuateEstimatedSize(columnarBatch: ColumnarBatch): Long = {
SparkVectorUtils.estimateSize(columnarBatch)
}

def createRecordBatch(serializedRecordBatch: Array[Byte], serializedSchema: Array[Byte]): ColumnarBatch = {
val schema = ConverterUtils.getSchemaFromBytesBuf(serializedSchema);
val allocator = SparkMemoryUtils.contextAllocatorForBufferImport
val resultBatch = UnsafeRecordBatchSerializer.deserializeUnsafe(allocator, serializedRecordBatch)
if (resultBatch == null) {
throw new Exception("Error from SerializedRecordBatch to ColumnarBatch.")
} else {
val resultColumnVectorList = fromArrowRecordBatch(schema, resultBatch)
val length = resultBatch.getLength
ConverterUtils.releaseArrowRecordBatch(resultBatch)
new ColumnarBatch(resultColumnVectorList.map(v => v.asInstanceOf[ColumnVector]), length)
}
}

def createArrowRecordBatch(columnarBatch: ColumnarBatch): ArrowRecordBatch = {
SparkVectorUtils.toArrowRecordBatch(columnarBatch)
}
Expand Down Expand Up @@ -369,6 +382,19 @@ object ConverterUtils extends Logging {
}
}

def getShortAttributeName(attr: Attribute): String = {
val index = attr.name.indexOf("(")
if (index != -1) {
attr.name.substring(0, index)
} else {
attr.name
}
}

def genColumnNameWithExprId(attr: Attribute): String = {
getShortAttributeName(attr) + "#" + attr.exprId.id
}

def getResultAttrFromExpr(
fieldExpr: Expression,
name: String = "None",
Expand Down
Loading