-
Notifications
You must be signed in to change notification settings - Fork 10
/
min_metric.jl
50 lines (43 loc) · 1.92 KB
/
min_metric.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
export MinMetricGoal
"""
MinMetricGoal(terms, metric::Term)
MinMetricGoal(goal::Term, metric::Term)
MinMetricGoal(problem::Problem)
[`Goal`](@ref) specification where each step has a cost specified by the
difference in values of a `metric` formula between the next state and the
current state, and the goal formula is a conjuction of `terms`.Planners called
with this specification will try to minimize the `metric` formula when solving
for the goal.
"""
struct MinMetricGoal <: Goal
terms::Vector{Term} # Goal terms to be satisfied
metric::Term # Cost metric to be minimized
end
function MinMetricGoal(problem::Problem)
goals = PDDL.flatten_conjs(PDDL.get_goal(problem))
metric = PDDL.get_metric(problem)
metric = metric.name == :minimize ?
metric.args[1] : Compound(:-, metric.args)
return MinMetricGoal(goals, metric)
end
MinMetricGoal(goal::Term, metric::Term) =
MinMetricGoal(PDDL.flatten_conjs(goal), metric)
function Base.show(io::IO, ::MIME"text/plain", spec::MinMetricGoal)
indent = get(io, :indent, "")
show_struct(io, spec; indent = indent,
show_pddl_list=(:terms,), show_pddl=(:metric,))
end
Base.hash(spec::MinMetricGoal, h::UInt) =
hash(spec.metric, hash(Set(spec.terms), h))
Base.:(==)(s1::MinMetricGoal, s2::MinMetricGoal) =
s1.metric == s2.metric && Set(s1.terms) == Set(s2.terms)
is_goal(spec::MinMetricGoal, domain::Domain, state::State) =
satisfy(domain, state, spec.terms)
is_violated(spec::MinMetricGoal, domain::Domain, state::State) = false
get_cost(spec::MinMetricGoal, domain::Domain, s1::State, ::Term, s2::State) =
domain[s2 => spec.metric] - domain[s1 => spec.metric]
get_reward(spec::MinMetricGoal, domain::Domain, s1::State, ::Term, s2::State) =
domain[s1 => spec.metric] - domain[s2 => spec.metric]
get_goal_terms(spec::MinMetricGoal) = spec.terms
set_goal_terms(spec::MinMetricGoal, terms) =
MinMetricGoal(terms, spec.metric)