From f85f303a9d46d21fe3ed8acc6336583bb61a299f Mon Sep 17 00:00:00 2001 From: Sy Brand Date: Sat, 7 Dec 2024 11:22:50 +0000 Subject: [PATCH] Add k_combinations --- include/tl/k_combinations.hpp | 180 ++++++++++++++++++++++++++++++++++ tests/k_combinations.cpp | 30 ++++++ 2 files changed, 210 insertions(+) create mode 100644 include/tl/k_combinations.hpp create mode 100644 tests/k_combinations.cpp diff --git a/include/tl/k_combinations.hpp b/include/tl/k_combinations.hpp new file mode 100644 index 0000000..d7bdd7f --- /dev/null +++ b/include/tl/k_combinations.hpp @@ -0,0 +1,180 @@ +#ifndef TL_RANGES_K_COMBINATIONS +#define TL_RANGES_K_COMBINATIONS + + +#include +#include +#include +#include +#include + +namespace tl { + template + requires std::ranges::view + class k_combinations_view : public std::ranges::view_interface> { + public: + template + class cursor; + + template + class sentinel { + public: + using base = std::ranges::iterator_t >; + sentinel() = default; + sentinel(base end) : end_(std::move(end)) {} + template + friend class cursor; + + private: + base end_; + }; + + template + class cursor { + private: + template + using constify = std::conditional_t; + + public: + cursor() = default; + constexpr explicit cursor(constify* base, std::size_t n, std::ranges::iterator_t> it) : + base_(base), current_(n, it) { + } + + //const-converting constructor + constexpr cursor(cursor i) requires Const&& std::convertible_to< + std::ranges::iterator_t, + std::ranges::iterator_t>> + : base_(i.base_), current_{ std::move(i.current_) } { + } + + constexpr decltype(auto) read() const { + return std::views::transform(current_, [](auto&& i) { + return std::ref(*i); + }); + } + + constexpr void next() { + auto it = current_.rbegin(); + auto end = current_.rend(); + while (it != end) { + ++(*it); + if (*it == std::ranges::end(*base_)) { + if (it != end - 1) { + *it = std::ranges::begin(*base_); + } + ++it; + } + else { + break; + } + } + } + + void prev() requires (std::ranges::bidirectional_range>) { + auto it = current_.rbegin(); + auto end = current_.rend(); + while (it != end) { + if (*it == std::ranges::begin(*base_)) { + std::ranges::advance(*it, std::ranges::end(*base_)); + ++it; + } + --(*it); + } + } + + //TODO advance + + constexpr bool equal(const cursor& rhs) const + requires (std::equality_comparable>>) { + return current_ == rhs.current_; + } + + constexpr bool equal(const sentinel& s) const { + return current_.front() == s.end_; + } + + constexpr auto distance_to(cursor const& other) const + requires (std::ranges::sized_range>) { + std::ptrdiff_t distance = 0; + for (std::size_t i = current_.size() - 1; i >= 0; --i) { + distance += std::ranges::distance(current_[i], other.current_[i]) * std::ranges::size(*base_); + } + return distance; + } + + private: + constify* base_; + std::vector> current_; + + friend class cursor; + friend class k_combinations_view; + }; + + constexpr k_combinations_view() = default; + + constexpr explicit k_combinations_view(V view, std::size_t n) + : base_(std::move(view)), n_(n) { + } + + constexpr auto begin() requires(!tl::simple_view) { + return basic_iterator{ cursor{ std::addressof(base_), n_, std::ranges::begin(base_)} }; + } + + constexpr auto begin() const + requires(std::ranges::range) { + return basic_iterator{ cursor{ std::addressof(base_), n_, std::ranges::begin(base_) } }; + } + + constexpr auto end() requires(!tl::simple_view) { + if constexpr (std::ranges::common_range and std::ranges::sized_range) { + return basic_iterator{ cursor(std::addressof(base_), n_, std::ranges::end(base_)) }; + } + else { + return sentinel{std::ranges::end(base_)}; + } + } + + constexpr auto end() const + requires(std::ranges::range) { + if constexpr (std::ranges::common_range and std::ranges::sized_range) { + return basic_iterator{ cursor(std::addressof(base_), n_, std::ranges::end(base_)) }; + } + else { + return sentinel{std::ranges::end(base_)}; + } + } + + constexpr auto size() requires(std::ranges::sized_range) { + return std::pow(std::ranges::size(base_), n_); + } + + constexpr auto size() const requires(std::ranges::sized_range) { + return std::pow(std::ranges::size(base_), n_); + } + + private: + V base_; + std::size_t n_; + }; + + template + k_combinations_view(R&&, std::size_t n) -> k_combinations_view>; + + namespace views { + namespace detail { + struct k_combinations_fn { + template + constexpr auto operator()(R&& r, std::size_t n) const + requires (std::ranges::forward_range) { + return k_combinations_view(std::forward(r), n); + } + }; + + } + + constexpr inline detail::k_combinations_fn k_combinations; + } +} + +#endif diff --git a/tests/k_combinations.cpp b/tests/k_combinations.cpp new file mode 100644 index 0000000..f5e18a2 --- /dev/null +++ b/tests/k_combinations.cpp @@ -0,0 +1,30 @@ +#include +#include +#include + +TEST_CASE("k_combinations") { + std::vector a{ 0, 1, 2 }; + { + auto k_combinations = tl::views::k_combinations(a, 2); + auto it = std::ranges::begin(k_combinations); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + REQUIRE(std::ranges::equal(std::vector{ i, j }, *it)); + ++it; + } + } + } + + { + auto k_combinations = tl::views::k_combinations(a, 3); + auto it = std::ranges::begin(k_combinations); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + REQUIRE(std::ranges::equal(std::vector{ i, j, k }, *it)); + ++it; + } + } + } + } +} \ No newline at end of file