Skip to content

Commit 7332ec4

Browse files
author
Olcay Taner YILDIZ
committed
Removed data field from DecisionNode.
1 parent 9b399c5 commit 7332ec4

1 file changed

Lines changed: 35 additions & 28 deletions

File tree

Classification/Model/DecisionTree/DecisionNode.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717

1818
class DecisionNode(object):
1919
children: list
20-
__data: InstanceList
2120
__class_label: str = None
2221
leaf: bool
2322
__condition: DecisionCondition
2423
__classLabelsDistribution: DiscreteDistribution
2524
EPSILON = 0.0000000001
2625

2726
def constructor1(self,
28-
data: InstanceList,
29-
condition=None,
30-
parameter=None,
31-
isStump=False
32-
):
27+
data: InstanceList,
28+
condition=None,
29+
parameter=None,
30+
isStump=False
31+
):
3332
"""
3433
The DecisionNode method takes InstanceList data as input and then it sets the class label parameter by finding
3534
the most occurred class label of given data, it then gets distinct class labels as class labels ArrayList.
@@ -65,15 +64,14 @@ def constructor1(self,
6564
best_attribute = -1
6665
best_split_value = 0
6766
self.__condition = condition
68-
self.__data = data
6967
self.__classLabelsDistribution = DiscreteDistribution()
70-
labels = self.__data.getClassLabels()
68+
labels = data.getClassLabels()
7169
for label in labels:
7270
self.__classLabelsDistribution.addItem(label)
7371
self.__class_label = Model.getMaximum(labels)
7472
self.leaf = True
7573
self.children = []
76-
class_labels = self.__data.getDistinctClassLabels()
74+
class_labels = data.getDistinctClassLabels()
7775
if len(class_labels) == 1:
7876
return
7977
if isStump and condition is not None:
@@ -93,15 +91,14 @@ def constructor1(self,
9391
distribution = data.discreteIndexedAttributeClassDistribution(index, k)
9492
if distribution.getSum() > 0:
9593
class_distribution.removeDistribution(distribution)
96-
entropy = (
97-
class_distribution.entropy() * class_distribution.getSum() + distribution.entropy() * distribution.getSum()) / data.size()
94+
entropy = (class_distribution.entropy() * class_distribution.getSum() + distribution.entropy() * distribution.getSum()) / data.size()
9895
if entropy + self.EPSILON < best_entropy:
9996
best_entropy = entropy
10097
best_attribute = index
10198
best_split_value = k
10299
class_distribution.addDistribution(distribution)
103100
elif isinstance(data.get(0).getAttribute(index), DiscreteAttribute):
104-
entropy = self.__entropyForDiscreteAttribute(index)
101+
entropy = self.__entropyForDiscreteAttribute(data, index)
105102
if entropy + self.EPSILON < best_entropy:
106103
best_entropy = entropy
107104
best_attribute = index
@@ -128,16 +125,19 @@ def constructor1(self,
128125
if best_attribute != -1:
129126
self.leaf = False
130127
if isinstance(data.get(0).getAttribute(best_attribute), DiscreteIndexedAttribute):
131-
self.__createChildrenForDiscreteIndexed(attributeIndex=best_attribute,
128+
self.__createChildrenForDiscreteIndexed(data=data,
129+
attributeIndex=best_attribute,
132130
attributeValue=best_split_value,
133131
parameter=parameter,
134132
isStump=isStump)
135133
elif isinstance(data.get(0).getAttribute(best_attribute), DiscreteAttribute):
136-
self.__createChildrenForDiscrete(attributeIndex=best_attribute,
134+
self.__createChildrenForDiscrete(data=data,
135+
attributeIndex=best_attribute,
137136
parameter=parameter,
138137
isStump=isStump)
139138
elif isinstance(data.get(0).getAttribute(best_attribute), ContinuousAttribute):
140-
self.__createChildrenForContinuous(attributeIndex=best_attribute,
139+
self.__createChildrenForContinuous(data=data,
140+
attributeIndex=best_attribute,
141141
splitValue=best_split_value,
142142
parameter=parameter,
143143
isStump=isStump)
@@ -149,7 +149,8 @@ def constructor2(self, inputFile: TextIOWrapper):
149149
if items[1][0] == '=':
150150
self.__condition = DecisionCondition(int(items[0]), DiscreteAttribute(items[2]), items[1][0])
151151
elif items[1][0] == ':':
152-
self.__condition = DecisionCondition(int(items[0]), DiscreteIndexedAttribute("", int(items[2]), int(items[3])), '=')
152+
self.__condition = DecisionCondition(int(items[0]),
153+
DiscreteIndexedAttribute("", int(items[2]), int(items[3])), '=')
153154
else:
154155
self.__condition = DecisionCondition(int(items[0]), ContinuousAttribute(float(items[2])), items[1][0])
155156
else:
@@ -175,7 +176,7 @@ def __init__(self,
175176
elif isinstance(data, TextIOWrapper):
176177
self.constructor2(data)
177178

178-
def __entropyForDiscreteAttribute(self, attributeIndex: int):
179+
def __entropyForDiscreteAttribute(self, data: InstanceList, attributeIndex: int):
179180
"""
180181
The entropyForDiscreteAttribute method takes an attributeIndex and creates an ArrayList of DiscreteDistribution.
181182
Then loops through the distributions and calculates the total entropy.
@@ -191,12 +192,13 @@ def __entropyForDiscreteAttribute(self, attributeIndex: int):
191192
Total entropy for the discrete attribute.
192193
"""
193194
total = 0.0
194-
distributions = self.__data.attributeClassDistribution(attributeIndex)
195+
distributions = data.attributeClassDistribution(attributeIndex)
195196
for distribution in distributions:
196-
total += (distribution.getSum() / self.__data.size()) * distribution.entropy()
197+
total += (distribution.getSum() / data.size()) * distribution.entropy()
197198
return total
198199

199200
def __createChildrenForDiscreteIndexed(self,
201+
data: InstanceList,
200202
attributeIndex: int,
201203
attributeValue: int,
202204
parameter: RandomForestParameter,
@@ -216,13 +218,13 @@ def __createChildrenForDiscreteIndexed(self,
216218
isStump : bool
217219
Refers to decision trees with only 1 splitting rule.
218220
"""
219-
children_data = Partition(self.__data, attributeIndex, attributeValue)
221+
children_data = Partition(data, attributeIndex, attributeValue)
220222
self.children.append(
221223
DecisionNode(data=children_data.get(0),
222224
condition=DecisionCondition(attributeIndex,
223225
DiscreteIndexedAttribute("",
224226
attributeValue,
225-
self.__data.get(0).getAttribute(
227+
data.get(0).getAttribute(
226228
attributeIndex).getMaxIndex())),
227229
parameter=parameter,
228230
isStump=isStump))
@@ -231,12 +233,13 @@ def __createChildrenForDiscreteIndexed(self,
231233
condition=DecisionCondition(attributeIndex,
232234
DiscreteIndexedAttribute("",
233235
-1,
234-
self.__data.get(0).getAttribute(
236+
data.get(0).getAttribute(
235237
attributeIndex).getMaxIndex())),
236238
parameter=parameter,
237239
isStump=isStump))
238240

239241
def __createChildrenForDiscrete(self,
242+
data: InstanceList,
240243
attributeIndex: int,
241244
parameter: RandomForestParameter,
242245
isStump: bool):
@@ -253,16 +256,20 @@ def __createChildrenForDiscrete(self,
253256
isStump : bool
254257
Refers to decision trees with only 1 splitting rule.
255258
"""
256-
value_list = self.__data.getAttributeValueList(attributeIndex)
257-
children_data = Partition(self.__data, attributeIndex)
259+
value_list = data.getAttributeValueList(attributeIndex)
260+
children_data = Partition(data, attributeIndex)
258261
for i in range(len(value_list)):
259262
self.children.append(DecisionNode(data=children_data.get(i),
260263
condition=DecisionCondition(attributeIndex=attributeIndex,
261264
value=DiscreteAttribute(value_list[i])),
262265
parameter=parameter,
263266
isStump=isStump))
264267

265-
def __createChildrenForContinuous(self, attributeIndex: int, splitValue: float, parameter: RandomForestParameter,
268+
def __createChildrenForContinuous(self,
269+
data: InstanceList,
270+
attributeIndex: int,
271+
splitValue: float,
272+
parameter: RandomForestParameter,
266273
isStump: bool):
267274
"""
268275
The createChildrenForContinuous method creates a list of DecisionNodes as children and a partition with respect
@@ -279,7 +286,7 @@ def __createChildrenForContinuous(self, attributeIndex: int, splitValue: float,
279286
splitValue : float
280287
Split value is used for partitioning.
281288
"""
282-
children_data = Partition(self.__data, attributeIndex, splitValue)
289+
children_data = Partition(data, attributeIndex, splitValue)
283290
self.children.append(DecisionNode(children_data.get(0),
284291
DecisionCondition(attributeIndex, ContinuousAttribute(splitValue), "<"),
285292
parameter, isStump))
@@ -304,7 +311,7 @@ def predict(self, instance: Instance) -> str:
304311
"""
305312
if isinstance(instance, CompositeInstance):
306313
possible_class_labels = instance.getPossibleClassLabels()
307-
distribution = self.__data.classDistribution()
314+
distribution = self.__classLabelsDistribution
308315
predicted_class = distribution.getMaxItemIncludeTheseOnly(possible_class_labels)
309316
if self.leaf:
310317
return predicted_class
@@ -332,4 +339,4 @@ def predictProbabilityDistribution(self, instance: Instance) -> dict:
332339
for node in self.children:
333340
if node.__condition.satisfy(instance):
334341
return node.predictProbabilityDistribution(instance)
335-
return self.__data.classDistribution().getProbabilityDistribution()
342+
return self.__classLabelsDistribution.getProbabilityDistribution()

0 commit comments

Comments
 (0)