Comments (2)
Since your state is also in the observation I think you should just add it to your observation specs:
self.observation_spec = CompositeSpec({
"observation": CompositeSpec({
"positions": BoundedTensorSpec(
low=0.0,
high=1.0,
shape=torch.Size([NUM_BLOCKS, 2]),
dtype=torch.float32
),
"sizes": BoundedTensorSpec(
low=0.1,
high=1.0,
shape=torch.Size([NUM_BLOCKS, 2]),
dtype=torch.float32
)
}),
"state": CompositeSpec({
"distance_from_center": UnboundedContinuousTensorSpec(
shape=torch.Size([NUM_BLOCKS]),
dtype=torch.float32
),
}),
})
The reason we don't do it automatically is that you could have a state that is at the root (not in "next") and copied at every call of step_mdp
(carried on unchanged throughout the rollout).
Here is an example using your code:
import torch
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from tensordict import TensorDict
COUNTER = 0
NUM_BLOCKS = 4
class BlockArrangementEnv(EnvBase):
def __init__(self):
super().__init__()
self.observation_spec = CompositeSpec({
"observation": CompositeSpec({
"positions": BoundedTensorSpec(
low=0.0,
high=1.0,
shape=torch.Size([NUM_BLOCKS, 2]),
dtype=torch.float32
),
"sizes": BoundedTensorSpec(
low=0.1,
high=1.0,
shape=torch.Size([NUM_BLOCKS, 2]),
dtype=torch.float32
)
}),
"state": CompositeSpec({
"distance_from_center": UnboundedContinuousTensorSpec(
shape=torch.Size([NUM_BLOCKS]),
dtype=torch.float32
),
}),
})
self.full_state_spec = CompositeSpec({
"state": CompositeSpec({
"distance_from_center": UnboundedContinuousTensorSpec(
shape=torch.Size([NUM_BLOCKS]),
dtype=torch.float32
),
}),
"context": UnboundedContinuousTensorSpec(
shape=torch.Size([NUM_BLOCKS]),
dtype=torch.float32
)
})
self.action_spec = CompositeSpec({
"action": CompositeSpec({
"index": BoundedTensorSpec(
low=0,
high=NUM_BLOCKS - 1,
shape=torch.Size([1]),
dtype=torch.int
),
"delta": BoundedTensorSpec(
low=-1.0,
high=1.0,
shape=torch.Size([2]),
dtype=torch.float32
)
})
})
self.reward_spec = UnboundedContinuousTensorSpec(
shape=torch.Size([NUM_BLOCKS]),
dtype=torch.float32
)
def _reset(self, td):
global COUNTER
COUNTER += 1
return TensorDict({
"observation": {
"positions": torch.rand([NUM_BLOCKS, 2]),
"sizes": torch.FloatTensor(NUM_BLOCKS, 2).uniform_(0.1, 1.0),
},
"state": {
"distance_from_center": torch.rand([NUM_BLOCKS]),
},
"context": self.full_state_spec["context"].zero() + COUNTER,
}, batch_size=[])
def _step(self, td, **kwargs):
return TensorDict({
"observation": {
"positions": torch.rand([NUM_BLOCKS, 2]),
"sizes": torch.FloatTensor(NUM_BLOCKS, 2).uniform_(0.1, 1.0),
},
"state": {
"distance_from_center": torch.rand([NUM_BLOCKS]),
},
"reward": torch.rand([NUM_BLOCKS]),
"done": torch.tensor(False)
}, batch_size=[])
def _set_seed(self, seed):
pass
env = BlockArrangementEnv()
check_env_specs(env)
assert (env.rollout(3)["context"] == COUNTER).all()
from rl.
Ah, thanks, that makes sense and there's no compelling reason for me to have distance_from_center as separate state for this case other than to learn how the mechanisms work for observation and state.
from rl.
Related Issues (20)
- [BUG] Incorrect default value for `normalize_advantage`.
- [Feature Request] Return depth from RoboHiveEnv
- [QUESTION] How to reset only certain nested parts of a key with TensorDictPrimer? HOT 3
- [Feature Request] DDPG with discrete actions HOT 4
- Achieving Episode-Based Sampling with SliceSampler in TorchRL HOT 2
- [BUG] Documentation Error: `MaskedEnv` Example Under ActionMask Transform Throws TypeError HOT 1
- [BUG] Buffer crashes on `extend` HOT 4
- Tutorial of implementing learning using torchrl in a PettingZoo environment HOT 4
- [BUG] default snapshot backend doesn't match docs HOT 1
- [Feature Request] Move `sota-check` inside `sota-implementations` HOT 8
- Errors reported in this section of USING PRETRAINED MODELS HOT 1
- [Feature Request] Handling of unserializable policies HOT 8
- In the Doc “RECURRENT DQN: TRAINING RECURRENT POLICIES” HOT 1
- [BUG] `DTypeCastTransform` changes arbitrary keys HOT 5
- [BUG] Transpose bug in `reward2go` when the last dim is not 1
- [BUG] VSCode automcompletions don't work HOT 2
- [BUG] `ConvNet` memory issues HOT 5
- [Bug/Question] Target workflow for LSTM Modules?
- [Docs] - Clarify multi collector HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from rl.