-
Notifications
You must be signed in to change notification settings - Fork 152
feat: add support of internlm2.5-1.8B-chat model #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| add_executable(mllm-internlm2_5-chat-runner main.cpp) | ||
| target_link_libraries(mllm-internlm2_5-chat-runner PRIVATE MllmRT MllmCPUBackend) | ||
| target_include_directories(mllm-internlm2_5-chat-runner PRIVATE ${MLLM_INCLUDE_DIR}) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| { | ||
| "architectures": [ | ||
| "InternLM2ForCausalLM" | ||
| ], | ||
| "bias": false, | ||
| "bos_token_id": 1, | ||
| "eos_token_id": 2, | ||
| "hidden_act": "silu", | ||
| "hidden_size": 2048, | ||
| "initializer_range": 0.02, | ||
| "intermediate_size": 8192, | ||
| "max_position_embeddings": 32768, | ||
| "num_attention_heads": 16, | ||
| "num_hidden_layers": 24, | ||
| "num_key_value_heads": 8, | ||
| "pad_token_id": 2, | ||
| "rms_norm_eps": 1e-05, | ||
| "rope_scaling": { | ||
| "factor": 2.0, | ||
| "type": "dynamic" | ||
| }, | ||
| "rope_theta": 1000000, | ||
| "tie_word_embeddings": false, | ||
| "torch_dtype": "bfloat16", | ||
| "transformers_version": "4.34.0", | ||
| "use_cache": true, | ||
| "vocab_size": 92544, | ||
| "linear_impl_type": "Default" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| #include <iostream> | ||
| #include <fmt/core.h> | ||
| #include <mllm/mllm.hpp> | ||
| #include <mllm/models/internlm2/modeling_internlm2.hpp> | ||
| #include <mllm/models/internlm2/tokenization_internlm2.hpp> | ||
|
|
||
| using mllm::Argparse; | ||
|
|
||
| MLLM_MAIN({ | ||
| auto& help = Argparse::add<bool>("-h|--help").help("Show help message"); | ||
| auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model path").required(true); | ||
| auto& model_version = Argparse::add<std::string>("-mv|--model_version").help("Model version").required(true); | ||
| auto& tokenizer_path = Argparse::add<std::string>("-t|--tokenizer_path").help("Tokenizer JSON path").required(true); | ||
| auto& config_path = Argparse::add<std::string>("-c|--config_path").help("Config path").required(true); | ||
|
|
||
| Argparse::parse(argc, argv); | ||
|
|
||
| #ifdef MLLM_PERFETTO_ENABLE | ||
| mllm::perf::start(); | ||
| #endif | ||
|
|
||
| mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; | ||
| if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; } | ||
|
|
||
| if (help.isSet()) { | ||
| Argparse::printHelp(); | ||
| mllm::shutdownContext(); | ||
| return 0; | ||
| } | ||
|
|
||
| auto cfg = mllm::models::internlm2::InternLM2Config(config_path.get()); | ||
| auto tokenizer = mllm::models::internlm2::InternLM2Tokenizer(tokenizer_path.get()); | ||
| auto model = mllm::models::internlm2::InternLM2ForCausalLM(cfg); | ||
|
|
||
| auto params = mllm::load(model_path.get(), file_version); | ||
| model.load(params); | ||
|
|
||
| fmt::print("\n{:*^60}\n", " InternLM2.5 1.5B CLI "); | ||
| fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); | ||
|
|
||
| std::string prompt_text; | ||
| fmt::print("💬 Prompt text (or 'exit/quit'): "); | ||
| std::getline(std::cin, prompt_text); | ||
| if (!(prompt_text == "exit" || prompt_text == "quit")) { | ||
| try { | ||
| fmt::print("🔄 Processing...\n"); | ||
| mllm::models::internlm2::InternLM2Message prompt{prompt_text}; | ||
| auto inputs = tokenizer.convertMessage(prompt); | ||
|
|
||
| fmt::print("\n🤖 Response: "); | ||
| for (auto& step : model.chat(inputs)) { | ||
| auto token = tokenizer.detokenize(step.cur_token_id); | ||
| std::wcout << token << std::flush; | ||
| } | ||
| fmt::print("\n{}\n", std::string(60, '-')); | ||
| } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); } | ||
| model.perfSummary(); | ||
| } | ||
|
Comment on lines
+41
to
+58
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loop the interactive chat Apply this diff to address the issue: - std::string prompt_text;
- fmt::print("💬 Prompt text (or 'exit/quit'): ");
- std::getline(std::cin, prompt_text);
- if (!(prompt_text == "exit" || prompt_text == "quit")) {
- try {
- fmt::print("🔄 Processing...\n");
- mllm::models::internlm2::InternLM2Message prompt{prompt_text};
- auto inputs = tokenizer.convertMessage(prompt);
-
- fmt::print("\n🤖 Response: ");
- for (auto& step : model.chat(inputs)) {
- auto token = tokenizer.detokenize(step.cur_token_id);
- std::wcout << token << std::flush;
- }
- fmt::print("\n{}\n", std::string(60, '-'));
- } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
- model.perfSummary();
- }
+ for (std::string prompt_text;;) {
+ fmt::print("💬 Prompt text (or 'exit/quit'): ");
+ if (!std::getline(std::cin, prompt_text)) { break; }
+ if (prompt_text == "exit" || prompt_text == "quit") { break; }
+ try {
+ fmt::print("🔄 Processing...\n");
+ mllm::models::internlm2::InternLM2Message prompt{prompt_text};
+ auto inputs = tokenizer.convertMessage(prompt);
+
+ fmt::print("\n🤖 Response: ");
+ for (auto& step : model.chat(inputs)) {
+ auto token = tokenizer.detokenize(step.cur_token_id);
+ std::wcout << token << std::flush;
+ }
+ fmt::print("\n{}\n", std::string(60, '-'));
+ } catch (const std::exception& e) {
+ fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-'));
+ }
+ model.perfSummary();
+ }
<!-- suggestion_start -->
<details>
<summary>📝 Committable suggestion</summary>
> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
```suggestion
for (std::string prompt_text;;) {
fmt::print("💬 Prompt text (or 'exit/quit'): ");
if (!std::getline(std::cin, prompt_text)) { break; }
if (prompt_text == "exit" || prompt_text == "quit") { break; }
try {
fmt::print("🔄 Processing...\n");
mllm::models::internlm2::InternLM2Message prompt{prompt_text};
auto inputs = tokenizer.convertMessage(prompt);
fmt::print("\n🤖 Response: ");
for (auto& step : model.chat(inputs)) {
auto token = tokenizer.detokenize(step.cur_token_id);
std::wcout << token << std::flush;
}
fmt::print("\n{}\n", std::string(60, '-'));
} catch (const std::exception& e) {
fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-'));
}
model.perfSummary();
}🤖 Prompt for AI Agents |
||
|
|
||
| #ifdef MLLM_PERFETTO_ENABLE | ||
| mllm::perf::stop(); | ||
| mllm::perf::saveReport("internlm2_5.perf"); | ||
| #endif | ||
|
|
||
| mllm::print("\n"); | ||
| mllm::memoryReport(); | ||
| }) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,82 @@ | ||||||||||||||||||||||||||||
| // Copyright (c) MLLM Team. | ||||||||||||||||||||||||||||
| // Licensed under the MIT License. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include <string> | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include "mllm/core/aops/LinearOp.hpp" | ||||||||||||||||||||||||||||
| #include "mllm/engine/ConfigFile.hpp" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| namespace mllm::models::internlm2 { | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| struct InternLM2Config : protected ConfigFile { | ||||||||||||||||||||||||||||
| InternLM2Config() = default; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| explicit InternLM2Config(const std::string& file_path) : ConfigFile(file_path) { | ||||||||||||||||||||||||||||
| auto& json = data(); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (json.contains("bias")) { bias = json["bias"].get<bool>(); } | ||||||||||||||||||||||||||||
| if (json.contains("hidden_size")) { hidden_size = json["hidden_size"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("intermediate_size")) { intermediate_size = json["intermediate_size"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("num_hidden_layers")) { num_hidden_layers = json["num_hidden_layers"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("num_attention_heads")) { num_attention_heads = json["num_attention_heads"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("num_key_value_heads")) { | ||||||||||||||||||||||||||||
| num_key_value_heads = json["num_key_value_heads"].get<int32_t>(); | ||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| num_key_value_heads = num_attention_heads; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (json.contains("max_position_embeddings")) { max_position_embeddings = json["max_position_embeddings"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("rms_norm_eps")) { rms_norm_eps = json["rms_norm_eps"].get<float>(); } | ||||||||||||||||||||||||||||
| if (json.contains("vocab_size")) { vocab_size = json["vocab_size"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("rope_theta")) { rope_theta = json["rope_theta"].get<float>(); } | ||||||||||||||||||||||||||||
| if (json.contains("tie_word_embeddings")) { tie_word_embeddings = json["tie_word_embeddings"].get<bool>(); } | ||||||||||||||||||||||||||||
| if (json.contains("use_cache")) { use_cache = json["use_cache"].get<bool>(); } | ||||||||||||||||||||||||||||
| if (json.contains("pad_token_id")) { pad_token_id = json["pad_token_id"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("bos_token_id")) { bos_token_id = json["bos_token_id"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("eos_token_id")) { eos_token_id = json["eos_token_id"].get<int32_t>(); } | ||||||||||||||||||||||||||||
| if (json.contains("initializer_range")) { initializer_range = json["initializer_range"].get<float>(); } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (json.contains("rope_scaling")) { | ||||||||||||||||||||||||||||
| const auto& scaling = json["rope_scaling"]; | ||||||||||||||||||||||||||||
| if (scaling.contains("type")) { rope_scaling_type = scaling["type"].get<std::string>(); } | ||||||||||||||||||||||||||||
| if (scaling.contains("factor")) { rope_scaling_factor = scaling["factor"].get<float>(); } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (json.contains("linear_impl_type")) { | ||||||||||||||||||||||||||||
| linear_impl_type = aops::str2LinearImplTypes(json["linear_impl_type"].get<std::string>()); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| head_dim = hidden_size / num_attention_heads; | ||||||||||||||||||||||||||||
| max_cache_length = max_position_embeddings; | ||||||||||||||||||||||||||||
| end_of_text_token_id = static_cast<int32_t>(eos_token_id); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
Comment on lines
+50
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate head dimension before use Use this diff as a starting point: - head_dim = hidden_size / num_attention_heads;
+ if (num_attention_heads <= 0 || hidden_size % num_attention_heads != 0) {
+ throw std::invalid_argument(
+ fmt::format("hidden_size ({}) must be divisible by num_attention_heads ({})",
+ hidden_size, num_attention_heads));
+ }
+ head_dim = hidden_size / num_attention_heads;📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| bool bias = false; | ||||||||||||||||||||||||||||
| int32_t hidden_size = 4096; | ||||||||||||||||||||||||||||
| int32_t intermediate_size = 11008; | ||||||||||||||||||||||||||||
| int32_t num_hidden_layers = 32; | ||||||||||||||||||||||||||||
| int32_t num_attention_heads = 32; | ||||||||||||||||||||||||||||
| int32_t num_key_value_heads = 32; | ||||||||||||||||||||||||||||
| int32_t max_position_embeddings = 2048; | ||||||||||||||||||||||||||||
| int32_t max_cache_length = 2048; | ||||||||||||||||||||||||||||
| int32_t head_dim = 128; | ||||||||||||||||||||||||||||
| int32_t vocab_size = 32000; | ||||||||||||||||||||||||||||
| float rms_norm_eps = 1e-6f; | ||||||||||||||||||||||||||||
| float rope_theta = 10000.0f; | ||||||||||||||||||||||||||||
| float rope_scaling_factor = 1.0f; | ||||||||||||||||||||||||||||
| std::string rope_scaling_type; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| float initializer_range = 0.02f; | ||||||||||||||||||||||||||||
| bool tie_word_embeddings = false; | ||||||||||||||||||||||||||||
| bool use_cache = true; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| int32_t pad_token_id = 0; | ||||||||||||||||||||||||||||
| int32_t bos_token_id = 1; | ||||||||||||||||||||||||||||
| int32_t eos_token_id = 2; | ||||||||||||||||||||||||||||
| int32_t end_of_text_token_id = 2; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| } // namespace mllm::models::internlm2 | ||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the banner text
The banner still says “InternLM2.5 1.5B CLI”, which is misleading now that this runner targets the 1.8 B chat model. Please update the label to match the actual model size.
Apply this diff:
📝 Committable suggestion
🤖 Prompt for AI Agents