Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Add a strategy to fall back to Vanilla Spark shuffle manager #1047

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.vectorized;


import scala.collection.convert.Wrappers;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;

public class IteratorWrapper {

private Iterator<List<Long>> in;

public IteratorWrapper(Iterator<List<Long>> in) {
this.in = in;
}

public boolean hasNext() {
return in.hasNext();
}

public List<Long> next() {
return in.next();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ public native long nativeMake(
long memoryPoolId,
boolean writeSchema);

public long make(
NativePartitioning part,
long offheapPerTask,
int bufferSize) {
return initSplit(
part.getShortName(),
part.getNumPartitions(),
part.getSchema(),
part.getExprList(),
offheapPerTask,
bufferSize
);
}

public native long initSplit(
String shortName,
int numPartitions,
byte[] schema,
byte[] exprList,
long offheapPerTask,
int bufferSize);

/**
*
* Spill partition data to disk.
Expand Down Expand Up @@ -113,6 +135,11 @@ public native long split(
long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch)
throws IOException;

/**
* Collect the record batch after splitting.
*/
public native void collect(long splitterId, int numRows) throws IOException;

/**
* Update the compress type.
*/
Expand All @@ -127,6 +154,15 @@ public native long split(
*/
public native SplitResult stop(long splitterId) throws IOException;

/**
* Clear the buffer. And stop processing splitting
*
* @param splitterId splitter instance id
* @return SplitResult
*/
public native SplitResult clear(long splitterId) throws IOException;


/**
* Release resources associated with designated splitter instance.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.vectorized;


import com.intel.oap.expression.ConverterUtils;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.ipc.message.ArrowBuffer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SplitIterator implements Iterator<ColumnarBatch>{

private static final Logger logger = LoggerFactory.getLogger(SplitIterator.class);

public static class IteratorOptions implements Serializable {
private static final long serialVersionUID = -1L;

private int partitionNum;

private String name;

private long offheapPerTask;

private int bufferSize;

private String expr;

public NativePartitioning getNativePartitioning() {
return nativePartitioning;
}

public void setNativePartitioning(NativePartitioning nativePartitioning) {
this.nativePartitioning = nativePartitioning;
}

NativePartitioning nativePartitioning;

public int getPartitionNum() {
return partitionNum;
}

public void setPartitionNum(int partitionNum) {
this.partitionNum = partitionNum;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public long getOffheapPerTask() {
return offheapPerTask;
}

public void setOffheapPerTask(long offheapPerTask) {
this.offheapPerTask = offheapPerTask;
}

public int getBufferSize() {
return bufferSize;
}

public void setBufferSize(int bufferSize) {
this.bufferSize = bufferSize;
}

public String getExpr() {
return expr;
}

public void setExpr(String expr) {
this.expr = expr;
}

}

ShuffleSplitterJniWrapper jniWrapper = null;

private long nativeSplitter = 0;
private final Iterator<ColumnarBatch> iterator;
private final IteratorOptions options;

private ColumnarBatch cb = null;

public SplitIterator(Iterator<ColumnarBatch> iterator, IteratorOptions options) {
this.iterator = iterator;
this.options = options;
}

private void nativeCreateInstance() {
for (int i = 0; i < cb.numCols(); i++) {
ArrowWritableColumnVector vector = (ArrowWritableColumnVector)(cb.column(i));
vector.getValueVector().setValueCount(cb.numRows());
}
ArrowRecordBatch recordBatch = ConverterUtils.createArrowRecordBatch(cb);
try {
if (jniWrapper == null) {
jniWrapper = new ShuffleSplitterJniWrapper();
}
if (nativeSplitter != 0) {
jniWrapper.clear(nativeSplitter);
nativeSplitter = 0;
// throw new Exception("NativeSplitter is not clear.");
}
nativeSplitter = jniWrapper.make(
options.getNativePartitioning(),
options.getOffheapPerTask(),
options.getBufferSize());
long[] bufAddrs = new long[recordBatch.getBuffers().size()];
long[] bufSizes = new long[recordBatch.getBuffersLayout().size()];
int i = 0, j = 0;
for (ArrowBuf buffer: recordBatch.getBuffers()) {
bufAddrs[i++] = buffer.memoryAddress();
}
for (ArrowBuffer buffer: recordBatch.getBuffersLayout()) {
bufSizes[j++] = buffer.getSize();
}
if (i != j || i < 1) {
logger.warn("bufAddrs and BuffersLayout have different lengths, and buffer sizes is " + i + " -- " + j);
}
jniWrapper.split(nativeSplitter, cb.numRows(), bufAddrs, bufSizes, false);
jniWrapper.collect(nativeSplitter, cb.numRows());
} catch (Exception e) {
if (nativeSplitter != 0) {
try {
jniWrapper.clear(nativeSplitter);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
nativeSplitter = 0;
}
throw new RuntimeException(e);
} finally {
ConverterUtils.releaseArrowRecordBatch(recordBatch);
// cb.close();
}

}

private native boolean nativeHasNext(long instance);

public boolean hasRecordBatch(){
while (iterator.hasNext()) {
cb = iterator.next();
if (cb.numRows() != 0 && cb.numCols() != 0) {
nativeCreateInstance();
return true;
}
}
if (nativeSplitter != 0) {
try {
jniWrapper.clear(nativeSplitter);
nativeSplitter = 0;
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
// jniWrapper.close(nativeSplitter);
}
}
return false;
}

@Override
public boolean hasNext() {
// 1. Init the native splitter
if (nativeSplitter == 0) {
return hasRecordBatch() && nativeHasNext(nativeSplitter);
}
// 2. Call native hasNext
if (nativeHasNext(nativeSplitter)) {
return true;
} else {
return hasRecordBatch() && nativeHasNext(nativeSplitter);
}
}

private native byte[] nativeNext(long instance);

@Override
public ColumnarBatch next() {
byte[] serializedRecordBatch = nativeNext(nativeSplitter);
return ConverterUtils.createRecordBatch(serializedRecordBatch,
options.getNativePartitioning().getSchema());
}

private native int nativeNextPartitionId(long nativeSplitter);

public int nextPartitionId() {
return nativeNextPartitionId(nativeSplitter);
}

@Override
protected void finalize() throws Throwable {
try {
if (nativeSplitter != 0) {
logger.error("NativeSplitter is not clear.");
jniWrapper.clear(nativeSplitter);
nativeSplitter = 0;
}
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
jniWrapper.close(nativeSplitter);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
val enableColumnarShuffledHashJoin: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffledhashjoin", "true").toBoolean && enableCpu

// enable or disable fallback shuffle manager
val enableFallbackShuffle: Boolean = conf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please also add a short note on how to use this feature? and also make this default to false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in the description dialog

.getConfString("spark.oap.sql.columnar.enableFallbackShuffle", "false")
.equals("true") && enableCpu

val enableArrowColumnarToRow: Boolean =
conf.getConfString("spark.oap.sql.columnar.columnartorow", "true").toBoolean && enableCpu

Expand Down
Loading