commit fbb285ba01aa4db1ae65a4af09e49f307790396c
parent abd8f9bdb27dff0e4ed253f17d2dc0eb1bb5e2cb
Author: Andrew <andrewlaack1@gmail.com>
Date: Mon, 23 Dec 2024 10:38:46 -0600
Set up bind
Diffstat:
25 files changed, 941 insertions(+), 636 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -6,7 +6,11 @@ __pycache__/
# C extensions
*.so
+*.dot
+cmake_install.cmake
+CMakeCache.txt
+datasets/
# Distribution / packaging
.Python
build/
@@ -22,6 +26,7 @@ sdist/
var/
wheels/
share/python-wheels/
+CMakeFiles/
*.egg-info/
.installed.cfg
*.egg
diff --git a/Notes.md b/Notes.md
@@ -28,3 +28,18 @@ C++ - 15.271s (user)
WITH 30,000 ELEMENTS, 4 FEATURES:
PYTHON - 37.5128 (user)
C++ 34.335 (user)
+
+
+WITH WORKLOAD OF 6000 SAMPLES AND 784 FEATURES:
+
+MAX MEMORY C++ - 24628kB
+MAX MEMORY Python - 521764kB
+
+ADD BENCHMARKING WITHOUT LIKELY AND THEN WITH.
+
+20,000 Samples, 4 Features:
+ WITHOUT LIKELY: 15.526 seconds (user), 15.632 seconds (user)
+ WITH LIKELY: 15.524 seconds (user), 15.608s (user)
+
+
+BUILD OPTIMAL DECISION TREE
diff --git a/classifier/Testing.py b/classifier/Testing.py
@@ -1,71 +0,0 @@
-from Podtc import PseudoOptimalDecisionTreeClassifier
-import numpy as np
-##import plotly.express as px
-#from sklearn.tree import DecisionTreeClassifier
-#from sklearn.datasets import load_digits
-#import pandas as pd
-#from keras.datasets import mnist
-#from sklearn.metrics import accuracy_score
-#from setuptools import setup
-
-# (train_X, train_y), (test_X, test_y) = mnist.load_data()
-
-# train_X = train_X.reshape(-1, 784)
-# test_X = test_X.reshape(-1, 784)
-
-X_train = np.random.randint(1, 10, size=(30000, 4))
-
-# Generate 60,000 labels for y_train (for example, labels between 1 and 10)
-y_train = np.random.randint(1, 11, size=30000)
-
-# train_X = [[2,5], [5,2], [3,4], [4,4], [5,5], [10, 10], [2,2], [12,12]]
-# train_y = [1, 1 , 2, 1, 5, 2,1 ,3]
-
-classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=1);
-
-classifier.fit(X_train, y_train)
-
-classifier.graph()
-
-exit()
-
-y_pred = classifier.predict(test_X)
-print("MY ACCURACY:")
-print(accuracy_score(y_true=test_y, y_pred=y_pred))
-
-classifier = DecisionTreeClassifier(max_depth=4)
-classifier.fit(train_X, train_y)
-y_pred = classifier.predict(test_X)
-
-print("SECOND ACCURACY:")
-print(accuracy_score(y_true=test_y, y_pred=y_pred))
-
-
-exit()
-
-X = np.random.random((200, 2))
-y = np.random.random((200)) * 10
-y = y.round()
-# y = (X[:,0] + X[:,1]) > .5
-
-
-#clf = DecisionTreeClassifier()
-#clf.fit(X,y)
-
-classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=2);
-
-classifier.fit(X,y)
-
-
-X_pred = np.random.random((20, 2)).round()
-
-print(X_pred)
-print(classifier.predict(X_pred))
-
-# classifier.predict()
-# print(classifier)
-
-
-
-#scatter = px.scatter(x=X[:,0], y=X[:,1], color=y)
-#scatter.show()
diff --git a/classifier/LeafNode.py b/old/LeafNode.py
diff --git a/classifier/Makefile b/old/Makefile
diff --git a/classifier/Podtc.py b/old/Podtc.py
diff --git a/classifier/SplittingNode.py b/old/SplittingNode.py
diff --git a/old/Testing.py b/old/Testing.py
@@ -0,0 +1,67 @@
+from Podtc import PseudoOptimalDecisionTreeClassifier
+import numpy as np
+##import plotly.express as px
+#from sklearn.tree import DecisionTreeClassifier
+#from sklearn.datasets import load_digits
+#import pandas as pd
+#from keras.datasets import mnist
+#from sklearn.metrics import accuracy_score
+#from setuptools import setup
+
+# (train_X, train_y), (test_X, test_y) = mnist.load_data()
+
+# train_X = train_X.reshape(-1, 784)
+# test_X = test_X.reshape(-1, 784)
+
+X_train = np.random.randint(1, 10, size=(6000, 784))
+
+# Generate 60,000 labels for y_train (for example, labels between 1 and 10)
+y_train = np.random.randint(1, 11, size=6000)
+
+# train_X = [[2,5], [5,2], [3,4], [4,4], [5,5], [10, 10], [2,2], [12,12]]
+# train_y = [1, 1 , 2, 1, 5, 2,1 ,3]
+
+classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=1);
+classifier.fit(X_train, y_train)
+exit()
+
+y_pred = classifier.predict(test_X)
+print("MY ACCURACY:")
+print(accuracy_score(y_true=test_y, y_pred=y_pred))
+
+classifier = DecisionTreeClassifier(max_depth=4)
+classifier.fit(train_X, train_y)
+y_pred = classifier.predict(test_X)
+
+print("SECOND ACCURACY:")
+print(accuracy_score(y_true=test_y, y_pred=y_pred))
+
+
+exit()
+
+X = np.random.random((200, 2))
+y = np.random.random((200)) * 10
+y = y.round()
+# y = (X[:,0] + X[:,1]) > .5
+
+
+#clf = DecisionTreeClassifier()
+#clf.fit(X,y)
+
+classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=2);
+
+classifier.fit(X,y)
+
+
+X_pred = np.random.random((20, 2)).round()
+
+print(X_pred)
+print(classifier.predict(X_pred))
+
+# classifier.predict()
+# print(classifier)
+
+
+
+#scatter = px.scatter(x=X[:,0], y=X[:,1], color=y)
+#scatter.show()
diff --git a/classifier/__init__.py b/old/__init__.py
diff --git a/classifier/cpp/gini.cpp b/old/cpp/gini.cpp
diff --git a/classifier/cpp/split.cpp b/old/cpp/split.cpp
diff --git a/rewrite/CMakeLists.txt b/rewrite/CMakeLists.txt
@@ -0,0 +1,42 @@
+
+cmake_minimum_required(VERSION 3.10)
+
+# Set project name
+project(DecisionTreeClassifier)
+
+# Set C++ standard
+set(CMAKE_CXX_STANDARD 17)
+
+# Find Python and Pybind11
+find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
+find_package(pybind11 REQUIRED)
+
+# Set compiler flags
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -Wall -fPIC")
+
+# Add source files
+set(SOURCES
+ cpp/DecisionTreeClassifier.cpp
+ cpp/TreeNode.cpp
+ cpp/Criterion.cpp
+ cpp/bindings.cpp
+)
+
+# Create the shared library
+add_library(decision_tree MODULE ${SOURCES})
+
+# Link with Python and Pybind11
+target_include_directories(decision_tree PRIVATE ${Python3_INCLUDE_DIRS})
+target_link_libraries(decision_tree PRIVATE ${Python3_LIBRARIES} pybind11::module)
+
+# Set output name based on Python extension suffix
+set_target_properties(decision_tree PROPERTIES
+ PREFIX ""
+ SUFFIX ".so"
+)
+
+# Rename the custom clean target to avoid conflict
+add_custom_target(clean_build
+ COMMAND rm -f ${CMAKE_BINARY_DIR}/*.o
+ COMMENT "Clean build files"
+)
diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp
@@ -1,175 +0,0 @@
-#include "DecisionTreeClassifier.h"
-#include <limits>
-#include <iostream>
-#include <unordered_map>
-
-using namespace std;
-
-DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
- depth = maxDepth;
-}
-
-void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){
- if (splittingTree != nullptr){
- throw logic_error("Decision trees don't support incremental learning, fit can only be called once.");
- }
-
- if(features <= 0){
- throw invalid_argument("Invalid argument, there must be 1 or more features to train on.");
- }
-
- if(samples <= 0){
- throw invalid_argument("Invalid argument, there must be 1 or more samples to train on.");
- }
-
- splittingTree = recurse(X, samples, y, features, depth);
- featureCount = features;
-}
-
-
-std::string DecisionTreeClassifier::getDot(){
- if (splittingTree == nullptr){
- throw logic_error("Decision tree must be created prior to generating dot output.");
- }
- std::string edges = splittingTree->getDotEdges();
- std::string dot = "digraph decisionTree {\n" + edges + "}";
- return dot;
-}
-
-int DecisionTreeClassifier::primaryClass(int* y, int labelCount){
-
- unordered_map map = unordered_map<int,int>();
-
- for(int i = 0; i < labelCount; ++i){
- map[y[i]] += 1;
- }
-
- int mostElements = 0;
- int label = 0;
-
- for (auto& item : map){
- if(item.second > mostElements){
- mostElements = item.second;
- label = item.first;
- }
- }
-
- return label;
-}
-
-
-
-// add depth
-TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){
-
- if(depthRem == 0){
- TreeNode* ret = new TreeNode(primaryClass(y, rows));
- return ret;
- }
-
- // found minimum node
- if(rows == 1){
- TreeNode* ret = new TreeNode(primaryClass(y, rows));
- return ret;
- }
-
- // get best split option
- TreeNode* chosen = bestSplit(X, rows, y, columns);
- SplitResults split = chosen->splitOnNode(X, y, rows, columns);
-
- // no valid splits, but we still did create some new arrays.
- if(split.rightSize == rows || split.leftSize == rows){
- TreeNode* ret = new TreeNode(primaryClass(y, rows));
- delete split.XLeft;
- delete split.XRight;
- delete split.yLeft;
- delete split.yRight;
- return ret;
- }
-
- // traverse lt tree
- TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1);
- // traverse gt tree
- TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1);
-
- chosen->setLeftChild(left);
- chosen->setRightChild(right);
-
- delete split.XLeft;
- delete split.XRight;
- delete split.yLeft;
- delete split.yRight;
-
- return chosen;
-}
-
-
-
-// 1 1 0
-// 3 3 0
-// 2 1 1
-// 4 1 3
-
-// consider adding interpolation to this and sorting the list first.
-// Also, no reason to consider the 0th split if that is the case.
-
-TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){
-
- TreeNode* bestNode = nullptr;
- float bestGini = std::numeric_limits<float>::max();
-
- for(int col = 0 ; col < columns; ++col){
- for(int row = 0; row < rows; ++row){
-
- float val = X[row*columns + col];
- TreeNode* current = new TreeNode(val, col);
- float gini = current->evalSplit(X, y, rows, columns, "gini");
- if (gini < bestGini){
-
- TreeNode* prevBest = bestNode;
- delete prevBest;
-
- bestNode = current;
- bestGini = gini;
- }
- else{
- delete current;
- }
- }
- }
-
- return bestNode;
-
-}
-
-int* DecisionTreeClassifier::predict(float* X, int samples, int features){
-
- if(featureCount == -1){
- throw logic_error("Unable to predict prior to calling fit().");
- }
-
- if(features != this->featureCount){
- throw invalid_argument("Incorrect number of features for prediction.");
- }
- cout << "PREDICTING" << endl;
-
- int* predictions = new int[samples];
-
- for(int i = 0; i < samples; ++i){
- TreeNode* current = splittingTree;
- while(!current->isLeaf()){
- float* currentElement = X;
- currentElement += features * i;
- bool lessThan = current->lessThan(currentElement, features);
- if(lessThan){
- current = current->getLeftChild();
- }
- else{
- current = current->getRightChild();
- }
- }
- predictions[i] = current->getClassification();
- }
-
- return predictions;
-}
diff --git a/rewrite/Makefile b/rewrite/Makefile
@@ -1,17 +1,276 @@
+# CMAKE generated file: DO NOT EDIT!
+# Generated by "Unix Makefiles" Generator, CMake Version 3.31
+
+# Default target executed when no arguments are given to make.
+default_target: all
+.PHONY : default_target
+
+# Allow only one "make -f Makefile2" at a time, but pass parallelism.
+.NOTPARALLEL:
+
+#=============================================================================
+# Special targets provided by cmake.
+
+# Disable implicit rules so canonical targets will work.
+.SUFFIXES:
+
+# Disable VCS-based implicit rules.
+% : %,v
+
+# Disable VCS-based implicit rules.
+% : RCS/%
+
+# Disable VCS-based implicit rules.
+% : RCS/%,v
+
+# Disable VCS-based implicit rules.
+% : SCCS/s.%
+
+# Disable VCS-based implicit rules.
+% : s.%
+
+.SUFFIXES: .hpux_make_needs_suffix_list
+
+# Command-line flag to silence nested $(MAKE).
+$(VERBOSE)MAKESILENT = -s
+
+#Suppress display of executed commands.
+$(VERBOSE).SILENT:
+
+# A target that is always out of date.
+cmake_force:
+.PHONY : cmake_force
+
+#=============================================================================
+# Set environment variables for the build.
+
+# The shell in which to execute make rules.
+SHELL = /bin/sh
+
+# The CMake executable.
+CMAKE_COMMAND = /usr/bin/cmake
+
+# The command to remove a file.
+RM = /usr/bin/cmake -E rm -f
+
+# Escaping for special characters.
+EQUALS = =
+
+# The top-level source directory on which CMake was run.
+CMAKE_SOURCE_DIR = /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite
+
+# The top-level build directory on which CMake was run.
+CMAKE_BINARY_DIR = /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite
+
+#=============================================================================
+# Targets provided globally by CMake.
+
+# Special rule for the target edit_cache
+edit_cache:
+ @$(CMAKE_COMMAND) -E cmake_echo_color "--switch=$(COLOR)" --cyan "No interactive CMake dialog available..."
+ /usr/bin/cmake -E echo No\ interactive\ CMake\ dialog\ available.
+.PHONY : edit_cache
+
+# Special rule for the target edit_cache
+edit_cache/fast: edit_cache
+.PHONY : edit_cache/fast
+
+# Special rule for the target rebuild_cache
+rebuild_cache:
+ @$(CMAKE_COMMAND) -E cmake_echo_color "--switch=$(COLOR)" --cyan "Running CMake to regenerate build system..."
+ /usr/bin/cmake --regenerate-during-build -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR)
+.PHONY : rebuild_cache
+
+# Special rule for the target rebuild_cache
+rebuild_cache/fast: rebuild_cache
+.PHONY : rebuild_cache/fast
+
+# The main all target
+all: cmake_check_build_system
+ $(CMAKE_COMMAND) -E cmake_progress_start /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite/CMakeFiles /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite//CMakeFiles/progress.marks
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 all
+ $(CMAKE_COMMAND) -E cmake_progress_start /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite/CMakeFiles 0
+.PHONY : all
+
+# The main clean target
clean:
- rm *.o
- rm *.out
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean
+.PHONY : clean
+
+# The main clean target
+clean/fast: clean
+.PHONY : clean/fast
+
+# Prepare targets for installation.
+preinstall: all
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall
+.PHONY : preinstall
+
+# Prepare targets for installation.
+preinstall/fast:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall
+.PHONY : preinstall/fast
+
+# clear depends
+depend:
+ $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 1
+.PHONY : depend
+
+#=============================================================================
+# Target rules for targets named decision_tree
+
+# Build rule for target.
+decision_tree: cmake_check_build_system
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 decision_tree
+.PHONY : decision_tree
+
+# fast build rule for target.
+decision_tree/fast:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/build
+.PHONY : decision_tree/fast
+
+#=============================================================================
+# Target rules for targets named clean_build
+
+# Build rule for target.
+clean_build: cmake_check_build_system
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean_build
+.PHONY : clean_build
+
+# fast build rule for target.
+clean_build/fast:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/clean_build.dir/build.make CMakeFiles/clean_build.dir/build
+.PHONY : clean_build/fast
+
+cpp/Criterion.o: cpp/Criterion.cpp.o
+.PHONY : cpp/Criterion.o
+
+# target to build an object file
+cpp/Criterion.cpp.o:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.o
+.PHONY : cpp/Criterion.cpp.o
+
+cpp/Criterion.i: cpp/Criterion.cpp.i
+.PHONY : cpp/Criterion.i
+
+# target to preprocess a source file
+cpp/Criterion.cpp.i:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.i
+.PHONY : cpp/Criterion.cpp.i
+
+cpp/Criterion.s: cpp/Criterion.cpp.s
+.PHONY : cpp/Criterion.s
+
+# target to generate assembly for a file
+cpp/Criterion.cpp.s:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.s
+.PHONY : cpp/Criterion.cpp.s
+
+cpp/DecisionTreeClassifier.o: cpp/DecisionTreeClassifier.cpp.o
+.PHONY : cpp/DecisionTreeClassifier.o
+
+# target to build an object file
+cpp/DecisionTreeClassifier.cpp.o:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.o
+.PHONY : cpp/DecisionTreeClassifier.cpp.o
+
+cpp/DecisionTreeClassifier.i: cpp/DecisionTreeClassifier.cpp.i
+.PHONY : cpp/DecisionTreeClassifier.i
+
+# target to preprocess a source file
+cpp/DecisionTreeClassifier.cpp.i:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.i
+.PHONY : cpp/DecisionTreeClassifier.cpp.i
+
+cpp/DecisionTreeClassifier.s: cpp/DecisionTreeClassifier.cpp.s
+.PHONY : cpp/DecisionTreeClassifier.s
+
+# target to generate assembly for a file
+cpp/DecisionTreeClassifier.cpp.s:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.s
+.PHONY : cpp/DecisionTreeClassifier.cpp.s
+
+cpp/TreeNode.o: cpp/TreeNode.cpp.o
+.PHONY : cpp/TreeNode.o
+
+# target to build an object file
+cpp/TreeNode.cpp.o:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.o
+.PHONY : cpp/TreeNode.cpp.o
+
+cpp/TreeNode.i: cpp/TreeNode.cpp.i
+.PHONY : cpp/TreeNode.i
+
+# target to preprocess a source file
+cpp/TreeNode.cpp.i:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.i
+.PHONY : cpp/TreeNode.cpp.i
+
+cpp/TreeNode.s: cpp/TreeNode.cpp.s
+.PHONY : cpp/TreeNode.s
+
+# target to generate assembly for a file
+cpp/TreeNode.cpp.s:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.s
+.PHONY : cpp/TreeNode.cpp.s
+
+cpp/bindings.o: cpp/bindings.cpp.o
+.PHONY : cpp/bindings.o
+
+# target to build an object file
+cpp/bindings.cpp.o:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.o
+.PHONY : cpp/bindings.cpp.o
+
+cpp/bindings.i: cpp/bindings.cpp.i
+.PHONY : cpp/bindings.i
+
+# target to preprocess a source file
+cpp/bindings.cpp.i:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.i
+.PHONY : cpp/bindings.cpp.i
+
+cpp/bindings.s: cpp/bindings.cpp.s
+.PHONY : cpp/bindings.s
+
+# target to generate assembly for a file
+cpp/bindings.cpp.s:
+ $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.s
+.PHONY : cpp/bindings.cpp.s
+
+# Help Target
+help:
+ @echo "The following are some of the valid targets for this Makefile:"
+ @echo "... all (the default if no target is provided)"
+ @echo "... clean"
+ @echo "... depend"
+ @echo "... edit_cache"
+ @echo "... rebuild_cache"
+ @echo "... clean_build"
+ @echo "... decision_tree"
+ @echo "... cpp/Criterion.o"
+ @echo "... cpp/Criterion.i"
+ @echo "... cpp/Criterion.s"
+ @echo "... cpp/DecisionTreeClassifier.o"
+ @echo "... cpp/DecisionTreeClassifier.i"
+ @echo "... cpp/DecisionTreeClassifier.s"
+ @echo "... cpp/TreeNode.o"
+ @echo "... cpp/TreeNode.i"
+ @echo "... cpp/TreeNode.s"
+ @echo "... cpp/bindings.o"
+ @echo "... cpp/bindings.i"
+ @echo "... cpp/bindings.s"
+.PHONY : help
-CXXFLAGS = -O3 -Wall -std=c++17 # Add -O3 for optimization and other flags
-LDFLAGS = # Add any linker flags here
-# Targets
-node:
- g++ $(CXXFLAGS) -c TreeNode.cpp -o TreeNode.o
+#=============================================================================
+# Special targets to cleanup operation of make.
-decisionTree:
- g++ $(CXXFLAGS) -c DecisionTreeClassifier.cpp -o DecisionTreeClassifier.o
+# Special rule to run CMake to check the build system integrity.
+# No rule that depends on this can have commands that come from listfiles
+# because they might be regenerated.
+cmake_check_build_system:
+ $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 0
+.PHONY : cmake_check_build_system
-tests: DecisionTreeClassifier.o TreeNode.o
- g++ $(CXXFLAGS) Tests.cpp DecisionTreeClassifier.o TreeNode.o -o a.out
diff --git a/rewrite/Test.py b/rewrite/Test.py
@@ -0,0 +1,38 @@
+from sklearn.datasets import load_digits
+import numpy as np
+from decision_tree import DecisionTreeClassifier
+from keras.datasets import mnist
+from sklearn.metrics import accuracy_score
+
+(train_X, train_y), (test_X, test_y) = mnist.load_data()
+
+train_X = train_X.reshape(-1, 784)
+test_X = test_X.reshape(-1, 784)
+
+
+train_X = train_X[0:2000]
+train_y = train_y[0:2000]
+
+
+clf = DecisionTreeClassifier(20)
+
+print("TRAINING")
+clf.fit(train_X, len(train_X), train_y, len(train_X[0]))
+
+predictions = clf.predict(test_X, test_X.shape[0], test_X.shape[1])
+
+# Output results
+print("Predictions:", predictions)
+print(accuracy_score(y_pred=predictions, y_true=test_y))
+
+
+
+
+
+
+dot_representation = clf.getDot()
+# print("\nDecision Tree (DOT Format):\n", dot_representation)
+
+# Optionally save DOT representation to a file
+with open("decision_tree.dot", "w") as file:
+ file.write(dot_representation)
diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp
@@ -1,94 +0,0 @@
-#include "DecisionTreeClassifier.h"
-#include <fstream> // For file I/O
-#include "iostream"
-#include "assert.h"
-
-using namespace std;
-
-void testTreeNode(){
- int labels[] = {10, 10, 10, 1, 2, 3};
- float samples[][4] = {
- {1,1,5,3},
- {1,2,5,3},
- {1,7,5,3},
- {1,3,5,3},
- {1,7,5,3},
- {1,1,5,3}
- };
-
-
- TreeNode tn = TreeNode(5.0f ,1);
- bool isLeaf = tn.isLeaf();
-
- assert(!isLeaf);
-
- float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini");
- assert(abs(giniVal - .5833333) < .0001 );
-
- tn.setSplit(0.0f, 0);
- float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini");
- assert(abs(giniVal2 - .6666666) < .0001);
-
-}
-
-
-int main(){
-
- cout << "STARTING" << endl;
-
- int SAMPLES = 20000;
-
- DecisionTreeClassifier clf(1);
- int labels[SAMPLES];
- for (int i = 0; i < SAMPLES; ++i){
- labels[i] = i % 10; // Example: create labels that cycle through 0 to 9
- }
-
- float samples[SAMPLES][4];
- for (int i = 0; i < SAMPLES; ++i){
- // Fill the samples with some arbitrary data (example: sequential)
- samples[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10)
- samples[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5)
- samples[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7)
- samples[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3)
- }
-
- clf.fit(*samples, SAMPLES, labels, 4);
- int PREDS = 10;
-
- float preds[PREDS][4];
- for (int i = 0; i < PREDS; ++i){
- // Fill the samples with some arbitrary data (example: sequential)
- preds[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10)
- preds[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5)
- preds[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7)
- preds[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3)
- }
-
- int* predsOut = clf.predict(*preds, PREDS, 4);
-
- cout << "DONE" << endl;
-
- for(int i = 0 ; i < PREDS; ++i){
- cout << preds[i][0] << " " << preds[i][1] << " " <<preds[i][2] << " " << preds[i][3] << " " << predsOut[i] << endl;
- }
-
- delete[] predsOut;
-
- return 0;
-
-
- ofstream outFile("decision_tree.dot");
-
- // Check if the file is open
- if (outFile.is_open()) {
- // Write the decision tree to the file
- outFile << clf.getDot() << endl;
- outFile.close(); // Close the file after writing
- cout << "Decision tree written to 'decision_tree.dot'" << endl;
- } else {
- cerr << "Failed to open the file." << endl;
- }
-
- return 0;
-}
diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp
@@ -1,244 +0,0 @@
-#include "TreeNode.h"
-#include "stdexcept"
-#include "unordered_map"
-#include "math.h"
-#include "iostream"
-#include <sstream> //for std::stringstream
-#include <string> //for std::string
-
-TreeNode::TreeNode(int classification){
- leaf = true;
- this->classification = classification;
-}
-
-TreeNode::TreeNode(float splittingVal, int featureIndex){
- splitValue = splittingVal;
- index = featureIndex;
- leaf = false;
-}
-
-void TreeNode::setSplit(float splittingVal, int featureIndex){
- splitValue = splittingVal;
- index = featureIndex;
- leaf = false;
-}
-
-bool TreeNode::isLeaf(){
- return leaf;
-}
-
-float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){
-
- if(isLeaf()){
- throw std::logic_error("Cannot evaluate split on leaf node.");
- }
-
- if(criterion != "gini"){
- throw std::invalid_argument("Gini impurity is the only supported criterion.");
- }
-
- return giniImpurity(X, y, samples, features);
-}
-
-
-void TreeNode::setLeftChild(TreeNode* child){
- leftChild = child;
-}
-
-void TreeNode::setRightChild(TreeNode* child){
- rightChild = child;
-}
-
-TreeNode* TreeNode::getLeftChild(){
- return leftChild;
-}
-
-TreeNode* TreeNode::getRightChild(){
- return rightChild;
-}
-
-float TreeNode::getSplitVal(){
- return splitValue;
-}
-
-int TreeNode::getIndexSplit(){
- return index;
-}
-
-SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){
-
- SplitResults result = SplitResults();
-
- int ltCount = 0;
- int gteqCount = 0;
-
- for(int i = 0 ; i < samples; ++i){
- if(X[(i*features) + index] < splitValue){
- ltCount += 1;
- }
- else{
- gteqCount += 1;
- }
- }
-
- // Create X arrays to return
-
- float* ltArr = new float[ltCount * features];
- float* gteqArr = new float[gteqCount * features];
-
- // Create array ptr next open
-
- float* nextLtX = ltArr;
- float* nextGteqX = gteqArr;
-
- // Create y arrays to return
-
- int* ltYArr = new int[ltCount];
- int* gteqYArr = new int[gteqCount];
-
- // Create array ptr next open
-
- int* nextLtY = ltYArr;
- int* nextGteqY = gteqYArr;
-
- // Set pointers for return to the new arrays
-
- result.XLeft = ltArr;
- result.yLeft = ltYArr;
-
- result.XRight = gteqArr;
- result.yRight = gteqYArr;
-
- result.leftSize = ltCount;
- result.rightSize = gteqCount;
-
- // Set arrays with correct values
-
- for(int i = 0 ; i < samples; ++i){
- if(X[(i*features) + index] < splitValue){
- for(int x = 0; x < features; ++x){
- nextLtX[x] = X[(i*features) + x];
- }
-
- nextLtX += features;
-
- nextLtY[0] = y[i];
- nextLtY += 1;
- }
- else{
- for(int x = 0; x < features; ++x){
- nextGteqX[x] = X[(i*features) + x];
- }
-
- nextGteqX += features;
-
- nextGteqY[0] = y[i];
- nextGteqY += 1;
- }
- }
-
- //for(int x = 0 ; x < ltCount; ++x){
- // for(int i = 0 ; i < features; ++i){
- // std::cout << ltArr[x*features + i];
- // }
- // std::cout << std::endl;
- //}
-
- //for(int x = 0 ; x < ltCount; ++x){
- // std::cout << ltYArr[x] << std::endl;
- //}
-
- return result;
-}
-
-
-
-float TreeNode::giniImpurity(float* X, int* y, int samples, int features){
-
- std::unordered_map<int, int> ltMap;
- std::unordered_map<int, int> gtMap;
-
- int ltCount = 0;
- int gteqCount = 0;
-
- for(int i = 0; i < samples; ++i){
- if(X[index + (i * features)] < splitValue){
- ltMap[y[i]]++;
- ltCount++;
- }
- else{
- gtMap[y[i]]++;
- gteqCount++;
- }
- }
-
-
- float ltGini= 1.0f;
-
- for (const auto& pair : ltMap) {
- ltGini -= pow(float(pair.second) / ltCount, 2);
- }
-
- float gteqGini = 1.0f;
-
- for (const auto& pair : gtMap) {
- gteqGini -= pow(float(pair.second) / gteqCount, 2);
- }
-
- if(gteqCount == 0){
- gteqGini = 0.0f;
- }
- if(ltCount == 0){
- ltGini = 0.0f;
- }
-
- float gini = gteqGini * float(gteqCount) / samples;
- gini += ltGini * float(ltCount) / samples;
-
- return gini;
-}
-
-
-
-std::string TreeNode::getDotEdges(){
-
- if(isLeaf()){
- return "";
- }
-
- std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n";
- current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n";
-
- current += rightChild->getDotEdges();
- current += leftChild->getDotEdges();
-
- return current;
-}
-
-std::string TreeNode::getDotLabel(){
- const void * address = static_cast<const void*>(this);
- std::stringstream ss;
- ss << address;
- std::string name = ss.str();
- if (isLeaf()){
- return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\"";
- }
-
- return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\"";
-}
-
-int TreeNode::getClassification(){
- if(isLeaf()){
- return classification;
- }
- throw std::logic_error("Unable to call getClassification() on internal vertices.");
-}
-
-bool TreeNode::lessThan(float* sample, int features){
-
- if(features < this->index){
- throw std::invalid_argument("Attempting to evaluate split with input that contains less features.");
- }
-
- return(sample[index] < splitValue);
-}
diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h
@@ -1,41 +0,0 @@
-#include "string"
-
-struct SplitResults{
- float* XLeft;
- float* XRight;
- int* yLeft;
- int* yRight;
- int leftSize;
- int rightSize;
-};
-
-class TreeNode{
- public:
- TreeNode(int classification);
- TreeNode(float splittingVal, int featureIndex);
- bool isLeaf();
- void setSplit(float splittingValue, int featureIndex);
- float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
- TreeNode* getLeftChild();
- TreeNode* getRightChild();
- void setLeftChild(TreeNode* child);
- void setRightChild(TreeNode* child);
- float getSplitVal();
- int getIndexSplit();
- SplitResults splitOnNode(float* X, int* y, int samples, int features);
- std::string getDotEdges();
- int getClassification();
- bool lessThan(float* sample, int features);
-
- private:
- bool leaf;
- float splitValue;
- int index;
- TreeNode* leftChild;
- TreeNode* rightChild;
- float giniImpurity(float* X, int* y, int samples, int features);
- std::string getDotLabel();
- int classification;
-};
-
-
diff --git a/rewrite/cpp/Criterion.cpp b/rewrite/cpp/Criterion.cpp
@@ -0,0 +1,49 @@
+#include "Criterion.h"
+#include <unordered_map>
+#include <math.h>
+
+float Criterion::giniImpurity(float* X, int* y, int samples, int features, int index, float splitValue){
+
+ std::unordered_map<int, int> ltMap;
+ std::unordered_map<int, int> gtMap;
+
+ int ltCount = 0;
+ int gteqCount = 0;
+
+ for(int i = 0; i < samples; ++i){
+ if(X[index + (i * features)] < splitValue){
+ ltMap[y[i]]++;
+ ltCount++;
+ }
+ else{
+ gtMap[y[i]]++;
+ gteqCount++;
+ }
+ }
+
+
+ float ltGini= 1.0f;
+
+ for (const auto& pair : ltMap) {
+ ltGini -= pow(float(pair.second) / ltCount, 2);
+ }
+
+ float gteqGini = 1.0f;
+
+ for (const auto& pair : gtMap) {
+ gteqGini -= pow(float(pair.second) / gteqCount, 2);
+ }
+
+ if(gteqCount == 0){
+ gteqGini = 0.0f;
+ }
+ if(ltCount == 0){
+ ltGini = 0.0f;
+ }
+
+ float gini = gteqGini * float(gteqCount) / samples;
+ gini += ltGini * float(ltCount) / samples;
+
+ return gini;
+}
+
diff --git a/rewrite/cpp/Criterion.h b/rewrite/cpp/Criterion.h
@@ -0,0 +1,5 @@
+
+class Criterion{
+ public:
+ float giniImpurity(float* X, int* y, int samples, int features, int index, float splitValue);
+};
diff --git a/rewrite/cpp/DecisionTreeClassifier.cpp b/rewrite/cpp/DecisionTreeClassifier.cpp
@@ -0,0 +1,177 @@
+#include "DecisionTreeClassifier.h"
+#include <limits>
+#include <iostream>
+#include <unordered_map>
+
+using namespace std;
+
+DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
+ depth = maxDepth;
+}
+
+void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){
+ if (splittingTree != nullptr){
+ throw logic_error("Decision trees don't support incremental learning, fit can only be called once.");
+ }
+
+ if(features <= 0){
+ throw invalid_argument("Invalid argument, there must be 1 or more features to train on.");
+ }
+
+ if(samples <= 0){
+ throw invalid_argument("Invalid argument, there must be 1 or more samples to train on.");
+ }
+
+ splittingTree = recurse(X, samples, y, features, depth);
+ featureCount = features;
+}
+
+
+std::string DecisionTreeClassifier::getDot(){
+ if (splittingTree == nullptr){
+ throw logic_error("Decision tree must be created prior to generating dot output.");
+ }
+ std::string edges = splittingTree->getDotEdges();
+ std::string dot = "digraph decisionTree {\n" + edges + "}";
+ return dot;
+}
+
+int DecisionTreeClassifier::primaryClass(int* y, int labelCount){
+
+ unordered_map map = unordered_map<int,int>();
+
+ for(int i = 0; i < labelCount; ++i){
+ map[y[i]] += 1;
+ }
+
+ int mostElements = 0;
+ int label = 0;
+
+ for (auto& item : map){
+ if(item.second > mostElements){
+ mostElements = item.second;
+ label = item.first;
+ }
+ }
+
+ return label;
+}
+
+
+
+// add depth
+TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){
+
+ if(depthRem == 0){
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
+ return ret;
+ }
+
+ // found minimum node
+ if(rows == 1){
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
+ return ret;
+ }
+
+ // get best split option
+ TreeNode* chosen = bestSplit(X, rows, y, columns);
+ SplitResults split = chosen->splitOnNode(X, y, rows, columns);
+
+ // no valid splits, but we still did create some new arrays.
+ if(split.rightSize == rows || split.leftSize == rows){
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
+ delete split.XLeft;
+ delete split.XRight;
+ delete split.yLeft;
+ delete split.yRight;
+ return ret;
+ }
+
+ // traverse lt tree
+ TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1);
+ // traverse gt tree
+ TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1);
+
+ chosen->setLeftChild(left);
+ chosen->setRightChild(right);
+
+ delete split.XLeft;
+ delete split.XRight;
+ delete split.yLeft;
+ delete split.yRight;
+
+ return chosen;
+}
+
+
+
+
+
+// 1 1 0
+// 3 3 0
+// 2 1 1
+// 4 1 3
+
+// consider adding interpolation to this and sorting the list first.
+// Also, no reason to consider the 0th split if that is the case.
+
+TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){
+
+ TreeNode* bestNode = nullptr;
+ float bestGini = std::numeric_limits<float>::max();
+
+ for(int col = 0 ; col < columns; ++col){
+ for(int row = 0; row < rows; ++row){
+
+ float val = X[row*columns + col];
+ TreeNode* current = new TreeNode(val, col);
+ float gini = current->evalSplit(X, y, rows, columns, "gini");
+ if (gini < bestGini){
+
+ TreeNode* prevBest = bestNode;
+ delete prevBest;
+
+ bestNode = current;
+ bestGini = gini;
+ }
+ else{
+ delete current;
+ }
+ }
+
+ }
+
+ return bestNode;
+
+}
+
+int* DecisionTreeClassifier::predict(float* X, int samples, int features){
+
+ if(featureCount == -1){
+ throw logic_error("Unable to predict prior to calling fit().");
+ }
+
+ if(features != this->featureCount){
+ throw invalid_argument("Incorrect number of features for prediction.");
+ }
+
+ int* predictions = new int[samples];
+
+ for(int i = 0; i < samples; ++i){
+ TreeNode* current = splittingTree;
+ while(!current->isLeaf()){
+ float* currentElement = X;
+ currentElement += features * i;
+ bool lessThan = current->lessThan(currentElement, features);
+ if(lessThan){
+ current = current->getLeftChild();
+ }
+ else{
+ current = current->getRightChild();
+ }
+ }
+ predictions[i] = current->getClassification();
+ }
+
+ return predictions;
+}
diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/cpp/DecisionTreeClassifier.h
diff --git a/rewrite/cpp/TreeNode.cpp b/rewrite/cpp/TreeNode.cpp
@@ -0,0 +1,203 @@
+#include "TreeNode.h"
+#include "stdexcept"
+#include "Criterion.h"
+#include "math.h"
+#include "iostream"
+#include <sstream> //for std::stringstream
+#include <string> //for std::string
+
+TreeNode::TreeNode(int classification){
+ leaf = true;
+ this->classification = classification;
+}
+
+TreeNode::TreeNode(float splittingVal, int featureIndex){
+ splitValue = splittingVal;
+ index = featureIndex;
+ leaf = false;
+}
+
+void TreeNode::setSplit(float splittingVal, int featureIndex){
+ splitValue = splittingVal;
+ index = featureIndex;
+ leaf = false;
+}
+
+bool TreeNode::isLeaf(){
+ return leaf;
+}
+
+float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){
+
+ if(isLeaf()){
+ throw std::logic_error("Cannot evaluate split on leaf node.");
+ }
+
+ if(criterion != "gini"){
+ throw std::invalid_argument("Gini impurity is the only supported criterion.");
+ }
+
+ Criterion evalCriterion= Criterion();
+
+ return evalCriterion.giniImpurity(X, y, samples, features, this->index, this->splitValue);
+}
+
+
+void TreeNode::setLeftChild(TreeNode* child){
+ leftChild = child;
+}
+
+void TreeNode::setRightChild(TreeNode* child){
+ rightChild = child;
+}
+
+TreeNode* TreeNode::getLeftChild(){
+ return leftChild;
+}
+
+TreeNode* TreeNode::getRightChild(){
+ return rightChild;
+}
+
+float TreeNode::getSplitVal(){
+ return splitValue;
+}
+
+int TreeNode::getIndexSplit(){
+ return index;
+}
+
+SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){
+
+ SplitResults result = SplitResults();
+
+ int ltCount = 0;
+ int gteqCount = 0;
+
+ for(int i = 0 ; i < samples; ++i){
+ if(X[(i*features) + index] < splitValue){
+ ltCount += 1;
+ }
+ else{
+ gteqCount += 1;
+ }
+ }
+
+ // Create X arrays to return
+
+ float* ltArr = new float[ltCount * features];
+ float* gteqArr = new float[gteqCount * features];
+
+ // Create array ptr next open
+
+ float* nextLtX = ltArr;
+ float* nextGteqX = gteqArr;
+
+ // Create y arrays to return
+
+ int* ltYArr = new int[ltCount];
+ int* gteqYArr = new int[gteqCount];
+
+ // Create array ptr next open
+
+ int* nextLtY = ltYArr;
+ int* nextGteqY = gteqYArr;
+
+ // Set pointers for return to the new arrays
+
+ result.XLeft = ltArr;
+ result.yLeft = ltYArr;
+
+ result.XRight = gteqArr;
+ result.yRight = gteqYArr;
+
+ result.leftSize = ltCount;
+ result.rightSize = gteqCount;
+
+ // Set arrays with correct values
+
+ for(int i = 0 ; i < samples; ++i){
+ if(X[(i*features) + index] < splitValue){
+ for(int x = 0; x < features; ++x){
+ nextLtX[x] = X[(i*features) + x];
+ }
+
+ nextLtX += features;
+
+ nextLtY[0] = y[i];
+ nextLtY += 1;
+ }
+ else{
+ for(int x = 0; x < features; ++x){
+ nextGteqX[x] = X[(i*features) + x];
+ }
+
+ nextGteqX += features;
+
+ nextGteqY[0] = y[i];
+ nextGteqY += 1;
+ }
+ }
+
+ //for(int x = 0 ; x < ltCount; ++x){
+ // for(int i = 0 ; i < features; ++i){
+ // std::cout << ltArr[x*features + i];
+ // }
+ // std::cout << std::endl;
+ //}
+
+ //for(int x = 0 ; x < ltCount; ++x){
+ // std::cout << ltYArr[x] << std::endl;
+ //}
+
+ return result;
+}
+
+
+
+
+
+
+
+std::string TreeNode::getDotEdges(){
+
+ if(isLeaf()){
+ return "";
+ }
+
+ std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n";
+ current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n";
+
+ current += rightChild->getDotEdges();
+ current += leftChild->getDotEdges();
+
+ return current;
+}
+
+std::string TreeNode::getDotLabel(){
+ const void * address = static_cast<const void*>(this);
+ std::stringstream ss;
+ ss << address;
+ std::string name = ss.str();
+ if (isLeaf()){
+ return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\"";
+ }
+
+ return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\"";
+}
+
+int TreeNode::getClassification(){
+ if(isLeaf()){
+ return classification;
+ }
+ throw std::logic_error("Unable to call getClassification() on internal vertices.");
+}
+
+bool TreeNode::lessThan(float* sample, int features){
+
+ if(features < this->index){
+ throw std::invalid_argument("Attempting to evaluate split with input that contains less features.");
+ }
+
+ return(sample[index] < splitValue);
+}
diff --git a/rewrite/cpp/TreeNode.h b/rewrite/cpp/TreeNode.h
@@ -0,0 +1,40 @@
+#include "string"
+
+struct SplitResults{
+ float* XLeft;
+ float* XRight;
+ int* yLeft;
+ int* yRight;
+ int leftSize;
+ int rightSize;
+};
+
+class TreeNode{
+ public:
+ TreeNode(int classification);
+ TreeNode(float splittingVal, int featureIndex);
+ bool isLeaf();
+ void setSplit(float splittingValue, int featureIndex);
+ float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
+ TreeNode* getLeftChild();
+ TreeNode* getRightChild();
+ void setLeftChild(TreeNode* child);
+ void setRightChild(TreeNode* child);
+ float getSplitVal();
+ int getIndexSplit();
+ SplitResults splitOnNode(float* X, int* y, int samples, int features);
+ std::string getDotEdges();
+ int getClassification();
+ bool lessThan(float* sample, int features);
+
+ private:
+ bool leaf;
+ float splitValue;
+ int index;
+ TreeNode* leftChild;
+ TreeNode* rightChild;
+ std::string getDotLabel();
+ int classification;
+};
+
+
diff --git a/rewrite/cpp/bindings.cpp b/rewrite/cpp/bindings.cpp
@@ -0,0 +1,30 @@
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+#include "DecisionTreeClassifier.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(decision_tree, m) {
+ py::class_<DecisionTreeClassifier>(m, "DecisionTreeClassifier")
+ .def(py::init<int>())
+ .def("fit", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, py::array_t<int> y, int features) {
+ auto X_buf = X.request(); // Request a buffer from NumPy array
+ auto y_buf = y.request(); // Request a buffer from NumPy array
+ float* X_ptr = static_cast<float*>(X_buf.ptr);
+ int* y_ptr = static_cast<int*>(y_buf.ptr);
+ self.fit(X_ptr, samples, y_ptr, features);
+ })
+ .def("predict", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, int features) {
+ auto X_buf = X.request(); // Request a buffer from NumPy array
+ float* X_ptr = static_cast<float*>(X_buf.ptr);
+ int* result = self.predict(X_ptr, samples, features);
+
+ // Return a NumPy array from the result
+ return py::array_t<int>(samples, result);
+ })
+ .def("getDot", &DecisionTreeClassifier::getDot)
+ .def("__repr__", [](const DecisionTreeClassifier &dt) {
+ return "<DecisionTreeClassifier>";
+ });
+}
+