Giter Club home page Giter Club logo

pytorch / rl Goto Github PK

View Code? Open in Web Editor NEW
1.9K 40.0 254.0 61.28 MB

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

Home Page: https://pytorch.org/rl

License: MIT License

Python 98.88% C++ 0.35% Shell 0.60% Batchfile 0.14% PowerShell 0.03%
ai control decision-making distributed-computing machine-learning marl model-based-reinforcement-learning multi-agent-reinforcement-learning pytorch reinforcement-learning rl robotics torch

rl's Introduction

Unit-tests Documentation Benchmarks codecov Twitter Follow Python version GitHub license pypi version pypi nightly version Downloads Downloads Discord Shield

TorchRL

Documentation | TensorDict | Features | Examples, tutorials and demos | Citation | Installation | Asking a question | Contributing

TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.

It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.

On the low-level end, torchrl comes with a set of highly re-usable functionals for cost functions, returns and data processing.

TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the full paper for a more curated description of the library.

Getting started

Check our Getting Started tutorials for quickly ramp up with the basic features of the library!

Documentation and knowledge base

The TorchRL documentation can be found here. It contains tutorials and the API reference.

TorchRL also provides a RL knowledge base to help you debug your code, or simply learn the basics of RL. Check it out here.

We have some introductory videos for you to get to know the library better, check them out:

Writing simplified and portable RL codebase with TensorDict

RL algorithms are very heterogeneous, and it can be hard to recycle a codebase across settings (e.g. from online to offline, from state-based to pixel-based learning). TorchRL solves this problem through TensorDict, a convenient data structure(1) that can be used to streamline one's RL codebase. With this tool, one can write a complete PPO training script in less than 100 lines of code!

Code
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
    LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

env = GymEnv("Pendulum-v1")
model = TensorDictModule(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 2),
        NormalParamExtractor()
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"]
)
critic = ValueOperator(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 1),
    ),
    in_keys=["observation"],
)
actor = ProbabilisticActor(
    model,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={"min": -1.0, "max": 1.0},
    return_log_prob=True
    )
buffer = TensorDictReplayBuffer(
    LazyTensorStorage(1000),
    SamplerWithoutReplacement()
    )
collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=1000,
    total_frames=1_000_000
    )
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
for data in collector:  # collect data
    for epoch in range(10):
        adv_fn(data)  # compute advantage
        buffer.extend(data.view(-1))
        for i in range(20):  # consume data
            sample = buffer.sample(50)  # mini-batch
            loss_vals = loss_fn(sample)
            loss_val = sum(
                value for key, value in loss_vals.items() if
                key.startswith("loss")
                )
            loss_val.backward()
            optim.step()
            optim.zero_grad()
    print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")

Here is an example of how the environment API relies on tensordict to carry data from one function to another during a rollout execution: Alt Text

TensorDict makes it easy to re-use pieces of code across environments, models and algorithms.

Code

For instance, here's how to code a rollout in TorchRL:

- obs, done = env.reset()
+ tensordict = env.reset()
policy = SafeModule(
    model,
    in_keys=["observation_pixels", "observation_vector"],
    out_keys=["action"],
)
out = []
for i in range(n_steps):
-     action, log_prob = policy(obs)
-     next_obs, reward, done, info = env.step(action)
-     out.append((obs, next_obs, action, log_prob, reward, done))
-     obs = next_obs
+     tensordict = policy(tensordict)
+     tensordict = env.step(tensordict)
+     out.append(tensordict)
+     tensordict = step_mdp(tensordict)  # renames next_observation_* keys to observation_*
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
+ out = torch.stack(out, 0)  # TensorDict supports multiple tensor operations

Using this, TorchRL abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing all primitives to be easily recycled across settings.

Code

Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):

- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
+ for i, tensordict in enumerate(collector):
-     replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
+     replay_buffer.add(tensordict)
    for j in range(num_optim_steps):
-         obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
-         loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
+         tensordict = replay_buffer.sample(batch_size)
+         loss = loss_fn(tensordict)
        loss.backward()
        optim.step()
        optim.zero_grad()

This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.

TensorDict supports multiple tensor operations on its device and shape (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):

Code
# stack and cat
tensordict = torch.stack(list_of_tensordicts, 0)
tensordict = torch.cat(list_of_tensordicts, 0)
# reshape
tensordict = tensordict.view(-1)
tensordict = tensordict.permute(0, 2, 1)
tensordict = tensordict.unsqueeze(-1)
tensordict = tensordict.squeeze(-1)
# indexing
tensordict = tensordict[:2]
tensordict[:, 2] = sub_tensordict
# device and memory location
tensordict.cuda()
tensordict.to("cuda:1")
tensordict.share_memory_()

TensorDict comes with a dedicated tensordict.nn module that contains everything you might need to write your model with it. And it is functorch and torch.compile compatible!

Code
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
- out = transformer_model(src, tgt)
+ td_module(tensordict)
+ out = tensordict["out"]

The TensorDictSequential class allows to branch sequences of nn.Module instances in a highly modular way. For instance, here is an implementation of a transformer using the encoder and decoder blocks:

