commit a8e1ba5db2e658a6b806deccdde07f9bd3b28a15
parent 93a716343d4dfad27a8ca44df30ef0fc5de879e2
Author: Andrew <andrewlaack1@gmail.com>
Date: Wed, 25 Dec 2024 15:11:17 -0600
Did stuff
Diffstat:
1 file changed, 7 insertions(+), 10 deletions(-)
diff --git a/rewrite/cpp/DecisionTreeClassifier.cpp b/rewrite/cpp/DecisionTreeClassifier.cpp
@@ -126,44 +126,41 @@ TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int colu
float bestGini = std::numeric_limits<float>::max();
std::mutex mtx;
- for(int i = 0 ; i < rows; ++i){}
-
auto evalColumn = [&](int col){
TreeNode* localBestNode = nullptr;
float localBestGini = std::numeric_limits<float>::max();
- for (int row = 0; row < rows; ++row) {
+ for (int row = 0; row < rows; ++row){
float val = X[row * columns + col];
TreeNode* current = new TreeNode(val, col);
float gini = current->evalSplit(X, y, rows, columns, "gini");
- if (gini < localBestGini) {
+ if (gini < localBestGini){
delete localBestNode;
localBestNode = current;
localBestGini = gini;
- } else {
+ }
+ else{
delete current;
}
}
- // Update global best values with mutex protection
std::lock_guard<std::mutex> lock(mtx);
- if (localBestGini < bestGini) {
+ if (localBestGini < bestGini){
delete bestNode;
bestNode = localBestNode;
bestGini = localBestGini;
- } else {
+ }
+ else{
delete localBestNode;
}
};
- // Create a thread pool to evaluate each column in parallel
std::vector<std::thread> threads;
for (int col = 0; col < columns; ++col) {
threads.emplace_back(evalColumn, col);
}
- // Wait for all threads to finish
for (auto& thread : threads) {
thread.join();
}