Skip to content

Commit db66db1

Browse files
committed
Support loading from gcs
1 parent d3fef93 commit db66db1

1 file changed

Lines changed: 4 additions & 10 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..pipelines.wan.wan_pipeline import WanPipeline
2424
from .. import max_logging, max_utils
2525
import orbax.checkpoint as ocp
26+
from etils import epath
2627

2728
WAN_CHECKPOINT = "WAN_CHECKPOINT"
2829

@@ -33,7 +34,7 @@ def __init__(self, config, checkpoint_type):
3334
self.config = config
3435
self.checkpoint_type = checkpoint_type
3536

36-
self.checkpoint_manager = create_orbax_checkpoint_manager(
37+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
3738
self.config.checkpoint_dir,
3839
enable_checkpointing=True,
3940
save_interval_steps=1,
@@ -68,17 +69,10 @@ def load_wan_configs_from_orbax(self, step):
6869
)
6970
)
7071

71-
params_restore_util_way = load_params_from_path(
72-
self.config,
73-
self.checkpoint_manager,
74-
abstract_tree_structure_params,
75-
"wan_state",
76-
step
77-
)
78-
7972
max_logging.log("Restoring WAN checkpoint")
8073
restored_checkpoint = self.checkpoint_manager.restore(
81-
step,
74+
directory=epath.Path(self.config.checkpoint_dir),
75+
step=step,
8276
args=ocp.args.Composite(
8377
wan_state=params_restore,
8478
# wan_state=params_restore_util_way,

0 commit comments

Comments
 (0)