Skip to content

Commit

Permalink
[RNG] Add geometric distribution to Device API (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
iMartyan authored Dec 23, 2024
1 parent db520e4 commit be58ee0
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/oneapi/math/rng/device/detail/distribution_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class poisson;
template <typename IntType = std::uint32_t, typename Method = bernoulli_method::by_default>
class bernoulli;

template <typename IntType = std::uint32_t, typename Method = geometric_method::by_default>
class geometric;

} // namespace oneapi::math::rng::device

#include "oneapi/math/rng/device/detail/uniform_impl.hpp"
Expand All @@ -75,6 +78,7 @@ class bernoulli;
#include "oneapi/math/rng/device/detail/exponential_impl.hpp"
#include "oneapi/math/rng/device/detail/poisson_impl.hpp"
#include "oneapi/math/rng/device/detail/bernoulli_impl.hpp"
#include "oneapi/math/rng/device/detail/geometric_impl.hpp"
#include "oneapi/math/rng/device/detail/beta_impl.hpp"
#include "oneapi/math/rng/device/detail/gamma_impl.hpp"

Expand Down
99 changes: 99 additions & 0 deletions include/oneapi/math/rng/device/detail/geometric_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
#define ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_

namespace oneapi::math::rng::device::detail {

template <typename IntType, typename Method>
class distribution_base<oneapi::math::rng::device::geometric<IntType, Method>> {
public:
struct param_type {
param_type(float p) : p_(p) {}
float p_;
};

distribution_base(float p) : p_(p) {
#ifndef __SYCL_DEVICE_ONLY__
if ((p > 1.0f) || (p < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
}

float p() const {
return p_;
}

param_type param() const {
return param_type(p_);
}

void param(const param_type& pt) {
#ifndef __SYCL_DEVICE_ONLY__
if ((pt.p_ > 1.0f) || (pt.p_ < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
p_ = pt.p_;
}

protected:
template <typename EngineType>
auto generate(EngineType& engine) ->
typename std::conditional<EngineType::vec_size == 1, IntType,
sycl::vec<IntType, EngineType::vec_size>>::type {
using FpType = typename std::conditional<std::is_same_v<IntType, std::uint64_t> ||
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

auto uni_res = engine.generate(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
if constexpr (EngineType::vec_size == 1) {
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}
else {
sycl::vec<IntType, EngineType::vec_size> vec_out;
for (int i = 0; i < EngineType::vec_size; i++) {
vec_out[i] = static_cast<IntType>(sycl::floor(ln_wrapper(uni_res[i]) * inv_ln));
}
return vec_out;
}
}

template <typename EngineType>
IntType generate_single(EngineType& engine) {
using FpType = typename std::conditional<std::is_same_v<IntType, std::uint64_t> ||
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

FpType uni_res = engine.generate_single(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}

float p_;
};

} // namespace oneapi::math::rng::device::detail

#endif // ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
58 changes: 58 additions & 0 deletions include/oneapi/math/rng/device/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,64 @@ class bernoulli : detail::distribution_base<bernoulli<IntType, Method>> {
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

// Class template oneapi::math::rng::device::geometric
//
// Represents discrete geometric random number distribution
//
// Supported types:
// std::uint32_t
// std::int32_t
// std::uint64_t
// std::int64_t
//
// Supported methods:
// oneapi::math::rng::geometric_method::icdf;
//
// Input arguments:
// p - success probablity of a trial. 0.5 by default
//
template <typename IntType, typename Method>
class geometric : detail::distribution_base<geometric<IntType, Method>> {
public:
static_assert(std::is_same<Method, geometric_method::icdf>::value,
"oneMath: rng/geometric: method is incorrect");

static_assert(std::is_same<IntType, std::int32_t>::value ||
std::is_same<IntType, std::uint32_t>::value ||
std::is_same<IntType, std::int64_t>::value ||
std::is_same<IntType, std::uint64_t>::value,
"oneMath: rng/geometric: type is not supported");

using method_type = Method;
using result_type = IntType;
using param_type = typename detail::distribution_base<geometric<IntType, Method>>::param_type;

geometric() : detail::distribution_base<geometric<IntType, Method>>(0.5f) {}

explicit geometric(float p) : detail::distribution_base<geometric<IntType, Method>>(p) {}
explicit geometric(const param_type& pt)
: detail::distribution_base<geometric<IntType, Method>>(pt.p_) {}

float p() const {
return detail::distribution_base<geometric<IntType, Method>>::p();
}

param_type param() const {
return detail::distribution_base<geometric<IntType, Method>>::param();
}

void param(const param_type& pt) {
detail::distribution_base<geometric<IntType, Method>>::param(pt);
}

template <typename Distr, typename Engine>
friend auto generate(Distr& distr, Engine& engine) ->
typename std::conditional<Engine::vec_size == 1, typename Distr::result_type,
sycl::vec<typename Distr::result_type, Engine::vec_size>>::type;
template <typename Distr, typename Engine>
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

} // namespace oneapi::math::rng::device

#endif // ONEMATH_RNG_DEVICE_DISTRIBUTIONS_HPP_
5 changes: 5 additions & 0 deletions include/oneapi/math/rng/device/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ struct icdf {};
using by_default = icdf;
} // namespace bernoulli_method

namespace geometric_method {
struct icdf {};
using by_default = icdf;
} // namespace geometric_method

namespace beta_method {
struct cja {};
struct cja_accurate {};
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,22 @@ struct statistics_device<oneapi::math::rng::device::bernoulli<Fp, Method>> {
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::geometric<Fp, Method>> {
template <typename AllocType>
bool check(const std::vector<Fp, AllocType>& r,
const oneapi::math::rng::device::geometric<Fp, Method>& distr) {
double tM, tD, tQ;
double p = static_cast<double>(distr.p());

tM = (1.0 - p) / p;
tD = (1.0 - p) / (p * p);
tQ = (1.0 - p) * (p * p - 9.0 * p + 9.0) / (p * p * p * p);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::math::rng::device::beta<Fp, Method>> {
template <typename AllocType>
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/rng/device/moments/moments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,99 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite,
Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

class Philox4x32x10GeometricIcdfDeviceMomentsTests
: public ::testing::TestWithParam<sycl::device*> {};

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, IntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedIntegerPrecision) {
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, Integer64Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedInteger64Precision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());

rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<1>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test1;
EXPECT_TRUEORSKIP((test1(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<4>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test2;
EXPECT_TRUEORSKIP((test2(GetParam())));
rng_device_test<
moments_test<oneapi::math::rng::device::philox4x32x10<16>,
oneapi::math::rng::device::geometric<
std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>>
test3;
EXPECT_TRUEORSKIP((test3(GetParam())));
}

INSTANTIATE_TEST_SUITE_P(Philox4x32x10GeometricIcdfDeviceMomentsTestsSuite,
Philox4x32x10GeometricIcdfDeviceMomentsTests, ::testing::ValuesIn(devices),
::DeviceNamePrint());

} // anonymous namespace

0 comments on commit be58ee0

Please sign in to comment.