From 2b6a0300f625e3667111fe48af49f4e83f4a8a3c Mon Sep 17 00:00:00 2001 From: Fernando Hernandez Date: Fri, 26 Jan 2024 10:09:20 -0800 Subject: [PATCH] Pass in `strict` flag to allow non-strict state_dict loading Summary: X-link: https://github.com/pytorch/tnt/pull/691 This allows users to ignore keys in the state_dict that aren't part of the given module, or part of the module that aren't in the state_dict. The default value for the flag (true) keeps the status quo and what the pytorch interface uses. Reviewed By: anshulverma Differential Revision: D53066198 fbshipit-source-id: 8a849f46d09d6e7d9185d589b966f7a7a089d9fc --- tests/test_snapshot.py | 21 +++++++++++++++++++++ torchsnapshot/snapshot.py | 15 +++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index aaa3d0b..ef3b41b 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -75,6 +75,27 @@ def test_nn_sequential(tmp_path: Path) -> None: assert check_state_dict_eq(foo.state_dict(), bar.state_dict()) +@pytest.mark.usefixtures("toggle_batching") +def test_strict_false(tmp_path: Path) -> None: + foo = torch.nn.Sequential( + torch.nn.Linear(128, 64), + torch.nn.Linear(64, 32), + torch.nn.Linear(32, 16), + ) + bar = torch.nn.Sequential( + torch.nn.Linear(128, 64), + torch.nn.Linear(64, 32), + torch.nn.Linear(32, 16), + torch.nn.Linear(16, 8), + ) + assert not check_state_dict_eq(foo.state_dict(), bar.state_dict()) + + expected_dict = foo.state_dict() + snapshot = Snapshot.take(str(tmp_path), {"foo": foo}) + snapshot.restore({"foo": bar}, strict=False) + assert check_state_dict_eq(foo.state_dict(), expected_dict) + + @pytest.mark.usefixtures("toggle_batching") def test_adagrad(tmp_path: Path) -> None: model = torch.nn.Sequential( diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index ce9a8ce..a6984e0 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -303,7 +303,7 @@ def async_take( unique_id=unique_id, ) - def restore(self, app_state: AppState) -> None: + def restore(self, app_state: AppState, strict: bool = True) -> None: """ Restores the application state from the snapshot. @@ -312,6 +312,10 @@ def restore(self, app_state: AppState) -> None: ``app_state`` needs to be either identical to or a subset of the ``app_state`` used for :func:`Snapshot.take` when the snapshot was taken. + strict (bool, optional): If ``True``, raises an error if the expected + state_dict keys in the snapshot do not match the actual keys in + the :class:`torch.nn.Module`. This only applies to :class:`torch.nn.Module` + and not other objects being restored in ``app_state``. """ torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore") self._validate_app_state(app_state) @@ -340,6 +344,7 @@ def restore(self, app_state: AppState) -> None: self._load_stateful( stateful_key=key, stateful=app_state.get(key), + strict=strict, storage=storage, pg=pg_wrapper, event_loop=event_loop, @@ -352,6 +357,7 @@ def restore(self, app_state: AppState) -> None: self._load_stateful( stateful_key=key, stateful=stateful, + strict=strict, storage=storage, pg=pg_wrapper, event_loop=event_loop, @@ -662,6 +668,7 @@ def _load_stateful( # noqa self, stateful_key: str, stateful: Optional[Stateful], + strict: bool, storage: StoragePlugin, pg: PGWrapper, event_loop: asyncio.AbstractEventLoop, @@ -737,7 +744,11 @@ def _load_stateful( # noqa flattened={k: fut.obj for k, fut in futs.items()}, prefix=stateful_key, ) - stateful.load_state_dict(state_dict) + + if isinstance(stateful, torch.nn.Module): + stateful.load_state_dict(state_dict, strict=strict) + else: + stateful.load_state_dict(state_dict) @staticmethod def _write_snapshot_metadata(