This project has been moved to: https://github.com/coax-dev/coax
microsoft / coax Goto Github PK
View Code? Open in Web Editor NEWThis project was moved to: https://github.com/coax-dev/coax
Home Page: https://github.com/coax-dev/coax
This project was moved to: https://github.com/coax-dev/coax
Home Page: https://github.com/coax-dev/coax
This project has been moved to: https://github.com/coax-dev/coax
There are important files that Microsoft projects should all have that are not present in this repository. A pull request has been opened to add the missing file(s). When the pr is merged this issue will be closed automatically.
Microsoft teams can learn more about this effort and share feedback within the open source guidance available internally.
This repository is currently missing a LICENSE file.
A license helps users understand how to use your project in a compliant manner. You can find the standard MIT license Microsoft uses at: https://github.com/microsoft/repo-templates/blob/main/shared/LICENSE.
If you would like to learn more about open source licenses, please visit the document at https://aka.ms/license (Microsoft-internal guidance).
Describe the bug
When running the PPO on Pong example, the line pi = coax.Policy(func_pi, env)
throws the following exception:
--------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-e81d2937c9e6> in <module>
1 # function approximators
----> 2 pi = coax.Policy(func_pi, env)
3 v = coax.V(func_v, env)
4
5 # target networks
/opt/conda/lib/python3.8/site-packages/coax/_core/policy.py in __init__(self, func, env, observation_preprocessor, proba_dist, random_seed)
70 proba_dist = ProbaDist(env.action_space)
71
---> 72 super().__init__(
73 func=func,
74 observation_space=env.observation_space,
/opt/conda/lib/python3.8/site-packages/coax/_core/base_stochastic_func_type2.py in __init__(self, func, observation_space, action_space, observation_preprocessor, proba_dist, random_seed)
151
152 # note: self._modeltype is set in super().__init__ via self._check_signature
--> 153 super().__init__(
154 func=func,
155 observation_space=observation_space,
/opt/conda/lib/python3.8/site-packages/coax/_core/base_func.py in __init__(self, func, observation_space, action_space, random_seed)
99
100 # init function params and state
--> 101 self._params, self._function_state = transformed.init(self.rng, *example_data.inputs.args)
102
103 # check if output has the expected shape etc.
/opt/conda/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(rng, *args, **kwargs)
275 rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
276 with base.new_context(rng=rng) as ctx:
--> 277 f(*args, **kwargs)
278 return ctx.collect_params(), ctx.collect_initial_state()
279
<ipython-input-8-51963c5ee651> in func_pi(S, is_training)
15 hk.Linear(env.action_space.n, w_init=jnp.zeros),
16 ))
---> 17 X = shared(S, is_training)
18 return {'logits': logits(X)}
19
<ipython-input-8-51963c5ee651> in shared(S, is_training)
7 ])
8 X = jnp.stack(S, axis=-1) / 255. # stack frames
----> 9 return seq(X)
10
11
/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
404 f = stateful.named_call(f, name=local_name)
405
--> 406 out = f(*args, **kwargs)
407
408 # Notify parent modules about our existence.
/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in run_interceptors(bound_method, method_name, self, *args, **kwargs)
261 """Runs any method interceptors or the original method."""
262 if not interceptor_stack:
--> 263 return bound_method(*args, **kwargs)
264
265 ctx = MethodContext(module=self,
/opt/conda/lib/python3.8/site-packages/haiku/_src/basic.py in __call__(self, inputs, *args, **kwargs)
124 out = layer(out, *args, **kwargs)
125 else:
--> 126 out = layer(out)
127 return out
128
/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
404 f = stateful.named_call(f, name=local_name)
405
--> 406 out = f(*args, **kwargs)
407
408 # Notify parent modules about our existence.
/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in run_interceptors(bound_method, method_name, self, *args, **kwargs)
261 """Runs any method interceptors or the original method."""
262 if not interceptor_stack:
--> 263 return bound_method(*args, **kwargs)
264
265 ctx = MethodContext(module=self,
/opt/conda/lib/python3.8/site-packages/haiku/_src/conv.py in __call__(self, inputs)
193 w *= self.mask
194
--> 195 out = lax.conv_general_dilated(inputs,
196 w,
197 window_strides=self.stride,
/opt/conda/lib/python3.8/site-packages/jax/_src/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision)
596 np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
597 window_strides, padding)
--> 598 return conv_general_dilated_p.bind(
599 lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
600 lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
/opt/conda/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
280 top_trace = find_top_trace(args)
281 tracers = map(top_trace.full_raise, args)
--> 282 out = top_trace.process_primitive(self, tracers, params)
283 return map(full_lower, out) if self.multiple_results else full_lower(out)
284
/opt/conda/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
626
627 def process_primitive(self, primitive, tracers, params):
--> 628 return primitive.impl(*tracers, **params)
629
630 def process_call(self, primitive, f, tracers, params):
/opt/conda/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
237 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
238 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 239 return compiled_fun(*args)
240
241
/opt/conda/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
355 device, = compiled.local_devices()
356 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 357 out_bufs = compiled.execute(input_bufs)
358 check_special(prim, out_bufs)
359 return result_handler(*out_bufs)
RuntimeError: Unimplemented: DNN library is not found.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Expected behavior
Code should run without throwing any exception
Desktop (please complete the following information):
jax
version: 0.2.9jaxlib
version: 0.1.60+cuda101coax
version: 0.1.6To Reproduce
Steps to reproduce the behavior:
docker build -t rl:gpu -f gpu.Dockerfile . && \
docker run -it \
--gpus all \
-u vscode \
-p 8888:8888 \
-v $(pwd):/workspaces/rl \
-w /workspaces/rl \
--rm \
--name rl \
rl:gpu jupyter lab --ip 0.0.0.0 --no-browser
Script to run:
import os
# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu') # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1' # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tell XLA to be quiet
import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam
# the name of this script
name = 'ppo'
# env with preprocessing
env = gym.make('PongNoFrameskip-v4')
env = gym.wrappers.AtariPreprocessing(env)
env = coax.wrappers.FrameStacking(env, num_frames=3)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
def shared(S, is_training):
seq = hk.Sequential([
coax.utils.diff_transform,
hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
hk.Flatten(),
])
X = jnp.stack(S, axis=-1) / 255. # stack frames
return seq(X)
def func_pi(S, is_training):
logits = hk.Sequential((
hk.Linear(256), jax.nn.relu,
hk.Linear(env.action_space.n, w_init=jnp.zeros),
))
X = shared(S, is_training)
return {'logits': logits(X)}
def func_v(S, is_training):
value = hk.Sequential((
hk.Linear(256), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
X = shared(S, is_training)
return value(X)
# function approximators
pi = coax.Policy(func_pi, env)
Additional context
The DQN on CartPole example runs without problems and is utilizing the GPU, so I think all the necessary GPU drivers are correctly installed.
Hi there,
Great library! Thanks a ton for doing this work.
I'm working on some environments that require a lot of samples and so was reading through the documentation looking for some sort of built-in way to handle running multiple environments in parallel. I didn't find anything, so I'm wondering, is there a utility I'm missing or perhaps a recommended way to handle this with coax?
If not, are there plans to add support for parallel environments in the future?
Thanks again, so far coax is a delight to use!
Suppose I train a Q-Learning agent on FrozenLake as per this tutorial, what is the best way to save this model/agent for use elsewhere?
I'm thinking I need to just save the Q-values and the potentially the policy pi? Interestingly the doc seems to have a Q.save(...)
and a pi.save(...)
, however when I try to use them they appear to be possibly not currently implemented in the code?
Many thanks for any help, and for this fantastic lib! :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.