encoder_module = TransformerEncoder(...)
encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
decoder_module = TransformerDecoder(...)
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
transformer = TensorDictSequential(encoder, decoder)
assert transformer.in_keys == ["src", "src_mask", "tgt"]
assert transformer.out_keys == ["memory", "output"]

TensorDictSequential allows to isolate subgraphs by querying a set of desired input / output keys:

transformer.select_subsequence(out_keys=["memory"])  # returns the encoder
transformer.select_subsequence(in_keys=["tgt", "memory"])  # returns the decoder

Check TensorDict tutorials to learn more!

Features

  • A common interface for environments which supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution (e.g. Model-based environments). The batched environments containers allow parallel execution(2). A common PyTorch-first class of tensor-specification class is also provided. TorchRL's environments API is simple but stringent and specific. Check the documentation and tutorial to learn more!

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    env_parallel = ParallelEnv(4, env_make)  # creates 4 envs in parallel
    tensordict = env_parallel.rollout(max_steps=20, policy=None)  # random rollout (no policy given)
    assert tensordict.shape == [4, 20]  # 4 envs, 20 steps rollout
    env_parallel.action_spec.is_in(tensordict["action"])  # spec check returns True
  • multiprocess and distributed data collectors(2) that work synchronously or asynchronously. Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised learning (although the "dataloader" -- read data collector -- is modified on-the-fly):

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    collector = MultiaSyncDataCollector(
        [env_make, env_make],
        policy=policy,
        devices=["cuda:0", "cuda:0"],
        total_frames=10000,
        frames_per_batch=50,
        ...
    )
    for i, tensordict_data in enumerate(collector):
        loss = loss_module(tensordict_data)
        loss.backward()
        optim.step()
        optim.zero_grad()
        collector.update_policy_weights_()

    Check our distributed collector examples to learn more about ultra-fast data collection with TorchRL.

  • efficient(2) and generic(1) replay buffers with modularized storage:

    Code
    storage = LazyMemmapStorage(  # memory-mapped (physical) storage
        cfg.buffer_size,
        scratch_dir="/tmp/"
    )
    buffer = TensorDictPrioritizedReplayBuffer(
        alpha=0.7,
        beta=0.5,
        collate_fn=lambda x: x,
        pin_memory=device != torch.device("cpu"),
        prefetch=10,  # multi-threaded sampling
        storage=storage
    )

    Replay buffers are also offered as wrappers around common datasets for offline RL:

    Code
    from torchrl.data.replay_buffers import SamplerWithoutReplacement
    from torchrl.data.datasets.d4rl import D4RLExperienceReplay
    data = D4RLExperienceReplay(
        "maze2d-open-v0",
        split_trajs=True,
        batch_size=128,
        sampler=SamplerWithoutReplacement(drop_last=True),
    )
    for sample in data:  # or alternatively sample = data.sample()
        fun(sample)
  • cross-library environment transforms(1), executed on device and in a vectorized fashion(2), which process and prepare the data coming out of the environments to be used by the agent:

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    env_base = ParallelEnv(4, env_make, device="cuda:0")  # creates 4 envs in parallel
    env = TransformedEnv(
        env_base,
        Compose(
            ToTensorImage(),
            ObservationNorm(loc=0.5, scale=1.0)),  # executes the transforms once and on device
    )
    tensordict = env.reset()
    assert tensordict.device == torch.device("cuda:0")

    Other transforms include: reward scaling (RewardScaling), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of successive operations (CatFrames), resizing (Resize) and many more.

    Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it easy to add and remove them at will:

    env.insert_transform(0, NoopResetEnv())  # inserts the NoopResetEnv transform at the index 0

    Nevertheless, transforms can access and execute operations on the parent environment:

    transform = env.transform[1]  # gathers the second transform of the list
    parent_env = transform.parent  # returns the base environment of the second transform, i.e. the base env + the first transform
  • various tools for distributed learning (e.g. memory mapped tensors)(2);

  • various architectures and models (e.g. actor-critic)(1):

    Code
    # create an nn.Module
    common_module = ConvNet(
        bias_last_layer=True,
        depth=None,
        num_cells=[32, 64, 64],
        kernel_sizes=[8, 4, 3],
        strides=[4, 2, 1],
    )
    # Wrap it in a SafeModule, indicating what key to read in and where to
    # write out the output
    common_module = SafeModule(
        common_module,
        in_keys=["pixels"],
        out_keys=["hidden"],
    )
    # Wrap the policy module in NormalParamsWrapper, such that the output
    # tensor is split in loc and scale, and scale is mapped onto a positive space
    policy_module = SafeModule(
        NormalParamsWrapper(
            MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
        ),
        in_keys=["hidden"],
        out_keys=["loc", "scale"],
    )
    # Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
    # SafeProbabilisticModule, indicating how to build the
    # torch.distribution.Distribution object and what to do with it
    policy_module = SafeProbabilisticTensorDictSequential(  # stochastic policy
        policy_module,
        SafeProbabilisticModule(
            in_keys=["loc", "scale"],
            out_keys="action",
            distribution_class=TanhNormal,
        ),
    )
    value_module = MLP(
        num_cells=[64, 64],
        out_features=1,
        activation=nn.ELU,
    )
    # Wrap the policy and value funciton in a common module
    actor_value = ActorValueOperator(common_module, policy_module, value_module)
    # standalone policy from this
    standalone_policy = actor_value.get_policy_operator()
  • exploration wrappers and modules to easily swap between exploration and exploitation(1):

    Code
    policy_explore = EGreedyWrapper(policy)
    with set_exploration_type(ExplorationType.RANDOM):
        tensordict = policy_explore(tensordict)  # will use eps-greedy
    with set_exploration_type(ExplorationType.MODE):
        tensordict = policy_explore(tensordict)  # will not use eps-greedy
  • A series of efficient loss modules and highly vectorized functional return and advantage computation.

    Code

    Loss modules

    from torchrl.objectives import DQNLoss
    loss_module = DQNLoss(value_network=value_network, gamma=0.99)
    tensordict = replay_buffer.sample(batch_size)
    loss = loss_module(tensordict)

    Advantage computation

    from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
    advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
  • a generic trainer class(1) that executes the aforementioned training loop. Through a hooking mechanism, it also supports any logging or data transformation operation at any given time.

  • various recipes to build models that correspond to the environment being deployed.

