Description & Motivation
TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device, etc. developed by PyTorch.
Currently Fabric does not support them:
from lightning.fabric import Fabric
import torch
from tensordict import TensorDict
def main():
fabric = Fabric(devices=2, accelerator="cpu", strategy="ddp")
fabric.launch()
d = TensorDict({"a": torch.rand(10, 1, 3), "b": torch.rand(10, 2, 7)}, batch_size=[10])
gathered = fabric.all_gather(d)
fabric.print(gathered)
reduced = fabric.all_reduce(d)
fabric.print(reduced)
if __name__ == "__main__":
main()
gives the following error:
Traceback (most recent call last):
File "/home/belerico/Desktop/lightning-apps/lightning/examples/fabric/reinforcement_learning/test.py", line 15, in <module>
main()
File "/home/belerico/Desktop/lightning-apps/lightning/examples/fabric/reinforcement_learning/test.py", line 11, in main
gathered = fabric.all_gather(d)
File "/home/belerico/Desktop/lightning-apps/lightning/src/lightning/fabric/fabric.py", line 496, in all_gather
data = convert_to_tensors(data, device=self.device)
File "/home/belerico/Desktop/lightning-apps/lightning/src/lightning/fabric/utilities/apply_func.py", line 107, in convert_to_tensors
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/lightning_utilities/core/apply_func.py", line 73, in apply_to_collection
return elem_type(OrderedDict(out))
File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/tensordict/tensordict.py", line 2888, in __init__
self._batch_size = self._parse_batch_size(source, batch_size)
File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/tensordict/tensordict.py", line 2905, in _parse_batch_size
raise ValueError(
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.
Pitch
Let TensorDict be supported by fabric's distributed functions all_gather
and all_reduce
Alternatives
No response
Additional context
With the following addition to the https://github.com/Lightning-AI/utilities/blob/main/src/lightning_utilities/core/apply_func.py#L71 method:
+from tensordict import make_tensordict
+from tensordict.tensordict import TensorDictBase
if isinstance(data, defaultdict):
return elem_type(data.default_factory, OrderedDict(out))
+ elif isinstance(data, TensorDictBase):
+ return make_tensordict(OrderedDict(out), device=kwargs.get("device", None)) # batch_size is automatically inferred
the above scripts runs without errors:
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 10, 1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 10, 2, 7]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 10]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10, 1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 2, 7]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
I don't know if this modification is enough in every case, but I can investigate more if this it'll became a possible feature.
cc @Borda