Skip to content

Commit a455e8d

Browse files
committed
General refactors in DiffuPy
1 parent 67752e9 commit a455e8d

3 files changed

Lines changed: 51 additions & 18 deletions

File tree

src/diffupy/matrix.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
if init_value is not None and self.rows_labels and list(self.cols_labels):
6969
mat = np.full((len(self.rows_labels), len(self.cols_labels)), init_value)
7070

71-
elif not list(mat):
71+
elif mat is None:
7272
raise ValueError('A path matrix or initialization should be provided.')
7373

7474
self.mat = np.array(mat)
@@ -114,7 +114,9 @@ def __next__(self):
114114

115115
nxt = tuple()
116116
if len(self.rows_labels) == 1:
117-
nxt += (self.mat[self.j],)
117+
nxt += (self.mat[0, self.j],)
118+
elif len(self.cols_labels) == 1:
119+
nxt += (self.mat[self.i, 0],)
118120
else:
119121
nxt += (self.mat[self.i][self.j],)
120122

@@ -284,11 +286,21 @@ def delete_col_from_label(self, label):
284286

285287
def set_cell_from_labels(self, row_label, col_label, x):
286288
"""Set cell from labels."""
287-
self.mat[self.rows_labels_ix_mapping[row_label], self.cols_labels_ix_mapping[col_label]] = x
289+
if len(self.rows_labels) == 1:
290+
self.mat[0, self.cols_labels_ix_mapping[col_label]] = x
291+
elif len(self.cols_labels) == 1:
292+
self.mat[self.rows_labels_ix_mapping[row_label], 0] = x
293+
else:
294+
self.mat[self.rows_labels_ix_mapping[row_label], self.cols_labels_ix_mapping[col_label]] = x
288295

289296
def get_cell_from_labels(self, row_label, col_label):
290297
"""Get cell from labels."""
291-
return self.mat[self.rows_labels_ix_mapping[row_label], self.cols_labels_ix_mapping[col_label]]
298+
if len(self.rows_labels) == 1:
299+
return self.mat[0, self.cols_labels_ix_mapping[col_label]]
300+
elif len(self.cols_labels) == 1:
301+
return self.mat[self.rows_labels_ix_mapping[row_label], 0]
302+
else:
303+
return self.mat[self.rows_labels_ix_mapping[row_label], self.cols_labels_ix_mapping[col_label]]
292304

293305
"""Methods"""
294306

src/diffupy/utils.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
"""Miscellaneous utils of the package."""
4-
4+
import itertools
55
import json
66
import logging
77
import pickle
@@ -123,23 +123,30 @@ def get_idx_scores_mapping(scores):
123123
return {i: score for i, score in enumerate(scores)}
124124

125125

126-
def print_dict_dimensions(entities_db, message='Total number of '):
126+
def map_intersection_type_background(background_labels: Dict[str, list], input_labels: list):
127+
"""Intersection mapping."""
128+
labels_dict = {}
129+
130+
for bck_label, bck_entities in background_labels.items():
131+
labels_dict[bck_label] = set(background_labels[bck_label]).intersection(input_labels)
132+
133+
return labels_dict
134+
135+
136+
def print_dict_dimensions(entities_db, title='Title', message=''):
127137
"""Print dimension of the dictionary."""
128138
total = 0
139+
m = f'{title}\n'
129140

130141
for k1, v1 in entities_db.items():
131-
m = ''
142+
m += f'\n{message}{k1}:\n'
132143
if isinstance(v1, dict):
133144
for k2, v2 in v1.items():
134-
m += f'{k2}({len(v2)}), '
135-
total += len(v2)
145+
m += f'{k2} ({v2})\n'
136146
else:
137-
m += f'{len(v1)} '
138-
total += len(v1)
139-
140-
log_dict({k1: m}, message)
147+
m += f'{v1}'
141148

142-
print(f'Total: {total} ')
149+
print(f'{m}\n\n')
143150

144151

145152
def log_dict(dict_to_print: dict, message: str = ''):
@@ -159,7 +166,12 @@ def get_random_value_from_dict(d: dict):
159166
return d[get_random_key_from_dict(d)]
160167

161168

162-
"""File loading utils."""
169+
def lists_combinations(list_1, list_2):
170+
"""Return all string combination from two list of strings."""
171+
return [x[0] + ' ' + x[1] for x in itertools.product(list_1, list_2)]
172+
173+
174+
"""File loading/writting utils."""
163175

164176

165177
def format_checker(fmt: str, fmt_list: list = GRAPH_FORMATS) -> None:
@@ -188,6 +200,12 @@ def from_json(path: str):
188200
return json.load(f)
189201

190202

203+
def to_json(data, path: str):
204+
"""Save json file."""
205+
with open(path, 'w') as f:
206+
json.dump(data, f)
207+
208+
191209
def from_pickle(input_path):
192210
"""Read from pickle file."""
193211
with open(input_path, 'rb') as f:
@@ -277,8 +295,11 @@ def munge_cell(cell):
277295
elif isinstance(cell, float) or isinstance(cell, int):
278296
return cell
279297

298+
elif cell is None:
299+
return 'NA'
300+
280301
else:
281-
raise TypeError('The cell type could not be processed.')
302+
raise TypeError(f'The cell "{cell}" could not be processed.')
282303

283304

284305
def parse_xls_sheet_to_df(sheet: opxl.workbook,
@@ -314,7 +335,7 @@ def parse_xls_to_df(path: str,
314335
return {sheets[ix].lower(): parse_xls_sheet_to_df(sheet, min_row, relevant_cols, irrelevant_cols)
315336
for ix, sheet in enumerate(wb)
316337
if (relevant_sheets is not None and sheets[ix] in relevant_sheets) or (
317-
irrelevant_sheets is not None and sheets[ix] in irrelevant_sheets)
338+
irrelevant_sheets is not None and sheets[ix] not in irrelevant_sheets)
318339
}
319340

320341
else:

src/diffupy/validate_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _validate_scores(scores: Matrix) -> None:
6363
if row_label in ['Nan', None]:
6464
raise ValueError("The scores in background must have col names to differentiate score sets.")
6565

66-
std_mat = Matrix(np.std(scores.mat, axis=0), ['sd'], scores.cols_labels)
66+
std_mat = Matrix([np.std(scores.mat, axis=0)], ['sd'], scores.cols_labels)
6767

6868
for sd, row_label, col_label in iter(std_mat):
6969
if sd in ['Nan', None]:

0 commit comments

Comments
 (0)