If you feel a feature is missing from the library, please submit an issue! If you would like to contribute to new features, check our call for contributions and our contribution page.

Examples, tutorials and demos

A series of examples are provided with an illustrative purpose:

and many more to come!

Check the examples markdown directory for more details about handling the various configuration settings.

We also provide tutorials and demos that give a sense of what the library can do.

Citation

If you're using TorchRL, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Installation

Create a conda environment where the packages will be installed.

conda create --name torch_rl python=3.9
conda activate torch_rl

PyTorch

Depending on the use of functorch that you want to make, you may want to install the latest (nightly) PyTorch release or the latest stable version of PyTorch. See here for a detailed list of commands, including pip3 or other special installation instructions.

Torchrl

You can install the latest stable release by using

pip3 install torchrl

This should work on linux, Windows 10 and OsX (Intel or Silicon chips). On certain Windows machines (Windows 11), one should install the library locally (see below).

The nightly build can be installed via

pip install torchrl-nightly

which we currently only ship for Linux and OsX (Intel) machines. Importantly, the nightly builds require the nightly builds of PyTorch too.

To install extra dependencies, call

pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,checkpointing]"

or a subset of these.

One may also desire to install the library locally. Three main reasons can motivate this:

  • the nightly/stable release isn't available for one's platform (eg, Windows 11, nightlies for Apple Silicon etc.);
  • contributing to the code;
  • install torchrl with a previous version of PyTorch (note that this should also be doable via a regular install followed by a downgrade to a previous pytorch version -- but the C++ binaries will not be available.)

To install the library locally, start by cloning the repo:

git clone https://github.com/pytorch/rl

Go to the directory where you have cloned the torchrl repo and install it (after installing ninja)

cd /path/to/torchrl/
pip install ninja -U
python setup.py develop

(unfortunately, pip install -e . will not work).

On M1 machines, this should work out-of-the-box with the nightly build of PyTorch. If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e')) appears, then try

ARCHFLAGS="-arch arm64" python setup.py develop

To run a quick sanity check, leave that directory (e.g. by executing cd ~/) and try to import the library.

python -c "import torchrl"

This should not return any warning or error.

Optional dependencies

The following libraries can be installed depending on the usage one wants to make of torchrl:

# diverse
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher

# rendering
pip3 install moviepy

# deepmind control suite
pip3 install dm_control

# gym, atari games
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame

# tests
pip3 install pytest pyyaml pytest-instafail

# tensorboard
pip3 install tensorboard

# wandb
pip3 install wandb

Troubleshooting

If a ModuleNotFoundError: No module named โ€˜torchrl._torchrl errors occurs (or a warning indicating that the C++ binaries could not be loaded), it means that the C++ extensions were not installed or not found.

  • One common reason might be that you are trying to import torchrl from within the git repo location. The following code snippet should return an error if torchrl has not been installed in develop mode:
    cd ~/path/to/rl/repo
    python -c 'from torchrl.envs.libs.gym import GymEnv'
    
    If this is the case, consider executing torchrl from another location.
  • If you're not importing torchrl from within its repo location, it could be caused by a problem during the local installation. Check the log after the python setup.py develop. One common cause is a g++/C++ version discrepancy and/or a problem with the ninja library.
  • If the problem persists, feel free to open an issue on the topic in the repo, we'll make our best to help!
  • On MacOs, we recommend installing XCode first. With Apple Silicon M1 chips, make sure you are using the arm64-built python (e.g. here). Running the following lines of code
    wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
    python collect_env.py
    
    should display
    OS: macOS *** (arm64)
    
    and not
    OS: macOS **** (x86_64)
    

Versioning issues can cause error message of the type undefined symbol and such. For these, refer to the versioning issues document for a complete explanation and proposed workarounds.

Asking a question

If you spot a bug in the library, please raise an issue in this repo.

If you have a more generic question regarding RL in PyTorch, post it on the PyTorch forum.

Contributing

Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs. You can checkout the detailed contribution guide here. As mentioned above, a list of open contributions can be found in here.

Contributors are recommended to install pre-commit hooks (using pre-commit install). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending -n to your commit command: git commit -m <commit message> -n

Disclaimer

