Skip to content

Commit e82fd8c

Browse files
authored
match_delete_rows in format_input (#24)
* match_delete_rows in format_input, solve dimensional errors in cross validation * flake8 corrections
1 parent e126565 commit e82fd8c

2 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/diffupy/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def diffuse(
165165
threshold,
166166
)
167167

168-
click.secho(f'Computing the diffusion algorithm.')
168+
click.secho('Computing the diffusion algorithm.')
169169

170170
results = run_diffusion(
171171
input_scores_dict,

src/diffupy/process_input.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def _load_data_input_from_file(path: str, **further_parse_args) -> Union[pd.Data
199199

200200
else:
201201
raise IOError(
202-
f'There is a problem with your file. Please ensure the file you submitted is correctly formatted with a'
203-
f'.csv or .tsv file extension.'
202+
'There is a problem with your file. Please ensure the file you submitted is correctly formatted with a'
203+
'.csv or .tsv file extension.'
204204
)
205205

206206

@@ -231,9 +231,9 @@ def _codify_input_data(
231231

232232
# Standardize the title of the node column labeling column to 'Label', for later processing.
233233
if LABEL not in df.columns:
234-
for l in list(df.columns):
235-
if l in NODE_LABELING:
236-
df = df.rename(columns={l: LABEL})
234+
for column_label in list(df.columns):
235+
if column_label in NODE_LABELING:
236+
df = df.rename(columns={column_label: LABEL})
237237
break
238238

239239
# If node type provided in a column, classify in a dictionary the input codification by its node type.
@@ -852,7 +852,8 @@ def format_categorical_input_vector_from_label_list(
852852
)
853853
)
854854

855-
return input_mat.match_missing_rows(kernel.rows_labels, missing_value).match_rows(kernel)
855+
return input_mat.match_delete_rows(kernel.rows_labels).match_missing_rows(kernel.rows_labels,
856+
missing_value).match_rows(kernel)
856857

857858

858859
def format_categorical_input_matrix_from_label_list(
@@ -935,7 +936,8 @@ def format_input_vector_from_label_score_dict(
935936
)
936937
)
937938

938-
return input_mat.match_missing_rows(kernel.rows_labels, missing_value).match_rows(kernel)
939+
return input_mat.match_delete_rows(kernel.rows_labels).match_missing_rows(kernel.rows_labels,
940+
missing_value).match_rows(kernel)
939941

940942

941943
def format_input_matrix_from_type_label_score_dict(

0 commit comments

Comments
 (0)