Skip to content

Commit 286ea9e

Browse files
committed
Refactored label_per_data using managers.
1 parent 6455754 commit 286ea9e

3 files changed

Lines changed: 42 additions & 16 deletions

File tree

app/api/managers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections import Counter
2+
3+
from django.db.models import Manager, Count
4+
5+
6+
class AnnotationManager(Manager):
7+
8+
def get_label_per_data(self, project):
9+
label_count = Counter()
10+
user_count = Counter()
11+
docs = project.documents.all()
12+
annotations = self.filter(document_id__in=docs.all())
13+
14+
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
15+
label_count[d['label__text']] += d['label__count']
16+
user_count[d['user__username']] += d['user__count']
17+
18+
return label_count, user_count
19+
20+
21+
class Seq2seqAnnotationManager(Manager):
22+
23+
def get_label_per_data(self, project):
24+
label_count = Counter()
25+
user_count = Counter()
26+
docs = project.documents.all()
27+
annotations = self.filter(document_id__in=docs.all())
28+
29+
for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')):
30+
label_count[d['text']] += d['text__count']
31+
user_count[d['user__username']] += d['user__count']
32+
33+
return label_count, user_count

app/api/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from django.core.exceptions import ValidationError
88
from polymorphic.models import PolymorphicModel
99

10+
from .managers import AnnotationManager, Seq2seqAnnotationManager
11+
1012
DOCUMENT_CLASSIFICATION = 'DocumentClassification'
1113
SEQUENCE_LABELING = 'SequenceLabeling'
1214
SEQ2SEQ = 'Seq2seq'
@@ -192,6 +194,8 @@ def __str__(self):
192194

193195

194196
class Annotation(models.Model):
197+
objects = AnnotationManager()
198+
195199
prob = models.FloatField(default=0.0)
196200
manual = models.BooleanField(default=False)
197201
user = models.ForeignKey(User, on_delete=models.CASCADE)
@@ -225,6 +229,9 @@ class Meta:
225229

226230

227231
class Seq2seqAnnotation(Annotation):
232+
# Override AnnotationManager for custom functionality
233+
objects = Seq2seqAnnotationManager()
234+
228235
document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
229236
text = models.TextField()
230237

app/api/views.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import Counter
2-
31
from django.conf import settings
42
from django.shortcuts import get_object_or_404, redirect
53
from django_filters.rest_framework import DjangoFilterBackend
@@ -15,7 +13,7 @@
1513
from rest_framework_csv.renderers import CSVRenderer
1614

1715
from .filters import DocumentFilter
18-
from .models import Project, Label, Document, Seq2seqAnnotation
16+
from .models import Project, Label, Document
1917
from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsOwnAnnotation
2018
from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer, UserSerializer
2119
from .serializers import ProjectPolymorphicSerializer
@@ -85,20 +83,8 @@ def progress(self, project):
8583
return {'total': total, 'remaining': remaining}
8684

8785
def label_per_data(self, project):
88-
label_count = Counter()
89-
user_count = Counter()
9086
annotation_class = project.get_annotation_class()
91-
docs = project.documents.all()
92-
annotations = annotation_class.objects.filter(document_id__in=docs.all())
93-
if annotation_class == Seq2seqAnnotation:
94-
for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')):
95-
label_count[d['text']] += d['text__count']
96-
user_count[d['user__username']] += d['user__count']
97-
else:
98-
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
99-
label_count[d['label__text']] += d['label__count']
100-
user_count[d['user__username']] += d['user__count']
101-
return label_count, user_count
87+
return annotation_class.objects.get_label_per_data(project=project)
10288

10389

10490
class ApproveLabelsAPI(APIView):

0 commit comments

Comments
 (0)