decision-tree-classifier

Decision tree classifier implementation in C++
git clone git://git.laack.co/decision-tree-classifier.git
Log | Files | Refs | README | LICENSE

bindings.cpp (1650B)


      1 #include <pybind11/pybind11.h>
      2 #include <pybind11/numpy.h>
      3 #include "DecisionTreeClassifier.h"
      4 #include <vector>
      5 
      6 namespace py = pybind11;
      7 
      8 PYBIND11_MODULE(decision_tree, m) {
      9     py::class_<DecisionTreeClassifier>(m, "DecisionTreeClassifier")
     10         .def(py::init<int>())
     11         .def("fit", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, py::array_t<int> y, int features) {
     12             auto X_buf = X.request(); // Request a buffer from NumPy array
     13             auto y_buf = y.request(); // Request a buffer from NumPy array
     14             float* X_ptr = static_cast<float*>(X_buf.ptr);
     15             int* y_ptr = static_cast<int*>(y_buf.ptr);
     16             self.fit(X_ptr, samples, y_ptr, features);
     17         })
     18         .def("predict", [](DecisionTreeClassifier &self, py::array_t<float> X, int samples, int features) {
     19             auto X_buf = X.request(); // Request a buffer from NumPy array
     20             float* X_ptr = static_cast<float*>(X_buf.ptr);
     21 
     22             // Get the prediction result as a raw pointer (dynamically allocated array)
     23             int* result = self.predict(X_ptr, samples, features);
     24 
     25 
     26             // Create a NumPy array from the raw pointer
     27             py::array_t<int> result_array(samples, result);
     28 
     29             // Once the NumPy array is created, delete the dynamically allocated array
     30             delete[] result;  // Properly deallocate the memory
     31 
     32             return result_array;  // Return the NumPy array
     33         })
     34         .def("getDot", &DecisionTreeClassifier::getDot)
     35         .def("__repr__", [](const DecisionTreeClassifier &dt) {
     36             return "<DecisionTreeClassifier>";
     37         });
     38 }
     39