Skip to content

Commit 9b399c5

Browse files
author
Olcay Taner YILDIZ
committed
Updated DecisionTree model saving.
1 parent 12c7572 commit 9b399c5

28 files changed

Lines changed: 185989 additions & 10 deletions

Classification/Model/DecisionTree/DecisionNode.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class DecisionNode(object):
2121
__class_label: str = None
2222
leaf: bool
2323
__condition: DecisionCondition
24+
__classLabelsDistribution: DiscreteDistribution
2425
EPSILON = 0.0000000001
2526

2627
def constructor1(self,
@@ -65,7 +66,11 @@ def constructor1(self,
6566
best_split_value = 0
6667
self.__condition = condition
6768
self.__data = data
68-
self.__class_label = Model.getMaximum(self.__data.getClassLabels())
69+
self.__classLabelsDistribution = DiscreteDistribution()
70+
labels = self.__data.getClassLabels()
71+
for label in labels:
72+
self.__classLabelsDistribution.addItem(label)
73+
self.__class_label = Model.getMaximum(labels)
6974
self.leaf = True
7075
self.children = []
7176
class_labels = self.__data.getDistinctClassLabels()
@@ -158,6 +163,7 @@ def constructor2(self, inputFile: TextIOWrapper):
158163
else:
159164
self.leaf = True
160165
self.__class_label = inputFile.readline().strip()
166+
self.__classLabelsDistribution = Model.loadClassDistribution(inputFile)
161167

162168
def __init__(self,
163169
data: object,
@@ -321,7 +327,7 @@ def predict(self, instance: Instance) -> str:
321327

322328
def predictProbabilityDistribution(self, instance: Instance) -> dict:
323329
if self.leaf:
324-
return self.__data.classDistribution().getProbabilityDistribution()
330+
return self.__classLabelsDistribution.getProbabilityDistribution()
325331
else:
326332
for node in self.children:
327333
if node.__condition.satisfy(instance):

Classification/Model/DummyModel.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,7 @@ def constructor1(self, trainSet: InstanceList):
2222

2323
def constructor2(self, fileName: str):
2424
inputFile = open(fileName, mode='r', encoding='utf-8')
25-
self.distribution = DiscreteDistribution()
26-
size = int(inputFile.readline().strip())
27-
for i in range(size):
28-
line = inputFile.readline().strip()
29-
items = line.split(" ")
30-
count = int(items[1])
31-
for j in range(count):
32-
self.distribution.addItem(items[0])
25+
self.distribution = Model.loadClassDistribution(inputFile)
3326
inputFile.close()
3427

3528
def __init__(self, trainSet: object):

Classification/Model/Model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from io import TextIOWrapper
33

44
from DataStructure.CounterHashMap import CounterHashMap
5+
from Math.DiscreteDistribution import DiscreteDistribution
56
from Math.Matrix import Matrix
67

78
from Classification.Instance.Instance import Instance
@@ -50,6 +51,18 @@ def loadMatrix(self, inputFile: TextIOWrapper) -> Matrix:
5051
matrix.setValue(j, k, float(items[k]))
5152
return matrix
5253

54+
@staticmethod
55+
def loadClassDistribution(inputFile: TextIOWrapper) -> DiscreteDistribution:
56+
distribution = DiscreteDistribution()
57+
size = int(inputFile.readline().strip())
58+
for i in range(size):
59+
line = inputFile.readline().strip()
60+
items = line.split(" ")
61+
count = int(items[1])
62+
for j in range(count):
63+
distribution.addItem(items[0])
64+
return distribution
65+
5366
@staticmethod
5467
def getMaximum(classLabels: list) -> str:
5568
"""

0 commit comments

Comments
 (0)