Giter Club home page Giter Club logo

Comments (3)

Trinkle23897 avatar Trinkle23897 commented on September 18, 2024 1

Do you want it inside envpool package or in the example folder?

I think it should be in example. I don't want to put tianshou's code in envpool/ either, so that the library's code is clean enough.

from envpool.

Trinkle23897 avatar Trinkle23897 commented on September 18, 2024

because I cannot step in a particular env... (env.send() does exist but env.recv() does not garantee the result to be from the same env).

Not quite sure your approach, but now envpool supports this feature when num_envs == batch_size

def test_partial_step(self) -> None:
num_envs = 5
max_episode_steps = 10
config = AtariEnvSpec.gen_config(
task="defender", num_envs=num_envs, max_episode_steps=max_episode_steps
)
spec = AtariEnvSpec(config)
env = AtariGymEnvPool(spec)
for _ in range(3):
print(env)
env.reset()
partial_ids = [np.arange(num_envs)[::2], np.arange(num_envs)[1::2]]
env.step(np.zeros(len(partial_ids[1]), dtype=int), env_id=partial_ids[1])
for _ in range(max_episode_steps - 2):
info = env.step(
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
)[-1]
assert np.all(~info["TimeLimit.truncated"])
info = env.step(
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
)[-1]
env_id = np.array(info["env_id"])
done_id = np.array(sorted(env_id[info["TimeLimit.truncated"]]))
assert np.all(done_id == partial_ids[1])
info = env.step(
np.zeros(len(partial_ids[0]), dtype=int),
env_id=partial_ids[0],
)[-1]
assert np.all(info["TimeLimit.truncated"])

# Of course, you can specify env_id to step corresponding envs

Indeed this feature lacks documentation, I'll add later...

I can also make a PR if you think it makes sense to integrate it directly into envpool (would make it easier for people already using gym / SB3 to adopt envpool ;))

Awesome! Looking forward to that.

from envpool.

araffin avatar araffin commented on September 18, 2024

Thanks for the heads up =)
My updated code, I'll try to make a PR tomorrow ;). Do you want it inside envpool package or in the example folder?

from typing import Optional

import envpool
import gym
import numpy as np
import torch as th
from envpool.python.protocol import EnvPool
from gym.envs.registration import EnvSpec

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor
from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnv,
    VecEnvObs,
    VecEnvStepReturn,
    VecEnvWrapper,
)
from stable_baselines3.common.evaluation import evaluate_policy

# Force PyTorch to use only one threads
# make things faster for simple envs
th.set_num_threads(1)

num_envs = 4
env_id = "Pendulum-v0"
seed = 0
use_env_pool = True


class VecAdapter(VecEnvWrapper):
    """
    Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.
    

    :param venv: The envpool object.
    """
    def __init__(self, venv: EnvPool):
        venv.num_envs = venv.spec.config.num_envs
        super().__init__(venv=venv)

    def step_async(self, actions: np.ndarray) -> None:
        self.actions = actions

    def reset(self) -> VecEnvObs:
        return self.venv.reset()

    def seed(self, seed: Optional[int] = None) -> None:
        # You can only seed EnvPool env by calling envpool.make()
        pass

    def step_wait(self) -> VecEnvStepReturn:
        obs, rewards, dones, info_dict = self.venv.step(self.actions)
        infos = []
        # Convert dict to list of dict
        # and add terminal observation
        for i in range(self.num_envs):
            infos.append(
                {
                    key: info_dict[key][i]
                    for key in info_dict.keys()
                    if isinstance(info_dict[key], np.ndarray)
                }
            )
            if dones[i]:
                infos[i]["terminal_observation"] = obs[i]
                obs[i] = self.venv.reset(np.array([i]))

        return obs, rewards, dones, infos


if use_env_pool:
    env = envpool.make(env_id, env_type="gym", num_envs=num_envs, seed=seed)
    env.spec.id = env_id
    env = VecAdapter(env)
    env = VecMonitor(env)
else:
    env = make_vec_env(env_id, n_envs=num_envs)

# Tuned hyperparams for Pendulum-v0
model = PPO(
    "MlpPolicy",
    env,
    n_steps=1024,
    learning_rate=1e-3,
    use_sde=True,
    sde_sample_freq=4,
    gae_lambda=0.95,
    gamma=0.9,
    verbose=1,
    seed=seed,
)
# model = PPO(
#     "MlpPolicy",
#     env,
#     learning_rate=1e-3,
#     gae_lambda=0.95,
#     gamma=0.9,
#     verbose=1,
#     seed=seed,
# )
try:
    model.learn(100_000)
except KeyboardInterrupt:
    pass

# Agent trained on envpool version should also perform well on regular Gym env
test_env = gym.make(env_id)

# Test with EnvPool
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"EnvPool - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# Test with Gym
mean_reward, std_reward = evaluate_policy(model, test_env, n_eval_episodes=20, warn=False)
print(f"Gym - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")

from envpool.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.