Skip to content
This repository was archived by the owner on Mar 11, 2021. It is now read-only.

Commit 0e59c3c

Browse files
committed
Change avg_stones to histogram
1 parent f194b0b commit 0e59c3c

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

dual_net.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def model_fn(features, labels, mode, params):
285285

286286
# Computations to be executed on CPU, outside of the main TPU queues.
287287
def eval_metrics_host_call_fn(
288-
features,
288+
avg_stones, avg_stones_delta,
289289
policy_output, value_output,
290290
pi_tensor, value_tensor,
291291
policy_cost, value_cost,
@@ -309,8 +309,6 @@ def eval_metrics_host_call_fn(
309309

310310
value_cost_normalized = value_cost / params['value_cost_weight']
311311
avg_value_observed = tf.reduce_mean(value_tensor)
312-
avg_stones_black = tf.reduce_mean(tf.reduce_sum(features[:,:,:,1], [1,2]))
313-
avg_stones_white = tf.reduce_mean(tf.reduce_sum(features[:,:,:,0], [1,2]))
314312

315313
with tf.variable_scope('metrics'):
316314
metric_ops = {
@@ -329,8 +327,7 @@ def eval_metrics_host_call_fn(
329327
'policy_target_top_1_confidence': tf.metrics.mean(
330328
policy_target_top_1_confidence),
331329
'avg_value_observed': tf.metrics.mean(avg_value_observed),
332-
'avg_stones_black': tf.metrics.mean(avg_stones_black),
333-
'avg_stones_white': tf.metrics.mean(avg_stones_white),
330+
'avg_stones_black': tf.metrics.mean(tf.reduce_mean(avg_stones)),
334331
}
335332

336333
if est_mode == tf.estimator.ModeKeys.EVAL:
@@ -348,6 +345,8 @@ def eval_metrics_host_call_fn(
348345
for metric_name, metric_op in metric_ops.items():
349346
summary.scalar(metric_name, metric_op[1], step=eval_step)
350347

348+
tf.summary.histogram("avg_stones_white", avg_stones_delta)
349+
351350
# Reset metrics occasionally so that they are mean of recent batches.
352351
reset_op = tf.variables_initializer(tf.local_variables('metrics'))
353352
cond_reset_op = tf.cond(
@@ -357,8 +356,14 @@ def eval_metrics_host_call_fn(
357356

358357
return summary.all_summary_ops() + [cond_reset_op]
359358

359+
# compute here to avoid sending all of features to cpu.
360+
avg_stones_black = tf.reduce_sum(features[:,:,:,1], [1,2])
361+
avg_stones_white = tf.reduce_sum(features[:,:,:,0], [1,2])
362+
avg_stones = avg_stones_black + avg_stones_white
363+
avg_stones_delta = avg_stones_black - avg_stones_white
364+
360365
metric_args = [
361-
features,
366+
avg_stones, avg_stones_delta,
362367
policy_output,
363368
value_output,
364369
labels['pi_tensor'],

0 commit comments

Comments
 (0)