Skip to content

Commit c12eed6

Browse files
author
Olcay Taner YILDIZ
committed
Added predictProbability method.
1 parent f1b3372 commit c12eed6

9 files changed

Lines changed: 42 additions & 1 deletion

File tree

Classification/Model/DecisionTree/DecisionNode.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,12 @@ def predict(self, instance: Instance) -> str:
253253
if node.__condition.satisfy(instance):
254254
return node.predict(instance)
255255
return self.__classLabel
256+
257+
def predictProbabilityDistribution(self, instance: Instance) -> dict:
258+
if self.leaf:
259+
return self.__data.classDistribution().getProbabilityDistribution()
260+
else:
261+
for node in self.children:
262+
if node.__condition.satisfy(instance):
263+
return node.predictProbabilityDistribution(instance)
264+
return self.__data.classDistribution().getProbabilityDistribution()

Classification/Model/DecisionTree/DecisionTree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def predict(self, instance: Instance) -> str:
4040
predictedClass = instance.getPossibleClassLabels()
4141
return predictedClass
4242

43+
def predictProbability(self, instance: Instance) -> dict:
44+
return self.__root.predictProbabilityDistribution(instance)
45+
4346
def pruneNode(self, node: DecisionNode, pruneSet: InstanceList):
4447
"""
4548
The prune method takes a DecisionNode and an InstanceList as inputs. It checks the classification performance

Classification/Model/DummyModel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ def predict(self, instance: Instance) -> str:
4040
return self.distribution.getMaxItemIncludeTheseOnly(possibleClassLabels)
4141
else:
4242
return self.distribution.getMaxItem()
43+
44+
def predictProbability(self, instance: Instance) -> dict:
45+
return self.distribution.getProbabilityDistribution()

Classification/Model/KnnModel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def predict(self, instance: Instance) -> str:
5454
predictedClass = Model.getMaximum(nearestNeighbors.getClassLabels())
5555
return predictedClass
5656

57+
def predictProbability(self, instance: Instance) -> dict:
58+
nearestNeighbors = self.nearestNeighbors(instance)
59+
return nearestNeighbors.classDistribution().getProbabilityDistribution()
60+
5761
def makeComparator(self):
5862
def compare(instanceA: KnnInstance, instanceB: KnnInstance):
5963
if instanceA.distance < instanceB.distance:

Classification/Model/Model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def predict(self, instance: Instance) -> str:
2424
"""
2525
pass
2626

27+
@abstractmethod
28+
def predictProbability(self, instance: Instance) -> dict:
29+
pass
30+
2731
@staticmethod
2832
def getMaximum(classLabels: list) -> str:
2933
"""

Classification/Model/NeuralNetworkModel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,9 @@ def predict(self, instance: Instance) -> str:
225225
return self.predictWithCompositeInstance(instance.getPossibleClassLabels())
226226
else:
227227
return self.classLabels[self.y.maxIndex()]
228+
229+
def predictProbability(self, instance: Instance) -> dict:
230+
result = {}
231+
for i in range(len(self.classLabels)):
232+
result[self.classLabels[i]] = self.y.getValue(i)
233+
return result

Classification/Model/RandomModel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def predict(self, instance: Instance) -> str:
4646
size = len(self.__classLabels)
4747
index = random.randrange(size)
4848
return self.__classLabels[index]
49+
50+
def predictProbability(self, instance: Instance) -> dict:
51+
result = {}
52+
for classLabel in self.__classLabels:
53+
result[classLabel] = 1.0 / len(self.__classLabels)
54+
return result

Classification/Model/TreeEnsembleModel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,9 @@ def predict(self, instance: Instance) -> str:
3838
for tree in self.__forest:
3939
distribution.addItem(tree.predict(instance))
4040
return distribution.getMaxItem()
41+
42+
def predictProbability(self, instance: Instance) -> dict:
43+
distribution = DiscreteDistribution()
44+
for tree in self.__forest:
45+
distribution.addItem(tree.predict(instance))
46+
return distribution.getProbabilityDistribution()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='NlpToolkit-Classification',
5-
version='1.0.7',
5+
version='1.0.8',
66
packages=['Classification', 'Classification.Model', 'Classification.Model.DecisionTree', 'Classification.Filter',
77
'Classification.DataSet', 'Classification.Instance', 'Classification.Attribute',
88
'Classification.Parameter', 'Classification.Classifier', 'Classification.Experiment',

0 commit comments

Comments
 (0)