1- from torch .utils .data import Dataset , DataLoader
2- import torch
31import os
4- from imageio import imread
2+
53import numpy as np
64import scipy as sp
7- from PIL import Image , ImageDraw
5+
6+ from diffgram .core .diffgram_dataset_iterator import DiffgramDatasetIterator
87
98
10- class DiffgramPytorchDataset (Dataset ):
9+ class DiffgramPytorchDataset (DiffgramDatasetIterator , Dataset ):
1110
1211 def __init__ (self , project , diffgram_file_id_list = None , transform = None ):
1312 """
@@ -16,60 +15,21 @@ def __init__(self, project, diffgram_file_id_list = None, transform = None):
1615 :param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1716 :param transform (callable, optional): Optional transforms to be applied on a sample
1817 """
18+ super (DiffgramDatasetIterator , self ).__init__ (project , diffgram_file_id_list )
19+ global torch , Dataset , DataLoader
20+ try :
21+ import torch as torch # type: ignore
22+ from torch .utils .data import Dataset , DataLoader
23+ except ModuleNotFoundError :
24+ raise ModuleNotFoundError (
25+ "'torch' module should be installed to convert the Dataset into pytorch format"
26+ )
1927 self .diffgram_file_id_list = diffgram_file_id_list
2028
2129 self .project = project
2230 self .transform = transform
23- self ._internal_file_list = []
2431 self .__validate_file_ids ()
2532
26- def __validate_file_ids (self ):
27- result = self .project .file .file_list_exists (self .diffgram_file_id_list )
28- if not result :
29- raise Exception (
30- 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
31-
32- def __extract_masks_from_polygon (self , instance_list , diffgram_file , empty_value = 0 ):
33- nx , ny = diffgram_file .image ['width' ], diffgram_file .image ['height' ]
34- mask_list = []
35- for instance in instance_list :
36- if instance ['type' ] != 'polygon' :
37- continue
38- poly = [(p ['x' ], p ['y' ]) for p in instance ['points' ]]
39-
40- img = Image .new (mode = 'L' , size = (nx , ny ), color = 0 ) # mode L = 8-bit pixels, black and white
41- draw = ImageDraw .Draw (img )
42- print ()
43- draw .polygon (poly , outline = 1 , fill = 1 )
44- mask = np .array (img ).astype ('float32' )
45- # mask[np.where(mask == 0)] = empty_value
46- print ('mask' , len (mask ))
47- mask_list .append (mask )
48- return mask_list
49-
50- def __extract_bbox_values (self , instance_list , diffgram_file ):
51- """
52- Creates a pytorch tensor based on the instance type.
53- For now we are assuming shapes here, but we can extend it
54- to accept custom shapes specified by the user.
55- :param instance:
56- :return:
57- """
58- x_min_list = []
59- x_max_list = []
60- y_min_list = []
61- y_max_list = []
62-
63- for inst in instance_list :
64- if inst ['type' ] != 'box' :
65- continue
66- x_min_list .append (inst ['x_min' ] / diffgram_file .image ['width' ])
67- x_max_list .append (inst ['x_max' ] / diffgram_file .image ['width' ])
68- y_min_list .append (inst ['y_min' ] / diffgram_file .image ['width' ])
69- y_max_list .append (inst ['y_max' ] / diffgram_file .image ['width' ])
70-
71- return x_min_list , x_max_list , y_min_list , y_max_list
72-
7333 def __len__ (self ):
7434 return len (self .diffgram_file_id_list )
7535
@@ -81,25 +41,17 @@ def __getitem__(self, idx):
8141 idx = idx .tolist ()
8242
8343 diffgram_file = self .project .file .get_by_id (self .diffgram_file_id_list [idx ], with_instances = True )
84- if hasattr (diffgram_file , 'image' ):
85- image = imread (diffgram_file .image .get ('url_signed' ))
86- else :
87- raise Exception ('Pytorch datasets only support images. Please provide only file_ids from images' )
8844
89- instance_list = diffgram_file .instance_list
90- instance_types_in_file = set ([x ['type' ] for x in instance_list ])
91- # Process the instances of each file
92- processed_instance_list = []
93- sample = {'image' : image , 'diffgram_file' : diffgram_file }
94- if 'box' in instance_types_in_file :
95- x_min_list , x_max_list , y_min_list , y_max_list = self .__extract_bbox_values (instance_list , diffgram_file )
96- sample ['x_min_list' ] = torch .Tensor (x_min_list )
97- sample ['x_max_list' ] = torch .Tensor (x_max_list )
98- sample ['y_min_list' ] = torch .Tensor (y_min_list )
99- sample ['y_max_list' ] = torch .Tensor (y_max_list )
100- if 'polygon' in instance_types_in_file :
101- mask_list = self .__extract_masks_from_polygon (instance_list , diffgram_file )
102- sample ['polygon_mask_list' ] = mask_list
45+ sample = self .get_file_instances (diffgram_file )
46+ if 'x_min_list' in sample :
47+ sample ['x_min_list' ] = torch .Tensor (sample ['x_min_list' ])
48+ if 'x_max_list' in sample :
49+ sample ['x_max_list' ] = torch .Tensor (sample ['x_max_list' ])
50+ if 'y_min_list' in sample :
51+ sample ['y_min_list' ] = torch .Tensor (sample ['y_min_list' ])
52+ if 'y_max_list' in sample :
53+ sample ['y_max_list' ] = torch .Tensor (sample ['y_max_list' ])
54+
10355 if self .transform :
10456 sample = self .transform (sample )
10557
0 commit comments