This library is released as a PyTorch beta feature. BC-breaking changes are likely to happen but they will be introduced with a deprecation warranty after a few release cycles.

License

TorchRL is licensed under the MIT License. See LICENSE for details.

rl's People

Contributors

albertbou92 avatar apbard avatar benjamin-eecs avatar blonck avatar by571 avatar danilbaibak avatar degensean avatar fedebotu avatar franktiantt avatar matteobettini avatar nairbv avatar nicolas-dufour avatar ordinskiy avatar osalpekar avatar remidomingues avatar riiswa avatar robandpdx avatar rohitnig avatar romainjln avatar seemethere avatar shagunsodhani avatar skandermoalla avatar sriramsk1999 avatar tcbegley avatar vmoens avatar xiaomengy avatar xmaples avatar yohann-benchetrit avatar yushiyangk avatar zeenolife avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rl's Issues

Prevent DataCollectors from passing private tensordict keys

Some tensordict keys may benefit from staying hidden by default to the agent to prevent overloading the shared memory for instance. These include tracked variables of OU exploration noise, for instance.
We should implement a collector rule that every tensordict key that starts with an underscore is private and should not be passed.

[Feature Request] TensorDict gradient support

We've overlooked gradient support of tensordicts so far but some checks should be run to make sure that all runs as expected.

New features

  • is_leaf() method, that tells whether a tensordict has values that require a gradient or not

New tests

  • behaviour when setting a tensor that requires grad in a SubTensorDict / SavedTensorDict instance

[Feature Request] Make `SavedTensorDict` lazy

Nearly all ops in SavedTensorDict could be done lazily:

class SavedTensorDict(_TensorDict):
    def register_op(self, op, args, kwargs):
         self._op_register.append((op, args, kwargs))

    def some_op(self, ..):
        self.register_op('some_op', ...)

    def _load(self):
        td = load(..)
        for op, args, kwarg in self._op_register:
            td = getattr(td, op)(*args, **kwargs)
        return td
    
    def _save_td(self, td):
        ....
        self_op_register = []

Remove singleton last dims in tensordicts

There are several occurrences of td.get(key).squeeze(-1) in the code due to early bugs in tensordicts where some of the values had shape that matched exactly the batch size.
If this is resolved, we should allow envs/policies to write tensors with a single value and no shape.

Unify losses that have target parameters

In the past we had two clearly separated loss classes for DQN and DoubleDQN.
Our convert_to_functional API allows us to easily create a target_network_parameters list that can play efficiently with functorch, thereby making it easy to create a copy of the parameters that can be used as target.

However, we still have separate objects for DQNLoss and DoubleDQNLoss, DDPGLoss and DoubleDDPGLoss, SACLoss and DoubleSACLoss, REDQLoss and DoubleREDQLoss.
This goes against our philosophy of having as few inherited classes as possible. All of them can easily be merged in a single class.

It would make more sense to tell the constructor whether the target parameters should be created.

The plan would be to move the class attributes delay_value (for DQN), delay_actor and delay_value (for DDPG) to the __init__ constructor of the parent class and get rid of the subclass.
For SAC, we would just need to have the args delay_actor, delay_qvalue, delay_value in the constructor.

Here is a list of classes to update for the record:

  • DQNLoss and DoubleDQNLoss
  • DDPGLoss and DoubleDDPGLoss
  • SACLoss and DoubleSACLoss
  • REDQLoss and DoubleREDQLoss.

Importantly, those changes should be reflected in the tests and in the scripts!

[Feature Request] Please provide the genetic and low-level functionality rather than the high-level interface like agent.train()

hi, it's really great that facebookresearch is considering provide a library for reinforcement learning research.

