Skip to content

Commit 2a5ab75

Browse files
committed
update guardrails
Signed-off-by: Jing Zhang <[email protected]>
1 parent 28995fc commit 2a5ab75

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.core.xcontent.XContentBuilder;
1616
import org.opensearch.core.xcontent.XContentParser;
1717
import org.opensearch.ml.common.connector.Connector;
18+
import org.opensearch.ml.common.model.Guardrails;
1819
import org.opensearch.ml.common.model.MLModelConfig;
1920
import org.opensearch.ml.common.controller.MLRateLimiter;
2021
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
@@ -43,6 +44,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
4344
public static final String CONNECTOR_FIELD = "connector"; // optional
4445
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update
4546
// request
47+
public static final String GUARDRAILS_FIELD = "guardrails";
4648

4749
@Getter
4850
private String modelId;
@@ -57,11 +59,12 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
5759
private String connectorId;
5860
private MLCreateConnectorInput connector;
5961
private Instant lastUpdateTime;
62+
private Guardrails guardrails;
6063

6164
@Builder(toBuilder = true)
6265
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId,
6366
Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig,
64-
Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime) {
67+
Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, Guardrails guardrails) {
6568
this.modelId = modelId;
6669
this.description = description;
6770
this.version = version;
@@ -74,6 +77,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St
7477
this.connectorId = connectorId;
7578
this.connector = connector;
7679
this.lastUpdateTime = lastUpdateTime;
80+
this.guardrails = guardrails;
7781
}
7882

7983
public MLUpdateModelInput(StreamInput in) throws IOException {
@@ -97,6 +101,9 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
97101
connector = new MLCreateConnectorInput(in);
98102
}
99103
lastUpdateTime = in.readOptionalInstant();
104+
if (in.readBoolean()) {
105+
this.guardrails = new Guardrails(in);
106+
}
100107
}
101108

102109
@Override
@@ -136,6 +143,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
136143
if (lastUpdateTime != null) {
137144
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
138145
}
146+
if (guardrails != null) {
147+
builder.field(GUARDRAILS_FIELD, guardrails);
148+
}
139149
builder.endObject();
140150
return builder;
141151
}
@@ -174,6 +184,9 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa
174184
if (lastUpdateTime != null) {
175185
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
176186
}
187+
if (guardrails != null) {
188+
builder.field(GUARDRAILS_FIELD, guardrails);
189+
}
177190
builder.endObject();
178191
return builder;
179192
}
@@ -212,6 +225,12 @@ public void writeTo(StreamOutput out) throws IOException {
212225
out.writeBoolean(false);
213226
}
214227
out.writeOptionalInstant(lastUpdateTime);
228+
if (guardrails != null) {
229+
out.writeBoolean(true);
230+
guardrails.writeTo(out);
231+
} else {
232+
out.writeBoolean(false);
233+
}
215234
}
216235

217236
public static MLUpdateModelInput parse(XContentParser parser) throws IOException {
@@ -227,6 +246,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
227246
String connectorId = null;
228247
MLCreateConnectorInput connector = null;
229248
Instant lastUpdateTime = null;
249+
Guardrails guardrails = null;
230250

231251
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
232252
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -257,6 +277,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
257277
case CONNECTOR_FIELD:
258278
connector = MLCreateConnectorInput.parse(parser, true);
259279
break;
280+
case GUARDRAILS_FIELD:
281+
guardrails = Guardrails.parse(parser);
282+
break;
260283
default:
261284
parser.skipChildren();
262285
break;
@@ -265,6 +288,6 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
265288
// Model ID can only be set through RestRequest. Model version can only be set
266289
// automatically.
267290
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter,
268-
modelConfig, updatedConnector, connectorId, connector, lastUpdateTime);
291+
modelConfig, updatedConnector, connectorId, connector, lastUpdateTime, guardrails);
269292
}
270293
}

plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,13 @@ private void updateRemoteOrTextEmbeddingModel(
206206
String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null;
207207
boolean isModelDeployed = isModelDeployed(mlModel.getModelState());
208208
// This flag is used to decide if we need to re-deploy the predictor(model) when updating the model cache.
209-
// If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter needs update, we
209+
// If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter and guardrails
210+
// need update, we
210211
// need to perform a re-deploy.
211212
boolean isPredictorUpdate = (updateModelInput.getConnector() != null)
212213
|| (newConnectorId != null)
213-
|| !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled());
214+
|| !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled())
215+
|| (updateModelInput.getGuardrails() != null);
214216
if (MLRateLimiter.updateValidityPreCheck(mlModel.getRateLimiter(), updateModelInput.getRateLimiter())) {
215217
MLRateLimiter updatedRateLimiterConfig = MLRateLimiter.update(mlModel.getRateLimiter(), updateModelInput.getRateLimiter());
216218
updateModelInput.setRateLimiter(updatedRateLimiterConfig);

0 commit comments

Comments
 (0)