@@ -298,7 +298,7 @@ def _compute_linear_output_bound_by_lp(
298
298
) if linear_layer .bias is None else linear_layer .bias [j ]
299
299
Ain_linear_input , Ain_neuron_output , Ain_neuron_binary ,\
300
300
rhs_in , Aeq_linear_input , Aeq_neuron_output ,\
301
- Aeq_neuron_binary , rhs_eq , _ , _ = \
301
+ Aeq_neuron_binary , rhs_eq , _ , _ , _ , _ = \
302
302
_add_constraint_by_neuron (
303
303
linear_layer .weight [j ], bij , relu_layer ,
304
304
torch .tensor (previous_neuron_input_lo [
@@ -519,8 +519,8 @@ def output_constraint(self, x_lo, x_up,
519
519
z_pre_relu_lo , z_pre_relu_up , x_lo , x_up , method )
520
520
for layer_count in range (len (self .relu_unit_index )):
521
521
Ain_z_curr , Ain_z_next , Ain_binary_layer , rhs_in_layer ,\
522
- Aeq_z_curr , Aeq_z_next , Aeq_binary_layer , rhs_eq_layer , _ , _ = \
523
- _add_constraint_by_layer (
522
+ Aeq_z_curr , Aeq_z_next , Aeq_binary_layer , rhs_eq_layer , _ , _ , \
523
+ _ , _ = _add_constraint_by_layer (
524
524
self .model [2 * layer_count ], self .model [2 * layer_count + 1 ],
525
525
z_pre_relu_lo [self .relu_unit_index [layer_count ]],
526
526
z_pre_relu_up [self .relu_unit_index [layer_count ]])
@@ -588,7 +588,14 @@ def output_constraint(self, x_lo, x_up,
588
588
(self .model [- 1 ].out_features , ), dtype = self .dtype )
589
589
else :
590
590
mip_constr_return .Cout = self .model [- 1 ].bias .clone ()
591
-
591
+ binary_lo = torch .zeros ((self .num_relu_units ,), dtype = self .dtype )
592
+ binary_up = torch .ones ((self .num_relu_units ,), dtype = self .dtype )
593
+ # If the input to the relu is always >= 0, then the relu will always
594
+ # be active.
595
+ binary_lo [z_pre_relu_lo >= 0 ] = 1.
596
+ # If the input to the relu is always <= 0, then the relu will always
597
+ # be inactive.
598
+ binary_up [z_pre_relu_up <= 0 ] = 0.
592
599
mip_constr_return .Ain_input = Ain_input [:ineq_constr_count ]
593
600
mip_constr_return .Ain_slack = Ain_slack [:ineq_constr_count ]
594
601
mip_constr_return .Ain_binary = Ain_binary [:ineq_constr_count ]
@@ -597,6 +604,8 @@ def output_constraint(self, x_lo, x_up,
597
604
mip_constr_return .Aeq_slack = Aeq_slack [:eq_constr_count ]
598
605
mip_constr_return .Aeq_binary = Aeq_binary [:eq_constr_count ]
599
606
mip_constr_return .rhs_eq = rhs_eq [:eq_constr_count ]
607
+ mip_constr_return .binary_lo = binary_lo
608
+ mip_constr_return .binary_up = binary_up
600
609
return (mip_constr_return , z_pre_relu_lo , z_pre_relu_up ,
601
610
z_post_relu_lo , z_post_relu_up , output_lo , output_up )
602
611
@@ -994,6 +1003,8 @@ def _add_constraint_by_neuron(
994
1003
Aeq_neuron_output = torch .empty ((0 , 1 ), dtype = dtype )
995
1004
Aeq_binary = torch .empty ((0 , 1 ), dtype = dtype )
996
1005
rhs_eq = torch .empty ((0 , ), dtype = dtype )
1006
+ binary_lo = 0
1007
+ binary_up = 1
997
1008
else :
998
1009
# The (leaky) ReLU is always active, or always inactive. If
999
1010
# the lower bound output_lo[j] >= 0, then it is always active,
@@ -1004,18 +1015,17 @@ def _add_constraint_by_neuron(
1004
1015
# zᵢ₊₁(j) = c*((Wᵢzᵢ)(j) + bᵢ(j)) and βᵢ(j) = 0
1005
1016
if neuron_input_lo >= 0 :
1006
1017
slope = 1.
1007
- binary_value = 1
1018
+ binary_lo = 1
1019
+ binary_up = 1
1008
1020
elif neuron_input_up <= 0 :
1009
1021
slope = relu_layer .negative_slope if isinstance (
1010
1022
relu_layer , nn .LeakyReLU ) else 0.
1011
- binary_value = 0.
1012
- Aeq_linear_input = torch .cat ((- slope * Wij .reshape (
1013
- (1 , - 1 )), torch .zeros ((1 , Wij .numel ()), dtype = dtype )),
1014
- dim = 0 )
1015
- Aeq_neuron_output = torch .tensor ([[1. ], [0 ]], dtype = dtype )
1016
- Aeq_binary = torch .tensor ([[0. ], [1. ]], dtype = dtype )
1017
- rhs_eq = torch .stack (
1018
- (slope * bij , torch .tensor (binary_value , dtype = dtype )))
1023
+ binary_lo = 0
1024
+ binary_up = 0
1025
+ Aeq_linear_input = - slope * Wij .reshape ((1 , - 1 ))
1026
+ Aeq_neuron_output = torch .tensor ([[1. ]], dtype = dtype )
1027
+ Aeq_binary = torch .tensor ([[0. ]], dtype = dtype )
1028
+ rhs_eq = slope * bij .reshape ((1 ,))
1019
1029
Ain_linear_input = torch .empty ((0 , Wij .numel ()), dtype = dtype )
1020
1030
Ain_neuron_output = torch .empty ((0 , 1 ), dtype = dtype )
1021
1031
Ain_binary = torch .empty ((0 , 1 ), dtype = dtype )
@@ -1024,7 +1034,7 @@ def _add_constraint_by_neuron(
1024
1034
relu_layer , neuron_input_lo , neuron_input_up )
1025
1035
return Ain_linear_input , Ain_neuron_output , Ain_binary , rhs_in ,\
1026
1036
Aeq_linear_input , Aeq_neuron_output , Aeq_binary , rhs_eq ,\
1027
- neuron_output_lo , neuron_output_up
1037
+ neuron_output_lo , neuron_output_up , binary_lo , binary_up
1028
1038
1029
1039
1030
1040
def _add_constraint_by_layer (linear_layer , relu_layer ,
@@ -1057,10 +1067,13 @@ def _add_constraint_by_layer(linear_layer, relu_layer,
1057
1067
z_next_up = []
1058
1068
bias = linear_layer .bias if linear_layer .bias is not None else \
1059
1069
torch .zeros ((linear_layer .out_features ,), dtype = dtype )
1070
+ binary_lo = torch .zeros ((linear_layer .out_features ,), dtype = dtype )
1071
+ binary_up = torch .ones ((linear_layer .out_features ,), dtype = dtype )
1060
1072
for j in range (linear_layer .out_features ):
1061
1073
Ain_linear_input , Ain_neuron_output , Ain_binary_j , rhs_in_j ,\
1062
1074
Aeq_linear_input , Aeq_neuron_output , Aeq_binary_j , rhs_eq_j ,\
1063
- neuron_output_lo , neuron_output_up = _add_constraint_by_neuron (
1075
+ neuron_output_lo , neuron_output_up , binary_lo [j ], binary_up [j ] = \
1076
+ _add_constraint_by_neuron (
1064
1077
linear_layer .weight [j ], bias [j ], relu_layer ,
1065
1078
linear_output_lo [j ], linear_output_up [j ])
1066
1079
Ain_z_curr .append (Ain_linear_input )
@@ -1089,4 +1102,4 @@ def _add_constraint_by_layer(linear_layer, relu_layer,
1089
1102
torch .cat (Ain_binary , dim = 0 ), torch .cat (rhs_in , dim = 0 ),\
1090
1103
torch .cat (Aeq_z_curr , dim = 0 ), torch .cat (Aeq_z_next , dim = 0 ),\
1091
1104
torch .cat (Aeq_binary , dim = 0 ), torch .cat (rhs_eq , dim = 0 ),\
1092
- torch .stack (z_next_lo ), torch .stack (z_next_up )
1105
+ torch .stack (z_next_lo ), torch .stack (z_next_up ), binary_lo , binary_up
0 commit comments