-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTree.py
More file actions
153 lines (138 loc) · 6.64 KB
/
Tree.py
File metadata and controls
153 lines (138 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
from scipy.stats import entropy
import Node
import BuildNode
class Tree:
def __init__(self, inputs, labels, featuresToConsider, impurityCat,
impurityNum, featureNames, featureTypes, nmin, numIntervals=999):
self.inputs = inputs
self.labels = labels
self.impurityCat = impurityDict[impurityCat]
self.impurityNum = impurityDict[impurityNum]
self.featuresToConsider = featuresToConsider
self.featureNames = featureNames
self.featureTypes = featureTypes
self.nmin = nmin
self.root = Node.Node(inputs, labels, '')
self.numIntervals = numIntervals
self.lenOfData = self.inputs.shape[1]
def make(self):
#I will build the tree in using BFS, each element in the queue
#will be a nonterminal node that is built but whose children are not
#built
self.root = BuildNode.build(self.root, self.lenOfData,self.impurityCat,
self.impurityNum, self.featureNames,
self.featureTypes, self.featuresToConsider,
self.nmin, self.numIntervals)
if isinstance(self.root, Node.Terminal):
return
queue = [self.root]
while (len(queue)>0):
nodee = queue[0]
queue.pop(0)
nodee.trueNode = BuildNode.build(nodee.trueNode, self.lenOfData,
self.impurityCat,
self.impurityNum, self.featureNames,
self.featureTypes, self.featuresToConsider,
self.nmin, self.numIntervals)
nodee.falseNode = BuildNode.build(nodee.falseNode, self.lenOfData,
self.impurityCat,
self.impurityNum, self.featureNames,
self.featureTypes, self.featuresToConsider,
self.nmin, self.numIntervals)
if (isinstance(nodee.trueNode , Node.NonTerminal)):
queue.append(nodee.trueNode)
if (isinstance(nodee.falseNode , Node.NonTerminal)):
queue.append(nodee.falseNode)
def draw(self):
#This function visualizes the tree by printing it to the console
#I will print the tree using breadth first search, the queue will
#contain a tuple of the node and the level of the node
with open('TreeDrawing.txt', 'w') as f:
queue = [(self.root,0)]
currentLevel=-1
while len(queue) != 0:
nextt = queue.pop(0)
nodee = nextt[0]
levell=nextt[1]
if (levell !=currentLevel):
f.write('NEW LEVEL \n')
currentLevel=levell
f.write('Level: '+ str(levell)+ ' NODE: '+ nodee.__repr__()+'\n')
if isinstance(nodee, Node.NonTerminal):
queue.append((nodee.trueNode,levell+1 ))
queue.append((nodee.falseNode, levell+1))
def predict(self,data):
predictions = []
for i in range(data.shape[0]):
pred = None
currentNode = self.root
while isinstance(currentNode, Node.NonTerminal):
colIndex = np.where(currentNode.decision.featureName ==
self.featureNames)[0][0]
if currentNode.decision.function(data[i,colIndex]):
currentNode = currentNode.trueNode
else:
currentNode = currentNode.falseNode
predictions.append(currentNode.prediction)
return predictions
def prune(self, maxAlpha):
#repeatedly find the weakest node (the node with the smallest alpha)
#and remove that weakest node if its alpha is less then maxalpha
#hyperparameter
smallestAlpha = np.inf
smallestAlphaNode = None #will store node to be pruned
smallestAlphaNodeParent = None #will store the parent of the node to be pruned
def recursePrune(currentNode, parentNode):
#returns the weakest node
nonlocal smallestAlpha
nonlocal smallestAlphaNodeParent
nonlocal smallestAlphaNode
if isinstance(currentNode, Node.NonTerminal):
#computes number of leaves and tree impurity of currentNode
leavesTr, sumImpurityTr=recursePrune(currentNode.trueNode, currentNode)
leavesFls, sumImpurityFls=recursePrune(currentNode.falseNode, currentNode)
leaves = leavesTr + leavesFls
sumImpurity = sumImpurityTr + sumImpurityFls
#update node to be pruned if this node is worse
alpha = (currentNode.impurity-sumImpurity)/(leaves-1)
if alpha < smallestAlpha:
smallestAlpha = alpha
smallestAlphaNodeParent = parentNode
smallestAlphaNode = currentNode
return leaves, sumImpurity
else:
#if the node is terminal node, then it has one leaf
return 1 , currentNode.impurity
recursePrune(self.root, None)
while (smallestAlpha < maxAlpha):
if smallestAlphaNode is None or smallestAlphaNode is self.root:
raise Exception('Cannot prune the root, maxalpha is too high')
#many of the features are not applicable, I have replaced them with Nones
newNode = BuildNode.build(Node.Node(None,smallestAlphaNode.labels,''),
self.lenOfData,self.impurityCat, self.impurityNum,
None, None,None, None,None, forceToTerminal=True)
if smallestAlphaNode is smallestAlphaNodeParent.trueNode:
smallestAlphaNodeParent.trueNode = newNode
else:
smallestAlphaNodeParent.falseNode = newNode
smallestAlpha = np.inf
smallestAlphaNodeParent = None
recursePrune(self.root, None)
impurityDict = {
'entropy':(lambda labs: zeroCase( getProbs(labs),entropy)),
'gini':(lambda labs: zeroCase(getProbs(labs), Gini)),
'mse':(lambda labs: zeroCase(labs, mse))
}
def zeroCase(labs, f):
if len(labs):
return f(labs)
else:
return 0
def Gini(probs):
return (1-np.sum(probs**2))
def mse (labs):
return np.mean( (labs - np.mean(labs))**2 )
def getProbs(arr):
_,counts = np.unique(arr, return_counts=True)
return counts/np.sum(counts)