Skip to content

Commit bfee81e

Browse files
committed
guardrails
Signed-off-by: Jing Zhang <[email protected]>
1 parent 2e0946a commit bfee81e

File tree

14 files changed

+597
-9
lines changed

14 files changed

+597
-9
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ public class CommonValue {
265265
+ MLModel.CONNECTOR_FIELD
266266
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"
267267
+ USER_FIELD_MAPPING
268-
+ " }\n"
268+
+ " },\n"
269+
+ " \""
270+
+ MLModel.GUARDRAILS_FIELD
271+
+ "\" : {\"type\": \"flat_object\"},\n"
269272
+ "}";
270273

271274
public static final String ML_TASK_INDEX_MAPPING = "{\n"

common/src/main/java/org/opensearch/ml/common/MLModel.java

+23-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.opensearch.core.xcontent.XContentBuilder;
1717
import org.opensearch.core.xcontent.XContentParser;
1818
import org.opensearch.ml.common.connector.Connector;
19+
import org.opensearch.ml.common.model.Guardrails;
1920
import org.opensearch.ml.common.model.MLModelConfig;
2021
import org.opensearch.ml.common.controller.MLRateLimiter;
2122
import org.opensearch.ml.common.model.MLModelFormat;
@@ -84,6 +85,7 @@ public class MLModel implements ToXContentObject {
8485
public static final String IS_HIDDEN_FIELD = "is_hidden";
8586
public static final String CONNECTOR_FIELD = "connector";
8687
public static final String CONNECTOR_ID_FIELD = "connector_id";
88+
public static final String GUARDRAILS_FIELD = "guardrails";
8789

8890
private String name;
8991
private String modelGroupId;
@@ -127,6 +129,7 @@ public class MLModel implements ToXContentObject {
127129
@Setter
128130
private Connector connector;
129131
private String connectorId;
132+
private Guardrails guardrails;
130133

131134
@Builder(toBuilder = true)
132135
public MLModel(String name,
@@ -158,7 +161,8 @@ public MLModel(String name,
158161
boolean deployToAllNodes,
159162
Boolean isHidden,
160163
Connector connector,
161-
String connectorId) {
164+
String connectorId,
165+
Guardrails guardrails) {
162166
this.name = name;
163167
this.modelGroupId = modelGroupId;
164168
this.algorithm = algorithm;
@@ -190,6 +194,7 @@ public MLModel(String name,
190194
this.isHidden = isHidden;
191195
this.connector = connector;
192196
this.connectorId = connectorId;
197+
this.guardrails = guardrails;
193198
}
194199

195200
public MLModel(StreamInput input) throws IOException {
@@ -243,6 +248,9 @@ public MLModel(StreamInput input) throws IOException {
243248
connector = Connector.fromStream(input);
244249
}
245250
connectorId = input.readOptionalString();
251+
if (input.readBoolean()) {
252+
this.guardrails = new Guardrails(input);
253+
}
246254
}
247255
}
248256

@@ -308,6 +316,12 @@ public void writeTo(StreamOutput out) throws IOException {
308316
out.writeBoolean(false);
309317
}
310318
out.writeOptionalString(connectorId);
319+
if (guardrails != null) {
320+
out.writeBoolean(true);
321+
guardrails.writeTo(out);
322+
} else {
323+
out.writeBoolean(false);
324+
}
311325
}
312326

313327
@Override
@@ -406,6 +420,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
406420
if (connectorId != null) {
407421
builder.field(CONNECTOR_ID_FIELD, connectorId);
408422
}
423+
if (guardrails != null) {
424+
builder.field(GUARDRAILS_FIELD, guardrails);
425+
}
409426
builder.endObject();
410427
return builder;
411428
}
@@ -448,6 +465,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
448465
boolean isHidden = false;
449466
Connector connector = null;
450467
String connectorId = null;
468+
Guardrails guardrails = null;
451469

452470
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
453471
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -571,6 +589,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
571589
case LAST_UNDEPLOYED_TIME_FIELD:
572590
lastUndeployedTime = Instant.ofEpochMilli(parser.longValue());
573591
break;
592+
case GUARDRAILS_FIELD:
593+
guardrails = Guardrails.parse(parser);
594+
break;
574595
default:
575596
parser.skipChildren();
576597
break;
@@ -608,6 +629,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
608629
.isHidden(isHidden)
609630
.connector(connector)
610631
.connectorId(connectorId)
632+
.guardrails(guardrails)
611633
.build();
612634
}
613635

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package org.opensearch.ml.common.model;
2+
3+
import lombok.Builder;
4+
import lombok.EqualsAndHashCode;
5+
import lombok.Getter;
6+
import org.opensearch.core.common.io.stream.StreamInput;
7+
import org.opensearch.core.common.io.stream.StreamOutput;
8+
import org.opensearch.core.xcontent.ToXContentObject;
9+
import org.opensearch.core.xcontent.XContentBuilder;
10+
import org.opensearch.core.xcontent.XContentParser;
11+
12+
import java.io.IOException;
13+
import java.util.ArrayList;
14+
import java.util.List;
15+
16+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
17+
18+
@EqualsAndHashCode
19+
@Getter
20+
public class Guardrail implements ToXContentObject {
21+
public static final String STOP_WORDS_FIELD = "stop_words";
22+
public static final String REGEX_FIELD = "regex";
23+
24+
private List<StopWords> stopWords;
25+
private String[] regex;
26+
27+
@Builder(toBuilder = true)
28+
public Guardrail(List<StopWords> stopWords, String[] regex) {
29+
this.stopWords = stopWords;
30+
this.regex = regex;
31+
}
32+
33+
public Guardrail(StreamInput input) throws IOException {
34+
if (input.readBoolean()) {
35+
stopWords = new ArrayList<>();
36+
int size = input.readInt();
37+
for (int i=0; i<size; i++) {
38+
stopWords.add(new StopWords(input));
39+
}
40+
}
41+
regex = input.readStringArray();
42+
}
43+
44+
public void writeTo(StreamOutput out) throws IOException {
45+
if (stopWords != null && stopWords.size() > 0) {
46+
out.writeBoolean(true);
47+
out.writeInt(stopWords.size());
48+
for (StopWords e : stopWords) {
49+
e.writeTo(out);
50+
}
51+
} else {
52+
out.writeBoolean(false);
53+
}
54+
out.writeStringArray(regex);
55+
}
56+
57+
@Override
58+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
59+
builder.startObject();
60+
if (stopWords != null && stopWords.size() > 0) {
61+
builder.field(STOP_WORDS_FIELD, stopWords);
62+
}
63+
if (regex != null) {
64+
builder.field(REGEX_FIELD, regex);
65+
}
66+
builder.endObject();
67+
return builder;
68+
}
69+
70+
public static Guardrail parse(XContentParser parser) throws IOException {
71+
List<StopWords> stopWords = null;
72+
String[] regex = null;
73+
74+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
75+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
76+
String fieldName = parser.currentName();
77+
parser.nextToken();
78+
79+
switch (fieldName) {
80+
case STOP_WORDS_FIELD:
81+
stopWords = new ArrayList<>();
82+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
83+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
84+
stopWords.add(StopWords.parse(parser));
85+
}
86+
break;
87+
case REGEX_FIELD:
88+
regex = parser.list().toArray(new String[0]);
89+
break;
90+
default:
91+
parser.skipChildren();
92+
break;
93+
}
94+
}
95+
return Guardrail.builder()
96+
.stopWords(stopWords)
97+
.regex(regex)
98+
.build();
99+
}
100+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package org.opensearch.ml.common.model;
2+
3+
import lombok.Builder;
4+
import lombok.EqualsAndHashCode;
5+
import lombok.Getter;
6+
import org.opensearch.core.common.io.stream.StreamInput;
7+
import org.opensearch.core.common.io.stream.StreamOutput;
8+
import org.opensearch.core.xcontent.ToXContentObject;
9+
import org.opensearch.core.xcontent.XContentBuilder;
10+
import org.opensearch.core.xcontent.XContentParser;
11+
12+
import java.io.IOException;
13+
14+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
15+
16+
@EqualsAndHashCode
17+
@Getter
18+
public class Guardrails implements ToXContentObject {
19+
public static final String TYPE_FIELD = "type";
20+
public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled";
21+
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
22+
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";
23+
24+
private String type;
25+
private Boolean engDetectionEnabled;
26+
private Guardrail inputGuardrail;
27+
private Guardrail outputGuardrail;
28+
29+
@Builder(toBuilder = true)
30+
public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) {
31+
this.type = type;
32+
this.engDetectionEnabled = engDetectionEnabled;
33+
this.inputGuardrail = inputGuardrail;
34+
this.outputGuardrail = outputGuardrail;
35+
}
36+
37+
public Guardrails(StreamInput input) throws IOException {
38+
type = input.readString();
39+
engDetectionEnabled = input.readBoolean();
40+
if (input.readBoolean()) {
41+
inputGuardrail = new Guardrail(input);
42+
}
43+
if (input.readBoolean()) {
44+
outputGuardrail = new Guardrail(input);
45+
}
46+
}
47+
48+
public void writeTo(StreamOutput out) throws IOException {
49+
out.writeString(type);
50+
out.writeBoolean(engDetectionEnabled);
51+
if (inputGuardrail != null) {
52+
out.writeBoolean(true);
53+
inputGuardrail.writeTo(out);
54+
} else {
55+
out.writeBoolean(false);
56+
}
57+
if (outputGuardrail != null) {
58+
out.writeBoolean(true);
59+
outputGuardrail.writeTo(out);
60+
} else {
61+
out.writeBoolean(false);
62+
}
63+
}
64+
65+
@Override
66+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
67+
builder.startObject();
68+
if (type != null) {
69+
builder.field(TYPE_FIELD, type);
70+
}
71+
if (engDetectionEnabled != null) {
72+
builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled);
73+
}
74+
if (inputGuardrail != null) {
75+
builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail);
76+
}
77+
if (outputGuardrail != null) {
78+
builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail);
79+
}
80+
builder.endObject();
81+
return builder;
82+
}
83+
84+
public static Guardrails parse(XContentParser parser) throws IOException {
85+
String type = null;
86+
Boolean engDetectionEnabled = null;
87+
Guardrail inputGuardrail = null;
88+
Guardrail outputGuardrail = null;
89+
90+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
91+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
92+
String fieldName = parser.currentName();
93+
parser.nextToken();
94+
95+
switch (fieldName) {
96+
case TYPE_FIELD:
97+
type = parser.text();
98+
break;
99+
case ENGLISH_DETECTION_ENABLED_FIELD:
100+
engDetectionEnabled = parser.booleanValue();
101+
break;
102+
case INPUT_GUARDRAIL_FIELD:
103+
inputGuardrail = Guardrail.parse(parser);
104+
break;
105+
case OUTPUT_GUARDRAIL_FIELD:
106+
outputGuardrail = Guardrail.parse(parser);
107+
break;
108+
default:
109+
parser.skipChildren();
110+
break;
111+
}
112+
}
113+
return Guardrails.builder()
114+
.type(type)
115+
.engDetectionEnabled(engDetectionEnabled)
116+
.inputGuardrail(inputGuardrail)
117+
.outputGuardrail(outputGuardrail)
118+
.build();
119+
}
120+
}

0 commit comments

Comments
 (0)