-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdemo_load_hypergraph_datasets.py
More file actions
58 lines (50 loc) · 2.07 KB
/
demo_load_hypergraph_datasets.py
File metadata and controls
58 lines (50 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
import os
import pickle
import logging
from trainers.trainer import get_data_loader
logging.basicConfig(level=logging.INFO,
format='[%(name)s:%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
def main():
# Base path where the preprocessed PyG pickle files are stored.
base_path = os.path.join("data", "hyperbert")
# List of dataset names to process.
dataset_names = ["cora_co", "dblp_a", "imdb", "pubmed"]
for ds in dataset_names:
logger.info(f"Loading dataset: {ds}")
config = {
"data_base_path": base_path,
"dataset_name": ds,
"batch_size": 1
}
# Load the underlying PyG graph to compute statistics.
pickle_path = os.path.join(config["data_base_path"], config["dataset_name"], f'{config["dataset_name"]}_pyg.pkl')
try:
with open(pickle_path, 'rb') as file:
pyg_graph = pickle.load(file)
except Exception as e:
logger.error(f"Error loading pickle for {ds}: {e}")
continue
# Compute statistics.
num_nodes = pyg_graph.x.shape[0] if hasattr(pyg_graph, "x") else 0
num_hyperedges = pyg_graph.hyperedge_index.shape[1] if hasattr(pyg_graph, "hyperedge_index") else 0
if hasattr(pyg_graph, "y"):
num_labels = len(set(int(y.item()) for y in pyg_graph.y))
else:
num_labels = 0
logger.info(f"Dataset: {ds} - Nodes: {num_nodes}, Hyperedges: {num_hyperedges}, Labels: {num_labels}")
# Load the dataloader using get_data_loader function.
try:
dataloader = get_data_loader(config)
except Exception as e:
logger.error(f"Error loading dataloader for {ds}: {e}")
continue
# Show one dataset entry.
logger.info(f"Dataset {ds}: printing one sample entry:")
entry = next(iter(dataloader))
print("Node index:", entry["node_idx"])
print("Text:", entry["text"])
print("Label:", entry["y"])
if __name__ == "__main__":
main()