Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(qwen3)
add_subdirectory(qwen3_service)
add_subdirectory(deepseek_ocr)
add_subdirectory(smollm3_3B)
add_subdirectory(internlm2_5)

if(MLLM_BUILD_QNN_BACKEND)
add_subdirectory(qwen_npu)
Expand Down
3 changes: 3 additions & 0 deletions examples/internlm2_5/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
add_executable(mllm-internlm2_5-chat-runner main.cpp)
target_link_libraries(mllm-internlm2_5-chat-runner PRIVATE MllmRT MllmCPUBackend)
target_include_directories(mllm-internlm2_5-chat-runner PRIVATE ${MLLM_INCLUDE_DIR})
29 changes: 29 additions & 0 deletions examples/internlm2_5/config_1.8B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"architectures": [
"InternLM2ForCausalLM"
],
"bias": false,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 32768,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"num_key_value_heads": 8,
"pad_token_id": 2,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 2.0,
"type": "dynamic"
},
"rope_theta": 1000000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.34.0",
"use_cache": true,
"vocab_size": 92544,
"linear_impl_type": "Default"
}
67 changes: 67 additions & 0 deletions examples/internlm2_5/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <iostream>
#include <fmt/core.h>
#include <mllm/mllm.hpp>
#include <mllm/models/internlm2/modeling_internlm2.hpp>
#include <mllm/models/internlm2/tokenization_internlm2.hpp>

using mllm::Argparse;

MLLM_MAIN({
auto& help = Argparse::add<bool>("-h|--help").help("Show help message");
auto& model_path = Argparse::add<std::string>("-m|--model_path").help("Model path").required(true);
auto& model_version = Argparse::add<std::string>("-mv|--model_version").help("Model version").required(true);
auto& tokenizer_path = Argparse::add<std::string>("-t|--tokenizer_path").help("Tokenizer JSON path").required(true);
auto& config_path = Argparse::add<std::string>("-c|--config_path").help("Config path").required(true);

Argparse::parse(argc, argv);

#ifdef MLLM_PERFETTO_ENABLE
mllm::perf::start();
#endif

mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1;
if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; }

if (help.isSet()) {
Argparse::printHelp();
mllm::shutdownContext();
return 0;
}

auto cfg = mllm::models::internlm2::InternLM2Config(config_path.get());
auto tokenizer = mllm::models::internlm2::InternLM2Tokenizer(tokenizer_path.get());
auto model = mllm::models::internlm2::InternLM2ForCausalLM(cfg);

auto params = mllm::load(model_path.get(), file_version);
model.load(params);

fmt::print("\n{:*^60}\n", " InternLM2.5 1.5B CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");
Comment on lines +38 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the banner text
The banner still says “InternLM2.5 1.5B CLI”, which is misleading now that this runner targets the 1.8 B chat model. Please update the label to match the actual model size.

Apply this diff:

-  fmt::print("\n{:*^60}\n", " InternLM2.5 1.5B CLI ");
+  fmt::print("\n{:*^60}\n", " InternLM2.5 1.8B CLI ");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fmt::print("\n{:*^60}\n", " InternLM2.5 1.5B CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");
fmt::print("\n{:*^60}\n", " InternLM2.5 1.8B CLI ");
fmt::print("Enter 'exit' or 'quit' to end the session\n\n");
🤖 Prompt for AI Agents
In examples/internlm2_5/main.cpp around lines 38 to 39, the banner currently
reads "InternLM2.5 1.5B CLI" but the runner targets the 1.8B chat model; update
the banner text to "InternLM2.5 1.8B CLI" by changing the formatted string
passed to fmt::print so the displayed label reflects the correct model size.


std::string prompt_text;
fmt::print("💬 Prompt text (or 'exit/quit'): ");
std::getline(std::cin, prompt_text);
if (!(prompt_text == "exit" || prompt_text == "quit")) {
try {
fmt::print("🔄 Processing...\n");
mllm::models::internlm2::InternLM2Message prompt{prompt_text};
auto inputs = tokenizer.convertMessage(prompt);

fmt::print("\n🤖 Response: ");
for (auto& step : model.chat(inputs)) {
auto token = tokenizer.detokenize(step.cur_token_id);
std::wcout << token << std::flush;
}
fmt::print("\n{}\n", std::string(60, '-'));
} catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
model.perfSummary();
}
Comment on lines +41 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Loop the interactive chat
The CLI prints “Enter 'exit' or 'quit' to end the session”, but it only processes a single prompt and then exits. Users can’t hold a multi-turn conversation, which defeats the whole purpose of exposing a chat runner. Wrap the prompt/response block in a loop so exit/quit truly controls termination.

Apply this diff to address the issue:

-  std::string prompt_text;
-  fmt::print("💬 Prompt text (or 'exit/quit'): ");
-  std::getline(std::cin, prompt_text);
-  if (!(prompt_text == "exit" || prompt_text == "quit")) {
-    try {
-      fmt::print("🔄 Processing...\n");
-      mllm::models::internlm2::InternLM2Message prompt{prompt_text};
-      auto inputs = tokenizer.convertMessage(prompt);
-
-      fmt::print("\n🤖 Response: ");
-      for (auto& step : model.chat(inputs)) {
-        auto token = tokenizer.detokenize(step.cur_token_id);
-        std::wcout << token << std::flush;
-      }
-      fmt::print("\n{}\n", std::string(60, '-'));
-    } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); }
-    model.perfSummary();
-  }
+  for (std::string prompt_text;;) {
+    fmt::print("💬 Prompt text (or 'exit/quit'): ");
+    if (!std::getline(std::cin, prompt_text)) { break; }
+    if (prompt_text == "exit" || prompt_text == "quit") { break; }
+    try {
+      fmt::print("🔄 Processing...\n");
+      mllm::models::internlm2::InternLM2Message prompt{prompt_text};
+      auto inputs = tokenizer.convertMessage(prompt);
+
+      fmt::print("\n🤖 Response: ");
+      for (auto& step : model.chat(inputs)) {
+        auto token = tokenizer.detokenize(step.cur_token_id);
+        std::wcout << token << std::flush;
+      }
+      fmt::print("\n{}\n", std::string(60, '-'));
+    } catch (const std::exception& e) {
+      fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-'));
+    }
+    model.perfSummary();
+  }

