Skip to content

Commit 1e39b7b

Browse files
committed
add some UT
Signed-off-by: Jing Zhang <[email protected]>
1 parent 6f2af00 commit 1e39b7b

File tree

8 files changed

+419
-0
lines changed

8 files changed

+419
-0
lines changed

common/src/main/java/org/opensearch/ml/common/model/Guardrail.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.model;
27

38
import lombok.Builder;

common/src/main/java/org/opensearch/ml/common/model/Guardrails.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.model;
27

38
import lombok.Builder;

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.model;
27

38
import lombok.Getter;

common/src/main/java/org/opensearch/ml/common/model/StopWords.java

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.model;
27

38
import lombok.Builder;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.model;
7+
8+
import org.junit.Assert;
9+
import org.junit.Before;
10+
import org.junit.Test;
11+
import org.opensearch.common.io.stream.BytesStreamOutput;
12+
import org.opensearch.common.settings.Settings;
13+
import org.opensearch.common.xcontent.XContentType;
14+
import org.opensearch.core.xcontent.NamedXContentRegistry;
15+
import org.opensearch.core.xcontent.ToXContent;
16+
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.core.xcontent.XContentParser;
18+
import org.opensearch.ml.common.TestHelper;
19+
import org.opensearch.search.SearchModule;
20+
21+
import java.io.IOException;
22+
import java.util.Collections;
23+
import java.util.List;
24+
25+
import static org.junit.Assert.*;
26+
27+
public class GuardrailTests {
28+
StopWords stopWords;
29+
String[] regex;
30+
31+
@Before
32+
public void setUp() {
33+
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
34+
regex = List.of("regex1").toArray(new String[0]);
35+
}
36+
37+
@Test
38+
public void writeTo() throws IOException {
39+
Guardrail guardrail = new Guardrail(List.of(stopWords), regex);
40+
BytesStreamOutput output = new BytesStreamOutput();
41+
guardrail.writeTo(output);
42+
Guardrail guardrail1 = new Guardrail(output.bytes().streamInput());
43+
44+
Assert.assertArrayEquals(guardrail.getStopWords().toArray(), guardrail1.getStopWords().toArray());
45+
Assert.assertArrayEquals(guardrail.getRegex(), guardrail1.getRegex());
46+
}
47+
48+
@Test
49+
public void toXContent() throws IOException {
50+
Guardrail guardrail = new Guardrail(List.of(stopWords), regex);
51+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
52+
guardrail.toXContent(builder, ToXContent.EMPTY_PARAMS);
53+
String content = TestHelper.xContentBuilderToString(builder);
54+
55+
Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}", content);
56+
}
57+
58+
@Test
59+
public void parse() throws IOException {
60+
String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}";
61+
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
62+
Collections.emptyList()).getNamedXContents()), null, jsonStr);
63+
parser.nextToken();
64+
Guardrail guardrail = Guardrail.parse(parser);
65+
66+
Assert.assertArrayEquals(guardrail.getStopWords().toArray(), List.of(stopWords).toArray());
67+
Assert.assertArrayEquals(guardrail.getRegex(), regex);
68+
}
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.model;
7+
8+
import org.junit.Assert;
9+
import org.junit.Before;
10+
import org.junit.Test;
11+
import org.opensearch.common.io.stream.BytesStreamOutput;
12+
import org.opensearch.common.settings.Settings;
13+
import org.opensearch.common.xcontent.XContentType;
14+
import org.opensearch.core.xcontent.NamedXContentRegistry;
15+
import org.opensearch.core.xcontent.ToXContent;
16+
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.core.xcontent.XContentParser;
18+
import org.opensearch.ml.common.TestHelper;
19+
import org.opensearch.search.SearchModule;
20+
21+
import java.io.IOException;
22+
import java.util.Collections;
23+
import java.util.List;
24+
25+
import static org.junit.Assert.*;
26+
27+
public class GuardrailsTests {
28+
StopWords stopWords;
29+
String[] regex;
30+
Guardrail inputGuardrail;
31+
Guardrail outputGuardrail;
32+
33+
@Before
34+
public void setUp() {
35+
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
36+
regex = List.of("regex1").toArray(new String[0]);
37+
inputGuardrail = new Guardrail(List.of(stopWords), regex);
38+
outputGuardrail = new Guardrail(List.of(stopWords), regex);
39+
}
40+
41+
@Test
42+
public void writeTo() throws IOException {
43+
Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
44+
BytesStreamOutput output = new BytesStreamOutput();
45+
guardrails.writeTo(output);
46+
Guardrails guardrails1 = new Guardrails(output.bytes().streamInput());
47+
48+
Assert.assertEquals(guardrails.getType(), guardrails1.getType());
49+
Assert.assertEquals(guardrails.getEngDetectionEnabled(), guardrails1.getEngDetectionEnabled());
50+
Assert.assertEquals(guardrails.getInputGuardrail(), guardrails1.getInputGuardrail());
51+
Assert.assertEquals(guardrails.getOutputGuardrail(), guardrails1.getOutputGuardrail());
52+
}
53+
54+
@Test
55+
public void toXContent() throws IOException {
56+
Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
57+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
58+
guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS);
59+
String content = TestHelper.xContentBuilderToString(builder);
60+
61+
Assert.assertEquals("{\"type\":\"test_type\"," +
62+
"\"english_detection_enabled\":false," +
63+
"\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," +
64+
"\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}",
65+
content);
66+
}
67+
68+
@Test
69+
public void parse() throws IOException {
70+
String jsonStr = "{\"type\":\"test_type\"," +
71+
"\"english_detection_enabled\":false," +
72+
"\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," +
73+
"\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}";
74+
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
75+
Collections.emptyList()).getNamedXContents()), null, jsonStr);
76+
parser.nextToken();
77+
Guardrails guardrails = Guardrails.parse(parser);
78+
79+
Assert.assertEquals(guardrails.getType(), "test_type");
80+
Assert.assertEquals(guardrails.getEngDetectionEnabled(), false);
81+
Assert.assertEquals(guardrails.getInputGuardrail(), inputGuardrail);
82+
Assert.assertEquals(guardrails.getOutputGuardrail(), outputGuardrail);
83+
}
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.model;
7+
8+
import org.apache.lucene.search.TotalHits;
9+
import org.junit.Assert;
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
import org.mockito.Mock;
13+
import org.mockito.MockitoAnnotations;
14+
import org.opensearch.action.search.SearchResponse;
15+
import org.opensearch.action.search.ShardSearchFailure;
16+
import org.opensearch.client.Client;
17+
import org.opensearch.common.action.ActionFuture;
18+
import org.opensearch.common.settings.Settings;
19+
import org.opensearch.common.unit.TimeValue;
20+
import org.opensearch.common.util.concurrent.ThreadContext;
21+
import org.opensearch.common.xcontent.XContentFactory;
22+
import org.opensearch.core.common.bytes.BytesReference;
23+
import org.opensearch.core.xcontent.NamedXContentRegistry;
24+
import org.opensearch.core.xcontent.ToXContent;
25+
import org.opensearch.core.xcontent.XContentBuilder;
26+
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
27+
import org.opensearch.search.SearchHit;
28+
import org.opensearch.search.SearchHits;
29+
import org.opensearch.search.SearchModule;
30+
import org.opensearch.search.aggregations.InternalAggregations;
31+
import org.opensearch.search.internal.InternalSearchResponse;
32+
import org.opensearch.search.profile.SearchProfileShardResults;
33+
import org.opensearch.search.suggest.Suggest;
34+
import org.opensearch.threadpool.ThreadPool;
35+
36+
import java.io.IOException;
37+
import java.util.Collections;
38+
import java.util.HashMap;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.concurrent.ExecutionException;
42+
import java.util.concurrent.TimeUnit;
43+
import java.util.concurrent.TimeoutException;
44+
45+
import static org.junit.Assert.*;
46+
import static org.mockito.ArgumentMatchers.any;
47+
import static org.mockito.Mockito.doNothing;
48+
import static org.mockito.Mockito.when;
49+
50+
public class MLGuardTests {
51+
52+
NamedXContentRegistry xContentRegistry;
53+
@Mock
54+
Client client;
55+
@Mock
56+
ThreadPool threadPool;
57+
ThreadContext threadContext;
58+
59+
StopWords stopWords;
60+
String[] regex;
61+
Guardrail inputGuardrail;
62+
Guardrail outputGuardrail;
63+
Guardrails guardrails;
64+
MLGuard mlGuard;
65+
66+
@Before
67+
public void setUp() {
68+
MockitoAnnotations.openMocks(this);
69+
xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
70+
Settings settings = Settings.builder().build();
71+
this.threadContext = new ThreadContext(settings);
72+
when(this.client.threadPool()).thenReturn(this.threadPool);
73+
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
74+
75+
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
76+
regex = List.of("regex1").toArray(new String[0]);
77+
inputGuardrail = new Guardrail(List.of(stopWords), regex);
78+
outputGuardrail = new Guardrail(List.of(stopWords), regex);
79+
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
80+
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
81+
}
82+
83+
@Test
84+
public void validate() {
85+
}
86+
87+
@Test
88+
public void validateRegexList() {
89+
}
90+
91+
@Test
92+
public void validateRegex() {
93+
}
94+
95+
@Test
96+
public void validateStopWords() {
97+
}
98+
99+
@Test
100+
public void validateStopWordsSingleIndex() throws IOException {
101+
SearchResponse searchResponse = createSearchResponse(1);
102+
ActionFuture<SearchResponse> future = createSearchResponseFuture(searchResponse);
103+
when(this.client.search(any())).thenReturn(future);
104+
105+
Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field"));
106+
Assert.assertFalse(res);
107+
}
108+
109+
private SearchResponse createSearchResponse(int size) throws IOException {
110+
XContentBuilder content = guardrails.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
111+
SearchHit[] hits = new SearchHit[size];
112+
if (size > 0) {
113+
hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content));
114+
}
115+
return new SearchResponse(
116+
new InternalSearchResponse(
117+
new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f),
118+
InternalAggregations.EMPTY,
119+
new Suggest(Collections.emptyList()),
120+
new SearchProfileShardResults(Collections.emptyMap()),
121+
false,
122+
false,
123+
1
124+
),
125+
"",
126+
5,
127+
5,
128+
0,
129+
100,
130+
ShardSearchFailure.EMPTY_ARRAY,
131+
SearchResponse.Clusters.EMPTY
132+
);
133+
}
134+
135+
private ActionFuture<SearchResponse> createSearchResponseFuture(SearchResponse searchResponse) {
136+
return new ActionFuture<>() {
137+
@Override
138+
public SearchResponse actionGet() {
139+
return searchResponse;
140+
}
141+
142+
@Override
143+
public SearchResponse actionGet(String timeout) {
144+
return searchResponse;
145+
}
146+
147+
@Override
148+
public SearchResponse actionGet(long timeoutMillis) {
149+
return searchResponse;
150+
}
151+
152+
@Override
153+
public SearchResponse actionGet(long timeout, TimeUnit unit) {
154+
return searchResponse;
155+
}
156+
157+
@Override
158+
public SearchResponse actionGet(TimeValue timeout) {
159+
return searchResponse;
160+
}
161+
162+
@Override
163+
public boolean cancel(boolean mayInterruptIfRunning) {
164+
return false;
165+
}
166+
167+
@Override
168+
public boolean isCancelled() {
169+
return false;
170+
}
171+
172+
@Override
173+
public boolean isDone() {
174+
return false;
175+
}
176+
177+
@Override
178+
public SearchResponse get() throws InterruptedException, ExecutionException {
179+
return searchResponse;
180+
}
181+
182+
@Override
183+
public SearchResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
184+
return searchResponse;
185+
}
186+
};
187+
}
188+
}

0 commit comments

Comments
 (0)