Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config for optional parameters in a chat message #2260

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3f14db7
feat: add config for optional parameters in a chat message
NJordan72 Jan 14, 2025
a685c8a
chore: cleanup
NJordan72 Jan 14, 2025
48c2748
chore: fix nits and add light docs
NJordan72 Jan 14, 2025
8b77971
docs: update docs/dataset-formats/conversation.qmd
NJordan72 Jan 15, 2025
f446988
feat: configurable message mappings, jinja template analyzer
NJordan72 Jan 16, 2025
c3ba9be
chore: handle bradley terry
NJordan72 Jan 16, 2025
31d8a83
docs: update docs
NJordan72 Jan 16, 2025
69771b7
refactor: change order of mappings, improve message transform
NJordan72 Jan 16, 2025
e79ae25
refactor: make chat awware of property mappings
NJordan72 Jan 16, 2025
2928567
chore: remove .python-version
NJordan72 Jan 16, 2025
609f9e2
chore: revert change
NJordan72 Jan 16, 2025
c080d95
chore: add dataset validation to tests where appropriate
NJordan72 Jan 17, 2025
dccd4d0
chore: add dataset validation to tests where appropriate
NJordan72 Jan 17, 2025
2808782
chore: clean up handling of ds_cfg
NJordan72 Jan 17, 2025
ab04956
chore: recursively serialize config
NJordan72 Jan 17, 2025
b81f6da
make sure to use the return value from validate_config
winglian Jan 17, 2025
5805fc3
DefaultDict pickle/unpickle fix
winglian Jan 17, 2025
fe5e394
fix super call for override
winglian Jan 17, 2025
f6586bc
refactor: message fields
NJordan72 Jan 21, 2025
0f9f5dd
chore: empty commit
NJordan72 Jan 14, 2025
ec66d47
tests: validate config before using
NJordan72 Jan 21, 2025
696f122
chore: add config validation to all e2e tests
NJordan72 Jan 21, 2025
2b82d58
chore: add unneeded logging
NJordan72 Jan 21, 2025
a7f32f7
chore: add missed config validation
NJordan72 Jan 21, 2025
fc26310
chore: pass field_messages to prompter
NJordan72 Jan 21, 2025
05646f4
test: fix borked test
NJordan72 Jan 21, 2025
26fa3a1
chore: remove uninteded file
NJordan72 Jan 23, 2025
5ff0b1e
chore: add deprecation warning and update chat_datasets script
NJordan72 Jan 24, 2025
756a801
chore: lint
NJordan72 Jan 24, 2025
04a42bb
refactor: message fields
NJordan72 Jan 21, 2025
daac330
feat: update axolotlinputconfig and test_models
NJordan72 Feb 2, 2025
b4bdcfe
feat: simplify dpodataset and ktodataset classes in config models
NJordan72 Feb 2, 2025
993878f
feat: improve readability and structure in dataset configuration models
NJordan72 Feb 2, 2025
965fd80
feat: change log level from info to debug in chattemplatestrategy
NJordan72 Feb 2, 2025
ff8569f
feat(prompt_strategies): refactor chattemplateprompter and chattempla…
NJordan72 Feb 2, 2025
e98afd9
feat(tests/utils): ignore type check in load_model call in test_model…
NJordan72 Feb 2, 2025
f82ed68
feat: improve type handling and test structure in chat templates
NJordan72 Feb 2, 2025
1af653c
feat(axolotl): enhance chat strategy with datasetconfig support
NJordan72 Feb 2, 2025
1f74e22
feat: update message handling in btchattemplatestrategy
NJordan72 Feb 2, 2025
69e8f56
feat: add config validation in test_kd.py
NJordan72 Feb 2, 2025
d1b7e16
feat: enhance config validation and capabilities handling
NJordan72 Feb 2, 2025
e79ca76
feat: update config validation in axolotl utils
NJordan72 Feb 2, 2025
0a32481
feat: refactor strategyloader in chat_template.py
NJordan72 Feb 3, 2025
dc7a526
trigger CI
NJordan72 Feb 3, 2025
d46a178
chore: revert dataset config changes for kto/dpo
NJordan72 Feb 3, 2025
4a98876
subject: refactor: rename 'messages_array_name' to 'field_messages'
NJordan72 Feb 3, 2025
bbfda64
feat: refactor prompt strategies and update config models
NJordan72 Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ datasets:
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Mapping of properties from the input dataset to the chat template. (default: None)
message_property_mappings:

# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
Expand Down
15 changes: 7 additions & 8 deletions scripts/chat_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
ds_cfg["field_messages"] = field_messages

message_fields = features[field_messages][0].keys()
message_field_role = None

message_property_mappings = {"role": None, "content": None}
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
message_property_mappings["role"] = key
break
if not message_field_role:
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role

message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
message_property_mappings["content"] = key
break
if not message_field_content:
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
ds_cfg["message_property_mappings"] = message_property_mappings

print(yaml.dump({"datasets": [ds_cfg]}))

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor

return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None
36 changes: 15 additions & 21 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ def _tokenize_single_prompt(self, prompt):

max_length = self.prompter.max_length

self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super()._tokenize_single_prompt(prompt)

if len(chosen_tokenized["input_ids"]) > max_length:
Expand All @@ -55,17 +52,12 @@ def _tokenize_single_prompt(self, prompt):
:max_length
]

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
rejected_tokenized = super()._tokenize_single_prompt(prompt)

if len(rejected_tokenized["input_ids"]) > max_length:
Expand Down Expand Up @@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
Expand All @@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]

return strategy
143 changes: 94 additions & 49 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel
from transformers import ProcessorMixin

from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig

# Configure the logger
LOG = logging.getLogger("axolotl")
Expand All @@ -23,16 +26,22 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
chat_template: str,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "role",
message_field_content: str = "content",
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
if message_property_mappings is None:
message_property_mappings = {
"role": "role",
"content": "content",
}

if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
Expand All @@ -45,18 +54,28 @@ def __init__(
"tool": "tool",
}

self.message_field_role = message_field_role
self.message_field_content = message_field_content
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template, field_messages
)
self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message

@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables

def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")

text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
Expand Down Expand Up @@ -184,24 +203,29 @@ def adjust_train_details(

return adjusted_details

def get_chat_template_msg_variables(
self, chat_template: str, field_messages: str
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(field_messages)


class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""

_messages = "messages"

def __init__(
self,
prompter: ChatTemplatePrompter,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter

self.roles_to_train = []
if roles_to_train:
Expand All @@ -213,13 +237,9 @@ def __init__(
self.train_on_eos = train_on_eos
self.images = "images"

@property
def messages(self):
return self._messages

@messages.setter
def messages(self, messages):
self._messages = messages
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)

@property
def supports_batched(self) -> bool:
Expand All @@ -229,7 +249,7 @@ def supports_batched(self) -> bool:
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
Expand Down Expand Up @@ -464,37 +484,52 @@ def find_turn(self, turns: list[dict], turn_idx: int):

def get_conversation_thread(self, prompt):
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)

turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
**transformed_message,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need to worry about any nested copying here? Should we set turn = transformed_message then set the other two fields after?

"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}

# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content

for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value

turns.append(turn)

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

return turns

def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {
key: message[value]
for key, value in self.prompter.message_property_mappings.items()
if message.get(value) is not None
}

# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)

# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values

# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val

return transformed_message

def get_images(self, prompt):
return prompt.get(self.images, None)

Expand All @@ -516,33 +551,46 @@ def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
}

def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
if ds_cfg is None:
dataset_config = {}
elif isinstance(ds_cfg, BaseModel):
dataset_config = ds_cfg.model_dump()
else:
dataset_config = ds_cfg

chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get(
"message_field_training", None
),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"field_messages": dataset_config.get("field_messages", "messages"),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}

strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls()

strategy = strategy_cls(
Expand All @@ -551,9 +599,6 @@ def __call__(
**strategy_params,
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]

return strategy


Expand Down
Loading
Loading