decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit fbb285ba01aa4db1ae65a4af09e49f307790396c
parent abd8f9bdb27dff0e4ed253f17d2dc0eb1bb5e2cb
Author: Andrew <andrewlaack1@gmail.com>
Date:   Mon, 23 Dec 2024 10:38:46 -0600

Set up bind

Diffstat:
M.gitignore | 5+++++
MNotes.md | 15+++++++++++++++
Dclassifier/Testing.py | 71-----------------------------------------------------------------------
Rclassifier/LeafNode.py -> old/LeafNode.py | 0
Rclassifier/Makefile -> old/Makefile | 0
Rclassifier/Podtc.py -> old/Podtc.py | 0
Rclassifier/SplittingNode.py -> old/SplittingNode.py | 0
Aold/Testing.py | 67+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Rclassifier/__init__.py -> old/__init__.py | 0
Rclassifier/cpp/gini.cpp -> old/cpp/gini.cpp | 0
Rclassifier/cpp/split.cpp -> old/cpp/split.cpp | 0
Arewrite/CMakeLists.txt | 42++++++++++++++++++++++++++++++++++++++++++
Drewrite/DecisionTreeClassifier.cpp | 175-------------------------------------------------------------------------------
Mrewrite/Makefile | 281+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
Arewrite/Test.py | 38++++++++++++++++++++++++++++++++++++++
Drewrite/Tests.cpp | 94-------------------------------------------------------------------------------
Drewrite/TreeNode.cpp | 244-------------------------------------------------------------------------------
Drewrite/TreeNode.h | 41-----------------------------------------
Arewrite/cpp/Criterion.cpp | 49+++++++++++++++++++++++++++++++++++++++++++++++++
Arewrite/cpp/Criterion.h | 5+++++
Arewrite/cpp/DecisionTreeClassifier.cpp | 177+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Rrewrite/DecisionTreeClassifier.h -> rewrite/cpp/DecisionTreeClassifier.h | 0
Arewrite/cpp/TreeNode.cpp | 203+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Arewrite/cpp/TreeNode.h | 40++++++++++++++++++++++++++++++++++++++++
Arewrite/cpp/bindings.cpp | 30++++++++++++++++++++++++++++++
25 files changed, 941 insertions(+), 636 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -6,7 +6,11 @@ __pycache__/ # C extensions *.so +*.dot +cmake_install.cmake +CMakeCache.txt +datasets/ # Distribution / packaging .Python build/ @@ -22,6 +26,7 @@ sdist/ var/ wheels/ share/python-wheels/ +CMakeFiles/ *.egg-info/ .installed.cfg *.egg diff --git a/Notes.md b/Notes.md @@ -28,3 +28,18 @@ C++ - 15.271s (user) WITH 30,000 ELEMENTS, 4 FEATURES: PYTHON - 37.5128 (user) C++ 34.335 (user) + + +WITH WORKLOAD OF 6000 SAMPLES AND 784 FEATURES: + +MAX MEMORY C++ - 24628kB +MAX MEMORY Python - 521764kB + +ADD BENCHMARKING WITHOUT LIKELY AND THEN WITH. + +20,000 Samples, 4 Features: + WITHOUT LIKELY: 15.526 seconds (user), 15.632 seconds (user) + WITH LIKELY: 15.524 seconds (user), 15.608s (user) + + +BUILD OPTIMAL DECISION TREE diff --git a/classifier/Testing.py b/classifier/Testing.py @@ -1,71 +0,0 @@ -from Podtc import PseudoOptimalDecisionTreeClassifier -import numpy as np -##import plotly.express as px -#from sklearn.tree import DecisionTreeClassifier -#from sklearn.datasets import load_digits -#import pandas as pd -#from keras.datasets import mnist -#from sklearn.metrics import accuracy_score -#from setuptools import setup - -# (train_X, train_y), (test_X, test_y) = mnist.load_data() - -# train_X = train_X.reshape(-1, 784) -# test_X = test_X.reshape(-1, 784) - -X_train = np.random.randint(1, 10, size=(30000, 4)) - -# Generate 60,000 labels for y_train (for example, labels between 1 and 10) -y_train = np.random.randint(1, 11, size=30000) - -# train_X = [[2,5], [5,2], [3,4], [4,4], [5,5], [10, 10], [2,2], [12,12]] -# train_y = [1, 1 , 2, 1, 5, 2,1 ,3] - -classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=1); - -classifier.fit(X_train, y_train) - -classifier.graph() - -exit() - -y_pred = classifier.predict(test_X) -print("MY ACCURACY:") -print(accuracy_score(y_true=test_y, y_pred=y_pred)) - -classifier = DecisionTreeClassifier(max_depth=4) -classifier.fit(train_X, train_y) -y_pred = classifier.predict(test_X) - -print("SECOND ACCURACY:") -print(accuracy_score(y_true=test_y, y_pred=y_pred)) - - -exit() - -X = np.random.random((200, 2)) -y = np.random.random((200)) * 10 -y = y.round() -# y = (X[:,0] + X[:,1]) > .5 - - -#clf = DecisionTreeClassifier() -#clf.fit(X,y) - -classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=2); - -classifier.fit(X,y) - - -X_pred = np.random.random((20, 2)).round() - -print(X_pred) -print(classifier.predict(X_pred)) - -# classifier.predict() -# print(classifier) - - - -#scatter = px.scatter(x=X[:,0], y=X[:,1], color=y) -#scatter.show() diff --git a/classifier/LeafNode.py b/old/LeafNode.py diff --git a/classifier/Makefile b/old/Makefile diff --git a/classifier/Podtc.py b/old/Podtc.py diff --git a/classifier/SplittingNode.py b/old/SplittingNode.py diff --git a/old/Testing.py b/old/Testing.py @@ -0,0 +1,67 @@ +from Podtc import PseudoOptimalDecisionTreeClassifier +import numpy as np +##import plotly.express as px +#from sklearn.tree import DecisionTreeClassifier +#from sklearn.datasets import load_digits +#import pandas as pd +#from keras.datasets import mnist +#from sklearn.metrics import accuracy_score +#from setuptools import setup + +# (train_X, train_y), (test_X, test_y) = mnist.load_data() + +# train_X = train_X.reshape(-1, 784) +# test_X = test_X.reshape(-1, 784) + +X_train = np.random.randint(1, 10, size=(6000, 784)) + +# Generate 60,000 labels for y_train (for example, labels between 1 and 10) +y_train = np.random.randint(1, 11, size=6000) + +# train_X = [[2,5], [5,2], [3,4], [4,4], [5,5], [10, 10], [2,2], [12,12]] +# train_y = [1, 1 , 2, 1, 5, 2,1 ,3] + +classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=1); +classifier.fit(X_train, y_train) +exit() + +y_pred = classifier.predict(test_X) +print("MY ACCURACY:") +print(accuracy_score(y_true=test_y, y_pred=y_pred)) + +classifier = DecisionTreeClassifier(max_depth=4) +classifier.fit(train_X, train_y) +y_pred = classifier.predict(test_X) + +print("SECOND ACCURACY:") +print(accuracy_score(y_true=test_y, y_pred=y_pred)) + + +exit() + +X = np.random.random((200, 2)) +y = np.random.random((200)) * 10 +y = y.round() +# y = (X[:,0] + X[:,1]) > .5 + + +#clf = DecisionTreeClassifier() +#clf.fit(X,y) + +classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=2); + +classifier.fit(X,y) + + +X_pred = np.random.random((20, 2)).round() + +print(X_pred) +print(classifier.predict(X_pred)) + +# classifier.predict() +# print(classifier) + + + +#scatter = px.scatter(x=X[:,0], y=X[:,1], color=y) +#scatter.show() diff --git a/classifier/__init__.py b/old/__init__.py diff --git a/classifier/cpp/gini.cpp b/old/cpp/gini.cpp diff --git a/classifier/cpp/split.cpp b/old/cpp/split.cpp diff --git a/rewrite/CMakeLists.txt b/rewrite/CMakeLists.txt @@ -0,0 +1,42 @@ + +cmake_minimum_required(VERSION 3.10) + +# Set project name +project(DecisionTreeClassifier) + +# Set C++ standard +set(CMAKE_CXX_STANDARD 17) + +# Find Python and Pybind11 +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) +find_package(pybind11 REQUIRED) + +# Set compiler flags +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -Wall -fPIC") + +# Add source files +set(SOURCES + cpp/DecisionTreeClassifier.cpp + cpp/TreeNode.cpp + cpp/Criterion.cpp + cpp/bindings.cpp +) + +# Create the shared library +add_library(decision_tree MODULE ${SOURCES}) + +# Link with Python and Pybind11 +target_include_directories(decision_tree PRIVATE ${Python3_INCLUDE_DIRS}) +target_link_libraries(decision_tree PRIVATE ${Python3_LIBRARIES} pybind11::module) + +# Set output name based on Python extension suffix +set_target_properties(decision_tree PROPERTIES + PREFIX "" + SUFFIX ".so" +) + +# Rename the custom clean target to avoid conflict +add_custom_target(clean_build + COMMAND rm -f ${CMAKE_BINARY_DIR}/*.o + COMMENT "Clean build files" +) diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp @@ -1,175 +0,0 @@ -#include "DecisionTreeClassifier.h" -#include <limits> -#include <iostream> -#include <unordered_map> - -using namespace std; - -DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ - depth = maxDepth; -} - -void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){ - if (splittingTree != nullptr){ - throw logic_error("Decision trees don't support incremental learning, fit can only be called once."); - } - - if(features <= 0){ - throw invalid_argument("Invalid argument, there must be 1 or more features to train on."); - } - - if(samples <= 0){ - throw invalid_argument("Invalid argument, there must be 1 or more samples to train on."); - } - - splittingTree = recurse(X, samples, y, features, depth); - featureCount = features; -} - - -std::string DecisionTreeClassifier::getDot(){ - if (splittingTree == nullptr){ - throw logic_error("Decision tree must be created prior to generating dot output."); - } - std::string edges = splittingTree->getDotEdges(); - std::string dot = "digraph decisionTree {\n" + edges + "}"; - return dot; -} - -int DecisionTreeClassifier::primaryClass(int* y, int labelCount){ - - unordered_map map = unordered_map<int,int>(); - - for(int i = 0; i < labelCount; ++i){ - map[y[i]] += 1; - } - - int mostElements = 0; - int label = 0; - - for (auto& item : map){ - if(item.second > mostElements){ - mostElements = item.second; - label = item.first; - } - } - - return label; -} - - - -// add depth -TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){ - - if(depthRem == 0){ - TreeNode* ret = new TreeNode(primaryClass(y, rows)); - return ret; - } - - // found minimum node - if(rows == 1){ - TreeNode* ret = new TreeNode(primaryClass(y, rows)); - return ret; - } - - // get best split option - TreeNode* chosen = bestSplit(X, rows, y, columns); - SplitResults split = chosen->splitOnNode(X, y, rows, columns); - - // no valid splits, but we still did create some new arrays. - if(split.rightSize == rows || split.leftSize == rows){ - TreeNode* ret = new TreeNode(primaryClass(y, rows)); - delete split.XLeft; - delete split.XRight; - delete split.yLeft; - delete split.yRight; - return ret; - } - - // traverse lt tree - TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1); - // traverse gt tree - TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1); - - chosen->setLeftChild(left); - chosen->setRightChild(right); - - delete split.XLeft; - delete split.XRight; - delete split.yLeft; - delete split.yRight; - - return chosen; -} - - - -// 1 1 0 -// 3 3 0 -// 2 1 1 -// 4 1 3 - -// consider adding interpolation to this and sorting the list first. -// Also, no reason to consider the 0th split if that is the case. - -TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){ - - TreeNode* bestNode = nullptr; - float bestGini = std::numeric_limits<float>::max(); - - for(int col = 0 ; col < columns; ++col){ - for(int row = 0; row < rows; ++row){ - - float val = X[row*columns + col]; - TreeNode* current = new TreeNode(val, col); - float gini = current->evalSplit(X, y, rows, columns, "gini"); - if (gini < bestGini){ - - TreeNode* prevBest = bestNode; - delete prevBest; - - bestNode = current; - bestGini = gini; - } - else{ - delete current; - } - } - } - - return bestNode; - -} - -int* DecisionTreeClassifier::predict(float* X, int samples, int features){ - - if(featureCount == -1){ - throw logic_error("Unable to predict prior to calling fit()."); - } - - if(features != this->featureCount){ - throw invalid_argument("Incorrect number of features for prediction."); - } - cout << "PREDICTING" << endl; - - int* predictions = new int[samples]; - - for(int i = 0; i < samples; ++i){ - TreeNode* current = splittingTree; - while(!current->isLeaf()){ - float* currentElement = X; - currentElement += features * i; - bool lessThan = current->lessThan(currentElement, features); - if(lessThan){ - current = current->getLeftChild(); - } - else{ - current = current->getRightChild(); - } - } - predictions[i] = current->getClassification(); - } - - return predictions; -} diff --git a/rewrite/Makefile b/rewrite/Makefile @@ -1,17 +1,276 @@ +# CMAKE generated file: DO NOT EDIT! +# Generated by "Unix Makefiles" Generator, CMake Version 3.31 + +# Default target executed when no arguments are given to make. +default_target: all +.PHONY : default_target + +# Allow only one "make -f Makefile2" at a time, but pass parallelism. +.NOTPARALLEL: + +#============================================================================= +# Special targets provided by cmake. + +# Disable implicit rules so canonical targets will work. +.SUFFIXES: + +# Disable VCS-based implicit rules. +% : %,v + +# Disable VCS-based implicit rules. +% : RCS/% + +# Disable VCS-based implicit rules. +% : RCS/%,v + +# Disable VCS-based implicit rules. +% : SCCS/s.% + +# Disable VCS-based implicit rules. +% : s.% + +.SUFFIXES: .hpux_make_needs_suffix_list + +# Command-line flag to silence nested $(MAKE). +$(VERBOSE)MAKESILENT = -s + +#Suppress display of executed commands. +$(VERBOSE).SILENT: + +# A target that is always out of date. +cmake_force: +.PHONY : cmake_force + +#============================================================================= +# Set environment variables for the build. + +# The shell in which to execute make rules. +SHELL = /bin/sh + +# The CMake executable. +CMAKE_COMMAND = /usr/bin/cmake + +# The command to remove a file. +RM = /usr/bin/cmake -E rm -f + +# Escaping for special characters. +EQUALS = = + +# The top-level source directory on which CMake was run. +CMAKE_SOURCE_DIR = /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite + +# The top-level build directory on which CMake was run. +CMAKE_BINARY_DIR = /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite + +#============================================================================= +# Targets provided globally by CMake. + +# Special rule for the target edit_cache +edit_cache: + @$(CMAKE_COMMAND) -E cmake_echo_color "--switch=$(COLOR)" --cyan "No interactive CMake dialog available..." + /usr/bin/cmake -E echo No\ interactive\ CMake\ dialog\ available. +.PHONY : edit_cache + +# Special rule for the target edit_cache +edit_cache/fast: edit_cache +.PHONY : edit_cache/fast + +# Special rule for the target rebuild_cache +rebuild_cache: + @$(CMAKE_COMMAND) -E cmake_echo_color "--switch=$(COLOR)" --cyan "Running CMake to regenerate build system..." + /usr/bin/cmake --regenerate-during-build -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) +.PHONY : rebuild_cache + +# Special rule for the target rebuild_cache +rebuild_cache/fast: rebuild_cache +.PHONY : rebuild_cache/fast + +# The main all target +all: cmake_check_build_system + $(CMAKE_COMMAND) -E cmake_progress_start /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite/CMakeFiles /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite//CMakeFiles/progress.marks + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 all + $(CMAKE_COMMAND) -E cmake_progress_start /home/andrew/gitRepos/pseudo-optimal-decision-tree/rewrite/CMakeFiles 0 +.PHONY : all + +# The main clean target clean: - rm *.o - rm *.out + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean +.PHONY : clean + +# The main clean target +clean/fast: clean +.PHONY : clean/fast + +# Prepare targets for installation. +preinstall: all + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall +.PHONY : preinstall + +# Prepare targets for installation. +preinstall/fast: + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall +.PHONY : preinstall/fast + +# clear depends +depend: + $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 1 +.PHONY : depend + +#============================================================================= +# Target rules for targets named decision_tree + +# Build rule for target. +decision_tree: cmake_check_build_system + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 decision_tree +.PHONY : decision_tree + +# fast build rule for target. +decision_tree/fast: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/build +.PHONY : decision_tree/fast + +#============================================================================= +# Target rules for targets named clean_build + +# Build rule for target. +clean_build: cmake_check_build_system + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean_build +.PHONY : clean_build + +# fast build rule for target. +clean_build/fast: + $(MAKE) $(MAKESILENT) -f CMakeFiles/clean_build.dir/build.make CMakeFiles/clean_build.dir/build +.PHONY : clean_build/fast + +cpp/Criterion.o: cpp/Criterion.cpp.o +.PHONY : cpp/Criterion.o + +# target to build an object file +cpp/Criterion.cpp.o: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.o +.PHONY : cpp/Criterion.cpp.o + +cpp/Criterion.i: cpp/Criterion.cpp.i +.PHONY : cpp/Criterion.i + +# target to preprocess a source file +cpp/Criterion.cpp.i: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.i +.PHONY : cpp/Criterion.cpp.i + +cpp/Criterion.s: cpp/Criterion.cpp.s +.PHONY : cpp/Criterion.s + +# target to generate assembly for a file +cpp/Criterion.cpp.s: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/Criterion.cpp.s +.PHONY : cpp/Criterion.cpp.s + +cpp/DecisionTreeClassifier.o: cpp/DecisionTreeClassifier.cpp.o +.PHONY : cpp/DecisionTreeClassifier.o + +# target to build an object file +cpp/DecisionTreeClassifier.cpp.o: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.o +.PHONY : cpp/DecisionTreeClassifier.cpp.o + +cpp/DecisionTreeClassifier.i: cpp/DecisionTreeClassifier.cpp.i +.PHONY : cpp/DecisionTreeClassifier.i + +# target to preprocess a source file +cpp/DecisionTreeClassifier.cpp.i: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.i +.PHONY : cpp/DecisionTreeClassifier.cpp.i + +cpp/DecisionTreeClassifier.s: cpp/DecisionTreeClassifier.cpp.s +.PHONY : cpp/DecisionTreeClassifier.s + +# target to generate assembly for a file +cpp/DecisionTreeClassifier.cpp.s: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/DecisionTreeClassifier.cpp.s +.PHONY : cpp/DecisionTreeClassifier.cpp.s + +cpp/TreeNode.o: cpp/TreeNode.cpp.o +.PHONY : cpp/TreeNode.o + +# target to build an object file +cpp/TreeNode.cpp.o: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.o +.PHONY : cpp/TreeNode.cpp.o + +cpp/TreeNode.i: cpp/TreeNode.cpp.i +.PHONY : cpp/TreeNode.i + +# target to preprocess a source file +cpp/TreeNode.cpp.i: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.i +.PHONY : cpp/TreeNode.cpp.i + +cpp/TreeNode.s: cpp/TreeNode.cpp.s +.PHONY : cpp/TreeNode.s + +# target to generate assembly for a file +cpp/TreeNode.cpp.s: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/TreeNode.cpp.s +.PHONY : cpp/TreeNode.cpp.s + +cpp/bindings.o: cpp/bindings.cpp.o +.PHONY : cpp/bindings.o + +# target to build an object file +cpp/bindings.cpp.o: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.o +.PHONY : cpp/bindings.cpp.o + +cpp/bindings.i: cpp/bindings.cpp.i +.PHONY : cpp/bindings.i + +# target to preprocess a source file +cpp/bindings.cpp.i: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.i +.PHONY : cpp/bindings.cpp.i + +cpp/bindings.s: cpp/bindings.cpp.s +.PHONY : cpp/bindings.s + +# target to generate assembly for a file +cpp/bindings.cpp.s: + $(MAKE) $(MAKESILENT) -f CMakeFiles/decision_tree.dir/build.make CMakeFiles/decision_tree.dir/cpp/bindings.cpp.s +.PHONY : cpp/bindings.cpp.s + +# Help Target +help: + @echo "The following are some of the valid targets for this Makefile:" + @echo "... all (the default if no target is provided)" + @echo "... clean" + @echo "... depend" + @echo "... edit_cache" + @echo "... rebuild_cache" + @echo "... clean_build" + @echo "... decision_tree" + @echo "... cpp/Criterion.o" + @echo "... cpp/Criterion.i" + @echo "... cpp/Criterion.s" + @echo "... cpp/DecisionTreeClassifier.o" + @echo "... cpp/DecisionTreeClassifier.i" + @echo "... cpp/DecisionTreeClassifier.s" + @echo "... cpp/TreeNode.o" + @echo "... cpp/TreeNode.i" + @echo "... cpp/TreeNode.s" + @echo "... cpp/bindings.o" + @echo "... cpp/bindings.i" + @echo "... cpp/bindings.s" +.PHONY : help -CXXFLAGS = -O3 -Wall -std=c++17 # Add -O3 for optimization and other flags -LDFLAGS = # Add any linker flags here -# Targets -node: - g++ $(CXXFLAGS) -c TreeNode.cpp -o TreeNode.o +#============================================================================= +# Special targets to cleanup operation of make. -decisionTree: - g++ $(CXXFLAGS) -c DecisionTreeClassifier.cpp -o DecisionTreeClassifier.o +# Special rule to run CMake to check the build system integrity. +# No rule that depends on this can have commands that come from listfiles +# because they might be regenerated. +cmake_check_build_system: + $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 0 +.PHONY : cmake_check_build_system -tests: DecisionTreeClassifier.o TreeNode.o - g++ $(CXXFLAGS) Tests.cpp DecisionTreeClassifier.o TreeNode.o -o a.out diff --git a/rewrite/Test.py b/rewrite/Test.py @@ -0,0 +1,38 @@ +from sklearn.datasets import load_digits +import numpy as np +from decision_tree import DecisionTreeClassifier +from keras.datasets import mnist +from sklearn.metrics import accuracy_score + +(train_X, train_y), (test_X, test_y) = mnist.load_data() + +train_X = train_X.reshape(-1, 784) +test_X = test_X.reshape(-1, 784) + + +train_X = train_X[0:2000] +train_y = train_y[0:2000] + + +clf = DecisionTreeClassifier(20) + +print("TRAINING") +clf.fit(train_X, len(train_X), train_y, len(train_X[0])) + +predictions = clf.predict(test_X, test_X.shape[0], test_X.shape[1]) + +# Output results +print("Predictions:", predictions) +print(accuracy_score(y_pred=predictions, y_true=test_y)) + + + + + + +dot_representation = clf.getDot() +# print("\nDecision Tree (DOT Format):\n", dot_representation) + +# Optionally save DOT representation to a file +with open("decision_tree.dot", "w") as file: + file.write(dot_representation) diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp @@ -1,94 +0,0 @@ -#include "DecisionTreeClassifier.h" -#include <fstream> // For file I/O -#include "iostream" -#include "assert.h" - -using namespace std; - -void testTreeNode(){ - int labels[] = {10, 10, 10, 1, 2, 3}; - float samples[][4] = { - {1,1,5,3}, - {1,2,5,3}, - {1,7,5,3}, - {1,3,5,3}, - {1,7,5,3}, - {1,1,5,3} - }; - - - TreeNode tn = TreeNode(5.0f ,1); - bool isLeaf = tn.isLeaf(); - - assert(!isLeaf); - - float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini"); - assert(abs(giniVal - .5833333) < .0001 ); - - tn.setSplit(0.0f, 0); - float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini"); - assert(abs(giniVal2 - .6666666) < .0001); - -} - - -int main(){ - - cout << "STARTING" << endl; - - int SAMPLES = 20000; - - DecisionTreeClassifier clf(1); - int labels[SAMPLES]; - for (int i = 0; i < SAMPLES; ++i){ - labels[i] = i % 10; // Example: create labels that cycle through 0 to 9 - } - - float samples[SAMPLES][4]; - for (int i = 0; i < SAMPLES; ++i){ - // Fill the samples with some arbitrary data (example: sequential) - samples[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10) - samples[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5) - samples[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7) - samples[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3) - } - - clf.fit(*samples, SAMPLES, labels, 4); - int PREDS = 10; - - float preds[PREDS][4]; - for (int i = 0; i < PREDS; ++i){ - // Fill the samples with some arbitrary data (example: sequential) - preds[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10) - preds[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5) - preds[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7) - preds[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3) - } - - int* predsOut = clf.predict(*preds, PREDS, 4); - - cout << "DONE" << endl; - - for(int i = 0 ; i < PREDS; ++i){ - cout << preds[i][0] << " " << preds[i][1] << " " <<preds[i][2] << " " << preds[i][3] << " " << predsOut[i] << endl; - } - - delete[] predsOut; - - return 0; - - - ofstream outFile("decision_tree.dot"); - - // Check if the file is open - if (outFile.is_open()) { - // Write the decision tree to the file - outFile << clf.getDot() << endl; - outFile.close(); // Close the file after writing - cout << "Decision tree written to 'decision_tree.dot'" << endl; - } else { - cerr << "Failed to open the file." << endl; - } - - return 0; -} diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp @@ -1,244 +0,0 @@ -#include "TreeNode.h" -#include "stdexcept" -#include "unordered_map" -#include "math.h" -#include "iostream" -#include <sstream> //for std::stringstream -#include <string> //for std::string - -TreeNode::TreeNode(int classification){ - leaf = true; - this->classification = classification; -} - -TreeNode::TreeNode(float splittingVal, int featureIndex){ - splitValue = splittingVal; - index = featureIndex; - leaf = false; -} - -void TreeNode::setSplit(float splittingVal, int featureIndex){ - splitValue = splittingVal; - index = featureIndex; - leaf = false; -} - -bool TreeNode::isLeaf(){ - return leaf; -} - -float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){ - - if(isLeaf()){ - throw std::logic_error("Cannot evaluate split on leaf node."); - } - - if(criterion != "gini"){ - throw std::invalid_argument("Gini impurity is the only supported criterion."); - } - - return giniImpurity(X, y, samples, features); -} - - -void TreeNode::setLeftChild(TreeNode* child){ - leftChild = child; -} - -void TreeNode::setRightChild(TreeNode* child){ - rightChild = child; -} - -TreeNode* TreeNode::getLeftChild(){ - return leftChild; -} - -TreeNode* TreeNode::getRightChild(){ - return rightChild; -} - -float TreeNode::getSplitVal(){ - return splitValue; -} - -int TreeNode::getIndexSplit(){ - return index; -} - -SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){ - - SplitResults result = SplitResults(); - - int ltCount = 0; - int gteqCount = 0; - - for(int i = 0 ; i < samples; ++i){ - if(X[(i*features) + index] < splitValue){ - ltCount += 1; - } - else{ - gteqCount += 1; - } - } - - // Create X arrays to return - - float* ltArr = new float[ltCount * features]; - float* gteqArr = new float[gteqCount * features]; - - // Create array ptr next open - - float* nextLtX = ltArr; - float* nextGteqX = gteqArr; - - // Create y arrays to return - - int* ltYArr = new int[ltCount]; - int* gteqYArr = new int[gteqCount]; - - // Create array ptr next open - - int* nextLtY = ltYArr; - int* nextGteqY = gteqYArr; - - // Set pointers for return to the new arrays - - result.XLeft = ltArr; - result.yLeft = ltYArr; - - result.XRight = gteqArr; - result.yRight = gteqYArr; - - result.leftSize = ltCount; - result.rightSize = gteqCount; - - // Set arrays with correct values - - for(int i = 0 ; i < samples; ++i){ - if(X[(i*features) + index] < splitValue){ - for(int x = 0; x < features; ++x){ - nextLtX[x] = X[(i*features) + x]; - } - - nextLtX += features; - - nextLtY[0] = y[i]; - nextLtY += 1; - } - else{ - for(int x = 0; x < features; ++x){ - nextGteqX[x] = X[(i*features) + x]; - } - - nextGteqX += features; - - nextGteqY[0] = y[i]; - nextGteqY += 1; - } - } - - //for(int x = 0 ; x < ltCount; ++x){ - // for(int i = 0 ; i < features; ++i){ - // std::cout << ltArr[x*features + i]; - // } - // std::cout << std::endl; - //} - - //for(int x = 0 ; x < ltCount; ++x){ - // std::cout << ltYArr[x] << std::endl; - //} - - return result; -} - - - -float TreeNode::giniImpurity(float* X, int* y, int samples, int features){ - - std::unordered_map<int, int> ltMap; - std::unordered_map<int, int> gtMap; - - int ltCount = 0; - int gteqCount = 0; - - for(int i = 0; i < samples; ++i){ - if(X[index + (i * features)] < splitValue){ - ltMap[y[i]]++; - ltCount++; - } - else{ - gtMap[y[i]]++; - gteqCount++; - } - } - - - float ltGini= 1.0f; - - for (const auto& pair : ltMap) { - ltGini -= pow(float(pair.second) / ltCount, 2); - } - - float gteqGini = 1.0f; - - for (const auto& pair : gtMap) { - gteqGini -= pow(float(pair.second) / gteqCount, 2); - } - - if(gteqCount == 0){ - gteqGini = 0.0f; - } - if(ltCount == 0){ - ltGini = 0.0f; - } - - float gini = gteqGini * float(gteqCount) / samples; - gini += ltGini * float(ltCount) / samples; - - return gini; -} - - - -std::string TreeNode::getDotEdges(){ - - if(isLeaf()){ - return ""; - } - - std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n"; - current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n"; - - current += rightChild->getDotEdges(); - current += leftChild->getDotEdges(); - - return current; -} - -std::string TreeNode::getDotLabel(){ - const void * address = static_cast<const void*>(this); - std::stringstream ss; - ss << address; - std::string name = ss.str(); - if (isLeaf()){ - return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\""; - } - - return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\""; -} - -int TreeNode::getClassification(){ - if(isLeaf()){ - return classification; - } - throw std::logic_error("Unable to call getClassification() on internal vertices."); -} - -bool TreeNode::lessThan(float* sample, int features){ - - if(features < this->index){ - throw std::invalid_argument("Attempting to evaluate split with input that contains less features."); - } - - return(sample[index] < splitValue); -} diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h @@ -1,41 +0,0 @@ -#include "string" - -struct SplitResults{ - float* XLeft; - float* XRight; - int* yLeft; - int* yRight; - int leftSize; - int rightSize; -}; - -class TreeNode{ - public: - TreeNode(int classification); - TreeNode(float splittingVal, int featureIndex); - bool isLeaf(); - void setSplit(float splittingValue, int featureIndex); - float evalSplit(float* X, int* y, int samples, int features, std::string criterion); - TreeNode* getLeftChild(); - TreeNode* getRightChild(); - void setLeftChild(TreeNode* child); - void setRightChild(TreeNode* child); - float getSplitVal(); - int getIndexSplit(); - SplitResults splitOnNode(float* X, int* y, int samples, int features); - std::string getDotEdges(); - int getClassification(); - bool lessThan(float* sample, int features); - - private: - bool leaf; - float splitValue; - int index; - TreeNode* leftChild; - TreeNode* rightChild; - float giniImpurity(float* X, int* y, int samples, int features); - std::string getDotLabel(); - int classification; -}; - - diff --git a/rewrite/cpp/Criterion.cpp b/rewrite/cpp/Criterion.cpp @@ -0,0 +1,49 @@ +#include "Criterion.h" +#include <unordered_map> +#include <math.h> + +float Criterion::giniImpurity(float* X, int* y, int samples, int features, int index, float splitValue){ + + std::unordered_map<int, int> ltMap; + std::unordered_map<int, int> gtMap; + + int ltCount = 0; + int gteqCount = 0; + + for(int i = 0; i < samples; ++i){ + if(X[index + (i * features)] < splitValue){ + ltMap[y[i]]++; + ltCount++; + } + else{ + gtMap[y[i]]++; + gteqCount++; + } + } + + + float ltGini= 1.0f; + + for (const auto& pair : ltMap) { + ltGini -= pow(float(pair.second) / ltCount, 2); + } + + float gteqGini = 1.0f; + + for (const auto& pair : gtMap) { + gteqGini -= pow(float(pair.second) / gteqCount, 2); + } + + if(gteqCount == 0){ + gteqGini = 0.0f; + } + if(ltCount == 0){ + ltGini = 0.0f; + } + + float gini = gteqGini * float(gteqCount) / samples; + gini += ltGini * float(ltCount) / samples; + + return gini; +} + diff --git a/rewrite/cpp/Criterion.h b/rewrite/cpp/Criterion.h @@ -0,0 +1,5 @@ + +class Criterion{ + public: + float giniImpurity(float* X, int* y, int samples, int features, int index, float splitValue); +}; diff --git a/rewrite/cpp/DecisionTreeClassifier.cpp b/rewrite/cpp/DecisionTreeClassifier.cpp @@ -0,0 +1,177 @@ +#include "DecisionTreeClassifier.h" +#include <limits> +#include <iostream> +#include <unordered_map> + +using namespace std; + +DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ + depth = maxDepth; +} + +void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){ + if (splittingTree != nullptr){ + throw logic_error("Decision trees don't support incremental learning, fit can only be called once."); + } + + if(features <= 0){ + throw invalid_argument("Invalid argument, there must be 1 or more features to train on."); + } + + if(samples <= 0){ + throw invalid_argument("Invalid argument, there must be 1 or more samples to train on."); + } + + splittingTree = recurse(X, samples, y, features, depth); + featureCount = features; +} + + +std::string DecisionTreeClassifier::getDot(){ + if (splittingTree == nullptr){ + throw logic_error("Decision tree must be created prior to generating dot output."); + } + std::string edges = splittingTree->getDotEdges(); + std::string dot = "digraph decisionTree {\n" + edges + "}"; + return dot; +} + +int DecisionTreeClassifier::primaryClass(int* y, int labelCount){ + + unordered_map map = unordered_map<int,int>(); + + for(int i = 0; i < labelCount; ++i){ + map[y[i]] += 1; + } + + int mostElements = 0; + int label = 0; + + for (auto& item : map){ + if(item.second > mostElements){ + mostElements = item.second; + label = item.first; + } + } + + return label; +} + + + +// add depth +TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){ + + if(depthRem == 0){ + TreeNode* ret = new TreeNode(primaryClass(y, rows)); + return ret; + } + + // found minimum node + if(rows == 1){ + TreeNode* ret = new TreeNode(primaryClass(y, rows)); + return ret; + } + + // get best split option + TreeNode* chosen = bestSplit(X, rows, y, columns); + SplitResults split = chosen->splitOnNode(X, y, rows, columns); + + // no valid splits, but we still did create some new arrays. + if(split.rightSize == rows || split.leftSize == rows){ + TreeNode* ret = new TreeNode(primaryClass(y, rows)); + delete split.XLeft; + delete split.XRight; + delete split.yLeft; + delete split.yRight; + return ret; + } + + // traverse lt tree + TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1); + // traverse gt tree + TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1); + + chosen->setLeftChild(left); + chosen->setRightChild(right); + + delete split.XLeft; + delete split.XRight; + delete split.yLeft; + delete split.yRight; + + return chosen; +} + + + + + +// 1 1 0 +// 3 3 0 +// 2 1 1 +// 4 1 3 + +// consider adding interpolation to this and sorting the list first. +// Also, no reason to consider the 0th split if that is the case. + +TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){ + + TreeNode* bestNode = nullptr; + float bestGini = std::numeric_limits<float>::max(); + + for(int col = 0 ; col < columns; ++col){ + for(int row = 0; row < rows; ++row){ + + float val = X[row*columns + col]; + TreeNode* current = new TreeNode(val, col); + float gini = current->evalSplit(X, y, rows, columns, "gini"); + if (gini < bestGini){ + + TreeNode* prevBest = bestNode; + delete prevBest; + + bestNode = current; + bestGini = gini; + } + else{ + delete current; + } + } + + } + + return bestNode; + +} + +int* DecisionTreeClassifier::predict(float* X, int samples, int features){ + + if(featureCount == -1){ + throw logic_error("Unable to predict prior to calling fit()."); + } + + if(features != this->featureCount){ + throw invalid_argument("Incorrect number of features for prediction."); + } + + int* predictions = new int[samples]; + + for(int i = 0; i < samples; ++i){ + TreeNode* current = splittingTree; + while(!current->isLeaf()){ + float* currentElement = X; + currentElement += features * i; + bool lessThan = current->lessThan(currentElement, features); + if(lessThan){ + current = current->getLeftChild(); + } + else{ + current = current->getRightChild(); + } + } + predictions[i] = current->getClassification(); + } + + return predictions; +} diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/cpp/DecisionTreeClassifier.h diff --git a/rewrite/cpp/TreeNode.cpp b/rewrite/cpp/TreeNode.cpp @@ -0,0 +1,203 @@ +#include "TreeNode.h" +#include "stdexcept" +#include "Criterion.h" +#include "math.h" +#include "iostream" +#include <sstream> //for std::stringstream +#include <string> //for std::string + +TreeNode::TreeNode(int classification){ + leaf = true; + this->classification = classification; +} + +TreeNode::TreeNode(float splittingVal, int featureIndex){ + splitValue = splittingVal; + index = featureIndex; + leaf = false; +} + +void TreeNode::setSplit(float splittingVal, int featureIndex){ + splitValue = splittingVal; + index = featureIndex; + leaf = false; +} + +bool TreeNode::isLeaf(){ + return leaf; +} + +float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){ + + if(isLeaf()){ + throw std::logic_error("Cannot evaluate split on leaf node."); + } + + if(criterion != "gini"){ + throw std::invalid_argument("Gini impurity is the only supported criterion."); + } + + Criterion evalCriterion= Criterion(); + + return evalCriterion.giniImpurity(X, y, samples, features, this->index, this->splitValue); +} + + +void TreeNode::setLeftChild(TreeNode* child){ + leftChild = child; +} + +void TreeNode::setRightChild(TreeNode* child){ + rightChild = child; +} + +TreeNode* TreeNode::getLeftChild(){ + return leftChild; +} + +TreeNode* TreeNode::getRightChild(){ + return rightChild; +} + +float TreeNode::getSplitVal(){ + return splitValue; +} + +int TreeNode::getIndexSplit(){ + return index; +} + +SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){ + + SplitResults result = SplitResults(); + + int ltCount = 0; + int gteqCount = 0; + + for(int i = 0 ; i < samples; ++i){ + if(X[(i*features) + index] < splitValue){ + ltCount += 1; + } + else{ + gteqCount += 1; + } + } + + // Create X arrays to return + + float* ltArr = new float[ltCount * features]; + float* gteqArr = new float[gteqCount * features]; + + // Create array ptr next open + + float* nextLtX = ltArr; + float* nextGteqX = gteqArr; + + // Create y arrays to return + + int* ltYArr = new int[ltCount]; + int* gteqYArr = new int[gteqCount]; + + // Create array ptr next open + + int* nextLtY = ltYArr; + int* nextGteqY = gteqYArr; + + // Set pointers for return to the new arrays + + result.XLeft = ltArr; + result.yLeft = ltYArr; + + result.XRight = gteqArr; + result.yRight = gteqYArr; + + result.leftSize = ltCount; + result.rightSize = gteqCount; + + // Set arrays with correct values + + for(int i = 0 ; i < samples; ++i){ + if(X[(i*features) + index] < splitValue){ + for(int x = 0; x < features; ++x){ + nextLtX[x] = X[(i*features) + x]; + } + + nextLtX += features; + + nextLtY[0] = y[i]; + nextLtY += 1; + } + else{ + for(int x = 0; x < features; ++x){ + nextGteqX[x] = X[(i*features) + x]; + } + + nextGteqX += features; + + nextGteqY[0] = y[i]; + nextGteqY += 1; + } + } + + //for(int x = 0 ; x < ltCount; ++x){ + // for(int i = 0 ; i < features; ++i){ + // std::cout << ltArr[x*features + i]; + // } + // std::cout << std::endl; + //} + + //for(int x = 0 ; x < ltCount; ++x){ + // std::cout << ltYArr[x] << std::endl; + //} + + return result; +} + + + + + + + +std::string TreeNode::getDotEdges(){ + + if(isLeaf()){ + return ""; + } + + std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n"; + current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n"; + + current += rightChild->getDotEdges(); + current += leftChild->getDotEdges(); + + return current; +} + +std::string TreeNode::getDotLabel(){ + const void * address = static_cast<const void*>(this); + std::stringstream ss; + ss << address; + std::string name = ss.str(); + if (isLeaf()){ + return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\""; + } + + return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\""; +} + +int TreeNode::getClassification(){ + if(isLeaf()){ + return classification; + } + throw std::logic_error("Unable to call getClassification() on internal vertices."); +} + +bool TreeNode::lessThan(float* sample, int features){ + + if(features < this->index){ + throw std::invalid_argument("Attempting to evaluate split with input that contains less features."); + } + + return(sample[index] < splitValue); +} diff --git a/rewrite/cpp/TreeNode.h b/rewrite/cpp/TreeNode.h @@ -0,0 +1,40 @@ +#include "string" + +struct SplitResults{ + float* XLeft; + float* XRight; + int* yLeft; + int* yRight; + int leftSize; + int rightSize; +}; + +class TreeNode{ + public: + TreeNode(int classification); + TreeNode(float splittingVal, int featureIndex); + bool isLeaf(); + void setSplit(float splittingValue, int featureIndex); + float evalSplit(float* X, int* y, int samples, int features, std::string criterion); + TreeNode* getLeftChild(); + TreeNode* getRightChild(); + void setLeftChild(TreeNode* child); + void setRightChild(TreeNode* child); + float getSplitVal(); + int getIndexSplit(); + SplitResults splitOnNode(float* X, int* y, int samples, int features); + std::string getDotEdges(); + int getClassification(); + bool lessThan(float* sample, int features); + + private: + bool leaf; + float splitValue; + int index; + TreeNode* leftChild; + TreeNode* rightChild; + std::string getDotLabel(); + int classification; +}; + + diff --git a/rewrite/cpp/bindings.cpp b/rewrite/cpp/bindings.cpp @@ -0,0 +1,30 @@ +#include <pybind11/pybind11.h> +#include <pybind11/numpy.h> +#include "DecisionTreeClassifier.h" + +namespace py = pybind11; + +PYBIND11_MODULE(decision_tree, m) { + py::class_<DecisionTreeClassifier>(m, "DecisionTreeClassifier") + .def(py::init<int>()) + .def("fit", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, py::array_t<int> y, int features) { + auto X_buf = X.request(); // Request a buffer from NumPy array + auto y_buf = y.request(); // Request a buffer from NumPy array + float* X_ptr = static_cast<float*>(X_buf.ptr); + int* y_ptr = static_cast<int*>(y_buf.ptr); + self.fit(X_ptr, samples, y_ptr, features); + }) + .def("predict", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, int features) { + auto X_buf = X.request(); // Request a buffer from NumPy array + float* X_ptr = static_cast<float*>(X_buf.ptr); + int* result = self.predict(X_ptr, samples, features); + + // Return a NumPy array from the result + return py::array_t<int>(samples, result); + }) + .def("getDot", &DecisionTreeClassifier::getDot) + .def("__repr__", [](const DecisionTreeClassifier &dt) { + return "<DecisionTreeClassifier>"; + }); +} +