Cart-based Machine Learning algorithm implemented using the standard Java and Kotlin libraries.
Start by creating a training data set that is used by the classifier to analyze and build the decision tree.
Diagnosis Question 1 | Diagnosis Question 2 | Diagnosis |
---|---|---|
Symptom1 | Symptom2 | DiagnosisA |
Symptom1 | Symptom1 | DiagnosisB |
Symptom1 | Symptom5 | DiagnosisC |
Symptom1 | Symptom3 | DiagnosisD |
Symptom5 | Symptom3 | DiagnosisE |
Symptom1 | Symptom4 | DiagnosisB |
Symptom1 | Symptom1 | DiagnosisC |
Symptom2 | Symptom3 | DiagnosisC |
Use a data class(if needed, a regular class will also work) to create a typed representation of a row in the training data set. The resulting class must implement the DecisionTreeClassifierDataRow
interface. The classification
method returns the value of the classification column.
In the example below, the classification column is the
diagnosis
field whose type is theDiagnosis
enum.
data class DataRow<T>(
val diagnosisSymptom1: Symptom,
val diagnosisSymptom2: Symptom,
var diagnosis: Diagnosis? = null) : DecisionTreeClassifierDataRow<T>() {
override fun classification(): T {
@Suppress("UNCHECKED_CAST")
return diagnosis as T
}
}
The training data is created as a List
of objects that implement the DecisionTreeClassifierDataRow<T>
interface where <T>
indicates the type associated with the classification column in the training model.
List<DecisionTreeClassifierDataRow<Diagnosis>> = listOf(DataRow(Symptom1, Symptom2, DiagnosisA),
DataRow(Symptom1, Symptom1, DiagnosisB),
DataRow(Symptom1, Symptom5, DiagnosisC),
DataRow(Symptom1, Symptom3, DiagnosisD),
DataRow(Symptom5, Symptom3, DiagnosisE),
DataRow(Symptom1, Symptom4, DiagnosisB),
DataRow(Symptom1, Symptom1, DiagnosisC),
DataRow(Symptom2, Symptom3, DiagnosisC))
The predicate's or questions used to analyze the training data is done with the PredicateFunction
class which takes a label and lambda as the predicate function.
val q1: PredicateFunction<DataRow<Diagnosis>> = PredicateFunction(label = QuestionLabels.Q1) {
it.diagnosisSymptom1 == Symptom1 || it.diagnosisSymptom2 == Symptom5
}
val q2: PredicateFunction<DataRow<Diagnosis>> = PredicateFunction(label = QuestionLabels.Q2) {
it.diagnosisSymptom1 == Symptom1 || it.diagnosisSymptom2 == Symptom3
}
val q3: PredicateFunction<DataRow<Diagnosis>> = PredicateFunction(label = QuestionLabels.Q3) {
it.diagnosisSymptom1 == Symptom3 || it.diagnosisSymptom2 == Symptom3
}
val q4: PredicateFunction<DataRow<Diagnosis>> = PredicateFunction(label = QuestionLabels.Q4) {
it.diagnosisSymptom1 == Symptom4 || it.diagnosisSymptom2 == Symptom4
}
val q5: PredicateFunction<DataRow<Diagnosis>> = PredicateFunction(label = QuestionLabels.Q5) {
it.diagnosisSymptom1 == Symptom5 || it.diagnosisSymptom2 == Symptom5
}
Now we can create our decision tree, providing a list of predicates and the associated data model. Once the DecisionTreeClassifier
class is instantiated, it automatically makes the necessary calculations and builds the appropriate decision tree.
val p: List<PredicateFunction<DataRow<Diagnosis>>> = listOf(q1, q2, q3, q4, q5)
val classifier: DecisionTreeClassifier<Diagnosis> = DecisionTreeClassifier(
trainingModel = trainingModel,
predicateFunctions = p)
Here is a sample list of data provided to the classifier for analysis.
val data: List<DataRow<Diagnosis>> = listOf(
DataRow(Symptom1, Symptom2),
DataRow(Symptom1, Symptom4))
To evaluate and retrieve the classification for a row of data.
classifier.evaluate(data[1])
Returns:
DiagnosisB
In this example, the response from the classifier will either DiagnosisA
, DiagnosisB
or DiagnosisC
because the provided questions associated with the given training data could not be partitioned further.
classifier.evaluate(data.first())
Returns:
DiagnosisA, DiagnosisB or DiagnosisC