Skip to content

Commit c9e953c

Browse files
committed
wip: adding dataset iterator class
1 parent 1d00f6b commit c9e953c

7 files changed

Lines changed: 243 additions & 75 deletions

File tree

pytorch_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@
2323
img = Image.new("L", [diffgram_dataset[0]['diffgram_file'].image['width'], diffgram_dataset[0]['diffgram_file'].image['height']], 0)
2424
mask1 = diffgram_dataset[0]['polygon_mask_list'][0]
2525
mask2 = diffgram_dataset[0]['polygon_mask_list'][1]
26-
print(mask1)
27-
for x in mask1:
28-
print(x)
2926
plt.figure()
3027
plt.subplot(1,2,1)
3128
# plt.imshow(img, 'gray', interpolation='none')
3229
plt.imshow(mask1, 'jet', interpolation='none', alpha=0.7)
33-
plt.imshow(mask2, 'Oranges', interpolation='none', alpha=0.7)
30+
plt.imshow(mask2, 'Oranges', interpolation='none', alpha=0.7)
31+
plt.show()
32+
33+
34+
# Dataset Example
35+
36+
dataset = project.directory.get('Default')
37+
38+
sliced_dataset = dataset.slice(query = 'labels.sheep > 0 or labels.sofa > 0')
39+
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from PIL import Image, ImageDraw
2+
from imageio import imread
3+
4+
5+
class DiffgramDatasetIterator:
6+
7+
def __init__(self, project, diffgram_file_id_list):
8+
"""
9+
10+
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
11+
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
12+
"""
13+
self.diffgram_file_id_list = diffgram_file_id_list
14+
15+
self.project = project
16+
self._internal_file_list = []
17+
self.__validate_file_ids()
18+
self.current_file_index = 0
19+
20+
def __iter__(self):
21+
self.current_file_index = 0
22+
return self
23+
24+
def __next__(self):
25+
file_id = self.diffgram_file_id_list[self.current_file_index]
26+
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
27+
instance_data = self.get_file_instances(diffgram_file)
28+
self.current_file_index += 1
29+
return instance_data
30+
31+
def __validate_file_ids(self):
32+
result = self.project.file.file_list_exists(self.diffgram_file_id_list)
33+
if not result:
34+
raise Exception(
35+
'Some file IDs do not belong to the project. Please provide only files from the same project.')
36+
37+
def get_image_data(self, diffgram_file):
38+
if hasattr(diffgram_file, 'image'):
39+
image = imread(diffgram_file.image.get('url_signed'))
40+
return image
41+
else:
42+
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')
43+
44+
def get_file_instances(self, diffgram_file):
45+
if diffgram_file['type'] not in ['image', 'frame']:
46+
raise NotImplementedError('File type "{}" is not supported yet'.format(diffgram_file['type']))
47+
48+
image = self.get_image_data(diffgram_file)
49+
instance_list = diffgram_file.instance_list
50+
instance_types_in_file = set([x['type'] for x in instance_list])
51+
# Process the instances of each file
52+
sample = {'image': image, 'diffgram_file': diffgram_file}
53+
has_boxes = False
54+
has_poly = False
55+
if 'box' in instance_types_in_file:
56+
has_boxes = True
57+
x_min_list, x_max_list, y_min_list, y_max_list = self.extract_bbox_values(instance_list, diffgram_file)
58+
sample['x_min_list'] = x_min_list
59+
sample['x_max_list'] = x_max_list
60+
sample['y_min_list'] = y_min_list
61+
sample['y_max_list'] = y_max_list
62+
63+
if 'polygon' in instance_types_in_file:
64+
has_poly = True
65+
mask_list = self.extract_masks_from_polygon(instance_list, diffgram_file)
66+
sample['polygon_mask_list'] = mask_list
67+
68+
if len(instance_types_in_file) > 2 and has_boxes and has_boxes:
69+
raise NotImplementedError(
70+
'SDK only supports boxes and polygon types currently. If you want a new instance type to be supported please contact us!'
71+
)
72+
73+
label_id_list, label_name_list = self.extract_labels(instance_list)
74+
sample['label_id_list'] = label_id_list
75+
sample['label_name_list'] = label_name_list
76+
77+
return sample
78+
79+
def extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value = 0):
80+
nx, ny = diffgram_file.image['width'], diffgram_file.image['height']
81+
mask_list = []
82+
for instance in instance_list:
83+
if instance['type'] != 'polygon':
84+
continue
85+
poly = [(p['x'], p['y']) for p in instance['points']]
86+
87+
img = Image.new(mode = 'L', size = (nx, ny), color = 0) # mode L = 8-bit pixels, black and white
88+
draw = ImageDraw.Draw(img)
89+
draw.polygon(poly, outline = 1, fill = 1)
90+
mask = np.array(img).astype('float32')
91+
# mask[np.where(mask == 0)] = empty_value
92+
mask_list.append(mask)
93+
return mask_list
94+
95+
def extract_labels(self, instance_list, allowed_instance_types = None):
96+
label_file_id_list = []
97+
label_names_list = []
98+
99+
for inst in instance_list:
100+
if allowed_instance_types and inst['type'] in allowed_instance_types:
101+
continue
102+
103+
label_file_id_list.append(inst['label_file']['id'])
104+
label_names_list.append(inst['label_file']['label']['name'])
105+
106+
return label_file_id_list, label_names_list
107+
108+
def extract_bbox_values(self, instance_list, diffgram_file):
109+
"""
110+
Creates a pytorch tensor based on the instance type.
111+
For now we are assuming shapes here, but we can extend it
112+
to accept custom shapes specified by the user.
113+
:param instance:
114+
:return:
115+
"""
116+
x_min_list = []
117+
x_max_list = []
118+
y_min_list = []
119+
y_max_list = []
120+
121+
for inst in instance_list:
122+
if inst['type'] != 'box':
123+
continue
124+
x_min_list.append(inst['x_min'] / diffgram_file.image['width'])
125+
x_max_list.append(inst['x_max'] / diffgram_file.image['width'])
126+
y_min_list.append(inst['y_min'] / diffgram_file.image['width'])
127+
y_max_list.append(inst['y_max'] / diffgram_file.image['width'])
128+
129+
return x_min_list, x_max_list, y_min_list, y_max_list

