1
+ /*
2
+ * Copyright OpenSearch Contributors
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ */
5
+
6
+ package org .opensearch .ml .common .model ;
7
+
8
+ import org .apache .lucene .search .TotalHits ;
9
+ import org .junit .Assert ;
10
+ import org .junit .Before ;
11
+ import org .junit .Test ;
12
+ import org .mockito .Mock ;
13
+ import org .mockito .MockitoAnnotations ;
14
+ import org .opensearch .action .search .SearchResponse ;
15
+ import org .opensearch .action .search .ShardSearchFailure ;
16
+ import org .opensearch .client .Client ;
17
+ import org .opensearch .common .action .ActionFuture ;
18
+ import org .opensearch .common .settings .Settings ;
19
+ import org .opensearch .common .unit .TimeValue ;
20
+ import org .opensearch .common .util .concurrent .ThreadContext ;
21
+ import org .opensearch .common .xcontent .XContentFactory ;
22
+ import org .opensearch .core .common .bytes .BytesReference ;
23
+ import org .opensearch .core .xcontent .NamedXContentRegistry ;
24
+ import org .opensearch .core .xcontent .ToXContent ;
25
+ import org .opensearch .core .xcontent .XContentBuilder ;
26
+ import org .opensearch .ml .common .conversation .ConversationalIndexConstants ;
27
+ import org .opensearch .search .SearchHit ;
28
+ import org .opensearch .search .SearchHits ;
29
+ import org .opensearch .search .SearchModule ;
30
+ import org .opensearch .search .aggregations .InternalAggregations ;
31
+ import org .opensearch .search .internal .InternalSearchResponse ;
32
+ import org .opensearch .search .profile .SearchProfileShardResults ;
33
+ import org .opensearch .search .suggest .Suggest ;
34
+ import org .opensearch .threadpool .ThreadPool ;
35
+
36
+ import java .io .IOException ;
37
+ import java .util .Collections ;
38
+ import java .util .HashMap ;
39
+ import java .util .List ;
40
+ import java .util .Map ;
41
+ import java .util .concurrent .ExecutionException ;
42
+ import java .util .concurrent .TimeUnit ;
43
+ import java .util .concurrent .TimeoutException ;
44
+
45
+ import static org .junit .Assert .*;
46
+ import static org .mockito .ArgumentMatchers .any ;
47
+ import static org .mockito .Mockito .doNothing ;
48
+ import static org .mockito .Mockito .when ;
49
+
50
+ public class MLGuardTests {
51
+
52
+ NamedXContentRegistry xContentRegistry ;
53
+ @ Mock
54
+ Client client ;
55
+ @ Mock
56
+ ThreadPool threadPool ;
57
+ ThreadContext threadContext ;
58
+
59
+ StopWords stopWords ;
60
+ String [] regex ;
61
+ Guardrail inputGuardrail ;
62
+ Guardrail outputGuardrail ;
63
+ Guardrails guardrails ;
64
+ MLGuard mlGuard ;
65
+
66
+ @ Before
67
+ public void setUp () {
68
+ MockitoAnnotations .openMocks (this );
69
+ xContentRegistry = new NamedXContentRegistry (new SearchModule (Settings .EMPTY , Collections .emptyList ()).getNamedXContents ());
70
+ Settings settings = Settings .builder ().build ();
71
+ this .threadContext = new ThreadContext (settings );
72
+ when (this .client .threadPool ()).thenReturn (this .threadPool );
73
+ when (this .threadPool .getThreadContext ()).thenReturn (this .threadContext );
74
+
75
+ stopWords = new StopWords ("test_index" , List .of ("test_field" ).toArray (new String [0 ]));
76
+ regex = List .of ("regex1" ).toArray (new String [0 ]);
77
+ inputGuardrail = new Guardrail (List .of (stopWords ), regex );
78
+ outputGuardrail = new Guardrail (List .of (stopWords ), regex );
79
+ guardrails = new Guardrails ("test_type" , false , inputGuardrail , outputGuardrail );
80
+ mlGuard = new MLGuard (guardrails , xContentRegistry , client );
81
+ }
82
+
83
+ @ Test
84
+ public void validate () {
85
+ }
86
+
87
+ @ Test
88
+ public void validateRegexList () {
89
+ }
90
+
91
+ @ Test
92
+ public void validateRegex () {
93
+ }
94
+
95
+ @ Test
96
+ public void validateStopWords () {
97
+ }
98
+
99
+ @ Test
100
+ public void validateStopWordsSingleIndex () throws IOException {
101
+ SearchResponse searchResponse = createSearchResponse (1 );
102
+ ActionFuture <SearchResponse > future = createSearchResponseFuture (searchResponse );
103
+ when (this .client .search (any ())).thenReturn (future );
104
+
105
+ Boolean res = mlGuard .validateStopWordsSingleIndex ("hello world" , "test_index" , List .of ("test_field" ));
106
+ Assert .assertFalse (res );
107
+ }
108
+
109
+ private SearchResponse createSearchResponse (int size ) throws IOException {
110
+ XContentBuilder content = guardrails .toXContent (XContentFactory .jsonBuilder (), ToXContent .EMPTY_PARAMS );
111
+ SearchHit [] hits = new SearchHit [size ];
112
+ if (size > 0 ) {
113
+ hits [0 ] = new SearchHit (0 ).sourceRef (BytesReference .bytes (content ));
114
+ }
115
+ return new SearchResponse (
116
+ new InternalSearchResponse (
117
+ new SearchHits (hits , new TotalHits (size , TotalHits .Relation .EQUAL_TO ), 1.0f ),
118
+ InternalAggregations .EMPTY ,
119
+ new Suggest (Collections .emptyList ()),
120
+ new SearchProfileShardResults (Collections .emptyMap ()),
121
+ false ,
122
+ false ,
123
+ 1
124
+ ),
125
+ "" ,
126
+ 5 ,
127
+ 5 ,
128
+ 0 ,
129
+ 100 ,
130
+ ShardSearchFailure .EMPTY_ARRAY ,
131
+ SearchResponse .Clusters .EMPTY
132
+ );
133
+ }
134
+
135
+ private ActionFuture <SearchResponse > createSearchResponseFuture (SearchResponse searchResponse ) {
136
+ return new ActionFuture <>() {
137
+ @ Override
138
+ public SearchResponse actionGet () {
139
+ return searchResponse ;
140
+ }
141
+
142
+ @ Override
143
+ public SearchResponse actionGet (String timeout ) {
144
+ return searchResponse ;
145
+ }
146
+
147
+ @ Override
148
+ public SearchResponse actionGet (long timeoutMillis ) {
149
+ return searchResponse ;
150
+ }
151
+
152
+ @ Override
153
+ public SearchResponse actionGet (long timeout , TimeUnit unit ) {
154
+ return searchResponse ;
155
+ }
156
+
157
+ @ Override
158
+ public SearchResponse actionGet (TimeValue timeout ) {
159
+ return searchResponse ;
160
+ }
161
+
162
+ @ Override
163
+ public boolean cancel (boolean mayInterruptIfRunning ) {
164
+ return false ;
165
+ }
166
+
167
+ @ Override
168
+ public boolean isCancelled () {
169
+ return false ;
170
+ }
171
+
172
+ @ Override
173
+ public boolean isDone () {
174
+ return false ;
175
+ }
176
+
177
+ @ Override
178
+ public SearchResponse get () throws InterruptedException , ExecutionException {
179
+ return searchResponse ;
180
+ }
181
+
182
+ @ Override
183
+ public SearchResponse get (long timeout , TimeUnit unit ) throws InterruptedException , ExecutionException , TimeoutException {
184
+ return searchResponse ;
185
+ }
186
+ };
187
+ }
188
+ }
0 commit comments