Skip to content

Commit

Permalink
Workflow Templates (deepjavalibrary#1594)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Mar 6, 2024
1 parent ce4afea commit 22acc49
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 8 deletions.
1 change: 1 addition & 0 deletions serving/docs/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ DEL models/{modelName}/adapters/{adapterName} - Delete adapter

The final option for working with adapters is through the [DJL Serving workflows system](workflows.md).
You can use the adapter `WorkflowFunction` to create and call an adapted version of a model within the workflow.
For the simple model + adapter case, you can also directly use the adapter [workflow template](workflow_templates.md).
With our workflows, multiple workflows sharing models will be de-duplicated.
So, the effect of having multiple adapters can be easily made with having one workflow for each adapter.
This system can be used on [Amazon SageMaker Multi-Model Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html).
Expand Down
5 changes: 4 additions & 1 deletion serving/docs/management_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ curl -v -X POST "http://localhost:8080/models?url=https%3A%2F%2Fresources.djl.ai

`POST /workflows`

* url - Workflow url.
* url - Workflow url
* template - A workflow template to use
* engine - the name of engine to load the model. DJL will try to infer engine if not specified.
* device - the device to load the model. DJL will pick optimal device if not specified, the value device can be:
* CPU device: cpu or simply -1
Expand All @@ -82,6 +83,8 @@ curl -v -X POST "http://localhost:8080/models?url=https%3A%2F%2Fresources.djl.ai
* max_worker - the maximum number of worker processes. The default is the same as the setting for `min_worker`.
* synchronous - if the creation of worker is synchronous. The default value is true.

Either a url or [template](workflow_templates.md) is required.

```bash
curl -X POST "http://localhost:8080/workflows?url=https%3A%2F%2Fresources.djl.ai%2Ftest-workflows%2Fmlp.zip"

Expand Down
42 changes: 42 additions & 0 deletions serving/docs/workflow_templates.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Workflow Templates

Workflow templates are a tool to make it easier to register similar workflows.

## Registering a workflow with a template

To register a new workflow using a template, the template must first be registered (see below).
Then, you can use the management API to register the workflow.
To do so, call register while specifying the template using the `template` query parameter.
Then, specify all template replacement values using additional query parameters.

## Available templates

### adapter

The adapter template is used for a simple workflow that reflects an [adapter model](adapters.md).

Parameters:

- template=`adapter`
- adapter - The adapter name
- url - The adapter URL
- model - The model URL

Example:

`POST /workflows?template=adapter&adapter={adapterName}&url={adapterUrl}&model={modelUrl}`

## Adding new templates

To add a new template, begin by creating the template JSON file.
This mostly matches the standard format of a [workflow](workflows.md).
However, your template can indicate variable sections of the template to be replaced.
This is done by prefixing the name to replace with a `$`.
So, a parameter `param` would replace the value `$param` within the template.
This replacement is directly in place, so if your parameter is a string you will still have to surround it with quotation marks.

Then, you must register your new workflow template.
There are two options for doing this.
First, you can add it to the classpath as a resource with path `workflowTemplates/{workflowTemplateName}.json`.
Alternatively, you can register it from a plugin by calling `WorkflowTemplates.register(..)`.
Once this is done, you will be able to begin creating workflows with your template.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class LoadModelRequest {

static final String URL = "url";
static final String TEMPLATE = "template";
static final String DEVICE = "device";
static final String MAX_WORKER = "max_worker";
static final String MIN_WORKER = "min_worker";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import ai.djl.serving.workflow.BadWorkflowException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.serving.workflow.WorkflowDefinition;
import ai.djl.serving.workflow.WorkflowTemplates;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
Expand All @@ -47,6 +49,7 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/** A class handling inbound HTTP requests to the management API. */
public class ManagementRequestHandler extends HttpRequestHandler {
Expand Down Expand Up @@ -258,8 +261,9 @@ private void handleRegisterModel(
private void handleRegisterWorkflow(
final ChannelHandlerContext ctx, QueryStringDecoder decoder) {
String workflowUrl = NettyUtils.getParameter(decoder, LoadModelRequest.URL, null);
if (workflowUrl == null) {
throw new BadRequestException("Parameter url is required.");
String workflowTemplate = NettyUtils.getParameter(decoder, LoadModelRequest.TEMPLATE, null);
if (workflowUrl == null && workflowTemplate == null) {
throw new BadRequestException("Either parameter url or template is required.");
}

boolean synchronous =
Expand All @@ -269,8 +273,20 @@ private void handleRegisterWorkflow(
try {
final ModelManager modelManager = ModelManager.getInstance();

URI uri = URI.create(workflowUrl);
Workflow workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
Workflow workflow;
if (workflowTemplate != null) { // Workflow from template
Map<String, String> templateReplacements = // NOPMD
decoder.parameters().entrySet().stream()
.filter(e -> e.getValue().size() == 1)
.map(e -> new Pair<>(e.getKey(), e.getValue().get(0)))
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
workflow =
WorkflowTemplates.template(workflowTemplate, templateReplacements)
.toWorkflow();
} else { // Workflow from URL
URI uri = URI.create(workflowUrl);
workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
}
String workflowName = workflow.getName();

CompletableFuture<Void> f =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import com.google.gson.JsonParseException;
import com.google.gson.annotations.SerializedName;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringReader;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
Expand All @@ -49,6 +51,7 @@
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
* This class is for parsing the JSON or YAML definition for a {@link Workflow}.
Expand Down Expand Up @@ -93,16 +96,21 @@ public static WorkflowDefinition parse(Path path) throws IOException {
/**
* Parses a new {@link WorkflowDefinition} from an input stream.
*
* @param name the workflow name
* @param name the workflow name (null for no name)
* @param uri the uri of the file
* @return the parsed {@link WorkflowDefinition}
* @throws IOException if read from uri failed
*/
public static WorkflowDefinition parse(String name, URI uri) throws IOException {
return parse(name, uri, null);
}

static WorkflowDefinition parse(String name, URI uri, Map<String, String> templateReplacements)
throws IOException {
String type = FilenameUtils.getFileExtension(Objects.requireNonNull(uri.toString()));
try (InputStream is = uri.toURL().openStream();
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
WorkflowDefinition wd = parse(type, reader);
WorkflowDefinition wd = parse(type, reader, templateReplacements);
if (name != null) {
wd.name = name;
}
Expand All @@ -113,7 +121,26 @@ public static WorkflowDefinition parse(String name, URI uri) throws IOException
}
}

private static WorkflowDefinition parse(String type, Reader input) {
private static WorkflowDefinition parse(
String type, Reader input, Map<String, String> templateReplacements) {
if (templateReplacements != null) {
String updatedInput =
new BufferedReader(input)
.lines()
.map(
l -> {
for (Entry<String, String> replacement :
templateReplacements.entrySet()) {
l =
l.replace(
"$" + replacement.getKey(),
replacement.getValue());
}
return l;
})
.collect(Collectors.joining("\n"));
input = new StringReader(updatedInput);
}
if ("yml".equalsIgnoreCase(type) || "yaml".equalsIgnoreCase(type)) {
try {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.workflow;

import ai.djl.util.ClassLoaderUtils;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** A class for managing and using {@link WorkflowDefinition} templates. */
public final class WorkflowTemplates {

private static final Map<String, URI> TEMPLATES = new ConcurrentHashMap<>();

private WorkflowTemplates() {}

/**
* Registers a new workflow template.
*
* @param name the template name
* @param template the template location
*/
public static void register(String name, URI template) {
TEMPLATES.put(name, template);
}

/**
* Constructs a {@link WorkflowDefinition} using a registered template.
*
* @param templateName the template name
* @param templateReplacements a map of replacements to be applied to the template
* @return the new {@link WorkflowDefinition} based off the template
* @throws IOException if it fails to load the template file for parsing
*/
public static WorkflowDefinition template(
String templateName, Map<String, String> templateReplacements) throws IOException {
URI uri = TEMPLATES.get(templateName);

if (uri == null) {
URL fromResource =
ClassLoaderUtils.getResource("workflowTemplates/" + templateName + ".json");
if (fromResource != null) {
try {
uri = fromResource.toURI();
} catch (URISyntaxException ignored) {
}
}
}

if (uri == null) {
throw new IllegalArgumentException(
"The workflow template " + templateName + " could not be found");
}

return WorkflowDefinition.parse(null, uri, templateReplacements);
}
}
18 changes: 18 additions & 0 deletions serving/src/main/resources/workflowTemplates/adapter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"name": "$adapter",
"version": "0.1",
"models": {
"m": "$model"
},
"configs": {
"adapters": {
"$adapter": {
"model": "m",
"src": "$url"
}
}
},
"workflow": {
"out": ["adapter", "$adapter", "in"]
}
}
14 changes: 14 additions & 0 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ public void testAdapterWorkflows()

testAdapterWorkflowPredict(channel, "adapter1", "a1");
testAdapterWorkflowPredict(channel, "adapter2", "a2");
testRegisterAdapterWorkflowTemplate(channel);

channel.close().sync();

Expand Down Expand Up @@ -925,6 +926,19 @@ private void testAdapterWorkflowPredict(Channel channel, String workflow, String
assertEquals(result, adapter + "testAWP");
}

private void testRegisterAdapterWorkflowTemplate(Channel channel) throws InterruptedException {
logTestFunction();
String adapterUrl = "dummy";
String modelUrl = URLEncoder.encode("src/test/resources/adaptecho", StandardCharsets.UTF_8);

String url =
"/workflows?template=adapter&adapter=a&url=" + adapterUrl + "&model=" + modelUrl;
request(channel, HttpMethod.POST, url);

StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class);
assertEquals(resp.getStatus(), "Workflow \"a\" registered.");
}

private void testAdapterInvoke(Channel channel) throws InterruptedException {
logTestFunction();
String url = "/invocations?model_name=adaptecho&adapter=adaptable";
Expand Down

0 comments on commit 22acc49

Please sign in to comment.