@@ -298,7 +298,7 @@ def _compute_linear_output_bound_by_lp(
298
298
) if linear_layer .bias is None else linear_layer .bias [j ]
299
299
Ain_linear_input , Ain_neuron_output , Ain_neuron_binary ,\
300
300
rhs_in , Aeq_linear_input , Aeq_neuron_output ,\
301
- Aeq_neuron_binary , rhs_eq , _ , _ = \
301
+ Aeq_neuron_binary , rhs_eq , _ , _ , _ , _ = \
302
302
_add_constraint_by_neuron (
303
303
linear_layer .weight [j ], bij , relu_layer ,
304
304
torch .tensor (previous_neuron_input_lo [
@@ -588,7 +588,14 @@ def output_constraint(self, x_lo, x_up,
588
588
(self .model [- 1 ].out_features , ), dtype = self .dtype )
589
589
else :
590
590
mip_constr_return .Cout = self .model [- 1 ].bias .clone ()
591
-
591
+ binary_lo = torch .zeros ((self .num_relu_units ,), dtype = self .dtype )
592
+ binary_up = torch .ones ((self .num_relu_units ,), dtype = self .dtype )
593
+ # If the input to the relu is always >= 0, then the relu will always
594
+ # be active.
595
+ binary_lo [z_pre_relu_lo >= 0 ] = 1.
596
+ # If the input to the relu is always <= 0, then the relu will always
597
+ # be inactive.
598
+ binary_up [z_pre_relu_up <= 0 ] = 0.
592
599
mip_constr_return .Ain_input = Ain_input [:ineq_constr_count ]
593
600
mip_constr_return .Ain_slack = Ain_slack [:ineq_constr_count ]
594
601
mip_constr_return .Ain_binary = Ain_binary [:ineq_constr_count ]
@@ -597,6 +604,10 @@ def output_constraint(self, x_lo, x_up,
597
604
mip_constr_return .Aeq_slack = Aeq_slack [:eq_constr_count ]
598
605
mip_constr_return .Aeq_binary = Aeq_binary [:eq_constr_count ]
599
606
mip_constr_return .rhs_eq = rhs_eq [:eq_constr_count ]
607
+ mip_constr_return .input_lo = x_lo
608
+ mip_constr_return .input_up = x_up
609
+ mip_constr_return .binary_lo = binary_lo
610
+ mip_constr_return .binary_up = binary_up
600
611
return (mip_constr_return , z_pre_relu_lo , z_pre_relu_up ,
601
612
z_post_relu_lo , z_post_relu_up , output_lo , output_up )
602
613
@@ -994,6 +1005,8 @@ def _add_constraint_by_neuron(
994
1005
Aeq_neuron_output = torch .empty ((0 , 1 ), dtype = dtype )
995
1006
Aeq_binary = torch .empty ((0 , 1 ), dtype = dtype )
996
1007
rhs_eq = torch .empty ((0 , ), dtype = dtype )
1008
+ binary_lo = 0
1009
+ binary_up = 1
997
1010
else :
998
1011
# The (leaky) ReLU is always active, or always inactive. If
999
1012
# the lower bound output_lo[j] >= 0, then it is always active,
@@ -1004,18 +1017,17 @@ def _add_constraint_by_neuron(
1004
1017
# zᵢ₊₁(j) = c*((Wᵢzᵢ)(j) + bᵢ(j)) and βᵢ(j) = 0
1005
1018
if neuron_input_lo >= 0 :
1006
1019
slope = 1.
1007
- binary_value = 1
1020
+ binary_lo = 1
1021
+ binary_up = 1
1008
1022
elif neuron_input_up <= 0 :
1009
1023
slope = relu_layer .negative_slope if isinstance (
1010
1024
relu_layer , nn .LeakyReLU ) else 0.
1011
- binary_value = 0.
1012
- Aeq_linear_input = torch .cat ((- slope * Wij .reshape (
1013
- (1 , - 1 )), torch .zeros ((1 , Wij .numel ()), dtype = dtype )),
1014
- dim = 0 )
1015
- Aeq_neuron_output = torch .tensor ([[1. ], [0 ]], dtype = dtype )
1016
- Aeq_binary = torch .tensor ([[0. ], [1. ]], dtype = dtype )
1017
- rhs_eq = torch .stack (
1018
- (slope * bij , torch .tensor (binary_value , dtype = dtype )))
1025
+ binary_lo = 0
1026
+ binary_up = 1
1027
+ Aeq_linear_input = - slope * Wij .reshape ((1 , - 1 ))
1028
+ Aeq_neuron_output = torch .tensor ([[1. ]], dtype = dtype )
1029
+ Aeq_binary = torch .tensor ([[0. ]], dtype = dtype )
1030
+ rhs_eq = slope * bij .reshape ((1 ,))
1019
1031
Ain_linear_input = torch .empty ((0 , Wij .numel ()), dtype = dtype )
1020
1032
Ain_neuron_output = torch .empty ((0 , 1 ), dtype = dtype )
1021
1033
Ain_binary = torch .empty ((0 , 1 ), dtype = dtype )
@@ -1024,7 +1036,7 @@ def _add_constraint_by_neuron(
1024
1036
relu_layer , neuron_input_lo , neuron_input_up )
1025
1037
return Ain_linear_input , Ain_neuron_output , Ain_binary , rhs_in ,\
1026
1038
Aeq_linear_input , Aeq_neuron_output , Aeq_binary , rhs_eq ,\
1027
- neuron_output_lo , neuron_output_up
1039
+ neuron_output_lo , neuron_output_up , binary_lo , binary_up
1028
1040
1029
1041
1030
1042
def _add_constraint_by_layer (linear_layer , relu_layer ,
@@ -1057,10 +1069,13 @@ def _add_constraint_by_layer(linear_layer, relu_layer,
1057
1069
z_next_up = []
1058
1070
bias = linear_layer .bias if linear_layer .bias is not None else \
1059
1071
torch .zeros ((linear_layer .out_features ,), dtype = dtype )
1072
+ binary_lo = torch .zeros ((linear_layer .out_features ,), dtype = dtype )
1073
+ binary_up = torch .ones ((linear_layer .out_features ,), dtype = dtype )
1060
1074
for j in range (linear_layer .out_features ):
1061
1075
Ain_linear_input , Ain_neuron_output , Ain_binary_j , rhs_in_j ,\
1062
1076
Aeq_linear_input , Aeq_neuron_output , Aeq_binary_j , rhs_eq_j ,\
1063
- neuron_output_lo , neuron_output_up = _add_constraint_by_neuron (
1077
+ neuron_output_lo , neuron_output_up , binary_lo [j ], binary_up [j ] = \
1078
+ _add_constraint_by_neuron (
1064
1079
linear_layer .weight [j ], bias [j ], relu_layer ,
1065
1080
linear_output_lo [j ], linear_output_up [j ])
1066
1081
Ain_z_curr .append (Ain_linear_input )
@@ -1089,4 +1104,4 @@ def _add_constraint_by_layer(linear_layer, relu_layer,
1089
1104
torch .cat (Ain_binary , dim = 0 ), torch .cat (rhs_in , dim = 0 ),\
1090
1105
torch .cat (Aeq_z_curr , dim = 0 ), torch .cat (Aeq_z_next , dim = 0 ),\
1091
1106
torch .cat (Aeq_binary , dim = 0 ), torch .cat (rhs_eq , dim = 0 ),\
1092
- torch .stack (z_next_lo ), torch .stack (z_next_up )
1107
+ torch .stack (z_next_lo ), torch .stack (z_next_up ), binary_lo , binary_up
0 commit comments