diff --git a/src/fitting.jl b/src/fitting.jl index d9d2228..2921b67 100644 --- a/src/fitting.jl +++ b/src/fitting.jl @@ -31,7 +31,7 @@ Function that retrieves information about correlated background from config in i """ function get_corr_info(config) if !(haskey(config["bkg"], "correlated")) - return false, nothing, nothing + return false,nothing,nothing end if (haskey(config["bkg"],"correlated")) & (config["bkg"]["correlated"]["mode"]!="none") @@ -45,6 +45,24 @@ function get_corr_info(config) end +""" + get_prior_info(bkg_only::Bool,config) + +Function that retrieves signal prior information. +""" +function get_signal_prior_info(bkg_only::Bool,config) + sqrt_prior=false + s_max=nothing + if bkg_only==false + if (config["signal"]["prior"]=="sqrt") + sqrt_prior=true + s_max = Float64(config["signal"]["upper_bound"]) + end + end + return sqrt_prior,s_max +end + + """ get_range(fit_range::Union{Vector{Vector{Int}}, Vector{Vector{Float64}}}) @@ -178,14 +196,7 @@ function get_stat_blocks(partitions,events::Array{Vector{Float64}},part_event_in @info "using a ",bkg_shape," bkg with ",bkg_shape_pars," parameters" @info "built prior" - sqrt_prior=false - s_max=nothing - if bkg_only==false - if (config["signal"]["prior"]=="sqrt") - sqrt_prior=true - s_max = Float64(config["signal"]["upper_bound"]) - end - end + sqrt_prior,s_max = get_signal_prior_info(bkg_only,config) likelihood = build_likelihood_looping_partitions(partitions, events, part_event_index,settings,sqrt_prior,s_max,fit_ranges,bkg_shape=bkg_shape) @info "built likelihood" diff --git a/test/io/test_all.jl b/test/io/test_all.jl index 10b3ee0..3cffb71 100644 --- a/test/io/test_all.jl +++ b/test/io/test_all.jl @@ -7,4 +7,5 @@ Test.@testset "likelihood" begin include("test_get_partition_event_index.jl") include("test_get_partitions_events.jl") include("test_get_corr_info.jl") + include("test_get_signal_prior_info.jl") end diff --git a/test/io/test_get_signal_prior_info.jl b/test/io/test_get_signal_prior_info.jl new file mode 100644 index 0000000..3abc3b6 --- /dev/null +++ b/test/io/test_get_signal_prior_info.jl @@ -0,0 +1,60 @@ +using Pkg +Pkg.activate(".") +Pkg.instantiate() +using Random +include("../../src/ZeroNuFit.jl") +using .ZeroNuFit +include("../../main.jl") + +@testset "test_get_signal_prior_info" begin + + @info "Testing function to retrieve signal prior info (function 'get_signal_prior_info' in src/fitting.jl)" + + # flat S prior (default) + config = Dict("signal" => Dict("prior" => "uniform", "upper_bound" => 10)) + + # only B fit + bkg_only = true + sqrt_prior = nothing + s_max = nothing + try + sqrt_prior,s_max = ZeroNuFit.get_signal_prior_info(bkg_only,config) + catch e + @error "Error in 'get_signal_prior_info' evaluation: $e" + throw(e) + end + + @test sqrt_prior == false + @test s_max == nothing + + # B+S fit + bkg_only = false + sqrt_prior = nothing + s_max = nothing + try + sqrt_prior,s_max = ZeroNuFit.get_signal_prior_info(bkg_only,config) + catch e + @error "Error in 'get_signal_prior_info' evaluation: $e" + throw(e) + end + + @test sqrt_prior == false + @test s_max == nothing + + # 1/sqrt(S) prior + config = Dict("signal" => Dict("prior" => "sqrt", "upper_bound" => 10)) + + bkg_only = false + sqrt_prior = nothing + s_max = nothing + try + sqrt_prior,s_max = ZeroNuFit.get_signal_prior_info(bkg_only,config) + catch e + @error "Error in 'get_signal_prior_info' evaluation: $e" + throw(e) + end + + @test sqrt_prior == true + @test s_max == 10.0 + +end