Skip to content

Commit 1d00f6b

Browse files
committed
wip: slice class and segmentation mask
1 parent a9c16ca commit 1d00f6b

7 files changed

Lines changed: 136 additions & 21 deletions

File tree

pytorch_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import diffgram
2+
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
3+
4+
project = diffgram.Project(project_string_id = "voc-test",
5+
client_id = "LIVE__p0blrrm6p5fnan5sh8ec",
6+
client_secret = "d14sl5vtg672ms8rg97yp1vc9do1ao3ee2xlzktk29kbk49t8mklpt7bvnmh",
7+
debug = True)
8+
9+
file = project.file.get_by_id(1554, with_instances = True)
10+
11+
diffgram_dataset = DiffgramPytorchDataset(
12+
project = project,
13+
diffgram_file_id_list = [1554]
14+
)
15+
16+
17+
18+
19+
20+
# Draw
21+
import matplotlib.pyplot as plt
22+
from PIL import Image, ImageDraw
23+
img = Image.new("L", [diffgram_dataset[0]['diffgram_file'].image['width'], diffgram_dataset[0]['diffgram_file'].image['height']], 0)
24+
mask1 = diffgram_dataset[0]['polygon_mask_list'][0]
25+
mask2 = diffgram_dataset[0]['polygon_mask_list'][1]
26+
print(mask1)
27+
for x in mask1:
28+
print(x)
29+
plt.figure()
30+
plt.subplot(1,2,1)
31+
# plt.imshow(img, 'gray', interpolation='none')
32+
plt.imshow(mask1, 'jet', interpolation='none', alpha=0.7)
33+
plt.imshow(mask2, 'Oranges', interpolation='none', alpha=0.7)

sdk/diffgram/core/directory.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,35 @@ def all_files(self):
9292
result = result + diffgram_files
9393
return result
9494

95+
def all_file_ids(self):
96+
page_num = 1
97+
result = []
98+
while page_num is not None:
99+
diffgram_files = self.list_files(limit = 1000, page_num = page_num, file_view_mode = 'ids_only')
100+
page_num = self.file_list_metadata['next_page']
101+
result = result + diffgram_files
102+
return result
103+
104+
def slice(self, query):
105+
from diffgram.core.sliced_directory import SlicedDirectory
106+
result = self.list_files(
107+
limit = 25,
108+
page_num = 1,
109+
file_view_mode = 'ids_only'
110+
)
111+
sliced_dataset = SlicedDirectory(
112+
query = query,
113+
original_directory = self
114+
)
115+
return sliced_dataset
116+
95117
def to_pytorch(self, transform = None):
96118
"""
97119
Transforms the file list inside the dataset into a pytorch dataset.
98120
:return:
99121
"""
100-
dataset_files = self.all_files()
101-
file_id_list = [file.id for file in dataset_files]
122+
from diffgram.core.sliced_directory import SlicedDirectory
123+
file_id_list = self.all_file_ids()
102124
pytorch_dataset = DiffgramPytorchDataset(
103125
project = self.client,
104126
diffgram_file_id_list = file_id_list,
@@ -162,7 +184,8 @@ def list_files(
162184
page_num=1,
163185
limit=100,
164186
search_term: str =None,
165-
file_view_mode: str = 'annotation'):
187+
file_view_mode: str = 'annotation',
188+
query: str = None):
166189
"""
167190
Get a list of files in directory (from Diffgram service).
168191
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from diffgram.core.directory import Directory
2+
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
3+
4+
class SlicedDirectory(Directory):
5+
6+
def __init__(self, client, original_directory: Directory, query: str):
7+
self.original_directory = original_directory
8+
self.query = query
9+
self.client = client
10+
11+
def all_file_ids(self):
12+
page_num = 1
13+
result = []
14+
while page_num is not None:
15+
diffgram_files = self.list_files(limit = 1000,
16+
page_num = page_num,
17+
file_view_mode = 'ids_only',
18+
query = self.query)
19+
page_num = self.file_list_metadata['next_page']
20+
result = result + diffgram_files
21+
return result
22+
23+
24+
def to_pytorch(self, transform = None):
25+
"""
26+
Transforms the file list inside the dataset into a pytorch dataset.
27+
:return:
28+
"""
29+
file_id_list = self.all_file_ids()
30+
pytorch_dataset = DiffgramPytorchDataset(
31+
project = self.client,
32+
diffgram_file_id_list = file_id_list,
33+
transform = transform
34+
35+
)
36+
return pytorch_dataset
37+

sdk/diffgram/file/file_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def format_assumptions(
383383

384384
return instance_list
385385

386-
def import_bulk:
386+
def import_bulk(self):
387387
"""
388388
Import multiple packets
389389
FUTURE
@@ -392,7 +392,7 @@ def import_bulk:
392392
{ packet_id : { packet }}
393393
394394
"""
395-
pass
395+
raise NotImplementedError
396396

397397
def get_file_list(self, id_list: list, with_instances: bool = False):
398398
"""
Binary file not shown.
Binary file not shown.

sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
from imageio import imread
55
import numpy as np
6+
import scipy as sp
7+
from PIL import Image, ImageDraw
68

79

810
class DiffgramPytorchDataset(Dataset):
@@ -15,17 +17,37 @@ def __init__(self, project, diffgram_file_id_list = None, transform = None):
1517
:param transform (callable, optional): Optional transforms to be applied on a sample
1618
"""
1719
self.diffgram_file_id_list = diffgram_file_id_list
18-
self.__validate_file_ids()
20+
1921
self.project = project
2022
self.transform = transform
2123
self._internal_file_list = []
22-
24+
self.__validate_file_ids()
2325

2426
def __validate_file_ids(self):
25-
url = '/api/'
26-
raise NotImplementedError
27-
28-
def __extract_bbox_values(self, instance_list):
27+
result = self.project.file.file_list_exists(self.diffgram_file_id_list)
28+
if not result:
29+
raise Exception(
30+
'Some file IDs do not belong to the project. Please provide only files from the same project.')
31+
32+
def __extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value = 0):
33+
nx, ny = diffgram_file.image['width'], diffgram_file.image['height']
34+
mask_list = []
35+
for instance in instance_list:
36+
if instance['type'] != 'polygon':
37+
continue
38+
poly = [(p['x'], p['y']) for p in instance['points']]
39+
40+
img = Image.new(mode = 'L', size = (nx, ny), color = 0) # mode L = 8-bit pixels, black and white
41+
draw = ImageDraw.Draw(img)
42+
print()
43+
draw.polygon(poly, outline = 1, fill = 1)
44+
mask = np.array(img).astype('float32')
45+
# mask[np.where(mask == 0)] = empty_value
46+
print('mask', len(mask))
47+
mask_list.append(mask)
48+
return mask_list
49+
50+
def __extract_bbox_values(self, instance_list, diffgram_file):
2951
"""
3052
Creates a pytorch tensor based on the instance type.
3153
For now we are assuming shapes here, but we can extend it
@@ -41,10 +63,10 @@ def __extract_bbox_values(self, instance_list):
4163
for inst in instance_list:
4264
if inst['type'] != 'box':
4365
continue
44-
x_min_list.append(inst['x_min'])
45-
x_max_list.append(inst['x_max'])
46-
y_min_list.append(inst['y_min'])
47-
y_max_list.append(inst['y_max'])
66+
x_min_list.append(inst['x_min'] / diffgram_file.image['width'])
67+
x_max_list.append(inst['x_max'] / diffgram_file.image['width'])
68+
y_min_list.append(inst['y_min'] / diffgram_file.image['width'])
69+
y_max_list.append(inst['y_max'] / diffgram_file.image['width'])
4870

4971
return x_min_list, x_max_list, y_min_list, y_max_list
5072

@@ -58,7 +80,7 @@ def __getitem__(self, idx):
5880
if torch.is_tensor(idx):
5981
idx = idx.tolist()
6082

61-
diffgram_file = self.project.file.get_by_id(idx, with_instances = True)
83+
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True)
6284
if hasattr(diffgram_file, 'image'):
6385
image = imread(diffgram_file.image.get('url_signed'))
6486
else:
@@ -68,17 +90,17 @@ def __getitem__(self, idx):
6890
instance_types_in_file = set([x['type'] for x in instance_list])
6991
# Process the instances of each file
7092
processed_instance_list = []
71-
72-
sample = {'image': image}
93+
sample = {'image': image, 'diffgram_file': diffgram_file}
7394
if 'box' in instance_types_in_file:
74-
x_min_list, x_max_list, y_min_list, y_max_list = self.__extract_bbox_values(instance_list)
95+
x_min_list, x_max_list, y_min_list, y_max_list = self.__extract_bbox_values(instance_list, diffgram_file)
7596
sample['x_min_list'] = torch.Tensor(x_min_list)
7697
sample['x_max_list'] = torch.Tensor(x_max_list)
7798
sample['y_min_list'] = torch.Tensor(y_min_list)
7899
sample['y_max_list'] = torch.Tensor(y_max_list)
79100
if 'polygon' in instance_types_in_file:
80-
101+
mask_list = self.__extract_masks_from_polygon(instance_list, diffgram_file)
102+
sample['polygon_mask_list'] = mask_list
81103
if self.transform:
82104
sample = self.transform(sample)
83105

84-
return sample
106+
return sample

0 commit comments

Comments
 (0)