Skip to content

Commit b9e032c

Browse files
committed
feat: initial support to export to pytorch
Gives users the ability to export any dataset into a pytorch dataset. Pending support for other instance types different from boxes and video support.
1 parent e85f04c commit b9e032c

7 files changed

Lines changed: 161 additions & 13 deletions

File tree

sdk/add_file_id_to_json.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from diffgram.core.core import Project
2+
import json
3+
4+
project = Project(project_string_id = "coco-dataset",
5+
debug = True,
6+
client_id = "LIVE__rj6whqkwxkups7oczqis",
7+
client_secret = "fr5vy64v2096qad9av0dgw3fr0kjavt4c156soiwx51ntyv9qswpuxkhg0lf")
8+
9+
10+
def find_file(file_list, name):
11+
for f in file_list:
12+
if f.original_filename == name:
13+
return f
14+
return None
15+
16+
17+
with open('/home/pablo/Downloads/coco2017.json') as json_file:
18+
data = json.load(json_file)
19+
20+
dataset_default = project.directory.get(name = "Default")
21+
22+
page_num = 1
23+
all_files = []
24+
print('start')
25+
while page_num != None:
26+
print('Current page', page_num)
27+
diffgram_files = dataset_default.list_files(limit = 1000, page_num = page_num, file_view_mode = 'base')
28+
page_num = dataset_default.file_list_metadata['next_page']
29+
print('{} of {}'.format(page_num, dataset_default.file_list_metadata['total_pages']))
30+
all_files = all_files + diffgram_files
31+
32+
print('')
33+
print('Files fetched: ', len(all_files))
34+
result = []
35+
for elm in data:
36+
file = find_file(all_files, name = elm['image_name'])
37+
if file:
38+
print('Adding file ID {} to {}'.format(file.id, elm['image_name']))
39+
elm['file_id'] = file.id
40+
result.append(elm)
41+
else:
42+
print(elm['image_name'], 'not found.')
43+
44+
s = json.dumps(result).
45+
f = open('/home/pablo/Downloads/coco2017_with_ids.json', 'w')
46+
f.write(s)

sdk/diffgram/core/directory.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from diffgram.file.file import File
22
from ..regular.regular import refresh_from_dict
33
import logging
4-
4+
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
55

66
def get_directory_list(self):
77
"""
@@ -78,6 +78,34 @@ def __init__(self,
7878
self.id = None
7979
self.file_list_metadata = {}
8080

81+
def all_files(self):
82+
"""
83+
Get all the files of the directoy.
84+
Warning! This can be an expensive function and take a long time.
85+
:return:
86+
"""
87+
page_num = 1
88+
result = []
89+
while page_num is not None:
90+
diffgram_files = self.list_files(limit = 1000, page_num = page_num, file_view_mode = 'base')
91+
page_num = self.file_list_metadata['next_page']
92+
result = result + diffgram_files
93+
return result
94+
95+
def to_pytorch(self, transform = None):
96+
"""
97+
Transforms the file list inside the dataset into a pytorch dataset.
98+
:return:
99+
"""
100+
dataset_files = self.all_files()
101+
file_id_list = [file.id for file in dataset_files]
102+
pytorch_dataset = DiffgramPytorchDataset(
103+
project = self.client,
104+
diffgram_file_id_list = file_id_list,
105+
transform = transform
106+
107+
)
108+
return pytorch_dataset
81109

82110
def new(self, name: str):
83111
"""

sdk/diffgram/file/file.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ..regular.regular import refresh_from_dict
22

3-
43
class File():
54
"""
65
file literal object
@@ -11,11 +10,12 @@ class File():
1110

1211
def __init__(
1312
self,
14-
id=None,
15-
client=None):
13+
id = None,
14+
client = None):
1615
self.id = id
1716
self.client = client
1817

18+
@staticmethod
1919
def new(
2020
client,
2121
file_json):

sdk/diffgram/file/file_constructor.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -414,29 +414,42 @@ def import_bulk():
414414

415415

416416
def get_by_id(self,
417-
id: int):
417+
id: int,
418+
with_instances: bool = False):
418419
"""
419420
returns Diffgram File object
420421
"""
421-
422-
endpoint = "/api/v1/file/view"
423422

424-
spec_dict = {
425-
'file_id': id,
426-
'project_string_id': self.client.project_string_id
423+
if not with_instances:
424+
endpoint = "/api/v1/file/view"
425+
426+
spec_dict = {
427+
'file_id': id,
428+
'project_string_id': self.client.project_string_id,
429+
}
430+
431+
432+
file_response_key = 'file'
433+
434+
else:
435+
endpoint = "/api/project/{}/file/{}/annotation/list".format(self.client.project_string_id, id)
436+
spec_dict = {
437+
'directory_id': self.client.directory_id
427438
}
439+
file_response_key = 'file_serialized'
428440

429441
response = self.client.session.post(
430442
self.client.host + endpoint,
431443
json = spec_dict)
432-
444+
433445
self.client.handle_errors(response)
434446

435447
response_json = response.json()
448+
file_data = response_json.get(file_response_key)
436449

437450
return File.new(
438451
client = self.client,
439-
file_json = response_json.get('file'))
452+
file_json = file_data)
440453

441454

442455

sdk/diffgram/pytorch_diffgram/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from torch.utils.data import Dataset, DataLoader
2+
import torch
3+
import os
4+
from imageio import imread
5+
import numpy as np
6+
7+
8+
class DiffgramPytorchDataset(Dataset):
9+
10+
def __init__(self, project, diffgram_file_id_list, transform = None):
11+
"""
12+
13+
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
14+
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
15+
:param transform (callable, optional): Optional transforms to be applied on a sample
16+
"""
17+
self.diffgram_file_id_list = diffgram_file_id_list
18+
self.project = project
19+
self.transform = transform
20+
21+
def __process_instance(self, instance):
22+
"""
23+
Creates a pytorch tensor based on the instance type.
24+
For now we are assuming shapes here, but we can extend it
25+
to accept custom shapes specified by the user.
26+
:param instance:
27+
:return:
28+
"""
29+
if instance['type'] == 'box':
30+
result = np.array([instance['x_min'], instance['y_min'], instance['x_max'], instance['y_max']])
31+
result = torch.tensor(result)
32+
return result
33+
34+
def __len__(self):
35+
return len(self.diffgram_file_id_list)
36+
37+
def __getitem__(self, idx):
38+
if torch.is_tensor(idx):
39+
idx = idx.tolist()
40+
41+
diffgram_file = self.project.file.get_by_id(idx, with_instances = True)
42+
if hasattr(diffgram_file, 'image'):
43+
image = imread(diffgram_file.image.get('url_signed'))
44+
else:
45+
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')
46+
47+
instance_list = diffgram_file.instance_list
48+
49+
# Process the instances of each file
50+
processed_instance_list = []
51+
for instance in instance_list:
52+
instnace_tensor = self.__process_instance(instance)
53+
processed_instance_list.append(instnace_tensor)
54+
sample = {'image': image, 'instance_list': instance_list}
55+
56+
if self.transform:
57+
sample = self.transform(sample)
58+
59+
return sample

sdk/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ opencv-python>=4.0.0.21
33
scipy>=1.1.0
44
six>=1.9.0
55
tensorflow>=1.12.0
6-
pillow
6+
pillow
7+
torch
8+
imageio

0 commit comments

Comments
 (0)