15
15
import org .opensearch .core .xcontent .XContentBuilder ;
16
16
import org .opensearch .core .xcontent .XContentParser ;
17
17
import org .opensearch .ml .common .connector .Connector ;
18
+ import org .opensearch .ml .common .model .Guardrails ;
18
19
import org .opensearch .ml .common .model .MLModelConfig ;
19
20
import org .opensearch .ml .common .controller .MLRateLimiter ;
20
21
import org .opensearch .ml .common .model .TextEmbeddingModelConfig ;
@@ -43,6 +44,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
43
44
public static final String CONNECTOR_FIELD = "connector" ; // optional
44
45
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time" ; // passively set when sending update
45
46
// request
47
+ public static final String GUARDRAILS_FIELD = "guardrails" ;
46
48
47
49
@ Getter
48
50
private String modelId ;
@@ -57,11 +59,12 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable {
57
59
private String connectorId ;
58
60
private MLCreateConnectorInput connector ;
59
61
private Instant lastUpdateTime ;
62
+ private Guardrails guardrails ;
60
63
61
64
@ Builder (toBuilder = true )
62
65
public MLUpdateModelInput (String modelId , String description , String version , String name , String modelGroupId ,
63
66
Boolean isEnabled , MLRateLimiter rateLimiter , MLModelConfig modelConfig ,
64
- Connector updatedConnector , String connectorId , MLCreateConnectorInput connector , Instant lastUpdateTime ) {
67
+ Connector updatedConnector , String connectorId , MLCreateConnectorInput connector , Instant lastUpdateTime , Guardrails guardrails ) {
65
68
this .modelId = modelId ;
66
69
this .description = description ;
67
70
this .version = version ;
@@ -74,6 +77,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St
74
77
this .connectorId = connectorId ;
75
78
this .connector = connector ;
76
79
this .lastUpdateTime = lastUpdateTime ;
80
+ this .guardrails = guardrails ;
77
81
}
78
82
79
83
public MLUpdateModelInput (StreamInput in ) throws IOException {
@@ -97,6 +101,9 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
97
101
connector = new MLCreateConnectorInput (in );
98
102
}
99
103
lastUpdateTime = in .readOptionalInstant ();
104
+ if (in .readBoolean ()) {
105
+ this .guardrails = new Guardrails (in );
106
+ }
100
107
}
101
108
102
109
@ Override
@@ -136,6 +143,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
136
143
if (lastUpdateTime != null ) {
137
144
builder .field (LAST_UPDATED_TIME_FIELD , lastUpdateTime .toEpochMilli ());
138
145
}
146
+ if (guardrails != null ) {
147
+ builder .field (GUARDRAILS_FIELD , guardrails );
148
+ }
139
149
builder .endObject ();
140
150
return builder ;
141
151
}
@@ -174,6 +184,9 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa
174
184
if (lastUpdateTime != null ) {
175
185
builder .field (LAST_UPDATED_TIME_FIELD , lastUpdateTime .toEpochMilli ());
176
186
}
187
+ if (guardrails != null ) {
188
+ builder .field (GUARDRAILS_FIELD , guardrails );
189
+ }
177
190
builder .endObject ();
178
191
return builder ;
179
192
}
@@ -212,6 +225,12 @@ public void writeTo(StreamOutput out) throws IOException {
212
225
out .writeBoolean (false );
213
226
}
214
227
out .writeOptionalInstant (lastUpdateTime );
228
+ if (guardrails != null ) {
229
+ out .writeBoolean (true );
230
+ guardrails .writeTo (out );
231
+ } else {
232
+ out .writeBoolean (false );
233
+ }
215
234
}
216
235
217
236
public static MLUpdateModelInput parse (XContentParser parser ) throws IOException {
@@ -227,6 +246,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
227
246
String connectorId = null ;
228
247
MLCreateConnectorInput connector = null ;
229
248
Instant lastUpdateTime = null ;
249
+ Guardrails guardrails = null ;
230
250
231
251
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
232
252
while (parser .nextToken () != XContentParser .Token .END_OBJECT ) {
@@ -257,6 +277,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
257
277
case CONNECTOR_FIELD :
258
278
connector = MLCreateConnectorInput .parse (parser , true );
259
279
break ;
280
+ case GUARDRAILS_FIELD :
281
+ guardrails = Guardrails .parse (parser );
282
+ break ;
260
283
default :
261
284
parser .skipChildren ();
262
285
break ;
@@ -265,6 +288,6 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException
265
288
// Model ID can only be set through RestRequest. Model version can only be set
266
289
// automatically.
267
290
return new MLUpdateModelInput (modelId , description , version , name , modelGroupId , isEnabled , rateLimiter ,
268
- modelConfig , updatedConnector , connectorId , connector , lastUpdateTime );
291
+ modelConfig , updatedConnector , connectorId , connector , lastUpdateTime , guardrails );
269
292
}
270
293
}
0 commit comments