tanzhenyu / baselines-tf2 Goto Github PK
View Code? Open in Web Editor NEWopenai baselines with tensorflow 2.0
License: MIT License
openai baselines with tensorflow 2.0
License: MIT License
This is a strange request, but how can I set the weights to 0 to kill the training (right now using trpo_mpi.py and Ant-v2)? Eventually, I'd like to set certain weights to 0 so that certain actions are not being trained. I've tried to do this in policies.py by creating another function:
def kill_actions(self):
weights2 = 0self.policy_network.get_layer('dense_2').get_weights()[0]
bias2 = 0self.policy_network.get_layer('dense_2').get_weights()[1]
self.policy_network.get_layer('dense_2').set_weights([weights2,bias2])
while using the network:
def network_fn(input_shape): # TEST
##################################################################################################
# input_shape = (27,) - for Ant-v2
print('input shape is {}'.format(input_shape))
x_input = tf.keras.Input(shape=input_shape)
# h = tf.keras.layers.Flatten(x_input)
h = x_input
h = tf.keras.layers.Dense(units=num_hidden, kernel_initializer=ortho_init(np.sqrt(2)),
name='dense_0'.format(0), activation=activation)(h)
h = tf.keras.layers.Dense(units=num_hidden, kernel_initializer=ortho_init(np.sqrt(2)),
name='dense_1'.format(1), activation=activation)(h)
h = tf.keras.layers.Dense(units=8, kernel_initializer=ortho_init(np.sqrt(2)),
name='dense_2'.format(2), activation=activation)(h)
network = tf.keras.Model(inputs=[x_input], outputs=[h])
return network
shape = (27,)
network = network_fn(shape) # FOR Ant-v2 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
return network_fn, network
and calling this just before everywhere a pi.step(ob) is called. However, this does not set the actions to 0 nor does it prevent the ant from training. Shouldn't the actions go to 0 if I zero out the weights and biases in the output layer?
Thanks
你好,我就直接用中文问您了,请问您使用tf2.0特征列的时候,有多少个原始特征输入就必须写多少个tf.keras.input吗 这样不仅语法复杂,好像部署发布的时候 延迟也很高 请问有其他办法解决这个问题吗
When I try to run the DQN Cartpole example by running this command (from this README)
python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5
I get the following error (full error log at the end)
TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.
For reference, here is my pip list
:
Package Version
-------------------- --------------------
absl-py 0.7.1
astor 0.8.0
certifi 2019.6.16
cloudpickle 1.2.1
future 0.17.1
gast 0.2.2
google-pasta 0.1.7
grpcio 1.21.1
gym 0.13.0
h5py 2.9.0
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.0
Markdown 3.1.1
numpy 1.16.4
opencv-python 4.1.0.25
pip 19.1.1
protobuf 3.8.0
pyglet 1.3.2
scipy 1.3.0
setuptools 41.0.1
six 1.12.0
tb-nightly 1.14.0a20190603
tensorflow 2.0.0b1
termcolor 1.1.0
tf-estimator-nightly 1.14.0.dev2019060501
Werkzeug 0.15.4
wheel 0.33.4
wrapt 1.11.2
Full error log:
(baselines-tf2) ryanlee@ryanlee-ThinkPad-T430s:~/git/baselines-tf2$ python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5
Logging to /tmp/openai-2019-07-02-12-05-25-399138
env_type: classic_control
Training deepq on classic_control:CartPole-v0 with arguments
{'network': 'mlp'}
input shape is (4,)
2019-07-02 12:05:25.466451: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2893300000 Hz
2019-07-02 12:05:25.466750: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x557cacf8d590 executing computations on platform Host. Devices:
2019-07-02 12:05:25.466798: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): <undefined>, <undefined>
2019-07-02 12:05:25.657545: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1483] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU. To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
input shape is (4,)
/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/numpy/core/_methods.py:85: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Traceback (most recent call last):
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 240, in <module>
main(sys.argv)
File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 203, in main
model, env = train(args, extra_args)
File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 79, in train
**alg_kwargs
File "/home/ryanlee/git/baselines-tf2/baselines/deepq/deepq.py", line 207, in learn
td_errors = model.train(obses_t, actions, rewards, obses_tp1, dones, weights)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 416, in __call__
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 359, in _initialize
*args, **kwds))
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1360, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1648, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1541, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 716, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 309, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2155, in bound_method_wrapper
return wrapped_fn(*args, **kwargs)
File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 706, in wrapper
raise e.ag_error_metadata.to_exception(type(e))
TypeError: in converted code:
relative to /home/ryanlee:
git/baselines-tf2/baselines/deepq/deepq_learner.py:152 train *
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(actions, self.num_actions, dtype=tf.float64), 1)
anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:884 binary_op_wrapper
return func(x, y, name=name)
anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:1180 _mul_dispatch
return gen_math_ops.mul(x, y, name=name)
anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:6490 mul
"Mul", x=x, y=y, name=name)
anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:563 _apply_op_helper
inferred_from[input_arg.type_attr]))
TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.
I have started working on upgrading Generative Adversarial Imitation Learning (GAIL) to TF2 in my forked repository. This will probably need multiple PRs :)
References
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.