Skip to content

Commit

Permalink
add regression tests for extra floating point functions
Browse files Browse the repository at this point in the history
  • Loading branch information
julmb committed Mar 11, 2024
1 parent d226855 commit 8916e4f
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion tests/Regression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

module Main (main) where

import Numeric
import qualified Numeric.AD.Mode.Forward as F
import qualified Numeric.AD.Mode.Forward.Double as FD
import qualified Numeric.AD.Mode.Reverse as R
Expand All @@ -28,7 +29,11 @@ tests = testGroup "tests" [
mode "reverse-double" (\ f -> RD.diff' f) (\ f -> RD.grad f) (\ f -> RD.jacobian f) (\ f -> RD.hessian f)]

mode :: String -> Diff -> Grad -> Jacobian -> Hessian -> TestTree
mode name diff grad jacobian hessian = testGroup name [basic diff grad jacobian hessian, issue97 diff, issue104 diff grad]
mode name diff grad jacobian hessian = testGroup name [
basic diff grad jacobian hessian,
issue97 diff,
issue104 diff grad,
issue108 diff]

basic :: Diff -> Grad -> Jacobian -> Hessian -> TestTree
basic diff grad jacobian hessian = testGroup "basic" [tdiff, tgrad, tjacobian, thessian] where
Expand Down Expand Up @@ -90,6 +95,42 @@ issue104 diff grad = testGroup "issue-104" [inside, outside] where
f x y = sqrt x * sqrt y -- grad f [x, y] = [sqrt y / 2 sqrt x, sqrt x / 2 sqrt y]
binary f [x, y] = f x y

issue108 :: Diff -> TestTree
issue108 diff = testGroup "issue-108" [tlog1p, texpm1, tlog1pexp, tlog1mexp] where
tlog1p = testCase "log1p" $ do
equal (-inf, inf) $ diff log1p (-1)
equal (-1.0000000000000007e-15, 1.000000000000001) $ diff log1p (-1e-15)
equal (-1e-20, 1) $ diff log1p (-1e-20)
equal (0, 1) $ diff log1p 0
equal (1e-20, 1) $ diff log1p 1e-20
equal (9.999999999999995e-16, 0.9999999999999989) $ diff log1p 1e-15
equal (0.6931471805599453, 0.5) $ diff log1p 1
texpm1 = testCase "expm1" $ do
equal (-0.6321205588285577, 0.36787944117144233) $ diff expm1 (-1)
equal (-9.999999999999995e-16, 0.999999999999999) $ diff expm1 (-1e-15)
equal (-1e-20, 1) $ diff expm1 (-1e-20)
equal (0, 1) $ diff expm1 0
equal (1e-20, 1) $ diff expm1 1e-20
equal (1.0000000000000007e-15, 1.000000000000001) $ diff expm1 1e-15
equal (1.718281828459045, 2.718281828459045) $ diff expm1 1
tlog1pexp = testCase "log1pexp" $ do
equal (0, 0) $ diff log1pexp (-1000)
equal (3.720075976020836e-44, 3.7200759760208356e-44) $ diff log1pexp (-100)
equal (0.31326168751822286, 0.2689414213699951) $ diff log1pexp (-1)
equal (0.6931471805599453, 0.5) $ diff log1pexp 0
equal (1.3132616875182228, 0.7310585786300049) $ diff log1pexp 1
equal (100, 1) $ diff log1pexp 100
equal (1000, 1) $ diff log1pexp 1000
tlog1mexp = testCase "log1mexp" $ do
equal (-0, -0) $ diff log1mexp (-1000)
equal (-3.720075976020836e-44, -3.7200759760208356e-44) $ diff log1mexp (-100)
equal (-0.45867514538708193, -0.5819767068693265) $ diff log1mexp (-1)
equal (-0.9327521295671886, -1.5414940825367982) $ diff log1mexp (-0.5)
equal (-2.3521684610440907, -9.50833194477505) $ diff log1mexp (-0.1)
equal (-34.538776394910684, -9.999999999999994e14) $ diff log1mexp (-1e-15)
equal (-46.051701859880914, -1e20) $ diff log1mexp (-1e-20)
equal = expect $ \ (a, b) (c, d) -> eq a c && eq b d

eq :: Double -> Double -> Bool
eq a b = isNaN a && isNaN b || a == b

Expand Down

0 comments on commit 8916e4f

Please sign in to comment.