diff --git a/src/it/scala/com/qubole/spark/hiveacid/LockSuite.scala b/src/it/scala/com/qubole/spark/hiveacid/LockSuite.scala index 6666e9e..d6bdb4b 100644 --- a/src/it/scala/com/qubole/spark/hiveacid/LockSuite.scala +++ b/src/it/scala/com/qubole/spark/hiveacid/LockSuite.scala @@ -183,6 +183,8 @@ class TestLockHelper extends TestHelper { .config("spark.hadoop.hive.txn.timeout", "6") //.config("spark.ui.enabled", "true") //.config("spark.ui.port", "4041") + // All V1 tests are executed USING HiveAcid + .config("spark.hive.acid.datasource.version", "v2") .enableHiveSupport() .getOrCreate() } diff --git a/src/it/scala/com/qubole/spark/hiveacid/ReadSuite.scala b/src/it/scala/com/qubole/spark/hiveacid/ReadSuite.scala index 8d52739..6097441 100644 --- a/src/it/scala/com/qubole/spark/hiveacid/ReadSuite.scala +++ b/src/it/scala/com/qubole/spark/hiveacid/ReadSuite.scala @@ -26,7 +26,7 @@ import org.scalatest._ import scala.util.control.NonFatal -@Ignore +//@Ignore class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { val log: Logger = LogManager.getLogger(this.getClass) @@ -222,9 +222,10 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll // Special case of comparing result read before conversion // and after conversion. log.info("++ Compare result across conversion") - val (dfFromSql, dfFromScala) = helper.sparkGetDF(table) + val (dfFromSql, dfFromScala, dfFromSqlV2) = helper.sparkGetDF(table) helper.compareResult(hiveResStr, dfFromSql.collect()) helper.compareResult(hiveResStr, dfFromScala.collect()) + helper.compareResult(hiveResStr, dfFromSqlV2.collect()) helper.verify(table, insertOnly = false) } @@ -272,11 +273,12 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll val hiveResStr = helper.hiveExecuteQuery(table.hiveSelect) - val (df1, df2) = helper.sparkGetDF(table) + val (df1, df2, dfV2) = helper.sparkGetDF(table) // Materialize it once helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) helper.hiveExecute(table.insertIntoHiveTableKey(11)) helper.hiveExecute(table.insertIntoHiveTableKey(12)) @@ -284,9 +286,9 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll helper.hiveExecute(table.insertIntoHiveTableKey(14)) helper.hiveExecute(table.insertIntoHiveTableKey(15)) if (isPartitioned) { - compactPartitionedAndTest(hiveResStr, df1, df2, Seq(11,12,13,14,15)) + compactPartitionedAndTest(hiveResStr, df1, df2, dfV2, Seq(11,12,13,14,15)) } else { - compactAndTest(hiveResStr, df1, df2) + compactAndTest(hiveResStr, df1, df2, dfV2) } // Shortcut for insert Only @@ -296,9 +298,9 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll helper.hiveExecute(table.deleteFromHiveTableKey(5)) helper.hiveExecute(table.deleteFromHiveTableKey(6)) if (isPartitioned) { - compactPartitionedAndTest(hiveResStr, df1, df2, Seq(3,4,5,6)) + compactPartitionedAndTest(hiveResStr, df1, df2, dfV2, Seq(3,4,5,6)) } else { - compactAndTest(hiveResStr, df1, df2) + compactAndTest(hiveResStr, df1, df2, dfV2) } helper.hiveExecute(table.updateInHiveTableKey(7)) @@ -306,33 +308,39 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll helper.hiveExecute(table.updateInHiveTableKey(9)) helper.hiveExecute(table.updateInHiveTableKey(10)) if (isPartitioned) { - compactPartitionedAndTest(hiveResStr, df1, df2, Seq(7,8,9,10)) + compactPartitionedAndTest(hiveResStr, df1, df2, dfV2, Seq(7,8,9,10)) } else { - compactAndTest(hiveResStr, df1, df2) + compactAndTest(hiveResStr, df1, df2, dfV2) } } } - def compactAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame): Unit = { + def compactAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame, dfV2: DataFrame): Unit = { helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) helper.hiveExecute(table.minorCompaction) helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) helper.hiveExecute(table.majorCompaction) helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) } - def compactPartitionedAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame, keys: Seq[Int]): Unit = { + def compactPartitionedAndTest(hiveResStr: String, df1: DataFrame, df2: DataFrame, dfV2: DataFrame, keys: Seq[Int]): Unit = { helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) keys.foreach(k => helper.hiveExecute(table.minorPartitionCompaction(k))) helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) keys.foreach((k: Int) => helper.hiveExecute(table.majorPartitionCompaction(k))) helper.compareResult(hiveResStr, df1.collect()) helper.compareResult(hiveResStr, df2.collect()) + helper.compareResult(hiveResStr, dfV2.collect()) } helper.myRun(testName, code) @@ -365,7 +373,7 @@ class ReadSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll helper.hiveExecute(table2.insertIntoHiveTableKeyRange(10, 25)) var hiveResStr = helper.hiveExecuteQuery(Table.hiveJoin(table1, table2)) - val sparkRes1 = helper.sparkCollect(Table.sparkJoin(table1, table2)) + val sparkRes1 = helper.sparkCollect(Table.hiveJoin(table1, table2)) helper.compareResult(hiveResStr, sparkRes1) } diff --git a/src/it/scala/com/qubole/spark/hiveacid/TestHelper.scala b/src/it/scala/com/qubole/spark/hiveacid/TestHelper.scala index c5075f6..c49eaab 100644 --- a/src/it/scala/com/qubole/spark/hiveacid/TestHelper.scala +++ b/src/it/scala/com/qubole/spark/hiveacid/TestHelper.scala @@ -76,26 +76,29 @@ class TestHelper extends SQLImplicits { def compare(table: Table, msg: String): Unit = { log.info(s"Verify simple $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelect) - val (dfFromSql, dfFromScala) = sparkGetDF(table) + val (dfFromSql, dfFromScala, dfFromSqlV2) = sparkGetDF(table) compareResult(hiveResStr, dfFromSql.collect()) compareResult(hiveResStr, dfFromScala.collect()) + compareResult(hiveResStr, dfFromSqlV2.collect()) } // With Predicate private def compareWithPred(table: Table, msg: String, pred: String): Unit = { log.info(s"Verify with predicate $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelectWithPred(pred)) - val (dfFromSql, dfFromScala) = sparkGetDFWithPred(table, pred) + val (dfFromSql, dfFromScala, dfFromSqlV2) = sparkGetDFWithPred(table, pred) compareResult(hiveResStr, dfFromSql.collect()) compareResult(hiveResStr, dfFromScala.collect()) + compareResult(hiveResStr, dfFromSqlV2.collect()) } // With Projection private def compareWithProj(table: Table, msg: String): Unit = { log.info(s"Verify with projection $msg") val hiveResStr = hiveExecuteQuery(table.hiveSelectWithProj) - val (dfFromSql, dfFromScala) = sparkGetDFWithProj(table) + val (dfFromSql, dfFromScala, dfFromSqlV2) = sparkGetDFWithProj(table) compareResult(hiveResStr, dfFromSql.collect()) compareResult(hiveResStr, dfFromScala.collect()) + compareResult(hiveResStr, dfFromSqlV2.collect()) } // Compare result of 2 tables via hive @@ -198,28 +201,31 @@ class TestHelper extends SQLImplicits { compareWithProj(table, "After Delete") } - def sparkGetDFWithProj(table: Table): (DataFrame, DataFrame) = { + def sparkGetDFWithProj(table: Table): (DataFrame, DataFrame, DataFrame) = { val dfSql = sparkSQL(table.sparkSelect) + val dfSqlV2 = sparkSQL(table.hiveSelect) var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load().select(table.sparkDFProj) dfScala = totalOrderBy(table, dfScala) - (dfSql, dfScala) + (dfSql, dfScala, dfSqlV2) } - def sparkGetDFWithPred(table: Table, pred: String): (DataFrame, DataFrame) = { + def sparkGetDFWithPred(table: Table, pred: String): (DataFrame, DataFrame, DataFrame) = { val dfSql = sparkSQL(table.sparkSelectWithPred(pred)) + val dfSqlV2 = sparkSQL(table.hiveSelectWithPred(pred)) var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load().where(col("intCol") < "5") dfScala = totalOrderBy(table, dfScala) - (dfSql, dfScala) + (dfSql, dfScala, dfSqlV2) } - def sparkGetDF(table: Table): (DataFrame, DataFrame) = { + def sparkGetDF(table: Table): (DataFrame, DataFrame, DataFrame) = { val dfSql = sparkSQL(table.sparkSelect) + val dfSqlV2 = sparkSQL(table.hiveSelect) var dfScala = spark.read.format("HiveAcid").options(Map("table" -> table.hiveTname)).load() dfScala = totalOrderBy(table, dfScala) - (dfSql, dfScala) + (dfSql, dfScala, dfSqlV2) } def sparkSQL(cmd: String): DataFrame = { diff --git a/src/it/scala/com/qubole/spark/hiveacid/TestSparkSession.scala b/src/it/scala/com/qubole/spark/hiveacid/TestSparkSession.scala index 1831feb..69d4e61 100644 --- a/src/it/scala/com/qubole/spark/hiveacid/TestSparkSession.scala +++ b/src/it/scala/com/qubole/spark/hiveacid/TestSparkSession.scala @@ -30,6 +30,8 @@ private[hiveacid] object TestSparkSession { .config("spark.sql.extensions", "com.qubole.spark.hiveacid.HiveAcidAutoConvertExtension") //.config("spark.ui.enabled", "true") //.config("spark.ui.port", "4041") + // All V1 tests are executed USING HiveAcid + .config("spark.hive.acid.datasource.version", "v2") .enableHiveSupport() .getOrCreate() spark.sparkContext.setLogLevel("WARN") diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala index 683a612..f537125 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidAutoConvert.scala @@ -28,7 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoTable, Log import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation -import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource +import com.qubole.spark.hiveacid.datasource.{HiveAcidDataSource, HiveAcidDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.SparkContext +import org.apache.spark.sql.internal.HiveSerDe /** @@ -43,14 +46,27 @@ case class HiveAcidAutoConvert(spark: SparkSession) extends Rule[LogicalPlan] { relation.tableMeta.properties.getOrElse("transactional", "false").toBoolean } - private def convert(relation: HiveTableRelation): LogicalRelation = { + private def convert(relation: HiveTableRelation): LogicalPlan = { val options = relation.tableMeta.properties ++ relation.tableMeta.storage.properties ++ Map("table" -> relation.tableMeta.qualifiedName) - val newRelation = new HiveAcidDataSource().createRelation(spark.sqlContext, options) LogicalRelation(newRelation, isStreaming = false) } + private def convertV2(relation: HiveTableRelation): LogicalPlan = { + val serde = relation.tableMeta.storage.serde.getOrElse("") + if (!serde.equalsIgnoreCase(HiveSerDe.sourceToSerDe("orc").get.serde.get)) { + // Only ORC formatted is supported as of now. If its not ORC, then fallback to + // datasource V1. + logInfo("Falling back to datasource v1 as " + serde + " is not supported by v2 reader.") + return convert(relation) + } + val dbName = relation.tableMeta.identifier.database.getOrElse("default") + val tableName = relation.tableMeta.identifier.table + val tableOpts = Map("database" -> dbName, "table" -> tableName) + DataSourceV2Relation.create(new HiveAcidDataSourceV2, tableOpts, None, None) + } + override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { // Write path @@ -61,7 +77,11 @@ case class HiveAcidAutoConvert(spark: SparkSession) extends Rule[LogicalPlan] { // Read path case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => - convert(relation) + if (spark.conf.get("spark.hive.acid.datasource.version", "v1").equals("v2")) { + convertV2(relation) + } else { + convert(relation) + } } } } diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidDataSourceV2Reader.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidDataSourceV2Reader.scala new file mode 100644 index 0000000..c368494 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidDataSourceV2Reader.scala @@ -0,0 +1,126 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid + +import java.lang.String.format +import java.io.IOException +import java.util.{ArrayList, List, Map} + +import org.apache.spark.sql.sources.v2.reader.DataSourceReader +import com.qubole.spark.hiveacid.hive.{HiveAcidMetadata, HiveConverter} +import org.apache.spark.SparkContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.DataSourceReader +import com.qubole.spark.hiveacid.transaction.HiveAcidTxn +import com.qubole.spark.hiveacid.util.{SerializableConfiguration, Util} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.reader._ +import com.qubole.spark.hiveacid.reader.v2.HiveAcidInputPartitionV2 +import com.qubole.spark.hiveacid.reader.TableReader +import com.qubole.spark.hiveacid.reader.hive.HiveAcidSearchArgument +import com.qubole.spark.hiveacid.reader.hive.HiveAcidSearchArgument.{buildTree, castLiteralValue, getPredicateLeafType, isSearchableType, quoteAttributeNameIfNeeded} + +/** + * Data source V2 implementation for HiveACID +*/ +class HiveAcidDataSourceV2Reader + extends DataSourceV2 with DataSourceReader with SupportsScanColumnarBatch + with SupportsPushDownRequiredColumns + with SupportsPushDownFilters with Logging { + + def this(options: java.util.Map[String, String], + sparkSession : SparkSession, + dbName : String, + tblName : String) { + this() + this.options = options + this.sparkSession = sparkSession + if (dbName != null) { + hiveAcidMetadata = HiveAcidMetadata.fromSparkSession(sparkSession, dbName + "." + tblName) + } else { + // If db name is null, default db is chosen. + hiveAcidMetadata = HiveAcidMetadata.fromSparkSession(sparkSession, tblName) + } + + // This is a hack to prevent the following situation: + // Spark(v 2.4.0) creates one instance of DataSourceReader to call readSchema() + // and then a new instance of DataSourceReader to call pushFilters(), + // planBatchInputPartitions() etc. Since it uses different DataSourceReader instances, + // and reads schema in former instance, schema remains null in the latter instance + // (which causes problems for other methods). More discussion: + // http://apache-spark-user-list.1001560.n3.nabble.com/DataSourceV2-APIs-creating-multiple-instances-of-DataSourceReader-and-hence-not-preserving-the-state-tc33646.html + // Also a null check on schema is already there in readSchema() to prevent initialization + // more than once just in case. + readSchema + } + + private var options: java.util.Map[String, String] = null + private var sparkSession : SparkSession = null + + //The pruned schema + private var schema: StructType = null + + private var pushedFilterArray : Array[Filter] = null + + private var hiveAcidMetadata: HiveAcidMetadata = _ + + override def readSchema: StructType = { + if (schema == null) { + schema = hiveAcidMetadata.tableSchema + } + schema + } + + override def planBatchInputPartitions() : java.util.List[InputPartition[ColumnarBatch]] = { + val factories = new java.util.ArrayList[InputPartition[ColumnarBatch]] + inTxn { + txn: HiveAcidTxn => { + import scala.collection.JavaConversions._ + val reader = new TableReader(sparkSession, txn, hiveAcidMetadata) + val hiveReader = reader.getPartitionsV2(schema.fieldNames, + pushedFilterArray, new SparkAcidConf(sparkSession, options.toMap)) + factories.addAll(hiveReader) + } + } + factories + } + + private def inTxn(f: HiveAcidTxn => Unit): Unit = { + new HiveTxnWrapper(sparkSession).inTxn(f) + } + + override def pushFilters (filters: Array[Filter]): Array[Filter] = { + this.pushedFilterArray = HiveAcidSearchArgument. + getSupportedFilters(hiveAcidMetadata.tableSchema, filters.toSeq).toArray + // ORC does not do row level filtering. So the filters has to be applied again. + filters + } + + override def pushedFilters(): Array[Filter] = this.pushedFilterArray + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.schema = requiredSchema + } +} \ No newline at end of file diff --git a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala index bd0731b..a3d3a6d 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/HiveAcidTable.scala @@ -231,7 +231,7 @@ object HiveAcidTable { * This wrapper can be used just once for running an operation. That operation is not allowed to recursively call this again * @param sparkSession */ -private class HiveTxnWrapper(sparkSession: SparkSession) extends Logging { +private[hiveacid] class HiveTxnWrapper(sparkSession: SparkSession) extends Logging { private var isLocalTxn: Boolean = _ private var curTxn: HiveAcidTxn = _ diff --git a/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSourceV2.scala b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSourceV2.scala new file mode 100644 index 0000000..38d3bea --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/datasource/HiveAcidDataSourceV2.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2019 Qubole, Inc. All rights reserved. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.datasource + +import java.util.{ArrayList, List, Map} +import org.apache.spark.sql.sources.v2.reader.DataSourceReader +import org.apache.spark.internal.Logging +import com.qubole.spark.hiveacid.HiveAcidDataSourceV2Reader +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.v2.{ReadSupport,DataSourceOptions,DataSourceV2} + +class HiveAcidDataSourceV2 extends DataSourceV2 with ReadSupport with Logging { + override def createReader (options: DataSourceOptions) : DataSourceReader = { + logInfo("Creating datasource V2 for table" + options.tableName.get) + new HiveAcidDataSourceV2Reader(options.asMap, + SparkSession.getActiveSession.orNull, + options.databaseName.get, options.tableName.get) + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala index 06d8c24..84e2d1f 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/rdd/HiveAcidRDD.scala @@ -140,7 +140,7 @@ private[hiveacid] class HiveAcidRDD[K, V](sc: SparkContext, sparkContext.getConf.getBoolean("spark.hadoopRDD.ignoreEmptySplits", defaultValue = false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. - protected def getJobConf: JobConf = { + def getJobConf: JobConf = { val conf: Configuration = broadcastedConf.value.value if (shouldCloneJobConf) { // Hadoop Configuration objects are not thread-safe, which may lead to various problems if diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala index 172c99c..90937a1 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/Reader.scala @@ -20,14 +20,15 @@ package com.qubole.spark.hiveacid.reader import com.qubole.spark.hiveacid.hive.HiveAcidMetadata - +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow private[reader] trait Reader { def makeRDDForTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] - def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata, - partitions: Seq[ReaderPartition]): RDD[InternalRow] + def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] + def makeV2ReaderForTable(hiveAcidMetadata: HiveAcidMetadata): java.util.List[InputPartition[ColumnarBatch]] } private[reader] case class ReaderPartition(ptn: Any) diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala index 18c1405..4a94408 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/TableReader.scala @@ -23,12 +23,15 @@ import com.qubole.spark.hiveacid.{HiveAcidOperation, SparkAcidConf} import com.qubole.spark.hiveacid.transaction._ import com.qubole.spark.hiveacid.hive.HiveAcidMetadata import com.qubole.spark.hiveacid.reader.hive.{HiveAcidReader, HiveAcidReaderOptions} +import com.qubole.spark.hiveacid.SparkAcidConf import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Table reader object @@ -41,9 +44,9 @@ private[hiveacid] class TableReader(sparkSession: SparkSession, curTxn: HiveAcidTxn, hiveAcidMetadata: HiveAcidMetadata) extends Logging { - def getRdd(requiredColumns: Array[String], + private def getTableReader(requiredColumns: Array[String], filters: Array[Filter], - readConf: SparkAcidConf): RDD[Row] = { + readConf: SparkAcidConf): HiveAcidReader = { val rowIdColumnSet = HiveAcidMetadata.rowIdSchema.fields.map(_.name).toSet val requiredColumnsWithoutRowId = requiredColumns.filterNot(rowIdColumnSet.contains) val partitionColumnNames = hiveAcidMetadata.partitionSchema.fields.map(_.name) @@ -73,7 +76,7 @@ private[hiveacid] class TableReader(sparkSession: SparkSession, // Filters val (partitionFilters, otherFilters) = filters.partition { predicate => !predicate.references.isEmpty && - predicate.references.toSet.subsetOf(partitionedColumnSet) + predicate.references.map(_.toLowerCase).toSet.subsetOf(partitionedColumnSet) } val dataFilters = otherFilters.filter(_ .references.intersect(partitionColumnNames).isEmpty @@ -113,18 +116,38 @@ private[hiveacid] class TableReader(sparkSession: SparkSession, val validWriteIds = HiveAcidTxn.getValidWriteIds(curTxn, hiveAcidMetadata) - val reader = new HiveAcidReader( + new HiveAcidReader( sparkSession, readerOptions, hiveAcidReaderOptions, - validWriteIds) + validWriteIds, + partitions) + } + def getPartitionsV2(requiredColumns: Array[String], + filters: Array[Filter], + readConf: SparkAcidConf): java.util.List[InputPartition[ColumnarBatch]] = { + val reader = getTableReader(requiredColumns, filters, readConf) + if (hiveAcidMetadata.isPartitioned) { + logDebug("getReader for Partitioned table") + reader.makeReaderForPartitionedTable(hiveAcidMetadata) + } else { + logDebug("getReader for non Partitioned table ") + reader.makeV2ReaderForTable(hiveAcidMetadata) + } + } + + def getRdd(requiredColumns: Array[String], + filters: Array[Filter], + readConf: SparkAcidConf): RDD[Row] = { + val reader = getTableReader(requiredColumns, filters, readConf) val rdd = if (hiveAcidMetadata.isPartitioned) { - reader.makeRDDForPartitionedTable(hiveAcidMetadata, partitions) + logDebug("getRdd for Partitioned table") + reader.makeRDDForPartitionedTable(hiveAcidMetadata) } else { + logDebug("getRdd for non Partitioned table ") reader.makeRDDForTable(hiveAcidMetadata) } - rdd.asInstanceOf[RDD[Row]] } } diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala index d883767..7e0e8cb 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidReader.scala @@ -20,7 +20,7 @@ package com.qubole.spark.hiveacid.reader.hive import java.util -import java.util.Properties +import java.util.{List, Properties} import scala.collection.JavaConverters._ import com.esotericsoftware.kryo.Kryo @@ -43,6 +43,7 @@ import com.qubole.spark.hiveacid.hive.HiveConverter import com.qubole.spark.hiveacid.reader.{Reader, ReaderOptions, ReaderPartition} import com.qubole.spark.hiveacid.rdd._ import com.qubole.spark.hiveacid.util._ +import com.qubole.spark.hiveacid.reader.v2.HiveAcidInputPartitionV2 import org.apache.commons.codec.binary.Base64 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} @@ -65,6 +66,9 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.hive.{Hive3Inspectors, HiveAcidUtils} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.types._ /** * Helper class for scanning tables stored in Hadoop - e.g., to read @@ -73,11 +77,13 @@ import org.apache.spark.unsafe.types.UTF8String * @param readerOptions - reader options for creating RDD * @param hiveAcidOptions - hive related reader options for creating RDD * @param validWriteIds - validWriteIds + * @param partitions - The list of partitions to be scanned. Valid only for partitioned table. */ private[reader] class HiveAcidReader(sparkSession: SparkSession, readerOptions: ReaderOptions, hiveAcidOptions: HiveAcidReaderOptions, - validWriteIds: ValidWriteIdList) + validWriteIds: ValidWriteIdList, + partitions: Seq[ReaderPartition]) extends CastSupport with Reader with Logging { @@ -108,13 +114,68 @@ extends CastSupport with Reader with Logging { override def conf: SQLConf = sparkSession.sessionState.conf - /** - * @param hiveAcidMetadata - hive acid metadata for underlying table - * @return - Returns RDD on top of non partitioned hive acid table and list of partitionNames empty list - * for entire table - */ - def makeRDDForTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] = { - val hiveTable = hiveAcidMetadata.hTable + def makeV2ReaderForPath(hiveAcidMetadata: HiveAcidMetadata, + path : String, + fieldSchemas: util.List[FieldSchema], + partitionValues : InternalRow): java.util.List[InputPartition[ColumnarBatch]] = { + setReaderOptions(hiveAcidMetadata) + + val ifcName = hiveAcidMetadata.hTable.getInputFormatClass.getName + val inputFormatClass = Util.classForName(ifcName, loadShaded = true) + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + + val colNames = getColumnNamesFromFieldSchema(fieldSchemas) + val colTypes = getColumnTypesFromFieldSchema(fieldSchemas) + val initializeJobConfFunc = HiveAcidReader.initializeLocalJobConfFunc( + path, hiveAcidOptions.tableDesc, + hiveAcidMetadata.hTable.getParameters, + colNames, colTypes) _ + + //TODO :Its a ugly hack, but avoids lots of duplicate code. + val rdd = new HiveAcidRDD( + sparkSession.sparkContext, + validWriteIds, + hiveAcidOptions.isFullAcidTable, + _broadcastedHadoopConf.asInstanceOf[Broadcast[SerializableConfiguration]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD) + val jobConf = rdd.getJobConf + val inputSplits = rdd.getPartitions + + val reqFields = hiveAcidMetadata.tableSchema.fields.filter(field => + readerOptions.requiredNonPartitionedColumns.contains(field.name)) + + val broadCastConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(jobConf)) + val partitionArray = new java.util.ArrayList[InputPartition[ColumnarBatch]] + for (i <- 0 until inputSplits.size) { + partitionArray.add(new HiveAcidInputPartitionV2(inputSplits(i).asInstanceOf[HiveAcidPartition], + broadCastConf, partitionValues, reqFields, hiveAcidMetadata.partitionSchema, hiveAcidMetadata.isFullAcidTable)) + logDebug("getPartitions : Input split: " + inputSplits(i)) + } + partitionArray + } + + def makeV2ReaderForTable(hiveAcidMetadata: HiveAcidMetadata): java.util.List[InputPartition[ColumnarBatch]] = { + makeV2ReaderForPath(hiveAcidMetadata, hiveAcidMetadata.hTable.getPath.toString, + hiveAcidMetadata.hTable.getSd.getCols, + new SpecificInternalRow(hiveAcidMetadata.partitionSchema)) + } + + private def setReaderOptions(hiveAcidMetadata: HiveAcidMetadata) : Unit = { + // Push Down Predicate + if (readerOptions.readConf.predicatePushdownEnabled) { + setPushDownFiltersInHadoopConf(readerOptions.hadoopConf, + hiveAcidMetadata, + readerOptions.dataFilters) + } + + // Set Required column. + setRequiredColumnsInHadoopConf(readerOptions.hadoopConf, + hiveAcidMetadata, + readerOptions.requiredNonPartitionedColumns) logDebug(s"sarg.pushdown: " + s"${readerOptions.hadoopConf.get("sarg.pushdown")}," + @@ -122,9 +183,17 @@ extends CastSupport with Reader with Logging { s"${readerOptions.hadoopConf.get("hive.io.file.readcolumn.names")}, " + s"hive.io.file.readcolumn.ids: " + s"${readerOptions.hadoopConf.get("hive.io.file.readcolumn.ids")}") + } + /** + * @param hiveAcidMetadata - hive acid metadata for underlying table + * @return - Returns RDD on top of non partitioned hive acid table and list of partitionNames empty list + * for entire table + */ + def makeRDDForTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] = { + setReaderOptions(hiveAcidMetadata) makeRDDForTable( - hiveTable, + hiveAcidMetadata.hTable, Util.classForName(hiveAcidOptions.tableDesc.getSerdeClassName, loadShaded = true).asInstanceOf[Class[Deserializer]], hiveAcidMetadata, @@ -134,12 +203,10 @@ extends CastSupport with Reader with Logging { /** * @param hiveAcidMetadata - hive acid metadata of underlying table - * @param partitions - partitions for the table * * @return - Returns RDD on top of partitioned hive acid table */ - def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata, - partitions: Seq[ReaderPartition]): RDD[InternalRow] = { + def makeRDDForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata): RDD[InternalRow] = { val partitionToDeserializer = partitions.map(p => p.ptn.asInstanceOf[HiveJarPartition]).map { part => @@ -151,6 +218,59 @@ extends CastSupport with Reader with Logging { makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None, readerOptions) } + def makeReaderForPartitionedTable(hiveAcidMetadata: HiveAcidMetadata): + java.util.ArrayList[InputPartition[ColumnarBatch]] = { + val partitionToDeserializer = getPartitionToDeserializer(partitions) + val partitionArray = new java.util.ArrayList[InputPartition[ColumnarBatch]] + val partList = partitionToDeserializer.map { case (partition, partDeserializer) => + val partPath = partition.getDataLocation + val inputPathStr = applyFilterIfNeeded(partPath, None) + val partSpec = partition.getSpec + val partCols = partition.getTable.getPartitionKeys.asScala.map(_.getName) + + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + + val mutableRow = new SpecificInternalRow(hiveAcidMetadata.partitionSchema) + + val partitionKeyAttrs = + readerOptions.requiredAttributes.zipWithIndex.filter { attr => + readerOptions.partitionAttributes.contains(attr) + } + + //TODO : The partition values can be filled directly using hive acid batch reader. + def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { + var offset = 0 + partitionKeyAttrs.foreach { case (attr, ordinal) => + val partOrdinal = readerOptions.partitionAttributes.indexOf(attr) + row(offset) = cast( + Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) + offset = offset + 1 + } + } + fillPartitionKeys(partValues, mutableRow) + + makeV2ReaderForPath(hiveAcidMetadata, inputPathStr, partition.getTPartition.getSd.getCols, mutableRow) + } + for (list <- partList) partitionArray.addAll(list) + partitionArray + } + + private def getPartitionToDeserializer(partitions: Seq[ReaderPartition]) + : Map[HiveJarPartition, Class[_ <: Deserializer]] = { + partitions.map(p => p.ptn.asInstanceOf[HiveJarPartition]).map { + part => + val deserializerClassName = part.getTPartition.getSd.getSerdeInfo.getSerializationLib + val deserializer = Util.classForName(deserializerClassName, loadShaded = true) + .asInstanceOf[Class[Deserializer]] + (part, deserializer) + }.toMap + } + /** * Creates a Hadoop RDD to read data from the target table's data directory. * Returns a transformed RDD that contains deserialized rows. diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala index 0a1ba9b..813019a 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/hive/HiveAcidSearchArgument.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.types._ * builder methods mentioned above can only be found in test code, where all tested filters are * known to be convertible. */ -private[hive] object HiveAcidSearchArgument { +private[hiveacid] object HiveAcidSearchArgument { private def buildTree(filters: Seq[Filter]): Option[Filter] = { filters match { case Seq() => None @@ -75,9 +75,9 @@ private[hive] object HiveAcidSearchArgument { // in order to distinguish predicate pushdown for nested columns. private def quoteAttributeNameIfNeeded(name: String) : String = { if (!name.contains("`") && name.contains(".")) { - s"`$name`" + s"`$name.toLowerCase()`" } else { - name + name.toLowerCase() } } @@ -135,7 +135,9 @@ private[hive] object HiveAcidSearchArgument { expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(dataTypeMap(attribute.toLowerCase)) + + def getTypeFromMap(attribute: String): DataType = dataTypeMap(attribute.toLowerCase) import org.apache.spark.sql.sources._ @@ -173,47 +175,47 @@ private[hive] object HiveAcidSearchArgument { // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualTo(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualNullSafe(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThan(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThanOrEqual(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThan(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThanOrEqual(attribute, value) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getTypeFromMap(attribute)) Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNull(attribute) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNotNull(attribute) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + case In(attribute, values) if isSearchableType(getTypeFromMap(attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + val castedValues = values.map(v => castLiteralValue(v, getTypeFromMap(attribute))) Some(builder.startAnd().in(quotedName, getType(attribute), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) @@ -221,6 +223,15 @@ private[hive] object HiveAcidSearchArgument { } } + def getSupportedFilters(schema: StructType, filters: Seq[Filter]): Seq[Filter] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) + } yield filter + convertibleFilters + } + /** * Create filters as a SearchArgument instance. */ @@ -229,10 +240,7 @@ private[hive] object HiveAcidSearchArgument { // First, tries to convert each filter individually to see whether it's convertible, and then // collect all convertible ones to build the final `SearchArgument`. - val convertibleFilters = for { - filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) - } yield filter + val convertibleFilters = getSupportedFilters(schema, filters) for { // Combines all convertible filters using `And` to produce a single conjunction diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidColumnVector.java b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidColumnVector.java new file mode 100644 index 0000000..fbe667c --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidColumnVector.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.reader.v2; + +import java.math.BigDecimal; + +import com.qubole.shaded.hadoop.hive.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; +import com.qubole.shaded.hadoop.hive.ql.exec.vector.ColumnVector; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. This class is a copy of spark ColumnVector which is declared private. + */ +public class HiveAcidColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + private final boolean isTimestamp; + + private int batchSize; + + HiveAcidColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionReaderV2.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionReaderV2.scala new file mode 100644 index 0000000..7144489 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionReaderV2.scala @@ -0,0 +1,89 @@ +package com.qubole.spark.hiveacid.reader.v2 + +import java.io.IOException +import java.util._ +import java.util.List + +import scala.collection.JavaConverters._ +import com.qubole.spark.hiveacid.util.SerializableConfiguration +import com.qubole.spark.hiveacid.rdd.HiveAcidPartition +import org.apache.hadoop.mapred.JobConf +import com.qubole.shaded.orc.TypeDescription +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.TaskContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.hadoop.mapred.TaskAttemptID +import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapred.JobID +import org.apache.spark.sql.execution.datasources.orc._ +import com.qubole.shaded.orc.OrcFile +import com.qubole.shaded.orc.OrcConf +import com.qubole.shaded.hadoop.hive.ql.io.orc.OrcSplit +import com.qubole.shaded.hadoop.hive.serde2.ColumnProjectionUtils + +private[v2] class HiveAcidInputPartitionReaderV2(split: HiveAcidPartition, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionValues : InternalRow, + requiredFields: Array[StructField], + partitionSchema : StructType, + isFullAcidTable: Boolean) + extends InputPartitionReader[ColumnarBatch] { + //TODO : Need to get a unique id to cache the jobConf. + private val jobConf = new JobConf(broadcastedConf.value.value) + private val defaultBatchSize = jobConf.get("spark.hive.acid.default.row.batch.size", "1024") + private val orcColumnarBatchReader = new OrcColumnarBatchReader(defaultBatchSize.toInt) + + private def initReader() : Unit = { + // Get the reader schema using the column names and types set in hive conf. + val readerSchema: TypeDescription = + com.qubole.shaded.hadoop.hive.ql.io.orc.OrcInputFormat.getDesiredRowTypeDescr(jobConf, true, 2147483647) + + // Set it as orc.mapred.input.schema so that the reader will read only the required columns + jobConf.set("orc.mapred.input.schema", readerSchema.toString) + + val fileSplit = split.inputSplit.value.asInstanceOf[OrcSplit] + val readerLocal = OrcFile.createReader(fileSplit.getPath, + OrcFile.readerOptions(jobConf).maxLength( + OrcConf.MAX_FILE_LENGTH.getLong(jobConf)).filesystem(fileSplit.getPath.getFileSystem(jobConf))) + + // Get the column id from hive conf para. TODO : Can be sent via a parameter + val colIds = jobConf.get(ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR) + val requestedColIds = if (!colIds.isEmpty()) { + colIds.split(",").map(a => a.toInt) + } else { + Array[Int]() + } + + // Register the listener for closing the reader before init is done. + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new org.apache.hadoop.mapred.TaskAttemptContextImpl(jobConf, attemptId) + val iter = new org.apache.spark.sql.execution.datasources.RecordReaderIterator(orcColumnarBatchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + + //TODO: Need to generalize it for supporting other kind of file format. + orcColumnarBatchReader.initialize(fileSplit, taskAttemptContext) + orcColumnarBatchReader.initBatch(readerLocal.getSchema, requestedColIds, + requiredFields, partitionSchema, partitionValues, isFullAcidTable, fileSplit.isOriginal) + } + initReader() + + @throws(classOf[IOException]) + override def next() : Boolean = { + orcColumnarBatchReader.nextKeyValue() + } + + override def get () : ColumnarBatch = { + orcColumnarBatchReader.getCurrentValue + } + + @throws(classOf[IOException]) + override def close() : Unit = { + orcColumnarBatchReader.close() + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionV2.scala b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionV2.scala new file mode 100644 index 0000000..c572fa2 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/HiveAcidInputPartitionV2.scala @@ -0,0 +1,35 @@ +package com.qubole.spark.hiveacid.reader.v2 + +import com.qubole.spark.hiveacid.rdd.HiveAcidPartition +import com.qubole.spark.hiveacid.util.{SerializableConfiguration} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.JobConf +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +private[hiveacid] class HiveAcidInputPartitionV2(split: HiveAcidPartition, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionValues : InternalRow, + requiredFields: Array[StructField], + partitionSchema : StructType, + isFullAcidTable: Boolean) + extends InputPartition[ColumnarBatch] { + override def preferredLocations: Array[String] = { + try split.inputSplit.value.getLocations + catch { + case e: Exception => + //preferredLocations specifies to return empty array if no preference + new Array[String] (0) + } + } + + override def createPartitionReader: InputPartitionReader[ColumnarBatch] = { + new HiveAcidInputPartitionReaderV2(split, broadcastedConf, partitionValues, + requiredFields, partitionSchema, isFullAcidTable) + } +} diff --git a/src/main/scala/com/qubole/spark/hiveacid/reader/v2/OrcColumnarBatchReader.java b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/OrcColumnarBatchReader.java new file mode 100644 index 0000000..2a663c3 --- /dev/null +++ b/src/main/scala/com/qubole/spark/hiveacid/reader/v2/OrcColumnarBatchReader.java @@ -0,0 +1,785 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.qubole.spark.hiveacid.reader.v2; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import com.qubole.shaded.orc.OrcConf; +import com.qubole.shaded.orc.OrcFile; +import com.qubole.shaded.orc.Reader; +import com.qubole.shaded.orc.TypeDescription; +import com.qubole.shaded.orc.mapred.OrcInputFormat; +import com.qubole.shaded.hadoop.hive.common.type.HiveDecimal; +import com.qubole.shaded.hadoop.hive.ql.exec.vector.*; +import com.qubole.shaded.hadoop.hive.serde2.io.HiveDecimalWritable; +import com.qubole.shaded.hadoop.hive.ql.io.orc.OrcSplit; +import com.qubole.shaded.hadoop.hive.ql.io.orc.VectorizedOrcAcidRowBatchReader; +import com.qubole.shaded.hadoop.hive.ql.plan.MapWork; +import com.qubole.shaded.hadoop.hive.ql.exec.Utilities; +import com.qubole.shaded.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; +import com.qubole.shaded.orc.OrcProto; +import org.apache.hadoop.mapred.Reporter; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.io.NullWritable; +import com.qubole.shaded.orc.OrcUtils; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import com.qubole.shaded.hadoop.hive.ql.io.sarg.*; + +/** + * After creating, `initialize` and `initBatch` should be called sequentially. This internally uses + * the Hive ACID Vectorized ORC reader to support reading of deleted and updated data. + */ +public class OrcColumnarBatchReader extends RecordReader { + + // The capacity of vectorized batch. + private int capacity; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + // ROW schema. This has the columns that needs to be projected. Even though we dont need + // the ACID related columns like row id, write-id are also projected so that ACID reader can + // use it to filter out the deleted/updated records. + private TypeDescription schema; + + + // The column IDs of the physical ORC file schema which are required by this reader. + // -1 means this required column doesn't exist in the ORC file. + private int[] requestedColIds; + + private StructField[] requiredFields; + + // Record reader from ORC row batch. + private com.qubole.shaded.orc.RecordReader baseRecordReader; + + // Wrapper reader over baseRecordReader for filtering out deleted/updated records. + private VectorizedOrcAcidRowBatchReader fullAcidRecordReader; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + // The wrapped ORC column vectors. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // File(split) to be read. + private OrcSplit fileSplit; + + private Configuration conf; + + // For full ACID scan, the first 5 fields are transaction related. These fields are used by + // fullAcidRecordReader. While forming the batch to emit we skip the first 5 columns. For + // normal scan, this value will be 0 as ORC file will not have the transaction related columns. + private int rootColIdx; + + + // Constructor. + public OrcColumnarBatchReader(int capacity) { + this.capacity = capacity; + } + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException { + if (fullAcidRecordReader != null) { + return fullAcidRecordReader.getProgress(); + } else { + return baseRecordReader.getProgress(); + } + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (fullAcidRecordReader != null) { + fullAcidRecordReader.close(); + fullAcidRecordReader = null; + } + if (baseRecordReader != null) { + baseRecordReader.close(); + baseRecordReader = null; + } + } + + // The columns that are pushed as search arguments to ORC file reader. + private String[] getSargColumnNames(String[] originalColumnNames, + List types, + boolean[] includedColumns, + boolean isOriginal) { + // Skip ACID related columns if present. + int dataColIdx = isOriginal ? 0 : rootColIdx + 1; + String[] columnNames = new String[types.size() - dataColIdx]; + int i = 0; + Iterator iterator = ((OrcProto.Type)types.get(dataColIdx)).getSubtypesList().iterator(); + + while(true) { + int columnId; + do { + if (!iterator.hasNext()) { + return columnNames; + } + columnId = (Integer)iterator.next(); + } while(includedColumns != null && !includedColumns[columnId - dataColIdx]); + columnNames[columnId - dataColIdx] = originalColumnNames[i++]; + } + } + + private void setSearchArgument(Reader.Options options, + List types, + Configuration conf, + boolean isOriginal) { + String neededColumnNames = conf.get("hive.io.file.readcolumn.names"); + if (neededColumnNames == null) { + options.searchArgument((SearchArgument)null, (String[])null); + } else { + // The filters which are pushed down are set in config using sarg.pushdown. + SearchArgument sarg = ConvertAstToSearchArg.createFromConf(conf); + if (sarg == null) { + options.searchArgument((SearchArgument)null, (String[])null); + } else { + String[] colNames = getSargColumnNames(neededColumnNames.split(","), + types, options.getInclude(), isOriginal); + options.searchArgument(sarg, colNames); + } + } + } + + private void setSearchArgumentForOption(Configuration conf, + TypeDescription readerSchema, + Reader.Options readerOptions, + boolean isOriginal) { + final List schemaTypes = OrcUtils.getOrcTypes(readerSchema); + setSearchArgument(readerOptions, schemaTypes, conf, isOriginal); + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize(InputSplit inputSplit, + TaskAttemptContext taskAttemptContext) { + fileSplit = (OrcSplit)inputSplit; + conf = taskAttemptContext.getConfiguration(); + } + + // Wrapper ACID reader over base ORC record reader. + private VectorizedOrcAcidRowBatchReader initHiveAcidReader(Configuration conf, + OrcSplit orcSplit, + com.qubole.shaded.orc.RecordReader innerReader) { + conf.set("hive.vectorized.execution.enabled", "true"); + MapWork mapWork = new MapWork(); + VectorizedRowBatchCtx rbCtx = new VectorizedRowBatchCtx(); + mapWork.setVectorMode(true); + mapWork.setVectorizedRowBatchCtx(rbCtx); + Utilities.setMapWork(conf, mapWork); + + org.apache.hadoop.mapred.RecordReader baseReader + = new org.apache.hadoop.mapred.RecordReader() { + + @Override + public boolean next(NullWritable key, VectorizedRowBatch value) throws IOException { + // This is the baseRecordReader which will be called internally by ACID reader to fetch + // records. + return innerReader.nextBatch(value); + } + + @Override + public NullWritable createKey() { + return NullWritable.get(); + } + + @Override + public VectorizedRowBatch createValue() { + // Tis column batch will be passed as value by ACID reader while calling next. So the + // baseRecordReader will populate the batch directly and we dont have to do any + // extra copy if selected in use is false. + return batch; + } + + @Override + public long getPos() throws IOException { + return 0; + } + + @Override + public void close() throws IOException { + innerReader.close(); + } + + @Override + public float getProgress() throws IOException { + return innerReader.getProgress(); + } + }; + + try { + return new VectorizedOrcAcidRowBatchReader(orcSplit, new JobConf(conf), Reporter.NULL, baseReader, rbCtx, true); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues, + boolean isFullAcidTable, + boolean isOriginal + ) throws IOException { + boolean isAcidScan = isFullAcidTable && !isOriginal; + if (!isAcidScan) { + //rootCol = org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.getRootColumn(true); + rootColIdx = 0; + } else { + //rootCol = org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.getRootColumn(false) - 1; + // In ORC, for full ACID scan, the first 5 fields stores the transaction metadata. + rootColIdx = 5; + } + + // Create the baseRecordReader. This reader actually does the reading from ORC file. + Reader readerInner = OrcFile.createReader( + fileSplit.getPath(), OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + Reader.Options options = /*createOptionsForReader(conf, orcSchema);*/ + OrcInputFormat.buildOptions(conf, readerInner, fileSplit.getStart(), fileSplit.getLength()); + setSearchArgumentForOption(conf, orcSchema, options, isOriginal); + baseRecordReader = readerInner.rows(options); + + // This schema will have both required fields and the filed to be used by ACID reader. + schema = orcSchema; + batch = orcSchema.createRowBatch(capacity); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + // The result schema will just have those fields which are required to be projected. + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + // For ACID scan, the ACID batch reader might filter out some of the records read from + // ORC file. So we have to recreate the batch read from ORC files. This columnVectors + // will be used during that time. Missing columns and partition columns are filled here + // and other valid columns will be filled once the batch of record is read. + //TODO:We can set the config to let ORC reader fill the partition values. + if (isAcidScan) { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + } + + + // Just wrap the ORC column vector instead of copying it to Spark column vector. This wrapper + // will be used for insert only table scan or scanning original files (ACID V1) or compacted + // file. In those cases, the batch read from ORC will be emitted as it is. So no need to + // prepare a separate copy. + + ColumnVector[] fields; + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + if (rootColIdx == 0) { + fields = batch.cols; + } else { + fields = ((StructColumnVector)batch.cols[rootColIdx]).fields; + } + + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + //StructColumnVector dataCols = (StructColumnVector)batch.cols[5]; + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new HiveAcidColumnVector(dt, fields[colId]); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + if (isAcidScan) { + fullAcidRecordReader = initHiveAcidReader(conf, fileSplit, baseRecordReader); + } else { + fullAcidRecordReader = null; + } + } + + /** + * Return true if there exists more data in the next batch. For acid scan, the ACID batch + * reader is used. The ACID batch reader internally uses the baseRecordReader and then + * filters out the deleted/not visible records. This filter is propagated here using + * selectedInUse. If selectedInUse is false, that means there is no filtering happened + * so we can directly use the orcVectorWrappers. If selectedInUse is set to true, we + * have to recreate the column batch using selected array. + */ + private boolean nextBatch() throws IOException { + VectorizedRowBatch vrb; + if (fullAcidRecordReader != null) { + vrb = schema.createRowBatch(capacity); + // Internally Acid batch reader changes the batch schema. So vrb is passed instead of batch. + if (!fullAcidRecordReader.next(NullWritable.get(), vrb)) { + // Should not use batch size for fullAcidRecordReader. The batch size may be 0 in some cases + // where whole batch of records are filtered out. + return false; + } + } else { + if (!baseRecordReader.nextBatch(batch)) { + //TODO: Should we return false if batch size is 0? + return false; + } + vrb = batch; + } + + int batchSize = vrb.size; + + // selectedInUse is false means no filtering is done. We can use the wrapper directly. No need to + // recreate the column batch. + if (!vrb.selectedInUse) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((HiveAcidColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + columnarBatch = new ColumnarBatch(orcVectorWrappers); + columnarBatch.setNumRows(batchSize); + return true; + } + + + // Recreate the batch using selected array. For those records with selected[idx] == 0, remove + // those from the resultant batch. So its possible that the batch size will become 0, but still we + // should return true, so that the caller calls next again. Before that we should reset the column + // vector to inform user that no data is there. + for (WritableColumnVector toColumn : columnVectors) { + toColumn.reset(); + } + + if (batchSize > 0) { + StructColumnVector dataCols = (StructColumnVector)vrb.cols[rootColIdx]; + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = dataCols.fields[requestedColIds[i]]; + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn, vrb.selected); + } else { + putValues(batchSize, field, fromColumn, toColumn, vrb.selected); + } + } + } + } + + columnarBatch = new ColumnarBatch(columnVectors); + columnarBatch.setNumRows(batchSize); + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + int size = data.vector[0].length; + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn, + int[] selected) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putBoolean(index, data[logicalIdx] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putByte(index, (byte)data[logicalIdx]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putShort(index, (short)data[logicalIdx]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putInt(index, (int)data[logicalIdx]); + } + } else if (type instanceof LongType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putLong(index, data[logicalIdx]); + } + //toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putLong(index, fromTimestampColumnVector(data, logicalIdx)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putFloat(index, (float)data[logicalIdx]); + } + } else if (type instanceof DoubleType) { + //toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + toColumn.putDouble(index, data[logicalIdx]); + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.arrayData(); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + arrayData.putBytes(pos, data.length[logicalIdx], data.vector[logicalIdx], data.start[logicalIdx]); + toColumn.putArray(index, pos, data.length[logicalIdx]); + pos += data.length[logicalIdx]; + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + toColumn.arrayData().reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[logicalIdx]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn, + int[] selected) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[logicalIdx] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[logicalIdx]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[logicalIdx]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[logicalIdx]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[logicalIdx]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, logicalIdx)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[logicalIdx]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[logicalIdx]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.arrayData(); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[logicalIdx], vector.vector[logicalIdx], vector.start[logicalIdx]); + toColumn.putArray(index, pos, vector.length[logicalIdx]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + toColumn.arrayData().reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + int logicalIdx = selected[index]; + if (fromColumn.isNull[logicalIdx]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[logicalIdx]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +}