Skip to content

About adding batch normalization #1

@zhiyuanyou

Description

@zhiyuanyou

博主,你好!
非常感谢你的开源代码!对于我理解DDPG有非常大的帮助!
我使用你的开源代码在新的environment中进行DDPG学习的时候,出现了输出的动作总是动作边界值的情况。经过查找资料,我认为batch normalization可以解决这个问题。
于是,我在不同的层之间,使用tf.contrib.layers.batch_norm函数进行batch normalization。但是改动之后,我收到了很长的一堆报错。

Traceback (most recent call last):
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[{{node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm}}]]
	 [[critic_net/ddpg/critic_net/q_output/Relu/_67]]
  (1) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[{{node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "run_ddpg.py", line 168, in <module>
    main()
  File "run_ddpg.py", line 78, in main
    action_without_clip, q = agent.select_action(state, p)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/agent/ddpg.py", line 21, in select_action
    pred_action, pred_q = self.predict_action(observation)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/agent/ddpg.py", line 14, in predict_action
    return self.model.predict_action_q(observation)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_model.py", line 79, in predict_action_q
    q = self.critic.predict_q_source_net(observation, action, sess)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 210, in predict_q_source_net
    self.input_action: feed_action})
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm (defined at /home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py:259) ]]
	 [[critic_net/ddpg/critic_net/q_output/Relu/_67]]
  (1) Internal: cuDNN launch failure : input shape ([1,100,1,1])
	 [[node critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm (defined at /home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py:259) ]]
0 successful operations.
0 derived errors ignored.

Original stack trace for 'critic_net/ddpg/critic_net/cond_1/batch_norm_2/FusedBatchNorm':
  File "run_ddpg.py", line 168, in <module>
    main()
  File "run_ddpg.py", line 53, in main
    tau=TAU)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_model.py", line 52, in __init__
    sess=self.sess)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 53, in __init__
    self.q_output = self.__create_critic_network()
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 99, in __create_critic_network
    activation=tf.nn.relu)
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 260, in batch_norm_layer
    lambda: batch_norm(x, activation_fn=activation, center=True, scale=True,
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1977, in cond
    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1814, in BuildCondBranch
    original_result = fn()
  File "/home/youzhiyuan/Desktop/TensorAgent_single/model/ddpg_critic.py", line 259, in <lambda>
    scope=scope_bn, decay=0.9, epsilon=1e-5),
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 182, in func_with_args
    return func(*args, **current_args)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 596, in batch_norm
    scope=scope)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 383, in _fused_batch_norm
    _fused_batch_norm_inference)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/utils.py", line 214, in smart_cond
    return static_cond(pred_value, fn1, fn2)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/utils.py", line 192, in static_cond
    return fn1()
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/layers.py", line 368, in _fused_batch_norm_training
    inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
    name=name)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 3946, in _fused_batch_norm
    name=name)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/youzhiyuan/anaconda3/envs/pool/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

我的系统是ubuntu16.04,我使用的是tensorflow 1.14.0。
卡在这个问题上数天,依然未能解决这个问题。希望博主在有时间的时候可以帮忙解答。
谢谢博主!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions