-
Notifications
You must be signed in to change notification settings - Fork 12.8k
aLoRA Support #15327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
aLoRA Support #15327
Conversation
…ation_string Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
This is the preferred method in PEFT which is the source of ground truth https://github.com/huggingface/peft/pull/2609/files#diff-13380145401d203d5935c5189dd09879f990b81aa63e8e3aaff8ce9110333f0e Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
This does not yet do the part to identify the invocation tokens and only apply the lora adapter afterwards, but it does seem to produce correct results if the invocation tokens are the beginning of the uncached input. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
One interesting update: For the specific adapters I'm using to test here, the I've updated my sniff test script above to use client-side template expansion and the raw NOTE: This is a property of these adapters and not of aLoRA in general. Theoretically, an adapter could be trained to invoke on the full |
Add the following to your request to remove the assistant generation prompt: "add_generation_prompt": false, |
Ah, yep, that will definitely help, but it won't eliminate the |
Ah, didn't notice that, I suppose that's just because the template doesn't properly handle unknown roles? |
Yeah, the real issue is that it was trained to act like the generation prompt, so the activation sequence is intentionally an incomplete turn, but with a different role. |
This currently limits to a single enabled alora per slot. Multiple aloras with different invocation sequences would be possible, but it would require a more complex integration of the adapter toggling and is not really a well studied case for alora since it's unclear if one alora can reuse cache from previous prefill computed with a different alora. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
This is a bit of an edge case, but theoretically a user could try the same query with the alora disabled (just using the base model), then retry with the alora. The cached tokens from the first pass should be invalid. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
The solution is to only fill up to the token before the invocation start in the batch if there are any tokens to be prefilled between those pulled from cache and the invocation start. When this is detected, the alora is temporarily disabled with a scale of 0.0, then immediately re-enabled after it has been initialized for the internal graph. Since the batch does not complete the prompt tokens, the remaining prompt tokens are handled in the next task, pulling all of the non-alora tokens from cache and proceeding with prefill for the alora tokens. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
UpdateI've now added support for correctly applying the
TestingI've got a few tweaks to my test script that allow it to stimulate these conditions: uq-req.pyimport json
import time
from transformers import AutoTokenizer
import requests
tokenizer = AutoTokenizer.from_pretrained("/Users/ghart/models/granite-3.2-8b-instruct")
url = "http://localhost:8081"
documents = [
{"text": "My name is Gabe"},
{"text": "I work for IBM"}
]
messages = [{"role": "user", "content": "Who does Gabe work for?"}]
adapter_message = {
"role": "certainty",
"content": ""
}
# Run base messages
print("----")
start = time.time()
resp = requests.post(f"{url}/v1/chat/completions", json={
"model": "unused",
"messages": messages,
"chat_template_kwargs": {
"documents": documents,
},
"temperature": 0.0,
"lora": [
# alora
{"id": 0, "scale": 0.0},
# lora
{"id": 1, "scale": 0.0},
],
})
end = time.time()
assistant_resp = resp.json()["choices"][0]["message"]
print(f"ASSISTANT RESPONSE ({end-start}s):")
print(assistant_resp["content"])
# UNCOMMENT this to extend the assistant's response so that it isn't cached
"""
assistant_resp["content"] = assistant_resp["content"] + "\nRespect my authority!"
"""
# Create the serialized version as a string so we can append the right prompt
messages.append(assistant_resp)
raw_prompt = tokenizer.apply_chat_template(messages, documents=documents, tokenize=False)
uq_prompt = raw_prompt + "<|start_of_role|>certainty<|end_of_role|>"
# Run with both adapters disabled
# UNCOMMENT this to exercise the case where the invocation string itself has
# been cached without the adapter
"""
print("----")
start = time.time()
resp = requests.post(f"{url}/v1/completions", json={
"model": "unused",
"prompt": uq_prompt,
"temperature": 0.0,
"lora": [
# alora
{"id": 0, "scale": 0.0},
# lora
{"id": 1, "scale": 0.0},
],
})
end = time.time()
js = resp.json()
uq_resp = js["choices"][0]["text"]
print(f"UQ RESPONSE w/out adapters ({end-start}s)")
print(uq_resp)
print(">>")
print(json.dumps(js["usage"], indent=2))
print(json.dumps(js["timings"], indent=2))
"""
# Run with the adapter and the prompt for UQ with the alora enabled
print("----")
start = time.time()
resp = requests.post(f"{url}/v1/completions", json={
"model": "unused",
"prompt": uq_prompt,
"temperature": 0.0,
"max_tokens": 100,
"lora": [
# alora
{"id": 0, "scale": 1.0},
# lora
{"id": 1, "scale": 0.0},
],
})
end = time.time()
js = resp.json()
uq_resp = js["choices"][0]["text"]
print(f"UQ RESPONSE w/ aLoRA ({end-start}s)")
print(uq_resp)
print(">>")
print(json.dumps(js["usage"], indent=2))
print(json.dumps(js["timings"], indent=2))
# Run with the adapter and the prompt for UQ with the lora enabled
print("----")
start = time.time()
resp = requests.post(f"{url}/v1/completions", json={
"model": "unused",
"prompt": uq_prompt,
"temperature": 0.0,
"max_tokens": 100,
"lora": [
# alora
{"id": 0, "scale": 0.0},
# lora
{"id": 1, "scale": 1.0},
],
})
end = time.time()
js = resp.json()
uq_resp = js["choices"][0]["text"]
print(f"UQ RESPONSE w/ LoRA ({end-start}s)")
print(uq_resp)
print(">>")
print(json.dumps(js["usage"], indent=2))
print(json.dumps(js["timings"], indent=2)) Don't use cached invocation sequence from base modelThis stimulates the case where the user ran the invocation sequence through the base model without the adapter and those tokens are cached (uncomment starting at line 57)
Don't use adapter for uncached tokens before invocation sequenceThis stimulates the case where for some reason there are additional tokens not pulled from cache that come before the invocation sequence (uncomment line 45)
|
Too much python 🤦 Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
This was the cause of the inconsistent results from the dummy test script with and without the turn that runs the prompt without the adapter before running it with the adapter. Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
I've now extended this to test with multiple Adapters
Adapters are converted using Boot with adapters./bin/llama-server \
-m ~/models/granite-3.2-8b-instruct/granite-3.2-8B-instruct-F16.gguf \
--lora ~/models/granite-3.2-8b-alora-uncertainty/granite-3.2-8B-alora-uncertainty-F16-LoRA.gguf \
--lora ~/models/granite-3.2-8b-alora-rag-answerability-prediction/granite-3.2-8B-alora-rag-answerability-prediction-F16-LoRA.gguf \
--port 8081 \
--jinja \
--reasoning-budget 0 Test Script(sorry, it requires my personal logging framework just 'cuz 😉... alora-chat.py#!/usr/bin/env bash
"""
This is a simple implementation of an interactive chat that leverages several
aLoRA adapters during the flow
"""
# Standard
import argparse
import os
# First Party
import alog
# Third Party
import requests
log = alog.use_channel("MAIN")
def make_document(i: int, doc: str) -> dict:
"""Make a document dict from the given doc as either text or a path"""
log.info("Adding document: %s", doc)
if os.path.exists(doc):
with open(doc, "r") as handle:
return {"text": handle.read(), "doc_id": i, "title": doc}
return{"text": doc, "doc_id": i}
def make_lora_req(adapter_ids: list[int], loras: list[int]) -> list[dict]:
return [
{"id": i, "scale": 1.0 if i in loras else 0.0}
for i in adapter_ids
]
def make_chat_req(messages: list[dict], documents: list[dict], adapter_ids: list[int], loras: list[int]) -> dict:
return {
"messages": messages,
"chat_template_kwargs": {
"documents": documents,
},
"temperature": 0.0,
"lora": make_lora_req(adapter_ids, loras),
}
def make_completion_req(prompt: str, documents: list[dict], adapter_ids: list[int], loras: list[int], **kwargs) -> dict:
kwargs.update({
"prompt": prompt,
"chat_template_kwargs": {
"documents": documents,
},
"temperature": 0.0,
"lora": make_lora_req(adapter_ids, loras),
})
return kwargs
def run_main_loop(host: str, documents: list[dict], uq_id: int, ans_id: int, adapter_ids: list[int]):
"""Run the main loop with questions"""
help_cmd = "/?"
doc_cmd = "/doc"
reset_cmd = "/reset"
quit_cmd = "/quit"
doc_pfx = f"{doc_cmd} "
def print_help():
print("Commands:")
print(f"{help_cmd}: Print help")
print(f"{doc_cmd}: Add a document")
print(f"{reset_cmd}: Reset the chat history")
print(f"{quit_cmd}: Quit")
messages = []
print_help()
while True:
inp = input("?> ").strip()
if inp == quit_cmd:
break
if not inp:
continue
if inp == help_cmd:
print_help()
continue
if inp == reset_cmd:
messages.clear()
continue
if inp.startswith(doc_pfx):
doc = inp[len(doc_pfx):].lstrip()
documents.append(make_document(len(documents), doc))
continue
# Apply the chat template with the user query
user_message = {"role": "user", "content": inp}
resp = requests.post(f"{host}/apply-template", json=make_chat_req(messages + [user_message], documents, adapter_ids, []))
resp.raise_for_status()
formatted_prompt = resp.json()["prompt"]
log.debug4("Formatted prompt: %s", formatted_prompt)
# Run the Answerability query
ans_prompt = formatted_prompt + "<|end_of_text|>\n<|start_of_role|>answerability<|end_of_role|>"
resp = requests.post(f"{host}/v1/completions", json=make_completion_req(ans_prompt, documents, adapter_ids, [ans_id], max_tokens=3))
resp.raise_for_status()
js = resp.json()
answerability = js["choices"][0]["text"]
log.debug("Answerability: %s", answerability)
log.debug2("Usage: %s", js["usage"])
log.debug2("Timings: %s", js["timings"])
answerable = not answerability.split()[0].lower().startswith("unanswerable")
if answerable:
print(">> The question is answerable!")
else:
print(">> I'm sorry, but that question isn't answerable with the given context")
if input("?> Do you want to try anyway [yN]? ").strip().lower() not in ["y", "yes"]:
continue
messages.append(user_message)
# If not unanswerable, run the question and get the assistant's response
resp = requests.post(f"{host}/v1/chat/completions", json=make_chat_req(messages, documents, adapter_ids, []))
resp.raise_for_status()
js = resp.json()
assistant_msg = js["choices"][0]["message"]
answer = assistant_msg["content"]
messages.append(assistant_msg)
print(f"ASSISTANT: {answer}")
# Get the uncertainty
formatted_prompt = requests.post(f"{host}/apply-template", json=make_chat_req(messages, documents, adapter_ids, [])).json()["prompt"]
uq_prompt = formatted_prompt + "<|end_of_text|>\n<|start_of_role|>certainty<|end_of_role|>"
resp = requests.post(f"{host}/v1/completions", json=make_completion_req(uq_prompt, documents, adapter_ids, [uq_id], max_tokens=5))
resp.raise_for_status()
js = resp.json()
uq = js["choices"][0]["text"]
print(f">> CERTAINTY: {uq}")
log.debug2("Usage: %s", js["usage"])
log.debug2("Timings: %s", js["timings"])
print()
def main():
parser = argparse.ArgumentParser(description=__doc__)
# Logging
parser.add_argument("--log-level", "-l", default=os.getenv("LOG_LEVEL", "info"))
parser.add_argument("--log-filters", "-lf", default=os.getenv("LOG_FILTERS", "urllib3.connectionpool:info"))
parser.add_argument("--log-json", "-lj", action="store_true", default=os.getenv("LOG_JSON", "").lower() == "true")
# Models
parser.add_argument("--alora-uq", "-u", type=int, default=None, help="Adapter ID for the UQ adapter")
parser.add_argument("--alora-answerability", "-a", type=int, default=None, help="Adapter ID for the Answerability adapter")
# Server
parser.add_argument("--host", "-s", default="http://localhost:8081", help="Host where llama-server is running")
# Docs
parser.add_argument("--document", "-d", nargs="+", help="document (text or path) to add as context")
# Configure logging
args = parser.parse_args()
alog.configure(
default_level=args.log_level,
filters=args.log_filters,
formatter="json" if args.log_json else "pretty",
thread_id=True,
)
# Make sure llama-server is up!
resp = requests.get(f"{args.host}/health")
resp.raise_for_status()
log.info("llama-server is up at %s", args.host)
# Get the loaded adapters
resp = requests.get(f"{args.host}/lora-adapters")
adapters = resp.json()
adapter_ids = [entry["id"] for entry in adapters]
# Figure out which adapter is which
uq_id = args.alora_uq
if uq_id is None:
candidates = [entry for entry in adapters if "uncertainty" in entry["path"]]
assert len(candidates) == 1, "Couldn't auto-deduce UQ adapter ID"
uq_id = candidates[0]["id"]
ans_id = args.alora_answerability
if ans_id is None:
candidates = [entry for entry in adapters if "answerability" in entry["path"]]
assert len(candidates) == 1, "Couldn't auto-deduce Answerability adapter ID"
ans_id = candidates[0]["id"]
log.info("UQ aLoRA ID: %d, Answerability aLoRA ID: %d", uq_id, ans_id)
# Load documents
documents = []
for i, doc in enumerate(args.document or []):
documents.append(make_document(i, doc))
# Start the prompt loop
log.info("Starting main loop")
run_main_loop(args.host, documents, uq_id, ans_id, adapter_ids)
if __name__ == "__main__":
main() Example Output
(NOTE: It's clear from my experiments that these adapters are not particularly robust, but that's a property of these specific ones that are being continuously refined!) |
I realized that my local |
…er_config.json While this has been replaced in the PEFT PR in favor of alora_invocation_tokens, the existing adapters in the ibm-granite org on HF use "invocation_string," so this will enable backwards compatibility and enable testing now (before PEFT PR changes have percolated everywhere). Branch: gabe-l-hart/alora-support Signed-off-by: Gabe Goodhart <[email protected]>
The other contingency for this PR is #15404. The functionality is not linked at all, but the above chat script will fail out when trying to perform the chat template expansion without the fix there. |
One additional note: These adapters seem to still work well when attached to a quantized model, so they don't require losing the speed/footprint benefits of quantization. ./bin/llama-server -m ~/models/granite-3.2-8b-instruct/ggml-model-Q4_K_M.gguf --lora ~/models/granite-3.2-8b-alora-uncertainty/granite-3.2-8B-alora-uncertainty-F16-LoRA.gguf --lora ~/models/granite-3.2-8b-alora-rag-answerability-prediction/granite-3.2-8B-alora-rag-answerability-prediction-F16-LoRA.gguf --port 8081 --jinja --reasoning-budget 0
EDIT: I also tried with |
Also important to test will be concurrent requests to the same |
DRAFT STATUSThis PR is in draft as a proof-of-concept while we discuss the best path forward.The implementation is now robust enough to be ready for full review. The changes were a bit more involved than I had originally hoped based on Georgi's comment, but they are all contained to the
tools/server
except for the changes to support the new GGUF field.Description
Closes #15212
Supports #15213
This PR adds support for Activated LoRA (aLoRA) in
llama-server
and in the GGUF representation of a LoRA adapter. The primary benefit of aLoRA is the ability to hot-swap adapters without needing to clear cache. This enables a much more efficient multi-adapter model where individual adapters provide "add-on" features to a model and can be applied during a model flow without redoing the prefill work.Current Changes
adapter.alora.invocation_tokens
GGUF KVadapter.alora.invocation_tokens
from"alora_invocation_tokens"
inconvert_lora_to_gguf.py
adapter.alora.invocation_tokens
when loading an adapteralora_invocation_tokens
tollama_lora_adapter
structllama.h
to support getting the invocation tokens from aconst llama_lora_adapter *
server
to conditionally not clear cache when a request with an adapter change arrives under the following conditions:aloras
TODO
alora
is to identify the invocation tokens within an input request and only use the adapter for tokens starting with the invocation sequence. This may require a much deeper intervention to support adapter scaling on a per-token basis rather than on a per computation basis.Testing
I'm testing this using the following models and adapters:
Conversion
Execution
Sniff test
This script simply verifies that the two adapters can be toggled and that the cache is cleared appropriately. The example inputs are trivial, so the timings are not particularly valuable.
server-req.py
Response