Giter Club home page Giter Club logo

Comments (7)

vmoens avatar vmoens commented on May 18, 2024 1

Yes this is closed now!

from rl.

skandermoalla avatar skandermoalla commented on May 18, 2024

@vmoens Is this issue still open? @BY571 Does #1038 close it?

from rl.

skandermoalla avatar skandermoalla commented on May 18, 2024

Nice! @vmoens Do you know if this is meant to handle truncation correctly (e.g. as introduced in Gymnasium)? I.e. allows to give a value for the truncation state, $V(s_{truncation})$, to bootstrap all the "reward-to-go", $G_t$, in that episode from it.
Or maybe, if this is deferred to the learning logic, there is a flag for states belonging to truncated episodes so that one could bootstrap their values (e.g. $\nabla \log \pi(a_t|s_t)[(G_t + (1_{s_t \in truncatedEpisode} * V(s_{truncation})) ...]$).
I'm not aware of the updates since #403.

from rl.

skandermoalla avatar skandermoalla commented on May 18, 2024

Similarly, for episodes that have not finished yet (typically the last episode in each concurrent environment), is there a way to find those and mask them out in the loss?
Thanks!

from rl.

vmoens avatar vmoens commented on May 18, 2024

I think that -- provided that you pass the correct mask to the function -- truncation should be handled properly.
@BY571 can you confirm?

This is the tranform. It looks at done or truncated here.
The functional handles these two as a "done" but as you can see, upstream the transform will do done = done | truncated.

Let us know if something is not clear!

from rl.

BY571 avatar BY571 commented on May 18, 2024

Yes, when an episode was ended (without done=True) truncated is set true on that last state the transform handles it as if that was the last state of the episode:

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[4.],
        [3.],
        [2.],
        [1.]])

If you want to mask these episodes out completely you might have to set states, reward, actions (etc) to zero. Simply setting truncated to True for all those steps would not work. Then the reward-to-go transform returns only the current reward per step as it expects that each step is a single episode with length=1:

>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],
        [1.],
        [1.],
        [1.]])
      
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> td["reward"][-3:]=0
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[1.],
        [0.],
        [0.],
        [0.]])

Let me know if this helped to clarify

from rl.

skandermoalla avatar skandermoalla commented on May 18, 2024

Thanks both for the clarifications.

So I guess the answer to my question is that the transform is aware of truncation and handles it as termination.
So, it will not bootstrap truncated episodes or mask unfinished ones. This is left to the user.

Also, it does not expect the last state to have ["next"]["truncated”] = True or ["next"][“done”] = True, it will only complain if there is not any done or truncated in the batch.

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],	# Belongs to truncated episode.
        [1.],	# Belongs to truncated episode. Next is truncated.
        [2.],	# Belongs to unfinished episode.
        [1.]])	# Belongs to unfinished episode. Next is not truncated, nor done.

Otherwise, regarding the last step truncation:

Yes, when an episode was ended (without done=True) truncated is set true

Where does this happen? It doesn't seem to be done by a collector at the last frame of a batch. (I’m new to TorchRL and in the process of deciding whether I should adopt it!)

from rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.