Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/generate-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pbjson-build.workspace = true
typify.workspace = true
schemars.workspace = true
schemafy.workspace = true
regex = "1"
36 changes: 36 additions & 0 deletions crates/generate-types/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::path::{Path, PathBuf};
use big_serde_json as serde_json;
use tool_generator::{generate_all_tool_code, replace_tool_struct_with_enum};

mod pbjson_patches;
mod proto_patches;
mod schema_converter;
mod tool_generator;

Expand Down Expand Up @@ -1298,6 +1300,28 @@ fn generate_google_protobuf_types_from_git() {

println!("✅ Found protobuf files, compiling with complete dependencies...");

// Patch the proto file with missing fields/enums before compilation
let patches = &*proto_patches::GOOGLE_TYPE_PATCHES;
match std::fs::read_to_string(&proto_file) {
Ok(proto_content) => match patches.apply(&proto_content) {
Ok(patched_proto) => {
if let Err(e) = std::fs::write(&proto_file, patched_proto) {
println!("❌ Failed to write patched proto: {}", e);
return;
}
println!("✅ Applied proto patches for missing fields/enums");
}
Err(e) => {
println!("❌ Failed to apply proto patches: {}", e);
return;
}
},
Err(e) => {
println!("❌ Failed to read proto file: {}", e);
return;
}
}

// Include both the main service proto and google.type dependencies
let proto_paths = vec![
proto_file.to_string_lossy().to_string(),
Expand Down Expand Up @@ -1469,6 +1493,18 @@ fn create_google_combined_output(temp_dir: &std::path::Path, pbjson_dir: Option<
pbjson_content.push_str("// pbjson serde output unavailable\n");
}

// Fix float deserialization (pbjson wraps floats in NumberDeserialize which we don't want)
match pbjson_patches::fix_float_fields(&pbjson_content, &pbjson_patches::FLOAT_PATCHES) {
Ok(fixed) => {
pbjson_content = fixed;
println!("✅ Applied float deserialization fix");
}
Err(e) => {
println!("⚠️ Failed to apply float fix: {}", e);
// Continue anyway - floats will just accept strings too
}
}

if let Some(parent) = Path::new(pbjson_dest_path).parent() {
let _ = std::fs::create_dir_all(parent);
}
Expand Down
137 changes: 137 additions & 0 deletions crates/generate-types/src/pbjson_patches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//! Post-processing patches for pbjson generated code.
//!
//! pbjson wraps floats in `NumberDeserialize<_>` to accept both numbers and
//! strings per proto3 JSON spec. We remove this wrapper since we want strict
//! JSON number parsing for floats.

use regex::Regex;
use std::sync::LazyLock;

/// Float field patches for strict JSON number parsing
pub static FLOAT_PATCHES: LazyLock<Vec<FloatFieldPatch>> = LazyLock::new(|| {
vec![
FloatFieldPatch {
field_name: "temperature",
},
FloatFieldPatch {
field_name: "top_p",
},
FloatFieldPatch {
field_name: "presence_penalty",
},
FloatFieldPatch {
field_name: "frequency_penalty",
},
]
});

pub struct FloatFieldPatch {
pub field_name: &'static str,
}

/// Fix float deserialization by removing NumberDeserialize wrapper.
///
/// pbjson generates:
/// ```ignore
/// field__ =
/// map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?
/// .map(|x| x.0)
/// ;
/// ```
///
/// We replace with:
/// ```ignore
/// field__ = map_.next_value()?;
/// ```
pub fn fix_float_fields(content: &str, patches: &[FloatFieldPatch]) -> Result<String, String> {
let mut result = content.to_string();

for f in patches {
// pbjson format:
// field__ =
// map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
// ;
let pattern = format!(
r#"(?s)({}__\s*=\s*)\n\s*map_\.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>\(\)\?\.map\(\|x\| x\.0\)\s*;"#,
regex::escape(f.field_name)
);
let re = Regex::new(&pattern).map_err(|e| format!("Invalid regex: {}", e))?;

// This might not match if already fixed or format different - that's OK
if re.is_match(&result) {
result = re
.replace_all(&result, |caps: &regex::Captures| {
format!("{} map_.next_value()?;", &caps[1])
})
.to_string();
}
}

Ok(result)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_fix_float_field() {
// Actual pbjson format has = on one line, value on next with indentation
let input = r#"
GeneratedField::Temperature => {
if temperature__.is_some() {
return Err(serde::de::Error::duplicate_field("temperature"));
}
temperature__ =
map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
;
}
"#;
let patches = vec![FloatFieldPatch {
field_name: "temperature",
}];

let result = fix_float_fields(input, &patches).unwrap();

assert!(result.contains("temperature__ = map_.next_value()?;"));
assert!(!result.contains("NumberDeserialize"));
}

#[test]
fn test_fix_multiple_float_fields() {
// Actual pbjson format
let input = r#"
temperature__ =
map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
;
top_p__ =
map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0)
;
"#;
let patches = vec![
FloatFieldPatch {
field_name: "temperature",
},
FloatFieldPatch {
field_name: "top_p",
},
];

let result = fix_float_fields(input, &patches).unwrap();

assert!(result.contains("temperature__ = map_.next_value()?;"));
assert!(result.contains("top_p__ = map_.next_value()?;"));
assert!(!result.contains("NumberDeserialize"));
}

#[test]
fn test_no_match_is_ok() {
let input = "some unrelated content";
let patches = vec![FloatFieldPatch {
field_name: "temperature",
}];

let result = fix_float_fields(input, &patches).unwrap();
assert_eq!(result, input);
}
}
Loading
Loading