From 731d39fa593070000dc950b17ed8293856d2c2bb Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Sat, 14 Dec 2024 23:07:55 -0800 Subject: [PATCH] feat(fuzzer): Support multiple joins in the join node "toSql" methods for reference query runners (#11801) Summary: Currently, the hash join and nested loop join "toSql" methods for all reference query runners only support a single join. This change extends it to support multiple joins, only needing the join node of the last join in the tree. It traverses up the tree and recursively builds the sql query. Differential Revision: D66977480 --- velox/exec/fuzzer/DuckQueryRunner.cpp | 161 +++++++++++----- velox/exec/fuzzer/DuckQueryRunner.h | 8 + velox/exec/fuzzer/JoinFuzzer.cpp | 6 +- velox/exec/fuzzer/PrestoQueryRunner.cpp | 204 +++++++++++++++------ velox/exec/fuzzer/PrestoQueryRunner.h | 14 ++ velox/exec/fuzzer/ReferenceQueryRunner.h | 8 + velox/exec/tests/PrestoQueryRunnerTest.cpp | 97 ++++++++++ 7 files changed, 397 insertions(+), 101 deletions(-) diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e1..5d497bbabcb6c 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,17 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::multiset> DuckQueryRunner::execute( + const std::string& sql, + const std::unordered_map>& inputMap, + const RowTypePtr& resultType) { + DuckDbQueryRunner queryRunner; + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(sql, resultType); +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -341,38 +352,62 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; } - return out.str(); - }; + out << keys[i]->name(); + } + return out.str(); +} - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "MergeJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } } - return out.str(); - }; + } const auto& outputNames = joinNode->outputType()->names(); @@ -386,24 +421,27 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one + // Presto. A scalar subquery expression is a subquery that returns one // result row from exactly one column for every input row. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -412,29 +450,31 @@ std::optional DuckQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -448,6 +488,30 @@ std::optional DuckQueryRunner::toSql( std::optional DuckQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "MergeJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } + } + } + + const auto& outputNames = joinNode->outputType()->names(); std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -458,13 +522,16 @@ std::optional DuckQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af04884..97184b0f7ce48 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,14 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Creates tables for each entry in 'inputs' and runs 'sql' query. Returns + /// results according to 'resultType' schema. + std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 333a79c24b746..09cbc6bdbc66d 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -680,8 +680,10 @@ std::optional JoinFuzzer::computeReferenceResults( } if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + std::unordered_map> inputs = { + {fmt::format("t_{}", plan->sources()[0]->id()), probeInput}, + {fmt::format("t_{}", plan->sources()[1]->id()), buildInput}}; + return referenceQueryRunner_->execute(*sql, inputs, plan->outputType()); } LOG(INFO) << "Query not supported by the reference DB"; diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64df..21ad7bd8951cc 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -554,46 +554,70 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); } + return out.str(); +}; - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +}; - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +}; - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); +std::optional PrestoQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + // Presto does not support merge join. + if (joinNode->sources()[i]->name() == "MergeJoin") { + return std::nullopt; } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } } - return out.str(); - }; + } + + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -607,13 +631,16 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in @@ -622,9 +649,9 @@ std::optional PrestoQueryRunner::toSql( if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -633,29 +660,31 @@ std::optional PrestoQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -668,6 +697,33 @@ std::optional PrestoQueryRunner::toSql( std::optional PrestoQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + // Presto does not support merge join. + if (joinNode->sources()[i]->name() == "MergeJoin") { + return std::nullopt; + } + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } + } + } + + const auto& outputNames = joinNode->outputType()->names(); std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -678,13 +734,16 @@ std::optional PrestoQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( @@ -702,6 +761,13 @@ std::optional PrestoQueryRunner::toSql( return "tmp"; } +std::multiset> PrestoQueryRunner::execute( + const std::string& sql, + const std::unordered_map>& inputMap, + const RowTypePtr& resultType) { + return exec::test::materialize(executeVector(sql, inputMap, resultType)); +} + std::multiset> PrestoQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -749,6 +815,40 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::vector PrestoQueryRunner::executeVector( + const std::string& sql, + const std::unordered_map>& inputMap, + const velox::RowTypePtr& resultType) { + std::unordered_map> inputMapWithNulls; + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMapWithNulls.insert( + {tableName, + {makeNullRows(input, fmt::format("{}x", tableName), pool())}}); + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + const std::vector& currInput = + inputMapWithNulls.contains(tableName) ? inputMapWithNulls[tableName] + : input; + auto tableDirectoryPath = createTable(tableName, currInput[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, currInput, writerPool.get()); + } + + // Run the query. + return execute(sql); +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e101..830454252302c 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -83,6 +83,14 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Creates tables for each entry in 'inputs' and runs 'sql' query. Returns + /// results according to 'resultType' schema. + std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +108,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + std::vector executeVector( + const std::string& sql, + const std::unordered_map>& + inputMap, + const velox::RowTypePtr& resultType); + std::vector executeVector( const std::string& sql, const std::vector& input, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc246..ae0ed12e6459d 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -66,6 +66,14 @@ class ReferenceQueryRunner { return true; } + /// Executes SQL query returned by the 'toSql' method using 'inputs' data. + /// Converts results using 'resultType' schema. + virtual std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) = 0; + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c1..0ac157f861374 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,101 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto u = ROW({"u0", "u1", "u2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto v = ROW({"v0", "v1", "v2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto w = ROW({"w0", "w1", "w2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (v1 > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).tableScan("w", w).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test