it would be very helpful if the library provide the low-level functionality rather than the high-level interface like agent.train() which then became another existing library like stable-baselines3 (https://github.com/DLR-RM/stable-baselines3).

would you mind refer to the philosophy and design of another rl library cherry, which only provides the general-purpose low-level functionality. https://github.com/learnables/cherry

have a good day.

Batched environment is not closed

Error Log:

Traceback (most recent call last):
  File "/private/home/sodhani/projects/rl/torchrl/envs/vec_env.py", line 256, in __del__
RuntimeError: Batched environment must be explicitely closed before it turns out of scope.

Full Trace:

(torch_rl) โžœ  rl git:(main) โœ— python examples/ppo/ppo.py --config=examples/ppo/configs/humanoid.txt --total_frames=10

/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
...
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
collector = MultiSyncDataCollector();
loss_module = ClipPPOLoss(
  (actor): ProbabilisticActor(module=NormalParamWrapper(
    (operator): MLP(
      (0): Linear(in_features=67, out_features=400, bias=True)
      (1): Tanh()
      (2): Linear(in_features=400, out_features=300, bias=True)
      (3): Tanh()
      (4): Linear(in_features=300, out_features=42, bias=True)
    )
  ), distribution_class=<class 'torchrl.modules.distributions.continuous.TanhNormal'>, device=cuda:0)
  (critic): ValueOperator(
      module=MLP(
        (0): Linear(in_features=67, out_features=400, bias=True)
        (1): Tanh()
        (2): Linear(in_features=400, out_features=300, bias=True)
        (3): Tanh()
        (4): Linear(in_features=300, out_features=1, bias=True)
      ),
      device=cuda:0,
      in_keys=['observation_vector'],
      out_keys=['state_value'])
);
recorder = TransformedEnv(env=DMControlEnv(env=humanoid, task=walk, batch_size=torch.Size([])), transform=Compose(
	VideoRecorder(keys=['next_pixels']),
	RewardScaling(loc=0.0000, scale=1.0000, keys=['reward']),
	CatTensors(in_keys=['next_com_velocity', 'next_extremities', 'next_head_height', 'next_joint_angles', 'next_torso_vertical', 'next_velocity'], out_key=next_observation_vector),
	ObservationNorm(keys=['next_observation_vector']),
	DoubleToFloat(keys=['reward', 'action', 'next_observation_vector']),
	FiniteTensorDictCheck(keys=[])));
target_net_updater = None;
policy_exploration = ProbabilisticActor(module=NormalParamWrapper(
  (operator): MLP(
    (0): Linear(in_features=67, out_features=400, bias=True)
    (1): Tanh()
    (2): Linear(in_features=400, out_features=300, bias=True)
    (3): Tanh()
    (4): Linear(in_features=300, out_features=42, bias=True)
  )
), distribution_class=<class 'torchrl.modules.distributions.continuous.TanhNormal'>, device=cuda:0);
replay_buffer = None;
writer = <torch.utils.tensorboard.writer.SummaryWriter object at 0x7f8c20dc7430>;
args = Namespace(config='examples/ppo/configs/humanoid.txt', optim_steps_per_collection=10, optimizer='adam', selected_keys=None, batch_size=256, log_interval=10000, lr=0.0003, weight_decay=2e-05, clip_norm=1.0, clip_grad_norm=False, normalize_rewards_online=True, sub_traj_len=-1, collector_devices='cpu', pin_memory=False, init_with_lag=True, frames_per_batch=250, total_frames=2, num_workers=32, env_per_collector=8, seed=42, exploration_mode=None, async_collection=False, env_library='dm_control', env_name='humanoid', env_task='walk', from_pixels=False, frame_skip=4, reward_scaling=1.0, init_env_steps=250, vecnorm=False, norm_rewards=False, noops=0, max_frames_per_traj=250, loss='clip', gamma=0.99, lamda=0.95, entropy_factor=0.0001, loss_function='smooth_l1', tanh_loc=True, gSDE=False, default_policy_scale=1.0, distribution='tanh_normal', lstm=False, shared_mapping=False, record_video=True, exp_name='', record_interval=200, record_frames=250);

  0%|                                                                                                                                                                                           | 0/8 [00:00<?, ?it/s]1 is done!
2 is done!
0 is done!
3 is done!
grad_norm: 4.5215, optim_steps: 10.0000, loss_objective: 0.5818, entropy: 13.2145, loss_entropy: -0.0013, loss_critic: 0.3802, ESS: 0.6759, reward_training: 0.1401: : 1024it [00:00, 1958.13it/s]Exception ignored in: <function _BatchedEnv.__del__ at 0x7f8c22dcea60>
Traceback (most recent call last):
  File "/private/home/sodhani/projects/rl/torchrl/envs/vec_env.py", line 256, in __del__
RuntimeError: Batched environment must be explicitely closed before it turns out of scope.
grad_norm: 4.5215, optim_steps: 10.0000, loss_objective: 0.5818, entropy: 13.2145, loss_entropy: -0.0013, loss_critic: 0.3802, ESS: 0.6759, reward_training: 0.1401: : 1024it [00:01, 575.76it/s]

[Feature Request] Default creation of LazyStackedTensorDict vs TensorDict when calling torch.stack?

By default, torch.stack(list_of_tds, dim) will return a LazyStackedTensorDict object.
Keeping the discussion about the usefulness of this class apart, it may be more desirable to have a LazyStackedTensorDict created only when it is explicitely asked by the user, e.g.

stacked_td = torch.stack(list_of_tds, dim)
assert isinstance(stacked_td, TensorDict) # passes

stacked_td = torch.stack(list_of_tds, dim, contiguous=False)
assert isinstance(stacked_td, LazyStackedTensorDict) # passes

The main advantage is that working with LazyStackedTensorDict can be computationally expensive and users may not be aware of this bottleneck.
Also, when called on lists of torch.Tensor objects, torch.stack does not return a non-contiguous object. Aligning TensorDict behaviour to this might be beneficial for consistency purposes.

Functorch integration

Using functorch would enable torchrl to support meta-rl algorithms.
But many of the existing features could also benefit from functorch, e.g. networks and target networks calls; multiple q-value networks; etc.
This would simplify the Double classes for instance, and bring many more features in the future. For instance, it would permit to code REDQ efficiently.
There is, however, a choice to be made: if we adopt functorch, do we want every module to be functional, or do we want to restrict this to the necessary cases?

`TensorDict.values()` method

We have a _TensorDict.items() (and corresponding _TensorDict.items_meta()) method that returns an iterator of key-value tuples.
It would be nice to have also a _Tensordict.values() method that would only return an iterator of the tensordict values.

[Feature Request] Make objective modules compatible with dictionaries

The usage of the TensorDict class simplifies the process of passing data across processes, designing general classes that are oblivious to the keys used in a specific algorithm (e.g. whether or not a action_log_prob / hidden_state key should be expected).
However, introducing new classes can prevent users from copy-paste and re-use modules (see here). We should make sure that TensorDict is used only when absolutely necessary. These cases include situations where all the content of a dictionary will be treated in a similar way:

  • indexing
  • reshaping
  • sending from worker to worker, device to device
  • concatenation / stacking
    In general, TensorDict should be used for high-level classes: Agent, DataCollector, possibly probabilistic operator modules.

Objectives should not require TensorDicts in general.

However, in some cases they may need to check the trajectory length (1st dimension) or the batch size (0th dimension), or even the device. An option in those cases would be to infer those from a specific tensor in the dictionary (e.g. reward?)

Plan

  • Test and fix modules such that they all accept a dictionary as input.
  • Modify typing in this perspective.
  • Currently, modules return a TensorDict but we could perfectly return a regular dict.

tensordict vs tensor_dict

We keep on switching between tensordict and tensor_dict in the code (docstrings, code names and file naming).
We should opt for one.

torchvision dependency

torchvision should only be an optional dependency, but as of now it must be installed to import torchrl primitives.

[Feature Request] Integrate new environments

  • BRAX
    BRAX environments can be differentiated through, which is awesome.
    Also, they are fully functional (i.e. they take state and action as input). This would play well with TensorDict, we'd just need to keep track of what the previous state was (and make sure it hasn't been replaced by a transformed version).
    Here are some examples
  • Jumanji
  • Unity - ML Agent
  • habitat-lab

