Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
iMartyan committed Dec 13, 2024
1 parent 40b7e0a commit 7f81a95
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/oneapi/math/rng/device/detail/distribution_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class poisson;
template <typename IntType = std::uint32_t, typename Method = bernoulli_method::by_default>
class bernoulli;

template <typename IntType = std::uint32_t, typename Method = geometric_method::by_default>
class geometric;

} // namespace oneapi::math::rng::device

#include "oneapi/math/rng/device/detail/uniform_impl.hpp"
Expand All @@ -75,6 +78,7 @@ class bernoulli;
#include "oneapi/math/rng/device/detail/exponential_impl.hpp"
#include "oneapi/math/rng/device/detail/poisson_impl.hpp"
#include "oneapi/math/rng/device/detail/bernoulli_impl.hpp"
#include "oneapi/math/rng/device/detail/geometric_impl.hpp"
#include "oneapi/math/rng/device/detail/beta_impl.hpp"
#include "oneapi/math/rng/device/detail/gamma_impl.hpp"

Expand Down
95 changes: 95 additions & 0 deletions include/oneapi/math/rng/device/detail/geometric_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
#define ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_

namespace oneapi::math::rng::device::detail {

template <typename IntType, typename Method>
class distribution_base<oneapi::math::rng::device::geometric<IntType, Method>> {
public:
struct param_type {
param_type(float p) : p_(p) {}
float p_;
};

distribution_base(float p) : p_(p) {
#ifndef __SYCL_DEVICE_ONLY__
if ((p > 1.0f) || (p < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
}

float p() const {
return p_;
}

param_type param() const {
return param_type(p_);
}

void param(const param_type& pt) {
#ifndef __SYCL_DEVICE_ONLY__
if ((pt.p_ > 1.0f) || (pt.p_ < 0.0f)) {
throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1");
}
#endif
p_ = pt.p_;
}

protected:
template <typename EngineType>
auto generate(EngineType& engine) ->
typename std::conditional<EngineType::vec_size == 1, IntType,
sycl::vec<IntType, EngineType::vec_size>>::type {
using FpType =
typename std::conditional<std::is_same_v<IntType, std::uint64_t> ||
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

auto uni_res = engine.generate(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
if constexpr (EngineType::vec_size == 1) {
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}
else {
sycl::vec<IntType, EngineType::vec_size> vec_out;
for (int i = 0; i < EngineType::vec_size; i++)
vec_out[i] = static_cast<IntType>(sycl::floor(ln_wrapper(uni_res[i]) * inv_ln));
return vec_out;
}
}

template <typename EngineType>
IntType generate_single(EngineType& engine) {
auto uni_res = engine.generate_single(FpType(0.0), FpType(1.0));
float inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}

float p_;
};

} // namespace oneapi::math::rng::device::detail

#endif // ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_
59 changes: 59 additions & 0 deletions include/oneapi/math/rng/device/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,65 @@ class bernoulli : detail::distribution_base<bernoulli<IntType, Method>> {
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

// Class template oneapi::mkl::rng::device::geometric
//
// Represents discrete geometric random number distribution
//
// Supported types:
// std::uint32_t
// std::int32_t
// std::uint64_t
// std::int64_t
//
// Supported methods:
// oneapi::mkl::rng::geometric_method::icdf;
//
// Input arguments:
// p - success probablity of a trial. 0.5 by default
//
template <typename IntType, typename Method>
class geometric : detail::distribution_base<geometric<IntType, Method>> {
public:
static_assert(std::is_same<Method, geometric_method::icdf>::value,
"oneMKL: rng/geometric: method is incorrect");

static_assert(std::is_same<IntType, std::int32_t>::value ||
std::is_same<IntType, std::uint32_t>::value ||
std::is_same<IntType, std::int64_t>::value ||
std::is_same<IntType, std::uint64_t>::value,
"oneMKL: rng/geometric: type is not supported");

using method_type = Method;
using result_type = IntType;
using param_type = typename detail::distribution_base<geometric<IntType, Method>>::param_type;

geometric() : detail::distribution_base<geometric<IntType, Method>>(0.5f) {}

explicit geometric(float p) : detail::distribution_base<geometric<IntType, Method>>(p) {}
explicit geometric(const param_type& pt)
: detail::distribution_base<geometric<IntType, Method>>(pt.p_) {}

float p() const {
return detail::distribution_base<geometric<IntType, Method>>::p();
}

param_type param() const {
return detail::distribution_base<geometric<IntType, Method>>::param();
}

void param(const param_type& pt) {
detail::distribution_base<geometric<IntType, Method>>::param(pt);
}

template <typename Distr, typename Engine>
friend auto generate(Distr& distr, Engine& engine) ->
typename std::conditional<Engine::vec_size == 1, typename Distr::result_type,
sycl::vec<typename Distr::result_type, Engine::vec_size>>::type;

template <typename Distr, typename Engine>
friend typename Distr::result_type generate_single(Distr& distr, Engine& engine);
};

} // namespace oneapi::math::rng::device

#endif // ONEMATH_RNG_DEVICE_DISTRIBUTIONS_HPP_
5 changes: 5 additions & 0 deletions include/oneapi/math/rng/device/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ struct icdf {};
using by_default = icdf;
} // namespace bernoulli_method

namespace geometric_method {
struct icdf {};
using by_default = icdf;
} // namespace geometric_method

namespace beta_method {
struct cja {};
struct cja_accurate {};
Expand Down

0 comments on commit 7f81a95

Please sign in to comment.