Skip to content

Commit

Permalink
fix(fuzzer): Add filter parsing to toSql methods for hasJoinNode in R…
Browse files Browse the repository at this point in the history
…eferenceQueryRunners (facebookincubator#11566)

Summary:

This change updates both the DuckQueryRunner and PrestoQueryRunner to parse filters in their hasJoinNode toSql methods.

Differential Revision: D66021799
  • Loading branch information
Daniel Hunte authored and facebook-github-bot committed Nov 25, 2024
1 parent 78d761b commit 47de11d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 21 deletions.
44 changes: 34 additions & 10 deletions velox/exec/fuzzer/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,12 @@ std::optional<std::string> DuckQueryRunner::toSql(
return out.str();
};

const auto& equiClausesToSql = [](auto joinNode) {
const auto filterToSql = [](core::TypedExprPtr filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};

const auto& joinConditionAsSql = [&](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -363,6 +368,9 @@ std::optional<std::string> DuckQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
out << " AND " << filterToSql(joinNode->filter());
}
return out.str();
};

Expand All @@ -378,39 +386,55 @@ std::optional<std::string> DuckQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
// Multiple columns returned by a scalar subquery is not supported in
// DuckDB.
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ") FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode);
sql << ")";
}
break;
default:
Expand Down
45 changes: 34 additions & 11 deletions velox/exec/fuzzer/PrestoQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,12 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return out.str();
};

const auto equiClausesToSql = [](auto joinNode) {
const auto filterToSql = [](core::TypedExprPtr filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};

const auto& joinConditionAsSql = [&](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -578,6 +583,9 @@ std::optional<std::string> PrestoQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
out << " AND " << filterToSql(joinNode->filter());
}
return out.str();
};

Expand All @@ -593,46 +601,61 @@ std::optional<std::string> PrestoQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
// Multiple columns returned by a scalar subquery is not supported in
// Presto.
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ") FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode);
sql << ")";
}
break;
default:
VELOX_UNREACHABLE(
"Unknown join type: {}", static_cast<int>(joinNode->joinType()));
}

return sql.str();
}

Expand Down

0 comments on commit 47de11d

Please sign in to comment.