Test the agent class

There are no tests for the agent class yet as it requires lots of components to run.

[Feature Request] Remove `env.dtype` attribute

Initially the plan was for an environment to have some dtype associated. This was supposed to support actions, obs and reward.
However it makes little sense when observations are multiple tensors. Also, an action could be a floating point number and the observation a uint8 tensor.
We should rely only on env.observation_spec.dtype (or action_spec) to infer dtype.

Unittests for binarized reward

For some RL tasks, it is more important to know that there is a non-null reward rather than knowing its magnitude.
The BinarizeReward transforms makes all rewards either 1 (non-negative) or 0 (null or negative).
This class is currently not tested. The test should include (1) testing that the output rewards match what is expected (2) test that the reward_spec environment attribute is modified accordingly (should be of type BinaryDiscreteTensorSpec).

[Feature Request] Trainer saving and loading utils

The Trainer class supports saving model and environment variables in order for them to be re-used. However this workflow is still experimental.
Eventually, it should be possible to save and load an agent state and restart training where it was left.

Implement gSDE exploration

gSDE uses a fixed noise throughout the trajectory to avoid having erratic behaviours of the actor.
gSDE can be used with SAC, DDPG, PPO and others. We should think in advance how implementing this would modify the workflow, what implementation changes would be needed, what new class should be created etc.

Extend tensor-like behaviour of `TensorDict` classes

TensorDict classes support a series of tensor operations such as tensordict.view() or tensordict.masked_fill_().
However, these are only implemented in a built-in fashion as of now.
It should be possible to seamlessly call torch.view(tensordict, ...) using __torch_function__.
The operations to support include:

  • masked_fill_: currently only the in-place method is implemented. We need to make sure to clone the tensordict before calling masked_fill_ such that the resulting tensordict is not the original one.
  • squeeze
  • unsqueeze
  • unbind
  • stack
  • cat
  • masked_select
  • clone
  • permute

PPO example is broken

Error:

(torch_rl) โžœ  rl git:(main) โœ— python examples/ppo/ppo.py --config=examples/ppo/configs/humanoid.txt
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/lazy.py:178: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
  interpolation: int = Image.BILINEAR,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:295: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
  interpolation: int = Image.NEAREST,
/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchvision/transforms/functional_pil.py:328: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
  interpolation: int = Image.BICUBIC,
collector = MultiSyncDataCollector();
loss_module = ClipPPOLoss(
  (actor): ProbabilisticActor(module=NormalParamWrapper(
    (operator): MLP(
      (0): Linear(in_features=67, out_features=400, bias=True)
      (1): Tanh()
      (2): Linear(in_features=400, out_features=300, bias=True)
      (3): Tanh()
      (4): Linear(in_features=300, out_features=42, bias=True)
    )
  ), distribution_class=<class 'torchrl.modules.distributions.continuous.TanhNormal'>, device=cuda:0)
  (critic): ValueOperator(
      module=MLP(
        (0): Linear(in_features=67, out_features=400, bias=True)
        (1): Tanh()
        (2): Linear(in_features=400, out_features=300, bias=True)
        (3): Tanh()
        (4): Linear(in_features=300, out_features=1, bias=True)
      ),
      device=cuda:0,
      in_keys=['observation_vector'],
      out_keys=['state_value'])
);
recorder = TransformedEnv(env=DMControlEnv(env=humanoid, task=walk, batch_size=torch.Size([])), transform=Compose(
	VideoRecorder(keys=['next_pixels']),
	RewardScaling(loc=0.0000, scale=1.0000, keys=['reward']),
	CatTensors(in_keys=['next_com_velocity', 'next_extremities', 'next_head_height', 'next_joint_angles', 'next_torso_vertical', 'next_velocity'], out_key=next_observation_vector),
	ObservationNorm(keys=['next_observation_vector']),
	DoubleToFloat(keys=['reward', 'action', 'next_observation_vector']),
	FiniteTensorDictCheck(keys=[])));
