Skip to content

Commit

Permalink
Pass in strict flag to allow non-strict state_dict loading
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/tnt#691

This allows users to ignore keys in the state_dict that aren't part of
the given module, or part of the module that aren't in the state_dict.

The default value for the flag (true) keeps the status quo and what the pytorch
interface uses.

Reviewed By: anshulverma

Differential Revision: D53066198

fbshipit-source-id: 8a849f46d09d6e7d9185d589b966f7a7a089d9fc
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Jan 26, 2024
1 parent 0e60109 commit 2b6a030
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
21 changes: 21 additions & 0 deletions tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ def test_nn_sequential(tmp_path: Path) -> None:
assert check_state_dict_eq(foo.state_dict(), bar.state_dict())


@pytest.mark.usefixtures("toggle_batching")
def test_strict_false(tmp_path: Path) -> None:
foo = torch.nn.Sequential(
torch.nn.Linear(128, 64),
torch.nn.Linear(64, 32),
torch.nn.Linear(32, 16),
)
bar = torch.nn.Sequential(
torch.nn.Linear(128, 64),
torch.nn.Linear(64, 32),
torch.nn.Linear(32, 16),
torch.nn.Linear(16, 8),
)
assert not check_state_dict_eq(foo.state_dict(), bar.state_dict())

expected_dict = foo.state_dict()
snapshot = Snapshot.take(str(tmp_path), {"foo": foo})
snapshot.restore({"foo": bar}, strict=False)
assert check_state_dict_eq(foo.state_dict(), expected_dict)


@pytest.mark.usefixtures("toggle_batching")
def test_adagrad(tmp_path: Path) -> None:
model = torch.nn.Sequential(
Expand Down
15 changes: 13 additions & 2 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def async_take(
unique_id=unique_id,
)

def restore(self, app_state: AppState) -> None:
def restore(self, app_state: AppState, strict: bool = True) -> None:
"""
Restores the application state from the snapshot.
Expand All @@ -312,6 +312,10 @@ def restore(self, app_state: AppState) -> None:
``app_state`` needs to be either identical to or a subset of the
``app_state`` used for :func:`Snapshot.take` when the snapshot was
taken.
strict (bool, optional): If ``True``, raises an error if the expected
state_dict keys in the snapshot do not match the actual keys in
the :class:`torch.nn.Module`. This only applies to :class:`torch.nn.Module`
and not other objects being restored in ``app_state``.
"""
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
self._validate_app_state(app_state)
Expand Down Expand Up @@ -340,6 +344,7 @@ def restore(self, app_state: AppState) -> None:
self._load_stateful(
stateful_key=key,
stateful=app_state.get(key),
strict=strict,
storage=storage,
pg=pg_wrapper,
event_loop=event_loop,
Expand All @@ -352,6 +357,7 @@ def restore(self, app_state: AppState) -> None:
self._load_stateful(
stateful_key=key,
stateful=stateful,
strict=strict,
storage=storage,
pg=pg_wrapper,
event_loop=event_loop,
Expand Down Expand Up @@ -662,6 +668,7 @@ def _load_stateful( # noqa
self,
stateful_key: str,
stateful: Optional[Stateful],
strict: bool,
storage: StoragePlugin,
pg: PGWrapper,
event_loop: asyncio.AbstractEventLoop,
Expand Down Expand Up @@ -737,7 +744,11 @@ def _load_stateful( # noqa
flattened={k: fut.obj for k, fut in futs.items()},
prefix=stateful_key,
)
stateful.load_state_dict(state_dict)

if isinstance(stateful, torch.nn.Module):
stateful.load_state_dict(state_dict, strict=strict)
else:
stateful.load_state_dict(state_dict)

@staticmethod
def _write_snapshot_metadata(
Expand Down

0 comments on commit 2b6a030

Please sign in to comment.