From 3b21e8f6b8daf97b803246f1d8135a1e2f38cd0a Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Thu, 26 Dec 2024 17:39:33 -0800 Subject: [PATCH] feat(fuzzer): Support multiple joins in the join node "toSql" methods for reference query runners (#11801) Summary: Currently, the hash join and nested loop join "toSql" methods for all reference query runners only support a single join. This change extends it to support multiple joins, only needing the join node of the last join in the tree. It traverses up the tree and recursively builds the sql query. Differential Revision: D66977480 --- velox/core/PlanNode.h | 2 + velox/exec/fuzzer/DuckQueryRunner.cpp | 160 ++++++++++++----- velox/exec/fuzzer/DuckQueryRunner.h | 12 ++ velox/exec/fuzzer/JoinFuzzer.cpp | 4 +- velox/exec/fuzzer/PrestoQueryRunner.cpp | 196 +++++++++++++++------ velox/exec/fuzzer/PrestoQueryRunner.h | 22 ++- velox/exec/fuzzer/ReferenceQueryRunner.h | 14 ++ velox/exec/tests/PrestoQueryRunnerTest.cpp | 118 +++++++++++++ 8 files changed, 420 insertions(+), 108 deletions(-) diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index d9085215860ab..549fa4327e408 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -324,6 +324,8 @@ class ValuesNode : public PlanNode { const size_t repeatTimes_; }; +using ValuesNodePtr = std::shared_ptr; + class ArrowStreamNode : public PlanNode { public: ArrowStreamNode( diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e1..39ae4114c19fd 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,38 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::string DuckQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} + +std::unordered_map> +DuckQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::multiset> DuckQueryRunner::execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + DuckDbQueryRunner queryRunner; + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(sql, plan->outputType()); +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -341,38 +373,53 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; } - return out.str(); - }; + out << keys[i]->name(); + } + return out.str(); +} - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + if (!joinNode.leftKeys().empty()) { + out << " AND "; } - return out.str(); - }; + out << filterToSql(joinNode.filter()); + } + return out.str(); +}; + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto probeSubQuery = toSql(joinNode->sources()[0]), + buildSubQuery = toSql(joinNode->sources()[1]); + probeSubQuery && buildSubQuery) { + probeTableName = fmt::format("({})", *probeSubQuery); + buildTableName = fmt::format("({})", *buildSubQuery); + } else { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -386,24 +433,27 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one + // Presto. A scalar subquery expression is a subquery that returns one // result row from exactly one column for every input row. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -412,29 +462,31 @@ std::optional DuckQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -448,6 +500,17 @@ std::optional DuckQueryRunner::toSql( std::optional DuckQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto probeSubQuery = toSql(joinNode->sources()[0]), + buildSubQuery = toSql(joinNode->sources()[1]); + probeSubQuery && buildSubQuery) { + probeTableName = fmt::format("({})", *probeSubQuery); + buildTableName = fmt::format("({})", *buildSubQuery); + } else { + return std::nullopt; + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -458,13 +521,16 @@ std::optional DuckQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af04884..6de5f4684fdff 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,18 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + /// Executes SQL query returned by the 'toSql' method based on the plan. + std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 1860eca9df0bb..21e4a7de81fcc 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -680,10 +680,8 @@ std::optional JoinFuzzer::computeReferenceResults( } if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + return referenceQueryRunner_->execute(*sql, plan); } - LOG(INFO) << "Query not supported by the reference DB"; return std::nullopt; } diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64df..b7952b9eed863 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -554,46 +554,62 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); } + return out.str(); +}; - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +}; - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + if (!joinNode.leftKeys().empty()) { + out << " AND "; + } + out << filterToSql(joinNode.filter()); + } + return out.str(); +}; - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +std::string PrestoQueryRunner::getTableName( + const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; +std::optional PrestoQueryRunner::toSql( + const std::shared_ptr& joinNode) { + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } + std::string probeTableName; + std::string buildTableName; + if (const auto probeSubQuery = toSql(joinNode->sources()[0]), + buildSubQuery = toSql(joinNode->sources()[1]); + probeSubQuery && buildSubQuery) { + probeTableName = fmt::format("({})", *probeSubQuery); + buildTableName = fmt::format("({})", *buildSubQuery); + } else { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -607,13 +623,16 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in @@ -622,9 +641,9 @@ std::optional PrestoQueryRunner::toSql( if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -633,29 +652,31 @@ std::optional PrestoQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -668,6 +689,17 @@ std::optional PrestoQueryRunner::toSql( std::optional PrestoQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + if (const auto probeSubQuery = toSql(joinNode->sources()[0]), + buildSubQuery = toSql(joinNode->sources()[1]); + probeSubQuery && buildSubQuery) { + probeTableName = fmt::format("({})", *probeSubQuery); + buildTableName = fmt::format("({})", *buildSubQuery); + } else { + return std::nullopt; + } + std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -678,13 +710,16 @@ std::optional PrestoQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( @@ -695,11 +730,17 @@ std::optional PrestoQueryRunner::toSql( } std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& valuesNode) { + const core::ValuesNodePtr& valuesNode) { if (!isSupportedDwrfType(valuesNode->outputType())) { return std::nullopt; } - return "tmp"; + return getTableName(valuesNode); +} + +std::multiset> PrestoQueryRunner::execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + return exec::test::materialize(executeVector(sql, plan)); } std::multiset> PrestoQueryRunner::execute( @@ -749,6 +790,51 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::unordered_map> +PrestoQueryRunner::getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::vector PrestoQueryRunner::executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) { + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMap[tableName] = { + makeNullRows(input, fmt::format("{}x", tableName), pool())}; + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + auto tableDirectoryPath = createTable(tableName, input[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, input, writerPool.get()); + } + + // Run the query. + return execute(sql); +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e101..b8ab4795eae89 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -70,6 +70,9 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool isSupported(const exec::FunctionSignature& signature) override; + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode); + /// Creates 'tmp' table using specified data, executes SQL query generated by /// 'toSql' and returns the results. /// @@ -83,6 +86,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Executes SQL query returned by the 'toSql' method based on the plan. + /// Returns std::nullopt if the plan is not supported. + std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +109,14 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan); + + std::vector executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) override; + std::vector executeVector( const std::string& sql, const std::vector& input, @@ -137,13 +154,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::shared_ptr& tableWriteNode); std::optional toSql( - const std::shared_ptr& joinNode); + const std::shared_ptr& joinNode); std::optional toSql( const std::shared_ptr& joinNode); - std::optional toSql( - const std::shared_ptr& valuesNode); + std::optional toSql(const core::ValuesNodePtr& valuesNode); std::string startQuery( const std::string& sql, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc246..8e475191dccf2 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -66,6 +66,13 @@ class ReferenceQueryRunner { return true; } + /// Executes SQL query returned by the 'toSql' method based on the plan. + virtual std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( @@ -88,6 +95,13 @@ class ReferenceQueryRunner { return false; } + /// Similar to 'execute' but returns results in RowVector format. + virtual std::vector executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Similar to 'execute' but returns results in RowVector format. /// Caller should ensure 'supportsVeloxVectorResults' returns true. virtual std::vector executeVector( diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c1..14447f5eb3967 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,122 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = makeRowVector( + {"t0", "t1", "t2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto u = makeRowVector( + {"u0", "u1", "u2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto v = makeRowVector( + {"v0", "v1", "v2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto w = makeRowVector( + {"w0", "w1", "w2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (cast(v1 as BIGINT) > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).values({w}).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test