sdk/diffgram/core/directory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def slice(self, query):
109109
file_view_mode = 'ids_only'
110110
)
111111
sliced_dataset = SlicedDirectory(
112+
client = self.client,
112113
query = query,
113114
original_directory = self
114115
)
Binary file not shown.

sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

Lines changed: 23 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from torch.utils.data import Dataset, DataLoader
2-
import torch
31
import os
4-
from imageio import imread
2+
53
import numpy as np
64
import scipy as sp
7-
from PIL import Image, ImageDraw
5+
6+
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
87

98

10-
class DiffgramPytorchDataset(Dataset):
9+
class DiffgramPytorchDataset(DiffgramDatasetIterator, Dataset):
1110

1211
def __init__(self, project, diffgram_file_id_list = None, transform = None):
1312
"""
@@ -16,60 +15,21 @@ def __init__(self, project, diffgram_file_id_list = None, transform = None):
1615
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1716
:param transform (callable, optional): Optional transforms to be applied on a sample
1817
"""
18+
super(DiffgramDatasetIterator, self).__init__(project, diffgram_file_id_list)
19+
global torch, Dataset, DataLoader
20+
try:
21+
import torch as torch # type: ignore
22+
from torch.utils.data import Dataset, DataLoader
23+
except ModuleNotFoundError:
24+
raise ModuleNotFoundError(
25+
"'torch' module should be installed to convert the Dataset into pytorch format"
26+
)
1927
self.diffgram_file_id_list = diffgram_file_id_list
2028

2129
self.project = project
2230
self.transform = transform
23-
self._internal_file_list = []
2431
self.__validate_file_ids()
2532

26-
def __validate_file_ids(self):
27-
result = self.project.file.file_list_exists(self.diffgram_file_id_list)
28-
if not result:
29-
raise Exception(
30-
'Some file IDs do not belong to the project. Please provide only files from the same project.')
31-
32-
def __extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value = 0):
33-
nx, ny = diffgram_file.image['width'], diffgram_file.image['height']
34-
mask_list = []
35-
for instance in instance_list:
36-
if instance['type'] != 'polygon':
37-
continue
38-
poly = [(p['x'], p['y']) for p in instance['points']]
39-
40-
img = Image.new(mode = 'L', size = (nx, ny), color = 0) # mode L = 8-bit pixels, black and white
41-
draw = ImageDraw.Draw(img)
42-
print()
43-
draw.polygon(poly, outline = 1, fill = 1)
44-
mask = np.array(img).astype('float32')
45-
# mask[np.where(mask == 0)] = empty_value
46-
print('mask', len(mask))
47-
mask_list.append(mask)
48-
return mask_list
49-
50-
def __extract_bbox_values(self, instance_list, diffgram_file):
51-
"""
52-
Creates a pytorch tensor based on the instance type.
53-
For now we are assuming shapes here, but we can extend it
54-
to accept custom shapes specified by the user.
55-
:param instance:
56-
:return:
57-
"""
58-
x_min_list = []
59-
x_max_list = []
60-
y_min_list = []
61-
y_max_list = []
62-
63-
for inst in instance_list:
64-
if inst['type'] != 'box':
65-
continue
66-
x_min_list.append(inst['x_min'] / diffgram_file.image['width'])
67-
x_max_list.append(inst['x_max'] / diffgram_file.image['width'])
68-
y_min_list.append(inst['y_min'] / diffgram_file.image['width'])
69-
y_max_list.append(inst['y_max'] / diffgram_file.image['width'])
70-
71-
return x_min_list, x_max_list, y_min_list, y_max_list
72-
7333
def __len__(self):
7434
return len(self.diffgram_file_id_list)
7535

@@ -81,25 +41,17 @@ def __getitem__(self, idx):
8141
idx = idx.tolist()
8242

8343
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True)
84-
if hasattr(diffgram_file, 'image'):
85-
image = imread(diffgram_file.image.get('url_signed'))
86-
else:
87-
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')
8844

