|
| 1 | +from torch.utils.data import Dataset, DataLoader |
| 2 | +import torch |
| 3 | +import os |
| 4 | +from imageio import imread |
| 5 | +import numpy as np |
| 6 | + |
| 7 | + |
| 8 | +class DiffgramPytorchDataset(Dataset): |
| 9 | + |
| 10 | + def __init__(self, project, diffgram_file_id_list, transform = None): |
| 11 | + """ |
| 12 | +
|
| 13 | + :param project (sdk.core.core.Project): A Project object from the Diffgram SDK |
| 14 | + :param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram. |
| 15 | + :param transform (callable, optional): Optional transforms to be applied on a sample |
| 16 | + """ |
| 17 | + self.diffgram_file_id_list = diffgram_file_id_list |
| 18 | + self.project = project |
| 19 | + self.transform = transform |
| 20 | + |
| 21 | + def __process_instance(self, instance): |
| 22 | + """ |
| 23 | + Creates a pytorch tensor based on the instance type. |
| 24 | + For now we are assuming shapes here, but we can extend it |
| 25 | + to accept custom shapes specified by the user. |
| 26 | + :param instance: |
| 27 | + :return: |
| 28 | + """ |
| 29 | + if instance['type'] == 'box': |
| 30 | + result = np.array([instance['x_min'], instance['y_min'], instance['x_max'], instance['y_max']]) |
| 31 | + result = torch.tensor(result) |
| 32 | + return result |
| 33 | + |
| 34 | + def __len__(self): |
| 35 | + return len(self.diffgram_file_id_list) |
| 36 | + |
| 37 | + def __getitem__(self, idx): |
| 38 | + if torch.is_tensor(idx): |
| 39 | + idx = idx.tolist() |
| 40 | + |
| 41 | + diffgram_file = self.project.file.get_by_id(idx, with_instances = True) |
| 42 | + if hasattr(diffgram_file, 'image'): |
| 43 | + image = imread(diffgram_file.image.get('url_signed')) |
| 44 | + else: |
| 45 | + raise Exception('Pytorch datasets only support images. Please provide only file_ids from images') |
| 46 | + |
| 47 | + instance_list = diffgram_file.instance_list |
| 48 | + |
| 49 | + # Process the instances of each file |
| 50 | + processed_instance_list = [] |
| 51 | + for instance in instance_list: |
| 52 | + instnace_tensor = self.__process_instance(instance) |
| 53 | + processed_instance_list.append(instnace_tensor) |
| 54 | + sample = {'image': image, 'instance_list': instance_list} |
| 55 | + |
| 56 | + if self.transform: |
| 57 | + sample = self.transform(sample) |
| 58 | + |
| 59 | + return sample |
0 commit comments