Skip to content

Commit

Permalink
✨ adds svm classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Jan 1, 2024
1 parent 8a2b75e commit 7a166ec
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 6 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hog-detector"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "Histogram of Oriented Gradients and Object Detection"
authors = ["Christian <[email protected]>"]
Expand All @@ -24,8 +24,9 @@ name = "bin"
path = "src/main.rs"

[features]
default = ["wasm"]
default = ["wasm", "svm"]
brief = ["dep:brief-rs"]
svm = ["svm-burns"]
mnist = ["dep:mnist"]
eyes = ["reqwest", "zip"]
wasm = [
Expand Down Expand Up @@ -68,6 +69,7 @@ smartcore = { version = "0.3.0", features = [
"ndarray-bindings",
"serde",
], default-features = false, git = "https://github.com/smartcorelib/smartcore", branch = "fix-245" }
svm-burns = { git = "https://github.com/chriamue/svm-burns", optional = true}
zip = { version = "0.6.6", default-features = false, features = [
"deflate",
], optional = true }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ You can find a [demo here](https://chriamue.github.io/hog-detector)
2. Compile the code to WASM:

```sh
wasm-pack build --target web
trunk build --release
```

3. Run the Web version in your browser

```sh
python3 -m http.server
trunk serve --release
```

Open your browser on [Localhost](http://localhost:8000)
Expand Down
Binary file added res/eyes_svm_burns_model.json
Binary file not shown.
4 changes: 4 additions & 0 deletions src/classifier/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
/// naive bayes classifier module
pub mod bayes;

#[cfg(feature = "svm")]
/// svm classifier module
pub mod svm;
pub use bayes::BayesClassifier;
pub use object_detector_rust::prelude::Classifier;
121 changes: 121 additions & 0 deletions src/classifier/svm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::fmt::Debug;

use ndarray::{Array1, ArrayView1, ArrayView2};
use object_detector_rust::{
classifier::Classifier, predictable::Predictable, trainable::Trainable,
window_generator::PyramidWindow,
};
use serde::{Serialize, Deserialize};
use svm_burns::{svm::SVM, Parameters, RBFKernel, SVC};

use crate::HogDetector;

/// A support vector machine classifier
#[derive(Default, Serialize, Deserialize)]
pub struct SVMClassifier {
model: Option<SVC>,
}

impl Debug for SVMClassifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SVMClassifier").finish()
}
}

impl SVMClassifier {
/// Creates a new `SVMClassifier`
pub fn new() -> Self {
SVMClassifier { model: None }
}
}

impl PartialEq for SVMClassifier {
fn eq(&self, other: &SVMClassifier) -> bool {
self.model.is_none() && other.model.is_none()
|| self.model.is_some() && other.model.is_some()
}
}

impl HogDetector<f32, usize, SVMClassifier, PyramidWindow> {
/// new default svm detector
pub fn svm() -> Self {
HogDetector::<f32, usize, SVMClassifier, PyramidWindow>::default()
}
}

impl Trainable<f32, usize> for SVMClassifier {
fn fit(&mut self, x: &ArrayView2<f32>, y: &ArrayView1<usize>) -> Result<(), String> {
let x_vec: Vec<Vec<f64>> = x
.outer_iter()
.map(|row| row.iter().map(|&elem| elem as f64).collect())
.collect();

let y_vec: Vec<i32> = y.iter().map(|&elem| elem as i32).collect();

let mut parameters = Parameters::default();
parameters.with_kernel(Box::new(RBFKernel::new(0.7)));
let mut svc = SVC::new(parameters);

svc.fit(&x_vec, &y_vec);
self.model = Some(svc);
Ok(())
}
}

impl Predictable<f32, usize> for SVMClassifier {
fn predict(&self, x: &ArrayView2<f32>) -> Result<Array1<usize>, String> {
let x_vec: Vec<Vec<f64>> = x
.outer_iter()
.map(|row| row.iter().map(|&elem| elem as f64).collect())
.collect();
let prediction = self.model.as_ref().unwrap().predict(&x_vec);
let prediction: Vec<usize> = prediction
.iter()
.map(|&x| if x > 0 { 1 } else { 0 })
.collect();
Ok(Array1::from(prediction))
}
}

impl Classifier<f32, usize> for SVMClassifier {}

#[cfg(test)]
mod tests {
use super::*;
use crate::hogdetector::HogDetectorTrait;
use image::Rgb;
use object_detector_rust::dataset::DataSet;
use object_detector_rust::detector::Detector;
use object_detector_rust::{prelude::MemoryDataSet, tests::test_image};

#[test]
fn test_default() {
let classifier = super::SVMClassifier::default();
assert!(classifier.model.is_none());
}

#[test]
fn test_partial_eq() {
let detector1 = HogDetector::<f32, usize, super::SVMClassifier, _>::default();
let detector2 = HogDetector::<f32, usize, super::SVMClassifier, _>::svm();
assert!(detector1.eq(&detector2));
}

#[test]
fn test_detector() {
let img = test_image();
let mut dataset = MemoryDataSet::new_test();
dataset.load().unwrap();
let (x, y) = dataset.get_data();
let x = x.into_iter().map(|x| x.thumbnail_exact(32, 32)).collect();
let y = y.into_iter().map(|y| y as usize).collect::<Vec<_>>();

let mut detector: HogDetector<f32, usize, super::SVMClassifier, _> = HogDetector::default();
detector.fit_class(&x, &y, 1).unwrap();
let detections = detector.detect(&img);
assert!(detections.is_empty());
let visualization = detector.visualize_detections(&img).to_rgb8();
assert_eq!(&Rgb([0, 0, 0]), visualization.get_pixel(55, 0));
assert_eq!(&Rgb([255, 0, 0]), visualization.get_pixel(75, 0));
}
}
20 changes: 18 additions & 2 deletions src/dataset/eyes_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ mod tests {
use object_detector_rust::classifier::CombinedClassifier;
use object_detector_rust::detector::PersistentDetector;
use object_detector_rust::prelude::RandomForestClassifier;
use object_detector_rust::prelude::SVMClassifier;
use std::fs::File;

#[test]
Expand Down Expand Up @@ -202,7 +201,7 @@ mod tests {
#[ignore = "takes more than 200s in debug mode"]
#[test]
fn test_train_svm_model() {
let mut model: HogDetector<f32, bool, SVMClassifier<_, _>, _> = HogDetector::default();
let mut model: HogDetector<f32, bool, object_detector_rust::prelude::SVMClassifier<_, _>, _> = HogDetector::default();

let mut dataset = EyesDataSet::default();
dataset.load().unwrap();
Expand Down Expand Up @@ -249,6 +248,23 @@ mod tests {
model.save(file_writer).unwrap();
}

#[ignore = "takes more than 200s in debug mode"]
#[test]
fn test_train_svm_burns_model() {
let mut model: HogDetector<f32, usize, crate::classifier::svm::SVMClassifier, _> =
HogDetector::default();

let mut dataset = EyesDataSet::default();
dataset.load().unwrap();
let (x, y) = dataset.get_data();
let y = y.into_iter().map(|y| y as usize).collect::<Vec<_>>();
model.fit_class(&x, &y, 1).unwrap();
assert!(model.classifier.is_some());

let file_writer = File::create("res/eyes_svm_burns_model.json").unwrap();
model.save(file_writer).unwrap();
}

#[ignore = "takes more than 200s in debug mode"]
#[test]
fn test_train_combined_model() {
Expand Down
12 changes: 12 additions & 0 deletions src/wasm/hogdetector_js.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::classifier::BayesClassifier;
use crate::classifier::svm::SVMClassifier;
use crate::detection_filter::{DetectionFilter, TrackerFilter};
use crate::detector::visualize_detections;
use crate::hogdetector::HogDetectorTrait;
Expand Down Expand Up @@ -95,6 +96,17 @@ impl HogDetectorJS {
*self.hog.lock().unwrap() = Box::new(hog);
}

pub fn init_svm_classifier(&self) {
let hog = {
let mut model: HogDetector<f32, usize, SVMClassifier, _> =
HogDetector::default();
let file = Cursor::new(include_bytes!("../../res/eyes_svm_burns_model.json"));
model.load(file).unwrap();
model
};
*self.hog.lock().unwrap() = Box::new(hog);
}

#[wasm_bindgen]
pub fn init_combined_classifier(&self) {
let hog = {
Expand Down
11 changes: 11 additions & 0 deletions src/wasm/trainer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub enum Msg {
SwitchBayesClassifier,
SwitchRandomForestClassifier,
SwitchCombinedClassifier,
SwitchSVMClassifier,
}

#[derive(Clone, PartialEq, Properties)]
Expand Down Expand Up @@ -65,6 +66,11 @@ impl Component for TrainerApp {
log::info!("Switched to Combined Classifier");
true
}
Msg::SwitchSVMClassifier => {
ctx.props().detector.init_svm_classifier();
log::info!("Switched to SVM Classifier");
true
}
}
}

Expand All @@ -74,6 +80,7 @@ impl Component for TrainerApp {
ctx.link().callback(|_| Msg::TrainWithHardNegativeSamples);
let onclick_bayes = ctx.link().callback(|_| Msg::SwitchBayesClassifier);
let onclick_random_forest = ctx.link().callback(|_| Msg::SwitchRandomForestClassifier);
let onclick_svm = ctx.link().callback(|_| Msg::SwitchSVMClassifier);
let onclick_combined = ctx.link().callback(|_| Msg::SwitchCombinedClassifier);
html! {
<div id="train-classifier-buttons">
Expand All @@ -89,6 +96,9 @@ impl Component for TrainerApp {
<button type="button" class="btn btn-success" onclick={onclick_random_forest}>
{ "Switch to Random Forest Classifier" }
</button>
<button type="button" class="btn btn-success" onclick={onclick_svm}>
{ "Switch to SVM Classifier" }
</button>
<button type="button" class="btn btn-success" onclick={onclick_combined}>
{ "Switch to Combined Classifier" }
</button>
Expand Down Expand Up @@ -116,6 +126,7 @@ mod tests {
assert!(rendered.contains("Train Detector"));
assert!(rendered.contains("Train Detector with hard negative samples"));
assert!(rendered.contains("Switch to Random Forest Classifier"));
assert!(rendered.contains("Switch to SVM Classifier"));
assert!(rendered.contains("Switch to Combined Classifier"));
}
}

0 comments on commit 7a166ec

Please sign in to comment.