Skip to content

Commit

Permalink
Use workflow template for workflow model_dir (deepjavalibrary#1612)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Mar 8, 2024
1 parent c63a22d commit 0790c81
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
4 changes: 2 additions & 2 deletions serving/docs/workflow_templates.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ Example:
To add a new template, begin by creating the template JSON file.
This mostly matches the standard format of a [workflow](workflows.md).
However, your template can indicate variable sections of the template to be replaced.
This is done by prefixing the name to replace with a `$`.
So, a parameter `param` would replace the value `$param` within the template.
This is done by surrounding the name with curly braces.
So, a parameter `param` would replace the value `{param}` within the template.
This replacement is directly in place, so if your parameter is a string you will still have to surround it with quotation marks.

Then, you must register your new workflow template.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,19 @@ public static WorkflowDefinition parse(Path path) throws IOException {
* @throws IOException if read from uri failed
*/
public static WorkflowDefinition parse(String name, URI uri) throws IOException {
return parse(name, uri, null);
return parse(name, uri, new ConcurrentHashMap<>());
}

static WorkflowDefinition parse(String name, URI uri, Map<String, String> templateReplacements)
throws IOException {
String type = FilenameUtils.getFileExtension(Objects.requireNonNull(uri.toString()));

// Default model_dir template replacement
if (templateReplacements == null) {
templateReplacements = new ConcurrentHashMap<>();
}
templateReplacements.put("model_dir", getWorkflowDir(uri.toString()));

try (InputStream is = uri.toURL().openStream();
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
WorkflowDefinition wd = parse(type, reader, templateReplacements);
Expand All @@ -133,7 +140,7 @@ private static WorkflowDefinition parse(
templateReplacements.entrySet()) {
l =
l.replace(
"$" + replacement.getKey(),
"{" + replacement.getKey() + "}",
replacement.getValue());
}
return l;
Expand Down Expand Up @@ -221,14 +228,12 @@ public static URI toWorkflowUri(String link) {
* @throws BadWorkflowException if the workflow could not be parsed successfully
*/
public Workflow toWorkflow() throws BadWorkflowException {
int pos = baseUri.lastIndexOf('/');
String workflowDir = baseUri.substring(0, pos);
String workflowDir = getWorkflowDir(baseUri);

if (models != null) {
for (Entry<String, ModelInfo<Input, Output>> emd : models.entrySet()) {
ModelInfo<Input, Output> md = emd.getValue();
md.setId(emd.getKey());
md.postWorkflowParsing(workflowDir);
}
}

Expand Down Expand Up @@ -266,6 +271,11 @@ public Workflow toWorkflow() throws BadWorkflowException {
return new Workflow(name, version, wpcs, expressions, configs, loadedFunctions);
}

private static String getWorkflowDir(String uri) {
int pos = uri.lastIndexOf('/');
return uri.substring(0, pos);
}

private static final class ModelDefinitionDeserializer
implements JsonDeserializer<ModelInfo<Input, Output>> {

Expand Down
10 changes: 5 additions & 5 deletions serving/src/main/resources/workflowTemplates/adapter.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
{
"name": "$adapter",
"name": "{adapter}",
"version": "0.1",
"models": {
"m": "$model"
"m": "{model}"
},
"configs": {
"adapters": {
"$adapter": {
"{adapter}": {
"model": "m",
"src": "$url"
"src": "{url}"
}
}
},
"workflow": {
"out": ["adapter", "$adapter", "in"]
"out": ["adapter", "{adapter}", "in"]
}
}
14 changes: 0 additions & 14 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import com.google.gson.JsonParseException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -194,18 +192,6 @@ public ModelInfo(
dimension = new Dimension("Model", id);
}

/**
* Performs post workflow parsing initialization.
*
* @param workflowDir the workflow parent directory
*/
public void postWorkflowParsing(String workflowDir) {
if (modelUrl == null) {
throw new JsonParseException("modelUrl is required in workflow definition.");
}
modelUrl = modelUrl.replaceAll("\\{model_dir}", workflowDir);
}

/**
* Returns the properties of the model.
*
Expand Down

0 comments on commit 0790c81

Please sign in to comment.