Skip to content

Commit 5cf2c69

Browse files
committed
update sa folder logic
1 parent 6d26bba commit 5cf2c69

5 files changed

Lines changed: 256 additions & 8 deletions

File tree

stage_advantage/README.md

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ TrainConfig(
141141

142142
### Usage
143143

144-
From the **repository root**:
144+
From the **repository root**, the core training command is:
145145

146146
```bash
147147
# Single GPU (KAI0 or PI06)
@@ -163,6 +163,17 @@ uv run python scripts/train_pytorch.py ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD --exp_n
163163

164164
Logs and checkpoints go to `experiment/<config_name>/` and `experiment/<config_name>/log/<exp_name>.log`. Redirect to a log file if desired, e.g. `2>&1 | tee experiment/ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD/log/run1.log`.
165165

166+
For a ready-to-use script with environment setup (conda/venv activation, DDP configuration) and automatic log management, see **`annotation/train_estimator.sh`**:
167+
168+
```bash
169+
RUNNAME=ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD RUNTIME=run1 bash stage_advantage/annotation/train_estimator.sh
170+
171+
# Multi-GPU
172+
RUNNAME=ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD RUNTIME=run1 NPROC_PER_NODE=8 bash stage_advantage/annotation/train_estimator.sh
173+
```
174+
175+
The shell script handles output directory creation, log redirection (via `tee`), and multi-GPU/multi-node dispatch automatically.
176+
166177
### Training Outputs
167178

168179
```
@@ -212,7 +223,7 @@ experiment/ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD/ # or ADVANTAGE_TORCH_PI06_FLATTE
212223

213224
### Usage
214225

215-
From the **repository root** (or ensure Python can import the project and paths are correct):
226+
From the **repository root**, the core evaluation command is:
216227

217228
```bash
218229
uv run python stage_advantage/annotation/eval.py <model_type> <model_name> <repo_id>
@@ -228,7 +239,13 @@ uv run python stage_advantage/annotation/eval.py Flatten-Fold KAI0 /path/to/data
228239
uv run python stage_advantage/annotation/eval.py Flatten-Fold PI06 /path/to/dataset
229240
```
230241

231-
`<model_type>` is a key in `eval.py`’s `MODELS_CONFIG_MAP` (e.g. `Flatten-Fold`); `<model_name>` is `PI06` or `KAI0`; `<repo_id>` is the path to the LeRobot dataset. Results are written under `<repo_id>/data_<model_name>_<ckpt_steps>/`.
242+
`<model_type>` is a key in `eval.py`'s `MODELS_CONFIG_MAP` (e.g. `Flatten-Fold`); `<model_name>` is `PI06` or `KAI0`; `<repo_id>` is the path to the LeRobot dataset. Results are written under `<repo_id>/data_<model_name>_<ckpt_steps>/`.
243+
244+
For a ready-to-use script with environment setup (conda/venv activation, environment variables) and status logging, see **`annotation/eval.sh`**:
245+
246+
```bash
247+
bash stage_advantage/annotation/eval.sh Flatten-Fold KAI0 /path/to/dataset
248+
```
232249

233250
### Evaluation Outputs
234251

@@ -286,14 +303,22 @@ At **inference** time you must use the **same prompt format** as in training. To
286303

287304
### Usage
288305

289-
From the repository root, run JAX training with the AWBC config and an experiment name:
306+
From the repository root, the core training command is:
290307

291308
```bash
292309
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
293310
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_tee_shirt_sort_awbc --exp_name=run1
294311
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_hang_cloth_awbc --exp_name=run1
295312
```
296313

314+
For a ready-to-use script with environment setup (venv activation, `XLA_PYTHON_CLIENT_MEM_FRACTION`, `WANDB_MODE`) and automatic log management, see **`awbc/train_awbc.sh`**:
315+
316+
```bash
317+
RUNNAME=pi05_flatten_fold_awbc RUNTIME=run1 bash stage_advantage/awbc/train_awbc.sh
318+
```
319+
320+
The shell script handles output directory creation and log redirection (via `tee`) automatically.
321+
297322
---
298323

299324
## Directory Structure
@@ -304,9 +329,12 @@ stage_advantage/
304329
├── annotation/ # Stages 0–2: labeling & estimator training
305330
│ ├── README.md
306331
│ ├── gt_label.py # Core labeling script (progress → advantage → task_index)
307-
│ ├── gt_labeling.sh # Batch labeling for PI06 / KAI0 variants (only .sh kept here)
332+
│ ├── gt_labeling.sh # Batch labeling for PI06 / KAI0 variants
333+
│ ├── train_estimator.sh # Shell script for Stage 1 training (env + DDP + logging)
308334
│ ├── eval.py # Evaluate trained estimator on datasets
335+
│ ├── eval.sh # Shell script for Stage 2 evaluation (env + logging)
309336
│ └── evaluator.py # SimpleValueEvaluator: batched GPU inference
310-
└── awbc/ # Stage 3: AWBC (commands in README)
311-
└── README.md
337+
└── awbc/ # Stage 3: AWBC
338+
├── README.md
339+
└── train_awbc.sh # Shell script for Stage 3 AWBC training (env + logging)
312340
```

stage_advantage/annotation/eval.sh

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/bin/bash
2+
###############################################################################
3+
# eval.sh
4+
#
5+
# Use a trained Advantage Estimator to label a dataset with predicted
6+
# advantage values (relative_advantage, absolute_value, absolute_advantage).
7+
#
8+
# This script calls eval.py, which:
9+
# 1. Loads a trained Advantage Estimator checkpoint
10+
# 2. Iterates over all episodes in the LeRobot dataset
11+
# 3. Reads video frames from three camera views (top, left, right)
12+
# 4. Runs batched GPU inference to predict advantage values per frame
13+
# 5. Writes results as new parquet files with advantage columns appended
14+
#
15+
# The output parquets are saved under:
16+
# <repo_id>/data_<model_name>_<ckpt_steps>/chunk-*/episode_*.parquet
17+
#
18+
# Prerequisites:
19+
# - A trained Advantage Estimator checkpoint (from Stage 1)
20+
# - Update MODELS_CONFIG_MAP in eval.py with the correct checkpoint paths
21+
#
22+
# Usage:
23+
# bash eval.sh <model_type> <model_name> <repo_id>
24+
#
25+
# Examples:
26+
# bash eval.sh Flatten-Fold KAI0 /path/to/dataset
27+
# bash eval.sh Flatten-Fold PI06 /path/to/dataset
28+
#
29+
# Arguments:
30+
# model_type : Flatten-Fold / demo_A / demo_B
31+
# model_name : PI06 (single-timestep) / KAI0 (two-timestep stage-level)
32+
# repo_id : Path to the LeRobot dataset to evaluate
33+
###############################################################################
34+
set -xe
35+
set -o pipefail
36+
37+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
38+
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../../" && pwd)"
39+
cd "${PROJECT_ROOT}"
40+
echo "Project root: ${PROJECT_ROOT}"
41+
42+
# ─── Conda / venv activation ─────────────────────────────────────────────────
43+
source /cpfs01/shared/smch/miniconda3/etc/profile.d/conda.sh
44+
conda activate uv_py311
45+
source .venv/bin/activate
46+
47+
export TZ='Asia/Shanghai'
48+
49+
# ─── Other environment variables ──────────────────────────────────────────────
50+
export UV_DEFAULT_INDEX="https://mirrors.aliyun.com/pypi/simple/"
51+
export WANDB_MODE=offline
52+
53+
# ─── Parse arguments ─────────────────────────────────────────────────────────
54+
MODEL_TYPE=${1:?"Usage: bash eval.sh <model_type> <model_name> <repo_id>"}
55+
MODEL_NAME=${2:?"Usage: bash eval.sh <model_type> <model_name> <repo_id>"}
56+
REPO_ID=${3:?"Usage: bash eval.sh <model_type> <model_name> <repo_id>"}
57+
58+
echo "============================================================"
59+
echo " Advantage Estimator Evaluation"
60+
echo " Model type: ${MODEL_TYPE}"
61+
echo " Model name: ${MODEL_NAME}"
62+
echo " Dataset: ${REPO_ID}"
63+
echo "============================================================"
64+
65+
uv run python "${SCRIPT_DIR}/eval.py" "${MODEL_TYPE}" "${MODEL_NAME}" "${REPO_ID}"
66+
67+
echo "============================================================"
68+
echo " Evaluation complete!"
69+
echo " Results saved under: ${REPO_ID}/data_${MODEL_NAME}_*/"
70+
echo "============================================================"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
###############################################################################
3+
# train_estimator.sh
4+
###########################################################
5+
set -xe
6+
set -o pipefail
7+
8+
# ─── Navigate to project root ────────────────────────────────────────────────
9+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
10+
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../../" && pwd)"
11+
cd "${PROJECT_ROOT}"
12+
13+
source .venv/bin/activate
14+
15+
# ─── Training config name ────────────────────────────────────────────────────
16+
# RUNNAME must be one of: ADVANTAGE_TORCH_PI06_FLATTEN_FOLD, ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD
17+
# Default to ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD if RUNNAME is not set
18+
CFG=${RUNNAME:-ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD}
19+
20+
# ─── DDP environment variables ───────────────────────────────────────────────
21+
WORLD_SIZE=${WORLD_SIZE:-1}
22+
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
23+
RANK=${RANK:-0}
24+
NPROC_PER_NODE=${NPROC_PER_NODE:-1}
25+
MASTER_PORT=${MASTER_PORT:-12345}
26+
27+
# ─── Validate required environment variables ─────────────────────────────────
28+
if [ -z "${RUNNAME+x}" ]; then
29+
echo "[WARNING] RUNNAME is not set, using default: ${CFG}"
30+
export RUNNAME=${CFG}
31+
else
32+
echo "RUNNAME is set to: ${RUNNAME}"
33+
fi
34+
35+
if [ -z "${RUNTIME+x}" ]; then
36+
echo "[ERROR] RUNTIME is not set. Please set RUNTIME for experiment output directory."
37+
echo " Example: RUNTIME=run1 bash train_estimator.sh"
38+
exit 1
39+
else
40+
echo "RUNTIME is set to: ${RUNTIME}"
41+
fi
42+
43+
# ─── Create output directories ───────────────────────────────────────────────
44+
OUTPUT_DIR="./experiment/${RUNNAME}"
45+
LOG_OUTPUT_DIR="${OUTPUT_DIR}/log"
46+
mkdir -p "${OUTPUT_DIR}" "${LOG_OUTPUT_DIR}"
47+
48+
# Set to "offline" for offline logging; remove or set to "online" for cloud sync
49+
export WANDB_MODE=${WANDB_MODE:-offline}
50+
51+
if [ "${NPROC_PER_NODE}" -gt 1 ] || [ "${WORLD_SIZE}" -gt 1 ]; then
52+
# Multi-GPU / Multi-Node training via torchrun
53+
echo "Launching DDP training with torchrun..."
54+
uv run torchrun \
55+
--nnodes=${WORLD_SIZE} \
56+
--nproc_per_node=${NPROC_PER_NODE} \
57+
--node_rank=${RANK} \
58+
--master_addr=${MASTER_ADDR} \
59+
--master_port=${MASTER_PORT} \
60+
scripts/train_pytorch.py ${CFG} \
61+
--exp_name=${RUNTIME} \
62+
--save_interval 10000 \
63+
2>&1 | tee "${LOG_OUTPUT_DIR}/${RUNTIME}.log"
64+
else
65+
# Single-GPU training
66+
echo "Launching single-GPU training..."
67+
uv run python scripts/train_pytorch.py ${CFG} \
68+
--exp_name=${RUNTIME} \
69+
--save_interval 10000 \
70+
2>&1 | tee "${LOG_OUTPUT_DIR}/${RUNTIME}.log"
71+
fi

stage_advantage/awbc/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Each uses `base_config=DataConfig(prompt_from_task=True)` so that the dataset’
3636

3737
## Usage
3838

39-
From the **repository root**, run training with the config name and `--exp_name`:
39+
From the **repository root**, the core training command is:
4040

4141
```bash
4242
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
@@ -46,6 +46,14 @@ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_hang_cloth_awbc
4646

4747
Checkpoints and logs are written under `experiment/<config_name>/<exp_name>/` (e.g. `experiment/pi05_flatten_fold_awbc/run1/`).
4848

49+
For a ready-to-use script with environment setup (venv activation, `XLA_PYTHON_CLIENT_MEM_FRACTION`, `WANDB_MODE`) and automatic log management, see **`train_awbc.sh`**:
50+
51+
```bash
52+
RUNNAME=pi05_flatten_fold_awbc RUNTIME=run1 bash stage_advantage/awbc/train_awbc.sh
53+
```
54+
55+
The shell script handles output directory creation and log redirection (via `tee`) automatically.
56+
4957
## Prompt format (training and inference)
5058

5159
During **training**, the prompt is taken from **`meta/tasks.jsonl`**: each sample’s `task_index` is mapped to a string (written by `gt_label.py` when creating the advantage dataset).

stage_advantage/awbc/train_awbc.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
###############################################################################
3+
# train_awbc.sh
4+
#
5+
# Train a policy with Advantage-Weighted Behavior Cloning (AWBC) using
6+
# advantage-estimator-labeled data. The data must have task_index per frame and
7+
# meta/tasks.jsonl mapping task_index -> prompt string (from Stage 0 + Stage 2).
8+
#
9+
# Configs (see src/openpi/training/config.py):
10+
# pi05_flatten_fold_awbc
11+
# pi05_tee_shirt_sort_awbc
12+
# pi05_hang_cloth_awbc
13+
#
14+
# Prerequisites:
15+
# - Complete Stage 0 (GT labeling) and Stage 2 (advantage estimation on data),
16+
# then run gt_label.py with --advantage-source absolute_advantage to produce
17+
# the "advantage" dataset with task_index and tasks.jsonl.
18+
# - Set repo_id in the AWBC config to the path of that dataset
19+
# (e.g. <path_to_repo_root>/data/FlattenFold/advantage).
20+
# - Run compute_norm_states_fast.py for the chosen config before training.
21+
# - Set weight_loader in config to your π₀.5 base checkpoint.
22+
#
23+
# Usage:
24+
# RUNNAME=pi05_flatten_fold_awbc RUNTIME=run1 bash stage_advantage/awbc/train_awbc.sh
25+
###############################################################################
26+
set -xe
27+
set -o pipefail
28+
29+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
30+
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../../" && pwd)"
31+
cd "${PROJECT_ROOT}"
32+
33+
source .venv/bin/activate
34+
35+
# ─── Training config name ────────────────────────────────────────────────────
36+
# RUNNAME must be one of: pi05_flatten_fold_awbc, pi05_tee_shirt_sort_awbc, pi05_hang_cloth_awbc
37+
CFG=${RUNNAME:-pi05_flatten_fold_awbc}
38+
39+
# ─── Validate required environment variables ─────────────────────────────────
40+
if [ -z "${RUNNAME+x}" ]; then
41+
echo "[WARNING] RUNNAME is not set, using default: ${CFG}"
42+
export RUNNAME=${CFG}
43+
else
44+
echo "RUNNAME is set to: ${RUNNAME}"
45+
fi
46+
47+
if [ -z "${RUNTIME+x}" ]; then
48+
echo "[ERROR] RUNTIME is not set. Please set RUNTIME for experiment output directory."
49+
echo " Example: RUNTIME=run1 bash stage_advantage/awbc/train_awbc.sh"
50+
exit 1
51+
else
52+
echo "RUNTIME is set to: ${RUNTIME}"
53+
fi
54+
55+
# ─── Output directories ─────────────────────────────────────────────────────
56+
OUTPUT_DIR="./experiment/${RUNNAME}"
57+
LOG_OUTPUT_DIR="${OUTPUT_DIR}/log"
58+
mkdir -p "${OUTPUT_DIR}" "${LOG_OUTPUT_DIR}"
59+
60+
export WANDB_MODE=${WANDB_MODE:-offline}
61+
export XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}
62+
63+
# ─── Launch JAX training ────────────────────────────────────────────────────
64+
echo "Launching AWBC training (JAX)..."
65+
uv run scripts/train.py ${CFG} \
66+
--exp_name=${RUNTIME} \
67+
2>&1 | tee "${LOG_OUTPUT_DIR}/${RUNTIME}.log"
68+
69+
echo "============================================================"
70+
echo " AWBC training finished. Checkpoints: ${OUTPUT_DIR}/${RUNTIME}/"
71+
echo "============================================================"

0 commit comments

Comments
 (0)