<!-- suggestion_start -->

<details>
<summary>📝 Committable suggestion</summary>

> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

```suggestion
  for (std::string prompt_text;;) {
    fmt::print("💬 Prompt text (or 'exit/quit'): ");
    if (!std::getline(std::cin, prompt_text)) { break; }
    if (prompt_text == "exit" || prompt_text == "quit") { break; }
    try {
      fmt::print("🔄 Processing...\n");
      mllm::models::internlm2::InternLM2Message prompt{prompt_text};
      auto inputs = tokenizer.convertMessage(prompt);

      fmt::print("\n🤖 Response: ");
      for (auto& step : model.chat(inputs)) {
        auto token = tokenizer.detokenize(step.cur_token_id);
        std::wcout << token << std::flush;
      }
      fmt::print("\n{}\n", std::string(60, '-'));
    } catch (const std::exception& e) {
      fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-'));
    }
    model.perfSummary();
  }
🤖 Prompt for AI Agents
In examples/internlm2_5/main.cpp around lines 41-58, the interactive prompt is
executed only once; wrap the prompt/response block in a loop so users can hold a
multi-turn conversation. Replace the single read+if with a for/while loop that
repeatedly prints the prompt, uses std::getline to read into prompt_text
(breaking the loop on EOF), and breaks when the input equals "exit" or "quit";
keep the existing try/catch/response printing and call model.perfSummary() after
each iteration.


#ifdef MLLM_PERFETTO_ENABLE
mllm::perf::stop();
mllm::perf::saveReport("internlm2_5.perf");
#endif

mllm::print("\n");
mllm::memoryReport();
})
82 changes: 82 additions & 0 deletions mllm/models/internlm2/configuration_internlm2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.

#pragma once

#include <string>

#include "mllm/core/aops/LinearOp.hpp"
#include "mllm/engine/ConfigFile.hpp"

namespace mllm::models::internlm2 {

struct InternLM2Config : protected ConfigFile {
InternLM2Config() = default;

explicit InternLM2Config(const std::string& file_path) : ConfigFile(file_path) {
auto& json = data();

if (json.contains("bias")) { bias = json["bias"].get<bool>(); }
if (json.contains("hidden_size")) { hidden_size = json["hidden_size"].get<int32_t>(); }
if (json.contains("intermediate_size")) { intermediate_size = json["intermediate_size"].get<int32_t>(); }
if (json.contains("num_hidden_layers")) { num_hidden_layers = json["num_hidden_layers"].get<int32_t>(); }
if (json.contains("num_attention_heads")) { num_attention_heads = json["num_attention_heads"].get<int32_t>(); }
if (json.contains("num_key_value_heads")) {
num_key_value_heads = json["num_key_value_heads"].get<int32_t>();
} else {
num_key_value_heads = num_attention_heads;
}
if (json.contains("max_position_embeddings")) { max_position_embeddings = json["max_position_embeddings"].get<int32_t>(); }
if (json.contains("rms_norm_eps")) { rms_norm_eps = json["rms_norm_eps"].get<float>(); }
if (json.contains("vocab_size")) { vocab_size = json["vocab_size"].get<int32_t>(); }
if (json.contains("rope_theta")) { rope_theta = json["rope_theta"].get<float>(); }
if (json.contains("tie_word_embeddings")) { tie_word_embeddings = json["tie_word_embeddings"].get<bool>(); }
if (json.contains("use_cache")) { use_cache = json["use_cache"].get<bool>(); }
if (json.contains("pad_token_id")) { pad_token_id = json["pad_token_id"].get<int32_t>(); }
if (json.contains("bos_token_id")) { bos_token_id = json["bos_token_id"].get<int32_t>(); }
if (json.contains("eos_token_id")) { eos_token_id = json["eos_token_id"].get<int32_t>(); }
if (json.contains("initializer_range")) { initializer_range = json["initializer_range"].get<float>(); }

if (json.contains("rope_scaling")) {
const auto& scaling = json["rope_scaling"];
if (scaling.contains("type")) { rope_scaling_type = scaling["type"].get<std::string>(); }
if (scaling.contains("factor")) { rope_scaling_factor = scaling["factor"].get<float>(); }
}

if (json.contains("linear_impl_type")) {
linear_impl_type = aops::str2LinearImplTypes(json["linear_impl_type"].get<std::string>());
}

head_dim = hidden_size / num_attention_heads;
max_cache_length = max_position_embeddings;
end_of_text_token_id = static_cast<int32_t>(eos_token_id);
}
Comment on lines +50 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate head dimension before use
head_dim is computed with plain integer division. If hidden_size isn’t divisible by num_attention_heads (or if the value is zero), this silently truncates and breaks attention math later on. The reference InternLM2 implementations explicitly guard against this with a runtime check. Please add the same validation before assigning head_dim so that bad configs fail fast instead of producing corrupted tensors. (huggingface.co)

Use this diff as a starting point:

-    head_dim = hidden_size / num_attention_heads;
+    if (num_attention_heads <= 0 || hidden_size % num_attention_heads != 0) {
+      throw std::invalid_argument(
+          fmt::format("hidden_size ({}) must be divisible by num_attention_heads ({})",
+                      hidden_size, num_attention_heads));
+    }
+    head_dim = hidden_size / num_attention_heads;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
head_dim = hidden_size / num_attention_heads;
max_cache_length = max_position_embeddings;
end_of_text_token_id = static_cast<int32_t>(eos_token_id);
}
if (num_attention_heads <= 0 || hidden_size % num_attention_heads != 0) {
throw std::invalid_argument(
fmt::format("hidden_size ({}) must be divisible by num_attention_heads ({})",
hidden_size, num_attention_heads));
}
head_dim = hidden_size / num_attention_heads;
max_cache_length = max_position_embeddings;
end_of_text_token_id = static_cast<int32_t>(eos_token_id);
}
🤖 Prompt for AI Agents
In mllm/models/internlm2/configuration_internlm2.hpp around lines 50 to 53,
validate that num_attention_heads is > 0 and that hidden_size is divisible by
num_attention_heads before computing head_dim; if either check fails, throw or
log a clear runtime error (include hidden_size and num_attention_heads in the
message) so misconfigured models fail fast rather than producing truncated
head_dim, then compute head_dim = hidden_size / num_attention_heads and proceed
with max_cache_length and end_of_text_token_id assignments.


bool bias = false;
int32_t hidden_size = 4096;
int32_t intermediate_size = 11008;
int32_t num_hidden_layers = 32;
int32_t num_attention_heads = 32;
int32_t num_key_value_heads = 32;
int32_t max_position_embeddings = 2048;
int32_t max_cache_length = 2048;
int32_t head_dim = 128;
int32_t vocab_size = 32000;
float rms_norm_eps = 1e-6f;
float rope_theta = 10000.0f;
float rope_scaling_factor = 1.0f;
std::string rope_scaling_type;

float initializer_range = 0.02f;
bool tie_word_embeddings = false;
bool use_cache = true;

int32_t pad_token_id = 0;
int32_t bos_token_id = 1;
int32_t eos_token_id = 2;
int32_t end_of_text_token_id = 2;

aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault;
};

} // namespace mllm::models::internlm2
Loading