cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

diabetes.py (1534B)


      1 import pandas as pd
      2 import matplotlib.pyplot as plt
      3 import numpy as np
      4 from sklearn.tree import DecisionTreeClassifier, plot_tree
      5 
      6 
      7 df = pd.read_csv('./diabetes.csv')
      8 X = df['BMI']
      9 y = df['Outcome']
     10 
     11 xSub = []
     12 ySub = []
     13 X = X.to_numpy()
     14 y = y.to_numpy()
     15 
     16 for i in range(60, 90):
     17     if y[i] == 0:
     18         if X[i] > 32:
     19             continue
     20     xSub.append(X[i])
     21     ySub.append(y[i])
     22 
     23 xSub = np.array(xSub).reshape(-1, 1)
     24 ySub = np.array(ySub).reshape(-1, 1)
     25 
     26 clf = DecisionTreeClassifier(max_depth=1)
     27 clf.fit(xSub, ySub)
     28 
     29 # Save decision tree with dark background
     30 fig, ax = plt.subplots(figsize=(8, 6))
     31 plot_tree(clf, filled=True, feature_names=["BMI"], class_names=["No Diabetes", "Diabetes"], ax=ax)
     32 plt.tight_layout()
     33 plt.savefig("tree1.pdf")
     34 plt.close()
     35 
     36 # Save first scatter plot (dia1)
     37 plt.figure(figsize=(10, 6))
     38 plt.scatter(xSub, ySub, c=ySub, cmap='bwr', s=60, edgecolors='#000000')
     39 plt.xlabel("BMI")
     40 plt.ylabel("Diagnosis")
     41 plt.xlim([0, 50])
     42 plt.tight_layout()
     43 plt.savefig("dia1.pdf")
     44 plt.close()
     45 
     46 # Save second scatter plot with vertical line at 32.25 (dia2)
     47 plt.figure(figsize=(10, 6))
     48 plt.scatter(xSub, ySub, c=ySub, cmap='bwr', s=60, edgecolors="#000000")
     49 
     50 # Add shaded regions
     51 plt.axvspan(0, 32.25, color='blue', alpha=0.3)  # Light blue to the left
     52 plt.axvspan(32.25, 50, color='red', alpha=0.3)  # Light red to the right
     53 
     54 # Vertical decision boundary
     55 plt.axvline(32.25, color="black", linestyle="-")
     56 
     57 plt.xlabel("BMI")
     58 plt.ylabel("Diagnosis")
     59 plt.xlim([0, 50])
     60 plt.tight_layout()
     61 plt.savefig("dia2.pdf")
     62 plt.close()