Skip to content

Commit 0e5cb62

Browse files
committed
address comments
1 parent 9206764 commit 0e5cb62

File tree

4 files changed

+259
-16
lines changed

4 files changed

+259
-16
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class CommonValue {
5656
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
5757
public static final String ML_TASK_INDEX = ".plugins-ml-task";
5858
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
59-
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
59+
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
6060
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
6161
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
6262
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

+12-8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.concurrent.atomic.AtomicReference;
3131
import java.util.regex.Matcher;
3232
import java.util.regex.Pattern;
33+
import java.util.stream.Collectors;
3334

3435
import static java.util.concurrent.TimeUnit.SECONDS;
3536
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
@@ -42,6 +43,8 @@ public class MLGuard {
4243
private Map<String, List<String>> stopWordsIndicesOutput = new HashMap<>();
4344
private List<String> inputRegex;
4445
private List<String> outputRegex;
46+
private List<Pattern> inputRegexPattern;
47+
private List<Pattern> outputRegexPattern;
4548
private NamedXContentRegistry xContentRegistry;
4649
private Client client;
4750

@@ -56,10 +59,12 @@ public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Cl
5659
if (inputGuardrail != null) {
5760
fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput);
5861
inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex());
62+
inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
5963
}
6064
if (outputGuardrail != null) {
6165
fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput);
6266
outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex());
67+
outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
6368
}
6469
}
6570

@@ -76,25 +81,24 @@ private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map<String, List<S
7681
public Boolean validate(String input, int type) {
7782
switch (type) {
7883
case 0: // validate input
79-
return validateRegexList(input, inputRegex) && validateStopWords(input, stopWordsIndicesInput);
84+
return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput);
8085
case 1: // validate output
81-
return validateRegexList(input, outputRegex) && validateStopWords(input, stopWordsIndicesOutput);
86+
return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput);
8287
default:
83-
return true;
88+
throw new IllegalArgumentException("Unsupported type to validate for guardrails.");
8489
}
8590
}
8691

87-
public Boolean validateRegexList(String input, List<String> regexList) {
88-
for (String regex : regexList) {
89-
if (!validateRegex(input, regex)) {
92+
public Boolean validateRegexList(String input, List<Pattern> regexPatterns) {
93+
for (Pattern pattern : regexPatterns) {
94+
if (!validateRegex(input, pattern)) {
9095
return false;
9196
}
9297
}
9398
return true;
9499
}
95100

96-
public Boolean validateRegex(String input, String regex) {
97-
Pattern pattern = Pattern.compile(regex);
101+
public Boolean validateRegex(String input, Pattern pattern) {
98102
Matcher matcher = pattern.matcher(input);
99103
return !matcher.matches();
100104
}

common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java

+16-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import lombok.Data;
99
import lombok.Builder;
1010
import lombok.Getter;
11+
import org.opensearch.Version;
1112
import org.opensearch.core.common.io.stream.StreamInput;
1213
import org.opensearch.core.common.io.stream.StreamOutput;
1314
import org.opensearch.core.common.io.stream.Writeable;
@@ -46,6 +47,8 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
4647
// request
4748
public static final String GUARDRAILS_FIELD = "guardrails";
4849

50+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS = Version.V_2_13_0;
51+
4952
@Getter
5053
private String modelId;
5154
private String description;
@@ -81,6 +84,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St
8184
}
8285

8386
public MLUpdateModelInput(StreamInput in) throws IOException {
87+
Version streamInputVersion = in.getVersion();
8488
modelId = in.readString();
8589
description = in.readOptionalString();
8690
version = in.readOptionalString();
@@ -101,8 +105,10 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
101105
connector = new MLCreateConnectorInput(in);
102106
}
103107
lastUpdateTime = in.readOptionalInstant();
104-
if (in.readBoolean()) {
105-
this.guardrails = new Guardrails(in);
108+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
109+
if (in.readBoolean()) {
110+
this.guardrails = new Guardrails(in);
111+
}
106112
}
107113
}
108114

@@ -193,6 +199,7 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa
193199

194200
@Override
195201
public void writeTo(StreamOutput out) throws IOException {
202+
Version streamOutputVersion = out.getVersion();
196203
out.writeString(modelId);
197204
out.writeOptionalString(description);
198205
out.writeOptionalString(version);
@@ -225,11 +232,13 @@ public void writeTo(StreamOutput out) throws IOException {
225232
out.writeBoolean(false);
226233
}
227234
out.writeOptionalInstant(lastUpdateTime);
228-
if (guardrails != null) {
229-
out.writeBoolean(true);
230-
guardrails.writeTo(out);
231-
} else {
232-
out.writeBoolean(false);
235+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
236+
if (guardrails != null) {
237+
out.writeBoolean(true);
238+
guardrails.writeTo(out);
239+
} else {
240+
out.writeBoolean(false);
241+
}
233242
}
234243
}
235244

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.model;
7+
8+
import org.apache.lucene.search.TotalHits;
9+
import org.junit.Assert;
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
import org.mockito.Mock;
13+
import org.mockito.MockitoAnnotations;
14+
import org.opensearch.action.search.SearchResponse;
15+
import org.opensearch.action.search.ShardSearchFailure;
16+
import org.opensearch.client.Client;
17+
import org.opensearch.common.action.ActionFuture;
18+
import org.opensearch.common.settings.Settings;
19+
import org.opensearch.common.unit.TimeValue;
20+
import org.opensearch.common.util.concurrent.ThreadContext;
21+
import org.opensearch.common.xcontent.XContentFactory;
22+
import org.opensearch.core.common.bytes.BytesReference;
23+
import org.opensearch.core.xcontent.NamedXContentRegistry;
24+
import org.opensearch.core.xcontent.ToXContent;
25+
import org.opensearch.core.xcontent.XContentBuilder;
26+
import org.opensearch.search.SearchHit;
27+
import org.opensearch.search.SearchHits;
28+
import org.opensearch.search.SearchModule;
29+
import org.opensearch.search.aggregations.InternalAggregations;
30+
import org.opensearch.search.internal.InternalSearchResponse;
31+
import org.opensearch.search.profile.SearchProfileShardResults;
32+
import org.opensearch.search.suggest.Suggest;
33+
import org.opensearch.threadpool.ThreadPool;
34+
35+
import java.io.IOException;
36+
import java.util.Collections;
37+
import java.util.List;
38+
import java.util.Map;
39+
import java.util.concurrent.ExecutionException;
40+
import java.util.concurrent.TimeUnit;
41+
import java.util.concurrent.TimeoutException;
42+
import java.util.regex.Pattern;
43+
44+
import static org.mockito.ArgumentMatchers.any;
45+
import static org.mockito.Mockito.when;
46+
47+
public class MLGuardTests {
48+
49+
NamedXContentRegistry xContentRegistry;
50+
@Mock
51+
Client client;
52+
@Mock
53+
ThreadPool threadPool;
54+
ThreadContext threadContext;
55+
56+
StopWords stopWords;
57+
String[] regex;
58+
List<Pattern> regexPatterns;
59+
Guardrail inputGuardrail;
60+
Guardrail outputGuardrail;
61+
Guardrails guardrails;
62+
MLGuard mlGuard;
63+
64+
@Before
65+
public void setUp() {
66+
MockitoAnnotations.openMocks(this);
67+
xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
68+
Settings settings = Settings.builder().build();
69+
this.threadContext = new ThreadContext(settings);
70+
when(this.client.threadPool()).thenReturn(this.threadPool);
71+
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
72+
73+
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
74+
regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]);
75+
regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*"));
76+
inputGuardrail = new Guardrail(List.of(stopWords), regex);
77+
outputGuardrail = new Guardrail(List.of(stopWords), regex);
78+
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
79+
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
80+
}
81+
82+
@Test
83+
public void validateInput() {
84+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
85+
Boolean res = mlGuard.validate(input, 0);
86+
87+
Assert.assertFalse(res);
88+
}
89+
90+
@Test
91+
public void validateOutput() {
92+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
93+
Boolean res = mlGuard.validate(input, 1);
94+
95+
Assert.assertFalse(res);
96+
}
97+
98+
@Test
99+
public void validateRegexListSuccess() {
100+
String input = "\n\nHuman:hello good words.\n\nAssistant:";
101+
Boolean res = mlGuard.validateRegexList(input, regexPatterns);
102+
103+
Assert.assertTrue(res);
104+
}
105+
106+
@Test
107+
public void validateRegexListFailed() {
108+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
109+
Boolean res = mlGuard.validateRegexList(input, regexPatterns);
110+
111+
Assert.assertFalse(res);
112+
}
113+
114+
@Test
115+
public void validateRegexSuccess() {
116+
String input = "\n\nHuman:hello good words.\n\nAssistant:";
117+
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));
118+
119+
Assert.assertTrue(res);
120+
}
121+
122+
@Test
123+
public void validateRegexFailed() {
124+
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
125+
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));
126+
127+
Assert.assertFalse(res);
128+
}
129+
130+
@Test
131+
public void validateStopWords() throws IOException {
132+
Map<String, List<String>> stopWordsIndices = Map.of("test_index", List.of("test_field"));
133+
SearchResponse searchResponse = createSearchResponse(1);
134+
ActionFuture<SearchResponse> future = createSearchResponseFuture(searchResponse);
135+
when(this.client.search(any())).thenReturn(future);
136+
137+
Boolean res = mlGuard.validateStopWords("hello world", stopWordsIndices);
138+
Assert.assertTrue(res);
139+
}
140+
141+
@Test
142+
public void validateStopWordsSingleIndex() throws IOException {
143+
SearchResponse searchResponse = createSearchResponse(1);
144+
ActionFuture<SearchResponse> future = createSearchResponseFuture(searchResponse);
145+
when(this.client.search(any())).thenReturn(future);
146+
147+
Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field"));
148+
Assert.assertTrue(res);
149+
}
150+
151+
private SearchResponse createSearchResponse(int size) throws IOException {
152+
XContentBuilder content = guardrails.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
153+
SearchHit[] hits = new SearchHit[size];
154+
if (size > 0) {
155+
hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content));
156+
}
157+
return new SearchResponse(
158+
new InternalSearchResponse(
159+
new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f),
160+
InternalAggregations.EMPTY,
161+
new Suggest(Collections.emptyList()),
162+
new SearchProfileShardResults(Collections.emptyMap()),
163+
false,
164+
false,
165+
1
166+
),
167+
"",
168+
5,
169+
5,
170+
0,
171+
100,
172+
ShardSearchFailure.EMPTY_ARRAY,
173+
SearchResponse.Clusters.EMPTY
174+
);
175+
}
176+
177+
private ActionFuture<SearchResponse> createSearchResponseFuture(SearchResponse searchResponse) {
178+
return new ActionFuture<>() {
179+
@Override
180+
public SearchResponse actionGet() {
181+
return searchResponse;
182+
}
183+
184+
@Override
185+
public SearchResponse actionGet(String timeout) {
186+
return searchResponse;
187+
}
188+
189+
@Override
190+
public SearchResponse actionGet(long timeoutMillis) {
191+
return searchResponse;
192+
}
193+
194+
@Override
195+
public SearchResponse actionGet(long timeout, TimeUnit unit) {
196+
return searchResponse;
197+
}
198+
199+
@Override
200+
public SearchResponse actionGet(TimeValue timeout) {
201+
return searchResponse;
202+
}
203+
204+
@Override
205+
public boolean cancel(boolean mayInterruptIfRunning) {
206+
return false;
207+
}
208+
209+
@Override
210+
public boolean isCancelled() {
211+
return false;
212+
}
213+
214+
@Override
215+
public boolean isDone() {
216+
return false;
217+
}
218+
219+
@Override
220+
public SearchResponse get() throws InterruptedException, ExecutionException {
221+
return searchResponse;
222+
}
223+
224+
@Override
225+
public SearchResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
226+
return searchResponse;
227+
}
228+
};
229+
}
230+
}

0 commit comments

Comments
 (0)