Skip to content

Commit 0fd424b

Browse files
committed
guardrails model support
Signed-off-by: Jing Zhang <[email protected]>
1 parent 9b072c4 commit 0fd424b

File tree

11 files changed

+827
-537
lines changed

11 files changed

+827
-537
lines changed

common/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies {
2525
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
2626
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
2727
compileOnly group: 'org.json', name: 'json', version: '20231013'
28+
implementation 'com.jayway.jsonpath:json-path:2.9.0'
2829

2930
implementation('com.google.guava:guava:32.1.2-jre') {
3031
exclude group: 'com.google.guava', module: 'failureaccess'
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,17 @@
1-
/*
2-
* Copyright OpenSearch Contributors
3-
* SPDX-License-Identifier: Apache-2.0
4-
*/
5-
61
package org.opensearch.ml.common.model;
72

8-
import lombok.Builder;
9-
import lombok.EqualsAndHashCode;
10-
import lombok.Getter;
11-
import org.opensearch.core.common.io.stream.StreamInput;
3+
import org.opensearch.client.Client;
124
import org.opensearch.core.common.io.stream.StreamOutput;
5+
import org.opensearch.core.xcontent.NamedXContentRegistry;
136
import org.opensearch.core.xcontent.ToXContentObject;
14-
import org.opensearch.core.xcontent.XContentBuilder;
15-
import org.opensearch.core.xcontent.XContentParser;
167

178
import java.io.IOException;
18-
import java.util.ArrayList;
19-
import java.util.List;
20-
21-
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
22-
23-
@EqualsAndHashCode
24-
@Getter
25-
public class Guardrail implements ToXContentObject {
26-
public static final String STOP_WORDS_FIELD = "stop_words";
27-
public static final String REGEX_FIELD = "regex";
28-
29-
private List<StopWords> stopWords;
30-
private String[] regex;
31-
32-
@Builder(toBuilder = true)
33-
public Guardrail(List<StopWords> stopWords, String[] regex) {
34-
this.stopWords = stopWords;
35-
this.regex = regex;
36-
}
37-
38-
public Guardrail(StreamInput input) throws IOException {
39-
if (input.readBoolean()) {
40-
stopWords = new ArrayList<>();
41-
int size = input.readInt();
42-
for (int i=0; i<size; i++) {
43-
stopWords.add(new StopWords(input));
44-
}
45-
}
46-
regex = input.readStringArray();
47-
}
48-
49-
public void writeTo(StreamOutput out) throws IOException {
50-
if (stopWords != null && stopWords.size() > 0) {
51-
out.writeBoolean(true);
52-
out.writeInt(stopWords.size());
53-
for (StopWords e : stopWords) {
54-
e.writeTo(out);
55-
}
56-
} else {
57-
out.writeBoolean(false);
58-
}
59-
out.writeStringArray(regex);
60-
}
619

62-
@Override
63-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
64-
builder.startObject();
65-
if (stopWords != null && stopWords.size() > 0) {
66-
builder.field(STOP_WORDS_FIELD, stopWords);
67-
}
68-
if (regex != null) {
69-
builder.field(REGEX_FIELD, regex);
70-
}
71-
builder.endObject();
72-
return builder;
73-
}
10+
public abstract class Guardrail implements ToXContentObject {
7411

75-
public static Guardrail parse(XContentParser parser) throws IOException {
76-
List<StopWords> stopWords = null;
77-
String[] regex = null;
12+
public abstract void writeTo(StreamOutput out) throws IOException;
7813

79-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
80-
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
81-
String fieldName = parser.currentName();
82-
parser.nextToken();
14+
public abstract Boolean validate(String input);
8315

84-
switch (fieldName) {
85-
case STOP_WORDS_FIELD:
86-
stopWords = new ArrayList<>();
87-
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
88-
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
89-
stopWords.add(StopWords.parse(parser));
90-
}
91-
break;
92-
case REGEX_FIELD:
93-
regex = parser.list().toArray(new String[0]);
94-
break;
95-
default:
96-
parser.skipChildren();
97-
break;
98-
}
99-
}
100-
return Guardrail.builder()
101-
.stopWords(stopWords)
102-
.regex(regex)
103-
.build();
104-
}
16+
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
10517
}

common/src/main/java/org/opensearch/ml/common/model/Guardrails.java

+53-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import org.opensearch.core.xcontent.XContentParser;
1616

1717
import java.io.IOException;
18+
import java.util.Map;
19+
import java.util.Set;
1820

1921
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
2022

@@ -39,10 +41,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra
3941
public Guardrails(StreamInput input) throws IOException {
4042
type = input.readString();
4143
if (input.readBoolean()) {
42-
inputGuardrail = new Guardrail(input);
44+
switch (type) {
45+
case "local_regex":
46+
inputGuardrail = new LocalRegexGuardrail(input);
47+
break;
48+
case "model":
49+
break;
50+
default:
51+
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
52+
}
4353
}
4454
if (input.readBoolean()) {
45-
outputGuardrail = new Guardrail(input);
55+
switch (type) {
56+
case "local_regex":
57+
outputGuardrail = new LocalRegexGuardrail(input);
58+
break;
59+
case "model":
60+
break;
61+
default:
62+
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
63+
}
4664
}
4765
}
4866

@@ -80,8 +98,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
8098

8199
public static Guardrails parse(XContentParser parser) throws IOException {
82100
String type = null;
83-
Guardrail inputGuardrail = null;
84-
Guardrail outputGuardrail = null;
101+
Map<String, Object> inputGuardrailMap = null;
102+
Map<String, Object> outputGuardrailMap = null;
85103

86104
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
87105
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -93,20 +111,47 @@ public static Guardrails parse(XContentParser parser) throws IOException {
93111
type = parser.text();
94112
break;
95113
case INPUT_GUARDRAIL_FIELD:
96-
inputGuardrail = Guardrail.parse(parser);
114+
inputGuardrailMap = parser.map();
97115
break;
98116
case OUTPUT_GUARDRAIL_FIELD:
99-
outputGuardrail = Guardrail.parse(parser);
117+
outputGuardrailMap = parser.map();
100118
break;
101119
default:
102120
parser.skipChildren();
103121
break;
104122
}
105123
}
124+
if (!validateType(type)) {
125+
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
126+
}
127+
106128
return Guardrails.builder()
107129
.type(type)
108-
.inputGuardrail(inputGuardrail)
109-
.outputGuardrail(outputGuardrail)
130+
.inputGuardrail(createGuardrail(type, inputGuardrailMap))
131+
.outputGuardrail(createGuardrail(type, outputGuardrailMap))
110132
.build();
111133
}
134+
135+
private static Boolean validateType(String type) {
136+
Set<String> types = Set.of("local_regex", "model");
137+
if (types.contains(type)) {
138+
return true;
139+
}
140+
return false;
141+
}
142+
143+
private static Guardrail createGuardrail(String type, Map<String, Object> params) {
144+
if (params == null || params.isEmpty()) {
145+
return null;
146+
}
147+
148+
switch (type) {
149+
case "local_regex":
150+
return new LocalRegexGuardrail(params);
151+
case "model":
152+
return new ModelGuardrail(params);
153+
default:
154+
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
155+
}
156+
}
112157
}

0 commit comments

Comments
 (0)