33import os
44from imageio import imread
55import numpy as np
6+ import scipy as sp
7+ from PIL import Image , ImageDraw
68
79
810class DiffgramPytorchDataset (Dataset ):
@@ -15,17 +17,37 @@ def __init__(self, project, diffgram_file_id_list = None, transform = None):
1517 :param transform (callable, optional): Optional transforms to be applied on a sample
1618 """
1719 self .diffgram_file_id_list = diffgram_file_id_list
18- self . __validate_file_ids ()
20+
1921 self .project = project
2022 self .transform = transform
2123 self ._internal_file_list = []
22-
24+ self . __validate_file_ids ()
2325
2426 def __validate_file_ids (self ):
25- url = '/api/'
26- raise NotImplementedError
27-
28- def __extract_bbox_values (self , instance_list ):
27+ result = self .project .file .file_list_exists (self .diffgram_file_id_list )
28+ if not result :
29+ raise Exception (
30+ 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
31+
32+ def __extract_masks_from_polygon (self , instance_list , diffgram_file , empty_value = 0 ):
33+ nx , ny = diffgram_file .image ['width' ], diffgram_file .image ['height' ]
34+ mask_list = []
35+ for instance in instance_list :
36+ if instance ['type' ] != 'polygon' :
37+ continue
38+ poly = [(p ['x' ], p ['y' ]) for p in instance ['points' ]]
39+
40+ img = Image .new (mode = 'L' , size = (nx , ny ), color = 0 ) # mode L = 8-bit pixels, black and white
41+ draw = ImageDraw .Draw (img )
42+ print ()
43+ draw .polygon (poly , outline = 1 , fill = 1 )
44+ mask = np .array (img ).astype ('float32' )
45+ # mask[np.where(mask == 0)] = empty_value
46+ print ('mask' , len (mask ))
47+ mask_list .append (mask )
48+ return mask_list
49+
50+ def __extract_bbox_values (self , instance_list , diffgram_file ):
2951 """
3052 Creates a pytorch tensor based on the instance type.
3153 For now we are assuming shapes here, but we can extend it
@@ -41,10 +63,10 @@ def __extract_bbox_values(self, instance_list):
4163 for inst in instance_list :
4264 if inst ['type' ] != 'box' :
4365 continue
44- x_min_list .append (inst ['x_min' ])
45- x_max_list .append (inst ['x_max' ])
46- y_min_list .append (inst ['y_min' ])
47- y_max_list .append (inst ['y_max' ])
66+ x_min_list .append (inst ['x_min' ] / diffgram_file . image [ 'width' ] )
67+ x_max_list .append (inst ['x_max' ] / diffgram_file . image [ 'width' ] )
68+ y_min_list .append (inst ['y_min' ] / diffgram_file . image [ 'width' ] )
69+ y_max_list .append (inst ['y_max' ] / diffgram_file . image [ 'width' ] )
4870
4971 return x_min_list , x_max_list , y_min_list , y_max_list
5072
@@ -58,7 +80,7 @@ def __getitem__(self, idx):
5880 if torch .is_tensor (idx ):
5981 idx = idx .tolist ()
6082
61- diffgram_file = self .project .file .get_by_id (idx , with_instances = True )
83+ diffgram_file = self .project .file .get_by_id (self . diffgram_file_id_list [ idx ] , with_instances = True )
6284 if hasattr (diffgram_file , 'image' ):
6385 image = imread (diffgram_file .image .get ('url_signed' ))
6486 else :
@@ -68,17 +90,17 @@ def __getitem__(self, idx):
6890 instance_types_in_file = set ([x ['type' ] for x in instance_list ])
6991 # Process the instances of each file
7092 processed_instance_list = []
71-
72- sample = {'image' : image }
93+ sample = {'image' : image , 'diffgram_file' : diffgram_file }
7394 if 'box' in instance_types_in_file :
74- x_min_list , x_max_list , y_min_list , y_max_list = self .__extract_bbox_values (instance_list )
95+ x_min_list , x_max_list , y_min_list , y_max_list = self .__extract_bbox_values (instance_list , diffgram_file )
7596 sample ['x_min_list' ] = torch .Tensor (x_min_list )
7697 sample ['x_max_list' ] = torch .Tensor (x_max_list )
7798 sample ['y_min_list' ] = torch .Tensor (y_min_list )
7899 sample ['y_max_list' ] = torch .Tensor (y_max_list )
79100 if 'polygon' in instance_types_in_file :
80-
101+ mask_list = self .__extract_masks_from_polygon (instance_list , diffgram_file )
102+ sample ['polygon_mask_list' ] = mask_list
81103 if self .transform :
82104 sample = self .transform (sample )
83105
84- return sample
106+ return sample
0 commit comments