Skip to content

Commit

Permalink
feat(fuzzer): Support multiple joins in the hash node "toSql" method …
Browse files Browse the repository at this point in the history
…in Presto QR (facebookincubator#11801)

Summary:

Currently, the hash join "toSql" method for PrestoQueryRunner only supports a single join. This change extends it to support multiple joins, only needing the join node of the last join in the tree. It traverses up the tree and recursively builds the sql query.

Differential Revision: D66977480
  • Loading branch information
Daniel Hunte authored and facebook-github-bot committed Dec 9, 2024
1 parent dfad495 commit 6e9702c
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 60 deletions.
138 changes: 78 additions & 60 deletions velox/exec/fuzzer/PrestoQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,46 +548,70 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return sql.str();
}

std::optional<std::string> PrestoQueryRunner::toSql(
const std::shared_ptr<const core::HashJoinNode>& joinNode) {
if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) {
return std::nullopt;
const std::string joinKeysToSql(
const std::vector<core::FieldAccessTypedExprPtr>& keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
}
return out.str();
};

if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) {
return std::nullopt;
}
std::string filterToSql(const core::TypedExprPtr& filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};

const auto joinKeysToSql = [](auto keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode.leftKeys().size(); ++i) {
if (i > 0) {
out << " AND ";
}
return out.str();
};

const auto filterToSql = [](core::TypedExprPtr filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};
out << joinNode.leftKeys()[i]->name() << " = "
<< joinNode.rightKeys()[i]->name();
}
if (joinNode.filter()) {
out << " AND " << filterToSql(joinNode.filter());
}
return out.str();
};

const auto& joinConditionAsSql = [&](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
out << " AND ";
}
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
std::optional<std::string> PrestoQueryRunner::toSql(
const std::shared_ptr<const core::HashJoinNode>& joinNode) {
std::string probeTableName =
fmt::format("t_{}", joinNode->sources()[0]->id());
std::string buildTableName =
fmt::format("t_{}", joinNode->sources()[1]->id());

// If an input to this join is another join, change the table name to a
// sub-query.
for (auto i = 0; i < joinNode->sources().size(); ++i) {
// Presto does not support merge join.
if (joinNode->sources()[i]->name() == "MergeJoin") {
return std::nullopt;
}
if (joinNode->filter()) {
out << " AND " << filterToSql(joinNode->filter());
if (joinNode->sources()[i]->name() == "HashJoin" ||
joinNode->sources()[i]->name() == "NestedLoopJoin") {
if (auto sql = toSql(joinNode->sources()[i])) {
if (i == 0) {
probeTableName = fmt::format("({})", *sql);
} else {
buildTableName = fmt::format("({})", *sql);
}
} else {
return std::nullopt;
}
}
return out.str();
};
}

if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) ||
!isSupportedDwrfType(joinNode->sources()[1]->outputType())) {
return std::nullopt;
}

const auto& outputNames = joinNode->outputType()->names();

