diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e1..5d497bbabcb6c 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,17 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::multiset> DuckQueryRunner::execute( + const std::string& sql, + const std::unordered_map>& inputMap, + const RowTypePtr& resultType) { + DuckDbQueryRunner queryRunner; + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(sql, resultType); +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -341,38 +352,62 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; } - return out.str(); - }; + out << keys[i]->name(); + } + return out.str(); +} - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "MergeJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } } - return out.str(); - }; + } const auto& outputNames = joinNode->outputType()->names(); @@ -386,24 +421,27 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one + // Presto. A scalar subquery expression is a subquery that returns one // result row from exactly one column for every input row. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -412,29 +450,31 @@ std::optional DuckQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -448,6 +488,30 @@ std::optional DuckQueryRunner::toSql( std::optional DuckQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "MergeJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } + } + } + + const auto& outputNames = joinNode->outputType()->names(); std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -458,13 +522,16 @@ std::optional DuckQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af04884..97184b0f7ce48 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,14 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Creates tables for each entry in 'inputs' and runs 'sql' query. Returns + /// results according to 'resultType' schema. + std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 333a79c24b746..09cbc6bdbc66d 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -680,8 +680,10 @@ std::optional JoinFuzzer::computeReferenceResults( } if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + std::unordered_map> inputs = { + {fmt::format("t_{}", plan->sources()[0]->id()), probeInput}, + {fmt::format("t_{}", plan->sources()[1]->id()), buildInput}}; + return referenceQueryRunner_->execute(*sql, inputs, plan->outputType()); } LOG(INFO) << "Query not supported by the reference DB"; diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64df..21ad7bd8951cc 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -554,46 +554,70 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; +static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); } + return out.str(); +}; - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } +static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +}; - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); +static std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + out << " AND " << filterToSql(joinNode.filter()); + } + return out.str(); +}; - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); +std::optional PrestoQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + // Presto does not support merge join. + if (joinNode->sources()[i]->name() == "MergeJoin") { + return std::nullopt; } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } } - return out.str(); - }; + } + + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } const auto& outputNames = joinNode->outputType()->names(); @@ -607,13 +631,16 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); break; case core::JoinType::kLeftSemiFilter: // Multiple columns returned by a scalar subquery is not supported in @@ -622,9 +649,9 @@ std::optional PrestoQueryRunner::toSql( if (joinNode->leftKeys().size() > 1) { return std::nullopt; } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } @@ -633,29 +660,31 @@ std::optional PrestoQueryRunner::toSql( case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } - sql << ") FROM t"; + sql << ") FROM " << probeTableName; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; if (joinNode->filter()) { sql << " WHERE " << filterToSql(joinNode->filter()); } sql << ")"; } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); + sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM " + << buildTableName << " WHERE " << joinConditionAsSql(*joinNode); sql << ")"; } break; @@ -668,6 +697,33 @@ std::optional PrestoQueryRunner::toSql( std::optional PrestoQueryRunner::toSql( const std::shared_ptr& joinNode) { + std::string probeTableName = + fmt::format("t_{}", joinNode->sources()[0]->id()); + std::string buildTableName = + fmt::format("t_{}", joinNode->sources()[1]->id()); + + // If an input to this join is another join, change the table name to a + // sub-query. + for (auto i = 0; i < joinNode->sources().size(); ++i) { + // Presto does not support merge join. + if (joinNode->sources()[i]->name() == "MergeJoin") { + return std::nullopt; + } + if (joinNode->sources()[i]->name() == "HashJoin" || + joinNode->sources()[i]->name() == "NestedLoopJoin") { + if (auto sql = toSql(joinNode->sources()[i])) { + if (i == 0) { + probeTableName = fmt::format("({})", *sql); + } else { + buildTableName = fmt::format("({})", *sql); + } + } else { + return std::nullopt; + } + } + } + + const auto& outputNames = joinNode->outputType()->names(); std::stringstream sql; sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); @@ -678,13 +734,16 @@ std::optional PrestoQueryRunner::toSql( const std::string joinCondition{"(1 = 1)"}; switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; + sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName + << " ON " << joinCondition; break; default: VELOX_UNREACHABLE( @@ -702,6 +761,13 @@ std::optional PrestoQueryRunner::toSql( return "tmp"; } +std::multiset> PrestoQueryRunner::execute( + const std::string& sql, + const std::unordered_map>& inputMap, + const RowTypePtr& resultType) { + return exec::test::materialize(executeVector(sql, inputMap, resultType)); +} + std::multiset> PrestoQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -749,6 +815,40 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::vector PrestoQueryRunner::executeVector( + const std::string& sql, + const std::unordered_map>& inputMap, + const velox::RowTypePtr& resultType) { + std::unordered_map> inputMapWithNulls; + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMapWithNulls.insert( + {tableName, + {makeNullRows(input, fmt::format("{}x", tableName), pool())}}); + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + const std::vector& currInput = + inputMapWithNulls.contains(tableName) ? inputMapWithNulls[tableName] + : input; + auto tableDirectoryPath = createTable(tableName, currInput[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, currInput, writerPool.get()); + } + + // Run the query. + return execute(sql); +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e101..830454252302c 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -83,6 +83,14 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Creates tables for each entry in 'inputs' and runs 'sql' query. Returns + /// results according to 'resultType' schema. + std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +108,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + std::vector executeVector( + const std::string& sql, + const std::unordered_map>& + inputMap, + const velox::RowTypePtr& resultType); + std::vector executeVector( const std::string& sql, const std::vector& input, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc246..ae0ed12e6459d 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -66,6 +66,14 @@ class ReferenceQueryRunner { return true; } + /// Executes SQL query returned by the 'toSql' method using 'inputs' data. + /// Converts results using 'resultType' schema. + virtual std::multiset> execute( + const std::string& sql, + const std::unordered_map>& + inputMap, + const RowTypePtr& resultType) = 0; + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c1..0ac157f861374 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,101 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto u = ROW({"u0", "u1", "u2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto v = ROW({"v0", "v1", "v2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + auto w = ROW({"w0", "w1", "w2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (v1 > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan("t", t) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).tableScan("w", w).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test