cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

module.cpp (1695B)


      1 
      2 #include <pybind11/pybind11.h>
      3 #include <pybind11/numpy.h>
      4 #include "ELCClassifier.h"
      5 #include <vector>
      6 
      7 namespace py = pybind11;
      8 
      9 PYBIND11_MODULE(decision_tree, m) {
     10     py::class_<ELCClassifier>(m, "ELCClassifier")
     11 		.def(py::init<int, int, int>())
     12 		.def(py::init<int, int, int, std::string>(), 
     13 			 py::arg("depth"), 
     14 			 py::arg("combinations"), 
     15 			 py::arg("threadCount"), 
     16 			 py::arg("objectiveFunction") = "")
     17         .def("fit", [](ELCClassifier &self, py::array_t<float> X, int samples, py::array_t<int> y, int features) {
     18             auto X_buf = X.request(); // Request a buffer from NumPy array
     19             auto y_buf = y.request(); // Request a buffer from NumPy array
     20             float* X_ptr = static_cast<float*>(X_buf.ptr);
     21             int* y_ptr = static_cast<int*>(y_buf.ptr);
     22             self.fit(X_ptr, samples, y_ptr, features);
     23         })
     24         .def("predict", [](ELCClassifier &self, py::array_t<float> X, int samples, int features) {
     25             auto X_buf = X.request(); // Request a buffer from NumPy array
     26             float* X_ptr = static_cast<float*>(X_buf.ptr);
     27 
     28             // Get the prediction result as a raw pointer (dynamically allocated array)
     29             int* result = self.predict(X_ptr, samples, features);
     30 
     31 
     32             // Create a NumPy array from the raw pointer
     33             py::array_t<int> result_array(samples, result);
     34 
     35             // Once the NumPy array is created, delete the dynamically allocated array
     36             delete[] result;  // Properly deallocate the memory
     37 
     38             return result_array;  // Return the NumPy array
     39         })
     40         .def("getDot", &ELCClassifier::getDot)
     41 		.def("getSplits", &ELCClassifier::getSplits);
     42 }
     43