decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit a8e1ba5db2e658a6b806deccdde07f9bd3b28a15
parent 93a716343d4dfad27a8ca44df30ef0fc5de879e2
Author: Andrew <andrewlaack1@gmail.com>
Date:   Wed, 25 Dec 2024 15:11:17 -0600

Did stuff

Diffstat:
Mrewrite/cpp/DecisionTreeClassifier.cpp | 17+++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/rewrite/cpp/DecisionTreeClassifier.cpp b/rewrite/cpp/DecisionTreeClassifier.cpp @@ -126,44 +126,41 @@ TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int colu float bestGini = std::numeric_limits<float>::max(); std::mutex mtx; - for(int i = 0 ; i < rows; ++i){} - auto evalColumn = [&](int col){ TreeNode* localBestNode = nullptr; float localBestGini = std::numeric_limits<float>::max(); - for (int row = 0; row < rows; ++row) { + for (int row = 0; row < rows; ++row){ float val = X[row * columns + col]; TreeNode* current = new TreeNode(val, col); float gini = current->evalSplit(X, y, rows, columns, "gini"); - if (gini < localBestGini) { + if (gini < localBestGini){ delete localBestNode; localBestNode = current; localBestGini = gini; - } else { + } + else{ delete current; } } - // Update global best values with mutex protection std::lock_guard<std::mutex> lock(mtx); - if (localBestGini < bestGini) { + if (localBestGini < bestGini){ delete bestNode; bestNode = localBestNode; bestGini = localBestGini; - } else { + } + else{ delete localBestNode; } }; - // Create a thread pool to evaluate each column in parallel std::vector<std::thread> threads; for (int col = 0; col < columns; ++col) { threads.emplace_back(evalColumn, col); } - // Wait for all threads to finish for (auto& thread : threads) { thread.join(); }