From 47de11deaca8525c0e7aee0eca56f40fe26601dc Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Mon, 25 Nov 2024 00:32:22 -0800 Subject: [PATCH] fix(fuzzer): Add filter parsing to toSql methods for hasJoinNode in ReferenceQueryRunners (#11566) Summary: This change updates both the DuckQueryRunner and PrestoQueryRunner to parse filters in their hasJoinNode toSql methods. Differential Revision: D66021799 --- velox/exec/fuzzer/DuckQueryRunner.cpp | 44 ++++++++++++++++++------ velox/exec/fuzzer/PrestoQueryRunner.cpp | 45 +++++++++++++++++++------ 2 files changed, 68 insertions(+), 21 deletions(-) diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index 15b74efae88eb..3b10efef767c4 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -354,7 +354,12 @@ std::optional DuckQueryRunner::toSql( return out.str(); }; - const auto& equiClausesToSql = [](auto joinNode) { + const auto filterToSql = [](core::TypedExprPtr filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); + }; + + const auto& joinConditionAsSql = [&](auto joinNode) { std::stringstream out; for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { if (i > 0) { @@ -363,6 +368,9 @@ std::optional DuckQueryRunner::toSql( out << joinNode->leftKeys()[i]->name() << " = " << joinNode->rightKeys()[i]->name(); } + if (joinNode->filter()) { + out << " AND " << filterToSql(joinNode->filter()); + } return out.str(); }; @@ -378,39 +386,55 @@ std::optional DuckQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kLeftSemiFilter: + // Multiple columns returned by a scalar subquery is not supported in + // DuckDB. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u)"; + << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; break; case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ") FROM t"; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode) - << ") FROM t"; + sql << ", EXISTS (SELECT * FROM u WHERE " + << joinConditionAsSql(joinNode); + sql << ") FROM t"; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u)"; + << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; } else { sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << equiClausesToSql(joinNode) << ")"; + << joinConditionAsSql(joinNode); + sql << ")"; } break; default: diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index d1613579b265e..3e9119373e12d 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -569,7 +569,12 @@ std::optional PrestoQueryRunner::toSql( return out.str(); }; - const auto equiClausesToSql = [](auto joinNode) { + const auto filterToSql = [](core::TypedExprPtr filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); + }; + + const auto& joinConditionAsSql = [&](auto joinNode) { std::stringstream out; for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { if (i > 0) { @@ -578,6 +583,9 @@ std::optional PrestoQueryRunner::toSql( out << joinNode->leftKeys()[i]->name() << " = " << joinNode->rightKeys()[i]->name(); } + if (joinNode->filter()) { + out << " AND " << filterToSql(joinNode->filter()); + } return out.str(); }; @@ -593,46 +601,61 @@ std::optional PrestoQueryRunner::toSql( switch (joinNode->joinType()) { case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode); + sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); break; case core::JoinType::kLeftSemiFilter: + // Multiple columns returned by a scalar subquery is not supported in + // Presto. if (joinNode->leftKeys().size() > 1) { return std::nullopt; } sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u)"; + << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; break; case core::JoinType::kLeftSemiProject: if (joinNode->isNullAware()) { sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t"; + << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ") FROM t"; } else { - sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode) - << ") FROM t"; + sql << ", EXISTS (SELECT * FROM u WHERE " + << joinConditionAsSql(joinNode); + sql << ") FROM t"; } break; case core::JoinType::kAnti: if (joinNode->isNullAware()) { sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u)"; + << " FROM u"; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; } else { sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << equiClausesToSql(joinNode) << ")"; + << joinConditionAsSql(joinNode); + sql << ")"; } break; default: VELOX_UNREACHABLE( "Unknown join type: {}", static_cast(joinNode->joinType())); } - return sql.str(); }