Skip to content

Commit b96e65f

Browse files
committed
[minor]: update sa readme
1 parent 2e08ecf commit b96e65f

2 files changed

Lines changed: 20 additions & 15 deletions

File tree

stage_advantage/README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
# Stage Advantage Pipeline
22

3-
This module implements a two-stage pipeline for training an **Advantage Estimator** and using it in **Advantage-Weighted Behavior Cloning (AWBC)**.
3+
This module implements a pipeline for training an **Advantage Estimator** and using it in **Advantage-Weighted Behavior Cloning (AWBC)**.
44

55
## Pipeline Overview
66

77
```
88
┌──────────────────────────────────────────────────────────────────────────┐
99
│ Stage 0: GT Labeling (annotation/gt_labeling.sh + gt_label.py) │
10-
│ Compute advantage from progress and assign task_index labels
10+
│ Compute advantage (from progress or from Stage 2 output) → task_index
1111
├──────────────────────────────────────────────────────────────────────────┤
1212
│ Stage 1: Train Advantage Estimator (annotation/train_estimator.sh) │
13-
│ Fine-tune pi0 model to predict advantage from observations │
13+
│ Fine-tune pi0 model to predict advantage from observations
1414
├──────────────────────────────────────────────────────────────────────────┤
15-
│ Stage 2: Advantage Estimation on New Data (annotation/eval.py) │
16-
│ Use trained estimator to label new datasets with advantage values
15+
│ Stage 2: Advantage Estimation on New Data (annotation/eval.py)
16+
│ Use trained estimator → parquets with data_PI06_* / data_KAI0_*
1717
├──────────────────────────────────────────────────────────────────────────┤
18-
│ Stage 3: AWBC Training (awbc/train_awbc.sh)
19-
│ Train policy with advantage-weighted behavior cloning
18+
│ Stage 3: AWBC Training (scripts/train.py pi05_*_awbc)
19+
│ Train policy with advantage-weighted behavior cloning (prompt_from_task)
2020
└──────────────────────────────────────────────────────────────────────────┘
2121
```
2222

23+
**End-to-end order for AWBC:** (1) Stage 0 on data with `progress` → optional for Stage 1. (2) Stage 1 → train estimator. (3) Stage 2 → run eval on your dataset so it gets `data_PI06_100000/` or `data_KAI0_100000/` with advantage columns. (4) Run Stage 0 again with `--advantage-source absolute_advantage` on that dataset (e.g. via `gt_labeling.sh` with `DATA_PATH` = the repo you ran eval on, and source subdirs `data_PI06_100000` / `data_KAI0_100000`). (5) Point AWBC config `repo_id` at the resulting advantage-labeled directory and run Stage 3 training.
24+
2325
---
2426

2527
## Stage 0: GT Data Labeling
2628

27-
**Goal**: Compute advantage values from raw trajectory progress and label each frame with a discretized `task_index`.
29+
**Goal**: Compute advantage values (from `progress` or from Stage 2’s `absolute_advantage`) and label each frame with a discretized `task_index`; write `meta/tasks.jsonl` (prompt strings per `task_index`).
2830

2931
**Script**: `annotation/gt_labeling.sh` (calls `annotation/gt_label.py`)
3032

33+
**For AWBC:** Run Stage 2 (eval) first so the dataset has `data_PI06_100000/` or `data_KAI0_100000/` with advantage columns. Then run Stage 0 with `--advantage-source absolute_advantage` on that output (e.g. set `gt_labeling.sh`’s `DATA_PATH` to the eval repo and use source subdirs `data_PI06_100000` / `data_KAI0_100000`; the script copies them into the target’s `data/` and runs `gt_label.py`).
34+
3135
### How it works
3236

33-
1. **Prepare dataset directory**: Copy/link the raw dataset (parquet + videos + meta) into a new working directory with standard LeRobot layout.
37+
1. **Prepare dataset directory**: Copy/link the source (parquet + videos + meta) into a new working directory with standard LeRobot layout. For AWBC, the source parquets are the Stage 2 output (with `absolute_advantage`).
3438
2. **Compute advantage**: For each frame `i`, the advantage is defined as:
3539
```
3640
advantage[i] = progress[i + chunk_size] - progress[i]
@@ -276,9 +280,9 @@ At **inference** time you must use the **same prompt format** as in training. To
276280

277281
### Before training
278282

279-
1. Produce the advantage dataset (Stage 0 + Stage 2) and place it at e.g. `./data/FlattenFold/advantage`.
280-
2. In `config.py`, set **`repo_id`** to that path and **`weight_loader`** to your π₀.5 base checkpoint for the three AWBC configs you use.
281-
3. Compute norm stats:
283+
1. **Produce the advantage dataset:** Run Stage 2 (eval) on your dataset so it has `data_PI06_100000/` or `data_KAI0_100000/`. Then run Stage 0 (e.g. `gt_labeling.sh`) with `DATA_PATH` = that repo and source subdirs `data_PI06_100000` / `data_KAI0_100000`; the script outputs a directory with `data/` (parquets with `task_index`), `meta/tasks.jsonl`, and `videos`. Use that directory as the advantage dataset (e.g. copy or link it to `./data/FlattenFold/advantage`).
284+
2. In `config.py`, set **`repo_id`** to that advantage dataset path and **`weight_loader`** to your π₀.5 base checkpoint for the AWBC config(s) you use.
285+
3. **Compute norm stats:**
282286
`uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_awbc`
283287
(and similarly for `pi05_tee_shirt_sort_awbc` / `pi05_hang_cloth_awbc` if needed.)
284288

stage_advantage/annotation/gt_labeling.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ prepare_and_label() {
5555
}
5656

5757
# ─── Dataset variants (only PI06 and KAI0) ─────────────────────────────────────
58-
# PI06: single-timestep / timeline-style labeling (1 stage)
59-
prepare_and_label "data_1T_TL_100000" "${base_name}_PI06_binary" ""
58+
# Source subdirs must match Stage 2 (eval) output: data_PI06_100000 / data_KAI0_100000
59+
# PI06: single-timestep labeling (1 stage)
60+
prepare_and_label "data_PI06_100000" "${base_name}_PI06_binary" ""
6061

6162
# KAI0: two-stage, stage-level labeling
62-
prepare_and_label "data_2T_SL_100000" "${base_name}_KAI0_abs_binary" "--stage-nums 2"
63+
prepare_and_label "data_KAI0_100000" "${base_name}_KAI0_abs_binary" "--stage-nums 2"
6364

6465
echo "============================================================"
6566
echo " All datasets labeled successfully!"

0 commit comments

Comments
 (0)