diff --git a/src/fitting.jl b/src/fitting.jl index 1034b5a..d9d2228 100644 --- a/src/fitting.jl +++ b/src/fitting.jl @@ -5,6 +5,11 @@ using Plots using Cuba using SpecialFunctions +""" + get_bkg_info(config) + +Function that retrieves background information from config in input. +""" function get_bkg_info(config) bkg_shape=:uniform bkg_shape_pars=nothing @@ -18,6 +23,33 @@ function get_bkg_info(config) return bkg_shape,bkg_shape_pars end + +""" + get_corr_info(config) + +Function that retrieves information about correlated background from config in input. +""" +function get_corr_info(config) + if !(haskey(config["bkg"], "correlated")) + return false, nothing, nothing + end + + if (haskey(config["bkg"],"correlated")) & (config["bkg"]["correlated"]["mode"]!="none") + corr= true + hier_mode =config["bkg"]["correlated"]["mode"] + hier_range =config["bkg"]["correlated"]["range"] + return corr,hier_mode,hier_range + else + return false,nothing,nothing + end +end + + +""" + get_range(fit_range::Union{Vector{Vector{Int}}, Vector{Vector{Float64}}}) + +Function that returns lower and upper edges of fit ranges. +""" function get_range(fit_range::Union{Vector{Vector{Int}}, Vector{Vector{Float64}}}) range_l = [arr[1] for arr in fit_range] range_h = [arr[2] for arr in fit_range] @@ -137,16 +169,7 @@ Function to retrieve useful pieces (prior, likelihood, posterior), also in savin function get_stat_blocks(partitions,events::Array{Vector{Float64}},part_event_index,fit_ranges;config,bkg_only) settings=get_settings(config) - - if (haskey(config["bkg"],"correlated")) & (config["bkg"]["correlated"]["mode"]!="none") - corr= true - hier_mode =config["bkg"]["correlated"]["mode"] - hier_range =config["bkg"]["correlated"]["range"] - else - corr=false - hier_mode=nothing - hier_range=nothing - end + corr,hier_mode,hier_range = get_corr_info(config) bkg_shape,bkg_shape_pars = get_bkg_info(config) diff --git a/test/io/test_all.jl b/test/io/test_all.jl index c9465ca..10b3ee0 100644 --- a/test/io/test_all.jl +++ b/test/io/test_all.jl @@ -6,4 +6,5 @@ Test.@testset "likelihood" begin include("test_get_events.jl") include("test_get_partition_event_index.jl") include("test_get_partitions_events.jl") + include("test_get_corr_info.jl") end diff --git a/test/io/test_get_corr_info.jl b/test/io/test_get_corr_info.jl new file mode 100644 index 0000000..8110f9c --- /dev/null +++ b/test/io/test_get_corr_info.jl @@ -0,0 +1,73 @@ +using Pkg +Pkg.activate(".") +Pkg.instantiate() +using Random +include("../../src/ZeroNuFit.jl") +using .ZeroNuFit +include("../../main.jl") + +@testset "test_get_corr_info" begin + + @info "Testing function to retrieve bkg correlation info (function 'get_corr_info' in src/fitting.jl)" + + # no entry for correlated bkg + config = Dict("bkg" => Dict()) + + corr = false + hier_mode=nothing + hier_range=nothing + try + corr,hier_mode,hier_range = ZeroNuFit.get_corr_info(config) + catch e + @error "Error in 'get_corr_info' evaluation: $e" + throw(e) + end + + @test corr == false + @test hier_mode == nothing + @test hier_range == nothing + + # not-correlated bkg + config = Dict("bkg" => Dict("correlated" => Dict("mode" => "none", "range" => "none"))) + + corr = false + hier_mode=nothing + hier_range=nothing + try + corr,hier_mode,hier_range = ZeroNuFit.get_corr_info(config) + catch e + @error "Error in 'get_corr_info' evaluation: $e" + throw(e) + end + + @test corr == false + @test hier_mode == nothing + @test hier_range == nothing + + # correlated bkg + config = Dict("bkg" => Dict("correlated" => Dict("mode" => "lognormal", "range" => [0,0.1]))) + + corr = false + hier_mode=nothing + hier_range=nothing + try + corr,hier_mode,hier_range = ZeroNuFit.get_corr_info(config) + catch e + @error "Error in 'get_corr_info' evaluation: $e" + throw(e) + end + + expected_value = true + @testset "Check corr accuracy" begin + @test corr == expected_value + end + expected_value = config["bkg"]["correlated"]["mode"] + @testset "Check hier_mode accuracy" begin + @test hier_mode == expected_value + end + expected_value = config["bkg"]["correlated"]["range"] + @testset "Check hier_range accuracy" begin + @test hier_range == expected_value + end + +end