Skip to content

Commit

Permalink
refactor: Refactor spiller to have better abstraction (#11656)
Browse files Browse the repository at this point in the history
Summary:
Spiller is a single class dealing with multiple different spilling scenarios. We want to abstract out Spiller and make its implementations closer to its use sites and expose different use case APIs.

Pull Request resolved: #11656

Reviewed By: xiaoxmeng

Differential Revision: D67666563

Pulled By: tanjialiang

fbshipit-source-id: 66c5cfc307e248f00c1ad871409023f05d0da4d0
  • Loading branch information
tanjialiang authored and facebook-github-bot committed Dec 28, 2024
1 parent f4ac9dd commit f973c65
Show file tree
Hide file tree
Showing 30 changed files with 1,226 additions and 1,227 deletions.
112 changes: 93 additions & 19 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,23 @@ void GroupingSet::noMoreInput() {
addRemainingInput();
}

VELOX_CHECK_NULL(outputSpiller_);
// Spill the remaining in-memory state to disk if spilling has been triggered
// on this grouping set. This is to simplify query OOM prevention when
// producing output as we don't support to spill during that stage as for now.
if (hasSpilled()) {
if (inputSpiller_ != nullptr) {
spill();
}

ensureOutputFits();
}

bool GroupingSet::hasSpilled() const {
return spiller_ != nullptr;
if (inputSpiller_ != nullptr) {
VELOX_CHECK_NULL(outputSpiller_);
return true;
}
return outputSpiller_ != nullptr;
}

bool GroupingSet::hasOutput() {
Expand Down Expand Up @@ -980,6 +985,18 @@ RowTypePtr GroupingSet::makeSpillType() const {
return ROW(std::move(names), std::move(types));
}

std::optional<common::SpillStats> GroupingSet::spilledStats() const {
if (!hasSpilled()) {
return std::nullopt;
}
if (inputSpiller_ != nullptr) {
VELOX_CHECK_NULL(outputSpiller_);
return inputSpiller_->stats();
}
VELOX_CHECK_NOT_NULL(outputSpiller_);
return outputSpiller_->stats();
}

void GroupingSet::spill() {
// NOTE: if the disk spilling is triggered by the memory arbitrator, then it
// is possible that the grouping set hasn't processed any input data yet.
Expand All @@ -989,11 +1006,11 @@ void GroupingSet::spill() {
}

auto* rows = table_->rows();
if (!hasSpilled()) {
VELOX_CHECK_NULL(outputSpiller_);
if (inputSpiller_ == nullptr) {
VELOX_DCHECK(pool_.trackUsage());
VELOX_CHECK(numDistinctSpillFilesPerPartition_.empty());
spiller_ = std::make_unique<Spiller>(
Spiller::Type::kAggregateInput,
inputSpiller_ = std::make_unique<AggregationInputSpiller>(
rows,
makeSpillType(),
HashBitRange(
Expand All @@ -1006,22 +1023,23 @@ void GroupingSet::spill() {
spillConfig_,
spillStats_);
VELOX_CHECK_EQ(
spiller_->state().maxPartitions(), 1 << spillConfig_->numPartitionBits);
inputSpiller_->state().maxPartitions(),
1 << spillConfig_->numPartitionBits);
}
// Spilling may execute on multiple partitions in parallel, and
// HashStringAllocator is not thread safe. If any aggregations
// allocate/deallocate memory during spilling it can lead to concurrency bugs.
// Freeze the HashStringAllocator to make it effectively immutable and
// guarantee we don't accidentally enter an unsafe situation.
rows->stringAllocator().freezeAndExecute([&]() { spiller_->spill(); });
rows->stringAllocator().freezeAndExecute([&]() { inputSpiller_->spill(); });
if (isDistinct() && numDistinctSpillFilesPerPartition_.empty()) {
size_t totalNumDistinctSpilledFiles{0};
numDistinctSpillFilesPerPartition_.resize(
spiller_->state().maxPartitions(), 0);
for (int partition = 0; partition < spiller_->state().maxPartitions();
inputSpiller_->state().maxPartitions(), 0);
for (int partition = 0; partition < inputSpiller_->state().maxPartitions();
++partition) {
numDistinctSpillFilesPerPartition_[partition] =
spiller_->state().numFinishedFiles(partition);
inputSpiller_->state().numFinishedFiles(partition);
totalNumDistinctSpilledFiles +=
numDistinctSpillFilesPerPartition_[partition];
}
Expand All @@ -1042,20 +1060,16 @@ void GroupingSet::spill(const RowContainerIterator& rowIterator) {

auto* rows = table_->rows();
VELOX_CHECK(pool_.trackUsage());
spiller_ = std::make_unique<Spiller>(
Spiller::Type::kAggregateOutput,
rows,
makeSpillType(),
spillConfig_,
spillStats_);
outputSpiller_ = std::make_unique<AggregationOutputSpiller>(
rows, makeSpillType(), spillConfig_, spillStats_);

// Spilling may execute on multiple partitions in parallel, and
// HashStringAllocator is not thread safe. If any aggregations
// allocate/deallocate memory during spilling it can lead to concurrency bugs.
// Freeze the HashStringAllocator to make it effectively immutable and
// guarantee we don't accidentally enter an unsafe situation.
rows->stringAllocator().freezeAndExecute(
[&]() { spiller_->spill(rowIterator); });
[&]() { outputSpiller_->spill(rowIterator); });
table_->clear(/*freeTable=*/true);
}

Expand Down Expand Up @@ -1091,7 +1105,13 @@ bool GroupingSet::getOutputWithSpill(
table_->clear(/*freeTable=*/true);

VELOX_CHECK_NULL(merge_);
spiller_->finishSpill(spillPartitionSet_);
if (inputSpiller_ != nullptr) {
VELOX_CHECK_NULL(outputSpiller_);
inputSpiller_->finishSpill(spillPartitionSet_);
} else {
VELOX_CHECK_NOT_NULL(outputSpiller_);
outputSpiller_->finishSpill(spillPartitionSet_);
}
removeEmptyPartitions(spillPartitionSet_);

if (!prepareNextSpillPartitionOutput()) {
Expand Down Expand Up @@ -1176,9 +1196,11 @@ bool GroupingSet::mergeNextWithoutAggregates(
const RowVectorPtr& result) {
VELOX_CHECK_NOT_NULL(merge_);
VELOX_CHECK(isDistinct());
VELOX_CHECK_NULL(outputSpiller_);
VELOX_CHECK_NOT_NULL(inputSpiller_);
VELOX_CHECK_EQ(
numDistinctSpillFilesPerPartition_.size(),
spiller_->state().maxPartitions());
inputSpiller_->state().maxPartitions());

// We are looping over sorted rows produced by tree-of-losers. We logically
// split the stream into runs of duplicate rows. As we process each run we
Expand Down Expand Up @@ -1414,4 +1436,56 @@ std::optional<int64_t> GroupingSet::estimateOutputRowSize() const {
}
return table_->rows()->estimateRowSize();
}

AggregationInputSpiller::AggregationInputSpiller(
RowContainer* container,
RowTypePtr rowType,
const HashBitRange& hashBitRange,
int32_t numSortingKeys,
const std::vector<CompareFlags>& sortCompareFlags,
const common::SpillConfig* spillConfig,
folly::Synchronized<common::SpillStats>* spillStats)
: SpillerBase(
container,
std::move(rowType),
hashBitRange,
numSortingKeys,
sortCompareFlags,
std::numeric_limits<uint64_t>::max(),
spillConfig->maxSpillRunRows,
spillConfig,
spillStats) {}

AggregationOutputSpiller::AggregationOutputSpiller(
RowContainer* container,
RowTypePtr rowType,
const common::SpillConfig* spillConfig,
folly::Synchronized<common::SpillStats>* spillStats)
: SpillerBase(
container,
std::move(rowType),
HashBitRange{},
0,
{},
std::numeric_limits<uint64_t>::max(),
spillConfig->maxSpillRunRows,
spillConfig,
spillStats) {}

void AggregationInputSpiller::spill() {
SpillerBase::spill(nullptr);
}

void AggregationOutputSpiller::spill(const RowContainerIterator& startRowIter) {
SpillerBase::spill(&startRowIter);
}

void AggregationOutputSpiller::runSpill(bool lastRun) {
SpillerBase::runSpill(lastRun);
if (lastRun) {
for (auto partition = 0; partition < spillRuns_.size(); ++partition) {
state_.finishFile(partition);
}
}
}
} // namespace facebook::velox::exec
70 changes: 59 additions & 11 deletions velox/exec/GroupingSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "velox/exec/VectorHasher.h"

namespace facebook::velox::exec {
class AggregationInputSpiller;
class AggregationOutputSpiller;

class GroupingSet {
public:
Expand All @@ -46,7 +48,7 @@ class GroupingSet {

~GroupingSet();

// Used by MarkDistinct operator to identify rows with unique values.
/// Used by MarkDistinct operator to identify rows with unique values.
static std::unique_ptr<GroupingSet> createForMarkDistinct(
const RowTypePtr& inputType,
std::vector<std::unique_ptr<VectorHasher>>&& hashers,
Expand Down Expand Up @@ -110,16 +112,12 @@ class GroupingSet {
void spill();

/// Spills all the rows in container starting from the offset specified by
/// 'rowIterator'.
/// 'rowIterator'. This should be only called during output processing and
/// when no spill has occurred previously.
void spill(const RowContainerIterator& rowIterator);

/// Returns the spiller stats including total bytes and rows spilled so far.
std::optional<common::SpillStats> spilledStats() const {
if (spiller_ == nullptr) {
return std::nullopt;
}
return spiller_->stats();
}
std::optional<common::SpillStats> spilledStats() const;

/// Returns true if spilling has triggered on this grouping set.
bool hasSpilled() const;
Expand All @@ -134,8 +132,8 @@ class GroupingSet {
return table_ ? table_->rows()->numRows() : 0;
}

// Frees hash tables and other state when giving up partial aggregation as
// non-productive. Must be called before toIntermediate() is used.
/// Frees hash tables and other state when giving up partial aggregation as
/// non-productive. Must be called before toIntermediate() is used.
void abandonPartialAggregation();

/// Translates the raw input in input to accumulators initialized from a
Expand Down Expand Up @@ -342,7 +340,9 @@ class GroupingSet {
// 'remainingInput_'.
bool remainingMayPushdown_;

std::unique_ptr<Spiller> spiller_;
std::unique_ptr<AggregationInputSpiller> inputSpiller_;

std::unique_ptr<AggregationOutputSpiller> outputSpiller_;

// The current spill partition in producing spill output. If it is -1, then we
// haven't started yet.
Expand Down Expand Up @@ -391,4 +391,52 @@ class GroupingSet {
folly::Synchronized<common::SpillStats>* const spillStats_;
};

class AggregationInputSpiller : public SpillerBase {
public:
static constexpr std::string_view kType = "AggregationInputSpiller";

AggregationInputSpiller(
RowContainer* container,
RowTypePtr rowType,
const HashBitRange& hashBitRange,
int32_t numSortingKeys,
const std::vector<CompareFlags>& sortCompareFlags,
const common::SpillConfig* spillConfig,
folly::Synchronized<common::SpillStats>* spillStats);

void spill();

private:
std::string type() const override {
return std::string(kType);
}

bool needSort() const override {
return true;
}
};

class AggregationOutputSpiller : public SpillerBase {
public:
static constexpr std::string_view kType = "AggregationOutputSpiller";

AggregationOutputSpiller(
RowContainer* container,
RowTypePtr rowType,
const common::SpillConfig* spillConfig,
folly::Synchronized<common::SpillStats>* spillStats);

void spill(const RowContainerIterator& startRowIter);

private:
std::string type() const override {
return std::string(kType);
}

void runSpill(bool lastRun) override;

bool needSort() const override {
return false;
}
};
} // namespace facebook::velox::exec
Loading

0 comments on commit f973c65

Please sign in to comment.