Skip to content

Commit 8cae1a8

Browse files
committed
add more UT
Signed-off-by: Jing Zhang <[email protected]>
1 parent af43b45 commit 8cae1a8

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java

+51-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.opensearch.core.xcontent.NamedXContentRegistry;
2424
import org.opensearch.core.xcontent.ToXContent;
2525
import org.opensearch.core.xcontent.XContentBuilder;
26-
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
2726
import org.opensearch.search.SearchHit;
2827
import org.opensearch.search.SearchHits;
2928
import org.opensearch.search.SearchModule;
@@ -35,16 +34,13 @@
3534

3635
import java.io.IOException;
3736
import java.util.Collections;
38-
import java.util.HashMap;
3937
import java.util.List;
4038
import java.util.Map;
4139
import java.util.concurrent.ExecutionException;
4240
import java.util.concurrent.TimeUnit;
4341
import java.util.concurrent.TimeoutException;
4442

45-
import static org.junit.Assert.*;
4643
import static org.mockito.ArgumentMatchers.any;
47-
import static org.mockito.Mockito.doNothing;
4844
import static org.mockito.Mockito.when;
4945

5046
public class MLGuardTests {
@@ -73,27 +69,72 @@ public void setUp() {
7369
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
7470

7571
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
76-
regex = List.of("regex1").toArray(new String[0]);
72+
regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]);
7773
inputGuardrail = new Guardrail(List.of(stopWords), regex);
7874
outputGuardrail = new Guardrail(List.of(stopWords), regex);
7975
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
8076
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
8177
}
8278

8379
@Test
84-
public void validate() {
80+
public void validateInput() {
81+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
82+
Boolean res = mlGuard.validate(input, 0);
83+
84+
Assert.assertFalse(res);
85+
}
86+
87+
@Test
88+
public void validateOutput() {
89+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
90+
Boolean res = mlGuard.validate(input, 1);
91+
92+
Assert.assertFalse(res);
93+
}
94+
95+
@Test
96+
public void validateRegexListSuccess() {
97+
String input = "\n\nHuman:hello good words.\n\nAssistant:";
98+
List<String> regexList = List.of(regex);
99+
Boolean res = mlGuard.validateRegexList(input, regexList);
100+
101+
Assert.assertTrue(res);
102+
}
103+
104+
@Test
105+
public void validateRegexListFailed() {
106+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
107+
List<String> regexList = List.of(regex);
108+
Boolean res = mlGuard.validateRegexList(input, regexList);
109+
110+
Assert.assertFalse(res);
85111
}
86112

87113
@Test
88-
public void validateRegexList() {
114+
public void validateRegexSuccess() {
115+
String input = "\n\nHuman:hello good words.\n\nAssistant:";
116+
Boolean res = mlGuard.validateRegex(input, regex[0]);
117+
118+
Assert.assertTrue(res);
89119
}
90120

91121
@Test
92-
public void validateRegex() {
122+
public void validateRegexFailed() {
123+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
124+
Boolean res = mlGuard.validateRegex(input, regex[0]);
125+
126+
Assert.assertFalse(res);
93127
}
94128

95129
@Test
96-
public void validateStopWords() {
130+
public void validateStopWords() throws IOException {
131+
Map<String, List<String>> stopWordsIndices = Map.of("test_index", List.of("test_field"));
132+
SearchResponse searchResponse = createSearchResponse(1);
133+
ActionFuture<SearchResponse> future = createSearchResponseFuture(searchResponse);
134+
when(this.client.search(any())).thenReturn(future);
135+
136+
Boolean res = mlGuard.validateStopWords("hello world", stopWordsIndices);
137+
Assert.assertTrue(res);
97138
}
98139

99140
@Test
@@ -103,7 +144,7 @@ public void validateStopWordsSingleIndex() throws IOException {
103144
when(this.client.search(any())).thenReturn(future);
104145

105146
Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field"));
106-
Assert.assertFalse(res);
147+
Assert.assertTrue(res);
107148
}
108149

109150
private SearchResponse createSearchResponse(int size) throws IOException {

0 commit comments

Comments
 (0)