89-
instance_list = diffgram_file.instance_list
90-
instance_types_in_file = set([x['type'] for x in instance_list])
91-
# Process the instances of each file
92-
processed_instance_list = []
93-
sample = {'image': image, 'diffgram_file': diffgram_file}
94-
if 'box' in instance_types_in_file:
95-
x_min_list, x_max_list, y_min_list, y_max_list = self.__extract_bbox_values(instance_list, diffgram_file)
96-
sample['x_min_list'] = torch.Tensor(x_min_list)
97-
sample['x_max_list'] = torch.Tensor(x_max_list)
98-
sample['y_min_list'] = torch.Tensor(y_min_list)
99-
sample['y_max_list'] = torch.Tensor(y_max_list)
100-
if 'polygon' in instance_types_in_file:
101-
mask_list = self.__extract_masks_from_polygon(instance_list, diffgram_file)
102-
sample['polygon_mask_list'] = mask_list
45+
sample = self.get_file_instances(diffgram_file)
46+
if 'x_min_list' in sample:
47+
sample['x_min_list'] = torch.Tensor(sample['x_min_list'])
48+
if 'x_max_list' in sample:
49+
sample['x_max_list'] = torch.Tensor(sample['x_max_list'])
50+
if 'y_min_list' in sample:
51+
sample['y_min_list'] = torch.Tensor(sample['y_min_list'])
52+
if 'y_max_list' in sample:
53+
sample['y_max_list'] = torch.Tensor(sample['y_max_list'])
54+
10355
if self.transform:
10456
sample = self.transform(sample)
10557

sdk/diffgram/tensorflow_diffgram/__init__.py

Whitespace-only changes.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
2+
import os
3+
4+
5+
class DiffgramTensorflowDataset(DiffgramDatasetIterator):
6+
7+
def __init__(self, project, diffgram_file_id_list = None):
8+
"""
9+
10+
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
11+
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
12+
:param transform (callable, optional): Optional transforms to be applied on a sample
13+
"""
14+
super(DiffgramDatasetIterator, self).__init__(project, diffgram_file_id_list)
15+
global tf
16+
try:
17+
import tensorflow as tf # type: ignore
18+
except ModuleNotFoundError:
19+
raise ModuleNotFoundError(
20+
"'tensorflow' module should be installed to convert the Dataset into tensorflow format"
21+
)
22+
self.diffgram_file_id_list = diffgram_file_id_list
23+
24+
self.project = project
25+
self.__validate_file_ids()
26+
27+
def int64_feature(self, value):
28+
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
29+
30+
def int64_list_feature(self, value):
31+
return tf.train.Feature(int64_list = tf.train.Int64List(value = value))
32+
33+
def bytes_feature(self, value):
34+
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
35+
36+
def bytes_list_feature(self, value):
37+
return tf.train.Feature(bytes_list = tf.train.BytesList(value = value))
38+
39+
def float_feature(self, value):
40+
return tf.train.Feature(float_list = tf.train.FloatList(value = [value]))
41+
42+
def float_list_feature(self, value):
43+
return tf.train.Feature(float_list = tf.train.FloatList(value = value))
44+
45+
def __validate_file_ids(self):
46+
result = self.project.file.file_list_exists(self.diffgram_file_id_list)
47+
if not result:
48+
raise Exception(
49+
'Some file IDs do not belong to the project. Please provide only files from the same project.')
50+
51+
def __iter__(self):
52+
self.current_file_index = 0
53+
return self
54+
55+
def __next__(self):
56+
file_id = self.diffgram_file_id_list[self.current_file_index]
57+
diffgram_file = self.project.file.get_by_id(file_id, with_instances = True)
58+
instance_data = self.get_file_instances(diffgram_file)
59+
filename, file_extension = os.path.splitext(instance_data['diffgram_file']['image']['original_filename'])
60+
print('instance_data', instance_data)
61+
tf_example_dict = {
62+
'image/height': self.int64_feature(instance_data['diffgram_file']['height']),
63+
'image/width': self.int64_feature(instance_data['diffgram_file']['width']),
64+
'image/filename': self.bytes_feature(filename),
65+
'image/source_id': self.bytes_feature(filename),
66+
'image/encoded': self.bytes_feature(instance_data['image']),
67+
'image/format': self.bytes_feature(file_extension),
68+
'image/object/bbox/xmin': self.float_list_feature(instance_data['x_min_list']),
69+
'image/object/bbox/xmax': self.float_list_feature(instance_data['x_max_list']),
70+
'image/object/bbox/ymin': self.float_list_feature(instance_data['y_min_list']),
71+
'image/object/bbox/ymax': self.float_list_feature(instance_data['y_max_list']),
72+
'image/object/class/text': self.bytes_list_feature(instance_data['label_name_list']),
73+
'image/object/class/label': self.int64_list_feature(instance_data['label_id_list']),
74+
}
75+
tf_example = tf.train.Example(features = tf.train.Features(feature = tf_example_dict))
76+
self.current_file_index += 1
77+
return tf_example
78+
79+
def get_dataset_obj(self):
80+
return tf.data.Dataset.from_generator(self.__iter__)

0 commit comments

Comments
 (0)