Skip to content
Snippets Groups Projects
Commit d90d5169 authored by ibe23's avatar ibe23
Browse files

50% Class Design of DT

parent e23aa855
Branches
No related tags found
No related merge requests found
Pipeline #1439 passed
from typing import Any, Callable, Dict, List, Tuple
import pandas as pd
from pandas.plotting import scatter_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#processing
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.pipeline import make_pipeline
#validation
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
#model selection
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
#metrics
from sklearn.metrics import roc_curve
from sklearn.svm import SVC
#Scoring
from sklearn.metrics import mean_squared_error, r2_score
pd.set_option('display.max_columns', None)
from sklearn.tree import export_graphviz
from subprocess import call
from IPython.display import Image
# %matplotlib inline
class DecisionTree:
def __init__(self,filename,names) -> None:
self.filename = filename
self.names = names
def load_dataset(self)->pd.DataFrame:
df = pd.read_csv(self.filename, names=self.names, header=0)
return df
def split_data(
data: pd.DataFrame,
train_size: float,
test_size: float,
labels: list,
)->Tuple[np.array, np.array, np.array, np.array, np.array, np.array]:
X = data.drop(labels)
y = data[labels]
X_train, X_extra, y_train, y_extra = train_test_split(X, y, train_size)
X_val, X_test, y_val, y_test = train_test_split(X_extra, y_extra, test_size)
return X_train, X_test, X_val, y_train, y_test, y_val
def train_decisiontree(
X_train: pd.DataFrame,
y_train: pd.DataFrame,
)->object:
clf = DecisionTreeClassifier(max_leaf_nodes=8, random_state=0)
fit_clf = clf.fit(X_train, y_train)
return fit_clf
def verboseExplainability(self):
fname=["sepal length", "sepal width", "petal length", "petal width"]
cname=['setosa', 'versicolor', 'virginica']
fnames = ['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
export_graphviz(clf_dec, out_file='tree.dot', feature_names = fname, class_names = cname,
rounded = True, proportion = False, precision = 2, filled = True)
call(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png', '-Gdpi=400'])
Image(filename = 'tree.png')
def binaryStructure(self):
n_nodes = clf_dec.tree_.node_count
children_left = clf_dec.tree_.children_left
children_right = clf_dec.tree_.children_right
feature = clf_dec.tree_.feature
threshold = clf_dec.tree_.threshold
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, 0)]
while len(stack) > 0:
node_id, depth = stack.pop()
node_depth[node_id] = depth
is_split_node = children_left[node_id] != children_right[node_id]
if is_split_node:
stack.append((children_left[node_id], depth + 1))
stack.append((children_right[node_id], depth + 1))
else:
is_leaves[node_id] = True
print(
"The binary tree structure has {n} nodes and has "
"the following tree structure:\n".format(n=n_nodes)
)
for i in range(n_nodes):
if is_leaves[i]:
print(
"{space}node={node} is a leaf node.".format(
space=node_depth[i] * "\t", node=i
)
)
else:
print(
"{space}node={node} is a split node: "
"go to node {left} if X[:, {feature}] <= {threshold} "
"else to node {right}.".format(
space=node_depth[i] * "\t",
node=i,
left=children_left[i],
feature=feature[i],
threshold=threshold[i],
right=children_right[i],
)
)
def decision_rules(test_samples):
node_indicator = clf_dec.decision_path(test_samples)
leaf_id = clf_dec.apply(test_samples)
for sid in range(0,len(test_samples)):
sample_id = sid
node_index = node_indicator.indices[
node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
]
print("\nRules used to predict sample {id}:".format(id=sample_id))
for node_id in node_index:
# continue to the next node if it is a leaf node
if leaf_id[sample_id] == node_id:
continue
# check if value of the split feature for sample 0 is below threshold
if test_samples.iloc[sample_id, feature[node_id]] <= threshold[node_id]:
threshold_sign = "<="
else:
threshold_sign = ">"
print(
"decision node {node} : (test_sample[{sample}, {feature}] = {value}) "
"{inequality} {threshold})".format(
node=node_id,
sample=sample_id,
feature=feature[node_id],
value=test_samples.iloc[sample_id, feature[node_id]],
inequality=threshold_sign,
threshold=threshold[node_id],
)
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment