Skip to content

Commit

Permalink
wrap dtensor import for OSS compatibility
Browse files Browse the repository at this point in the history
Summary: The `compute_local_shape_and_global_offse()` is part of torch nightly, so to prevent any import error for pytorch stable users let's guard this import

Reviewed By: diego-urgell

Differential Revision: D56484333

fbshipit-source-id: 6be460235b7e20ad8dd36091fc4136c091a03397
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 24, 2024
1 parent 7c14dd4 commit 987bb5b
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion torchsnapshot/io_preparers/dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,20 @@
Replicate,
Shard as ShardPlacement,
)
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset

try:
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset

except ImportError:

def compute_local_shape_and_global_offset(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
raise RuntimeError(
"Please use the latest nightly pytorch release to use this feature."
)


from torchsnapshot.io_preparers.sharded_tensor import (
_OverlappingRegion,
ShardedTensorBufferConsumer,
Expand Down

0 comments on commit 987bb5b

Please sign in to comment.