diff --git a/Cargo.toml b/Cargo.toml index af4c9c7..2f1f454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "cassowary" -version = "0.3.0" -authors = ["Dylan Ede "] +version = "0.4.0" +authors = ["Dylan Ede ", "Schell Scivally "] description = """ A Rust implementation of the Cassowary linear constraint solving algorithm. @@ -16,3 +16,4 @@ license = "MIT / Apache-2.0" keywords = ["constraint", "simplex", "user", "interface", "layout"] [dependencies] +ordered-float = "1.0" \ No newline at end of file diff --git a/cassowary-rs/src/solver_impl.rs b/cassowary-rs/src/solver_impl.rs new file mode 100644 index 0000000..97a72bf --- /dev/null +++ b/cassowary-rs/src/solver_impl.rs @@ -0,0 +1,955 @@ +use super::{ + Symbol, + Tag, + SymbolType, + Constraint, + Expression, + Term, + Row, + AddConstraintError, + RemoveConstraintError, + InternalSolverError, + SuggestValueError, + AddEditVariableError, + RemoveEditVariableError, + RelationalOperator, + near_zero +}; + +use std::collections::{ HashMap, HashSet }; +use std::hash::Hash; +use std::fmt::Debug; +use std::collections::hash_map::Entry; + + +#[derive(Clone)] +#[derive(Debug)] +struct EditInfo { + tag: Tag, + constraint: Constraint, + constant: f64 +} + +/// A constraint solver using the Cassowary algorithm. For proper usage please see the top level crate documentation. +#[derive(Debug)] +pub struct Solver +where + T: Debug + Clone + Eq + Hash +{ + cns: HashMap, Tag>, + var_data: HashMap, + var_for_symbol: HashMap, + public_changes: Vec<(T, f64)>, + changed: HashSet, + should_clear_changes: bool, + rows: HashMap, + edits: HashMap>, + infeasible_rows: Vec, // never contains external symbols + objective: Row, + artificial: Option, + id_tick: usize +} + +impl Solver +{ + /// Construct a new solver. + pub fn new() -> Solver { + Solver { + cns: HashMap::new(), + var_data: HashMap::new(), + var_for_symbol: HashMap::new(), + public_changes: Vec::new(), + changed: HashSet::new(), + should_clear_changes: false, + rows: HashMap::new(), + edits: HashMap::new(), + infeasible_rows: Vec::new(), + objective: Row::new(0.0), + artificial: None, + id_tick: 1 + } + } + + pub fn add_constraints>>( + &mut self, + constraints: I) -> Result<(), AddConstraintError> + { + for constraint in constraints { + self.add_constraint(constraint)?; + } + Ok(()) + } + + /// Add a constraint to the solver. + pub fn add_constraint(&mut self, constraint: Constraint) -> Result<(), AddConstraintError> { + if self.cns.contains_key(&constraint) { + return Err(AddConstraintError::DuplicateConstraint); + } + + // Creating a row causes symbols to reserved for the variables + // in the constraint. If this method exits with an exception, + // then its possible those variables will linger in the var map. + // Since its likely that those variables will be used in other + // constraints and since exceptional conditions are uncommon, + // i'm not too worried about aggressive cleanup of the var map. + let (mut row, tag) = self.create_row(&constraint); + let mut subject = Symbol::choose_subject(&row, &tag); + + // If chooseSubject could find a valid entering symbol, one + // last option is available if the entire row is composed of + // dummy variables. If the constant of the row is zero, then + // this represents redundant constraints and the new dummy + // marker can enter the basis. If the constant is non-zero, + // then it represents an unsatisfiable constraint. + if subject.type_() == SymbolType::Invalid && row.all_dummies() { + if !near_zero(*row.constant.as_ref()) { + return Err(AddConstraintError::UnsatisfiableConstraint); + } else { + subject = tag.marker; + } + } + + // If an entering symbol still isn't found, then the row must + // be added using an artificial variable. If that fails, then + // the row represents an unsatisfiable constraint. + if subject.type_() == SymbolType::Invalid { + if !try!(self.add_with_artificial_variable(&row) + .map_err(|e| AddConstraintError::InternalSolverError(e.0))) { + return Err(AddConstraintError::UnsatisfiableConstraint); + } + } else { + row.solve_for_symbol(subject); + self.substitute(subject, &row); + if subject.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v:T = self.var_for_symbol[&subject].clone(); + self.var_changed(v); + } + self.rows.insert(subject, row); + } + + self.cns.insert(constraint, tag); + + // Optimizing after each constraint is added performs less + // aggregate work due to a smaller average system size. It + // also ensures the solver remains in a consistent state. + let objective = self.objective.clone(); + self.optimise(&objective).map_err(|e| AddConstraintError::InternalSolverError(e.0))?; + Ok(()) + } + + /// Remove a constraint from the solver. + pub fn remove_constraint(&mut self, constraint: &Constraint) -> Result<(), RemoveConstraintError> { + let tag = self.cns.remove(constraint).ok_or(RemoveConstraintError::UnknownConstraint)?; + + // Remove the error effects from the objective function + // *before* pivoting, or substitutions into the objective + // will lead to incorrect solver results. + self.remove_constraint_effects(constraint, &tag); + + // If the marker is basic, simply drop the row. Otherwise, + // pivot the marker into the basis and then drop the row. + if let None = self.rows.remove(&tag.marker) { + let (leaving, mut row) = + self.get_marker_leaving_row(tag.marker) + .ok_or( + RemoveConstraintError::InternalSolverError( + "Failed to find leaving row." + ) + )?; + row.solve_for_symbols(leaving, tag.marker); + self.substitute(tag.marker, &row); + } + + // Optimizing after each constraint is removed ensures that the + // solver remains consistent. It makes the solver api easier to + // use at a small tradeoff for speed. + let objective = self.objective.clone(); + self.optimise(&objective).map_err(|e| RemoveConstraintError::InternalSolverError(e.0))?; + + // Check for and decrease the reference count for variables referenced by the constraint + // If the reference count is zero remove the variable from the variable map + for term in &constraint.expr().terms { + if !near_zero(term.coefficient.into_inner()) { + let mut should_remove = false; + if let Some(&mut (_, _, ref mut count)) = self.var_data.get_mut(&term.variable) { + *count -= 1; + should_remove = *count == 0; + } + if should_remove { + self.var_for_symbol.remove(&self.var_data[&term.variable].1); + self.var_data.remove(&term.variable); + } + } + } + Ok(()) + } + + /// Test whether a constraint has been added to the solver. + pub fn has_constraint(&self, constraint: &Constraint) -> bool { + self.cns.contains_key(constraint) + } + + /// Add an edit variable to the solver. + /// + /// This method should be called before the `suggest_value` method is + /// used to supply a suggested value for the given edit variable. + pub fn add_edit_variable(&mut self, v: T, strength: f64) -> Result<(), AddEditVariableError> { + if self.edits.contains_key(&v) { + return Err(AddEditVariableError::DuplicateEditVariable); + } + let strength = ::strength::clip(strength); + if strength == ::strength::REQUIRED { + return Err(AddEditVariableError::BadRequiredStrength); + } + let cn = Constraint::new(Expression::from_term(Term::new(v.clone(), 1.0)), + RelationalOperator::Equal, + strength); + self.add_constraint(cn.clone()).unwrap(); + self.edits.insert(v.clone(), EditInfo { + tag: self.cns[&cn].clone(), + constraint: cn, + constant: 0.0 + }); + Ok(()) + } + + /// Remove an edit variable from the solver. + pub fn remove_edit_variable(&mut self, v: T) -> Result<(), RemoveEditVariableError> { + if let Some(constraint) = self.edits.remove(&v).map(|e| e.constraint) { + try!(self.remove_constraint(&constraint) + .map_err(|e| match e { + RemoveConstraintError::UnknownConstraint => + RemoveEditVariableError::InternalSolverError("Edit constraint not in system"), + RemoveConstraintError::InternalSolverError(s) => + RemoveEditVariableError::InternalSolverError(s) + })); + Ok(()) + } else { + Err(RemoveEditVariableError::UnknownEditVariable) + } + } + + /// Test whether an edit variable has been added to the solver. + pub fn has_edit_variable(&self, v: &T) -> bool { + self.edits.contains_key(v) + } + + /// Suggest a value for the given edit variable. + /// + /// This method should be used after an edit variable has been added to + /// the solver in order to suggest the value for that variable. + pub fn suggest_value(&mut self, variable: T, value: f64) -> Result<(), SuggestValueError> { + let (info_tag_marker, info_tag_other, delta) = { + let info = self.edits.get_mut(&variable).ok_or(SuggestValueError::UnknownEditVariable)?; + let delta = value - info.constant; + info.constant = value; + (info.tag.marker, info.tag.other, delta) + }; + // tag.marker and tag.other are never external symbols + + // The nice version of the following code runs into non-lexical borrow issues. + // Ideally the `if row...` code would be in the body of the if. Pretend that it is. + { + let infeasible_rows = &mut self.infeasible_rows; + if self.rows.get_mut(&info_tag_marker) + .map(|row| + if row.add(-delta) < 0.0 { + infeasible_rows.push(info_tag_marker); + }).is_some() + { + + } else if self.rows.get_mut(&info_tag_other) + .map(|row| + if row.add(delta) < 0.0 { + infeasible_rows.push(info_tag_other); + }).is_some() + { + + } else { + for (symbol, row) in &mut self.rows { + let coeff = row.coefficient_for(info_tag_marker); + let diff = delta * coeff; + if diff != 0.0 && symbol.type_() == SymbolType::External { + let v = self.var_for_symbol[symbol].clone(); + // inline var_changed - borrow checker workaround + if self.should_clear_changes { + self.changed.clear(); + self.should_clear_changes = false; + } + self.changed.insert(v); + } + if coeff != 0.0 && + row.add(diff) < 0.0 && + symbol.type_() != SymbolType::External + { + infeasible_rows.push(*symbol); + } + } + } + } + self.dual_optimise().map_err(|e| SuggestValueError::InternalSolverError(e.0))?; + return Ok(()); + } + + fn var_changed(&mut self, v: T) { + if self.should_clear_changes { + self.changed.clear(); + self.should_clear_changes = false; + } + self.changed.insert(v); + } + + /// Fetches all changes to the values of variables since the last call to this function. + /// + /// The list of changes returned is not in a specific order. Each change comprises the variable changed and + /// the new value of that variable. + pub fn fetch_changes(&mut self) -> &[(T, f64)] { + if self.should_clear_changes { + self.changed.clear(); + self.should_clear_changes = false; + } else { + self.should_clear_changes = true; + } + self.public_changes.clear(); + for v in &self.changed { + if let Some(var_data) = self.var_data.get_mut(&v) { + let new_value = self.rows.get(&var_data.1).map(|r| r.constant).map(|o| o.into_inner()).unwrap_or(0.0); + let old_value = var_data.0; + if old_value != new_value { + self.public_changes.push((v.clone(), new_value)); + var_data.0 = new_value; + } + } + } + &self.public_changes + } + + /// Reset the solver to the empty starting condition. + /// + /// This method resets the internal solver state to the empty starting + /// condition, as if no constraints or edit variables have been added. + /// This can be faster than deleting the solver and creating a new one + /// when the entire system must change, since it can avoid unnecessary + /// heap (de)allocations. + pub fn reset(&mut self) { + self.rows.clear(); + self.cns.clear(); + self.var_data.clear(); + self.var_for_symbol.clear(); + self.changed.clear(); + self.should_clear_changes = false; + self.edits.clear(); + self.infeasible_rows.clear(); + self.objective = Row::new(0.0); + self.artificial = None; + self.id_tick = 1; + } + + /// Get the symbol for the given variable. + /// + /// If a symbol does not exist for the variable, one will be created. + fn get_var_symbol(&mut self, v: T) -> Symbol { + let id_tick = &mut self.id_tick; + let var_for_symbol = &mut self.var_for_symbol; + let value = self.var_data.entry(v.clone()).or_insert_with(|| { + let s = Symbol(*id_tick, SymbolType::External); + var_for_symbol.insert(s, v); + *id_tick += 1; + (std::f64::NAN, s, 0) + }); + value.2 += 1; + value.1 + } + + /// Create a new Row object for the given constraint. + /// + /// The terms in the constraint will be converted to cells in the row. + /// Any term in the constraint with a coefficient of zero is ignored. + /// This method uses the `getVarSymbol` method to get the symbol for + /// the variables added to the row. If the symbol for a given cell + /// variable is basic, the cell variable will be substituted with the + /// basic row. + /// + /// The necessary slack and error variables will be added to the row. + /// If the constant for the row is negative, the sign for the row + /// will be inverted so the constant becomes positive. + /// + /// The tag will be updated with the marker and error symbols to use + /// for tracking the movement of the constraint in the tableau. + fn create_row(&mut self, constraint: &Constraint) -> (Row, Tag) { + let expr = constraint.expr(); + let mut row = Row::new(expr.constant.into_inner()); + // Substitute the current basic variables into the row. + for term in &expr.terms { + if !near_zero(term.coefficient.into_inner()) { + let symbol = self.get_var_symbol(term.variable.clone()); + if let Some(other_row) = self.rows.get(&symbol) { + row.insert_row(other_row, term.coefficient.into_inner()); + } else { + row.insert_symbol(symbol, term.coefficient.into_inner()); + } + } + } + + // Add the necessary slack, error, and dummy variables. + let tag = match constraint.op() { + RelationalOperator::GreaterOrEqual | + RelationalOperator::LessOrEqual => { + let coeff = if constraint.op() == RelationalOperator::LessOrEqual { + 1.0 + } else { + -1.0 + }; + let slack = Symbol(self.id_tick, SymbolType::Slack); + self.id_tick += 1; + row.insert_symbol(slack, coeff); + if constraint.strength() < ::strength::REQUIRED { + let error = Symbol(self.id_tick, SymbolType::Error); + self.id_tick += 1; + row.insert_symbol(error, -coeff); + self.objective.insert_symbol(error, constraint.strength()); + Tag { + marker: slack, + other: error + } + } else { + Tag { + marker: slack, + other: Symbol::invalid() + } + } + } + RelationalOperator::Equal => { + if constraint.strength() < ::strength::REQUIRED { + let errplus = Symbol(self.id_tick, SymbolType::Error); + self.id_tick += 1; + let errminus = Symbol(self.id_tick, SymbolType::Error); + self.id_tick += 1; + row.insert_symbol(errplus, -1.0); // v = eplus - eminus + row.insert_symbol(errminus, 1.0); // v - eplus + eminus = 0 + self.objective.insert_symbol(errplus, constraint.strength()); + self.objective.insert_symbol(errminus, constraint.strength()); + Tag { + marker: errplus, + other: errminus + } + } else { + let dummy = Symbol(self.id_tick, SymbolType::Dummy); + self.id_tick += 1; + row.insert_symbol(dummy, 1.0); + Tag { + marker: dummy, + other: Symbol::invalid() + } + } + } + }; + + // Ensure the row has a positive constant. + if *row.constant.as_ref() < 0.0 { + row.reverse_sign(); + } + (row, tag) + } + + /// Add the row to the tableau using an artificial variable. + /// + /// This will return false if the constraint cannot be satisfied. + fn add_with_artificial_variable(&mut self, row: &Row) -> Result { + // Create and add the artificial variable to the tableau + let art = Symbol(self.id_tick, SymbolType::Slack); + self.id_tick += 1; + self.rows.insert(art, row.clone()); + self.artificial = Some(row.clone()); + + // Optimize the artificial objective. This is successful + // only if the artificial objective is optimized to zero. + let artificial = self.artificial.as_ref().unwrap().clone(); + self.optimise(&artificial)?; + let success = near_zero(*artificial.constant.as_ref()); + self.artificial = None; + + // If the artificial variable is basic, pivot the row so that + // it becomes basic. If the row is constant, exit early. + if let Some(mut row) = self.rows.remove(&art) { + if row.cells.is_empty() { + return Ok(success); + } + let entering = row.any_pivotable_symbol(); // never External + if entering.type_() == SymbolType::Invalid { + return Ok(false); // unsatisfiable (will this ever happen?) + } + row.solve_for_symbols(art, entering); + self.substitute(entering, &row); + self.rows.insert(entering, row); + } + + // Remove the artificial row from the tableau + for (_, row) in &mut self.rows { + row.remove(art); + } + self.objective.remove(art); + Ok(success) + } + + /// Substitute the parametric symbol with the given row. + /// + /// This method will substitute all instances of the parametric symbol + /// in the tableau and the objective function with the given row. + fn substitute(&mut self, symbol: Symbol, row: &Row) { + for (&other_symbol, other_row) in &mut self.rows { + let constant_changed = other_row.substitute(symbol, row); + if other_symbol.type_() == SymbolType::External && constant_changed { + // inline var_changed + if self.should_clear_changes { + self.changed.clear(); + self.should_clear_changes = false; + } + let v = self.var_for_symbol[&other_symbol].clone(); + self.changed.insert(v); + } + if other_symbol.type_() != SymbolType::External && *other_row.constant.as_ref() < 0.0 { + self.infeasible_rows.push(other_symbol); + } + } + self.objective.substitute(symbol, row); + if let Some(artificial) = self.artificial.as_mut() { + artificial.substitute(symbol, row); + } + } + + /// Optimize the system for the given objective function. + /// + /// This method performs iterations of Phase 2 of the simplex method + /// until the objective function reaches a minimum. + fn optimise(&mut self, objective: &Row) -> Result<(), InternalSolverError> { + loop { + let entering = objective.get_entering_symbol(); + if entering.type_() == SymbolType::Invalid { + return Ok(()); + } + let (leaving, mut row) = + self.get_leaving_row(entering).ok_or(InternalSolverError("The objective is unbounded"))?; + // pivot the entering symbol into the basis + row.solve_for_symbols(leaving, entering); + self.substitute(entering, &row); + if entering.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&entering].clone(); + self.var_changed(v); + } + self.rows.insert(entering, row); + } + } + + /// Optimize the system using the dual of the simplex method. + /// + /// The current state of the system should be such that the objective + /// function is optimal, but not feasible. This method will perform + /// an iteration of the dual simplex method to make the solution both + /// optimal and feasible. + fn dual_optimise(&mut self) -> Result<(), InternalSolverError> { + while !self.infeasible_rows.is_empty() { + let leaving = self.infeasible_rows.pop().unwrap(); + + let row = if let Entry::Occupied(entry) = self.rows.entry(leaving) { + if *entry.get().constant.as_ref() < 0.0 { + Some(entry.remove()) + } else { + None + } + } else { + None + }; + if let Some(mut row) = row { + let entering = self.get_dual_entering_symbol(&row); + if entering.type_() == SymbolType::Invalid { + return Err(InternalSolverError("Dual optimise failed.")); + } + // pivot the entering symbol into the basis + row.solve_for_symbols(leaving, entering); + self.substitute(entering, &row); + if entering.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&entering].clone(); + self.var_changed(v); + } + self.rows.insert(entering, row); + } + } + Ok(()) + } + + /// Compute the entering symbol for the dual optimize operation. + /// + /// This method will return the symbol in the row which has a positive + /// coefficient and yields the minimum ratio for its respective symbol + /// in the objective function. The provided row *must* be infeasible. + /// If no symbol is found which meats the criteria, an invalid symbol + /// is returned. + /// Could return an External symbol + fn get_dual_entering_symbol(&self, row: &Row) -> Symbol { + let mut entering = Symbol::invalid(); + let mut ratio = std::f64::INFINITY; + for (symbol, value) in &row.cells { + let value = *value.as_ref(); + if value > 0.0 && symbol.type_() != SymbolType::Dummy { + let coeff = self.objective.coefficient_for(*symbol); + let r = coeff / value; + if r < ratio { + ratio = r; + entering = *symbol; + } + } + } + entering + } + + /// Compute the row which holds the exit symbol for a pivot. + /// + /// This method will return an iterator to the row in the row map + /// which holds the exit symbol. If no appropriate exit symbol is + /// found, the end() iterator will be returned. This indicates that + /// the objective function is unbounded. + /// Never returns a row for an External symbol + fn get_leaving_row(&mut self, entering: Symbol) -> Option<(Symbol, Row)> { + let mut ratio = std::f64::INFINITY; + let mut found = None; + for (symbol, row) in &self.rows { + if symbol.type_() != SymbolType::External { + let temp = row.coefficient_for(entering); + if temp < 0.0 { + let temp_ratio = -row.constant.as_ref() / temp; + if temp_ratio < ratio { + ratio = temp_ratio; + found = Some(*symbol); + } + } + } + } + found.map(|s| (s, self.rows.remove(&s).unwrap())) + } + + /// Compute the leaving row for a marker variable. + /// + /// This method will return an iterator to the row in the row map + /// which holds the given marker variable. The row will be chosen + /// according to the following precedence: + /// + /// 1) The row with a restricted basic varible and a negative coefficient + /// for the marker with the smallest ratio of -constant / coefficient. + /// + /// 2) The row with a restricted basic variable and the smallest ratio + /// of constant / coefficient. + /// + /// 3) The last unrestricted row which contains the marker. + /// + /// If the marker does not exist in any row, the row map end() iterator + /// will be returned. This indicates an internal solver error since + /// the marker *should* exist somewhere in the tableau. + fn get_marker_leaving_row(&mut self, marker: Symbol) -> Option<(Symbol, Row)> { + let mut r1 = std::f64::INFINITY; + let mut r2 = r1; + let mut first = None; + let mut second = None; + let mut third = None; + for (symbol, row) in &self.rows { + let c = row.coefficient_for(marker); + let row_constant = row.constant.as_ref(); + if c == 0.0 { + continue; + } + if symbol.type_() == SymbolType::External { + third = Some(*symbol); + } else if c < 0.0 { + let r = -row_constant / c; + if r < r1 { + r1 = r; + first = Some(*symbol); + } + } else { + let r = row_constant / c; + if r < r2 { + r2 = r; + second = Some(*symbol); + } + } + } + first + .or(second) + .or(third) + .and_then(|s| { + if s.type_() == SymbolType::External && *self.rows[&s].constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&s].clone(); + self.var_changed(v); + } + self.rows + .remove(&s) + .map(|r| (s, r)) + }) + } + + /// Remove the effects of a constraint on the objective function. + fn remove_constraint_effects(&mut self, cn: &Constraint, tag: &Tag) { + if tag.marker.type_() == SymbolType::Error { + self.remove_marker_effects(tag.marker, cn.strength()); + } else if tag.other.type_() == SymbolType::Error { + self.remove_marker_effects(tag.other, cn.strength()); + } + } + + /// Remove the effects of an error marker on the objective function. + fn remove_marker_effects(&mut self, marker: Symbol, strength: f64) { + if let Some(row) = self.rows.get(&marker) { + self.objective.insert_row(row, -strength); + } else { + self.objective.insert_symbol(marker, -strength); + } + } + + /// Get the stored value for a variable. + /// + /// Normally values should be retrieved and updated using `fetch_changes`, but + /// this method can be used for debugging or testing. + pub fn get_value(&self, v: T) -> f64 { + self.var_data.get(&v).and_then(|s| { + self.rows.get(&s.1).map(|r| r.constant) + }) + .map(|o| o.into_inner()) + .unwrap_or(0.0) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + + use std::cell::RefCell; + use std::collections::HashMap; + use std::rc::Rc; + + #[derive(Clone, Debug, Hash, PartialEq, Eq)] + enum Variable { + Left(u8), Width(u8), + } + derive_syntax_for!(Variable); + + #[test] + fn example() { + let mut names = HashMap::new(); + fn print_changes(names: &HashMap, changes: &[(Variable, f64)]) { + println!("Changes:"); + for &(ref var, ref val) in changes { + println!("{}: {}", names[var], val); + } + } + + let window_width = Variable::new(); + names.insert(window_width, "window_width"); + struct Element { + left: Variable, + right: Variable + } + let box1 = Element { + left: Variable::Left(1), + right: Variable::Right(1) + }; + names.insert(box1.left, "box1.left"); + names.insert(box1.right, "box1.right"); + let box2 = Element { + left: Variable::Left(2), + right: Variable::Right(2) + }; + names.insert(box2.left, "box2.left"); + names.insert(box2.right, "box2.right"); + let mut solver = Solver::new(); + solver + .add_constraints( + vec![ + window_width |GE(REQUIRED)| 0.0, // positive window width + box1.left |EQ(REQUIRED)| 0.0, // left align + box2.right |EQ(REQUIRED)| window_width, // right align + box2.left |GE(REQUIRED)| box1.right, // no overlap + // positive widths + box1.left |LE(REQUIRED)| box1.right, + box2.left |LE(REQUIRED)| box2.right, + // preferred widths: + box1.right - box1.left |EQ(WEAK)| 50.0, + box2.right - box2.left |EQ(WEAK)| 100.0 + ] + ) + .expect("Could not add box constraints"); + solver.add_edit_variable(window_width, STRONG).unwrap(); + solver.suggest_value(window_width, 300.0).unwrap(); + print_changes(&names, solver.fetch_changes()); + solver.suggest_value(window_width, 75.0).unwrap(); + print_changes(&names, solver.fetch_changes()); + solver.add_constraint( + (box1.right - box1.left) / 50.0 |EQ(MEDIUM)| (box2.right - box2.left) / 100.0 + ).unwrap(); + print_changes(&names, solver.fetch_changes()); + } + + #[derive(Clone, Default)] + struct Values(Rc>>); + + impl Values { + fn value_of(&self, var: Variable) -> f64 { + *self.0.borrow().get(&var).unwrap_or(&0.0) + } + fn update_values(&self, changes: &[(Variable, f64)]) { + for &(ref var, ref value) in changes { + println!("{:?} changed to {:?}", var, value); + self.0.borrow_mut().insert(*var, *value); + } + } + } + + pub fn new_values() -> (Box f64>, Box) { + let values = Values(Rc::new(RefCell::new(HashMap::new()))); + let value_of = { + let values = values.clone(); + move |v| values.value_of(v) + }; + let update_values = { + let values = values.clone(); + move |changes: &[_]| { + values.update_values(changes); + } + }; + (Box::new(value_of), Box::new(update_values)) + } + + #[test] + fn test_quadrilateral() { + use cassowary::strength::{WEAK, STRONG, REQUIRED}; + struct Point { + x: Variable, + y: Variable + } + impl Point { + fn new() -> Point { + Point { + x: Variable::new(), + y: Variable::new() + } + } + } + let (value_of, update_values) = new_values(); + + let points = [Point::new(), + Point::new(), + Point::new(), + Point::new()]; + let point_starts = [(10.0, 10.0), (10.0, 200.0), (200.0, 200.0), (200.0, 10.0)]; + let midpoints = [Point::new(), + Point::new(), + Point::new(), + Point::new()]; + let mut solver = Solver::new(); + let mut weight = 1.0; + let multiplier = 2.0; + for i in 0..4 { + solver + .add_constraints( + vec![points[i].x |EQ(WEAK * weight)| point_starts[i].0, + points[i].y |EQ(WEAK * weight)| point_starts[i].1] + ) + .expect("Could not add initial quad points"); + weight *= multiplier; + } + + for (start, end) in vec![(0, 1), (1, 2), (2, 3), (3, 0)] { + solver + .add_constraints( + vec![midpoints[start].x |EQ(REQUIRED)| (points[start].x + points[end].x) / 2.0, + midpoints[start].y |EQ(REQUIRED)| (points[start].y + points[end].y) / 2.0] + ) + .expect("Could not add quad midpoints"); + } + + solver + .add_constraints( + vec![points[0].x + 20.0 |LE(STRONG)| points[2].x, + points[0].x + 20.0 |LE(STRONG)| points[3].x, + + points[1].x + 20.0 |LE(STRONG)| points[2].x, + points[1].x + 20.0 |LE(STRONG)| points[3].x, + + points[0].y + 20.0 |LE(STRONG)| points[1].y, + points[0].y + 20.0 |LE(STRONG)| points[2].y, + + points[3].y + 20.0 |LE(STRONG)| points[1].y, + points[3].y + 20.0 |LE(STRONG)| points[2].y] + ) + .expect("Could not add quad midpoint constraints"); + + for point in &points { + solver + .add_constraints( + vec![point.x |GE(REQUIRED)| 0.0, + point.y |GE(REQUIRED)| 0.0, + + point.x |LE(REQUIRED)| 500.0, + point.y |LE(REQUIRED)| 500.0] + ) + .expect("Could not add required bounds on quad"); + } + + update_values(solver.fetch_changes()); + + assert_eq!([(value_of(midpoints[0].x), value_of(midpoints[0].y)), + (value_of(midpoints[1].x), value_of(midpoints[1].y)), + (value_of(midpoints[2].x), value_of(midpoints[2].y)), + (value_of(midpoints[3].x), value_of(midpoints[3].y))], + [(10.0, 105.0), + (105.0, 200.0), + (200.0, 105.0), + (105.0, 10.0)]); + + solver.add_edit_variable(points[2].x, STRONG).expect("Could not add x edit variable for 2nd point"); + solver.add_edit_variable(points[2].y, STRONG).expect("Could not add y edit variable for 2nd point"); + solver.suggest_value(points[2].x, 300.0).expect("Could not suggest value for x edit variable for 2nd point"); + solver.suggest_value(points[2].y, 400.0).expect("Could not suggest value for y edit variable for 2nd point"); + + update_values(solver.fetch_changes()); + + assert_eq!([(value_of(points[0].x), value_of(points[0].y)), + (value_of(points[1].x), value_of(points[1].y)), + (value_of(points[2].x), value_of(points[2].y)), + (value_of(points[3].x), value_of(points[3].y))], + [(10.0, 10.0), + (10.0, 200.0), + (300.0, 400.0), + (200.0, 10.0)]); + + assert_eq!([(value_of(midpoints[0].x), value_of(midpoints[0].y)), + (value_of(midpoints[1].x), value_of(midpoints[1].y)), + (value_of(midpoints[2].x), value_of(midpoints[2].y)), + (value_of(midpoints[3].x), value_of(midpoints[3].y))], + [(10.0, 105.0), + (155.0, 300.0), + (250.0, 205.0), + (105.0, 10.0)]); + } + + #[test] + fn remove_constraint() { + let (value_of, update_values) = new_values(); + + let mut solver = Solver::new(); + + let val = Variable::new(); + + let constraint: Constraint = val | EQ(REQUIRED) | 100.0; + solver.add_constraint(constraint.clone()).unwrap(); + update_values(solver.fetch_changes()); + + assert_eq!(value_of(val), 100.0); + + solver.remove_constraint(&constraint).unwrap(); + solver.add_constraint(val | EQ(REQUIRED) | 0.0).unwrap(); + update_values(solver.fetch_changes()); + + assert_eq!(value_of(val), 0.0); + } +} diff --git a/src/derive_syntax.rs b/src/derive_syntax.rs new file mode 100644 index 0000000..378d769 --- /dev/null +++ b/src/derive_syntax.rs @@ -0,0 +1,335 @@ +#[macro_export] +macro_rules! derive_bitor_for { + ( $x:ty ) => { +// impl BitOr for f64 +// { +// type Output = PartialConstraint<$x>; +// fn bitor(self, r: WeightedRelation) -> PartialConstraint<$x> { +// PartialConstraint(self.into(), r) +// } +// } +// impl BitOr for f32 { +// type Output = PartialConstraint<$x>; +// fn bitor(self, r: WeightedRelation) -> PartialConstraint<$x> { +// (self as f64).bitor(r) +// } +// } + impl BitOr for $x { + type Output = PartialConstraint<$x>; + fn bitor(self, r: WeightedRelation) -> PartialConstraint<$x> { + PartialConstraint(self.into(), r) + } + } + } +} + +/// Derives operator support for your cassowary solver variable type. +/// This allows you to use your variable type in writing expressions, to a limited extent. +#[macro_export] +macro_rules! derive_syntax_for { + ( $x:ty ) => { + impl From<$x> for Expression<$x> { + fn from(v: $x) -> Expression<$x> { + Expression::from_term(Term::new(v, 1.0)) + } + } + + impl Add for $x { + type Output = Expression<$x>; + fn add(self, v: f64) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0)], v) + } + } + + impl Add for $x { + type Output = Expression<$x>; + fn add(self, v: f32) -> Expression<$x> { + self.add(v as f64) + } + } + + impl Add<$x> for f64 { + type Output = Expression<$x>; + fn add(self, v: $x) -> Expression<$x> { + Expression::new(vec![Term::new(v, 1.0)], self) + } + } + + impl Add<$x> for f32 { + type Output = Expression<$x>; + fn add(self, v: $x) -> Expression<$x> { + (self as f64).add(v) + } + } + + impl Add<$x> for $x { + type Output = Expression<$x>; + fn add(self, v: $x) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0), Term::new(v, 1.0)], 0.0) + } + } + + impl Add> for $x { + type Output = Expression<$x>; + fn add(self, t: Term<$x>) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0), t], 0.0) + } + } + + impl Add<$x> for Term<$x> { + type Output = Expression<$x>; + fn add(self, v: $x) -> Expression<$x> { + Expression::new(vec![self, Term::new(v, 1.0)], 0.0) + } + } + + impl Add> for $x { + type Output = Expression<$x>; + fn add(self, mut e: Expression<$x>) -> Expression<$x> { + e.terms.push(Term::new(self, 1.0)); + e + } + } + + impl Add<$x> for Expression<$x> { + type Output = Expression<$x>; + fn add(mut self, v: $x) -> Expression<$x> { + self += v; + self + } + } + + impl AddAssign<$x> for Expression<$x> { + fn add_assign(&mut self, v: $x) { + self.terms.push(Term::new(v, 1.0)); + } + } + + impl Neg for $x { + type Output = Term<$x>; + fn neg(self) -> Term<$x> { + Term::new(self, -1.0) + } + } + + impl Sub for $x { + type Output = Expression<$x>; + fn sub(self, v: f64) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0)], -v) + } + } + + impl Sub for $x { + type Output = Expression<$x>; + fn sub(self, v: f32) -> Expression<$x> { + self.sub(v as f64) + } + } + + impl Sub<$x> for f64 { + type Output = Expression<$x>; + fn sub(self, v: $x) -> Expression<$x> { + Expression::new(vec![Term::new(v, -1.0)], self) + } + } + + impl Sub<$x> for f32 { + type Output = Expression<$x>; + fn sub(self, v: $x) -> Expression<$x> { + (self as f64).sub(v) + } + } + + impl Sub<$x> for $x { + type Output = Expression<$x>; + fn sub(self, v: $x) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0), Term::new(v, -1.0)], 0.0) + } + } + + impl Sub> for $x { + type Output = Expression<$x>; + fn sub(self, t: Term<$x>) -> Expression<$x> { + Expression::new(vec![Term::new(self, 1.0), -t], 0.0) + } + } + + impl Sub<$x> for Term<$x> { + type Output = Expression<$x>; + fn sub(self, v: $x) -> Expression<$x> { + Expression::new(vec![self, Term::new(v, -1.0)], 0.0) + } + } + + impl Sub> for $x { + type Output = Expression<$x>; + fn sub(self, mut e: Expression<$x>) -> Expression<$x> { + e.negate(); + e.terms.push(Term::new(self, 1.0)); + e + } + } + + impl Sub<$x> for Expression<$x> { + type Output = Expression<$x>; + fn sub(mut self, v: $x) -> Expression<$x> { + self -= v; + self + } + } + + impl SubAssign<$x> for Expression<$x> { + fn sub_assign(&mut self, v: $x) { + self.terms.push(Term::new(v, -1.0)); + } + } + + impl Mul for $x { + type Output = Term<$x>; + fn mul(self, v: f64) -> Term<$x> { + Term::new(self, v) + } + } + + impl Mul for $x { + type Output = Term<$x>; + fn mul(self, v: f32) -> Term<$x> { + self.mul(v as f64) + } + } + + impl Mul<$x> for f64 { + type Output = Term<$x>; + fn mul(self, v: $x) -> Term<$x> { + Term::new(v, self) + } + } + + impl Mul<$x> for f32 { + type Output = Term<$x>; + fn mul(self, v: $x) -> Term<$x> { + (self as f64).mul(v) + } + } + + impl Div for $x { + type Output = Term<$x>; + fn div(self, v: f64) -> Term<$x> { + Term::new(self, 1.0 / v) + } + } + + impl Div for $x { + type Output = Term<$x>; + fn div(self, v: f32) -> Term<$x> { + self.div(v as f64) + } + } + + impl BitOr<$x> for PartialConstraint<$x> { + type Output = Constraint<$x>; + fn bitor(self, rhs: $x) -> Constraint<$x> { + let (op, s) = self.1.into(); + Constraint::new(self.0 - rhs, op, s) + } + } + + impl Constrainable<$x> for $x { + fn equal_to(self, x:X) -> Constraint<$x> where X: Into> + Clone { + let lhs:Expression<$x> = + self + .into(); + let rhs:Expression<$x> = + x.into(); + lhs.equal_to(rhs) + } + fn greater_than_or_equal_to(self, x:X) -> Constraint<$x> where X: Into> + Clone { + let lhs:Expression<$x> = + self + .into(); + let rhs:Expression<$x> = + x.into(); + lhs.is_ge(rhs) + } + fn less_than_or_equal_to(self, x:X) -> Constraint<$x> where X: Into> + Clone { + let lhs:Expression<$x> = + self + .into(); + let rhs:Expression<$x> = + x.into(); + lhs.is_le(rhs) + } + } + }; +} + + +#[cfg(test)] +mod tests { + use super::super::{ + Constrainable, + Constraint, + Expression, + PartialConstraint, + Solver, + Term + }; + + use std::ops::*; + + + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + enum VariableX { + Left(usize), Width(usize) + } + derive_syntax_for!(VariableX); + + + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + enum VariableY { + Top(usize), Height(usize) + } + derive_syntax_for!(VariableY); + + + struct Element(usize); + + + impl Element { + fn left(&self) -> VariableX { + VariableX::Left(self.0) + } + fn width(&self) -> VariableX { + VariableX::Width(self.0) + } + fn _top(&self) -> VariableX { + VariableX::Left(self.0) + } + fn _height(&self) -> VariableX { + VariableX::Width(self.0) + } + } + + + #[test] + fn can_do_ops() { + let el0 = + Element(0); + let el1 = + Element(1); + + let mut solver_x = Solver::new(); + solver_x + .add_constraints( + vec![ + el0.left().is(0.0), + el0.width().is(100.0), + el1.left().is_ge(el0.left() + el0.width()) + ] + ) + .unwrap(); + assert_eq!(solver_x.get_value(el0.left()), 0.0); + assert_eq!(solver_x.get_value(el0.width()), 100.0); + assert_eq!(solver_x.get_value(el1.left()), 100.0); + } +} diff --git a/src/lib.rs b/src/lib.rs index b883aa7..739081f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,8 +31,10 @@ //! //! First we need to include the relevant parts of `cassowary`: //! -//! ``` -//! use cassowary::{ Solver, Variable }; +//! ```ignore +//! extern crate cassowary_variable; +//! use cassowary::{ Solver }; +//! use cassowary_variable::Variable; //! use cassowary::WeightedRelation::*; //! use cassowary::strength::{ WEAK, MEDIUM, STRONG, REQUIRED }; //! ``` @@ -157,58 +159,10 @@ //! to control the behaviour when the preferred widths cannot both be satisfied. In this example we are going //! to constrain the boxes to try to maintain a ratio between their widths. //! -//! ``` -//! # use cassowary::{ Solver, Variable }; -//! # use cassowary::WeightedRelation::*; -//! # use cassowary::strength::{ WEAK, MEDIUM, STRONG, REQUIRED }; -//! # -//! # use std::collections::HashMap; -//! # let mut names = HashMap::new(); -//! # fn print_changes(names: &HashMap, changes: &[(Variable, f64)]) { -//! # println!("Changes:"); -//! # for &(ref var, ref val) in changes { -//! # println!("{}: {}", names[var], val); -//! # } -//! # } -//! # -//! # let window_width = Variable::new(); -//! # names.insert(window_width, "window_width"); -//! # struct Element { -//! # left: Variable, -//! # right: Variable -//! # } -//! # let box1 = Element { -//! # left: Variable::new(), -//! # right: Variable::new() -//! # }; -//! # names.insert(box1.left, "box1.left"); -//! # names.insert(box1.right, "box1.right"); -//! # let box2 = Element { -//! # left: Variable::new(), -//! # right: Variable::new() -//! # }; -//! # names.insert(box2.left, "box2.left"); -//! # names.insert(box2.right, "box2.right"); -//! # let mut solver = Solver::new(); -//! # solver.add_constraints(&[window_width |GE(REQUIRED)| 0.0, // positive window width -//! # box1.left |EQ(REQUIRED)| 0.0, // left align -//! # box2.right |EQ(REQUIRED)| window_width, // right align -//! # box2.left |GE(REQUIRED)| box1.right, // no overlap -//! # // positive widths -//! # box1.left |LE(REQUIRED)| box1.right, -//! # box2.left |LE(REQUIRED)| box2.right, -//! # // preferred widths: -//! # box1.right - box1.left |EQ(WEAK)| 50.0, -//! # box2.right - box2.left |EQ(WEAK)| 100.0]).unwrap(); -//! # solver.add_edit_variable(window_width, STRONG).unwrap(); -//! # solver.suggest_value(window_width, 300.0).unwrap(); -//! # print_changes(&names, solver.fetch_changes()); -//! # solver.suggest_value(window_width, 75.0); -//! # print_changes(&names, solver.fetch_changes()); -//! solver.add_constraint( -//! (box1.right - box1.left) / 50.0 |EQ(MEDIUM)| (box2.right - box2.left) / 100.0 -//! ).unwrap(); -//! print_changes(&names, solver.fetch_changes()); +//! These docs are all out of date: +//! +//! ```ignore + //! ``` //! //! Now the result gives values that maintain the ratio between the sizes of the two boxes: @@ -225,100 +179,112 @@ //! One thing that this example exposes is that this crate is a rather low level library. It does not have //! any inherent knowledge of user interfaces, directions or boxes. Thus for use in a user interface this //! crate should ideally be wrapped by a higher level API, which is outside the scope of this crate. -use std::sync::Arc; +extern crate ordered_float; + use std::collections::HashMap; -use std::collections::hash_map::{Entry}; +use std::collections::hash_map::Entry; +use ordered_float::OrderedFloat; mod solver_impl; mod operators; +pub use operators::Constrainable; +#[macro_use] +pub mod derive_syntax; -static VARIABLE_ID: ::std::sync::atomic::AtomicUsize = ::std::sync::atomic::ATOMIC_USIZE_INIT; - -/// Identifies a variable for the constraint solver. -/// Each new variable is unique in the view of the solver, but copying or cloning the variable produces -/// a copy of the same variable. -#[derive(Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub struct Variable(usize); - -impl Variable { - /// Produces a new unique variable for use in constraint solving. - pub fn new() -> Variable { - Variable(VARIABLE_ID.fetch_add(1, ::std::sync::atomic::Ordering::Relaxed)) - } -} /// A variable and a coefficient to multiply that variable by. This is a sub-expression in /// a constraint equation. -#[derive(Copy, Clone, Debug)] -pub struct Term { - pub variable: Variable, - pub coefficient: f64 +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct Term { + pub variable: T, + pub coefficient: OrderedFloat } -impl Term { +impl Term { /// Construct a new Term from a variable and a coefficient. - fn new(variable: Variable, coefficient: f64) -> Term { + pub fn new(variable: T, coefficient: f64) -> Term { Term { variable: variable, - coefficient: coefficient + coefficient: coefficient.into() } } } /// An expression that can be the left hand or right hand side of a constraint equation. /// It is a linear combination of variables, i.e. a sum of variables weighted by coefficients, plus an optional constant. -#[derive(Clone, Debug)] -pub struct Expression { - pub terms: Vec, - pub constant: f64 +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct Expression { + pub terms: Vec>, + pub constant: OrderedFloat } -impl Expression { +impl Expression { /// Constructs an expression of the form _n_, where n is a constant real number, not a variable. - pub fn from_constant(v: f64) -> Expression { + pub fn from_constant(v: f64) -> Expression { Expression { terms: Vec::new(), - constant: v + constant: v.into() } } /// Constructs an expression from a single term. Forms an expression of the form _n x_ /// where n is the coefficient, and x is the variable. - pub fn from_term(term: Term) -> Expression { + pub fn from_term(term: Term) -> Expression { Expression { terms: vec![term], - constant: 0.0 + constant: 0.0.into() } } /// General constructor. Each `Term` in `terms` is part of the sum forming the expression, as well as `constant`. - pub fn new(terms: Vec, constant: f64) -> Expression { + pub fn new(terms: Vec>, constant: f64) -> Expression { Expression { terms: terms, - constant: constant + constant: constant.into() } } /// Mutates this expression by multiplying it by minus one. pub fn negate(&mut self) { - self.constant = -self.constant; + self.constant = (-(self.constant.into_inner())).into(); for t in &mut self.terms { - *t = -*t; + let t2 = t.clone(); + *t = -t2; } } } -impl From for Expression { - fn from(v: f64) -> Expression { - Expression::from_constant(v) +impl Constrainable for Expression { + fn equal_to(self, x: X) -> Constraint where X: Into> + Clone { + self |WeightedRelation::EQ(strength::REQUIRED) | x.into() + } + + fn greater_than_or_equal_to(self, x: X) -> Constraint where X: Into> + Clone { + self |WeightedRelation::GE(strength::REQUIRED) | x.into() + } + + fn less_than_or_equal_to(self, x: X) -> Constraint where X: Into> + Clone { + self |WeightedRelation::LE(strength::REQUIRED) | x.into() } } -impl From for Expression { - fn from(v: Variable) -> Expression { - Expression::from_term(Term::new(v, 1.0)) +impl From for Expression { + fn from(v: f64) -> Expression { + Expression::from_constant(v) } } -impl From for Expression { - fn from(t: Term) -> Expression { +impl From for Expression { + fn from(v: i32) -> Expression { + Expression::from_constant(v as f64) + } +} + +impl From for Expression { + fn from(v: u32) -> Expression { + Expression::from_constant(v as f64) + } +} + +impl From> for Expression { + fn from(t: Term) -> Expression { Expression::from_term(t) } } @@ -382,61 +348,51 @@ impl std::fmt::Display for RelationalOperator { } } -#[derive(Debug)] -struct ConstraintData { - expression: Expression, - strength: f64, - op: RelationalOperator -} /// A constraint, consisting of an equation governed by an expression and a relational operator, /// and an associated strength. -#[derive(Clone, Debug)] -pub struct Constraint(Arc); +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Constraint{ + expression: Expression, + strength: OrderedFloat, + op: RelationalOperator +} -impl Constraint { + +impl Constraint { /// Construct a new constraint from an expression, a relational operator and a strength. /// This corresponds to the equation `e op 0.0`, e.g. `x + y >= 0.0`. For equations with a non-zero /// right hand side, subtract it from the equation to give a zero right hand side. - pub fn new(e: Expression, op: RelationalOperator, strength: f64) -> Constraint { - Constraint(Arc::new(ConstraintData { + pub fn new(e: Expression, op: RelationalOperator, strength: f64) -> Constraint { + Constraint{ expression: e, op: op, - strength: strength - })) + strength: strength.into() + } } /// The expression of the left hand side of the constraint equation. - pub fn expr(&self) -> &Expression { - &self.0.expression + pub fn expr(&self) -> &Expression { + &self.expression } /// The relational operator governing the constraint. pub fn op(&self) -> RelationalOperator { - self.0.op + self.op } /// The strength of the constraint that the solver will use. pub fn strength(&self) -> f64 { - self.0.strength + self.strength.into_inner() } -} - -impl ::std::hash::Hash for Constraint { - fn hash(&self, hasher: &mut H) { - use ::std::ops::Deref; - hasher.write_usize(self.0.deref() as *const _ as usize); - } -} - -impl PartialEq for Constraint { - fn eq(&self, other: &Constraint) -> bool { - use ::std::ops::Deref; - self.0.deref() as *const _ == other.0.deref() as *const _ + /// Set the strength in builder-style + pub fn with_strength(self, s:f64) -> Self { + let mut c = self; + c.strength = s.into(); + c } } -impl Eq for Constraint {} - /// This is part of the syntactic sugar used for specifying constraints. This enum should be used as part of a /// constraint expression. See the module documentation for more information. +#[derive(Debug)] pub enum WeightedRelation { /// `==` EQ(f64), @@ -458,9 +414,11 @@ impl From for (RelationalOperator, f64) { /// This is an intermediate type used in the syntactic sugar for specifying constraints. You should not use it /// directly. -pub struct PartialConstraint(Expression, WeightedRelation); +#[derive(Debug)] +pub struct PartialConstraint(pub Expression, pub WeightedRelation); #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug)] enum SymbolType { Invalid, External, @@ -470,17 +428,59 @@ enum SymbolType { } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug)] struct Symbol(usize, SymbolType); impl Symbol { + /// Choose the subject for solving for the row. + /// + /// This method will choose the best subject for using as the solve + /// target for the row. An invalid symbol will be returned if there + /// is no valid target. + /// + /// The symbols are chosen according to the following precedence: + /// + /// 1) The first symbol representing an external variable. + /// 2) A negative slack or error tag variable. + /// + /// If a subject cannot be found, an invalid symbol will be returned. + fn choose_subject(row: &Row, tag: &Tag) -> Symbol { + for s in row.cells.keys() { + if s.type_() == SymbolType::External { + return *s + } + } + if tag.marker.type_() == SymbolType::Slack || tag.marker.type_() == SymbolType::Error { + if row.coefficient_for(tag.marker) < 0.0 { + return tag.marker; + } + } + if tag.other.type_() == SymbolType::Slack || tag.other.type_() == SymbolType::Error { + if row.coefficient_for(tag.other) < 0.0 { + return tag.other; + } + } + Symbol::invalid() + } + fn invalid() -> Symbol { Symbol(0, SymbolType::Invalid) } fn type_(&self) -> SymbolType { self.1 } } + +#[derive(Copy, Clone)] +#[derive(Debug)] +struct Tag { + marker: Symbol, + other: Symbol +} + + #[derive(Clone)] +#[derive(Debug)] struct Row { - cells: HashMap, - constant: f64 + cells: HashMap>, + constant: OrderedFloat } fn near_zero(value: f64) -> bool { @@ -493,24 +493,26 @@ fn near_zero(value: f64) -> bool { } impl Row { - fn new(constant: f64) -> Row { + pub fn new(constant: f64) -> Row { Row { cells: HashMap::new(), - constant: constant + constant: constant.into() } } fn add(&mut self, v: f64) -> f64 { - self.constant += v; - self.constant + *(self.constant.as_mut()) += v; + self.constant.into_inner() } fn insert_symbol(&mut self, s: Symbol, coefficient: f64) { match self.cells.entry(s) { Entry::Vacant(entry) => if !near_zero(coefficient) { - entry.insert(coefficient); + entry.insert(coefficient.into()); }, Entry::Occupied(mut entry) => { - *entry.get_mut() += coefficient; - if near_zero(*entry.get_mut()) { + let ofloat = entry.get_mut(); + let float = ofloat.as_mut(); + *float += coefficient; + if near_zero(*float) { entry.remove(); } } @@ -518,10 +520,10 @@ impl Row { } fn insert_row(&mut self, other: &Row, coefficient: f64) -> bool { - let constant_diff = other.constant * coefficient; - self.constant += constant_diff; + let constant_diff = other.constant.as_ref() * coefficient; + *self.constant.as_mut() += constant_diff; for (s, v) in &other.cells { - self.insert_symbol(*s, v * coefficient); + self.insert_symbol(*s, v.into_inner() * coefficient); } constant_diff != 0.0 } @@ -531,20 +533,20 @@ impl Row { } fn reverse_sign(&mut self) { - self.constant = -self.constant; + *self.constant.as_mut() *= -1.0; for (_, v) in &mut self.cells { - *v = -*v; + *v.as_mut() *= -1.0; } } fn solve_for_symbol(&mut self, s: Symbol) { let coeff = -1.0 / match self.cells.entry(s) { - Entry::Occupied(entry) => entry.remove(), + Entry::Occupied(entry) => entry.remove().into_inner(), Entry::Vacant(_) => unreachable!() }; - self.constant *= coeff; + *self.constant.as_mut() *= coeff; for (_, v) in &mut self.cells { - *v *= coeff; + *v.as_mut() *= coeff; } } @@ -554,16 +556,55 @@ impl Row { } fn coefficient_for(&self, s: Symbol) -> f64 { - self.cells.get(&s).cloned().unwrap_or(0.0) + self.cells.get(&s).cloned().map(|o| o.into_inner()).unwrap_or(0.0) } fn substitute(&mut self, s: Symbol, row: &Row) -> bool { if let Some(coeff) = self.cells.remove(&s) { - self.insert_row(row, coeff) + self.insert_row(row, coeff.into()) } else { false } } + + /// Test whether a row is composed of all dummy variables. + fn all_dummies(&self) -> bool { + for symbol in self.cells.keys() { + if symbol.type_() != SymbolType::Dummy { + return false; + } + } + true + } + + /// Get the first Slack or Error symbol in the row. + /// + /// If no such symbol is present, and Invalid symbol will be returned. + /// Never returns an External symbol + fn any_pivotable_symbol(&self) -> Symbol { + for symbol in self.cells.keys() { + if symbol.type_() == SymbolType::Slack || symbol.type_() == SymbolType::Error { + return *symbol; + } + } + Symbol::invalid() + } + + /// Compute the entering variable for a pivot operation. + /// + /// This method will return first symbol in the objective function which + /// is non-dummy and has a coefficient less than zero. If no symbol meets + /// the criteria, it means the objective function is at a minimum, and an + /// invalid symbol is returned. + /// Could return an External symbol + fn get_entering_symbol(&self) -> Symbol { + for (symbol, value) in &self.cells { + if symbol.type_() != SymbolType::Dummy && *value.as_ref() < 0.0 { + return symbol.clone(); + } + } + Symbol::invalid() + } } /// The possible error conditions that `Solver::add_constraint` can fail with. @@ -621,3 +662,229 @@ pub enum SuggestValueError { struct InternalSolverError(&'static str); pub use solver_impl::Solver; + + +#[cfg(test)] +mod tests { + use super::*; + use super::strength::*; + use super::WeightedRelation::*; + //use std::cell::RefCell; + use std::collections::HashMap; + //use std::rc::Rc; + use std::ops::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + + static NEXT_K:AtomicUsize = AtomicUsize::new(0); + + #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] + pub struct Variable(usize); + derive_syntax_for!(Variable); + derive_bitor_for!(Variable); + + impl Variable { + pub fn new() -> Variable { + Variable(NEXT_K.fetch_add(1, Ordering::Relaxed)) + } + } + + #[test] + fn example() { + let mut names = HashMap::new(); + fn print_changes(names: &HashMap, changes: &[(Variable, f64)]) { + println!("Changes:"); + for &(ref var, ref val) in changes { + println!("{}: {}", names[var], val); + } + } + + let window_width = Variable::new(); + names.insert(window_width, "window_width"); + struct Element { + left: Variable, + right: Variable + } + let box1 = Element { + left: Variable::new(), + right: Variable::new() + }; + names.insert(box1.left, "box1.left"); + names.insert(box1.right, "box1.right"); + let box2 = Element { + left: Variable::new(), + right: Variable::new() + }; + names.insert(box2.left, "box2.left"); + names.insert(box2.right, "box2.right"); + let mut solver = Solver::new(); + + //solver + // .add_edit_variable(window_width, STRONG) + // .expect("Could not add window width edit var"); + //solver + // .suggest_value(window_width, 1000.0) + // .expect("Could not suggest window width = 1000"); + solver + .add_constraint(window_width |GE(REQUIRED)| 0.0) + .expect("Could not add window width >= 0"); + solver + .add_constraint(window_width |LE(REQUIRED)| 1000.0) + .expect("Could not add window width <= 1000.0"); + solver + .add_constraint(box1.left |EQ(REQUIRED)| 0.0) + .expect("Could not add left align constraint"); + solver + .add_constraint(box2.right |EQ(REQUIRED)| window_width) + .expect("Could not add right align constraint"); + solver + .add_constraint(box2.left |GE(REQUIRED)| box1.right) + .expect("Could not add no overlap constraint"); + + solver + .add_constraint(box1.right |EQ(WEAK)| box1.left + 50.0) + .expect("Could not add box1 width constraint"); + solver + .add_constraint(box2.right |EQ(WEAK)| box2.left + 100.0) + .expect("Could not add box2 width constraint"); + + solver + .add_constraint(box1.left |LE(REQUIRED)| box1.right) + .expect("Could not add box1 positive width constraint"); + solver + .add_constraint(box2.left |LE(REQUIRED)| box2.right) + .expect("Could not add box2 positive width constraint"); + + //print_changes(&names, solver.fetch_changes()); + //solver + // .suggest_value(window_width, 75.0) + // .expect("Could not suggest window width = 75"); + //print_changes(&names, solver.fetch_changes()); + //solver.add_constraint( + // (box1.right - box1.left) / 50.0 |EQ(MEDIUM)| (box2.right - box2.left) / 100.0 + //).unwrap(); + print_changes(&names, solver.fetch_changes()); + } + + // #[test] + fn _test_quadrilateral() { + struct Point { + x: Variable, + y: Variable + } + impl Point { + fn new() -> Point { + Point { + x: Variable::new(), + y: Variable::new() + } + } + } + + let points = [Point::new(), + Point::new(), + Point::new(), + Point::new()]; + let point_starts = [(10.0, 10.0), (10.0, 200.0), (200.0, 200.0), (200.0, 10.0)]; + let midpoints = [Point::new(), + Point::new(), + Point::new(), + Point::new()]; + let mut solver = Solver::new(); + let mut weight = 1.0; + let multiplier = 2.0; + for i in 0..4 { + solver + .add_constraints( + vec![points[i].x |EQ(WEAK * weight)| point_starts[i].0, + points[i].y |EQ(WEAK * weight)| point_starts[i].1] + ) + .expect("Could not add initial quad points"); + weight *= multiplier; + } + + for (start, end) in vec![(0, 1), (1, 2), (2, 3), (3, 0)] { + solver + .add_constraints( + vec![midpoints[start].x |EQ(REQUIRED)| (points[start].x + points[end].x) / 2.0, + midpoints[start].y |EQ(REQUIRED)| (points[start].y + points[end].y) / 2.0] + ) + .expect("Could not add quad midpoints"); + } + + solver + .add_constraints( + vec![points[0].x + 20.0 |LE(STRONG)| points[2].x, + points[0].x + 20.0 |LE(STRONG)| points[3].x, + + points[1].x + 20.0 |LE(STRONG)| points[2].x, + points[1].x + 20.0 |LE(STRONG)| points[3].x, + + points[0].y + 20.0 |LE(STRONG)| points[1].y, + points[0].y + 20.0 |LE(STRONG)| points[2].y, + + points[3].y + 20.0 |LE(STRONG)| points[1].y, + points[3].y + 20.0 |LE(STRONG)| points[2].y] + ) + .expect("Could not add quad midpoint constraints"); + + for point in &points { + solver + .add_constraints( + vec![point.x |GE(REQUIRED)| 0.0, + point.y |GE(REQUIRED)| 0.0, + + point.x |LE(REQUIRED)| 500.0, + point.y |LE(REQUIRED)| 500.0] + ) + .expect("Could not add required bounds on quad"); + } + + assert_eq!([(solver.get_value(midpoints[0].x), solver.get_value(midpoints[0].y)), + (solver.get_value(midpoints[1].x), solver.get_value(midpoints[1].y)), + (solver.get_value(midpoints[2].x), solver.get_value(midpoints[2].y)), + (solver.get_value(midpoints[3].x), solver.get_value(midpoints[3].y))], + [(10.0, 105.0), + (105.0, 200.0), + (200.0, 105.0), + (105.0, 10.0)]); + + solver.add_edit_variable(points[2].x, STRONG).expect("Could not add x edit variable for 2nd point"); + solver.add_edit_variable(points[2].y, STRONG).expect("Could not add y edit variable for 2nd point"); + solver.suggest_value(points[2].x, 300.0).expect("Could not suggest value for x edit variable for 2nd point"); + solver.suggest_value(points[2].y, 400.0).expect("Could not suggest value for y edit variable for 2nd point"); + + assert_eq!([(solver.get_value(points[0].x), solver.get_value(points[0].y)), + (solver.get_value(points[1].x), solver.get_value(points[1].y)), + (solver.get_value(points[2].x), solver.get_value(points[2].y)), + (solver.get_value(points[3].x), solver.get_value(points[3].y))], + [(10.0, 10.0), + (10.0, 200.0), + (300.0, 400.0), + (200.0, 10.0)]); + + assert_eq!([(solver.get_value(midpoints[0].x), solver.get_value(midpoints[0].y)), + (solver.get_value(midpoints[1].x), solver.get_value(midpoints[1].y)), + (solver.get_value(midpoints[2].x), solver.get_value(midpoints[2].y)), + (solver.get_value(midpoints[3].x), solver.get_value(midpoints[3].y))], + [(10.0, 105.0), + (155.0, 300.0), + (250.0, 205.0), + (105.0, 10.0)]); + } + + #[test] + fn can_add_and_remove_constraints() { + let mut solver = Solver::new(); + + let var = Variable(0); + + let constraint: Constraint = var | EQ(REQUIRED) | 100.0; + solver.add_constraint(constraint.clone()).unwrap(); + assert_eq!(solver.get_value(var), 100.0); + + solver.remove_constraint(&constraint).unwrap(); + solver.add_constraint(var | EQ(REQUIRED) | 0.0).unwrap(); + assert_eq!(solver.get_value(var), 0.0); + } +} diff --git a/src/operators.rs b/src/operators.rs index 159fbb1..aeb1395 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -1,669 +1,482 @@ -use ::std::ops; +use std::ops; + use { Term, - Variable, Expression, WeightedRelation, PartialConstraint, Constraint }; -// Relation +/// A trait for creating constraints using your custom variable types without +/// using the BitOr hack. +pub trait Constrainable +where + Var: Sized, + Self: Sized +{ + fn equal_to(self, x: X) -> Constraint where X: Into> + Clone; -impl ops::BitOr for f64 { - type Output = PartialConstraint; - fn bitor(self, r: WeightedRelation) -> PartialConstraint { - PartialConstraint(self.into(), r) + fn is(self, x: X) -> Constraint where X: Into> + Clone { + self.equal_to(x) } -} -impl ops::BitOr for f32 { - type Output = PartialConstraint; - fn bitor(self, r: WeightedRelation) -> PartialConstraint { - (self as f64).bitor(r) + + fn greater_than_or_equal_to(self, x: X) -> Constraint where X: Into> + Clone; + + fn is_ge(self, x: X) -> Constraint where X: Into> + Clone { + self.greater_than_or_equal_to(x) } -} -impl ops::BitOr for Variable { - type Output = PartialConstraint; - fn bitor(self, r: WeightedRelation) -> PartialConstraint { - PartialConstraint(self.into(), r) + + fn less_than_or_equal_to(self, x: X) -> Constraint where X: Into> + Clone; + + fn is_le(self, x: X) -> Constraint where X: Into> + Clone { + self.less_than_or_equal_to(x) } } -impl ops::BitOr for Term { - type Output = PartialConstraint; - fn bitor(self, r: WeightedRelation) -> PartialConstraint { + + +// WeightedRelation + +impl ops::BitOr for Term { + type Output = PartialConstraint; + fn bitor(self, r: WeightedRelation) -> PartialConstraint { PartialConstraint(self.into(), r) } } -impl ops::BitOr for Expression { - type Output = PartialConstraint; - fn bitor(self, r: WeightedRelation) -> PartialConstraint { +impl ops::BitOr for Expression { + type Output = PartialConstraint; + fn bitor(self, r: WeightedRelation) -> PartialConstraint { PartialConstraint(self.into(), r) } } -impl ops::BitOr for PartialConstraint { - type Output = Constraint; - fn bitor(self, rhs: f64) -> Constraint { +impl ops::BitOr for PartialConstraint { + type Output = Constraint; + fn bitor(self, rhs: f64) -> Constraint { let (op, s) = self.1.into(); Constraint::new(self.0 - rhs, op, s) } } -impl ops::BitOr for PartialConstraint { - type Output = Constraint; - fn bitor(self, rhs: f32) -> Constraint { +impl ops::BitOr for PartialConstraint { + type Output = Constraint; + fn bitor(self, rhs: f32) -> Constraint { self.bitor(rhs as f64) } } -impl ops::BitOr for PartialConstraint { - type Output = Constraint; - fn bitor(self, rhs: Variable) -> Constraint { - let (op, s) = self.1.into(); - Constraint::new(self.0 - rhs, op, s) - } -} -impl ops::BitOr for PartialConstraint { - type Output = Constraint; - fn bitor(self, rhs: Term) -> Constraint { + +impl ops::BitOr> for PartialConstraint { + type Output = Constraint; + fn bitor(self, rhs: Term) -> Constraint { let (op, s) = self.1.into(); Constraint::new(self.0 - rhs, op, s) } } -impl ops::BitOr for PartialConstraint { - type Output = Constraint; - fn bitor(self, rhs: Expression) -> Constraint { +impl ops::BitOr> for PartialConstraint { + type Output = Constraint; + fn bitor(self, rhs: Expression) -> Constraint { let (op, s) = self.1.into(); Constraint::new(self.0 - rhs, op, s) } } -// Variable - -impl ops::Add for Variable { - type Output = Expression; - fn add(self, v: f64) -> Expression { - Expression::new(vec![Term::new(self, 1.0)], v) - } -} - -impl ops::Add for Variable { - type Output = Expression; - fn add(self, v: f32) -> Expression { - self.add(v as f64) - } -} - -impl ops::Add for f64 { - type Output = Expression; - fn add(self, v: Variable) -> Expression { - Expression::new(vec![Term::new(v, 1.0)], self) - } -} - -impl ops::Add for f32 { - type Output = Expression; - fn add(self, v: Variable) -> Expression { - (self as f64).add(v) - } -} - -impl ops::Add for Variable { - type Output = Expression; - fn add(self, v: Variable) -> Expression { - Expression::new(vec![Term::new(self, 1.0), Term::new(v, 1.0)], 0.0) - } -} - -impl ops::Add for Variable { - type Output = Expression; - fn add(self, t: Term) -> Expression { - Expression::new(vec![Term::new(self, 1.0), t], 0.0) - } -} - -impl ops::Add for Term { - type Output = Expression; - fn add(self, v: Variable) -> Expression { - Expression::new(vec![self, Term::new(v, 1.0)], 0.0) - } -} - -impl ops::Add for Variable { - type Output = Expression; - fn add(self, mut e: Expression) -> Expression { - e.terms.push(Term::new(self, 1.0)); - e - } -} - -impl ops::Add for Expression { - type Output = Expression; - fn add(mut self, v: Variable) -> Expression { - self += v; - self - } -} - -impl ops::AddAssign for Expression { - fn add_assign(&mut self, v: Variable) { - self.terms.push(Term::new(v, 1.0)); - } -} - -impl ops::Neg for Variable { - type Output = Term; - fn neg(self) -> Term { - Term::new(self, -1.0) - } -} - -impl ops::Sub for Variable { - type Output = Expression; - fn sub(self, v: f64) -> Expression { - Expression::new(vec![Term::new(self, 1.0)], -v) - } -} - -impl ops::Sub for Variable { - type Output = Expression; - fn sub(self, v: f32) -> Expression { - self.sub(v as f64) - } -} - -impl ops::Sub for f64 { - type Output = Expression; - fn sub(self, v: Variable) -> Expression { - Expression::new(vec![Term::new(v, -1.0)], self) - } -} - -impl ops::Sub for f32 { - type Output = Expression; - fn sub(self, v: Variable) -> Expression { - (self as f64).sub(v) - } -} - -impl ops::Sub for Variable { - type Output = Expression; - fn sub(self, v: Variable) -> Expression { - Expression::new(vec![Term::new(self, 1.0), Term::new(v, -1.0)], 0.0) - } -} - -impl ops::Sub for Variable { - type Output = Expression; - fn sub(self, t: Term) -> Expression { - Expression::new(vec![Term::new(self, 1.0), -t], 0.0) - } -} - -impl ops::Sub for Term { - type Output = Expression; - fn sub(self, v: Variable) -> Expression { - Expression::new(vec![self, Term::new(v, -1.0)], 0.0) - } -} - -impl ops::Sub for Variable { - type Output = Expression; - fn sub(self, mut e: Expression) -> Expression { - e.negate(); - e.terms.push(Term::new(self, 1.0)); - e - } -} - -impl ops::Sub for Expression { - type Output = Expression; - fn sub(mut self, v: Variable) -> Expression { - self -= v; - self - } -} - -impl ops::SubAssign for Expression { - fn sub_assign(&mut self, v: Variable) { - self.terms.push(Term::new(v, -1.0)); - } -} - -impl ops::Mul for Variable { - type Output = Term; - fn mul(self, v: f64) -> Term { - Term::new(self, v) - } -} - -impl ops::Mul for Variable { - type Output = Term; - fn mul(self, v: f32) -> Term { - self.mul(v as f64) - } -} - -impl ops::Mul for f64 { - type Output = Term; - fn mul(self, v: Variable) -> Term { - Term::new(v, self) - } -} - -impl ops::Mul for f32 { - type Output = Term; - fn mul(self, v: Variable) -> Term { - (self as f64).mul(v) - } -} - -impl ops::Div for Variable { - type Output = Term; - fn div(self, v: f64) -> Term { - Term::new(self, 1.0 / v) - } -} - -impl ops::Div for Variable { - type Output = Term; - fn div(self, v: f32) -> Term { - self.div(v as f64) - } -} - // Term -impl ops::Mul for Term { - type Output = Term; - fn mul(mut self, v: f64) -> Term { +impl ops::Mul for Term { + type Output = Term; + fn mul(mut self, v: f64) -> Term { self *= v; self } } -impl ops::MulAssign for Term { +impl ops::MulAssign for Term { fn mul_assign(&mut self, v: f64) { - self.coefficient *= v; + *(self.coefficient.as_mut()) *= v; } } -impl ops::Mul for Term { - type Output = Term; - fn mul(self, v: f32) -> Term { +impl ops::Mul for Term { + type Output = Term; + fn mul(self, v: f32) -> Term { self.mul(v as f64) } } -impl ops::MulAssign for Term { +impl ops::MulAssign for Term { fn mul_assign(&mut self, v: f32) { self.mul_assign(v as f64) } } -impl ops::Mul for f64 { - type Output = Term; - fn mul(self, mut t: Term) -> Term { - t.coefficient *= self; +impl ops::Mul> for f64 { + type Output = Term; + fn mul(self, mut t: Term) -> Term { + *(t.coefficient.as_mut()) *= self; t } } -impl ops::Mul for f32 { - type Output = Term; - fn mul(self, t: Term) -> Term { +impl ops::Mul> for f32 { + type Output = Term; + fn mul(self, t: Term) -> Term { (self as f64).mul(t) } } -impl ops::Div for Term { - type Output = Term; - fn div(mut self, v: f64) -> Term { +impl ops::Div for Term { + type Output = Term; + fn div(mut self, v: f64) -> Term { self /= v; self } } -impl ops::DivAssign for Term { +impl ops::DivAssign for Term { fn div_assign(&mut self, v: f64) { - self.coefficient /= v; + *(self.coefficient.as_mut()) /= v; } } -impl ops::Div for Term { - type Output = Term; - fn div(self, v: f32) -> Term { +impl ops::Div for Term { + type Output = Term; + fn div(self, v: f32) -> Term { self.div(v as f64) } } -impl ops::DivAssign for Term { +impl ops::DivAssign for Term { fn div_assign(&mut self, v: f32) { self.div_assign(v as f64) } } -impl ops::Add for Term { - type Output = Expression; - fn add(self, v: f64) -> Expression { +impl ops::Add for Term { + type Output = Expression; + fn add(self, v: f64) -> Expression { Expression::new(vec![self], v) } } -impl ops::Add for Term { - type Output = Expression; - fn add(self, v: f32) -> Expression { +impl ops::Add for Term { + type Output = Expression; + fn add(self, v: f32) -> Expression { self.add(v as f64) } } -impl ops::Add for f64 { - type Output = Expression; - fn add(self, t: Term) -> Expression { +impl ops::Add> for f64 { + type Output = Expression; + fn add(self, t: Term) -> Expression { Expression::new(vec![t], self) } } -impl ops::Add for f32 { - type Output = Expression; - fn add(self, t: Term) -> Expression { +impl ops::Add> for f32 { + type Output = Expression; + fn add(self, t: Term) -> Expression { (self as f64).add(t) } } -impl ops::Add for Term { - type Output = Expression; - fn add(self, t: Term) -> Expression { +impl ops::Add> for Term { + type Output = Expression; + fn add(self, t: Term) -> Expression { Expression::new(vec![self, t], 0.0) } } -impl ops::Add for Term { - type Output = Expression; - fn add(self, mut e: Expression) -> Expression { +impl ops::Add> for Term { + type Output = Expression; + fn add(self, mut e: Expression) -> Expression { e.terms.push(self); e } } -impl ops::Add for Expression { - type Output = Expression; - fn add(mut self, t: Term) -> Expression { +impl ops::Add> for Expression { + type Output = Expression; + fn add(mut self, t: Term) -> Expression { self += t; self } } -impl ops::AddAssign for Expression { - fn add_assign(&mut self, t: Term) { +impl ops::AddAssign> for Expression { + fn add_assign(&mut self, t: Term) { self.terms.push(t); } } -impl ops::Neg for Term { - type Output = Term; - fn neg(mut self) -> Term { - self.coefficient = -self.coefficient; +impl ops::Neg for Term { + type Output = Term; + fn neg(mut self) -> Term { + *(self.coefficient.as_mut()) = -(self.coefficient.into_inner()); self } } -impl ops::Sub for Term { - type Output = Expression; - fn sub(self, v: f64) -> Expression { +impl ops::Sub for Term { + type Output = Expression; + fn sub(self, v: f64) -> Expression { Expression::new(vec![self], -v) } } -impl ops::Sub for Term { - type Output = Expression; - fn sub(self, v: f32) -> Expression { +impl ops::Sub for Term { + type Output = Expression; + fn sub(self, v: f32) -> Expression { self.sub(v as f64) } } -impl ops::Sub for f64 { - type Output = Expression; - fn sub(self, t: Term) -> Expression { +impl ops::Sub> for f64 { + type Output = Expression; + fn sub(self, t: Term) -> Expression { Expression::new(vec![-t], self) } } -impl ops::Sub for f32 { - type Output = Expression; - fn sub(self, t: Term) -> Expression { +impl ops::Sub> for f32 { + type Output = Expression; + fn sub(self, t: Term) -> Expression { (self as f64).sub(t) } } -impl ops::Sub for Term { - type Output = Expression; - fn sub(self, t: Term) -> Expression { +impl ops::Sub> for Term { + type Output = Expression; + fn sub(self, t: Term) -> Expression { Expression::new(vec![self, -t], 0.0) } } -impl ops::Sub for Term { - type Output = Expression; - fn sub(self, mut e: Expression) -> Expression { +impl ops::Sub> for Term { + type Output = Expression; + fn sub(self, mut e: Expression) -> Expression { e.negate(); e.terms.push(self); e } } -impl ops::Sub for Expression { - type Output = Expression; - fn sub(mut self, t: Term) -> Expression { +impl ops::Sub> for Expression { + type Output = Expression; + fn sub(mut self, t: Term) -> Expression { self -= t; self } } -impl ops::SubAssign for Expression { - fn sub_assign(&mut self, t: Term) { +impl ops::SubAssign> for Expression { + fn sub_assign(&mut self, t: Term) { self.terms.push(-t); } } // Expression -impl ops::Mul for Expression { - type Output = Expression; - fn mul(mut self, v: f64) -> Expression { - self *= v; +impl ops::Mul for Expression { + type Output = Expression; + fn mul(mut self, v: f64) -> Expression { + self *= v.clone(); self } } -impl ops::MulAssign for Expression { +impl ops::MulAssign for Expression { fn mul_assign(&mut self, v: f64) { - self.constant *= v; + *(self.constant.as_mut()) *= v; for t in &mut self.terms { - *t = *t * v; + *t = t.clone() * v; } } } -impl ops::Mul for Expression { - type Output = Expression; - fn mul(self, v: f32) -> Expression { +impl ops::Mul for Expression { + type Output = Expression; + fn mul(self, v: f32) -> Expression { self.mul(v as f64) } } -impl ops::MulAssign for Expression { +impl ops::MulAssign for Expression { fn mul_assign(&mut self, v: f32) { - self.mul_assign(v as f64) + let v2 = v as f64; + *(self.constant.as_mut()) *= v2; + for t in &mut self.terms { + *t = t.clone() * v2; + } } } -impl ops::Mul for f64 { - type Output = Expression; - fn mul(self, mut e: Expression) -> Expression { - e.constant *= self; +impl ops::Mul> for f64 { + type Output = Expression; + fn mul(self, mut e: Expression) -> Expression { + *(e.constant.as_mut()) *= self; for t in &mut e.terms { - *t = *t * self; + *t = t.clone() * self; } e } } -impl ops::Mul for f32 { - type Output = Expression; - fn mul(self, e: Expression) -> Expression { +impl ops::Mul> for f32 { + type Output = Expression; + fn mul(self, e: Expression) -> Expression { (self as f64).mul(e) } } -impl ops::Div for Expression { - type Output = Expression; - fn div(mut self, v: f64) -> Expression { +impl ops::Div for Expression { + type Output = Expression; + fn div(mut self, v: f64) -> Expression { self /= v; self } } -impl ops::DivAssign for Expression { +impl ops::DivAssign for Expression { fn div_assign(&mut self, v: f64) { - self.constant /= v; + *(self.constant.as_mut()) /= v; for t in &mut self.terms { - *t = *t / v; + *t = t.clone() / v; } } } -impl ops::Div for Expression { - type Output = Expression; - fn div(self, v: f32) -> Expression { +impl ops::Div for Expression { + type Output = Expression; + fn div(self, v: f32) -> Expression { self.div(v as f64) } } -impl ops::DivAssign for Expression { +impl ops::DivAssign for Expression { fn div_assign(&mut self, v: f32) { self.div_assign(v as f64) } } -impl ops::Add for Expression { - type Output = Expression; - fn add(mut self, v: f64) -> Expression { +impl ops::Add for Expression { + type Output = Expression; + fn add(mut self, v: f64) -> Expression { self += v; self } } -impl ops::AddAssign for Expression { +impl ops::AddAssign for Expression { fn add_assign(&mut self, v: f64) { - self.constant += v; + *(self.constant.as_mut()) += v; } } -impl ops::Add for Expression { - type Output = Expression; - fn add(self, v: f32) -> Expression { +impl ops::Add for Expression { + type Output = Expression; + fn add(self, v: f32) -> Expression { self.add(v as f64) } } -impl ops::AddAssign for Expression { +impl ops::AddAssign for Expression { fn add_assign(&mut self, v: f32) { self.add_assign(v as f64) } } -impl ops::Add for f64 { - type Output = Expression; - fn add(self, mut e: Expression) -> Expression { - e.constant += self; +impl ops::Add> for f64 { + type Output = Expression; + fn add(self, mut e: Expression) -> Expression { + *(e.constant.as_mut()) += self; e } } -impl ops::Add for f32 { - type Output = Expression; - fn add(self, e: Expression) -> Expression { +impl ops::Add> for f32 { + type Output = Expression; + fn add(self, e: Expression) -> Expression { (self as f64).add(e) } } -impl ops::Add for Expression { - type Output = Expression; - fn add(mut self, e: Expression) -> Expression { +impl ops::Add> for Expression { + type Output = Expression; + fn add(mut self, e: Expression) -> Expression { self += e; self } } -impl ops::AddAssign for Expression { - fn add_assign(&mut self, mut e: Expression) { +impl ops::AddAssign> for Expression { + fn add_assign(&mut self, mut e: Expression) { self.terms.append(&mut e.terms); - self.constant += e.constant; + *(self.constant.as_mut()) += e.constant.into_inner(); } } -impl ops::Neg for Expression { - type Output = Expression; - fn neg(mut self) -> Expression { +impl ops::Neg for Expression { + type Output = Expression; + fn neg(mut self) -> Expression { self.negate(); self } } -impl ops::Sub for Expression { - type Output = Expression; - fn sub(mut self, v: f64) -> Expression { +impl ops::Sub for Expression { + type Output = Expression; + fn sub(mut self, v: f64) -> Expression { self -= v; self } } -impl ops::SubAssign for Expression { +impl ops::SubAssign for Expression { fn sub_assign(&mut self, v: f64) { - self.constant -= v; + *(self.constant.as_mut()) -= v; } } -impl ops::Sub for Expression { - type Output = Expression; - fn sub(self, v: f32) -> Expression { +impl ops::Sub for Expression { + type Output = Expression; + fn sub(self, v: f32) -> Expression { self.sub(v as f64) } } -impl ops::SubAssign for Expression { +impl ops::SubAssign for Expression { fn sub_assign(&mut self, v: f32) { self.sub_assign(v as f64) } } -impl ops::Sub for f64 { - type Output = Expression; - fn sub(self, mut e: Expression) -> Expression { +impl ops::Sub> for f64 { + type Output = Expression; + fn sub(self, mut e: Expression) -> Expression { e.negate(); - e.constant += self; + *(e.constant.as_mut()) += self; e } } -impl ops::Sub for f32 { - type Output = Expression; - fn sub(self, e: Expression) -> Expression { +impl ops::Sub> for f32 { + type Output = Expression; + fn sub(self, e: Expression) -> Expression { (self as f64).sub(e) } } -impl ops::Sub for Expression { - type Output = Expression; - fn sub(mut self, e: Expression) -> Expression { +impl ops::Sub> for Expression { + type Output = Expression; + fn sub(mut self, e: Expression) -> Expression { self -= e; self } } -impl ops::SubAssign for Expression { - fn sub_assign(&mut self, mut e: Expression) { +impl ops::SubAssign> for Expression { + fn sub_assign(&mut self, mut e: Expression) { e.negate(); self.terms.append(&mut e.terms); - self.constant += e.constant; + *(self.constant.as_mut()) += e.constant.into_inner(); } } diff --git a/src/solver_impl.rs b/src/solver_impl.rs index 7c5948e..8f2a7fb 100644 --- a/src/solver_impl.rs +++ b/src/solver_impl.rs @@ -1,8 +1,8 @@ use { Symbol, + Tag, SymbolType, Constraint, - Variable, Expression, Term, Row, @@ -16,43 +16,52 @@ use { near_zero }; -use ::std::rc::Rc; -use ::std::cell::RefCell; -use ::std::collections::{ HashMap, HashSet }; -use ::std::collections::hash_map::Entry; +use std::collections::{ HashMap, HashSet }; +use std::hash::Hash; +use std::fmt::Debug; +use std::collections::hash_map::Entry; -#[derive(Copy, Clone)] -struct Tag { - marker: Symbol, - other: Symbol + +fn print_inc(msg: &str) { + println!(">{}", msg); +} + +fn print_dec(msg: &str) { + println!("<{}", msg); } + #[derive(Clone)] -struct EditInfo { +#[derive(Debug)] +struct EditInfo { tag: Tag, - constraint: Constraint, + constraint: Constraint, constant: f64 } /// A constraint solver using the Cassowary algorithm. For proper usage please see the top level crate documentation. -pub struct Solver { - cns: HashMap, - var_data: HashMap, - var_for_symbol: HashMap, - public_changes: Vec<(Variable, f64)>, - changed: HashSet, +#[derive(Debug)] +pub struct Solver +where + T: Debug + Clone + Eq + Hash +{ + cns: HashMap, Tag>, + var_data: HashMap, + var_for_symbol: HashMap, + public_changes: Vec<(T, f64)>, + changed: HashSet, should_clear_changes: bool, - rows: HashMap>, - edits: HashMap, + rows: HashMap, + edits: HashMap>, infeasible_rows: Vec, // never contains external symbols - objective: Rc>, - artificial: Option>>, + objective: Option, id_tick: usize } -impl Solver { +impl Solver +{ /// Construct a new solver. - pub fn new() -> Solver { + pub fn new() -> Solver { Solver { cns: HashMap::new(), var_data: HashMap::new(), @@ -63,24 +72,23 @@ impl Solver { rows: HashMap::new(), edits: HashMap::new(), infeasible_rows: Vec::new(), - objective: Rc::new(RefCell::new(Row::new(0.0))), - artificial: None, + objective: Some(Row::new(0.0)), id_tick: 1 } } - pub fn add_constraints<'a, I: IntoIterator>( + pub fn add_constraints>>( &mut self, constraints: I) -> Result<(), AddConstraintError> { for constraint in constraints { - try!(self.add_constraint(constraint.clone())); + self.add_constraint(constraint)?; } Ok(()) } /// Add a constraint to the solver. - pub fn add_constraint(&mut self, constraint: Constraint) -> Result<(), AddConstraintError> { + pub fn add_constraint(&mut self, constraint: Constraint) -> Result<(), AddConstraintError> { if self.cns.contains_key(&constraint) { return Err(AddConstraintError::DuplicateConstraint); } @@ -92,7 +100,7 @@ impl Solver { // constraints and since exceptional conditions are uncommon, // i'm not too worried about aggressive cleanup of the var map. let (mut row, tag) = self.create_row(&constraint); - let mut subject = Solver::choose_subject(&row, &tag); + let mut subject = Symbol::choose_subject(&row, &tag); // If chooseSubject could find a valid entering symbol, one // last option is available if the entire row is composed of @@ -100,8 +108,8 @@ impl Solver { // this represents redundant constraints and the new dummy // marker can enter the basis. If the constant is non-zero, // then it represents an unsatisfiable constraint. - if subject.type_() == SymbolType::Invalid && Solver::all_dummies(&row) { - if !near_zero(row.constant) { + if subject.type_() == SymbolType::Invalid && row.all_dummies() { + if !near_zero(*row.constant.as_ref()) { return Err(AddConstraintError::UnsatisfiableConstraint); } else { subject = tag.marker; @@ -119,8 +127,8 @@ impl Solver { } else { row.solve_for_symbol(subject); self.substitute(subject, &row); - if subject.type_() == SymbolType::External && row.constant != 0.0 { - let v = self.var_for_symbol[&subject]; + if subject.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v:T = self.var_for_symbol[&subject].clone(); self.var_changed(v); } self.rows.insert(subject, row); @@ -131,14 +139,15 @@ impl Solver { // Optimizing after each constraint is added performs less // aggregate work due to a smaller average system size. It // also ensures the solver remains in a consistent state. - let objective = self.objective.clone(); - try!(self.optimise(&objective).map_err(|e| AddConstraintError::InternalSolverError(e.0))); + let mut objective = self.objective.take().expect("Could not take objective in add_constraint"); + self.optimise(&mut objective).map_err(|e| AddConstraintError::InternalSolverError(e.0))?; + self.objective = Some(objective); Ok(()) } /// Remove a constraint from the solver. - pub fn remove_constraint(&mut self, constraint: &Constraint) -> Result<(), RemoveConstraintError> { - let tag = try!(self.cns.remove(constraint).ok_or(RemoveConstraintError::UnknownConstraint)); + pub fn remove_constraint(&mut self, constraint: &Constraint) -> Result<(), RemoveConstraintError> { + let tag = self.cns.remove(constraint).ok_or(RemoveConstraintError::UnknownConstraint)?; // Remove the error effects from the objective function // *before* pivoting, or substitutions into the objective @@ -148,10 +157,13 @@ impl Solver { // If the marker is basic, simply drop the row. Otherwise, // pivot the marker into the basis and then drop the row. if let None = self.rows.remove(&tag.marker) { - let (leaving, mut row) = try!(self.get_marker_leaving_row(tag.marker) - .ok_or( - RemoveConstraintError::InternalSolverError( - "Failed to find leaving row."))); + let (leaving, mut row) = + self.get_marker_leaving_row(tag.marker) + .ok_or( + RemoveConstraintError::InternalSolverError( + "Failed to find leaving row." + ) + )?; row.solve_for_symbols(leaving, tag.marker); self.substitute(tag.marker, &row); } @@ -159,13 +171,14 @@ impl Solver { // Optimizing after each constraint is removed ensures that the // solver remains consistent. It makes the solver api easier to // use at a small tradeoff for speed. - let objective = self.objective.clone(); - try!(self.optimise(&objective).map_err(|e| RemoveConstraintError::InternalSolverError(e.0))); + let mut objective = self.objective.take().expect("Could not take objective in remove_constraint"); + self.optimise(&mut objective).map_err(|e| RemoveConstraintError::InternalSolverError(e.0))?; + self.objective = Some(objective); // Check for and decrease the reference count for variables referenced by the constraint // If the reference count is zero remove the variable from the variable map for term in &constraint.expr().terms { - if !near_zero(term.coefficient) { + if !near_zero(term.coefficient.into_inner()) { let mut should_remove = false; if let Some(&mut (_, _, ref mut count)) = self.var_data.get_mut(&term.variable) { *count -= 1; @@ -181,7 +194,7 @@ impl Solver { } /// Test whether a constraint has been added to the solver. - pub fn has_constraint(&self, constraint: &Constraint) -> bool { + pub fn has_constraint(&self, constraint: &Constraint) -> bool { self.cns.contains_key(constraint) } @@ -189,7 +202,7 @@ impl Solver { /// /// This method should be called before the `suggest_value` method is /// used to supply a suggested value for the given edit variable. - pub fn add_edit_variable(&mut self, v: Variable, strength: f64) -> Result<(), AddEditVariableError> { + pub fn add_edit_variable(&mut self, v: T, strength: f64) -> Result<(), AddEditVariableError> { if self.edits.contains_key(&v) { return Err(AddEditVariableError::DuplicateEditVariable); } @@ -200,7 +213,9 @@ impl Solver { let cn = Constraint::new(Expression::from_term(Term::new(v.clone(), 1.0)), RelationalOperator::Equal, strength); - self.add_constraint(cn.clone()).unwrap(); + self + .add_constraint(cn.clone()) + .expect("Could not add constraint in add_edit_variable"); self.edits.insert(v.clone(), EditInfo { tag: self.cns[&cn].clone(), constraint: cn, @@ -210,7 +225,7 @@ impl Solver { } /// Remove an edit variable from the solver. - pub fn remove_edit_variable(&mut self, v: Variable) -> Result<(), RemoveEditVariableError> { + pub fn remove_edit_variable(&mut self, v: T) -> Result<(), RemoveEditVariableError> { if let Some(constraint) = self.edits.remove(&v).map(|e| e.constraint) { try!(self.remove_constraint(&constraint) .map_err(|e| match e { @@ -226,7 +241,7 @@ impl Solver { } /// Test whether an edit variable has been added to the solver. - pub fn has_edit_variable(&self, v: &Variable) -> bool { + pub fn has_edit_variable(&self, v: &T) -> bool { self.edits.contains_key(v) } @@ -234,9 +249,9 @@ impl Solver { /// /// This method should be used after an edit variable has been added to /// the solver in order to suggest the value for that variable. - pub fn suggest_value(&mut self, variable: Variable, value: f64) -> Result<(), SuggestValueError> { + pub fn suggest_value(&mut self, variable: T, value: f64) -> Result<(), SuggestValueError> { let (info_tag_marker, info_tag_other, delta) = { - let info = try!(self.edits.get_mut(&variable).ok_or(SuggestValueError::UnknownEditVariable)); + let info = self.edits.get_mut(&variable).ok_or(SuggestValueError::UnknownEditVariable)?; let delta = value - info.constant; info.constant = value; (info.tag.marker, info.tag.other, delta) @@ -266,7 +281,7 @@ impl Solver { let coeff = row.coefficient_for(info_tag_marker); let diff = delta * coeff; if diff != 0.0 && symbol.type_() == SymbolType::External { - let v = self.var_for_symbol[symbol]; + let v = self.var_for_symbol[symbol].clone(); // inline var_changed - borrow checker workaround if self.should_clear_changes { self.changed.clear(); @@ -283,11 +298,11 @@ impl Solver { } } } - try!(self.dual_optimise().map_err(|e| SuggestValueError::InternalSolverError(e.0))); + self.dual_optimise().map_err(|e| SuggestValueError::InternalSolverError(e.0))?; return Ok(()); } - fn var_changed(&mut self, v: Variable) { + fn var_changed(&mut self, v: T) { if self.should_clear_changes { self.changed.clear(); self.should_clear_changes = false; @@ -299,7 +314,7 @@ impl Solver { /// /// The list of changes returned is not in a specific order. Each change comprises the variable changed and /// the new value of that variable. - pub fn fetch_changes(&mut self) -> &[(Variable, f64)] { + pub fn fetch_changes(&mut self) -> &[(T, f64)] { if self.should_clear_changes { self.changed.clear(); self.should_clear_changes = false; @@ -307,12 +322,12 @@ impl Solver { self.should_clear_changes = true; } self.public_changes.clear(); - for &v in &self.changed { + for v in &self.changed { if let Some(var_data) = self.var_data.get_mut(&v) { - let new_value = self.rows.get(&var_data.1).map(|r| r.constant).unwrap_or(0.0); + let new_value = self.rows.get(&var_data.1).map(|r| r.constant).map(|o| o.into_inner()).unwrap_or(0.0); let old_value = var_data.0; if old_value != new_value { - self.public_changes.push((v, new_value)); + self.public_changes.push((v.clone(), new_value)); var_data.0 = new_value; } } @@ -336,22 +351,21 @@ impl Solver { self.should_clear_changes = false; self.edits.clear(); self.infeasible_rows.clear(); - *self.objective.borrow_mut() = Row::new(0.0); - self.artificial = None; + self.objective = Some(Row::new(0.0)); self.id_tick = 1; } /// Get the symbol for the given variable. /// /// If a symbol does not exist for the variable, one will be created. - fn get_var_symbol(&mut self, v: Variable) -> Symbol { + fn get_var_symbol(&mut self, v: T) -> Symbol { let id_tick = &mut self.id_tick; let var_for_symbol = &mut self.var_for_symbol; - let value = self.var_data.entry(v).or_insert_with(|| { + let value = self.var_data.entry(v.clone()).or_insert_with(|| { let s = Symbol(*id_tick, SymbolType::External); var_for_symbol.insert(s, v); *id_tick += 1; - (::std::f64::NAN, s, 0) + (std::f64::NAN, s, 0) }); value.2 += 1; value.1 @@ -372,23 +386,22 @@ impl Solver { /// /// The tag will be updated with the marker and error symbols to use /// for tracking the movement of the constraint in the tableau. - fn create_row(&mut self, constraint: &Constraint) -> (Box, Tag) { + fn create_row(&mut self, constraint: &Constraint) -> (Row, Tag) { let expr = constraint.expr(); - let mut row = Row::new(expr.constant); + let mut row = Row::new(expr.constant.into_inner()); // Substitute the current basic variables into the row. for term in &expr.terms { - if !near_zero(term.coefficient) { - let symbol = self.get_var_symbol(term.variable); + if !near_zero(term.coefficient.into_inner()) { + let symbol = self.get_var_symbol(term.variable.clone()); if let Some(other_row) = self.rows.get(&symbol) { - row.insert_row(other_row, term.coefficient); + row.insert_row(other_row, term.coefficient.into_inner()); } else { - row.insert_symbol(symbol, term.coefficient); + row.insert_symbol(symbol, term.coefficient.into_inner()); } } } - let mut objective = self.objective.borrow_mut(); - + let mut objective = self.objective.take().expect("Could not take objective in create_row"); // Add the necessary slack, error, and dummy variables. let tag = match constraint.op() { RelationalOperator::GreaterOrEqual | @@ -442,43 +455,15 @@ impl Solver { } } }; + self.objective = Some(objective); // Ensure the row has a positive constant. - if row.constant < 0.0 { + if *row.constant.as_ref() < 0.0 { row.reverse_sign(); } - (Box::new(row), tag) - } - /// Choose the subject for solving for the row. - /// - /// This method will choose the best subject for using as the solve - /// target for the row. An invalid symbol will be returned if there - /// is no valid target. - /// - /// The symbols are chosen according to the following precedence: - /// - /// 1) The first symbol representing an external variable. - /// 2) A negative slack or error tag variable. - /// - /// If a subject cannot be found, an invalid symbol will be returned. - fn choose_subject(row: &Row, tag: &Tag) -> Symbol { - for s in row.cells.keys() { - if s.type_() == SymbolType::External { - return *s - } - } - if tag.marker.type_() == SymbolType::Slack || tag.marker.type_() == SymbolType::Error { - if row.coefficient_for(tag.marker) < 0.0 { - return tag.marker; - } - } - if tag.other.type_() == SymbolType::Slack || tag.other.type_() == SymbolType::Error { - if row.coefficient_for(tag.other) < 0.0 { - return tag.other; - } - } - Symbol::invalid() + print_dec("objective:ref mut out"); + (row, tag) } /// Add the row to the tableau using an artificial variable. @@ -488,15 +473,13 @@ impl Solver { // Create and add the artificial variable to the tableau let art = Symbol(self.id_tick, SymbolType::Slack); self.id_tick += 1; - self.rows.insert(art, Box::new(row.clone())); - self.artificial = Some(Rc::new(RefCell::new(row.clone()))); + self.rows.insert(art, row.clone()); // Optimize the artificial objective. This is successful // only if the artificial objective is optimized to zero. - let artificial = self.artificial.as_ref().unwrap().clone(); - try!(self.optimise(&artificial)); - let success = near_zero(artificial.borrow().constant); - self.artificial = None; + let mut artificial:Row = row.clone(); + self.optimise(&mut artificial)?; + let success = near_zero(*artificial.constant.as_ref()); // If the artificial variable is basic, pivot the row so that // it becomes basic. If the row is constant, exit early. @@ -504,7 +487,7 @@ impl Solver { if row.cells.is_empty() { return Ok(success); } - let entering = Solver::any_pivotable_symbol(&row); // never External + let entering = row.any_pivotable_symbol(); // never External if entering.type_() == SymbolType::Invalid { return Ok(false); // unsatisfiable (will this ever happen?) } @@ -517,7 +500,10 @@ impl Solver { for (_, row) in &mut self.rows { row.remove(art); } - self.objective.borrow_mut().remove(art); + self.objective + .as_mut() + .expect("Could not mutate objective in add_with_artificial_variable") + .remove(art); Ok(success) } @@ -529,45 +515,47 @@ impl Solver { for (&other_symbol, other_row) in &mut self.rows { let constant_changed = other_row.substitute(symbol, row); if other_symbol.type_() == SymbolType::External && constant_changed { - let v = self.var_for_symbol[&other_symbol]; // inline var_changed if self.should_clear_changes { self.changed.clear(); self.should_clear_changes = false; } + let v = self.var_for_symbol[&other_symbol].clone(); self.changed.insert(v); } - if other_symbol.type_() != SymbolType::External && other_row.constant < 0.0 { + if other_symbol.type_() != SymbolType::External && *other_row.constant.as_ref() < 0.0 { self.infeasible_rows.push(other_symbol); } } - self.objective.borrow_mut().substitute(symbol, row); - if let Some(artificial) = self.artificial.as_ref() { - artificial.borrow_mut().substitute(symbol, row); - } } /// Optimize the system for the given objective function. /// /// This method performs iterations of Phase 2 of the simplex method /// until the objective function reaches a minimum. - fn optimise(&mut self, objective: &RefCell) -> Result<(), InternalSolverError> { - loop { - let entering = Solver::get_entering_symbol(&objective.borrow()); + /// + /// Returns the optimized objective function. + fn optimise(&mut self, objective: &mut Row) -> Result<(), InternalSolverError> { + 'optimisation: loop { + let entering = objective.get_entering_symbol(); if entering.type_() == SymbolType::Invalid { - return Ok(()); + break 'optimisation; } - let (leaving, mut row) = try!(self.get_leaving_row(entering) - .ok_or(InternalSolverError("The objective is unbounded"))); + let (leaving, mut row) = + self + .get_leaving_row(entering) + .ok_or(InternalSolverError("The objective is unbounded"))?; // pivot the entering symbol into the basis row.solve_for_symbols(leaving, entering); self.substitute(entering, &row); - if entering.type_() == SymbolType::External && row.constant != 0.0 { - let v = self.var_for_symbol[&entering]; + objective.substitute(entering, &row); + if entering.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&entering].clone(); self.var_changed(v); } self.rows.insert(entering, row); } + Ok(()) } /// Optimize the system using the dual of the simplex method. @@ -581,7 +569,7 @@ impl Solver { let leaving = self.infeasible_rows.pop().unwrap(); let row = if let Entry::Occupied(entry) = self.rows.entry(leaving) { - if entry.get().constant < 0.0 { + if *entry.get().constant.as_ref() < 0.0 { Some(entry.remove()) } else { None @@ -597,8 +585,8 @@ impl Solver { // pivot the entering symbol into the basis row.solve_for_symbols(leaving, entering); self.substitute(entering, &row); - if entering.type_() == SymbolType::External && row.constant != 0.0 { - let v = self.var_for_symbol[&entering]; + if entering.type_() == SymbolType::External && *row.constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&entering].clone(); self.var_changed(v); } self.rows.insert(entering, row); @@ -607,22 +595,6 @@ impl Solver { Ok(()) } - /// Compute the entering variable for a pivot operation. - /// - /// This method will return first symbol in the objective function which - /// is non-dummy and has a coefficient less than zero. If no symbol meets - /// the criteria, it means the objective function is at a minimum, and an - /// invalid symbol is returned. - /// Could return an External symbol - fn get_entering_symbol(objective: &Row) -> Symbol { - for (symbol, value) in &objective.cells { - if symbol.type_() != SymbolType::Dummy && *value < 0.0 { - return *symbol; - } - } - Symbol::invalid() - } - /// Compute the entering symbol for the dual optimize operation. /// /// This method will return the symbol in the row which has a positive @@ -633,12 +605,14 @@ impl Solver { /// Could return an External symbol fn get_dual_entering_symbol(&self, row: &Row) -> Symbol { let mut entering = Symbol::invalid(); - let mut ratio = ::std::f64::INFINITY; - let objective = self.objective.borrow(); + let mut ratio = std::f64::INFINITY; + let objective = + self.objective.as_ref().expect("Could not get objective in get_dual_entering_symbol"); for (symbol, value) in &row.cells { - if *value > 0.0 && symbol.type_() != SymbolType::Dummy { + let value = *value.as_ref(); + if value > 0.0 && symbol.type_() != SymbolType::Dummy { let coeff = objective.coefficient_for(*symbol); - let r = coeff / *value; + let r = coeff / value; if r < ratio { ratio = r; entering = *symbol; @@ -648,19 +622,6 @@ impl Solver { entering } - /// Get the first Slack or Error symbol in the row. - /// - /// If no such symbol is present, and Invalid symbol will be returned. - /// Never returns an External symbol - fn any_pivotable_symbol(row: &Row) -> Symbol { - for symbol in row.cells.keys() { - if symbol.type_() == SymbolType::Slack || symbol.type_() == SymbolType::Error { - return *symbol; - } - } - Symbol::invalid() - } - /// Compute the row which holds the exit symbol for a pivot. /// /// This method will return an iterator to the row in the row map @@ -668,14 +629,14 @@ impl Solver { /// found, the end() iterator will be returned. This indicates that /// the objective function is unbounded. /// Never returns a row for an External symbol - fn get_leaving_row(&mut self, entering: Symbol) -> Option<(Symbol, Box)> { - let mut ratio = ::std::f64::INFINITY; + fn get_leaving_row(&mut self, entering: Symbol) -> Option<(Symbol, Row)> { + let mut ratio = std::f64::INFINITY; let mut found = None; for (symbol, row) in &self.rows { if symbol.type_() != SymbolType::External { let temp = row.coefficient_for(entering); if temp < 0.0 { - let temp_ratio = -row.constant / temp; + let temp_ratio = -row.constant.as_ref() / temp; if temp_ratio < ratio { ratio = temp_ratio; found = Some(*symbol); @@ -703,27 +664,28 @@ impl Solver { /// If the marker does not exist in any row, the row map end() iterator /// will be returned. This indicates an internal solver error since /// the marker *should* exist somewhere in the tableau. - fn get_marker_leaving_row(&mut self, marker: Symbol) -> Option<(Symbol, Box)> { - let mut r1 = ::std::f64::INFINITY; + fn get_marker_leaving_row(&mut self, marker: Symbol) -> Option<(Symbol, Row)> { + let mut r1 = std::f64::INFINITY; let mut r2 = r1; let mut first = None; let mut second = None; let mut third = None; for (symbol, row) in &self.rows { let c = row.coefficient_for(marker); + let row_constant = row.constant.as_ref(); if c == 0.0 { continue; } if symbol.type_() == SymbolType::External { third = Some(*symbol); } else if c < 0.0 { - let r = -row.constant / c; + let r = -row_constant / c; if r < r1 { r1 = r; first = Some(*symbol); } } else { - let r = row.constant / c; + let r = row_constant / c; if r < r2 { r2 = r; second = Some(*symbol); @@ -734,8 +696,8 @@ impl Solver { .or(second) .or(third) .and_then(|s| { - if s.type_() == SymbolType::External && self.rows[&s].constant != 0.0 { - let v = self.var_for_symbol[&s]; + if s.type_() == SymbolType::External && *self.rows[&s].constant.as_ref() != 0.0 { + let v = self.var_for_symbol[&s].clone(); self.var_changed(v); } self.rows @@ -745,7 +707,7 @@ impl Solver { } /// Remove the effects of a constraint on the objective function. - fn remove_constraint_effects(&mut self, cn: &Constraint, tag: &Tag) { + fn remove_constraint_effects(&mut self, cn: &Constraint, tag: &Tag) { if tag.marker.type_() == SymbolType::Error { self.remove_marker_effects(tag.marker, cn.strength()); } else if tag.other.type_() == SymbolType::Error { @@ -755,30 +717,24 @@ impl Solver { /// Remove the effects of an error marker on the objective function. fn remove_marker_effects(&mut self, marker: Symbol, strength: f64) { + print_inc("objective:in remove_marker_effects"); if let Some(row) = self.rows.get(&marker) { - self.objective.borrow_mut().insert_row(row, -strength); + self.objective.as_mut().expect("Could not get objective remove_marker_effects 1").insert_row(row, -strength); } else { - self.objective.borrow_mut().insert_symbol(marker, -strength); - } - } - - /// Test whether a row is composed of all dummy variables. - fn all_dummies(row: &Row) -> bool { - for symbol in row.cells.keys() { - if symbol.type_() != SymbolType::Dummy { - return false; - } + self.objective.as_mut().expect("Could not get objective remove_marker_effects 2").insert_symbol(marker, -strength); } - true + print_dec("objective:out remove_marker_effects"); } /// Get the stored value for a variable. /// /// Normally values should be retrieved and updated using `fetch_changes`, but /// this method can be used for debugging or testing. - pub fn get_value(&self, v: Variable) -> f64 { + pub fn get_value(&self, v: T) -> f64 { self.var_data.get(&v).and_then(|s| { self.rows.get(&s.1).map(|r| r.constant) - }).unwrap_or(0.0) + }) + .map(|o| o.into_inner()) + .unwrap_or(0.0) } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs deleted file mode 100644 index c03016a..0000000 --- a/tests/common/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::collections::HashMap; -use std::cell::RefCell; -use std::rc::Rc; - -use cassowary::Variable; - -#[derive(Clone, Default)] -struct Values(Rc>>); - -impl Values { - fn value_of(&self, var: Variable) -> f64 { - *self.0.borrow().get(&var).unwrap_or(&0.0) - } - fn update_values(&self, changes: &[(Variable, f64)]) { - for &(ref var, ref value) in changes { - println!("{:?} changed to {:?}", var, value); - self.0.borrow_mut().insert(*var, *value); - } - } -} - -pub fn new_values() -> (Box f64>, Box) { - let values = Values(Rc::new(RefCell::new(HashMap::new()))); - let value_of = { - let values = values.clone(); - move |v| values.value_of(v) - }; - let update_values = { - let values = values.clone(); - move |changes: &[_]| { - values.update_values(changes); - } - }; - (Box::new(value_of), Box::new(update_values)) -} \ No newline at end of file diff --git a/tests/quadrilateral.rs b/tests/quadrilateral.rs deleted file mode 100644 index 5df91a4..0000000 --- a/tests/quadrilateral.rs +++ /dev/null @@ -1,106 +0,0 @@ -extern crate cassowary; -use cassowary::{ Solver, Variable }; -use cassowary::WeightedRelation::*; - -mod common; -use common::new_values; - -#[test] -fn test_quadrilateral() { - use cassowary::strength::{WEAK, STRONG, REQUIRED}; - struct Point { - x: Variable, - y: Variable - } - impl Point { - fn new() -> Point { - Point { - x: Variable::new(), - y: Variable::new() - } - } - } - let (value_of, update_values) = new_values(); - - let points = [Point::new(), - Point::new(), - Point::new(), - Point::new()]; - let point_starts = [(10.0, 10.0), (10.0, 200.0), (200.0, 200.0), (200.0, 10.0)]; - let midpoints = [Point::new(), - Point::new(), - Point::new(), - Point::new()]; - let mut solver = Solver::new(); - let mut weight = 1.0; - let multiplier = 2.0; - for i in 0..4 { - solver.add_constraints(&[points[i].x |EQ(WEAK * weight)| point_starts[i].0, - points[i].y |EQ(WEAK * weight)| point_starts[i].1]) - .unwrap(); - weight *= multiplier; - } - - for (start, end) in vec![(0, 1), (1, 2), (2, 3), (3, 0)] { - solver.add_constraints(&[midpoints[start].x |EQ(REQUIRED)| (points[start].x + points[end].x) / 2.0, - midpoints[start].y |EQ(REQUIRED)| (points[start].y + points[end].y) / 2.0]) - .unwrap(); - } - - solver.add_constraints(&[points[0].x + 20.0 |LE(STRONG)| points[2].x, - points[0].x + 20.0 |LE(STRONG)| points[3].x, - - points[1].x + 20.0 |LE(STRONG)| points[2].x, - points[1].x + 20.0 |LE(STRONG)| points[3].x, - - points[0].y + 20.0 |LE(STRONG)| points[1].y, - points[0].y + 20.0 |LE(STRONG)| points[2].y, - - points[3].y + 20.0 |LE(STRONG)| points[1].y, - points[3].y + 20.0 |LE(STRONG)| points[2].y]) - .unwrap(); - - for point in &points { - solver.add_constraints(&[point.x |GE(REQUIRED)| 0.0, - point.y |GE(REQUIRED)| 0.0, - - point.x |LE(REQUIRED)| 500.0, - point.y |LE(REQUIRED)| 500.0]).unwrap() - } - - update_values(solver.fetch_changes()); - - assert_eq!([(value_of(midpoints[0].x), value_of(midpoints[0].y)), - (value_of(midpoints[1].x), value_of(midpoints[1].y)), - (value_of(midpoints[2].x), value_of(midpoints[2].y)), - (value_of(midpoints[3].x), value_of(midpoints[3].y))], - [(10.0, 105.0), - (105.0, 200.0), - (200.0, 105.0), - (105.0, 10.0)]); - - solver.add_edit_variable(points[2].x, STRONG).unwrap(); - solver.add_edit_variable(points[2].y, STRONG).unwrap(); - solver.suggest_value(points[2].x, 300.0).unwrap(); - solver.suggest_value(points[2].y, 400.0).unwrap(); - - update_values(solver.fetch_changes()); - - assert_eq!([(value_of(points[0].x), value_of(points[0].y)), - (value_of(points[1].x), value_of(points[1].y)), - (value_of(points[2].x), value_of(points[2].y)), - (value_of(points[3].x), value_of(points[3].y))], - [(10.0, 10.0), - (10.0, 200.0), - (300.0, 400.0), - (200.0, 10.0)]); - - assert_eq!([(value_of(midpoints[0].x), value_of(midpoints[0].y)), - (value_of(midpoints[1].x), value_of(midpoints[1].y)), - (value_of(midpoints[2].x), value_of(midpoints[2].y)), - (value_of(midpoints[3].x), value_of(midpoints[3].y))], - [(10.0, 105.0), - (155.0, 300.0), - (250.0, 205.0), - (105.0, 10.0)]); -} diff --git a/tests/removal.rs b/tests/removal.rs deleted file mode 100644 index d9387b2..0000000 --- a/tests/removal.rs +++ /dev/null @@ -1,30 +0,0 @@ -extern crate cassowary; - -use cassowary::{Variable, Solver, Constraint}; -use cassowary::WeightedRelation::*; -use cassowary::strength::*; - -mod common; - -use common::new_values; - -#[test] -fn remove_constraint() { - let (value_of, update_values) = new_values(); - - let mut solver = Solver::new(); - - let val = Variable::new(); - - let constraint: Constraint = val | EQ(REQUIRED) | 100.0; - solver.add_constraint(constraint.clone()).unwrap(); - update_values(solver.fetch_changes()); - - assert_eq!(value_of(val), 100.0); - - solver.remove_constraint(&constraint).unwrap(); - solver.add_constraint(val | EQ(REQUIRED) | 0.0).unwrap(); - update_values(solver.fetch_changes()); - - assert_eq!(value_of(val), 0.0); -}