michaeltmatthews / craftax Goto Github PK
View Code? Open in Web Editor NEW(Crafter + NetHack) in JAX. ICML 2024 Spotlight.
License: MIT License
(Crafter + NetHack) in JAX. ICML 2024 Spotlight.
License: MIT License
Hi!
In the paper, you mention:
For perspective, it took one of the authors (with extensive knowledge of the game mechanics) roughly 5 hours of gameplay to first achieve a ‘perfect’ run where every achievement was completed. This was playing in a GUI that allowed for unlimited time to pause and think before taking each action.
And I see the play_craftax()
function records the trajectories of environment states, actions and rewards. Could you please release that data for the "perfect run" you mention in the paper?
Thank you! :)
When importing craftax functions such as the one below, my GPU memory immediately fills up for a large part. E.g. when I run the single line below
from craftax.craftax_env import make_craftax_env_from_name
and then check the gpu memory, I get that 60GB is already taken up on a H100 machine:
Is this normal? And if so, what's causing this huge memory allocation?
Thanks a lot for your help!
Hi Team,
Thank you for your excellent work. Could you please provide or direct me to a detailed explanation of each component in the environment's state, possibly through a .md file?
In play_craftax.py
, on lines 181--182, the script checks if the environment termination signal was returned by the step method, and so if the player reaches zero health the game ends.
In play_craftax_classic.py
, there is no such check, and indeed the player can continue to play with zero or negative health.
When I play_craftax
, I noticed the following two performance problems:
The first environment step is really slow. I guess it's due to JAX tracing + compilation. I wonder if compilation is necessary for interactive mode. If human experts play at ~5 steps per second for a single environment (author speed cited in craftax paper), is it necessary to use jit?
Subsequent steps are fast, but the game runs very intense on my laptop (as measured by CPU usage, laptop temperature, battery drain, etc.) making it uncomfortable to play for more than a couple of minutes. Since there is not a lot of processing going on between turns, I think this should be considered a bug.
I think the cause is as follows. I was looking at the play_craftax game loop. I noticed that this loop polls for input and renders the state in a loop that does not appear to have any rate limiting mechanism. Therefore it runs my processor as fast as it can, repeatedly re-rendering the state, even though nothing is changing. Two potential solutions are (1) inserting a fraction-of-a-second time.sleep
at the end of the loop, or, better yet, if available (2) using a pygame primitive that blocks the process waiting for input rather than polling for inputs in get_action_from_keypress
.
I know that craftax is an RL research environment, not primarily a video game designed for human players. But these issues seem potentially easy to fix and could make a big difference in the experience for researchers orienting towards the environment (such as myself).
Love this env!
It looks like the dtypes of the action and obs space don't match the actual returned values
https://gist.github.com/wassname/56b8c323a9fd5c92ec777904ce319f59
P.S. Might be useful to you, on a fork I
Hi,
It doesn't seem like this is an issue for most people, but I'm running into it :(
Traceback (most recent call last):
File "/Users/miniconda3/lib/python3.10/site-packages/craftax/craftax/play_craftax.py", line 163, in main
renderer.render(env_state)
File "/Users/miniconda3/lib/python3.10/site-packages/craftax/craftax/play_craftax.py", line 112, in render
pixels = self._render(env_state, block_pixel_size=BLOCK_PIXEL_SIZE_HUMAN)
File "/Users/miniconda3/lib/python3.10/site-packages/craftax/craftax/renderer.py", line 241, in render_craftax_pixels
map_pixels, _ = jax.lax.scan(
File "/Users/miniconda3/lib/python3.10/site-packages/craftax/craftax/renderer.py", line 236, in _add_block_type_to_pixels
+ textures["full_map_block_textures"][block_index]
File "/Users/miniconda3/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 739, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "/Users/miniconda3/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
return binary_op(*args)
File "/Users/miniconda3/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 102, in fn
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
TypeError: mul got incompatible shapes for broadcasting: (576, 1216, 3), (576, 704, 3).
My dependencies are:
jax 0.4.27
jaxlib 0.4.27
gym 0.26.2
gymnasium 0.29.1
gymnax 0.0.8
numpy 1.26.4
I hope this is just some version incompatibility :( or if there are others who run into the same problem.
The paper explains that Craftax Pixels-based observations comprised of 10x10 squares (emphasis added):
PixelsThe pixel-based observations for Craftax-Classic take the same form as those in Crafter, with each 16x16 square downscaled to 7x7. For Craftax, we downscale only to 10x10. This is because we deal with numbers greater than 9, meaning that the digits had to be rendered in a smaller font size. At 7x7 downscaling many of these digits were indistinguishable. Pixel observations from Crafter, Craftax-Classic and Craftax are shown in Figure 25.
But the source code craftax/craftax/constants.py
sets BLOCK_PIXEL_SIZE_AGENT = 7
on line 17.
And I confirmed that the following code produces observations where certain digits are indistinguishable (1 from 5, 2 from 3, 0 from 8 from 9).
import jax
from craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
import matplotlib.pyplot as plt
env = CraftaxPixelsEnv()
rng = jax.random.PRNGKey(seed=0)
obs, state = env.reset(rng)
for i in range(100):
print(f"rendering {i}...")
custom_state = state.replace(
inventory=state.inventory.replace(
wood=i,
),
)
rgb = env.get_obs(custom_state)
print(rgb.shape)
plt.imshow(rgb)
plt.show()
I think this means that with the default configuration as-released, Craftax does not conform to the description in the paper, in terms of observation size or digit distinguishability.
From craftax/world_gen/world_gen_configs.py
I was testing out training your PPO-RNN baseline.
I was able to reproduce the results, so that's great!
I saved the model via your --save_policy
argument, but I did not see anything to load the policy back once I saved it.
For example, if I trained the model for 1B steps and saved the policy, and then sometime later, I wanted to retrain the model for another 1B steps but start from an existing checkpoint rather than just retrain all 2B again.
Could you show me how I can do this?
Thanks!
Hi,
I wanted to check out Craftax as a possible benchmark for an RNN-RL project, but I'm having some trouble building the project locally (pip install craftax
works fine) that seems like it might be a bug. When trying to pip install in editable mode (on python 3.10.12), I get the following:
riley@monolith:~/research/rl_rnn_demo/Craftax$ pip install -e .
Defaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/riley/research/rl_rnn_demo/Craftax
Installing build dependencies ... done
Checking if build backend supports build_editable ... done
ERROR: Project file:///home/riley/research/rl_rnn_demo/Craftax has a 'pyproject.toml' and its build backend is missing the 'build_editable' hook. Since it does not have a 'setup.py' nor a 'setup.cfg', it cannot be installed in editable mode. Consider using a build backend that supports PEP 660.
Can you advise on how to resolve/whether this is a Craftax issue?
Hi!
Does craftax currently provide natural language descriptions of the states? Or captions?
If so, how can these be accessed?
Thanks!
Edit: I was able to render the textual descriptions:
import jax
from craftax_classic.envs.craftax_symbolic_env import CraftaxClassicSymbolicEnv
from craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from environment_base.wrappers import LogWrapper
from craftax.renderer import render_craftax_text
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)
# Create environment
env = CraftaxSymbolicEnv()
env_params = env.default_params
# Get an initial state and observation
obs, state = env.reset(rngs[0], env_params)
# Pick random action
action = env.action_space(env_params).sample(rngs[1])
# Step environment
obs, state, reward, done, info = env.step(rngs[2], state, action, env_params)
# Get the text representation of the state
text_state = render_craftax_text(state)
but I am not sure this works if using batched environments and/or the LogWrapper. Could you please confirm?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.