-
Notifications
You must be signed in to change notification settings - Fork 11
/
utils.py
67 lines (53 loc) · 2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.
"""Useful tools."""
from typing import Union, Optional
import numpy as np
def sigmoid(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""Sigmoid function"""
return 1 / (1 + np.exp(-x))
def check_array(
array: np.ndarray,
name: str,
expected_dim: int = 1,
expected_dtype: Optional[type] = None,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
) -> ValueError:
"""Input validation on array.
Parameters
-------
array: object
Input array to check.
name: str
Name of the input array.
expected_dim: int, default=1
Expected dimension of the input array.
expected_dtype: {type, tuple of type}, default=None
Expected dtype of the input array.
min_val: float, default=None
Minimum value allowed in the input array.
max_val: float, default=None
Maximum value allowed in the input array.
"""
if not isinstance(array, np.ndarray):
raise ValueError(f"{name} must be {expected_dim}D array, but got {type(array)}")
if array.ndim != expected_dim:
raise ValueError(
f"{name} must be {expected_dim}D array, but got {expected_dim}D array"
)
if expected_dtype is not None:
if not np.issubsctype(array, expected_dtype):
raise ValueError(
f"The elements of {name} must be {expected_dtype}, but got {array.dtype}"
)
if min_val is not None:
if array.min() < min_val:
raise ValueError(
f"The elements of {name} must be larger than {min_val}, but got minimum value {array.min()}"
)
if max_val is not None:
if array.max() > max_val:
raise ValueError(
f"The elements of {name} must be smaller than {max_val}, but got maximum value {array.max()}"
)