diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index d9085215860ab..549fa4327e408 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -324,6 +324,8 @@ class ValuesNode : public PlanNode { const size_t repeatTimes_; }; +using ValuesNodePtr = std::shared_ptr; + class ArrowStreamNode : public PlanNode { public: ArrowStreamNode( diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e1..fa8eeb5787414 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -102,6 +102,18 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { return kAggregationFunctionDataSpecs; } +std::multiset> DuckQueryRunner::execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + DuckDbQueryRunner queryRunner; + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(sql, plan->outputType()); +} + std::multiset> DuckQueryRunner::execute( const std::string& sql, const std::vector& input, @@ -164,6 +176,11 @@ std::optional DuckQueryRunner::toSql( return toSql(joinNode); } + if (const auto joinNode = + std::dynamic_pointer_cast(plan)) { + return toSql(joinNode); + } + VELOX_NYI(); } @@ -340,137 +357,4 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } - -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); - } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; - - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; - - const auto& outputNames = joinNode->outputType()->names(); - - std::stringstream sql; - if (joinNode->isLeftSemiProjectJoin()) { - sql << "SELECT " - << folly::join(", ", outputNames.begin(), --outputNames.end()); - } else { - sql << "SELECT " << folly::join(", ", outputNames); - } - - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeftSemiFilter: - // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one - // result row from exactly one column for every input row. - if (joinNode->leftKeys().size() > 1) { - return std::nullopt; - } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - break; - case core::JoinType::kLeftSemiProject: - if (joinNode->isNullAware()) { - sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ") FROM t"; - } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; - } - break; - case core::JoinType::kAnti: - if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ")"; - } - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} - -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); - - // Nested loop join without filter. - VELOX_CHECK( - joinNode->joinCondition() == nullptr, - "This code path should be called only for nested loop join without filter"); - const std::string joinCondition{"(1 = 1)"}; - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} } // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af04884..5825c1d06619d 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -46,6 +46,11 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; + /// Executes SQL query returned by the 'toSql' method based on the plan. + std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns /// results according to 'resultType' schema. std::multiset> execute( @@ -60,6 +65,8 @@ class DuckQueryRunner : public ReferenceQueryRunner { const RowTypePtr& resultType) override; private: + using ReferenceQueryRunner::toSql; + std::optional toSql( const std::shared_ptr& aggregationNode); @@ -72,12 +79,6 @@ class DuckQueryRunner : public ReferenceQueryRunner { std::optional toSql( const std::shared_ptr& rowNumberNode); - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& joinNode); - std::unordered_set aggregateFunctionNames_; }; diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 1860eca9df0bb..21e4a7de81fcc 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -680,10 +680,8 @@ std::optional JoinFuzzer::computeReferenceResults( } if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + return referenceQueryRunner_->execute(*sql, plan); } - LOG(INFO) << "Query not supported by the reference DB"; return std::nullopt; } diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64df..9e0169197fd54 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -28,7 +28,6 @@ #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/dwrf/writer/Writer.h" #include "velox/exec/fuzzer/FuzzerUtil.h" -#include "velox/exec/fuzzer/ToSQLUtil.h" #include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/IPPrefixType.h" @@ -221,20 +220,6 @@ std::string toWindowCallSql( return sql.str(); } -bool isSupportedDwrfType(const TypePtr& type) { - if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { - return false; - } - - for (auto i = 0; i < type->size(); ++i) { - if (!isSupportedDwrfType(type->childAt(i))) { - return false; - } - } - - return true; -} - } // namespace const std::vector& PrestoQueryRunner::supportedScalarTypes() const { @@ -554,152 +539,10 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; - } - - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } - - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); - } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; - - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; - - const auto& outputNames = joinNode->outputType()->names(); - - std::stringstream sql; - if (joinNode->isLeftSemiProjectJoin()) { - sql << "SELECT " - << folly::join(", ", outputNames.begin(), --outputNames.end()); - } else { - sql << "SELECT " << folly::join(", ", outputNames); - } - - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeftSemiFilter: - // Multiple columns returned by a scalar subquery is not supported in - // Presto. A scalar subquery expression is a subquery that returns one - // result row from exactly one column for every input row. - if (joinNode->leftKeys().size() > 1) { - return std::nullopt; - } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - break; - case core::JoinType::kLeftSemiProject: - if (joinNode->isNullAware()) { - sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ") FROM t"; - } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; - } - break; - case core::JoinType::kAnti: - if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ")"; - } - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - return sql.str(); -} - -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); - - // Nested loop join without filter. - VELOX_CHECK( - joinNode->joinCondition() == nullptr, - "This code path should be called only for nested loop join without filter"); - const std::string joinCondition{"(1 = 1)"}; - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} - -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& valuesNode) { - if (!isSupportedDwrfType(valuesNode->outputType())) { - return std::nullopt; - } - return "tmp"; +std::multiset> PrestoQueryRunner::execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + return exec::test::materialize(executeVector(sql, plan)); } std::multiset> PrestoQueryRunner::execute( @@ -749,6 +592,36 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +std::vector PrestoQueryRunner::executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) { + std::unordered_map> inputMap = + getAllTablesAndNames(plan); + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMap[tableName] = { + makeNullRows(input, fmt::format("{}x", tableName), pool())}; + } + } + + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + auto tableDirectoryPath = createTable(tableName, input[0]->type()); + + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); + + writeToFile(filePath, input, writerPool.get()); + } + + // Run the query. + return execute(sql); +} + std::vector PrestoQueryRunner::executeVector( const std::string& sql, const std::vector& probeInput, diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e101..c4aacdad520de 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -83,6 +83,12 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Executes SQL query returned by the 'toSql' method based on the plan. + /// Returns std::nullopt if the plan is not supported. + std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) override; + std::multiset> execute( const std::string& sql, const std::vector& probeInput, @@ -100,6 +106,10 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { bool supportsVeloxVectorResults() const override; + std::vector executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) override; + std::vector executeVector( const std::string& sql, const std::vector& input, @@ -116,6 +126,8 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { } private: + using ReferenceQueryRunner::toSql; + memory::MemoryPool* pool() { return pool_.get(); } @@ -136,15 +148,6 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { std::optional toSql( const std::shared_ptr& tableWriteNode); - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& valuesNode); - std::string startQuery( const std::string& sql, const std::string& sessionProperty = ""); diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc246..b42e9f809cc56 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -17,6 +17,7 @@ #include #include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/ToSQLUtil.h" #include "velox/expression/FunctionSignature.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -66,6 +67,13 @@ class ReferenceQueryRunner { return true; } + /// Executes SQL query returned by the 'toSql' method based on the plan. + virtual std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( @@ -88,6 +96,13 @@ class ReferenceQueryRunner { return false; } + /// Similar to 'execute' but returns results in RowVector format. + virtual std::vector executeVector( + const std::string& sql, + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Similar to 'execute' but returns results in RowVector format. /// Caller should ensure 'supportsVeloxVectorResults' returns true. virtual std::vector executeVector( @@ -115,6 +130,233 @@ class ReferenceQueryRunner { const std::string& sessionProperty) { VELOX_UNSUPPORTED(); } + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); + } + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTablesAndNames(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTablesAndNames(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; + } + + bool isSupportedDwrfType(const TypePtr& type) { + if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { + return false; + } + + for (auto i = 0; i < type->size(); ++i) { + if (!isSupportedDwrfType(type->childAt(i))) { + return false; + } + } + + return true; + } + + static const std::string joinKeysToSql( + const std::vector& keys) { + std::stringstream out; + for (auto i = 0; i < keys.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << keys[i]->name(); + } + return out.str(); + }; + + static std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); + }; + + static std::string joinConditionAsSql( + const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; + } + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + if (!joinNode.leftKeys().empty()) { + out << " AND "; + } + out << filterToSql(joinNode.filter()); + } + return out.str(); + }; + + /// Same as the above toSql but for values join nodes. + virtual std::optional toSql( + const core::ValuesNodePtr& valuesNode) { + if (!isSupportedDwrfType(valuesNode->outputType())) { + return std::nullopt; + } + return getTableName(valuesNode); + } + + /// Same as the above toSql but for hash join nodes. + virtual std::optional toSql( + const std::shared_ptr& joinNode) { + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } + std::string probeTableName; + std::string buildTableName; + const std::optional probeSubQuery = + toSql(joinNode->sources()[0]); + const std::optional buildSubQuery = + toSql(joinNode->sources()[1]); + if (probeSubQuery && buildSubQuery) { + probeTableName = probeSubQuery->find(" ") != std::string::npos + ? fmt::format("({})", *probeSubQuery) + : *probeSubQuery; + buildTableName = buildSubQuery->find(" ") != std::string::npos + ? fmt::format("({})", *buildSubQuery) + : *buildSubQuery; + } else { + return std::nullopt; + } + + const auto& outputNames = joinNode->outputType()->names(); + + std::stringstream sql; + if (joinNode->isLeftSemiProjectJoin()) { + sql << "SELECT " + << folly::join(", ", outputNames.begin(), --outputNames.end()); + } else { + sql << "SELECT " << folly::join(", ", outputNames); + } + + switch (joinNode->joinType()) { + case core::JoinType::kInner: + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kLeft: + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kFull: + sql << " FROM " << probeTableName << " FULL OUTER JOIN " + << buildTableName << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kLeftSemiFilter: + // Multiple columns returned by a scalar subquery is not supported in + // Presto. A scalar subquery expression is a subquery that returns one + // result row from exactly one column for every input row. + if (joinNode->leftKeys().size() > 1) { + return std::nullopt; + } + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; + break; + case core::JoinType::kLeftSemiProject: + if (joinNode->isNullAware()) { + sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ") FROM " << probeTableName; + } else { + sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << probeTableName; + } + break; + case core::JoinType::kAnti: + if (joinNode->isNullAware()) { + sql << " FROM " << probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; + } else { + sql << " FROM " << probeTableName + << " WHERE NOT EXISTS (SELECT * FROM " << buildTableName + << " WHERE " << joinConditionAsSql(*joinNode); + sql << ")"; + } + break; + default: + VELOX_UNREACHABLE( + "Unknown join type: {}", static_cast(joinNode->joinType())); + } + return sql.str(); + } + + virtual std::optional toSql( + const std::shared_ptr& joinNode) { + std::string probeTableName; + std::string buildTableName; + const std::optional probeSubQuery = + toSql(joinNode->sources()[0]); + const std::optional buildSubQuery = + toSql(joinNode->sources()[1]); + if (probeSubQuery && buildSubQuery) { + probeTableName = probeSubQuery->find(" ") != std::string::npos + ? fmt::format("({})", *probeSubQuery) + : *probeSubQuery; + buildTableName = buildSubQuery->find(" ") != std::string::npos + ? fmt::format("({})", *buildSubQuery) + : *buildSubQuery; + } else { + return std::nullopt; + } + + std::stringstream sql; + sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); + + // Nested loop join without filter. + VELOX_CHECK( + joinNode->joinCondition() == nullptr, + "This code path should be called only for nested loop join without filter"); + const std::string joinCondition{"(1 = 1)"}; + switch (joinNode->joinType()) { + case core::JoinType::kInner: + sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName + << " ON " << joinCondition; + break; + case core::JoinType::kLeft: + sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName + << " ON " << joinCondition; + break; + case core::JoinType::kFull: + sql << " FROM " << probeTableName << " FULL OUTER JOIN " + << buildTableName << " ON " << joinCondition; + break; + default: + VELOX_UNREACHABLE( + "Unknown join type: {}", static_cast(joinNode->joinType())); + } + return sql.str(); + } protected: memory::MemoryPool* aggregatePool() { diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c1..14447f5eb3967 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,122 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = makeRowVector( + {"t0", "t1", "t2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto u = makeRowVector( + {"u0", "u1", "u2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto v = makeRowVector( + {"v0", "v1", "v2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto w = makeRowVector( + {"w0", "w1", "w2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (cast(v1 as BIGINT) > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).values({w}).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test