Skip to content

Commit 6fc04eb

Browse files
committed
fix bug related to column names in datamodule - turn int to string
1 parent a4c5992 commit 6fc04eb

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

mambular/data_utils/datamodule.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def setup(self, stage: str):
212212
else torch.long
213213
)
214214

215-
cat_key = (
216-
"cat_" + key
215+
cat_key = "cat_" + str(
216+
key
217217
) # Assuming categorical keys are prefixed with 'cat_'
218218
if cat_key in train_preprocessed_data:
219219
train_cat_tensors.append(
@@ -224,7 +224,7 @@ def setup(self, stage: str):
224224
torch.tensor(val_preprocessed_data[cat_key], dtype=dtype)
225225
)
226226

227-
binned_key = "num_" + key # for binned features
227+
binned_key = "num_" + str(key) # for binned features
228228
if binned_key in train_preprocessed_data:
229229
train_cat_tensors.append(
230230
torch.tensor(train_preprocessed_data[binned_key], dtype=dtype)
@@ -237,8 +237,8 @@ def setup(self, stage: str):
237237

238238
# Populate tensors for numerical features, if present in processed data
239239
for key in self.num_feature_info: # type: ignore
240-
num_key = (
241-
"num_" + key
240+
num_key = "num_" + str(
241+
key
242242
) # Assuming numerical keys are prefixed with 'num_'
243243
if num_key in train_preprocessed_data:
244244
train_num_tensors.append(
@@ -306,21 +306,25 @@ def preprocess_new_data(self, X, embeddings):
306306
)
307307
else torch.long
308308
)
309-
cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_'
309+
cat_key = "cat_" + str(
310+
key
311+
) # Assuming categorical keys are prefixed with 'cat_'
310312
if cat_key in preprocessed_data:
311313
cat_tensors.append(
312314
torch.tensor(preprocessed_data[cat_key], dtype=dtype)
313315
)
314316

315-
binned_key = "num_" + key # for binned features
317+
binned_key = "num_" + str(key) # for binned features
316318
if binned_key in preprocessed_data:
317319
cat_tensors.append(
318320
torch.tensor(preprocessed_data[binned_key], dtype=dtype)
319321
)
320322

321323
# Populate tensors for numerical features, if present in processed data
322324
for key in self.num_feature_info: # type: ignore
323-
num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_'
325+
num_key = "num_" + str(
326+
key
327+
) # Assuming numerical keys are prefixed with 'num_'
324328
if num_key in preprocessed_data:
325329
num_tensors.append(
326330
torch.tensor(preprocessed_data[num_key], dtype=torch.float32)

0 commit comments

Comments
 (0)