Giter Club home page Giter Club logo

baselines-tf2's People

Contributors

tanzhenyu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

baselines-tf2's Issues

I would like to ask you some questions

你好,我就直接用中文问您了,请问您使用tf2.0特征列的时候,有多少个原始特征输入就必须写多少个tf.keras.input吗 这样不仅语法复杂,好像部署发布的时候 延迟也很高 请问有其他办法解决这个问题吗

Cannot run DQN (deepq) Cartpole example

When I try to run the DQN Cartpole example by running this command (from this README)

python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5

I get the following error (full error log at the end)

TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.

For reference, here is my pip list:

Package              Version
-------------------- --------------------
absl-py              0.7.1
astor                0.8.0
certifi              2019.6.16
cloudpickle          1.2.1
future               0.17.1
gast                 0.2.2
google-pasta         0.1.7
grpcio               1.21.1
gym                  0.13.0
h5py                 2.9.0
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.0
Markdown             3.1.1
numpy                1.16.4
opencv-python        4.1.0.25
pip                  19.1.1
protobuf             3.8.0
pyglet               1.3.2
scipy                1.3.0
setuptools           41.0.1
six                  1.12.0
tb-nightly           1.14.0a20190603
tensorflow           2.0.0b1
termcolor            1.1.0
tf-estimator-nightly 1.14.0.dev2019060501
Werkzeug             0.15.4
wheel                0.33.4
wrapt                1.11.2

Full error log:

(baselines-tf2) ryanlee@ryanlee-ThinkPad-T430s:~/git/baselines-tf2$ python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5
Logging to /tmp/openai-2019-07-02-12-05-25-399138
env_type: classic_control
Training deepq on classic_control:CartPole-v0 with arguments
{'network': 'mlp'}
input shape is (4,)
2019-07-02 12:05:25.466451: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2893300000 Hz
2019-07-02 12:05:25.466750: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x557cacf8d590 executing computations on platform Host. Devices:
2019-07-02 12:05:25.466798: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
2019-07-02 12:05:25.657545: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1483] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU.  To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
input shape is (4,)
/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/numpy/core/_methods.py:85: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
Traceback (most recent call last):
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 240, in <module>
    main(sys.argv)
  File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 203, in main
    model, env = train(args, extra_args)
  File "/home/ryanlee/git/baselines-tf2/baselines/run.py", line 79, in train
    **alg_kwargs
  File "/home/ryanlee/git/baselines-tf2/baselines/deepq/deepq.py", line 207, in learn
    td_errors = model.train(obses_t, actions, rewards, obses_tp1, dones, weights)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 416, in __call__
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 359, in _initialize
    *args, **kwds))
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1360, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1648, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1541, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 716, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 309, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2155, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/home/ryanlee/anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 706, in wrapper
    raise e.ag_error_metadata.to_exception(type(e))
TypeError: in converted code:
    relative to /home/ryanlee:

    git/baselines-tf2/baselines/deepq/deepq_learner.py:152 train  *
        q_t_selected = tf.reduce_sum(q_t * tf.one_hot(actions, self.num_actions, dtype=tf.float64), 1)
    anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:884 binary_op_wrapper
        return func(x, y, name=name)
    anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:1180 _mul_dispatch
        return gen_math_ops.mul(x, y, name=name)
    anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:6490 mul
        "Mul", x=x, y=y, name=name)
    anaconda3/envs/baselines-tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:563 _apply_op_helper
        inferred_from[input_arg.type_attr]))

    TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.

Setting weights to 0 to kill training

This is a strange request, but how can I set the weights to 0 to kill the training (right now using trpo_mpi.py and Ant-v2)? Eventually, I'd like to set certain weights to 0 so that certain actions are not being trained. I've tried to do this in policies.py by creating another function:

def kill_actions(self):
weights2 = 0self.policy_network.get_layer('dense_2').get_weights()[0]
bias2 = 0
self.policy_network.get_layer('dense_2').get_weights()[1]
self.policy_network.get_layer('dense_2').set_weights([weights2,bias2])

while using the network:

def network_fn(input_shape): # TEST
##################################################################################################
# input_shape = (27,) - for Ant-v2

    print('input shape is {}'.format(input_shape))
    x_input = tf.keras.Input(shape=input_shape)
    # h = tf.keras.layers.Flatten(x_input)
    h = x_input
    h = tf.keras.layers.Dense(units=num_hidden, kernel_initializer=ortho_init(np.sqrt(2)),
                                name='dense_0'.format(0), activation=activation)(h)
    h = tf.keras.layers.Dense(units=num_hidden, kernel_initializer=ortho_init(np.sqrt(2)),
                                name='dense_1'.format(1), activation=activation)(h)
    h = tf.keras.layers.Dense(units=8, kernel_initializer=ortho_init(np.sqrt(2)),
                                name='dense_2'.format(2), activation=activation)(h)
    network = tf.keras.Model(inputs=[x_input], outputs=[h])
    return network

shape = (27,)
network = network_fn(shape) # FOR Ant-v2 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
return network_fn, network

and calling this just before everywhere a pi.step(ob) is called. However, this does not set the actions to 0 nor does it prevent the ant from training. Shouldn't the actions go to 0 if I zero out the weights and biases in the output layer?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.