diff --git a/torchsnapshot/io_preparers/sharded_tensor.py b/torchsnapshot/io_preparers/sharded_tensor.py index b16d380..eb969c4 100644 --- a/torchsnapshot/io_preparers/sharded_tensor.py +++ b/torchsnapshot/io_preparers/sharded_tensor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from functools import reduce from operator import mul -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch from torch.distributed._shard.sharded_tensor import ( @@ -199,7 +199,7 @@ def prepare_read( cls, entry: ShardedTensorEntry, obj_out: Optional[ShardedTensor] = None, - ) -> Tuple[List[ReadReq], Future[ShardedTensor | torch.Tensor]]: + ) -> Tuple[List[ReadReq], Future[Union[ShardedTensor, torch.Tensor]]]: # Note: in case obj_out is None, a Future[Tensor] will be returned if obj_out is None: obj_out = ShardedTensorIOPreparer.empty_tensor_from_sharded_tensor_entry(