23
23
import org .opensearch .core .xcontent .NamedXContentRegistry ;
24
24
import org .opensearch .core .xcontent .ToXContent ;
25
25
import org .opensearch .core .xcontent .XContentBuilder ;
26
- import org .opensearch .ml .common .conversation .ConversationalIndexConstants ;
27
26
import org .opensearch .search .SearchHit ;
28
27
import org .opensearch .search .SearchHits ;
29
28
import org .opensearch .search .SearchModule ;
35
34
36
35
import java .io .IOException ;
37
36
import java .util .Collections ;
38
- import java .util .HashMap ;
39
37
import java .util .List ;
40
38
import java .util .Map ;
41
39
import java .util .concurrent .ExecutionException ;
42
40
import java .util .concurrent .TimeUnit ;
43
41
import java .util .concurrent .TimeoutException ;
44
42
45
- import static org .junit .Assert .*;
46
43
import static org .mockito .ArgumentMatchers .any ;
47
- import static org .mockito .Mockito .doNothing ;
48
44
import static org .mockito .Mockito .when ;
49
45
50
46
public class MLGuardTests {
@@ -73,27 +69,72 @@ public void setUp() {
73
69
when (this .threadPool .getThreadContext ()).thenReturn (this .threadContext );
74
70
75
71
stopWords = new StopWords ("test_index" , List .of ("test_field" ).toArray (new String [0 ]));
76
- regex = List .of ("regex1 " ).toArray (new String [0 ]);
72
+ regex = List .of ("(.| \n )*stop words(.| \n )* " ).toArray (new String [0 ]);
77
73
inputGuardrail = new Guardrail (List .of (stopWords ), regex );
78
74
outputGuardrail = new Guardrail (List .of (stopWords ), regex );
79
75
guardrails = new Guardrails ("test_type" , false , inputGuardrail , outputGuardrail );
80
76
mlGuard = new MLGuard (guardrails , xContentRegistry , client );
81
77
}
82
78
83
79
@ Test
84
- public void validate () {
80
+ public void validateInput () {
81
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
82
+ Boolean res = mlGuard .validate (input , 0 );
83
+
84
+ Assert .assertFalse (res );
85
+ }
86
+
87
+ @ Test
88
+ public void validateOutput () {
89
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
90
+ Boolean res = mlGuard .validate (input , 1 );
91
+
92
+ Assert .assertFalse (res );
93
+ }
94
+
95
+ @ Test
96
+ public void validateRegexListSuccess () {
97
+ String input = "\n \n Human:hello good words.\n \n Assistant:" ;
98
+ List <String > regexList = List .of (regex );
99
+ Boolean res = mlGuard .validateRegexList (input , regexList );
100
+
101
+ Assert .assertTrue (res );
102
+ }
103
+
104
+ @ Test
105
+ public void validateRegexListFailed () {
106
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
107
+ List <String > regexList = List .of (regex );
108
+ Boolean res = mlGuard .validateRegexList (input , regexList );
109
+
110
+ Assert .assertFalse (res );
85
111
}
86
112
87
113
@ Test
88
- public void validateRegexList () {
114
+ public void validateRegexSuccess () {
115
+ String input = "\n \n Human:hello good words.\n \n Assistant:" ;
116
+ Boolean res = mlGuard .validateRegex (input , regex [0 ]);
117
+
118
+ Assert .assertTrue (res );
89
119
}
90
120
91
121
@ Test
92
- public void validateRegex () {
122
+ public void validateRegexFailed () {
123
+ String input = "\n \n Human:hello stop words.\n \n Assistant:" ;
124
+ Boolean res = mlGuard .validateRegex (input , regex [0 ]);
125
+
126
+ Assert .assertFalse (res );
93
127
}
94
128
95
129
@ Test
96
- public void validateStopWords () {
130
+ public void validateStopWords () throws IOException {
131
+ Map <String , List <String >> stopWordsIndices = Map .of ("test_index" , List .of ("test_field" ));
132
+ SearchResponse searchResponse = createSearchResponse (1 );
133
+ ActionFuture <SearchResponse > future = createSearchResponseFuture (searchResponse );
134
+ when (this .client .search (any ())).thenReturn (future );
135
+
136
+ Boolean res = mlGuard .validateStopWords ("hello world" , stopWordsIndices );
137
+ Assert .assertTrue (res );
97
138
}
98
139
99
140
@ Test
@@ -103,7 +144,7 @@ public void validateStopWordsSingleIndex() throws IOException {
103
144
when (this .client .search (any ())).thenReturn (future );
104
145
105
146
Boolean res = mlGuard .validateStopWordsSingleIndex ("hello world" , "test_index" , List .of ("test_field" ));
106
- Assert .assertFalse (res );
147
+ Assert .assertTrue (res );
107
148
}
108
149
109
150
private SearchResponse createSearchResponse (int size ) throws IOException {
0 commit comments