Expand All @@ -601,13 +625,16 @@ std::optional<std::string> PrestoQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
sql << " FROM " << probeTableName << " INNER JOIN " << buildTableName
<< " ON " << joinConditionAsSql(*joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
sql << " FROM " << probeTableName << " LEFT JOIN " << buildTableName
<< " ON " << joinConditionAsSql(*joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
sql << " FROM " << probeTableName << " FULL OUTER JOIN " << buildTableName
<< " ON " << joinConditionAsSql(*joinNode);
break;
case core::JoinType::kLeftSemiFilter:
// Multiple columns returned by a scalar subquery is not supported in
Expand All @@ -616,9 +643,9 @@ std::optional<std::string> PrestoQueryRunner::toSql(
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u";
sql << " FROM " << probeTableName << " WHERE "
<< joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM " << buildTableName;
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
Expand All @@ -627,29 +654,31 @@ std::optional<std::string> PrestoQueryRunner::toSql(
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u";
<< joinKeysToSql(joinNode->rightKeys()) << " FROM "
<< buildTableName;
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ") FROM t";
sql << ") FROM " << probeTableName;
} else {
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << ") FROM t";
sql << ", EXISTS (SELECT * FROM " << buildTableName << " WHERE "
<< joinConditionAsSql(*joinNode);
sql << ") FROM " << probeTableName;
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u";
sql << " FROM " << probeTableName << " WHERE "
<< joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM "
<< buildTableName;
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << " FROM " << probeTableName << " WHERE NOT EXISTS (SELECT * FROM "
<< buildTableName << " WHERE " << joinConditionAsSql(*joinNode);
sql << ")";
}
break;
Expand All @@ -662,17 +691,6 @@ std::optional<std::string> PrestoQueryRunner::toSql(

std::optional<std::string> PrestoQueryRunner::toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode) {
const auto& joinKeysToSql = [](auto keys) {
std::stringstream out;
for (auto i = 0; i < keys.size(); ++i) {
if (i > 0) {
out << ", ";
}
out << keys[i]->name();
}
return out.str();
};

const auto& outputNames = joinNode->outputType()->names();
std::stringstream sql;

Expand Down
97 changes: 97 additions & 0 deletions velox/exec/tests/PrestoQueryRunnerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,101 @@ TEST_F(PrestoQueryRunnerTest, toSql) {
}
}

TEST_F(PrestoQueryRunnerTest, toSqlJoins) {
auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins");
auto queryRunner = std::make_unique<PrestoQueryRunner>(
aggregatePool.get(),
"http://unused",
"hive",
static_cast<std::chrono::milliseconds>(1000));

auto t = ROW({"t0", "t1", "t2"}, {BIGINT(), BIGINT(), BOOLEAN()});
auto u = ROW({"u0", "u1", "u2"}, {BIGINT(), BIGINT(), BOOLEAN()});
auto v = ROW({"v0", "v1", "v2"}, {BIGINT(), BIGINT(), BOOLEAN()});
auto w = ROW({"w0", "w1", "w2"}, {BIGINT(), BIGINT(), BOOLEAN()});

// Single join.
{
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.tableScan("t", t)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(),
/*filter=*/"",
{"t0", "t1"},
core::JoinType::kInner)
.planNode();
EXPECT_EQ(
*queryRunner->toSql(plan),
"SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0");
}

// Two joins with a filter.
{
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.tableScan("t", t)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(),
/*filter=*/"",
{"t0"},
core::JoinType::kLeftSemiFilter)
.hashJoin(
{"t0"},
{"v0"},
PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(),
"v1 > 0",
{"t0", "v1"},
core::JoinType::kInner)
.planNode();
EXPECT_EQ(
*queryRunner->toSql(plan),
"SELECT t0, v1"
" FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))"
" INNER JOIN t_3 ON t0 = v0 AND (v1 > BIGINT '0')");
}

// Three joins.
{
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.tableScan("t", t)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).tableScan("u", u).planNode(),
/*filter=*/"",
{"t0", "t1"},
core::JoinType::kLeft)
.hashJoin(
{"t0"},
{"v0"},
PlanBuilder(planNodeIdGenerator).tableScan("v", v).planNode(),
/*filter=*/"",
{"t0", "v1"},
core::JoinType::kInner)
.hashJoin(
{"t0", "v1"},
{"w0", "w1"},
PlanBuilder(planNodeIdGenerator).tableScan("w", w).planNode(),
/*filter=*/"",
{"t0", "w1"},
core::JoinType::kFull)
.planNode();
EXPECT_EQ(
*queryRunner->toSql(plan),
"SELECT t0, w1"
" FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)"
" INNER JOIN t_3 ON t0 = v0)"
" FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1");
}
}

} // namespace facebook::velox::exec::test

0 comments on commit 6e9702c

Please sign in to comment.