Giter Club home page Giter Club logo

Comments (2)

vmoens avatar vmoens commented on September 18, 2024

Since your state is also in the observation I think you should just add it to your observation specs:

        self.observation_spec = CompositeSpec({
            "observation": CompositeSpec({
                "positions": BoundedTensorSpec(
                    low=0.0,
                    high=1.0,
                    shape=torch.Size([NUM_BLOCKS, 2]),
                    dtype=torch.float32
                ),
                "sizes": BoundedTensorSpec(
                    low=0.1,
                    high=1.0,
                    shape=torch.Size([NUM_BLOCKS, 2]),
                dtype=torch.float32
                )
            }),
            "state": CompositeSpec({
                "distance_from_center": UnboundedContinuousTensorSpec(
                    shape=torch.Size([NUM_BLOCKS]),
                    dtype=torch.float32
                ),
            }),
        })

The reason we don't do it automatically is that you could have a state that is at the root (not in "next") and copied at every call of step_mdp (carried on unchanged throughout the rollout).
Here is an example using your code:

import torch
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from tensordict import TensorDict

COUNTER = 0
NUM_BLOCKS = 4

class BlockArrangementEnv(EnvBase):

    def __init__(self):
        super().__init__()

        self.observation_spec = CompositeSpec({
            "observation": CompositeSpec({
                "positions": BoundedTensorSpec(
                    low=0.0,
                    high=1.0,
                    shape=torch.Size([NUM_BLOCKS, 2]),
                    dtype=torch.float32
                ),
                "sizes": BoundedTensorSpec(
                    low=0.1,
                    high=1.0,
                    shape=torch.Size([NUM_BLOCKS, 2]),
                dtype=torch.float32
                )
            }),
            "state": CompositeSpec({
                "distance_from_center": UnboundedContinuousTensorSpec(
                    shape=torch.Size([NUM_BLOCKS]),
                    dtype=torch.float32
                ),
            }),
        })

        self.full_state_spec = CompositeSpec({
            "state": CompositeSpec({
                "distance_from_center": UnboundedContinuousTensorSpec(
                    shape=torch.Size([NUM_BLOCKS]),
                    dtype=torch.float32
                ),
            }),
            "context": UnboundedContinuousTensorSpec(
                    shape=torch.Size([NUM_BLOCKS]),
                    dtype=torch.float32
                )
        })

        self.action_spec = CompositeSpec({
            "action": CompositeSpec({
                "index": BoundedTensorSpec(
                    low=0,
                    high=NUM_BLOCKS - 1,
                    shape=torch.Size([1]),
                    dtype=torch.int
                ),
                "delta": BoundedTensorSpec(
                    low=-1.0,
                    high=1.0,
                    shape=torch.Size([2]),
                    dtype=torch.float32
                )
            })
        })

        self.reward_spec = UnboundedContinuousTensorSpec(
            shape=torch.Size([NUM_BLOCKS]),
            dtype=torch.float32
        )


    def _reset(self, td):
        global COUNTER
        COUNTER += 1
        return TensorDict({
            "observation": {
                "positions": torch.rand([NUM_BLOCKS, 2]),
                "sizes": torch.FloatTensor(NUM_BLOCKS, 2).uniform_(0.1, 1.0),
            },
            "state": {
                "distance_from_center": torch.rand([NUM_BLOCKS]),
            },
            "context": self.full_state_spec["context"].zero() + COUNTER,
        }, batch_size=[])

    def _step(self, td, **kwargs):
        return TensorDict({
            "observation": {
                "positions": torch.rand([NUM_BLOCKS, 2]),
                "sizes": torch.FloatTensor(NUM_BLOCKS, 2).uniform_(0.1, 1.0),
            },
            "state": {
                "distance_from_center": torch.rand([NUM_BLOCKS]),
            },
            "reward": torch.rand([NUM_BLOCKS]),
            "done": torch.tensor(False)
        }, batch_size=[])

    def _set_seed(self, seed):
        pass


env = BlockArrangementEnv()
check_env_specs(env)
assert (env.rollout(3)["context"] == COUNTER).all()

from rl.

mneilly avatar mneilly commented on September 18, 2024

Ah, thanks, that makes sense and there's no compelling reason for me to have distance_from_center as separate state for this case other than to learn how the mechanisms work for observation and state.

from rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.