target_net_updater = None;
policy_exploration = ProbabilisticActor(module=NormalParamWrapper(
  (operator): MLP(
    (0): Linear(in_features=67, out_features=400, bias=True)
    (1): Tanh()
    (2): Linear(in_features=400, out_features=300, bias=True)
    (3): Tanh()
    (4): Linear(in_features=300, out_features=42, bias=True)
  )
), distribution_class=<class 'torchrl.modules.distributions.continuous.TanhNormal'>, device=cuda:0);
replay_buffer = None;
writer = <torch.utils.tensorboard.writer.SummaryWriter object at 0x7f197916ef10>;
args = Namespace(config='examples/ppo/configs/humanoid.txt', optim_steps_per_collection=10, optimizer='adam', selected_keys=None, batch_size=256, log_interval=10000, lr=0.0003, weight_decay=2e-05, clip_norm=1.0, clip_grad_norm=False, normalize_rewards_online=True, sub_traj_len=-1, collector_devices='cpu', pin_memory=False, init_with_lag=True, frames_per_batch=250, total_frames=12500000, num_workers=32, env_per_collector=8, seed=42, exploration_mode=None, async_collection=False, env_library='dm_control', env_name='humanoid', env_task='walk', from_pixels=False, frame_skip=4, reward_scaling=1.0, init_env_steps=250, vecnorm=False, norm_rewards=False, noops=0, max_frames_per_traj=250, loss='clip', gamma=0.99, lamda=0.95, entropy_factor=0.0001, loss_function='smooth_l1', tanh_loc=True, gSDE=False, default_policy_scale=1.0, distribution='tanh_normal', lstm=False, shared_mapping=False, record_video=True, exp_name='', record_interval=200, record_frames=250);

  0%|                                                                                                                                                                                    | 0/50000000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/private/home/sodhani/projects/rl/examples/ppo/ppo.py", line 131, in <module>
    agent.train()
  File "/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchrl/agents/agents.py", line 340, in train
    self.steps(batch)
  File "/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchrl/agents/agents.py", line 408, in steps
    losses_td = self.loss_module(sub_batch_device)
  File "/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/private/home/sodhani/.conda/envs/torch_rl/lib/python3.9/site-packages/torchrl/objectives/costs/ppo.py", line 223, in forward
    raise RuntimeError(
RuntimeError: The key state_value returns a value that requires a gradient, consider detaching.

Fix `__call__` to `forward` in `_LossModule` instances

_LossModule usually implements a __call__ method, which conflicts with the usual implementation of nn.Module where the forward method has to be overwritten.
Using forward would allow us to decorate the functions with pre/post-forward hooks that handle the case where the input is a dictionary.

Simplify `ProbabilisticOperator` class

The class ProbabilisticOperator would benefit for a deep refactoring.
This class has two responsibilities: it is the interface between nn.Module instances and TensorDict objects, and it samples outputs from a target distributions to fill the tensordict used as input.
I propose the following re-factoring: a parent class TDModule would play the role of reading and writing in the input tensordict.
A child class ProbabilisticTDModule would read, sample and write in the input tensordict. At first, we could keep the current logic where everything is probabilistic but the Delta distribution is used by default.
A TDSequence class would act similarly as nn.Sequence, in that it would read and write over a tensordict sequentially.
This class would ultimately substitute (or be the parent of) ActorCritic and ActorValue operators.

Removing _CustomOpTensorDict classes

_CustomOpTensorDict is an abstract class for some operations (mainly on TensorDict shapes).
It may not be necessary to keep it and maintaining it may be bothersome in the future.

Pros:

  • Most operations that we'd be doing would not change the storage anyway in most cases. For instance, a call on the view method of a regular TensorDict's values will return a view on those tensors, with no extra memory cost.
  • Code clarity: this is one more class that inherits from _TensorDict, and each new operation will (in the current implementation) require its own class. This is not ideal.

Cons:

  • In the case of LazyStackedTensorDict, we would probably like to have a view method that keeps the storage location unchanged. Not using _CustomOpTensorDict will require us to call contiguous on the stack and then call the view operation, but this will assign a new storage to the resulting tensors. The final option could be to prevent users from calling these fixed-storage operations on LazyStackedTensorDict but that could lead to some issues in the future (i.e. users asking for this feature when it has been decided to deprecate it previously).
  • The ProbabilisticOperator class can write to tensordict instances that are unsqueezed. This makes it easy to do in-place modifications of a tensordict without having to create multiple copies of the tensors. This feature is used to execute a policy on a single environment that has no batch size.

`_TensorDict.permute(...)` operation (+ transpose)

The _TensorDict parent class supports a bunch of tensor operations such as view and unsqueeze.
Modifications of those tensordicts will result in changes in the original tensordict (provided they don't violate the tensordict specs).

unsqueeze:

>>> t = TensorDict({'a': torch.randn(3, 4)}, [3])
>>> t.unsqueeze(0).set('b', torch.randn(1, 3, 4));
>>> print(t)
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
>>> t.unsqueeze(0).fill_('a', 0.0);
>>> print(t.get('a'))
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

view:

>>> t = TensorDict({'a': torch.randn(3, 4)}, [3])
>>> t.view(-1).set('b', torch.randn(12));
>>> print(t)
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
>>> t.view(-1).fill_('a', 0.0);
>>> print(t.get('a'))
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

A similar behaviour should be implemented for _TensorDict.permute(...). The expected behaviour would be:
view:

>>> t = TensorDict({'a': torch.randn(3, 4, 1)}, [3, 4])
>>> t.permute(1, 0).set('b', torch.randn(4, 3));
>>> print(t)
TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
>>> t.permute(1, 0).fill_('a', 0.0);
>>> print(t.get('a'))
tensor([[[0.],
         [0.],
         [0.],
         [0.]],
        [[0.],
         [0.],
         [0.],
         [0.]],
        [[0.],
         [0.],
         [0.],
         [0.]]])

The same should work with transpose (which would permute two dimensions only).

This would require a new _CustomOpTensorDict subclass, similar to ViewedTensorDict and UnsqueezedTensorDict.

value of eps involving tanh in torchrl.modules.distributions

`requires_grad` attribute for meta tensors

MetaTensors is a special TorchRL class that tracks the metadata of tensors stored in the TensorDict.
Its main reason of existence is to avoid the overhead of loading tensors just to read an attribute. We already have a SavedTensordict class, for instance, that represents tensors saved on disk. If we had to load the whole tensordict from disk every time an attribute (shape, etc.) is gathered this could create an unecessary overhead.
We carefully choose which attribute should be part of MetaTensors. Some may not be easy to infer and are therefore not part of the metatensor.
It may be useful to implement a requires_grad attribute in MetaTensor. Its purpose would simply be to tell if a tensor is part of a graph or not.

In the tests, one should make sure that these aspects are accounted for:

  • For LazyStackTensorDict objects, the resulting tensors will require a gradient if a single tensor of the list requires a gradient. The respective MetaTensor should follow the same logic.
  • In theory, MemmapTensors and SavedTensorDicts are not compatible with gradients that are part of computational graphs. This means that when writing a tensor that has requires_grad in a SavedTensorDict, an error should be raised. Same if a regular TensorDict is cast using tensordict.to(SavedTensorDict) or tensordict.memmap_().

Missing tests

Many tests were not written but placeholders were put in place to fill the gaps.
Because of circleci integration, we may want to write those tests ASAP.

The following tests should be written:

  • test the recipes from the agent module
  • test agent class
  • test_categorical in test_distributions
  • test_gym envs (currently commented because very long to run)
  • tensordict tests with cuda
  • modules with cuda
  • distributions with cuda
  • replay buffer with cuda
  • costs with cuda
  • recipes with cuda
  • mask_fill_ for all TD classes
  • doc build

[Feature Request] Trajectories, episodes and rollout naming

The term "trajectory" and "episode" conflict. The term "rollout" also appears from time to time in the code and should be clarified or changed too.
See this thread: a trajectory could be used for a possibly incomplete episode, which should have an initial and terminal state.
Rollout should maybe be kept for MCTS usage.

consistent frame count in data collectors

When asking for a certain number of frames to the data collectors, we should ensure that the appropriate number is returned by each. A test should be written for this.

Test TensorDict `items()`, `values()` and `keys()`

There is currently no test for items(), values() and keys() (+ items_meta(), values_meta())

Those methods should be tested for every TensorDict subclass using the TestTesorDicts class.

We should test the following things:

  • The keys must be sorted. This is important as the other tests will rely on that. We'd like to catch that if they aren't sorted. I would (1) add a new key-value pair in the tensordict to make sure that keys will keep on being sorted after that operation, (2) check that the keys are sorted with something like all(keys[i] <= keys[i+1] for i in range(len(keys) - 1)) (when keys have been converted to a list).
  • The following loops match: [k, td.get(k) for k in td.keys()], list(td.items()), list(zip(td.keys(), td.values()))
  • After editing the tensordict with set and update the above check still passes

This should be done for : ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td", "td_reset_bs"]

Any other test would be welcome!

Test optional dependencies

Write a set of checks for the optional dependencies
This would involve installing all the requires deps but not the optional and run all tests.

Unittest for CatFrames

The test for the CatFrames transform has been left out.
CatFrames was proposed in DQN for Atari games as a way to integrate the game dynamic in the observations (in a way, a single frame of Atari does not contain enough information -- it is a POMDP). CatFrame keeps the previous N frames in a buffer and returns a stack of those frames, concatenated along a specific dimension.

To test this class, we should assert that the concatenation returns what is expected.
When there is no previous observation, the same observation is repeated several times. One should check that this works as expected.
We should check that the observation specs is transformed as should be (one can check the other transform tests for inspiration).

Policy, actor, model naming

The terms policy, actor and model should be used in appropriate contexts.
For instance, we have an Actor class but a policy is passed to the DataCollector instances. A precise definition should be written in the documentation and the code usage should follow it.
The term model should be restricted to the actual network architecture (i.e. the mapping_operator in the ProbabilisticOperator classes).
Implicitly, so far, we have considered that an actor is a type of policy that has a model (for instance, a random policy would not have a model).

See this post for some definitions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.