Skip to content

Commit 8826749

Browse files
committed
Add support for the "expanded cost features" model.
This commit includes two separate changes: 1) the addition of extra features which contribute to the "inlining cost" estimate, and 2) extending the size and training duration of the model to account for the additional features.
1 parent b235539 commit 8826749

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+44542
-2523
lines changed

compiler_opt/rl/feature_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _build_quantile_map(quantile_file_dir):
4242

4343
@gin.configurable
4444
def get_observation_processing_layer_creator(quantile_file_dir,
45+
with_sqrt=False,
4546
with_z_score_normalization=False,
4647
eps=1e-8):
4748
"""Wrapper for observation_processing_layer."""
@@ -67,7 +68,9 @@ def normalization(obs):
6768
x = tf.cast(
6869
tf.raw_ops.Bucketize(input=expanded_obs, boundaries=quantile),
6970
tf.float32) / len(quantile)
70-
features = [x, tf.sqrt(x), x * x]
71+
features = [x, x * x]
72+
if with_sqrt:
73+
features.append(np.sqrt(x))
7174
if with_z_score_normalization:
7275
y = tf.cast(expanded_obs, tf.float32)
7376
y = (y - mean) / (std + eps)

compiler_opt/rl/feature_ops_test.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@
1818
import os
1919

2020
from absl.testing import parameterized
21+
import numpy as np
2122
import tensorflow as tf
2223

2324
from compiler_opt.rl import constant
2425
from compiler_opt.rl import feature_ops
2526

27+
_WITH_Z_SCORE_SQRT_PRODUCT_VALUES = [('with_sqrt_with_z_score', True, True),
28+
('with_sqrt_without_z_score', True, False),
29+
('without_sqrt_with_z_score', False, True),
30+
('without_sqrt_without_z_score', False,
31+
False)]
32+
2633

2734
class FeatureUtilsTest(tf.test.TestCase, parameterized.TestCase):
2835

@@ -50,12 +57,11 @@ def test_build_quantile_map_from_config(self):
5057
# std
5158
self.assertAllClose(14.988885, std)
5259

53-
@parameterized.named_parameters(('with_z_score', True),
54-
('without_z_score', False))
55-
def test_create_observation_processing_layer(self, with_z_score):
60+
@parameterized.named_parameters(*_WITH_Z_SCORE_SQRT_PRODUCT_VALUES)
61+
def test_create_observation_processing_layer(self, with_z_score, with_sqrt):
5662
observation_processing_layer = (
5763
feature_ops.get_observation_processing_layer_creator(
58-
self._quantile_file_dir, with_z_score))
64+
self._quantile_file_dir, with_sqrt, with_z_score))
5965

6066
obs_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='edge_count')
6167
processing_layer = observation_processing_layer(obs_spec)
@@ -65,24 +71,28 @@ def test_create_observation_processing_layer(self, with_z_score):
6571

6672
outputs = self.evaluate(outputs)
6773

74+
expected_shape = [2, 1, 2]
75+
expected = np.array([[[0.333333, 0.111111]], [[0.777778, 0.604938]]])
76+
77+
if with_sqrt:
78+
expected_shape[2] += 1
79+
expected = np.concatenate([expected, [[[0.57735]], [[0.881917]]]],
80+
axis=-1)
81+
6882
if with_z_score:
69-
self.assertAllEqual([2, 1, 4], outputs.shape)
70-
self.assertAllClose([[[0.333333, 0.57735, 0.111111, -0.555968]],
71-
[[0.777778, 0.881917, 0.604938, -0.155671]]],
72-
outputs)
73-
else:
74-
self.assertAllEqual([2, 1, 3], outputs.shape)
75-
self.assertAllClose(
76-
[[[0.333333, 0.57735, 0.111111]], [[0.777778, 0.881917, 0.604938]]],
77-
outputs)
78-
79-
@parameterized.named_parameters(('with_z_score', True),
80-
('without_z_score', False))
83+
expected_shape[2] += 1
84+
expected = np.concatenate([expected, [[[-0.555968]], [[-0.155671]]]],
85+
axis=-1)
86+
87+
self.assertAllEqual(expected_shape, outputs.shape)
88+
self.assertAllClose(expected.tolist(), outputs)
89+
90+
@parameterized.named_parameters(*_WITH_Z_SCORE_SQRT_PRODUCT_VALUES)
8191
def test_create_observation_processing_layer_for_dummy_features(
82-
self, with_z_score):
92+
self, with_z_score, with_sqrt):
8393
observation_processing_layer = (
8494
feature_ops.get_observation_processing_layer_creator(
85-
self._quantile_file_dir, with_z_score))
95+
self._quantile_file_dir, with_sqrt, with_z_score))
8696

