You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: stage_advantage/README.md
+35-7Lines changed: 35 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -141,7 +141,7 @@ TrainConfig(
141
141
142
142
### Usage
143
143
144
-
From the **repository root**:
144
+
From the **repository root**, the core training command is:
145
145
146
146
```bash
147
147
# Single GPU (KAI0 or PI06)
@@ -163,6 +163,17 @@ uv run python scripts/train_pytorch.py ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD --exp_n
163
163
164
164
Logs and checkpoints go to `experiment/<config_name>/` and `experiment/<config_name>/log/<exp_name>.log`. Redirect to a log file if desired, e.g. `2>&1 | tee experiment/ADVANTAGE_TORCH_KAI0_FLATTEN_FOLD/log/run1.log`.
165
165
166
+
For a ready-to-use script with environment setup (conda/venv activation, DDP configuration) and automatic log management, see **`annotation/train_estimator.sh`**:
uv run python stage_advantage/annotation/eval.py Flatten-Fold PI06 /path/to/dataset
229
240
```
230
241
231
-
`<model_type>` is a key in `eval.py`’s `MODELS_CONFIG_MAP` (e.g. `Flatten-Fold`); `<model_name>` is `PI06` or `KAI0`; `<repo_id>` is the path to the LeRobot dataset. Results are written under `<repo_id>/data_<model_name>_<ckpt_steps>/`.
242
+
`<model_type>` is a key in `eval.py`'s `MODELS_CONFIG_MAP` (e.g. `Flatten-Fold`); `<model_name>` is `PI06` or `KAI0`; `<repo_id>` is the path to the LeRobot dataset. Results are written under `<repo_id>/data_<model_name>_<ckpt_steps>/`.
243
+
244
+
For a ready-to-use script with environment setup (conda/venv activation, environment variables) and status logging, see **`annotation/eval.sh`**:
@@ -286,14 +303,22 @@ At **inference** time you must use the **same prompt format** as in training. To
286
303
287
304
### Usage
288
305
289
-
From the repository root, run JAX training with the AWBC config and an experiment name:
306
+
From the repository root, the core training command is:
290
307
291
308
```bash
292
309
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
293
310
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_tee_shirt_sort_awbc --exp_name=run1
294
311
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_hang_cloth_awbc --exp_name=run1
295
312
```
296
313
314
+
For a ready-to-use script with environment setup (venv activation, `XLA_PYTHON_CLIENT_MEM_FRACTION`, `WANDB_MODE`) and automatic log management, see **`awbc/train_awbc.sh`**:
Copy file name to clipboardExpand all lines: stage_advantage/awbc/README.md
+9-1Lines changed: 9 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -36,7 +36,7 @@ Each uses `base_config=DataConfig(prompt_from_task=True)` so that the dataset’
36
36
37
37
## Usage
38
38
39
-
From the **repository root**, run training with the config name and `--exp_name`:
39
+
From the **repository root**, the core training command is:
40
40
41
41
```bash
42
42
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_flatten_fold_awbc --exp_name=run1
@@ -46,6 +46,14 @@ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_hang_cloth_awbc
46
46
47
47
Checkpoints and logs are written under `experiment/<config_name>/<exp_name>/` (e.g. `experiment/pi05_flatten_fold_awbc/run1/`).
48
48
49
+
For a ready-to-use script with environment setup (venv activation, `XLA_PYTHON_CLIENT_MEM_FRACTION`, `WANDB_MODE`) and automatic log management, see **`train_awbc.sh`**:
The shell script handles output directory creation and log redirection (via `tee`) automatically.
56
+
49
57
## Prompt format (training and inference)
50
58
51
59
During **training**, the prompt is taken from **`meta/tasks.jsonl`**: each sample’s `task_index` is mapped to a string (written by `gt_label.py` when creating the advantage dataset).
0 commit comments