Skip to content

Enabled Weighted Sampling#635

Draft
mkolodner-sc wants to merge 6 commits into
mainfrom
mkolodner-sc/enable_weighted_sampling
Draft

Enabled Weighted Sampling#635
mkolodner-sc wants to merge 6 commits into
mainfrom
mkolodner-sc/enable_weighted_sampling

Conversation

@mkolodner-sc
Copy link
Copy Markdown
Collaborator

@mkolodner-sc mkolodner-sc commented May 12, 2026

Summary

Adds native weighted edge sampling to GiGL's distributed training pipeline via GLT's CPUWeightedSampler. When enabled, neighbors are sampled proportionally to edge weights rather than uniformly.

New API

  • DistPartitioner.register_edge_weights(edge_weights) — registers a 1D per-edge weight tensor (homogeneous or dict[EdgeType, Tensor] for heterogeneous) before calling partition_edge_index_and_edge_features(). Weights are partitioned alongside edge features in the same pass (co-partitioned, mirroring the node features + labels pattern).
  • build_dataset(weight_edge_feat_name=...) — accepts the name of an existing edge feature column to use as weights. The column is sliced out and removed from the feature tensor before registration, so it is never duplicated in memory.
  • DistNeighborLoader(with_weight=True) / DistABLPLoader(with_weight=True) — enables weighted sampling. Defaults to False; must be set explicitly.
  • BaseDistLoader.validate_with_weight() — shared validation: raises ValueError if with_weight=True but no weights are registered in the dataset; raises NotImplementedError if used with PPRSamplerOptions (weight-proportional PPR residual propagation is deferred to a future PR).

Implementation notes

  • GraphPartitionData.weights (field already existed) carries the partitioned weight tensor to DistDataset._initialize_graph(), which forwards it to GLT's init_graph(edge_weights=...).
  • DistDataset.has_edge_weights property reflects whether weights were registered at construction time.
  • SamplingConfig.with_weight is now threaded through from the loader rather than hardcoded to False.

Tests

New test file tests/unit/distributed/distributed_weighted_sampling_test.py with 8 tests:

  • Correctness (homogeneous + heterogeneous): weight=0 edges to "bad" nodes are never traversed in sampled subgraphs — verified by encoding node class in features and asserting no bad node appears after weighted sampling.
  • Partitioner edge cases: features only, weights only, neither, both (with consistency check that GraphPartitionData.edge_ids == FeaturePartitionData.ids), and heterogeneous partial weights (one edge type weighted, another not).

@mkolodner-sc mkolodner-sc changed the title [WIP] Enabled Weighted Sampling Enabled Weighted Sampling May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant