Skip to content

Commit 6c67b1c

Browse files
committed
address comments
Signed-off-by: Jing Zhang <[email protected]>
1 parent 8cae1a8 commit 6c67b1c

File tree

4 files changed

+36
-22
lines changed

4 files changed

+36
-22
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class CommonValue {
5656
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
5757
public static final String ML_TASK_INDEX = ".plugins-ml-task";
5858
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
59-
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
59+
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
6060
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
6161
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
6262
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

+12-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.concurrent.atomic.AtomicReference;
3636
import java.util.regex.Matcher;
3737
import java.util.regex.Pattern;
38+
import java.util.stream.Collectors;
3839

3940
import static java.util.concurrent.TimeUnit.SECONDS;
4041
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
@@ -47,6 +48,8 @@ public class MLGuard {
4748
private Map<String, List<String>> stopWordsIndicesOutput = new HashMap<>();
4849
private List<String> inputRegex;
4950
private List<String> outputRegex;
51+
private List<Pattern> inputRegexPattern;
52+
private List<Pattern> outputRegexPattern;
5053
private NamedXContentRegistry xContentRegistry;
5154
private Client client;
5255

@@ -61,10 +64,12 @@ public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Cl
6164
if (inputGuardrail != null) {
6265
fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput);
6366
inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex());
67+
inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
6468
}
6569
if (outputGuardrail != null) {
6670
fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput);
6771
outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex());
72+
outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
6873
}
6974
}
7075

@@ -81,25 +86,24 @@ private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map<String, List<S
8186
public Boolean validate(String input, int type) {
8287
switch (type) {
8388
case 0: // validate input
84-
return validateRegexList(input, inputRegex) && validateStopWords(input, stopWordsIndicesInput);
89+
return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput);
8590
case 1: // validate output
86-
return validateRegexList(input, outputRegex) && validateStopWords(input, stopWordsIndicesOutput);
91+
return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput);
8792
default:
88-
return true;
93+
throw new IllegalArgumentException("Unsupported type to validate for guardrails.");
8994
}
9095
}
9196

92-
public Boolean validateRegexList(String input, List<String> regexList) {
93-
for (String regex : regexList) {
94-
if (!validateRegex(input, regex)) {
97+
public Boolean validateRegexList(String input, List<Pattern> regexPatterns) {
98+
for (Pattern pattern : regexPatterns) {
99+
if (!validateRegex(input, pattern)) {
95100
return false;
96101
}
97102
}
98103
return true;
99104
}
100105

101-
public Boolean validateRegex(String input, String regex) {
102-
Pattern pattern = Pattern.compile(regex);
106+
public Boolean validateRegex(String input, Pattern pattern) {
103107
Matcher matcher = pattern.matcher(input);
104108
return !matcher.matches();
105109
}

common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java

+16-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import lombok.Data;
99
import lombok.Builder;
1010
import lombok.Getter;
11+
import org.opensearch.Version;
1112
import org.opensearch.core.common.io.stream.StreamInput;
1213
import org.opensearch.core.common.io.stream.StreamOutput;
1314
import org.opensearch.core.common.io.stream.Writeable;
@@ -46,6 +47,8 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
4647
// request
4748
public static final String GUARDRAILS_FIELD = "guardrails";
4849

50+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS = Version.V_2_13_0;
51+
4952
@Getter
5053
private String modelId;
5154
private String description;
@@ -81,6 +84,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St
8184
}
8285

8386
public MLUpdateModelInput(StreamInput in) throws IOException {
87+
Version streamInputVersion = in.getVersion();
8488
modelId = in.readString();
8589
description = in.readOptionalString();
8690
version = in.readOptionalString();
@@ -101,8 +105,10 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
101105
connector = new MLCreateConnectorInput(in);
102106
}
103107
lastUpdateTime = in.readOptionalInstant();
104-
if (in.readBoolean()) {
105-
this.guardrails = new Guardrails(in);
108+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
109+
if (in.readBoolean()) {
110+
this.guardrails = new Guardrails(in);
111+
}
106112
}
107113
}
108114

@@ -193,6 +199,7 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa
193199

194200
@Override
195201
public void writeTo(StreamOutput out) throws IOException {
202+
Version streamOutputVersion = out.getVersion();
196203
out.writeString(modelId);
197204
out.writeOptionalString(description);
198205
out.writeOptionalString(version);
@@ -225,11 +232,13 @@ public void writeTo(StreamOutput out) throws IOException {
225232
out.writeBoolean(false);
226233
}
227234
out.writeOptionalInstant(lastUpdateTime);
228-
if (guardrails != null) {
229-
out.writeBoolean(true);
230-
guardrails.writeTo(out);
231-
} else {
232-
out.writeBoolean(false);
235+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS)) {
236+
if (guardrails != null) {
237+
out.writeBoolean(true);
238+
guardrails.writeTo(out);
239+
} else {
240+
out.writeBoolean(false);
241+
}
233242
}
234243
}
235244

common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import java.util.concurrent.ExecutionException;
4040
import java.util.concurrent.TimeUnit;
4141
import java.util.concurrent.TimeoutException;
42+
import java.util.regex.Pattern;
4243

4344
import static org.mockito.ArgumentMatchers.any;
4445
import static org.mockito.Mockito.when;
@@ -54,6 +55,7 @@ public class MLGuardTests {
5455

5556
StopWords stopWords;
5657
String[] regex;
58+
List<Pattern> regexPatterns;
5759
Guardrail inputGuardrail;
5860
Guardrail outputGuardrail;
5961
Guardrails guardrails;
@@ -70,6 +72,7 @@ public void setUp() {
7072

7173
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
7274
regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]);
75+
regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*"));
7376
inputGuardrail = new Guardrail(List.of(stopWords), regex);
7477
outputGuardrail = new Guardrail(List.of(stopWords), regex);
7578
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
@@ -95,33 +98,31 @@ public void validateOutput() {
9598
@Test
9699
public void validateRegexListSuccess() {
97100
String input = "\n\nHuman:hello good words.\n\nAssistant:";
98-
List<String> regexList = List.of(regex);
99-
Boolean res = mlGuard.validateRegexList(input, regexList);
101+
Boolean res = mlGuard.validateRegexList(input, regexPatterns);
100102

101103
Assert.assertTrue(res);
102104
}
103105

104106
@Test
105107
public void validateRegexListFailed() {
106108
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
107-
List<String> regexList = List.of(regex);
108-
Boolean res = mlGuard.validateRegexList(input, regexList);
109+
Boolean res = mlGuard.validateRegexList(input, regexPatterns);
109110

110111
Assert.assertFalse(res);
111112
}
112113

113114
@Test
114115
public void validateRegexSuccess() {
115116
String input = "\n\nHuman:hello good words.\n\nAssistant:";
116-
Boolean res = mlGuard.validateRegex(input, regex[0]);
117+
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));
117118

118119
Assert.assertTrue(res);
119120
}
120121

121122
@Test
122123
public void validateRegexFailed() {
123124
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
124-
Boolean res = mlGuard.validateRegex(input, regex[0]);
125+
Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0));
125126

126127
Assert.assertFalse(res);
127128
}

0 commit comments

Comments
 (0)