8797
obs_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='dummy_feature')
8898
processing_layer = observation_processing_layer(obs_spec)

compiler_opt/rl/inlining/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_inlining_signature_spec():
2727
"""Returns (time_step_spec, action_spec) for LLVM inlining."""
2828
observation_spec = dict(
2929
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) for key in (
30+
# Base features
3031
'caller_basic_block_count',
3132
'caller_conditionally_executed_blocks',
3233
'caller_users',
@@ -38,6 +39,33 @@ def get_inlining_signature_spec():
3839
'edge_count',
3940
'callsite_height',
4041
'cost_estimate',
42+
43+
# Expanded cost features
44+
'sroa_savings',
45+
'sroa_losses',
46+
'load_elimination',
47+
'call_penalty',
48+
'call_argument_setup',
49+
'load_relative_intrinsic',
50+
'lowered_call_arg_setup',
51+
'indirect_call_penalty',
52+
'jump_table_penalty',
53+
'case_cluster_penalty',
54+
'switch_penalty',
55+
'unsimplified_common_instructions',
56+
'num_loops',
57+
'dead_blocks',
58+
'simplified_instructions',
59+
'constant_args',
60+
'constant_offset_ptr_args',
61+
'callsite_cost',
62+
'cold_cc_penalty',
63+
'last_call_to_static_bonus',
64+
'is_multiple_blocks',
65+
'nested_inlines',
66+
'nested_inline_cost_estimate',
67+
'threshold',
68+
4169
# inlining_default is not used as feature in training.
4270
'inlining_default'))
4371
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')

compiler_opt/rl/inlining/gin_configs/behavioral_cloning_nn_agent.gin

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ train_eval.batch_size=64
1111
train_eval.train_sequence_length=1
1212

1313
get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab'
14+
get_observation_processing_layer_creator.with_sqrt = False
1415
get_observation_processing_layer_creator.with_z_score_normalization = False
1516

1617
create_agent.policy_network = @q_network.QNetwork
1718

1819
19-
QNetwork.fc_layer_params=(40, 20)
20-
QNetwork.dropout_layer_params=(0.2, 0.2)
20+
QNetwork.fc_layer_params=(40, 40, 20)
21+
QNetwork.dropout_layer_params=(0.2, 0.2, 0.2)
2122
2223

2324
tf.train.AdamOptimizer.learning_rate = 0.001

compiler_opt/rl/inlining/gin_configs/ppo_nn_agent.gin

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,20 @@ [email protected]_inlining_signature_spec
99
train_eval.agent_name='ppo'
1010
train_eval.warmstart_policy_dir=''
1111
train_eval.num_policy_iterations=3000
12-
train_eval.num_iterations=200
12+
train_eval.num_iterations=300
1313
train_eval.batch_size=128
1414
train_eval.train_sequence_length=16
1515
train_eval.deploy_policy_name='saved_collect_policy'
1616
train_eval.use_stale_results=False
1717

1818
get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab'
19+
get_observation_processing_layer_creator.with_sqrt = False
1920
get_observation_processing_layer_creator.with_z_score_normalization = False
2021

2122
create_agent.policy_network = @actor_distribution_network.ActorDistributionNetwork
2223

2324
ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
24-
ActorDistributionNetwork.fc_layer_params=(40, 20)
25+
ActorDistributionNetwork.fc_layer_params=(40, 40, 20)
2526
ActorDistributionNetwork.dropout_layer_params=None
2627
ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu
2728

0 commit comments

Comments
 (0)