information-retrieval

Exploration of information retrieval topics
git clone git://git.laack.co/information-retrieval.git
Log | Files | Refs

train.py (959B)


      1 import pandas as pd
      2 import numpy as np
      3 from sklearn.tree import DecisionTreeClassifier
      4 from sklearn.model_selection import StratifiedKFold
      5 from statistics import mean, stdev
      6 
      7 df = pd.read_csv("embeddings.csv")
      8 
      9 def parse_embedding(embedding_str):
     10     cleaned = embedding_str.strip("[]").split()
     11     return np.array([float(x) for x in cleaned])
     12 
     13 
     14 df["embedding"] = df["embedding"].apply(parse_embedding)
     15 
     16 x_scaled = np.stack(df["embedding"].values)
     17 y = df["generated"].values
     18 
     19 tree_clf = DecisionTreeClassifier()
     20 
     21 skf = StratifiedKFold(n_splits=5, random_state=1, shuffle=True)
     22 lst_accu_stratified = []
     23 
     24 for train_index, test_index in skf.split(x_scaled, y):
     25 	x_train_fold, x_test_fold = x_scaled[train_index], x_scaled[test_index]
     26 	y_train_fold, y_test_fold = y[train_index], y[test_index]
     27 	tree_clf.fit(x_train_fold, y_train_fold)
     28 	lst_accu_stratified.append(tree_clf.score(x_test_fold, y_test_fold))
     29 print(f'Mean Accuracy: { mean(lst_accu_stratified)*100}')