1
+ /*
2
+ * Copyright OpenSearch Contributors
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ */
5
+
6
+ package org .opensearch .ml .common .model ;
7
+
8
+ import org .apache .lucene .search .TotalHits ;
9
+ import org .junit .Assert ;
10
+ import org .junit .Before ;
11
+ import org .junit .Test ;
12
+ import org .mockito .Mock ;
13
+ import org .mockito .MockitoAnnotations ;
14
+ import org .opensearch .action .search .SearchResponse ;
15
+ import org .opensearch .action .search .ShardSearchFailure ;
16
+ import org .opensearch .client .Client ;
17
+ import org .opensearch .common .action .ActionFuture ;
18
+ import org .opensearch .common .settings .Settings ;
19
+ import org .opensearch .common .unit .TimeValue ;
20
+ import org .opensearch .common .util .concurrent .ThreadContext ;
21
+ import org .opensearch .common .xcontent .XContentFactory ;
22
+ import org .opensearch .core .common .bytes .BytesReference ;
23
+ import org .opensearch .core .xcontent .NamedXContentRegistry ;
24
+ import org .opensearch .core .xcontent .ToXContent ;
25
+ import org .opensearch .core .xcontent .XContentBuilder ;
26
+ import org .opensearch .search .SearchHit ;
27
+ import org .opensearch .search .SearchHits ;
28
+ import org .opensearch .search .SearchModule ;
29
+ import org .opensearch .search .aggregations .InternalAggregations ;
30
+ import org .opensearch .search .internal .InternalSearchResponse ;
31
+ import org .opensearch .search .profile .SearchProfileShardResults ;
32
+ import org .opensearch .search .suggest .Suggest ;
33
+ import org .opensearch .threadpool .ThreadPool ;
34
+
35
+ import java .io .IOException ;
36
+ import java .util .Collections ;
37
+ import java .util .List ;
38
+ import java .util .Map ;
39
+ import java .util .concurrent .ExecutionException ;
40
+ import java .util .concurrent .TimeUnit ;
41
+ import java .util .concurrent .TimeoutException ;
42
+ import java .util .regex .Pattern ;
43
+
44
+ import static org .mockito .ArgumentMatchers .any ;
45
+ import static org .mockito .Mockito .when ;
46
+
47
+ public class MLGuardTests {
48
+
49
+ NamedXContentRegistry xContentRegistry ;
50
+ @ Mock
51
+ Client client ;
52
+ @ Mock
53
+ ThreadPool threadPool ;
54
+ ThreadContext threadContext ;
55
+
56
+ StopWords stopWords ;
57
+ String [] regex ;
58
+ List <Pattern > regexPatterns ;
59
+ Guardrail inputGuardrail ;
60
+ Guardrail outputGuardrail ;
61
+ Guardrails guardrails ;
62
+ MLGuard mlGuard ;
63
+
64
+ @ Before
65
+ public void setUp () {
66
+ MockitoAnnotations .openMocks (this );
67
+ xContentRegistry = new NamedXContentRegistry (new SearchModule (Settings .EMPTY , Collections .emptyList ()).getNamedXContents ());
68
+ Settings settings = Settings .builder ().build ();
69
+ this .threadContext = new ThreadContext (settings );
70
+ when (this .client .threadPool ()).thenReturn (this .threadPool );
71
+ when (this .threadPool .getThreadContext ()).thenReturn (this .threadContext );
72
+
73
+ stopWords = new StopWords ("test_index" , List .of ("test_field" ).toArray (new String [0 ]));
74
+ regex = List .of ("(.|\n )*stop words(.|\n )*" ).toArray (new String [0 ]);
75
+ regexPatterns = List .of (Pattern .compile ("(.|\n )*stop words(.|\n )*" ));
76
+ inputGuardrail = new Guardrail (List .of (stopWords ), regex );
77
+ outputGuardrail = new Guardrail (List .of (stopWords ), regex );
78
+ guardrails = new Guardrails ("test_type" , false , inputGuardrail , outputGuardrail );
79
+ mlGuard = new MLGuard (guardrails , xContentRegistry , client );
80
+ }
81
+
82
+ @ Test
83
+ public void validateInput () {
84
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
85
+ Boolean res = mlGuard .validate (input , 0 );
86
+
87
+ Assert .assertFalse (res );
88
+ }
89
+
90
+ @ Test
91
+ public void validateOutput () {
92
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
93
+ Boolean res = mlGuard .validate (input , 1 );
94
+
95
+ Assert .assertFalse (res );
96
+ }
97
+
98
+ @ Test
99
+ public void validateRegexListSuccess () {
100
+ String input = "\n \n Human:hello good words.\n \n Assistant:" ;
101
+ Boolean res = mlGuard .validateRegexList (input , regexPatterns );
102
+
103
+ Assert .assertTrue (res );
104
+ }
105
+
106
+ @ Test
107
+ public void validateRegexListFailed () {
108
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
109
+ Boolean res = mlGuard .validateRegexList (input , regexPatterns );
110
+
111
+ Assert .assertFalse (res );
112
+ }
113
+
114
+ @ Test
115
+ public void validateRegexSuccess () {
116
+ String input = "\n \n Human:hello good words.\n \n Assistant:" ;
117
+ Boolean res = mlGuard .validateRegex (input , regexPatterns .get (0 ));
118
+
119
+ Assert .assertTrue (res );
120
+ }
121
+
122
+ @ Test
123
+ public void validateRegexFailed () {
124
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
125
+ Boolean res = mlGuard .validateRegex (input , regexPatterns .get (0 ));
126
+
127
+ Assert .assertFalse (res );
128
+ }
129
+
130
+ @ Test
131
+ public void validateStopWords () throws IOException {
132
+ Map <String , List <String >> stopWordsIndices = Map .of ("test_index" , List .of ("test_field" ));
133
+ SearchResponse searchResponse = createSearchResponse (1 );
134
+ ActionFuture <SearchResponse > future = createSearchResponseFuture (searchResponse );
135
+ when (this .client .search (any ())).thenReturn (future );
136
+
137
+ Boolean res = mlGuard .validateStopWords ("hello world" , stopWordsIndices );
138
+ Assert .assertTrue (res );
139
+ }
140
+
141
+ @ Test
142
+ public void validateStopWordsSingleIndex () throws IOException {
143
+ SearchResponse searchResponse = createSearchResponse (1 );
144
+ ActionFuture <SearchResponse > future = createSearchResponseFuture (searchResponse );
145
+ when (this .client .search (any ())).thenReturn (future );
146
+
147
+ Boolean res = mlGuard .validateStopWordsSingleIndex ("hello world" , "test_index" , List .of ("test_field" ));
148
+ Assert .assertTrue (res );
149
+ }
150
+
151
+ private SearchResponse createSearchResponse (int size ) throws IOException {
152
+ XContentBuilder content = guardrails .toXContent (XContentFactory .jsonBuilder (), ToXContent .EMPTY_PARAMS );
153
+ SearchHit [] hits = new SearchHit [size ];
154
+ if (size > 0 ) {
155
+ hits [0 ] = new SearchHit (0 ).sourceRef (BytesReference .bytes (content ));
156
+ }
157
+ return new SearchResponse (
158
+ new InternalSearchResponse (
159
+ new SearchHits (hits , new TotalHits (size , TotalHits .Relation .EQUAL_TO ), 1.0f ),
160
+ InternalAggregations .EMPTY ,
161
+ new Suggest (Collections .emptyList ()),
162
+ new SearchProfileShardResults (Collections .emptyMap ()),
163
+ false ,
164
+ false ,
165
+ 1
166
+ ),
167
+ "" ,
168
+ 5 ,
169
+ 5 ,
170
+ 0 ,
171
+ 100 ,
172
+ ShardSearchFailure .EMPTY_ARRAY ,
173
+ SearchResponse .Clusters .EMPTY
174
+ );
175
+ }
176
+
177
+ private ActionFuture <SearchResponse > createSearchResponseFuture (SearchResponse searchResponse ) {
178
+ return new ActionFuture <>() {
179
+ @ Override
180
+ public SearchResponse actionGet () {
181
+ return searchResponse ;
182
+ }
183
+
184
+ @ Override
185
+ public SearchResponse actionGet (String timeout ) {
186
+ return searchResponse ;
187
+ }
188
+
189
+ @ Override
190
+ public SearchResponse actionGet (long timeoutMillis ) {
191
+ return searchResponse ;
192
+ }
193
+
194
+ @ Override
195
+ public SearchResponse actionGet (long timeout , TimeUnit unit ) {
196
+ return searchResponse ;
197
+ }
198
+
199
+ @ Override
200
+ public SearchResponse actionGet (TimeValue timeout ) {
201
+ return searchResponse ;
202
+ }
203
+
204
+ @ Override
205
+ public boolean cancel (boolean mayInterruptIfRunning ) {
206
+ return false ;
207
+ }
208
+
209
+ @ Override
210
+ public boolean isCancelled () {
211
+ return false ;
212
+ }
213
+
214
+ @ Override
215
+ public boolean isDone () {
216
+ return false ;
217
+ }
218
+
219
+ @ Override
220
+ public SearchResponse get () throws InterruptedException , ExecutionException {
221
+ return searchResponse ;
222
+ }
223
+
224
+ @ Override
225
+ public SearchResponse get (long timeout , TimeUnit unit ) throws InterruptedException , ExecutionException , TimeoutException {
226
+ return searchResponse ;
227
+ }
228
+ };
229
+ }
230
+ }
0 commit comments