diff --git a/Cargo.lock b/Cargo.lock index 37d26cc465..1fa781a2e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -853,6 +853,7 @@ dependencies = [ "allocator", "libfuzzer-sys", "qsc", + "qsc_llvm", ] [[package]] @@ -2190,6 +2191,7 @@ dependencies = [ "qsc_formatter", "qsc_frontend", "qsc_hir", + "qsc_llvm", "qsc_lowerer", "qsc_partial_eval", "qsc_passes", @@ -2314,6 +2316,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "qsc_llvm" +version = "0.0.0" +dependencies = [ + "arbitrary", + "expect-test", + "half", + "indoc", + "miette", + "rustc-hash", + "thiserror", + "winnow", +] + [[package]] name = "qsc_lowerer" version = "0.0.0" @@ -2512,6 +2528,8 @@ dependencies = [ "pyo3", "qdk_simulators", "qsc", + "qsc_codegen", + "qsc_llvm", "rand 0.8.5", "rayon", "resource_estimator", @@ -3818,6 +3836,15 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index cdd0b260c2..b68cce9b8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ license = "MIT" version = "0.0.0" [workspace.dependencies] +arbitrary = "1" bitflags = "2.11" memchr = "2.8" clap = "4.4" @@ -79,6 +80,7 @@ sorted-vec = "0.8" sorted-iter = "0.1" wasm-bindgen = "0.2.114" wasm-bindgen-futures = "0.4" +winnow = "0.7" rand = "0.8" serde_json = "1.0" pyo3 = "0.28" diff --git a/build.py b/build.py index 4cd38aaa3d..036624f506 100755 --- a/build.py +++ b/build.py @@ -345,8 +345,6 @@ def install_python_test_requirements(cwd, interpreter, check: bool = True): requirement, "--only-binary", "qirrunner", - "--only-binary", - "pyqir", ] subprocess.run(command_args, check=check, text=True, cwd=cwd) diff --git a/source/compiler/qsc_codegen/Cargo.toml b/source/compiler/qsc_codegen/Cargo.toml index 285f1ea1d8..4be7e584b3 100644 --- a/source/compiler/qsc_codegen/Cargo.toml +++ b/source/compiler/qsc_codegen/Cargo.toml @@ -19,6 +19,7 @@ qsc_fir = { path = "../qsc_fir" } qsc_formatter = { path = "../qsc_formatter" } qsc_frontend = { path = "../qsc_frontend" } qsc_hir = { path = "../qsc_hir" } +qsc_llvm = { path = "../qsc_llvm" } qsc_lowerer = { path = "../qsc_lowerer" } qsc_partial_eval = { path = "../qsc_partial_eval" } qsc_rca = { path = "../qsc_rca" } diff --git a/source/compiler/qsc_codegen/src/qir.rs b/source/compiler/qsc_codegen/src/qir.rs index 5efb55f353..3276bf6fad 100644 --- a/source/compiler/qsc_codegen/src/qir.rs +++ b/source/compiler/qsc_codegen/src/qir.rs @@ -1,16 +1,20 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use qsc_data_structures::target::{Profile, TargetCapabilityFlags}; -use qsc_eval::val::Value; +use qsc_data_structures::{ + span::Span, + target::{Profile, TargetCapabilityFlags}, +}; +use qsc_eval::{PackageSpan, val::Value}; +use qsc_llvm::qir::QirProfile; +use qsc_lowerer::map_fir_package_to_hir; use qsc_partial_eval::{ PartialEvalConfig, Program, ProgramEntry, partially_evaluate, partially_evaluate_call, }; use qsc_rca::PackageStoreComputeProperties; use qsc_rir::{passes::check_and_transform, rir}; -pub mod v1; -pub mod v2; +mod common; /// converts the given sources to RIR using the given language features. pub fn fir_to_rir( @@ -49,11 +53,58 @@ pub fn fir_to_qir( }, )?; check_and_transform(&mut program); - if capabilities <= Profile::AdaptiveRIF.into() { - Ok(v1::ToQir::::to_qir(&program, &program)) - } else { - Ok(v2::ToQir::::to_qir(&program, &program)) + let module = build_module(&program, capabilities); + #[cfg(debug_assertions)] + { + let ir_errors = qsc_llvm::validate_ir(&module); + assert!( + ir_errors.is_empty(), + "codegen produced invalid IR in fir_to_qir: {ir_errors:?}" + ); } + Ok(qsc_llvm::write_module_to_string(&module)) +} + +/// converts the given sources to QIR bitcode using the given language features. +pub fn fir_to_qir_bitcode( + fir_store: &qsc_fir::fir::PackageStore, + capabilities: TargetCapabilityFlags, + compute_properties: Option, + entry: &ProgramEntry, +) -> Result, qsc_partial_eval::Error> { + let mut program = get_rir_from_compilation( + fir_store, + compute_properties, + entry, + capabilities, + PartialEvalConfig { + generate_debug_metadata: false, + }, + )?; + check_and_transform(&mut program); + let module = build_module(&program, capabilities); + #[cfg(debug_assertions)] + { + let ir_errors = qsc_llvm::validate_ir(&module); + assert!( + ir_errors.is_empty(), + "codegen produced invalid IR in fir_to_qir_bitcode: {ir_errors:?}" + ); + } + qsc_llvm::try_write_bitcode(&module).map_err(|error| bitcode_write_error(entry, &error)) +} + +fn bitcode_write_error( + entry: &ProgramEntry, + error: &qsc_llvm::WriteError, +) -> qsc_partial_eval::Error { + qsc_partial_eval::Error::Unexpected( + format!("QIR bitcode emission failed: {error}"), + PackageSpan { + package: map_fir_package_to_hir(entry.expr.package), + span: Span::default(), + }, + ) } /// converts the given callable to QIR using the given arguments and language features. @@ -80,11 +131,16 @@ pub fn fir_to_qir_from_callable( }, )?; check_and_transform(&mut program); - if capabilities <= Profile::AdaptiveRIF.into() { - Ok(v1::ToQir::::to_qir(&program, &program)) - } else { - Ok(v2::ToQir::::to_qir(&program, &program)) + let module = build_module(&program, capabilities); + #[cfg(debug_assertions)] + { + let ir_errors = qsc_llvm::validate_ir(&module); + assert!( + ir_errors.is_empty(), + "codegen produced invalid IR in fir_to_qir_from_callable: {ir_errors:?}" + ); } + Ok(qsc_llvm::write_module_to_string(&module)) } /// converts the given callable to RIR using the given arguments and language features. @@ -114,6 +170,22 @@ pub fn fir_to_rir_from_callable( Ok((orig, program)) } +fn build_module( + program: &rir::Program, + capabilities: TargetCapabilityFlags, +) -> qsc_llvm::model::Module { + let profile = if capabilities <= Profile::AdaptiveRIF.into() { + if program.config.is_base() { + QirProfile::BaseV1 + } else { + QirProfile::AdaptiveV1 + } + } else { + QirProfile::AdaptiveV2 + }; + common::build_qir_module(program, profile) +} + fn get_rir_from_compilation( fir_store: &qsc_fir::fir::PackageStore, compute_properties: Option, diff --git a/source/compiler/qsc_codegen/src/qir/common.rs b/source/compiler/qsc_codegen/src/qir/common.rs new file mode 100644 index 0000000000..01ea162f5d --- /dev/null +++ b/source/compiler/qsc_codegen/src/qir/common.rs @@ -0,0 +1,683 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use qsc_data_structures::{attrs::Attributes, target::TargetCapabilityFlags}; +use qsc_rir::{ + rir::{self, ConditionCode, FcmpConditionCode}, + utils::get_all_block_successors, +}; + +use qsc_llvm::model::Type; +use qsc_llvm::model::{ + Attribute, AttributeGroup, BasicBlock, BinOpKind, CastKind, Constant, FloatPredicate, Function, + GlobalVariable, Instruction, IntPredicate, Linkage, MetadataNode, MetadataValue, Module, + NamedMetadata, Operand, Param, +}; +use qsc_llvm::qir::{self, QirProfile}; + +/// Whether to use typed pointers (`%Qubit*`, `i8*`) or opaque pointers (`ptr`). +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum PointerStyle { + Typed, // v1: `%Qubit*`, `%Result*`, `i8*` + Opaque, // v2: `ptr` +} + +/// Build a complete `Module` from a RIR program for any QIR profile. +pub fn build_qir_module(program: &rir::Program, profile: QirProfile) -> Module { + let style = if profile.uses_typed_pointers() { + PointerStyle::Typed + } else { + PointerStyle::Opaque + }; + + let globals = build_globals(program); + let functions = build_functions(program, style); + let attribute_groups = build_attribute_groups(program, profile.profile_name()); + let (named_metadata, metadata_nodes) = build_metadata(program, profile); + + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: profile.struct_types(), + globals, + functions, + attribute_groups, + named_metadata, + metadata_nodes, + } +} + +pub fn rir_ty(ty: rir::Ty, style: PointerStyle) -> Type { + match ty { + rir::Ty::Boolean => Type::Integer(1), + rir::Ty::Double => Type::Double, + rir::Ty::Integer => Type::Integer(64), + rir::Ty::Pointer => match style { + PointerStyle::Typed => Type::TypedPtr(Box::new(Type::Integer(8))), + PointerStyle::Opaque => Type::Ptr, + }, + rir::Ty::Qubit => match style { + PointerStyle::Typed => Type::NamedPtr(qir::QUBIT_TYPE_NAME.into()), + PointerStyle::Opaque => Type::Ptr, + }, + rir::Ty::Result => match style { + PointerStyle::Typed => Type::NamedPtr(qir::RESULT_TYPE_NAME.into()), + PointerStyle::Opaque => Type::Ptr, + }, + } +} + +pub fn rir_output_ty(ty: Option, style: PointerStyle) -> Option { + ty.map(|t| rir_ty(t, style)) +} + +fn var_name(id: rir::VariableId) -> String { + format!("var_{}", id.0) +} + +fn block_name(id: rir::BlockId) -> String { + format!("block_{}", id.0) +} + +/// Convert an RIR operand to an untyped `llvm_ir::Operand` (for use within binary ops, etc.). +fn operand_untyped(op: &rir::Operand, style: PointerStyle) -> Operand { + match op { + rir::Operand::Literal(lit) => literal_untyped(lit, style), + rir::Operand::Variable(var) => Operand::LocalRef(var_name(var.variable_id)), + } +} + +fn literal_untyped(lit: &rir::Literal, style: PointerStyle) -> Operand { + match lit { + rir::Literal::Bool(b) => Operand::IntConst(Type::Integer(1), i64::from(*b)), + rir::Literal::Double(d) => Operand::float_const(Type::Double, *d), + rir::Literal::Integer(i) => Operand::IntConst(Type::Integer(64), *i), + rir::Literal::Pointer => Operand::NullPtr, + rir::Literal::Qubit(q) => { + let q_i64 = i64::from(*q); + match style { + PointerStyle::Typed => { + Operand::IntToPtr(q_i64, Type::NamedPtr(qir::QUBIT_TYPE_NAME.into())) + } + PointerStyle::Opaque => Operand::IntToPtr(q_i64, Type::Ptr), + } + } + rir::Literal::Result(r) => { + let r_i64 = i64::from(*r); + match style { + PointerStyle::Typed => { + Operand::IntToPtr(r_i64, Type::NamedPtr(qir::RESULT_TYPE_NAME.into())) + } + PointerStyle::Opaque => Operand::IntToPtr(r_i64, Type::Ptr), + } + } + rir::Literal::Tag(idx, len) => { + let idx_i64 = i64::try_from(*idx).expect("tag index should fit in i64"); + let array_len = u64::try_from(*len).expect("tag length should fit in u64") + 1; // +1 for null terminator + match style { + PointerStyle::Typed => { + let arr_ty = Type::Array(array_len, Box::new(Type::Integer(8))); + Operand::GetElementPtr { + ty: arr_ty.clone(), + ptr: idx_i64.to_string(), + ptr_ty: Type::TypedPtr(Box::new(arr_ty)), + indices: vec![ + Operand::IntConst(Type::Integer(64), 0), + Operand::IntConst(Type::Integer(64), 0), + ], + } + } + PointerStyle::Opaque => Operand::GlobalRef(idx_i64.to_string()), + } + } + } +} + +/// Convert an RIR operand to a (Type, Operand) pair for use as a call argument. +fn call_arg(op: &rir::Operand, style: PointerStyle) -> (Type, Operand) { + match op { + rir::Operand::Literal(lit) => call_arg_literal(lit, style), + rir::Operand::Variable(var) => { + let ty = rir_ty(var.ty, style); + (ty, Operand::LocalRef(var_name(var.variable_id))) + } + } +} + +fn call_arg_literal(lit: &rir::Literal, style: PointerStyle) -> (Type, Operand) { + match lit { + rir::Literal::Bool(b) => ( + Type::Integer(1), + Operand::IntConst(Type::Integer(1), i64::from(*b)), + ), + rir::Literal::Double(d) => (Type::Double, Operand::float_const(Type::Double, *d)), + rir::Literal::Integer(i) => (Type::Integer(64), Operand::IntConst(Type::Integer(64), *i)), + rir::Literal::Pointer => match style { + PointerStyle::Typed => (Type::TypedPtr(Box::new(Type::Integer(8))), Operand::NullPtr), + PointerStyle::Opaque => (Type::Ptr, Operand::NullPtr), + }, + rir::Literal::Qubit(q) => { + let q_i64 = i64::from(*q); + match style { + PointerStyle::Typed => { + let ty = Type::NamedPtr(qir::QUBIT_TYPE_NAME.into()); + (ty.clone(), Operand::IntToPtr(q_i64, ty)) + } + PointerStyle::Opaque => (Type::Ptr, Operand::IntToPtr(q_i64, Type::Ptr)), + } + } + rir::Literal::Result(r) => { + let r_i64 = i64::from(*r); + match style { + PointerStyle::Typed => { + let ty = Type::NamedPtr(qir::RESULT_TYPE_NAME.into()); + (ty.clone(), Operand::IntToPtr(r_i64, ty)) + } + PointerStyle::Opaque => (Type::Ptr, Operand::IntToPtr(r_i64, Type::Ptr)), + } + } + rir::Literal::Tag(idx, len) => { + let idx_i64 = i64::try_from(*idx).expect("tag index should fit in i64"); + let array_len = u64::try_from(*len).expect("tag length should fit in u64") + 1; + match style { + PointerStyle::Typed => { + let arr_ty = Type::Array(array_len, Box::new(Type::Integer(8))); + ( + Type::TypedPtr(Box::new(Type::Integer(8))), + Operand::GetElementPtr { + ty: arr_ty.clone(), + ptr: idx_i64.to_string(), + ptr_ty: Type::TypedPtr(Box::new(arr_ty)), + indices: vec![ + Operand::IntConst(Type::Integer(64), 0), + Operand::IntConst(Type::Integer(64), 0), + ], + }, + ) + } + PointerStyle::Opaque => (Type::Ptr, Operand::GlobalRef(idx_i64.to_string())), + } + } + } +} + +fn operand_ir_ty(op: &rir::Operand, style: PointerStyle) -> Type { + match op { + rir::Operand::Literal(lit) => match lit { + rir::Literal::Bool(_) => Type::Integer(1), + rir::Literal::Double(_) => Type::Double, + rir::Literal::Integer(_) => Type::Integer(64), + rir::Literal::Pointer + | rir::Literal::Qubit(_) + | rir::Literal::Result(_) + | rir::Literal::Tag(..) => match style { + PointerStyle::Typed => match lit { + rir::Literal::Qubit(_) => Type::NamedPtr(qir::QUBIT_TYPE_NAME.into()), + rir::Literal::Result(_) => Type::NamedPtr(qir::RESULT_TYPE_NAME.into()), + _ => Type::TypedPtr(Box::new(Type::Integer(8))), + }, + PointerStyle::Opaque => Type::Ptr, + }, + }, + rir::Operand::Variable(var) => rir_ty(var.ty, style), + } +} + +fn convert_binop( + op: BinOpKind, + lhs: &rir::Operand, + rhs: &rir::Operand, + var: rir::Variable, + style: PointerStyle, +) -> Instruction { + Instruction::BinOp { + op, + ty: rir_ty(var.ty, style), + lhs: operand_untyped(lhs, style), + rhs: operand_untyped(rhs, style), + result: var_name(var.variable_id), + } +} + +#[allow(clippy::too_many_lines)] +pub fn convert_instruction( + instr: &rir::Instruction, + program: &rir::Program, + style: PointerStyle, +) -> Instruction { + match instr { + rir::Instruction::Add(lhs, rhs, var) => { + convert_binop(BinOpKind::Add, lhs, rhs, *var, style) + } + rir::Instruction::Ashr(lhs, rhs, var) => { + convert_binop(BinOpKind::Ashr, lhs, rhs, *var, style) + } + rir::Instruction::BitwiseAnd(lhs, rhs, var) + | rir::Instruction::LogicalAnd(lhs, rhs, var) => { + convert_binop(BinOpKind::And, lhs, rhs, *var, style) + } + rir::Instruction::BitwiseNot(value, var) => Instruction::BinOp { + op: BinOpKind::Xor, + ty: rir_ty(var.ty, style), + lhs: operand_untyped(value, style), + rhs: Operand::IntConst(Type::Integer(64), -1), + result: var_name(var.variable_id), + }, + rir::Instruction::BitwiseOr(lhs, rhs, var) | rir::Instruction::LogicalOr(lhs, rhs, var) => { + convert_binop(BinOpKind::Or, lhs, rhs, *var, style) + } + rir::Instruction::BitwiseXor(lhs, rhs, var) => { + convert_binop(BinOpKind::Xor, lhs, rhs, *var, style) + } + rir::Instruction::Branch(cond, true_id, false_id, _) => Instruction::Br { + cond_ty: rir_ty(cond.ty, style), + cond: Operand::LocalRef(var_name(cond.variable_id)), + true_dest: block_name(*true_id), + false_dest: block_name(*false_id), + }, + rir::Instruction::Call(call_id, args, output, _) => { + let callable = program.get_callable(*call_id); + Instruction::Call { + return_ty: rir_output_ty(callable.output_type, style), + callee: callable.name.clone(), + args: args.iter().map(|a| call_arg(a, style)).collect(), + result: output.map(|v| var_name(v.variable_id)), + attr_refs: vec![], + } + } + rir::Instruction::Convert(operand, var) => { + let from_ty = operand_ir_ty(operand, style); + let to_ty = rir_ty(var.ty, style); + let cast_op = match (&from_ty, &to_ty) { + (Type::Integer(64), Type::Double) => CastKind::Sitofp, + (Type::Double, Type::Integer(64)) => CastKind::Fptosi, + _ => panic!("unsupported conversion from {from_ty} to {to_ty}"), + }; + Instruction::Cast { + op: cast_op, + from_ty, + to_ty, + value: operand_untyped(operand, style), + result: var_name(var.variable_id), + } + } + rir::Instruction::Fadd(lhs, rhs, var) => { + convert_binop(BinOpKind::Fadd, lhs, rhs, *var, style) + } + rir::Instruction::Fdiv(lhs, rhs, var) => { + convert_binop(BinOpKind::Fdiv, lhs, rhs, *var, style) + } + rir::Instruction::Fmul(lhs, rhs, var) => { + convert_binop(BinOpKind::Fmul, lhs, rhs, *var, style) + } + rir::Instruction::Fsub(lhs, rhs, var) => { + convert_binop(BinOpKind::Fsub, lhs, rhs, *var, style) + } + rir::Instruction::Fcmp(op, lhs, rhs, var) => Instruction::FCmp { + pred: convert_fcmp(*op), + ty: operand_ir_ty(lhs, style), + lhs: operand_untyped(lhs, style), + rhs: operand_untyped(rhs, style), + result: var_name(var.variable_id), + }, + rir::Instruction::Icmp(op, lhs, rhs, var) => Instruction::ICmp { + pred: convert_icmp(*op), + ty: operand_ir_ty(lhs, style), + lhs: operand_untyped(lhs, style), + rhs: operand_untyped(rhs, style), + result: var_name(var.variable_id), + }, + rir::Instruction::Jump(block_id) => Instruction::Jump { + dest: block_name(*block_id), + }, + rir::Instruction::LogicalNot(value, var) => Instruction::BinOp { + op: BinOpKind::Xor, + ty: Type::Integer(1), + lhs: operand_untyped(value, style), + rhs: Operand::IntConst(Type::Integer(1), 1), + result: var_name(var.variable_id), + }, + rir::Instruction::Mul(lhs, rhs, var) => { + convert_binop(BinOpKind::Mul, lhs, rhs, *var, style) + } + rir::Instruction::Phi(args, var) => Instruction::Phi { + ty: rir_ty(var.ty, style), + incoming: args + .iter() + .map(|(op, bid)| (operand_untyped(op, style), block_name(*bid))) + .collect(), + result: var_name(var.variable_id), + }, + rir::Instruction::Return => Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + rir::Instruction::Sdiv(lhs, rhs, var) => { + convert_binop(BinOpKind::Sdiv, lhs, rhs, *var, style) + } + rir::Instruction::Shl(lhs, rhs, var) => { + convert_binop(BinOpKind::Shl, lhs, rhs, *var, style) + } + rir::Instruction::Srem(lhs, rhs, var) => { + convert_binop(BinOpKind::Srem, lhs, rhs, *var, style) + } + rir::Instruction::Store(operand, variable) => Instruction::Store { + ty: operand_ir_ty(operand, style), + value: operand_untyped(operand, style), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef(var_name(variable.variable_id)), + }, + rir::Instruction::Sub(lhs, rhs, var) => { + convert_binop(BinOpKind::Sub, lhs, rhs, *var, style) + } + rir::Instruction::Alloca(var) => Instruction::Alloca { + ty: rir_ty(var.ty, style), + result: var_name(var.variable_id), + }, + rir::Instruction::Load(var_from, var_to) => Instruction::Load { + ty: rir_ty(var_to.ty, style), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef(var_name(var_from.variable_id)), + result: var_name(var_to.variable_id), + }, + } +} + +fn convert_icmp(op: ConditionCode) -> IntPredicate { + match op { + ConditionCode::Eq => IntPredicate::Eq, + ConditionCode::Ne => IntPredicate::Ne, + ConditionCode::Sgt => IntPredicate::Sgt, + ConditionCode::Sge => IntPredicate::Sge, + ConditionCode::Slt => IntPredicate::Slt, + ConditionCode::Sle => IntPredicate::Sle, + } +} + +fn convert_fcmp(op: FcmpConditionCode) -> FloatPredicate { + match op { + FcmpConditionCode::False | FcmpConditionCode::True => { + panic!("unsupported fcmp predicate: {op}") + } + FcmpConditionCode::OrderedAndEqual => FloatPredicate::Oeq, + FcmpConditionCode::OrderedAndGreaterThan => FloatPredicate::Ogt, + FcmpConditionCode::OrderedAndGreaterThanOrEqual => FloatPredicate::Oge, + FcmpConditionCode::OrderedAndLessThan => FloatPredicate::Olt, + FcmpConditionCode::OrderedAndLessThanOrEqual => FloatPredicate::Ole, + FcmpConditionCode::OrderedAndNotEqual => FloatPredicate::One, + FcmpConditionCode::Ordered => FloatPredicate::Ord, + FcmpConditionCode::UnorderedOrEqual => FloatPredicate::Ueq, + FcmpConditionCode::UnorderedOrGreaterThan => FloatPredicate::Ugt, + FcmpConditionCode::UnorderedOrGreaterThanOrEqual => FloatPredicate::Uge, + FcmpConditionCode::UnorderedOrLessThan => FloatPredicate::Ult, + FcmpConditionCode::UnorderedOrLessThanOrEqual => FloatPredicate::Ule, + FcmpConditionCode::UnorderedOrNotEqual => FloatPredicate::Une, + FcmpConditionCode::Unordered => FloatPredicate::Uno, + } +} + +pub fn build_globals(program: &rir::Program) -> Vec { + program + .tags + .iter() + .enumerate() + .map(|(idx, tag)| { + let array_len = u64::try_from(tag.len() + 1).expect("tag length should fit in u64"); + GlobalVariable { + name: idx.to_string(), + ty: Type::Array(array_len, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString(tag.clone())), + } + }) + .collect() +} + +pub fn build_functions(program: &rir::Program, style: PointerStyle) -> Vec { + let mut declarations = Vec::new(); + let mut definitions = Vec::new(); + + for (_, callable) in program.callables.iter() { + if callable.body.is_some() { + definitions.push(build_definition(callable, program, style)); + } else { + declarations.push(build_declaration(callable, style)); + } + } + + // Definitions first, then declarations (matching the original QIR generator ordering) + definitions.extend(declarations); + definitions +} + +fn build_declaration(callable: &rir::Callable, style: PointerStyle) -> Function { + let attr_refs = match callable.call_type { + rir::CallableType::Measurement | rir::CallableType::Reset => vec![1], + rir::CallableType::NoiseIntrinsic => vec![2], + _ => vec![], + }; + + Function { + name: callable.name.clone(), + return_type: rir_output_ty(callable.output_type, style).unwrap_or(Type::Void), + params: callable + .input_type + .iter() + .map(|t| Param { + ty: rir_ty(*t, style), + name: None, + }) + .collect(), + is_declaration: true, + attribute_group_refs: attr_refs, + basic_blocks: vec![], + } +} + +fn build_definition( + callable: &rir::Callable, + program: &rir::Program, + style: PointerStyle, +) -> Function { + let entry_id = callable.body.expect("definition should have a body"); + + let mut all_blocks = vec![entry_id]; + all_blocks.extend(get_all_block_successors(entry_id, program)); + + let basic_blocks = all_blocks + .iter() + .map(|&bid| { + let block = program.get_block(bid); + BasicBlock { + name: block_name(bid), + instructions: block + .0 + .iter() + .map(|i| convert_instruction(i, program, style)) + .collect(), + } + }) + .collect(); + + Function { + name: qir::ENTRYPOINT_NAME.into(), + return_type: Type::Integer(64), + params: vec![], + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks, + } +} + +pub fn build_attribute_groups(program: &rir::Program, profile: &str) -> Vec { + let mut groups = vec![ + AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr(qir::ENTRY_POINT_ATTR.into()), + Attribute::StringAttr(qir::OUTPUT_LABELING_SCHEMA_ATTR.into()), + Attribute::KeyValue(qir::QIR_PROFILES_ATTR.into(), profile.into()), + Attribute::KeyValue( + qir::REQUIRED_NUM_QUBITS_ATTR.into(), + program.num_qubits.to_string(), + ), + Attribute::KeyValue( + qir::REQUIRED_NUM_RESULTS_ATTR.into(), + program.num_results.to_string(), + ), + ], + }, + AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr(qir::IRREVERSIBLE_ATTR.into())], + }, + ]; + + if program.attrs.contains(Attributes::QdkNoise) { + groups.push(AttributeGroup { + id: 2, + attributes: vec![Attribute::StringAttr(qir::QDK_NOISE_ATTR.into())], + }); + } + + groups +} + +/// Build module flags metadata for any QIR profile. +/// +/// The metadata encodes: +/// - `qir_major_version` / `qir_minor_version` (from `profile`) +/// - `dynamic_qubit_management` / `dynamic_result_management` (always false for now) +/// - Optional capability flags (int computations, float computations, backwards branching, arrays) +/// based on `profile` and `program.config.capabilities`. +#[allow(clippy::too_many_lines)] +pub fn build_metadata( + program: &rir::Program, + profile: QirProfile, +) -> (Vec, Vec) { + let mut nodes = vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String(qir::QIR_MAJOR_VERSION_KEY.into()), + MetadataValue::Int(Type::Integer(32), profile.major_version()), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String(qir::QIR_MINOR_VERSION_KEY.into()), + MetadataValue::Int(Type::Integer(32), profile.minor_version()), + ], + }, + MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String(qir::DYNAMIC_QUBIT_MGMT_KEY.into()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 3, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String(qir::DYNAMIC_RESULT_MGMT_KEY.into()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + ]; + + let mut next_id: u32 = 4; + + // For v1 profiles, capabilities come from the program config. + // For v2/v2.1 profiles, all capabilities are always emitted. + match profile { + QirProfile::BaseV1 => { + // Base profile: no capability metadata beyond version + dynamic mgmt. + } + QirProfile::AdaptiveV1 => { + // Adaptive v1: emit capabilities from the program config. + for cap in program.config.capabilities.iter() { + match cap { + TargetCapabilityFlags::IntegerComputations => { + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String(qir::INT_COMPUTATIONS_KEY.into()), + MetadataValue::SubList(vec![MetadataValue::String("i64".into())]), + ], + }); + next_id += 1; + } + TargetCapabilityFlags::FloatingPointComputations => { + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.into()), + MetadataValue::SubList(vec![MetadataValue::String( + "double".into(), + )]), + ], + }); + next_id += 1; + } + _ => {} + } + } + } + QirProfile::AdaptiveV2 => { + // Adaptive v2/v2.1: always emit all capability metadata. + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String(qir::INT_COMPUTATIONS_KEY.into()), + MetadataValue::SubList(vec![MetadataValue::String("i64".into())]), + ], + }); + next_id += 1; + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.into()), + MetadataValue::SubList(vec![MetadataValue::String("double".into())]), + ], + }); + next_id += 1; + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String(qir::BACKWARDS_BRANCHING_KEY.into()), + MetadataValue::Int(Type::Integer(2), 3), + ], + }); + next_id += 1; + nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String(qir::ARRAYS_KEY.into()), + MetadataValue::Int(Type::Integer(1), 1), + ], + }); + next_id += 1; + } + } + + let node_refs: Vec = (0..next_id).collect(); + let named = vec![NamedMetadata { + name: qir::MODULE_FLAGS_NAME.into(), + node_refs, + }]; + + (named, nodes) +} diff --git a/source/compiler/qsc_codegen/src/qir/common/tests.rs b/source/compiler/qsc_codegen/src/qir/common/tests.rs new file mode 100644 index 0000000000..e582460225 --- /dev/null +++ b/source/compiler/qsc_codegen/src/qir/common/tests.rs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use qsc_rir::rir::{self, Operand, Ty, Variable, VariableId}; + +fn bool_var(id: u32) -> Variable { + Variable { + variable_id: VariableId(id), + ty: Ty::Boolean, + } +} + +fn int_var(id: u32) -> Variable { + Variable { + variable_id: VariableId(id), + ty: Ty::Integer, + } +} + +fn bool_ref(id: u32) -> Operand { + Operand::Variable(bool_var(id)) +} + +fn int_ref(id: u32) -> Operand { + Operand::Variable(int_var(id)) +} + +/// Logical AND operates on `i1` (boolean) operands. +/// Bitwise AND operates on `i64` (integer) operands. +/// Both map to LLVM `and` — the type is what distinguishes them. +#[test] +fn logical_and_produces_i1_type() { + let instr = rir::Instruction::LogicalAnd(bool_ref(0), bool_ref(1), bool_var(2)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, .. } => { + assert_eq!(op, BinOpKind::And); + assert_eq!(ty, Type::Integer(1)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +#[test] +fn bitwise_and_produces_i64_type() { + let instr = rir::Instruction::BitwiseAnd(int_ref(0), int_ref(1), int_var(2)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, .. } => { + assert_eq!(op, BinOpKind::And); + assert_eq!(ty, Type::Integer(64)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +#[test] +fn logical_or_produces_i1_type() { + let instr = rir::Instruction::LogicalOr(bool_ref(0), bool_ref(1), bool_var(2)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, .. } => { + assert_eq!(op, BinOpKind::Or); + assert_eq!(ty, Type::Integer(1)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +#[test] +fn bitwise_or_produces_i64_type() { + let instr = rir::Instruction::BitwiseOr(int_ref(0), int_ref(1), int_var(2)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, .. } => { + assert_eq!(op, BinOpKind::Or); + assert_eq!(ty, Type::Integer(64)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +/// Logical NOT is `xor i1 %val, true` (flip a boolean). +/// Bitwise NOT is `xor i64 %val, -1` (flip all 64 bits). +#[test] +fn logical_not_produces_xor_i1_with_true() { + let instr = rir::Instruction::LogicalNot(bool_ref(0), bool_var(1)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, rhs, .. } => { + assert_eq!(op, BinOpKind::Xor); + assert_eq!(ty, Type::Integer(1)); + assert_eq!(rhs, super::Operand::IntConst(Type::Integer(1), 1)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +#[test] +fn bitwise_not_produces_xor_i64_with_minus_one() { + let instr = rir::Instruction::BitwiseNot(int_ref(0), int_var(1)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, rhs, .. } => { + assert_eq!(op, BinOpKind::Xor); + assert_eq!(ty, Type::Integer(64)); + assert_eq!(rhs, super::Operand::IntConst(Type::Integer(64), -1)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} + +#[test] +fn bitwise_xor_produces_i64_type() { + let instr = rir::Instruction::BitwiseXor(int_ref(0), int_ref(1), int_var(2)); + let result = convert_instruction(&instr, &rir::Program::default(), PointerStyle::Opaque); + match result { + Instruction::BinOp { op, ty, .. } => { + assert_eq!(op, BinOpKind::Xor); + assert_eq!(ty, Type::Integer(64)); + } + other => panic!("expected BinOp, got {other:?}"), + } +} diff --git a/source/compiler/qsc_codegen/src/qir/v1.rs b/source/compiler/qsc_codegen/src/qir/v1.rs deleted file mode 100644 index 8ed60ba29b..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1.rs +++ /dev/null @@ -1,737 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#[cfg(test)] -mod instruction_tests; - -#[cfg(test)] -mod tests; - -use qsc_data_structures::{attrs::Attributes, target::TargetCapabilityFlags}; -use qsc_rir::{ - rir::{self, ConditionCode, FcmpConditionCode}, - utils::get_all_block_successors, -}; -use std::fmt::Write; - -/// A trait for converting a type into QIR of type `T`. -/// This can be used to generate QIR strings or other representations. -pub trait ToQir { - fn to_qir(&self, program: &rir::Program) -> T; -} - -impl ToQir for rir::Literal { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::Literal::Bool(b) => format!("i1 {b}"), - rir::Literal::Double(d) => { - if (d.floor() - d.ceil()).abs() < f64::EPSILON { - // The value is a whole number, which requires at least one decimal point - // to differentiate it from an integer value. - format!("double {d:.1}") - } else { - format!("double {d}") - } - } - rir::Literal::Integer(i) => format!("i64 {i}"), - rir::Literal::Pointer => "i8* null".to_string(), - rir::Literal::Qubit(q) => format!("%Qubit* inttoptr (i64 {q} to %Qubit*)"), - rir::Literal::Result(r) => format!("%Result* inttoptr (i64 {r} to %Result*)"), - rir::Literal::Tag(idx, len) => { - let len = len + 1; // +1 for the null terminator - format!( - "i8* getelementptr inbounds ([{len} x i8], [{len} x i8]* @{idx}, i64 0, i64 0)" - ) - } - } - } -} - -impl ToQir for rir::Ty { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::Ty::Boolean => "i1".to_string(), - rir::Ty::Double => "double".to_string(), - rir::Ty::Integer => "i64".to_string(), - rir::Ty::Pointer => "i8*".to_string(), - rir::Ty::Qubit => "%Qubit*".to_string(), - rir::Ty::Result => "%Result*".to_string(), - } - } -} - -impl ToQir for Option { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - Some(ty) => ToQir::::to_qir(ty, program), - None => "void".to_string(), - } - } -} - -impl ToQir for rir::VariableId { - fn to_qir(&self, _program: &rir::Program) -> String { - format!("%var_{}", self.0) - } -} - -impl ToQir for rir::Variable { - fn to_qir(&self, program: &rir::Program) -> String { - format!( - "{} {}", - ToQir::::to_qir(&self.ty, program), - ToQir::::to_qir(&self.variable_id, program) - ) - } -} - -impl ToQir for rir::Operand { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - rir::Operand::Literal(lit) => ToQir::::to_qir(lit, program), - rir::Operand::Variable(var) => ToQir::::to_qir(var, program), - } - } -} - -impl ToQir for rir::FcmpConditionCode { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::FcmpConditionCode::False => "false".to_string(), - rir::FcmpConditionCode::OrderedAndEqual => "oeq".to_string(), - rir::FcmpConditionCode::OrderedAndGreaterThan => "ogt".to_string(), - rir::FcmpConditionCode::OrderedAndGreaterThanOrEqual => "oge".to_string(), - rir::FcmpConditionCode::OrderedAndLessThan => "olt".to_string(), - rir::FcmpConditionCode::OrderedAndLessThanOrEqual => "ole".to_string(), - rir::FcmpConditionCode::OrderedAndNotEqual => "one".to_string(), - rir::FcmpConditionCode::Ordered => "ord".to_string(), - rir::FcmpConditionCode::UnorderedOrEqual => "ueq".to_string(), - rir::FcmpConditionCode::UnorderedOrGreaterThan => "ugt".to_string(), - rir::FcmpConditionCode::UnorderedOrGreaterThanOrEqual => "uge".to_string(), - rir::FcmpConditionCode::UnorderedOrLessThan => "ult".to_string(), - rir::FcmpConditionCode::UnorderedOrLessThanOrEqual => "ule".to_string(), - rir::FcmpConditionCode::UnorderedOrNotEqual => "une".to_string(), - rir::FcmpConditionCode::Unordered => "uno".to_string(), - rir::FcmpConditionCode::True => "true".to_string(), - } - } -} - -impl ToQir for rir::ConditionCode { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::ConditionCode::Eq => "eq".to_string(), - rir::ConditionCode::Ne => "ne".to_string(), - rir::ConditionCode::Sgt => "sgt".to_string(), - rir::ConditionCode::Sge => "sge".to_string(), - rir::ConditionCode::Slt => "slt".to_string(), - rir::ConditionCode::Sle => "sle".to_string(), - } - } -} - -impl ToQir for rir::Instruction { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - rir::Instruction::Add(lhs, rhs, variable) => { - binop_to_qir("add", lhs, rhs, *variable, program) - } - rir::Instruction::Ashr(lhs, rhs, variable) => { - binop_to_qir("ashr", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseAnd(lhs, rhs, variable) => { - simple_bitwise_to_qir("and", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseNot(value, variable) => { - bitwise_not_to_qir(value, *variable, program) - } - rir::Instruction::BitwiseOr(lhs, rhs, variable) => { - simple_bitwise_to_qir("or", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseXor(lhs, rhs, variable) => { - simple_bitwise_to_qir("xor", lhs, rhs, *variable, program) - } - rir::Instruction::Branch(cond, true_id, false_id, _) => { - format!( - " br {}, label %{}, label %{}", - ToQir::::to_qir(cond, program), - ToQir::::to_qir(true_id, program), - ToQir::::to_qir(false_id, program) - ) - } - rir::Instruction::Call(call_id, args, output, _) => { - call_to_qir(args, *call_id, *output, program) - } - rir::Instruction::Convert(operand, variable) => { - convert_to_qir(operand, *variable, program) - } - rir::Instruction::Fadd(lhs, rhs, variable) => { - fbinop_to_qir("fadd", lhs, rhs, *variable, program) - } - rir::Instruction::Fdiv(lhs, rhs, variable) => { - fbinop_to_qir("fdiv", lhs, rhs, *variable, program) - } - rir::Instruction::Fmul(lhs, rhs, variable) => { - fbinop_to_qir("fmul", lhs, rhs, *variable, program) - } - rir::Instruction::Fsub(lhs, rhs, variable) => { - fbinop_to_qir("fsub", lhs, rhs, *variable, program) - } - rir::Instruction::LogicalAnd(lhs, rhs, variable) => { - logical_binop_to_qir("and", lhs, rhs, *variable, program) - } - rir::Instruction::LogicalNot(value, variable) => { - logical_not_to_qir(value, *variable, program) - } - rir::Instruction::LogicalOr(lhs, rhs, variable) => { - logical_binop_to_qir("or", lhs, rhs, *variable, program) - } - rir::Instruction::Mul(lhs, rhs, variable) => { - binop_to_qir("mul", lhs, rhs, *variable, program) - } - rir::Instruction::Fcmp(op, lhs, rhs, variable) => { - fcmp_to_qir(*op, lhs, rhs, *variable, program) - } - rir::Instruction::Icmp(op, lhs, rhs, variable) => { - icmp_to_qir(*op, lhs, rhs, *variable, program) - } - rir::Instruction::Jump(block_id) => { - format!(" br label %{}", ToQir::::to_qir(block_id, program)) - } - rir::Instruction::Phi(args, variable) => phi_to_qir(args, *variable, program), - rir::Instruction::Return => " ret i64 0".to_string(), - rir::Instruction::Sdiv(lhs, rhs, variable) => { - binop_to_qir("sdiv", lhs, rhs, *variable, program) - } - rir::Instruction::Shl(lhs, rhs, variable) => { - binop_to_qir("shl", lhs, rhs, *variable, program) - } - rir::Instruction::Srem(lhs, rhs, variable) => { - binop_to_qir("srem", lhs, rhs, *variable, program) - } - rir::Instruction::Store(_, _) => unimplemented!("store should be removed by pass"), - rir::Instruction::Sub(lhs, rhs, variable) => { - binop_to_qir("sub", lhs, rhs, *variable, program) - } - rir::Instruction::Alloca(..) | rir::Instruction::Load(..) => { - unimplemented!("advanced instructions are not supported in QIR v1 generation") - } - } - } -} - -fn convert_to_qir( - operand: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let operand_ty = get_value_ty(operand); - let var_ty = get_variable_ty(variable); - assert_ne!( - operand_ty, var_ty, - "input/output types ({operand_ty}, {var_ty}) should not match in convert" - ); - - let convert_instr = match (operand_ty, var_ty) { - ("i64", "double") => "sitofp i64", - ("double", "i64") => "fptosi double", - _ => panic!("unsupported conversion from {operand_ty} to {var_ty} in convert instruction"), - }; - - format!( - " {} = {convert_instr} {} to {var_ty}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(operand, program), - ) -} - -fn logical_not_to_qir( - value: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let value_ty = get_value_ty(value); - let var_ty = get_variable_ty(variable); - assert_eq!( - value_ty, var_ty, - "mismatched input/output types ({value_ty}, {var_ty}) for not" - ); - assert_eq!(var_ty, "i1", "unsupported type {var_ty} for not"); - - format!( - " {} = xor i1 {}, true", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(value, program) - ) -} - -fn logical_binop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i1", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn bitwise_not_to_qir( - value: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let value_ty = get_value_ty(value); - let var_ty = get_variable_ty(variable); - assert_eq!( - value_ty, var_ty, - "mismatched input/output types ({value_ty}, {var_ty}) for not" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for not"); - - format!( - " {} = xor {var_ty} {}, -1", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(value, program) - ) -} - -fn call_to_qir( - args: &[rir::Operand], - call_id: rir::CallableId, - output: Option, - program: &rir::Program, -) -> String { - let args = args - .iter() - .map(|arg| ToQir::::to_qir(arg, program)) - .collect::>() - .join(", "); - let callable = program.get_callable(call_id); - if let Some(output) = output { - format!( - " {} = call {} @{}({args})", - ToQir::::to_qir(&output.variable_id, program), - ToQir::::to_qir(&callable.output_type, program), - callable.name - ) - } else { - format!( - " call {} @{}({args})", - ToQir::::to_qir(&callable.output_type, program), - callable.name - ) - } -} - -fn fcmp_to_qir( - op: FcmpConditionCode, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for fcmp {op}" - ); - - assert_eq!(var_ty, "i1", "unsupported output type {var_ty} for fcmp"); - format!( - " {} = fcmp {} {lhs_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - ToQir::::to_qir(&op, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn icmp_to_qir( - op: ConditionCode, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for icmp {op}" - ); - - assert_eq!(var_ty, "i1", "unsupported output type {var_ty} for icmp"); - format!( - " {} = icmp {} {lhs_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - ToQir::::to_qir(&op, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn binop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn fbinop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "double", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn simple_bitwise_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn phi_to_qir( - args: &[(rir::Operand, rir::BlockId)], - variable: rir::Variable, - program: &rir::Program, -) -> String { - assert!( - !args.is_empty(), - "phi instruction should have at least one argument" - ); - let var_ty = get_variable_ty(variable); - let args = args - .iter() - .map(|(arg, block_id)| { - let arg_ty = get_value_ty(arg); - assert_eq!( - arg_ty, var_ty, - "mismatched types ({var_ty} [... {arg_ty}]) for phi" - ); - format!( - "[{}, %{}]", - get_value_as_str(arg, program), - ToQir::::to_qir(block_id, program) - ) - }) - .collect::>() - .join(", "); - - format!( - " {} = phi {var_ty} {args}", - ToQir::::to_qir(&variable.variable_id, program) - ) -} - -fn get_value_as_str(value: &rir::Operand, program: &rir::Program) -> String { - match value { - rir::Operand::Literal(lit) => match lit { - rir::Literal::Bool(b) => format!("{b}"), - rir::Literal::Double(d) => { - if (d.floor() - d.ceil()).abs() < f64::EPSILON { - // The value is a whole number, which requires at least one decimal point - // to differentiate it from an integer value. - format!("{d:.1}") - } else { - format!("{d}") - } - } - rir::Literal::Integer(i) => format!("{i}"), - rir::Literal::Pointer => "null".to_string(), - rir::Literal::Qubit(q) => format!("{q}"), - rir::Literal::Result(r) => format!("{r}"), - rir::Literal::Tag(..) => panic!( - "tag literals should not be used as string values outside of output recording" - ), - }, - rir::Operand::Variable(var) => ToQir::::to_qir(&var.variable_id, program), - } -} - -fn get_value_ty(lhs: &rir::Operand) -> &str { - match lhs { - rir::Operand::Literal(lit) => match lit { - rir::Literal::Integer(_) => "i64", - rir::Literal::Bool(_) => "i1", - rir::Literal::Double(_) => get_f64_ty(), - rir::Literal::Qubit(_) => "%Qubit*", - rir::Literal::Result(_) => "%Result*", - rir::Literal::Pointer | rir::Literal::Tag(..) => "i8*", - }, - rir::Operand::Variable(var) => get_variable_ty(*var), - } -} - -fn get_variable_ty(variable: rir::Variable) -> &'static str { - match variable.ty { - rir::Ty::Integer => "i64", - rir::Ty::Boolean => "i1", - rir::Ty::Double => get_f64_ty(), - rir::Ty::Qubit => "%Qubit*", - rir::Ty::Result => "%Result*", - rir::Ty::Pointer => "i8*", - } -} - -/// phi only supports "Floating-Point Types" which are defined as: -/// - `half` (`f16`) -/// - `bfloat` -/// - `float` (`f32`) -/// - `double` (`f64`) -/// - `fp128` -/// -/// We only support `f64`, so we break the pattern used for integers -/// and have to use `double` here. -/// -/// This conflicts with the QIR spec which says f64. Need to follow up on this. -fn get_f64_ty() -> &'static str { - "double" -} - -impl ToQir for rir::BlockId { - fn to_qir(&self, _program: &rir::Program) -> String { - format!("block_{}", self.0) - } -} - -impl ToQir for rir::Block { - fn to_qir(&self, program: &rir::Program) -> String { - self.0 - .iter() - .map(|instr| ToQir::::to_qir(instr, program)) - .collect::>() - .join("\n") - } -} - -impl ToQir for rir::Callable { - fn to_qir(&self, program: &rir::Program) -> String { - let input_type = self - .input_type - .iter() - .map(|t| ToQir::::to_qir(t, program)) - .collect::>() - .join(", "); - let output_type = ToQir::::to_qir(&self.output_type, program); - let Some(entry_id) = self.body else { - return format!( - "declare {output_type} @{}({input_type}){}", - self.name, - match self.call_type { - rir::CallableType::Measurement | rir::CallableType::Reset => { - // These callables are a special case that need the irreversible attribute. - " #1" - } - rir::CallableType::NoiseIntrinsic => " #2", - _ => "", - } - ); - }; - let mut body = String::new(); - let mut all_blocks = vec![entry_id]; - all_blocks.extend(get_all_block_successors(entry_id, program)); - for block_id in all_blocks { - let block = program.get_block(block_id); - write!( - body, - "{}:\n{}\n", - ToQir::::to_qir(&block_id, program), - ToQir::::to_qir(block, program) - ) - .expect("writing to string should succeed"); - } - assert!( - input_type.is_empty(), - "entry point should not have an input" - ); - format!("define {output_type} @ENTRYPOINT__main() #0 {{\n{body}}}",) - } -} - -impl ToQir for rir::Program { - fn to_qir(&self, _program: &rir::Program) -> String { - let callables = self - .callables - .iter() - .map(|(_, callable)| ToQir::::to_qir(callable, self)) - .collect::>() - .join("\n\n"); - let profile = if self.config.is_base() { - "base_profile" - } else { - "adaptive_profile" - }; - let mut constants = String::default(); - for (idx, tag) in self.tags.iter().enumerate() { - // We need to add the tag as a global constant. - writeln!( - constants, - "@{idx} = internal constant [{} x i8] c\"{tag}\\00\"", - tag.len() + 1 - ) - .expect("writing to string should succeed"); - } - let body = format!( - include_str!("./v1/template.ll"), - constants, - callables, - profile, - self.num_qubits, - self.num_results, - get_additional_module_attributes(self) - ); - let flags = get_module_metadata(self); - body + "\n" + &flags - } -} - -fn get_additional_module_attributes(program: &rir::Program) -> String { - let mut attrs = String::new(); - if program.attrs.contains(Attributes::QdkNoise) { - attrs.push_str("\nattributes #2 = { \"qdk_noise\" }"); - } - - attrs -} - -/// Create the module metadata for the given program. -/// creating the `llvm.module.flags` and its associated values. -fn get_module_metadata(program: &rir::Program) -> String { - let mut flags = String::new(); - - // push the default attrs, we don't have any config values - // for now that would change any of them. - flags.push_str( - r#" -!0 = !{i32 1, !"qir_major_version", i32 1} -!1 = !{i32 7, !"qir_minor_version", i32 0} -!2 = !{i32 1, !"dynamic_qubit_management", i1 false} -!3 = !{i32 1, !"dynamic_result_management", i1 false} -"#, - ); - - let mut index = 4; - - // If we are not in the base profile, we need to add the capabilities - // associated with the adaptive profile. - if !program.config.is_base() { - // loop through the capabilities and add them to the metadata - // for values that we can generate. - for cap in program.config.capabilities.iter() { - match cap { - TargetCapabilityFlags::IntegerComputations => { - // Use `5` as the flag to signify "Append" mode. See https://llvm.org/docs/LangRef.html#module-flags-metadata - writeln!( - flags, - "!{index} = !{{i32 5, !\"int_computations\", !{{!\"i64\"}}}}", - ) - .expect("writing to string should succeed"); - index += 1; - } - TargetCapabilityFlags::FloatingPointComputations => { - // Use `5` as the flag to signify "Append" mode. See https://llvm.org/docs/LangRef.html#module-flags-metadata - writeln!( - flags, - "!{index} = !{{i32 5, !\"float_computations\", !{{!\"double\"}}}}", - ) - .expect("writing to string should succeed"); - index += 1; - } - _ => {} - } - } - } - - let mut metadata_def = String::new(); - metadata_def.push_str("!llvm.module.flags = !{"); - for i in 0..index - 1 { - write!(metadata_def, "!{i}, ").expect("writing to string should succeed"); - } - writeln!(metadata_def, "!{}}}", index - 1).expect("writing to string should succeed"); - metadata_def + &flags -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/bool.rs b/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/bool.rs deleted file mode 100644 index a2968d9882..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/bool.rs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn logical_and_literals() { - let inst = rir::Instruction::LogicalAnd( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = and i1 true, false"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_and_variables() { - let inst = rir::Instruction::LogicalAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = and i1 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_true_literal() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 true, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_variables() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 %var_1, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_false_literal() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 false, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_or_literals() { - let inst = rir::Instruction::LogicalOr( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = or i1 true, false"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_or_variables() { - let inst = rir::Instruction::LogicalOr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = or i1 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/double.rs b/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/double.rs deleted file mode 100644 index 649ce98f6e..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/double.rs +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use core::f64::consts::{E, PI}; - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir::{ - FcmpConditionCode, Instruction, Literal, Operand, Program, Ty, Variable, VariableId, -}; - -#[test] -#[should_panic(expected = "unsupported type double for add")] -fn add_double_literals() { - let inst = Instruction::Add( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for sub")] -fn sub_double_literals() { - let inst = Instruction::Sub( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for mul")] -fn mul_double_literals() { - let inst = Instruction::Mul( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for sdiv")] -fn sdiv_double_literals() { - let inst = Instruction::Sdiv( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -fn fadd_double_literals() { - let inst = Instruction::Fadd( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fadd double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -#[should_panic(expected = "unsupported type double for ashr")] -fn ashr_double_literals() { - let inst = Instruction::Ashr( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for and")] -fn bitwise_and_double_literals() { - let inst = Instruction::BitwiseAnd( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for not")] -fn bitwise_not_double_literals() { - let inst = Instruction::BitwiseNot( - Operand::Literal(Literal::Double(PI)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for or")] -fn bitwise_or_double_literals() { - let inst = Instruction::BitwiseOr( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for xor")] -fn bitwise_xor_double_literals() { - let inst = Instruction::BitwiseXor( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -fn fadd_double_variables() { - let inst = Instruction::Fadd( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fadd double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oeq_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oeq double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oeq_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oeq double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_one_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndNotEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp one double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_one_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndNotEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp one double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_olt_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThan, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp olt double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_olt_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThan, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp olt double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_ole_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThanOrEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ole double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_ole_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThanOrEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ole double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_ogt_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThan, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ogt double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_ogt_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThan, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ogt double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_oge_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThanOrEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oge double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oge_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThanOrEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oge double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fmul_double_literals() { - let inst = Instruction::Fmul( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fmul double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fmul_double_variables() { - let inst = Instruction::Fmul( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fmul double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fdiv_double_literals() { - let inst = Instruction::Fdiv( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fdiv double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fdiv_double_variables() { - let inst = Instruction::Fdiv( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fdiv double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fsub_double_literals() { - let inst = Instruction::Fsub( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fsub double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fsub_double_variables() { - let inst = Instruction::Fsub( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fsub double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn convert_double_literal_to_integer() { - let inst = Instruction::Convert( - Operand::Literal(Literal::Double(PI)), - Variable { - variable_id: VariableId(0), - ty: Ty::Integer, - }, - ); - expect![" %var_0 = fptosi double 3.141592653589793 to i64"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn convert_double_variable_to_integer() { - let inst = Instruction::Convert( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Integer, - }, - ); - expect![" %var_0 = fptosi double %var_1 to i64"].assert_eq(&inst.to_qir(&Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/int.rs b/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/int.rs deleted file mode 100644 index fe2c725544..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/int.rs +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn add_integer_literals() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = add i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn add_integer_variables() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = add i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn ashr_integer_literals() { - let inst = rir::Instruction::Ashr( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = ashr i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn ashr_integer_variables() { - let inst = rir::Instruction::Ashr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = ashr i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_and_integer_literals() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = and i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_add_integer_variables() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = and i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_not_integer_literals() { - let inst = rir::Instruction::BitwiseNot( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 2, -1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_not_integer_variables() { - let inst = rir::Instruction::BitwiseNot( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 %var_1, -1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_or_integer_literals() { - let inst = rir::Instruction::BitwiseOr( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = or i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_or_integer_variables() { - let inst = rir::Instruction::BitwiseOr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = or i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_xor_integer_literals() { - let inst = rir::Instruction::BitwiseXor( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_xor_integer_variables() { - let inst = rir::Instruction::BitwiseXor( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_eq_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp eq i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_eq_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp eq i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_ne_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Ne, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp ne i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_ne_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Ne, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp ne i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_slt_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Slt, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp slt i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_slt_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Slt, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp slt i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sle_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sle, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sle i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sle_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sle, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sle i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sgt_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sgt, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sgt i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sgt_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sgt, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sgt i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sge_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sge, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sge i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sge_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sge, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sge i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn mul_integer_literals() { - let inst = rir::Instruction::Mul( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = mul i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn mul_integer_variables() { - let inst = rir::Instruction::Mul( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = mul i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sdiv_integer_literals() { - let inst = rir::Instruction::Sdiv( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sdiv i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sdiv_integer_variables() { - let inst = rir::Instruction::Sdiv( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sdiv i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn shl_integer_literals() { - let inst = rir::Instruction::Shl( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = shl i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn shl_integer_variables() { - let inst = rir::Instruction::Shl( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = shl i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn srem_integer_literals() { - let inst = rir::Instruction::Srem( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = srem i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn srem_integer_variables() { - let inst = rir::Instruction::Srem( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = srem i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sub_integer_literals() { - let inst = rir::Instruction::Sub( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sub i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sub_integer_variables() { - let inst = rir::Instruction::Sub( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sub i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn convert_integer_literal_to_double() { - let inst = rir::Instruction::Convert( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - expect![" %var_0 = sitofp i64 2 to double"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn convert_integer_variable_to_double() { - let inst = rir::Instruction::Convert( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - expect![" %var_0 = sitofp i64 %var_1 to double"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/invalid.rs b/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/invalid.rs deleted file mode 100644 index 62383253c0..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/invalid.rs +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use qsc_rir::rir; - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for add")] -fn add_mismatched_literal_input_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Double(1.0)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for add")] -fn add_mismatched_literal_input_output_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for add")] -fn add_mismatched_variable_input_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Double, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for add")] -fn add_mismatched_variable_input_output_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for and")] -fn bitwise_and_mismatched_literal_input_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Double(1.0)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for and")] -fn bitwise_and_mismatched_literal_input_output_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for and")] -fn bitwise_and_mismatched_variable_input_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Double, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for and")] -fn bitwise_and_mismatched_variable_input_output_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type i1 for add")] -fn add_bool_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched types (i64 [... i1]) for phi")] -fn phi_with_mismatched_args_should_panic() { - let args = [ - ( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(13), - ty: rir::Ty::Integer, - }), - rir::BlockId(3), - ), - ( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Boolean, - }), - rir::BlockId(7), - ), - ]; - let inst = rir::Instruction::Phi( - args.to_vec(), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported output type i64 for icmp")] -fn icmp_with_non_boolean_result_var_should_panic() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = inst.to_qir(&rir::Program::default()); -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/phi.rs b/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/phi.rs deleted file mode 100644 index f06c346d80..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests/phi.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -#[should_panic(expected = "phi instruction should have at least one argument")] -fn phi_with_empty_args() { - let args = []; - let inst = rir::Instruction::Phi( - args.to_vec(), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -fn phi_with_single_arg() { - let args = [( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(13), - ty: rir::Ty::Integer, - }), - rir::BlockId(3), - )]; - let inst = rir::Instruction::Phi( - args.to_vec(), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = phi i64 [%var_13, %block_3]"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn phi_with_multiple_args() { - let args = [ - ( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(13), - ty: rir::Ty::Integer, - }), - rir::BlockId(3), - ), - ( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::BlockId(7), - ), - ]; - let inst = rir::Instruction::Phi( - args.to_vec(), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = phi i64 [%var_13, %block_3], [%var_2, %block_7]"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v1/template.ll b/source/compiler/qsc_codegen/src/qir/v1/template.ll deleted file mode 100644 index f18691f5fe..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/template.ll +++ /dev/null @@ -1,10 +0,0 @@ -%Result = type opaque -%Qubit = type opaque - -{} -{} - -attributes #0 = {{ "entry_point" "output_labeling_schema" "qir_profiles"="{}" "required_num_qubits"="{}" "required_num_results"="{}" }} -attributes #1 = {{ "irreversible" }}{} - -; module flags diff --git a/source/compiler/qsc_codegen/src/qir/v1/tests.rs b/source/compiler/qsc_codegen/src/qir/v1/tests.rs deleted file mode 100644 index 6c18afa9c4..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v1/tests.rs +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::ToQir; -use expect_test::expect; -use qsc_rir::builder; -use qsc_rir::rir; - -#[test] -fn single_qubit_gate_decl_works() { - let decl = builder::x_decl(); - expect!["declare void @__quantum__qis__x__body(%Qubit*)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn two_qubit_gate_decl_works() { - let decl = builder::cx_decl(); - expect!["declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn single_qubit_rotation_decl_works() { - let decl = builder::rx_decl(); - expect!["declare void @__quantum__qis__rx__body(double, %Qubit*)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn measurement_decl_works() { - let decl = builder::m_decl(); - expect!["declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn read_result_decl_works() { - let decl = builder::read_result_decl(); - expect!["declare i1 @__quantum__rt__read_result(%Result*)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn result_record_decl_works() { - let decl = builder::result_record_decl(); - expect!["declare void @__quantum__rt__result_record_output(%Result*, i8*)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn single_qubit_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::x_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![rir::Operand::Literal(rir::Literal::Qubit(0))], - None, - None, - ); - expect![" call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Literal(rir::Literal::Double(std::f64::consts::PI)), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![" call void @__quantum__qis__rx__body(double 3.141592653589793, %Qubit* inttoptr (i64 0 to %Qubit*))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_round_number_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Literal(rir::Literal::Double(3.0)), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![ - " call void @__quantum__qis__rx__body(double 3.0, %Qubit* inttoptr (i64 0 to %Qubit*))" - ] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_variable_angle_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![ - " call void @__quantum__qis__rx__body(double %var_0, %Qubit* inttoptr (i64 0 to %Qubit*))" - ] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn bell_program() { - let program = builder::bell_program(); - expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_a\00" - @1 = internal constant [6 x i8] c"1_a0r\00" - @2 = internal constant [6 x i8] c"2_a1r\00" - - declare void @__quantum__qis__h__body(%Qubit*) - - declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) - - declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 - - declare void @__quantum__rt__array_record_output(i64, i8*) - - declare void @__quantum__rt__result_record_output(%Result*, i8*) - - define i64 @ENTRYPOINT__main() #0 { - block_0: - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - call void @__quantum__rt__array_record_output(i64 2, i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @1, i64 0, i64 0)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* getelementptr inbounds ([6 x i8], [6 x i8]* @2, i64 0, i64 0)) - ret i64 0 - } - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="2" } - attributes #1 = { "irreversible" } - - ; module flags - - !llvm.module.flags = !{!0, !1, !2, !3} - - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} - !2 = !{i32 1, !"dynamic_qubit_management", i1 false} - !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]].assert_eq(&program.to_qir(&program)); -} - -#[test] -fn teleport_program() { - let program = builder::teleport_program(); - expect![[r#" - %Result = type opaque - %Qubit = type opaque - - @0 = internal constant [4 x i8] c"0_r\00" - - declare void @__quantum__qis__h__body(%Qubit*) - - declare void @__quantum__qis__z__body(%Qubit*) - - declare void @__quantum__qis__x__body(%Qubit*) - - declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) - - declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 - - declare i1 @__quantum__rt__read_result(%Result*) - - declare void @__quantum__rt__result_record_output(%Result*, i8*) - - define i64 @ENTRYPOINT__main() #0 { - block_0: - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*)) - call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) - call void @__quantum__qis__cx__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 2 to %Qubit*)) - call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) - %var_0 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 0 to %Result*)) - br i1 %var_0, label %block_1, label %block_2 - block_1: - call void @__quantum__qis__z__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - br label %block_2 - block_2: - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 1 to %Result*)) - %var_1 = call i1 @__quantum__rt__read_result(%Result* inttoptr (i64 1 to %Result*)) - br i1 %var_1, label %block_3, label %block_4 - block_3: - call void @__quantum__qis__x__body(%Qubit* inttoptr (i64 1 to %Qubit*)) - br label %block_4 - block_4: - call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 2 to %Result*)) - call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* getelementptr inbounds ([4 x i8], [4 x i8]* @0, i64 0, i64 0)) - ret i64 0 - } - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="3" } - attributes #1 = { "irreversible" } - - ; module flags - - !llvm.module.flags = !{!0, !1, !2, !3} - - !0 = !{i32 1, !"qir_major_version", i32 1} - !1 = !{i32 7, !"qir_minor_version", i32 0} - !2 = !{i32 1, !"dynamic_qubit_management", i1 false} - !3 = !{i32 1, !"dynamic_result_management", i1 false} - "#]].assert_eq(&program.to_qir(&program)); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2.rs b/source/compiler/qsc_codegen/src/qir/v2.rs deleted file mode 100644 index 9a4a02f833..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2.rs +++ /dev/null @@ -1,661 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#[cfg(test)] -mod instruction_tests; - -#[cfg(test)] -mod tests; - -use qsc_data_structures::attrs::Attributes; -use qsc_rir::{ - rir::{self, ConditionCode, FcmpConditionCode}, - utils::get_all_block_successors, -}; -use std::fmt::Write; - -/// A trait for converting a type into QIR of type `T`. -/// This can be used to generate QIR strings or other representations. -pub trait ToQir { - fn to_qir(&self, program: &rir::Program) -> T; -} - -impl ToQir for rir::Literal { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::Literal::Bool(b) => format!("i1 {b}"), - rir::Literal::Double(d) => { - if (d.floor() - d.ceil()).abs() < f64::EPSILON { - // The value is a whole number, which requires at least one decimal point - // to differentiate it from an integer value. - format!("double {d:.1}") - } else { - format!("double {d}") - } - } - rir::Literal::Integer(i) => format!("i64 {i}"), - rir::Literal::Pointer => "ptr null".to_string(), - rir::Literal::Qubit(q) => format!("ptr inttoptr (i64 {q} to ptr)"), - rir::Literal::Result(r) => format!("ptr inttoptr (i64 {r} to ptr)"), - rir::Literal::Tag(idx, _) => format!("ptr @{idx}"), - } - } -} - -impl ToQir for rir::Ty { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::Ty::Boolean => "i1".to_string(), - rir::Ty::Double => "double".to_string(), - rir::Ty::Integer => "i64".to_string(), - rir::Ty::Pointer | rir::Ty::Qubit | rir::Ty::Result => "ptr".to_string(), - } - } -} - -impl ToQir for Option { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - Some(ty) => ToQir::::to_qir(ty, program), - None => "void".to_string(), - } - } -} - -impl ToQir for rir::VariableId { - fn to_qir(&self, _program: &rir::Program) -> String { - format!("%var_{}", self.0) - } -} - -impl ToQir for rir::Variable { - fn to_qir(&self, program: &rir::Program) -> String { - format!( - "{} {}", - ToQir::::to_qir(&self.ty, program), - ToQir::::to_qir(&self.variable_id, program) - ) - } -} - -impl ToQir for rir::Operand { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - rir::Operand::Literal(lit) => ToQir::::to_qir(lit, program), - rir::Operand::Variable(var) => ToQir::::to_qir(var, program), - } - } -} - -impl ToQir for rir::FcmpConditionCode { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::FcmpConditionCode::False => "false".to_string(), - rir::FcmpConditionCode::OrderedAndEqual => "oeq".to_string(), - rir::FcmpConditionCode::OrderedAndGreaterThan => "ogt".to_string(), - rir::FcmpConditionCode::OrderedAndGreaterThanOrEqual => "oge".to_string(), - rir::FcmpConditionCode::OrderedAndLessThan => "olt".to_string(), - rir::FcmpConditionCode::OrderedAndLessThanOrEqual => "ole".to_string(), - rir::FcmpConditionCode::OrderedAndNotEqual => "one".to_string(), - rir::FcmpConditionCode::Ordered => "ord".to_string(), - rir::FcmpConditionCode::UnorderedOrEqual => "ueq".to_string(), - rir::FcmpConditionCode::UnorderedOrGreaterThan => "ugt".to_string(), - rir::FcmpConditionCode::UnorderedOrGreaterThanOrEqual => "uge".to_string(), - rir::FcmpConditionCode::UnorderedOrLessThan => "ult".to_string(), - rir::FcmpConditionCode::UnorderedOrLessThanOrEqual => "ule".to_string(), - rir::FcmpConditionCode::UnorderedOrNotEqual => "une".to_string(), - rir::FcmpConditionCode::Unordered => "uno".to_string(), - rir::FcmpConditionCode::True => "true".to_string(), - } - } -} - -impl ToQir for rir::ConditionCode { - fn to_qir(&self, _program: &rir::Program) -> String { - match self { - rir::ConditionCode::Eq => "eq".to_string(), - rir::ConditionCode::Ne => "ne".to_string(), - rir::ConditionCode::Sgt => "sgt".to_string(), - rir::ConditionCode::Sge => "sge".to_string(), - rir::ConditionCode::Slt => "slt".to_string(), - rir::ConditionCode::Sle => "sle".to_string(), - } - } -} - -impl ToQir for rir::Instruction { - fn to_qir(&self, program: &rir::Program) -> String { - match self { - rir::Instruction::Add(lhs, rhs, variable) => { - binop_to_qir("add", lhs, rhs, *variable, program) - } - rir::Instruction::Ashr(lhs, rhs, variable) => { - binop_to_qir("ashr", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseAnd(lhs, rhs, variable) => { - simple_bitwise_to_qir("and", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseNot(value, variable) => { - bitwise_not_to_qir(value, *variable, program) - } - rir::Instruction::BitwiseOr(lhs, rhs, variable) => { - simple_bitwise_to_qir("or", lhs, rhs, *variable, program) - } - rir::Instruction::BitwiseXor(lhs, rhs, variable) => { - simple_bitwise_to_qir("xor", lhs, rhs, *variable, program) - } - rir::Instruction::Branch(cond, true_id, false_id, _) => { - format!( - " br {}, label %{}, label %{}", - ToQir::::to_qir(cond, program), - ToQir::::to_qir(true_id, program), - ToQir::::to_qir(false_id, program) - ) - } - rir::Instruction::Call(call_id, args, output, _) => { - call_to_qir(args, *call_id, *output, program) - } - rir::Instruction::Convert(operand, variable) => { - convert_to_qir(operand, *variable, program) - } - rir::Instruction::Fadd(lhs, rhs, variable) => { - fbinop_to_qir("fadd", lhs, rhs, *variable, program) - } - rir::Instruction::Fdiv(lhs, rhs, variable) => { - fbinop_to_qir("fdiv", lhs, rhs, *variable, program) - } - rir::Instruction::Fmul(lhs, rhs, variable) => { - fbinop_to_qir("fmul", lhs, rhs, *variable, program) - } - rir::Instruction::Fsub(lhs, rhs, variable) => { - fbinop_to_qir("fsub", lhs, rhs, *variable, program) - } - rir::Instruction::LogicalAnd(lhs, rhs, variable) => { - logical_binop_to_qir("and", lhs, rhs, *variable, program) - } - rir::Instruction::LogicalNot(value, variable) => { - logical_not_to_qir(value, *variable, program) - } - rir::Instruction::LogicalOr(lhs, rhs, variable) => { - logical_binop_to_qir("or", lhs, rhs, *variable, program) - } - rir::Instruction::Mul(lhs, rhs, variable) => { - binop_to_qir("mul", lhs, rhs, *variable, program) - } - rir::Instruction::Fcmp(op, lhs, rhs, variable) => { - fcmp_to_qir(*op, lhs, rhs, *variable, program) - } - rir::Instruction::Icmp(op, lhs, rhs, variable) => { - icmp_to_qir(*op, lhs, rhs, *variable, program) - } - rir::Instruction::Jump(block_id) => { - format!(" br label %{}", ToQir::::to_qir(block_id, program)) - } - rir::Instruction::Phi(..) => { - unreachable!("phi nodes should not be inserted for QIR v2 generation") - } - rir::Instruction::Return => " ret i64 0".to_string(), - rir::Instruction::Sdiv(lhs, rhs, variable) => { - binop_to_qir("sdiv", lhs, rhs, *variable, program) - } - rir::Instruction::Shl(lhs, rhs, variable) => { - binop_to_qir("shl", lhs, rhs, *variable, program) - } - rir::Instruction::Srem(lhs, rhs, variable) => { - binop_to_qir("srem", lhs, rhs, *variable, program) - } - rir::Instruction::Store(operand, variable) => { - store_to_qir(*operand, *variable, program) - } - rir::Instruction::Sub(lhs, rhs, variable) => { - binop_to_qir("sub", lhs, rhs, *variable, program) - } - rir::Instruction::Alloca(variable) => alloca_to_qir(*variable, program), - rir::Instruction::Load(var_from, var_to) => load_to_qir(*var_from, *var_to, program), - } - } -} - -fn convert_to_qir( - operand: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let operand_ty = get_value_ty(operand); - let var_ty = get_variable_ty(variable); - assert_ne!( - operand_ty, var_ty, - "input/output types ({operand_ty}, {var_ty}) should not match in convert" - ); - - let convert_instr = match (operand_ty, var_ty) { - ("i64", "double") => "sitofp i64", - ("double", "i64") => "fptosi double", - _ => panic!("unsupported conversion from {operand_ty} to {var_ty} in convert instruction"), - }; - - format!( - " {} = {convert_instr} {} to {var_ty}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(operand, program), - ) -} - -fn store_to_qir(operand: rir::Operand, variable: rir::Variable, program: &rir::Program) -> String { - let op_ty = get_value_ty(&operand); - format!( - " store {op_ty} {}, ptr {}", - get_value_as_str(&operand, program), - ToQir::::to_qir(&variable.variable_id, program) - ) -} - -fn load_to_qir(var_from: rir::Variable, var_to: rir::Variable, program: &rir::Program) -> String { - let var_to_ty = get_variable_ty(var_to); - format!( - " {} = load {var_to_ty}, ptr {}", - ToQir::::to_qir(&var_to.variable_id, program), - ToQir::::to_qir(&var_from.variable_id, program) - ) -} - -fn alloca_to_qir(variable: rir::Variable, program: &rir::Program) -> String { - let variable_ty = get_variable_ty(variable); - format!( - " {} = alloca {variable_ty}", - ToQir::::to_qir(&variable.variable_id, program) - ) -} - -fn logical_not_to_qir( - value: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let value_ty = get_value_ty(value); - let var_ty = get_variable_ty(variable); - assert_eq!( - value_ty, var_ty, - "mismatched input/output types ({value_ty}, {var_ty}) for not" - ); - assert_eq!(var_ty, "i1", "unsupported type {var_ty} for not"); - - format!( - " {} = xor i1 {}, true", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(value, program) - ) -} - -fn logical_binop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i1", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn bitwise_not_to_qir( - value: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let value_ty = get_value_ty(value); - let var_ty = get_variable_ty(variable); - assert_eq!( - value_ty, var_ty, - "mismatched input/output types ({value_ty}, {var_ty}) for not" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for not"); - - format!( - " {} = xor {var_ty} {}, -1", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(value, program) - ) -} - -fn call_to_qir( - args: &[rir::Operand], - call_id: rir::CallableId, - output: Option, - program: &rir::Program, -) -> String { - let args = args - .iter() - .map(|arg| ToQir::::to_qir(arg, program)) - .collect::>() - .join(", "); - let callable = program.get_callable(call_id); - if let Some(output) = output { - format!( - " {} = call {} @{}({args})", - ToQir::::to_qir(&output.variable_id, program), - ToQir::::to_qir(&callable.output_type, program), - callable.name - ) - } else { - format!( - " call {} @{}({args})", - ToQir::::to_qir(&callable.output_type, program), - callable.name - ) - } -} - -fn fcmp_to_qir( - op: FcmpConditionCode, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for fcmp {op}" - ); - - assert_eq!(var_ty, "i1", "unsupported output type {var_ty} for fcmp"); - format!( - " {} = fcmp {} {lhs_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - ToQir::::to_qir(&op, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn icmp_to_qir( - op: ConditionCode, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for icmp {op}" - ); - - assert_eq!(var_ty, "i1", "unsupported output type {var_ty} for icmp"); - format!( - " {} = icmp {} {lhs_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - ToQir::::to_qir(&op, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn binop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn fbinop_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "double", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn simple_bitwise_to_qir( - op: &str, - lhs: &rir::Operand, - rhs: &rir::Operand, - variable: rir::Variable, - program: &rir::Program, -) -> String { - let lhs_ty = get_value_ty(lhs); - let rhs_ty = get_value_ty(rhs); - let var_ty = get_variable_ty(variable); - assert_eq!( - lhs_ty, rhs_ty, - "mismatched input types ({lhs_ty}, {rhs_ty}) for {op}" - ); - assert_eq!( - lhs_ty, var_ty, - "mismatched input/output types ({lhs_ty}, {var_ty}) for {op}" - ); - assert_eq!(var_ty, "i64", "unsupported type {var_ty} for {op}"); - - format!( - " {} = {op} {var_ty} {}, {}", - ToQir::::to_qir(&variable.variable_id, program), - get_value_as_str(lhs, program), - get_value_as_str(rhs, program) - ) -} - -fn get_value_as_str(value: &rir::Operand, program: &rir::Program) -> String { - match value { - rir::Operand::Literal(lit) => match lit { - rir::Literal::Bool(b) => format!("{b}"), - rir::Literal::Double(d) => { - if (d.floor() - d.ceil()).abs() < f64::EPSILON { - // The value is a whole number, which requires at least one decimal point - // to differentiate it from an integer value. - format!("{d:.1}") - } else { - format!("{d}") - } - } - rir::Literal::Integer(i) => format!("{i}"), - rir::Literal::Pointer => "null".to_string(), - rir::Literal::Qubit(q) => format!("{q}"), - rir::Literal::Result(r) => format!("{r}"), - rir::Literal::Tag(..) => panic!( - "tag literals should not be used as string values outside of output recording" - ), - }, - rir::Operand::Variable(var) => ToQir::::to_qir(&var.variable_id, program), - } -} - -fn get_value_ty(lhs: &rir::Operand) -> &str { - match lhs { - rir::Operand::Literal(lit) => match lit { - rir::Literal::Integer(_) => "i64", - rir::Literal::Bool(_) => "i1", - rir::Literal::Double(_) => get_f64_ty(), - rir::Literal::Qubit(_) - | rir::Literal::Result(_) - | rir::Literal::Pointer - | rir::Literal::Tag(..) => "ptr", - }, - rir::Operand::Variable(var) => get_variable_ty(*var), - } -} - -fn get_variable_ty(variable: rir::Variable) -> &'static str { - match variable.ty { - rir::Ty::Integer => "i64", - rir::Ty::Boolean => "i1", - rir::Ty::Double => get_f64_ty(), - rir::Ty::Qubit | rir::Ty::Result | rir::Ty::Pointer => "ptr", - } -} - -/// phi only supports "Floating-Point Types" which are defined as: -/// - `half` (`f16`) -/// - `bfloat` -/// - `float` (`f32`) -/// - `double` (`f64`) -/// - `fp128` -/// -/// We only support `f64`, so we break the pattern used for integers -/// and have to use `double` here. -/// -/// This conflicts with the QIR spec which says f64. Need to follow up on this. -fn get_f64_ty() -> &'static str { - "double" -} - -impl ToQir for rir::BlockId { - fn to_qir(&self, _program: &rir::Program) -> String { - format!("block_{}", self.0) - } -} - -impl ToQir for rir::Block { - fn to_qir(&self, program: &rir::Program) -> String { - self.0 - .iter() - .map(|instr| ToQir::::to_qir(instr, program)) - .collect::>() - .join("\n") - } -} - -impl ToQir for rir::Callable { - fn to_qir(&self, program: &rir::Program) -> String { - let input_type = self - .input_type - .iter() - .map(|t| ToQir::::to_qir(t, program)) - .collect::>() - .join(", "); - let output_type = ToQir::::to_qir(&self.output_type, program); - let Some(entry_id) = self.body else { - return format!( - "declare {output_type} @{}({input_type}){}", - self.name, - match self.call_type { - rir::CallableType::Measurement | rir::CallableType::Reset => { - // These callables are a special case that need the irreversible attribute. - " #1" - } - rir::CallableType::NoiseIntrinsic => " #2", - _ => "", - } - ); - }; - let mut body = String::new(); - let mut all_blocks = vec![entry_id]; - all_blocks.extend(get_all_block_successors(entry_id, program)); - for block_id in all_blocks { - let block = program.get_block(block_id); - write!( - body, - "{}:\n{}\n", - ToQir::::to_qir(&block_id, program), - ToQir::::to_qir(block, program) - ) - .expect("writing to string should succeed"); - } - assert!( - input_type.is_empty(), - "entry point should not have an input" - ); - format!("define {output_type} @ENTRYPOINT__main() #0 {{\n{body}}}",) - } -} - -impl ToQir for rir::Program { - fn to_qir(&self, _program: &rir::Program) -> String { - let callables = self - .callables - .iter() - .map(|(_, callable)| ToQir::::to_qir(callable, self)) - .collect::>() - .join("\n\n"); - let mut constants = String::default(); - for (idx, tag) in self.tags.iter().enumerate() { - // We need to add the tag as a global constant. - writeln!( - constants, - "@{idx} = internal constant [{} x i8] c\"{tag}\\00\"", - tag.len() + 1 - ) - .expect("writing to string should succeed"); - } - let body = format!( - include_str!("./v2/template.ll"), - constants, - callables, - self.num_qubits, - self.num_results, - get_additional_module_attributes(self) - ); - body - } -} - -fn get_additional_module_attributes(program: &rir::Program) -> String { - let mut attrs = String::new(); - if program.attrs.contains(Attributes::QdkNoise) { - attrs.push_str("\nattributes #2 = { \"qdk_noise\" }"); - } - - attrs -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests.rs deleted file mode 100644 index 581b070a12..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests.rs +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -mod alloca; -mod bool; -mod double; -mod int; -mod invalid; -mod load; -mod store; diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/alloca.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/alloca.rs deleted file mode 100644 index bbed29d762..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/alloca.rs +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use crate::qir::v2::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn alloca_integer_without_size() { - let inst = rir::Instruction::Alloca(rir::Variable::new_integer(rir::VariableId(0))); - expect![" %var_0 = alloca i64"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn alloca_bool_without_size() { - let inst = rir::Instruction::Alloca(rir::Variable::new_boolean(rir::VariableId(0))); - expect![" %var_0 = alloca i1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn alloca_double_without_size() { - let inst = rir::Instruction::Alloca(rir::Variable::new_double(rir::VariableId(0))); - expect![" %var_0 = alloca double"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn alloca_pointer_without_size() { - let inst = rir::Instruction::Alloca(rir::Variable::new_ptr(rir::VariableId(0))); - expect![" %var_0 = alloca ptr"].assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/bool.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/bool.rs deleted file mode 100644 index a2968d9882..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/bool.rs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn logical_and_literals() { - let inst = rir::Instruction::LogicalAnd( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = and i1 true, false"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_and_variables() { - let inst = rir::Instruction::LogicalAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = and i1 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_true_literal() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 true, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_variables() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 %var_1, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_not_false_literal() { - let inst = rir::Instruction::LogicalNot( - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = xor i1 false, true"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_or_literals() { - let inst = rir::Instruction::LogicalOr( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = or i1 true, false"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn logical_or_variables() { - let inst = rir::Instruction::LogicalOr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Boolean, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Boolean, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = or i1 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/double.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/double.rs deleted file mode 100644 index 649ce98f6e..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/double.rs +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use core::f64::consts::{E, PI}; - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir::{ - FcmpConditionCode, Instruction, Literal, Operand, Program, Ty, Variable, VariableId, -}; - -#[test] -#[should_panic(expected = "unsupported type double for add")] -fn add_double_literals() { - let inst = Instruction::Add( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for sub")] -fn sub_double_literals() { - let inst = Instruction::Sub( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for mul")] -fn mul_double_literals() { - let inst = Instruction::Mul( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for sdiv")] -fn sdiv_double_literals() { - let inst = Instruction::Sdiv( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -fn fadd_double_literals() { - let inst = Instruction::Fadd( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fadd double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -#[should_panic(expected = "unsupported type double for ashr")] -fn ashr_double_literals() { - let inst = Instruction::Ashr( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for and")] -fn bitwise_and_double_literals() { - let inst = Instruction::BitwiseAnd( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for not")] -fn bitwise_not_double_literals() { - let inst = Instruction::BitwiseNot( - Operand::Literal(Literal::Double(PI)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for or")] -fn bitwise_or_double_literals() { - let inst = Instruction::BitwiseOr( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type double for xor")] -fn bitwise_xor_double_literals() { - let inst = Instruction::BitwiseXor( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - let _ = &inst.to_qir(&Program::default()); -} - -#[test] -fn fadd_double_variables() { - let inst = Instruction::Fadd( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fadd double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oeq_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oeq double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oeq_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oeq double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_one_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndNotEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp one double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_one_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndNotEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp one double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_olt_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThan, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp olt double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_olt_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThan, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp olt double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_ole_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThanOrEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ole double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_ole_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndLessThanOrEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ole double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_ogt_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThan, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ogt double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_ogt_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThan, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp ogt double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} -#[test] -fn fcmp_oge_double_literals() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThanOrEqual, - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oge double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fcmp_oge_double_variables() { - let inst = Instruction::Fcmp( - FcmpConditionCode::OrderedAndGreaterThanOrEqual, - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Boolean, - }, - ); - expect![" %var_0 = fcmp oge double %var_1, %var_2"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fmul_double_literals() { - let inst = Instruction::Fmul( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fmul double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fmul_double_variables() { - let inst = Instruction::Fmul( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fmul double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fdiv_double_literals() { - let inst = Instruction::Fdiv( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fdiv double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fdiv_double_variables() { - let inst = Instruction::Fdiv( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fdiv double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fsub_double_literals() { - let inst = Instruction::Fsub( - Operand::Literal(Literal::Double(PI)), - Operand::Literal(Literal::Double(E)), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fsub double 3.141592653589793, 2.718281828459045"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn fsub_double_variables() { - let inst = Instruction::Fsub( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Operand::Variable(Variable { - variable_id: VariableId(2), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Double, - }, - ); - expect![" %var_0 = fsub double %var_1, %var_2"].assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn convert_double_literal_to_integer() { - let inst = Instruction::Convert( - Operand::Literal(Literal::Double(PI)), - Variable { - variable_id: VariableId(0), - ty: Ty::Integer, - }, - ); - expect![" %var_0 = fptosi double 3.141592653589793 to i64"] - .assert_eq(&inst.to_qir(&Program::default())); -} - -#[test] -fn convert_double_variable_to_integer() { - let inst = Instruction::Convert( - Operand::Variable(Variable { - variable_id: VariableId(1), - ty: Ty::Double, - }), - Variable { - variable_id: VariableId(0), - ty: Ty::Integer, - }, - ); - expect![" %var_0 = fptosi double %var_1 to i64"].assert_eq(&inst.to_qir(&Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/int.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/int.rs deleted file mode 100644 index fe2c725544..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/int.rs +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn add_integer_literals() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = add i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn add_integer_variables() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = add i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn ashr_integer_literals() { - let inst = rir::Instruction::Ashr( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = ashr i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn ashr_integer_variables() { - let inst = rir::Instruction::Ashr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = ashr i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_and_integer_literals() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = and i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_add_integer_variables() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = and i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_not_integer_literals() { - let inst = rir::Instruction::BitwiseNot( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 2, -1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_not_integer_variables() { - let inst = rir::Instruction::BitwiseNot( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 %var_1, -1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_or_integer_literals() { - let inst = rir::Instruction::BitwiseOr( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = or i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_or_integer_variables() { - let inst = rir::Instruction::BitwiseOr( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = or i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_xor_integer_literals() { - let inst = rir::Instruction::BitwiseXor( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn bitwise_xor_integer_variables() { - let inst = rir::Instruction::BitwiseXor( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = xor i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_eq_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp eq i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_eq_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp eq i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_ne_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Ne, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp ne i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_ne_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Ne, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp ne i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_slt_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Slt, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp slt i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_slt_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Slt, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp slt i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sle_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sle, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sle i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sle_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sle, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sle i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sgt_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sgt, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sgt i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sgt_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sgt, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sgt i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} -#[test] -fn icmp_sge_integer_literals() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sge, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sge i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn icmp_sge_integer_variables() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Sge, - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - expect![" %var_0 = icmp sge i64 %var_1, %var_2"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn mul_integer_literals() { - let inst = rir::Instruction::Mul( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = mul i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn mul_integer_variables() { - let inst = rir::Instruction::Mul( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = mul i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sdiv_integer_literals() { - let inst = rir::Instruction::Sdiv( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sdiv i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sdiv_integer_variables() { - let inst = rir::Instruction::Sdiv( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sdiv i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn shl_integer_literals() { - let inst = rir::Instruction::Shl( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = shl i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn shl_integer_variables() { - let inst = rir::Instruction::Shl( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = shl i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn srem_integer_literals() { - let inst = rir::Instruction::Srem( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = srem i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn srem_integer_variables() { - let inst = rir::Instruction::Srem( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = srem i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sub_integer_literals() { - let inst = rir::Instruction::Sub( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sub i64 2, 5"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn sub_integer_variables() { - let inst = rir::Instruction::Sub( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - expect![" %var_0 = sub i64 %var_1, %var_2"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn convert_integer_literal_to_double() { - let inst = rir::Instruction::Convert( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - expect![" %var_0 = sitofp i64 2 to double"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn convert_integer_variable_to_double() { - let inst = rir::Instruction::Convert( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - expect![" %var_0 = sitofp i64 %var_1 to double"] - .assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/invalid.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/invalid.rs deleted file mode 100644 index 54dd71713d..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/invalid.rs +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::super::ToQir; -use qsc_rir::rir; - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for add")] -fn add_mismatched_literal_input_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Double(1.0)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for add")] -fn add_mismatched_literal_input_output_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for add")] -fn add_mismatched_variable_input_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Double, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for add")] -fn add_mismatched_variable_input_output_tys_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for and")] -fn bitwise_and_mismatched_literal_input_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Double(1.0)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for and")] -fn bitwise_and_mismatched_literal_input_output_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input types (i64, double) for and")] -fn bitwise_and_mismatched_variable_input_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Double, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "mismatched input/output types (i64, double) for and")] -fn bitwise_and_mismatched_variable_input_output_tys_should_panic() { - let inst = rir::Instruction::BitwiseAnd( - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(1), - ty: rir::Ty::Integer, - }), - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(2), - ty: rir::Ty::Integer, - }), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported type i1 for add")] -fn add_bool_should_panic() { - let inst = rir::Instruction::Add( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Operand::Literal(rir::Literal::Bool(false)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Boolean, - }, - ); - let _ = &inst.to_qir(&rir::Program::default()); -} - -#[test] -#[should_panic(expected = "unsupported output type i64 for icmp")] -fn icmp_with_non_boolean_result_var_should_panic() { - let inst = rir::Instruction::Icmp( - rir::ConditionCode::Eq, - rir::Operand::Literal(rir::Literal::Integer(2)), - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Integer, - }, - ); - let _ = inst.to_qir(&rir::Program::default()); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/load.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/load.rs deleted file mode 100644 index 65b4fe7ec5..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/load.rs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use crate::qir::v2::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn load_integer_from_pointer() { - let inst = rir::Instruction::Load( - rir::Variable::new_ptr(rir::VariableId(1)), - rir::Variable::new_integer(rir::VariableId(0)), - ); - expect![" %var_0 = load i64, ptr %var_1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn load_bool_from_pointer() { - let inst = rir::Instruction::Load( - rir::Variable::new_ptr(rir::VariableId(1)), - rir::Variable::new_boolean(rir::VariableId(0)), - ); - expect![" %var_0 = load i1, ptr %var_1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn load_double_from_pointer() { - let inst = rir::Instruction::Load( - rir::Variable::new_ptr(rir::VariableId(1)), - rir::Variable::new_double(rir::VariableId(0)), - ); - expect![" %var_0 = load double, ptr %var_1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn load_pointer_from_pointer() { - let inst = rir::Instruction::Load( - rir::Variable::new_ptr(rir::VariableId(1)), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" %var_0 = load ptr, ptr %var_1"].assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/store.rs b/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/store.rs deleted file mode 100644 index 7820f4ac12..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/instruction_tests/store.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use crate::qir::v2::ToQir; -use expect_test::expect; -use qsc_rir::rir; - -#[test] -fn store_integer_literal_to_pointer() { - let inst = rir::Instruction::Store( - rir::Operand::Literal(rir::Literal::Integer(5)), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" store i64 5, ptr %var_0"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn store_integer_variable_to_pointer() { - let inst = rir::Instruction::Store( - rir::Operand::Variable(rir::Variable::new_integer(rir::VariableId(1))), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" store i64 %var_1, ptr %var_0"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn store_bool_literal_to_pointer() { - let inst = rir::Instruction::Store( - rir::Operand::Literal(rir::Literal::Bool(true)), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" store i1 true, ptr %var_0"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn store_double_literal_to_pointer() { - let inst = rir::Instruction::Store( - rir::Operand::Literal(rir::Literal::Double(2.5)), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" store double 2.5, ptr %var_0"].assert_eq(&inst.to_qir(&rir::Program::default())); -} - -#[test] -fn store_pointer_literal_to_pointer() { - let inst = rir::Instruction::Store( - rir::Operand::Literal(rir::Literal::Pointer), - rir::Variable::new_ptr(rir::VariableId(0)), - ); - expect![" store ptr null, ptr %var_0"].assert_eq(&inst.to_qir(&rir::Program::default())); -} diff --git a/source/compiler/qsc_codegen/src/qir/v2/template.ll b/source/compiler/qsc_codegen/src/qir/v2/template.ll deleted file mode 100644 index e4c0ad408c..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/template.ll +++ /dev/null @@ -1,18 +0,0 @@ -{} -{} - -attributes #0 = {{ "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="{}" "required_num_results"="{}" }} -attributes #1 = {{ "irreversible" }}{} - -; module flags - -!llvm.module.flags = !{{!0, !1, !2, !3, !4, !5, !6, !7}} - -!0 = !{{i32 1, !"qir_major_version", i32 2}} -!1 = !{{i32 7, !"qir_minor_version", i32 1}} -!2 = !{{i32 1, !"dynamic_qubit_management", i1 false}} -!3 = !{{i32 1, !"dynamic_result_management", i1 false}} -!4 = !{{i32 5, !"int_computations", !{{!"i64"}}}} -!5 = !{{i32 5, !"float_computations", !{{!"double"}}}} -!6 = !{{i32 7, !"backwards_branching", i2 3}} -!7 = !{{i32 1, !"arrays", i1 true}} diff --git a/source/compiler/qsc_codegen/src/qir/v2/tests.rs b/source/compiler/qsc_codegen/src/qir/v2/tests.rs deleted file mode 100644 index 19facf0b9b..0000000000 --- a/source/compiler/qsc_codegen/src/qir/v2/tests.rs +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -use super::ToQir; -use expect_test::expect; -use qsc_rir::builder; -use qsc_rir::rir; - -#[test] -fn single_qubit_gate_decl_works() { - let decl = builder::x_decl(); - expect!["declare void @__quantum__qis__x__body(ptr)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn two_qubit_gate_decl_works() { - let decl = builder::cx_decl(); - expect!["declare void @__quantum__qis__cx__body(ptr, ptr)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn single_qubit_rotation_decl_works() { - let decl = builder::rx_decl(); - expect!["declare void @__quantum__qis__rx__body(double, ptr)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn measurement_decl_works() { - let decl = builder::m_decl(); - expect!["declare void @__quantum__qis__m__body(ptr, ptr) #1"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn read_result_decl_works() { - let decl = builder::read_result_decl(); - expect!["declare i1 @__quantum__rt__read_result(ptr)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn result_record_decl_works() { - let decl = builder::result_record_decl(); - expect!["declare void @__quantum__rt__result_record_output(ptr, ptr)"] - .assert_eq(&decl.to_qir(&rir::Program::default())); -} - -#[test] -fn single_qubit_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::x_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![rir::Operand::Literal(rir::Literal::Qubit(0))], - None, - None, - ); - expect![" call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Literal(rir::Literal::Double(std::f64::consts::PI)), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![" call void @__quantum__qis__rx__body(double 3.141592653589793, ptr inttoptr (i64 0 to ptr))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_round_number_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Literal(rir::Literal::Double(3.0)), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![" call void @__quantum__qis__rx__body(double 3.0, ptr inttoptr (i64 0 to ptr))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn qubit_rotation_variable_angle_call() { - let mut program = rir::Program::default(); - program - .callables - .insert(rir::CallableId(0), builder::rx_decl()); - let call = rir::Instruction::Call( - rir::CallableId(0), - vec![ - rir::Operand::Variable(rir::Variable { - variable_id: rir::VariableId(0), - ty: rir::Ty::Double, - }), - rir::Operand::Literal(rir::Literal::Qubit(0)), - ], - None, - None, - ); - expect![" call void @__quantum__qis__rx__body(double %var_0, ptr inttoptr (i64 0 to ptr))"] - .assert_eq(&call.to_qir(&program)); -} - -#[test] -fn bell_program() { - let program = builder::bell_program(); - expect![[r#" - @0 = internal constant [4 x i8] c"0_a\00" - @1 = internal constant [6 x i8] c"1_a0r\00" - @2 = internal constant [6 x i8] c"2_a1r\00" - - declare void @__quantum__qis__h__body(ptr) - - declare void @__quantum__qis__cx__body(ptr, ptr) - - declare void @__quantum__qis__m__body(ptr, ptr) #1 - - declare void @__quantum__rt__array_record_output(i64, ptr) - - declare void @__quantum__rt__result_record_output(ptr, ptr) - - define i64 @ENTRYPOINT__main() #0 { - block_0: - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__m__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__m__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) - call void @__quantum__rt__array_record_output(i64 2, ptr @0) - call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1) - call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @2) - ret i64 0 - } - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } - attributes #1 = { "irreversible" } - - ; module flags - - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - - !0 = !{i32 1, !"qir_major_version", i32 2} - !1 = !{i32 7, !"qir_minor_version", i32 1} - !2 = !{i32 1, !"dynamic_qubit_management", i1 false} - !3 = !{i32 1, !"dynamic_result_management", i1 false} - !4 = !{i32 5, !"int_computations", !{!"i64"}} - !5 = !{i32 5, !"float_computations", !{!"double"}} - !6 = !{i32 7, !"backwards_branching", i2 3} - !7 = !{i32 1, !"arrays", i1 true} - "#]].assert_eq(&program.to_qir(&program)); -} - -#[test] -fn teleport_program() { - let program = builder::teleport_program(); - expect![[r#" - @0 = internal constant [4 x i8] c"0_r\00" - - declare void @__quantum__qis__h__body(ptr) - - declare void @__quantum__qis__z__body(ptr) - - declare void @__quantum__qis__x__body(ptr) - - declare void @__quantum__qis__cx__body(ptr, ptr) - - declare void @__quantum__qis__mresetz__body(ptr, ptr) #1 - - declare i1 @__quantum__rt__read_result(ptr) - - declare void @__quantum__rt__result_record_output(ptr, ptr) - - define i64 @ENTRYPOINT__main() #0 { - block_0: - call void @__quantum__qis__x__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__cx__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 1 to ptr)) - call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 2 to ptr)) - call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) - call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) - %var_0 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr)) - br i1 %var_0, label %block_1, label %block_2 - block_1: - call void @__quantum__qis__z__body(ptr inttoptr (i64 1 to ptr)) - br label %block_2 - block_2: - call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 2 to ptr), ptr inttoptr (i64 1 to ptr)) - %var_1 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 1 to ptr)) - br i1 %var_1, label %block_3, label %block_4 - block_3: - call void @__quantum__qis__x__body(ptr inttoptr (i64 1 to ptr)) - br label %block_4 - block_4: - call void @__quantum__qis__mresetz__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 2 to ptr)) - call void @__quantum__rt__result_record_output(ptr inttoptr (i64 2 to ptr), ptr @0) - ret i64 0 - } - - attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="3" "required_num_results"="3" } - attributes #1 = { "irreversible" } - - ; module flags - - !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7} - - !0 = !{i32 1, !"qir_major_version", i32 2} - !1 = !{i32 7, !"qir_minor_version", i32 1} - !2 = !{i32 1, !"dynamic_qubit_management", i1 false} - !3 = !{i32 1, !"dynamic_result_management", i1 false} - !4 = !{i32 5, !"int_computations", !{!"i64"}} - !5 = !{i32 5, !"float_computations", !{!"double"}} - !6 = !{i32 7, !"backwards_branching", i2 3} - !7 = !{i32 1, !"arrays", i1 true} - "#]].assert_eq(&program.to_qir(&program)); -} diff --git a/source/compiler/qsc_llvm/Cargo.toml b/source/compiler/qsc_llvm/Cargo.toml new file mode 100644 index 0000000000..90e198bcfb --- /dev/null +++ b/source/compiler/qsc_llvm/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "qsc_llvm" + +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +arbitrary = { workspace = true } +half = "2" +miette = { workspace = true } +rustc-hash = { workspace = true } +thiserror = { workspace = true } +winnow = { workspace = true } + +[dev-dependencies] +expect-test = { workspace = true } +indoc = { workspace = true } + + +[lints] +workspace = true + +[lib] +doctest = false diff --git a/source/compiler/qsc_llvm/src/bitcode.rs b/source/compiler/qsc_llvm/src/bitcode.rs new file mode 100644 index 0000000000..0ed6bc3b33 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode.rs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod bitstream; +pub(crate) mod constants; +pub mod reader; +pub mod writer; diff --git a/source/compiler/qsc_llvm/src/bitcode/bitstream.rs b/source/compiler/qsc_llvm/src/bitcode/bitstream.rs new file mode 100644 index 0000000000..0b53290df1 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/bitstream.rs @@ -0,0 +1,604 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use rustc_hash::FxHashMap; + +struct BlockScope { + #[allow(dead_code)] + outer_abbrev_width: u32, + length_position: usize, + start_position: usize, + saved_next_abbrev_id: u32, + saved_abbrevs: Vec, +} + +pub struct BitstreamWriter { + buffer: Vec, + cur_byte: u8, + cur_bit: u32, + block_stack: Vec, + next_abbrev_id: u32, + defined_abbrevs: Vec, +} + +impl BitstreamWriter { + pub fn new() -> Self { + Self { + buffer: Vec::new(), + cur_byte: 0, + cur_bit: 0, + block_stack: Vec::new(), + next_abbrev_id: 4, + defined_abbrevs: Vec::new(), + } + } + + pub(crate) fn bit_position(&self) -> usize { + self.buffer.len() * 8 + self.cur_bit as usize + } + + pub fn emit_bits(&mut self, val: u64, width: u32) { + let mut remaining = width; + let mut v = val; + while remaining > 0 { + let bits_free = 8 - self.cur_bit; + let to_write = remaining.min(bits_free); + let mask = if to_write == 64 { + u64::MAX + } else { + (1u64 << to_write) - 1 + }; + self.cur_byte |= ((v & mask) as u8) << self.cur_bit; + self.cur_bit += to_write; + v >>= to_write; + remaining -= to_write; + if self.cur_bit == 8 { + self.buffer.push(self.cur_byte); + self.cur_byte = 0; + self.cur_bit = 0; + } + } + } + + pub fn emit_vbr(&mut self, val: u64, chunk_width: u32) { + let data_bits = chunk_width - 1; + let data_mask = (1u64 << data_bits) - 1; + let mut v = val; + loop { + let chunk = v & data_mask; + v >>= data_bits; + if v == 0 { + self.emit_bits(chunk, chunk_width); + break; + } + self.emit_bits(chunk | (1u64 << data_bits), chunk_width); + } + } + + pub fn enter_subblock( + &mut self, + block_id: u32, + new_abbrev_width: u32, + current_abbrev_width: u32, + ) { + // ENTER_SUBBLOCK abbrev id = 1 + self.emit_bits(1, current_abbrev_width); + self.emit_vbr(u64::from(block_id), 8); + self.emit_vbr(u64::from(new_abbrev_width), 4); + self.align32(); + + let length_position = self.buffer.len(); + // Write 4 zero bytes as placeholder for block length + self.buffer.extend_from_slice(&[0u8; 4]); + let start_position = self.buffer.len(); + + let saved_next_abbrev_id = self.next_abbrev_id; + let saved_abbrevs = std::mem::take(&mut self.defined_abbrevs); + self.next_abbrev_id = 4; + + self.block_stack.push(BlockScope { + outer_abbrev_width: current_abbrev_width, + length_position, + start_position, + saved_next_abbrev_id, + saved_abbrevs, + }); + } + + pub fn exit_block(&mut self, current_abbrev_width: u32) { + // END_BLOCK abbrev id = 0 + self.emit_bits(0, current_abbrev_width); + self.align32(); + + let scope = self.block_stack.pop().expect("no block to exit"); + let content_len = self.buffer.len() - scope.start_position; + let len_words = (content_len / 4) as u32; + let bytes = len_words.to_le_bytes(); + self.buffer[scope.length_position..scope.length_position + 4].copy_from_slice(&bytes); + self.next_abbrev_id = scope.saved_next_abbrev_id; + self.defined_abbrevs = scope.saved_abbrevs; + } + + pub fn emit_record(&mut self, code: u32, values: &[u64], abbrev_width: u32) { + // UNABBREV_RECORD abbrev id = 3 + self.emit_bits(3, abbrev_width); + self.emit_vbr(u64::from(code), 6); + self.emit_vbr(values.len() as u64, 6); + for &v in values { + self.emit_vbr(v, 6); + } + } + + /// Emit a ``DEFINE_ABBREV`` record and return the abbreviation ID assigned. + /// The first abbreviation in a block gets ID 4 (IDs 0-3 are reserved). + #[allow(dead_code)] + pub(crate) fn emit_define_abbrev(&mut self, abbrev: &AbbrevDef, abbrev_width: u32) -> u32 { + // DEFINE_ABBREV abbrev id = 2 + self.emit_bits(2, abbrev_width); + // Count operands: Arrays count as 2 (array + element) + let num_ops: usize = abbrev + .operands + .iter() + .map(|op| { + if matches!(op, AbbrevOperand::Array(_)) { + 2 + } else { + 1 + } + }) + .sum(); + self.emit_vbr(num_ops as u64, 5); + for op in &abbrev.operands { + self.emit_abbrev_operand(op); + } + let id = self.next_abbrev_id; + self.defined_abbrevs.push(abbrev.clone()); + self.next_abbrev_id += 1; + id + } + + #[allow(dead_code)] + fn emit_abbrev_operand(&mut self, op: &AbbrevOperand) { + match op { + AbbrevOperand::Literal(v) => { + self.emit_bits(1, 1); // is_literal = true + self.emit_vbr(*v, 8); + } + AbbrevOperand::Fixed(w) => { + self.emit_bits(0, 1); + self.emit_bits(1, 3); // encoding = Fixed + self.emit_vbr(u64::from(*w), 5); + } + AbbrevOperand::Vbr(w) => { + self.emit_bits(0, 1); + self.emit_bits(2, 3); // encoding = VBR + self.emit_vbr(u64::from(*w), 5); + } + AbbrevOperand::Array(elem) => { + self.emit_bits(0, 1); + self.emit_bits(3, 3); // encoding = Array + self.emit_abbrev_operand(elem); + } + AbbrevOperand::Char6 => { + self.emit_bits(0, 1); + self.emit_bits(4, 3); // encoding = Char6 + } + AbbrevOperand::Blob => { + self.emit_bits(0, 1); + self.emit_bits(5, 3); // encoding = Blob + } + } + } + + /// Emit a record using a previously defined abbreviation. + /// `fields` contains values for all non-literal operands in order. + #[allow(dead_code)] + pub(crate) fn emit_abbreviated_record( + &mut self, + abbrev_id: u32, + fields: &[u64], + abbrev_width: u32, + ) { + let def_idx = (abbrev_id - 4) as usize; + let abbrev = self.defined_abbrevs[def_idx].clone(); + self.emit_bits(u64::from(abbrev_id), abbrev_width); + let mut field_idx = 0; + for op in &abbrev.operands { + match op { + AbbrevOperand::Literal(_) => {} // implicit, not emitted + AbbrevOperand::Fixed(w) => { + self.emit_bits(fields[field_idx], *w); + field_idx += 1; + } + AbbrevOperand::Vbr(w) => { + self.emit_vbr(fields[field_idx], *w); + field_idx += 1; + } + AbbrevOperand::Char6 => { + self.emit_bits(u64::from(encode_char6(fields[field_idx] as u8)), 6); + field_idx += 1; + } + AbbrevOperand::Array(elem) => { + let count = fields[field_idx] as usize; + self.emit_vbr(fields[field_idx], 6); + field_idx += 1; + for _ in 0..count { + self.emit_array_element(elem, fields[field_idx]); + field_idx += 1; + } + } + AbbrevOperand::Blob => { + let len = fields[field_idx] as usize; + self.emit_vbr(fields[field_idx], 6); + field_idx += 1; + self.align32(); + for i in 0..len { + self.emit_bits(fields[field_idx + i], 8); + } + field_idx += len; + self.align32(); + } + } + } + } + + #[allow(dead_code)] + fn emit_array_element(&mut self, elem: &AbbrevOperand, value: u64) { + match elem { + AbbrevOperand::Fixed(w) => self.emit_bits(value, *w), + AbbrevOperand::Vbr(w) => self.emit_vbr(value, *w), + AbbrevOperand::Char6 => { + self.emit_bits(u64::from(encode_char6(value as u8)), 6); + } + _ => {} + } + } + + pub fn align32(&mut self) { + if self.cur_bit > 0 { + self.buffer.push(self.cur_byte); + self.cur_byte = 0; + self.cur_bit = 0; + } + let rem = self.buffer.len() % 4; + if rem != 0 { + let pad = 4 - rem; + self.buffer.extend(std::iter::repeat(0u8).take(pad)); + } + } + + pub fn finish(mut self) -> Vec { + if self.cur_bit > 0 { + self.buffer.push(self.cur_byte); + } + self.buffer + } + + pub(crate) fn patch_u32_bits(&mut self, bit_position: usize, value: u32) { + self.patch_bits(bit_position, u64::from(value), 32); + } + + fn patch_bits(&mut self, bit_position: usize, value: u64, width: u32) { + for bit_offset in 0..width as usize { + let absolute_bit = bit_position + bit_offset; + let byte_index = absolute_bit / 8; + let bit_index = (absolute_bit % 8) as u8; + let mask = 1u8 << bit_index; + let is_set = ((value >> bit_offset) & 1) != 0; + + if byte_index < self.buffer.len() { + if is_set { + self.buffer[byte_index] |= mask; + } else { + self.buffer[byte_index] &= !mask; + } + continue; + } + + if byte_index == self.buffer.len() && u32::from(bit_index) < self.cur_bit { + if is_set { + self.cur_byte |= mask; + } else { + self.cur_byte &= !mask; + } + continue; + } + + panic!("cannot patch future bit position {bit_position}"); + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct AbbrevDef { + pub(crate) operands: Vec, +} + +#[derive(Clone, Debug)] +pub(crate) enum AbbrevOperand { + Literal(u64), + Fixed(u32), + Vbr(u32), + Array(Box), + Char6, + Blob, +} + +struct ReaderBlockScope { + abbrevs: Vec, +} + +fn decode_char6(v: u8) -> u8 { + match v { + 0..=25 => b'a' + v, + 26..=51 => b'A' + v - 26, + 52..=61 => b'0' + v - 52, + 62 => b'.', + 63 => b'_', + _ => b'?', + } +} + +#[allow(dead_code)] +fn encode_char6(c: u8) -> u8 { + match c { + b'a'..=b'z' => c - b'a', + b'A'..=b'Z' => c - b'A' + 26, + b'0'..=b'9' => c - b'0' + 52, + b'.' => 62, + b'_' => 63, + _ => 0, + } +} + +pub struct BitstreamReader<'a> { + data: &'a [u8], + byte_pos: usize, + bit_pos: u32, + block_scope_stack: Vec, + blockinfo_abbrevs: FxHashMap>, +} + +impl<'a> BitstreamReader<'a> { + pub fn new(data: &'a [u8]) -> Self { + Self { + data, + byte_pos: 0, + bit_pos: 0, + block_scope_stack: Vec::new(), + blockinfo_abbrevs: FxHashMap::default(), + } + } + + pub fn read_bits(&mut self, width: u32) -> u64 { + let mut result: u64 = 0; + let mut remaining = width; + let mut shift = 0u32; + while remaining > 0 { + let bits_avail = 8 - self.bit_pos; + let to_read = remaining.min(bits_avail); + let mask = if to_read == 8 { + 0xFF + } else { + ((1u16 << to_read) - 1) as u8 + }; + let bits = (self.data[self.byte_pos] >> self.bit_pos) & mask; + result |= u64::from(bits) << shift; + shift += to_read; + self.bit_pos += to_read; + remaining -= to_read; + if self.bit_pos == 8 { + self.byte_pos += 1; + self.bit_pos = 0; + } + } + result + } + + pub fn read_vbr(&mut self, chunk_width: u32) -> u64 { + let data_bits = chunk_width - 1; + let data_mask = (1u64 << data_bits) - 1; + let cont_bit = 1u64 << data_bits; + let mut result: u64 = 0; + let mut shift = 0u32; + loop { + let chunk = self.read_bits(chunk_width); + result |= (chunk & data_mask) << shift; + if chunk & cont_bit == 0 { + break; + } + shift += data_bits; + } + result + } + + pub fn read_abbrev_id(&mut self, abbrev_width: u32) -> u32 { + self.read_bits(abbrev_width) as u32 + } + + pub fn enter_subblock(&mut self) -> (u32, u32, usize) { + let block_id = self.read_vbr(8) as u32; + let new_abbrev_width = self.read_vbr(4) as u32; + self.align32(); + let block_len_words = self.read_bits(32) as usize; + (block_id, new_abbrev_width, block_len_words) + } + + pub fn skip_block(&mut self, block_len_words: usize) { + self.byte_pos += block_len_words * 4; + self.bit_pos = 0; + } + + pub fn read_unabbrev_record(&mut self) -> (u32, Vec) { + let code = self.read_vbr(6) as u32; + let num_ops = self.read_vbr(6) as usize; + let mut values = Vec::with_capacity(num_ops); + for _ in 0..num_ops { + values.push(self.read_vbr(6)); + } + (code, values) + } + + pub fn align32(&mut self) { + if self.bit_pos > 0 { + self.byte_pos += 1; + self.bit_pos = 0; + } + let rem = self.byte_pos % 4; + if rem != 0 { + self.byte_pos += 4 - rem; + } + } + + pub fn at_end(&self) -> bool { + self.byte_pos >= self.data.len() + } + + pub fn byte_position(&self) -> usize { + self.byte_pos + } + + pub fn push_block_scope(&mut self, block_id: u32) { + let mut scope = ReaderBlockScope { + abbrevs: Vec::new(), + }; + if let Some(inherited) = self.blockinfo_abbrevs.get(&block_id) { + scope.abbrevs.clone_from(inherited); + } + self.block_scope_stack.push(scope); + } + + pub fn pop_block_scope(&mut self) { + self.block_scope_stack.pop(); + } + + pub fn read_define_abbrev(&mut self) -> Result<(), String> { + let abbrev = self.read_abbrev_def()?; + if let Some(scope) = self.block_scope_stack.last_mut() { + scope.abbrevs.push(abbrev); + } + Ok(()) + } + + pub fn read_blockinfo_abbrev(&mut self, target_block_id: u32) -> Result<(), String> { + let abbrev = self.read_abbrev_def()?; + self.blockinfo_abbrevs + .entry(target_block_id) + .or_default() + .push(abbrev); + Ok(()) + } + + pub fn read_abbreviated_record(&mut self, abbrev_id: u32) -> Result<(u32, Vec), String> { + let abbrev_index = (abbrev_id - 4) as usize; + let abbrev = self + .block_scope_stack + .last() + .and_then(|scope| scope.abbrevs.get(abbrev_index)) + .ok_or_else(|| format!("abbreviation {abbrev_id} not defined"))? + .clone(); + + if abbrev.operands.is_empty() { + return Err("abbreviation has no operands".to_string()); + } + + let code = match &abbrev.operands[0] { + AbbrevOperand::Literal(v) => *v as u32, + AbbrevOperand::Fixed(w) => self.read_bits(*w) as u32, + AbbrevOperand::Vbr(w) => self.read_vbr(*w) as u32, + AbbrevOperand::Char6 => u32::from(decode_char6(self.read_bits(6) as u8)), + AbbrevOperand::Array(_) | AbbrevOperand::Blob => { + return Err("abbreviation starts with Array or Blob".to_string()); + } + }; + + let mut values = Vec::new(); + for op in &abbrev.operands[1..] { + self.decode_abbrev_value(op, &mut values)?; + } + Ok((code, values)) + } + + fn read_abbrev_def(&mut self) -> Result { + let num_ops = self.read_vbr(5) as usize; + let mut operands = Vec::with_capacity(num_ops); + let mut i = 0; + while i < num_ops { + let is_literal = self.read_bits(1) != 0; + if is_literal { + operands.push(AbbrevOperand::Literal(self.read_vbr(8))); + } else { + let encoding = self.read_bits(3) as u8; + match encoding { + 1 => operands.push(AbbrevOperand::Fixed(self.read_vbr(5) as u32)), + 2 => operands.push(AbbrevOperand::Vbr(self.read_vbr(5) as u32)), + 3 => { + i += 1; + if i >= num_ops { + return Err("Array abbreviation missing element operand".to_string()); + } + let elem = self.read_single_abbrev_operand()?; + operands.push(AbbrevOperand::Array(Box::new(elem))); + } + 4 => operands.push(AbbrevOperand::Char6), + 5 => operands.push(AbbrevOperand::Blob), + _ => return Err(format!("unknown abbreviation encoding {encoding}")), + } + } + i += 1; + } + Ok(AbbrevDef { operands }) + } + + fn read_single_abbrev_operand(&mut self) -> Result { + let is_literal = self.read_bits(1) != 0; + if is_literal { + Ok(AbbrevOperand::Literal(self.read_vbr(8))) + } else { + let encoding = self.read_bits(3) as u8; + match encoding { + 1 => Ok(AbbrevOperand::Fixed(self.read_vbr(5) as u32)), + 2 => Ok(AbbrevOperand::Vbr(self.read_vbr(5) as u32)), + 4 => Ok(AbbrevOperand::Char6), + 5 => Ok(AbbrevOperand::Blob), + _ => Err(format!( + "invalid element encoding {encoding} in array abbreviation" + )), + } + } + } + + fn decode_abbrev_value( + &mut self, + op: &AbbrevOperand, + values: &mut Vec, + ) -> Result<(), String> { + match op { + AbbrevOperand::Literal(v) => values.push(*v), + AbbrevOperand::Fixed(w) => values.push(self.read_bits(*w)), + AbbrevOperand::Vbr(w) => values.push(self.read_vbr(*w)), + AbbrevOperand::Char6 => { + values.push(u64::from(decode_char6(self.read_bits(6) as u8))); + } + AbbrevOperand::Array(elem) => { + let len = self.read_vbr(6) as usize; + for _ in 0..len { + self.decode_abbrev_value(elem, values)?; + } + } + AbbrevOperand::Blob => { + let len = self.read_vbr(6) as usize; + self.align32(); + for _ in 0..len { + values.push(u64::from(self.data[self.byte_pos])); + self.byte_pos += 1; + } + self.align32(); + } + } + Ok(()) + } +} diff --git a/source/compiler/qsc_llvm/src/bitcode/bitstream/tests.rs b/source/compiler/qsc_llvm/src/bitcode/bitstream/tests.rs new file mode 100644 index 0000000000..c2d27692f7 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/bitstream/tests.rs @@ -0,0 +1,616 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; + +#[test] +fn fixed_bits_round_trip() { + let mut w = BitstreamWriter::new(); + w.emit_bits(1, 1); + w.emit_bits(0xAB, 8); + w.emit_bits(0x1234, 16); + w.emit_bits(0xDEAD_BEEF, 32); + w.emit_bits(0x0123_4567_89AB_CDEF, 64); + let data = w.finish(); + let mut r = BitstreamReader::new(&data); + expect!["1"].assert_eq(&r.read_bits(1).to_string()); + expect!["171"].assert_eq(&r.read_bits(8).to_string()); // 0xAB + expect!["4660"].assert_eq(&r.read_bits(16).to_string()); // 0x1234 + expect!["3735928559"].assert_eq(&r.read_bits(32).to_string()); // 0xDEADBEEF + expect!["81985529216486895"].assert_eq(&r.read_bits(64).to_string()); +} + +#[test] +fn vbr_round_trip_small() { + let mut w = BitstreamWriter::new(); + w.emit_vbr(0, 6); + w.emit_vbr(1, 6); + w.emit_vbr(31, 6); + let data = w.finish(); + let mut r = BitstreamReader::new(&data); + expect!["0"].assert_eq(&r.read_vbr(6).to_string()); + expect!["1"].assert_eq(&r.read_vbr(6).to_string()); + expect!["31"].assert_eq(&r.read_vbr(6).to_string()); +} + +#[test] +fn vbr_round_trip_large() { + let mut w = BitstreamWriter::new(); + w.emit_vbr(127, 6); + w.emit_vbr(255, 6); + w.emit_vbr(1023, 6); + w.emit_vbr(u64::from(u32::MAX), 6); + let data = w.finish(); + let mut r = BitstreamReader::new(&data); + expect!["127"].assert_eq(&r.read_vbr(6).to_string()); + expect!["255"].assert_eq(&r.read_vbr(6).to_string()); + expect!["1023"].assert_eq(&r.read_vbr(6).to_string()); + expect!["4294967295"].assert_eq(&r.read_vbr(6).to_string()); +} + +#[test] +fn align32_at_boundary() { + let mut w = BitstreamWriter::new(); + w.emit_bits(0xAABBCCDD, 32); + let len_before = w.buffer.len(); + w.align32(); + expect!["4"].assert_eq(&len_before.to_string()); + expect!["4"].assert_eq(&w.buffer.len().to_string()); +} + +#[test] +fn align32_pads_correctly() { + // 1 byte partial + let mut w = BitstreamWriter::new(); + w.emit_bits(0xFF, 8); + w.align32(); + expect!["4"].assert_eq(&w.buffer.len().to_string()); + + // 3 bits partial + let mut w = BitstreamWriter::new(); + w.emit_bits(0b101, 3); + w.align32(); + expect!["4"].assert_eq(&w.buffer.len().to_string()); + + // 2 bytes + let mut w = BitstreamWriter::new(); + w.emit_bits(0xFFFF, 16); + w.align32(); + expect!["4"].assert_eq(&w.buffer.len().to_string()); + + // 5 bytes + let mut w = BitstreamWriter::new(); + w.emit_bits(0xFF, 8); + w.emit_bits(0xFFFF_FFFF, 32); + w.align32(); + expect!["8"].assert_eq(&w.buffer.len().to_string()); +} + +#[test] +fn block_enter_exit_round_trip() { + let abbrev_width = 2; + let mut w = BitstreamWriter::new(); + w.enter_subblock(8, 4, abbrev_width); + w.exit_block(4); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + let id = r.read_abbrev_id(abbrev_width); + expect!["1"].assert_eq(&id.to_string()); // ENTER_SUBBLOCK + + let (block_id, new_abbrev, block_len) = r.enter_subblock(); + expect!["8"].assert_eq(&block_id.to_string()); + expect!["4"].assert_eq(&new_abbrev.to_string()); + + // The block content should contain only the END_BLOCK + alignment + let end_pos = r.byte_position() + block_len * 4; + let end_id = r.read_abbrev_id(new_abbrev); + expect!["0"].assert_eq(&end_id.to_string()); // END_BLOCK + r.align32(); + expect!["true"].assert_eq(&(r.byte_position() == end_pos).to_string()); +} + +#[test] +fn nested_blocks_round_trip() { + let outer_aw = 2; + let inner_aw = 3; + let leaf_aw = 4; + + let mut w = BitstreamWriter::new(); + w.enter_subblock(10, inner_aw, outer_aw); + w.enter_subblock(20, leaf_aw, inner_aw); + w.exit_block(leaf_aw); + w.exit_block(inner_aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + // Outer block + let id = r.read_abbrev_id(outer_aw); + expect!["1"].assert_eq(&id.to_string()); + let (bid, aw, _blen) = r.enter_subblock(); + expect!["10"].assert_eq(&bid.to_string()); + expect!["3"].assert_eq(&aw.to_string()); + + // Inner block + let id = r.read_abbrev_id(aw); + expect!["1"].assert_eq(&id.to_string()); + let (bid2, aw2, _blen2) = r.enter_subblock(); + expect!["20"].assert_eq(&bid2.to_string()); + expect!["4"].assert_eq(&aw2.to_string()); + + // END inner + let end = r.read_abbrev_id(aw2); + expect!["0"].assert_eq(&end.to_string()); + r.align32(); + + // END outer + let end = r.read_abbrev_id(aw); + expect!["0"].assert_eq(&end.to_string()); + r.align32(); + expect!["true"].assert_eq(&r.at_end().to_string()); +} + +#[test] +fn unabbrev_record_round_trip() { + let aw = 4; + let mut w = BitstreamWriter::new(); + w.emit_record(7, &[100, 200, 300], aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + let id = r.read_abbrev_id(aw); + expect!["3"].assert_eq(&id.to_string()); // UNABBREV_RECORD + let (code, vals) = r.read_unabbrev_record(); + expect!["7"].assert_eq(&code.to_string()); + expect!["[100, 200, 300]"].assert_eq(&format!("{vals:?}")); +} + +#[test] +fn multiple_records_in_block() { + let aw = 4; + let mut w = BitstreamWriter::new(); + w.enter_subblock(5, aw, 2); + w.emit_record(1, &[10, 20], aw); + w.emit_record(2, &[30], aw); + w.emit_record(3, &[], aw); + w.exit_block(aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + // ENTER_SUBBLOCK + let id = r.read_abbrev_id(2); + expect!["1"].assert_eq(&id.to_string()); + let (bid, new_aw, _blen) = r.enter_subblock(); + expect!["5"].assert_eq(&bid.to_string()); + expect!["4"].assert_eq(&new_aw.to_string()); + + // Record 1 + let id = r.read_abbrev_id(new_aw); + expect!["3"].assert_eq(&id.to_string()); + let (code, vals) = r.read_unabbrev_record(); + expect!["1"].assert_eq(&code.to_string()); + expect!["[10, 20]"].assert_eq(&format!("{vals:?}")); + + // Record 2 + let id = r.read_abbrev_id(new_aw); + expect!["3"].assert_eq(&id.to_string()); + let (code, vals) = r.read_unabbrev_record(); + expect!["2"].assert_eq(&code.to_string()); + expect!["[30]"].assert_eq(&format!("{vals:?}")); + + // Record 3 + let id = r.read_abbrev_id(new_aw); + expect!["3"].assert_eq(&id.to_string()); + let (code, vals) = r.read_unabbrev_record(); + expect!["3"].assert_eq(&code.to_string()); + expect!["[]"].assert_eq(&format!("{vals:?}")); + + // END_BLOCK + let id = r.read_abbrev_id(new_aw); + expect!["0"].assert_eq(&id.to_string()); + r.align32(); + expect!["true"].assert_eq(&r.at_end().to_string()); +} + +#[test] +fn patch_u32_bits_updates_non_byte_aligned_field() { + let mut w = BitstreamWriter::new(); + w.emit_bits(0b101, 3); + let patch_position = w.bit_position(); + w.emit_bits(0, 32); + w.emit_bits(0b11, 2); + w.patch_u32_bits(patch_position, 0xDEAD_BEEF); + + let data = w.finish(); + let mut r = BitstreamReader::new(&data); + expect!["5"].assert_eq(&r.read_bits(3).to_string()); + expect!["3735928559"].assert_eq(&r.read_bits(32).to_string()); + expect!["3"].assert_eq(&r.read_bits(2).to_string()); +} + +#[test] +fn char6_decode_table() { + // a-z + expect!["a"].assert_eq(&(decode_char6(0) as char).to_string()); + expect!["z"].assert_eq(&(decode_char6(25) as char).to_string()); + // A-Z + expect!["A"].assert_eq(&(decode_char6(26) as char).to_string()); + expect!["Z"].assert_eq(&(decode_char6(51) as char).to_string()); + // 0-9 + expect!["0"].assert_eq(&(decode_char6(52) as char).to_string()); + expect!["9"].assert_eq(&(decode_char6(61) as char).to_string()); + // special + expect!["."].assert_eq(&(decode_char6(62) as char).to_string()); + expect!["_"].assert_eq(&(decode_char6(63) as char).to_string()); +} + +#[test] +fn abbrev_fixed_round_trip() { + let mut w = BitstreamWriter::new(); + // Emit a DEFINE_ABBREV definition manually: + // 2 operands: Literal(7), Fixed(8) + w.emit_vbr(2, 5); // num_ops = 2 + w.emit_bits(1, 1); // op1: is_literal = true + w.emit_vbr(7, 8); // op1: value = 7 + w.emit_bits(0, 1); // op2: is_literal = false + w.emit_bits(1, 3); // op2: encoding = Fixed + w.emit_vbr(8, 5); // op2: width = 8 + // Emit an abbreviated record using that abbreviation: + // code=7 (from literal), value=0xAB (Fixed 8 bits) + w.emit_bits(0xAB, 8); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + r.read_define_abbrev() + .expect("read_define_abbrev should succeed"); + + let (code, values) = r + .read_abbreviated_record(4) + .expect("read_abbreviated_record should succeed"); + expect!["7"].assert_eq(&code.to_string()); + expect!["[171]"].assert_eq(&format!("{values:?}")); // 0xAB = 171 + r.pop_block_scope(); +} + +#[test] +fn abbrev_vbr_round_trip() { + let mut w = BitstreamWriter::new(); + // 2 operands: Literal(5), VBR(6) + w.emit_vbr(2, 5); // num_ops = 2 + w.emit_bits(1, 1); // op1: is_literal + w.emit_vbr(5, 8); // op1: value = 5 + w.emit_bits(0, 1); // op2: not literal + w.emit_bits(2, 3); // op2: encoding = VBR + w.emit_vbr(6, 5); // op2: chunk_width = 6 + // Emit record: code=5 (literal), value=12345 (VBR6) + w.emit_vbr(12345, 6); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + r.read_define_abbrev() + .expect("read_define_abbrev should succeed"); + + let (code, values) = r + .read_abbreviated_record(4) + .expect("read_abbreviated_record should succeed"); + expect!["5"].assert_eq(&code.to_string()); + expect!["[12345]"].assert_eq(&format!("{values:?}")); + r.pop_block_scope(); +} + +#[test] +fn abbrev_array_char6_round_trip() { + let mut w = BitstreamWriter::new(); + // 3 operands: Literal(19), Array, Char6 (element) + // LLVM counts the Array and its element encoding as separate operands + w.emit_vbr(3, 5); // num_ops = 3 + w.emit_bits(1, 1); // op1: is_literal + w.emit_vbr(19, 8); // op1: value = 19 (TYPE_CODE_STRUCT_NAME) + w.emit_bits(0, 1); // op2: not literal + w.emit_bits(3, 3); // op2: encoding = Array + // Array element encoding: Char6 + w.emit_bits(0, 1); // not literal + w.emit_bits(4, 3); // encoding = Char6 + // Emit record: code=19, array of Char6 spelling "Hello" + w.emit_vbr(5, 6); // array length = 5 + // Char6 encoding: H=33 (26+7), e=4, l=11, l=11, o=14 + w.emit_bits(33, 6); // H + w.emit_bits(4, 6); // e + w.emit_bits(11, 6); // l + w.emit_bits(11, 6); // l + w.emit_bits(14, 6); // o + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + r.read_define_abbrev() + .expect("read_define_abbrev should succeed"); + + let (code, values) = r + .read_abbreviated_record(4) + .expect("read_abbreviated_record should succeed"); + expect!["19"].assert_eq(&code.to_string()); + let s: String = values.iter().map(|&v| v as u8 as char).collect(); + expect!["Hello"].assert_eq(&s); + r.pop_block_scope(); +} + +#[test] +fn abbrev_blockinfo_propagation() { + let mut w = BitstreamWriter::new(); + // Emit a DEFINE_ABBREV for blockinfo + w.emit_vbr(2, 5); // num_ops = 2 + w.emit_bits(1, 1); // literal + w.emit_vbr(42, 8); // code = 42 + w.emit_bits(0, 1); // not literal + w.emit_bits(1, 3); // Fixed + w.emit_vbr(16, 5); // width = 16 + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + // Store abbreviation in blockinfo for block type 8 + r.push_block_scope(0); // dummy scope for reading + r.read_blockinfo_abbrev(8) + .expect("read_blockinfo_abbrev should succeed"); + r.pop_block_scope(); + + // Now push scope for block 8 — should inherit the abbreviation + r.push_block_scope(8); + // Emit a record using inherited abbreviation (id=4) + let mut w2 = BitstreamWriter::new(); + w2.emit_bits(0x1234, 16); // Fixed(16) + let data2 = w2.finish(); + + let mut r2 = BitstreamReader::new(&data2); + r2.blockinfo_abbrevs = r.blockinfo_abbrevs; + r2.push_block_scope(8); + let (code, values) = r2 + .read_abbreviated_record(4) + .expect("read_abbreviated_record should succeed"); + expect!["42"].assert_eq(&code.to_string()); + expect!["[4660]"].assert_eq(&format!("{values:?}")); // 0x1234 = 4660 + r2.pop_block_scope(); +} + +#[test] +fn multiple_abbrevs_in_scope() { + let mut w = BitstreamWriter::new(); + // Abbrev 1 (id=4): Literal(1), Fixed(8) + w.emit_vbr(2, 5); + w.emit_bits(1, 1); + w.emit_vbr(1, 8); + w.emit_bits(0, 1); + w.emit_bits(1, 3); + w.emit_vbr(8, 5); + // Abbrev 2 (id=5): Literal(2), Fixed(16) + w.emit_vbr(2, 5); + w.emit_bits(1, 1); + w.emit_vbr(2, 8); + w.emit_bits(0, 1); + w.emit_bits(1, 3); + w.emit_vbr(16, 5); + // Record using abbrev 1 (id=4) + w.emit_bits(0xFF, 8); + // Record using abbrev 2 (id=5) + w.emit_bits(0xBEEF, 16); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + r.read_define_abbrev().expect("first abbrev should succeed"); + r.read_define_abbrev() + .expect("second abbrev should succeed"); + + let (code1, vals1) = r + .read_abbreviated_record(4) + .expect("record 1 should succeed"); + expect!["1"].assert_eq(&code1.to_string()); + expect!["[255]"].assert_eq(&format!("{vals1:?}")); + + let (code2, vals2) = r + .read_abbreviated_record(5) + .expect("record 2 should succeed"); + expect!["2"].assert_eq(&code2.to_string()); + expect!["[48879]"].assert_eq(&format!("{vals2:?}")); // 0xBEEF + r.pop_block_scope(); +} + +#[test] +fn writer_define_abbrev_fixed_round_trip() { + let abbrev = AbbrevDef { + operands: vec![AbbrevOperand::Literal(7), AbbrevOperand::Fixed(8)], + }; + let aw = 4; + let mut w = BitstreamWriter::new(); + let id = w.emit_define_abbrev(&abbrev, aw); + expect!["4"].assert_eq(&id.to_string()); + w.emit_abbreviated_record(id, &[0xAB], aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + let aid = r.read_abbrev_id(aw); + expect!["2"].assert_eq(&aid.to_string()); // DEFINE_ABBREV + r.read_define_abbrev() + .expect("read_define_abbrev should succeed"); + let aid2 = r.read_abbrev_id(aw); + expect!["4"].assert_eq(&aid2.to_string()); + let (code, values) = r + .read_abbreviated_record(4) + .expect("read_abbreviated_record should succeed"); + expect!["7"].assert_eq(&code.to_string()); + expect!["[171]"].assert_eq(&format!("{values:?}")); // 0xAB = 171 + r.pop_block_scope(); +} + +#[test] +fn writer_define_abbrev_vbr_round_trip() { + let abbrev = AbbrevDef { + operands: vec![AbbrevOperand::Literal(5), AbbrevOperand::Vbr(6)], + }; + let aw = 4; + let mut w = BitstreamWriter::new(); + let id = w.emit_define_abbrev(&abbrev, aw); + w.emit_abbreviated_record(id, &[12345], aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + let aid = r.read_abbrev_id(aw); + expect!["2"].assert_eq(&aid.to_string()); + r.read_define_abbrev().expect("should succeed"); + let aid2 = r.read_abbrev_id(aw); + expect!["4"].assert_eq(&aid2.to_string()); + let (code, values) = r.read_abbreviated_record(4).expect("should succeed"); + expect!["5"].assert_eq(&code.to_string()); + expect!["[12345]"].assert_eq(&format!("{values:?}")); + r.pop_block_scope(); +} + +#[test] +fn writer_define_abbrev_char6_array_round_trip() { + // Abbreviation: Literal(19), Array(Char6) — struct name style + let abbrev = AbbrevDef { + operands: vec![ + AbbrevOperand::Literal(19), + AbbrevOperand::Array(Box::new(AbbrevOperand::Char6)), + ], + }; + let aw = 4; + let mut w = BitstreamWriter::new(); + let id = w.emit_define_abbrev(&abbrev, aw); + // "Hello" as character values: H=72, e=101, l=108, l=108, o=111 + let fields: Vec = vec![5, 72, 101, 108, 108, 111]; // [count, H, e, l, l, o] + w.emit_abbreviated_record(id, &fields, aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + let aid = r.read_abbrev_id(aw); + expect!["2"].assert_eq(&aid.to_string()); + r.read_define_abbrev().expect("should succeed"); + let aid2 = r.read_abbrev_id(aw); + expect!["4"].assert_eq(&aid2.to_string()); + let (code, values) = r.read_abbreviated_record(4).expect("should succeed"); + expect!["19"].assert_eq(&code.to_string()); + let s: String = values.iter().map(|&v| v as u8 as char).collect(); + expect!["Hello"].assert_eq(&s); + r.pop_block_scope(); +} + +#[test] +fn writer_multiple_abbrevs_round_trip() { + let abbrev1 = AbbrevDef { + operands: vec![AbbrevOperand::Literal(1), AbbrevOperand::Fixed(8)], + }; + let abbrev2 = AbbrevDef { + operands: vec![AbbrevOperand::Literal(2), AbbrevOperand::Fixed(16)], + }; + let aw = 4; + let mut w = BitstreamWriter::new(); + let id1 = w.emit_define_abbrev(&abbrev1, aw); + let id2 = w.emit_define_abbrev(&abbrev2, aw); + expect!["4"].assert_eq(&id1.to_string()); + expect!["5"].assert_eq(&id2.to_string()); + w.emit_abbreviated_record(id1, &[0xFF], aw); + w.emit_abbreviated_record(id2, &[0xBEEF], aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + r.push_block_scope(0); + // Read two DEFINE_ABBREVs + let a1 = r.read_abbrev_id(aw); + expect!["2"].assert_eq(&a1.to_string()); + r.read_define_abbrev().expect("first abbrev"); + let a2 = r.read_abbrev_id(aw); + expect!["2"].assert_eq(&a2.to_string()); + r.read_define_abbrev().expect("second abbrev"); + // Read two abbreviated records + let a3 = r.read_abbrev_id(aw); + expect!["4"].assert_eq(&a3.to_string()); + let (code1, vals1) = r.read_abbreviated_record(4).expect("record 1"); + expect!["1"].assert_eq(&code1.to_string()); + expect!["[255]"].assert_eq(&format!("{vals1:?}")); + let a4 = r.read_abbrev_id(aw); + expect!["5"].assert_eq(&a4.to_string()); + let (code2, vals2) = r.read_abbreviated_record(5).expect("record 2"); + expect!["2"].assert_eq(&code2.to_string()); + expect!["[48879]"].assert_eq(&format!("{vals2:?}")); + r.pop_block_scope(); +} + +#[test] +fn writer_abbrev_in_block_round_trip() { + let outer_aw = 2; + let block_aw = 4; + let abbrev = AbbrevDef { + operands: vec![ + AbbrevOperand::Literal(7), + AbbrevOperand::Fixed(8), + AbbrevOperand::Vbr(6), + ], + }; + let mut w = BitstreamWriter::new(); + w.enter_subblock(8, block_aw, outer_aw); + let id = w.emit_define_abbrev(&abbrev, block_aw); + expect!["4"].assert_eq(&id.to_string()); + w.emit_abbreviated_record(id, &[0xAB, 999], block_aw); + w.emit_record(3, &[42], block_aw); // unabbreviated mixed in + w.exit_block(block_aw); + let data = w.finish(); + + let mut r = BitstreamReader::new(&data); + // ENTER_SUBBLOCK + let aid = r.read_abbrev_id(outer_aw); + expect!["1"].assert_eq(&aid.to_string()); + let (bid, new_aw, _blen) = r.enter_subblock(); + expect!["8"].assert_eq(&bid.to_string()); + expect!["4"].assert_eq(&new_aw.to_string()); + r.push_block_scope(bid); + // DEFINE_ABBREV + let aid = r.read_abbrev_id(new_aw); + expect!["2"].assert_eq(&aid.to_string()); + r.read_define_abbrev().expect("define abbrev"); + // Abbreviated record + let aid = r.read_abbrev_id(new_aw); + expect!["4"].assert_eq(&aid.to_string()); + let (code, values) = r.read_abbreviated_record(4).expect("abbrev record"); + expect!["7"].assert_eq(&code.to_string()); + expect!["[171, 999]"].assert_eq(&format!("{values:?}")); + // Unabbreviated record + let aid = r.read_abbrev_id(new_aw); + expect!["3"].assert_eq(&aid.to_string()); + let (code, vals) = r.read_unabbrev_record(); + expect!["3"].assert_eq(&code.to_string()); + expect!["[42]"].assert_eq(&format!("{vals:?}")); + // END_BLOCK + let aid = r.read_abbrev_id(new_aw); + expect!["0"].assert_eq(&aid.to_string()); + r.pop_block_scope(); + r.align32(); + expect!["true"].assert_eq(&r.at_end().to_string()); +} + +#[test] +fn writer_abbrev_ids_reset_per_block() { + let outer_aw = 2; + let block_aw = 4; + let abbrev = AbbrevDef { + operands: vec![AbbrevOperand::Literal(1), AbbrevOperand::Fixed(8)], + }; + let mut w = BitstreamWriter::new(); + // First block: define abbrev gets id=4 + w.enter_subblock(8, block_aw, outer_aw); + let id1 = w.emit_define_abbrev(&abbrev, block_aw); + expect!["4"].assert_eq(&id1.to_string()); + w.exit_block(block_aw); + // Second block: define abbrev should also get id=4 (reset) + w.enter_subblock(9, block_aw, outer_aw); + let id2 = w.emit_define_abbrev(&abbrev, block_aw); + expect!["4"].assert_eq(&id2.to_string()); + w.exit_block(block_aw); + let _data = w.finish(); +} diff --git a/source/compiler/qsc_llvm/src/bitcode/constants.rs b/source/compiler/qsc_llvm/src/bitcode/constants.rs new file mode 100644 index 0000000000..667ba96640 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/constants.rs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Shared LLVM bitcode block IDs and record codes used by both the reader and +//! writer. + +// Block IDs +pub(crate) const MODULE_BLOCK_ID: u32 = 8; +pub(crate) const PARAMATTR_BLOCK_ID: u32 = 9; +pub(crate) const PARAMATTR_GROUP_BLOCK_ID: u32 = 10; +pub(crate) const CONSTANTS_BLOCK_ID: u32 = 11; +pub(crate) const FUNCTION_BLOCK_ID: u32 = 12; +pub(crate) const IDENTIFICATION_BLOCK_ID: u32 = 13; +pub(crate) const VALUE_SYMTAB_BLOCK_ID: u32 = 14; +pub(crate) const METADATA_BLOCK_ID: u32 = 15; +pub(crate) const TYPE_BLOCK_ID_NEW: u32 = 17; +pub(crate) const STRTAB_BLOCK_ID: u32 = 23; + +// Module codes +pub(crate) const MODULE_CODE_VERSION: u32 = 1; +pub(crate) const MODULE_CODE_TRIPLE: u32 = 2; +pub(crate) const MODULE_CODE_DATALAYOUT: u32 = 3; +pub(crate) const MODULE_CODE_GLOBALVAR: u32 = 7; +pub(crate) const MODULE_CODE_FUNCTION: u32 = 8; +pub(crate) const MODULE_CODE_VSTOFFSET: u32 = 13; +pub(crate) const MODULE_CODE_SOURCE_FILENAME: u32 = 16; + +// Type codes +pub(crate) const TYPE_CODE_NUMENTRY: u32 = 1; +pub(crate) const TYPE_CODE_VOID: u32 = 2; +pub(crate) const TYPE_CODE_FLOAT: u32 = 3; +pub(crate) const TYPE_CODE_DOUBLE: u32 = 4; +pub(crate) const TYPE_CODE_LABEL: u32 = 5; +pub(crate) const TYPE_CODE_OPAQUE: u32 = 6; +pub(crate) const TYPE_CODE_INTEGER: u32 = 7; +pub(crate) const TYPE_CODE_HALF: u32 = 10; +pub(crate) const TYPE_CODE_ARRAY: u32 = 11; +pub(crate) const TYPE_CODE_POINTER: u32 = 16; +pub(crate) const TYPE_CODE_STRUCT_NAME: u32 = 19; +pub(crate) const TYPE_CODE_FUNCTION_TYPE: u32 = 21; +pub(crate) const TYPE_CODE_OPAQUE_POINTER: u32 = 25; + +// Constant codes +pub(crate) const CST_CODE_SETTYPE: u32 = 1; +pub(crate) const CST_CODE_NULL: u32 = 2; +pub(crate) const CST_CODE_INTEGER: u32 = 4; +pub(crate) const CST_CODE_FLOAT: u32 = 6; +pub(crate) const CST_CODE_CSTRING: u32 = 9; +pub(crate) const CST_CODE_CE_CAST: u32 = 11; +pub(crate) const CST_CODE_CE_INBOUNDS_GEP: u32 = 20; + +// Function instruction codes +pub(crate) const FUNC_CODE_DECLAREBLOCKS: u32 = 1; +pub(crate) const FUNC_CODE_INST_BINOP: u32 = 2; +pub(crate) const FUNC_CODE_INST_CAST: u32 = 3; +pub(crate) const FUNC_CODE_INST_SELECT: u32 = 5; +pub(crate) const FUNC_CODE_INST_RET: u32 = 10; +pub(crate) const FUNC_CODE_INST_BR: u32 = 11; +pub(crate) const FUNC_CODE_INST_SWITCH: u32 = 12; +pub(crate) const FUNC_CODE_INST_UNREACHABLE: u32 = 15; +pub(crate) const FUNC_CODE_INST_PHI: u32 = 16; +pub(crate) const FUNC_CODE_INST_ALLOCA: u32 = 19; +pub(crate) const FUNC_CODE_INST_LOAD: u32 = 20; +pub(crate) const FUNC_CODE_INST_CMP2: u32 = 28; +pub(crate) const FUNC_CODE_INST_CALL: u32 = 34; +pub(crate) const FUNC_CODE_INST_GEP: u32 = 43; +pub(crate) const FUNC_CODE_INST_STORE: u32 = 44; + +// Packed CALL cc-info flags +pub(crate) const CALL_EXPLICIT_TYPE_FLAG: u64 = 1_u64 << 15; + +// Value symbol table codes +pub(crate) const VST_CODE_ENTRY: u32 = 1; +pub(crate) const VST_CODE_BBENTRY: u32 = 2; +pub(crate) const VST_CODE_FNENTRY: u32 = 3; + +// String table codes +pub(crate) const STRTAB_BLOB: u32 = 1; + +// Metadata record codes +pub(crate) const METADATA_STRING_OLD: u32 = 1; +pub(crate) const METADATA_VALUE: u32 = 2; +pub(crate) const METADATA_NODE: u32 = 3; +pub(crate) const METADATA_NAME: u32 = 4; +pub(crate) const METADATA_NAMED_NODE: u32 = 10; diff --git a/source/compiler/qsc_llvm/src/bitcode/reader.rs b/source/compiler/qsc_llvm/src/bitcode/reader.rs new file mode 100644 index 0000000000..9c0f797253 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/reader.rs @@ -0,0 +1,3540 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Winnow-based LLVM bitcode reader. +//! +//! Two-layer design: +//! Layer 1 (unchanged): `BitstreamReader` handles bit-level I/O, VBR encoding, +//! block/sub-block navigation, and abbreviation management. +//! Layer 2 (this file): winnow combinators parse record-value slices (`&[u64]`) +//! into LLVM IR model types. + +#[cfg(test)] +mod tests; + +use super::bitstream::BitstreamReader; +use crate::model::Type; +use crate::model::{ + Attribute, AttributeGroup, BasicBlock, BinOpKind, CastKind, Constant, FloatPredicate, Function, + GlobalVariable, Instruction, IntPredicate, Linkage, MetadataNode, MetadataValue, Module, + NamedMetadata, Operand, Param, +}; +use crate::{ReadDiagnostic, ReadDiagnosticKind, ReadPolicy, ReadReport}; +use rustc_hash::FxHashMap; +use std::{cell::RefCell, fmt}; +use winnow::combinator::opt; +use winnow::error::{ContextError, ErrMode}; +use winnow::prelude::*; +use winnow::token::{any, rest}; + +/// Converts a `ParseError` into a winnow `PResult`-compatible error. +fn map_parse_err(result: Result) -> PResult { + result.map_err(|_| ErrMode::Cut(ContextError::new())) +} + +// --------------------------------------------------------------------------- +// Winnow type aliases — &[u64] is a native winnow Stream (Token = u64) +// --------------------------------------------------------------------------- + +type RecordInput<'a> = &'a [u64]; +type PResult = winnow::ModalResult; + +// --------------------------------------------------------------------------- +// Constants — block IDs and record codes +// --------------------------------------------------------------------------- + +use super::constants::*; + +const BLOCKINFO_BLOCK_ID: u32 = 0; +const BLOCKINFO_CODE_SETBID: u32 = 1; + +// --------------------------------------------------------------------------- +// ParseError +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct ParseError { + pub kind: ReadDiagnosticKind, + pub context: &'static str, + pub offset: usize, + pub message: String, +} + +impl ParseError { + fn malformed(offset: usize, context: &'static str, message: impl Into) -> Self { + Self { + kind: ReadDiagnosticKind::MalformedInput, + context, + offset, + message: message.into(), + } + } + + fn unsupported(offset: usize, context: &'static str, message: impl Into) -> Self { + Self { + kind: ReadDiagnosticKind::UnsupportedSemanticConstruct, + context, + offset, + message: message.into(), + } + } +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "bitcode parse error at byte {}: {}", + self.offset, self.message + ) + } +} + +impl std::error::Error for ParseError {} + +impl From for ReadDiagnostic { + fn from(error: ParseError) -> Self { + Self { + kind: error.kind, + offset: Some(error.offset), + context: error.context, + message: error.message, + } + } +} + +impl From for ParseError { + fn from(diagnostic: ReadDiagnostic) -> Self { + Self { + kind: diagnostic.kind, + context: diagnostic.context, + offset: diagnostic.offset.unwrap_or_default(), + message: diagnostic.message, + } + } +} + +// --------------------------------------------------------------------------- +// Value tracking types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +enum ValueEntry { + Global(String), + Function(String), + Constant(Type, Constant), + Local(String, Type), + Param(String, Type), + IntToPtrConst(i64, Type), + GepConst { + source_ty: Type, + ptr_name: String, + ptr_ty: Type, + indices: Vec, + }, +} + +#[derive(Debug, Clone)] +enum MetadataSlotEntry { + String(String), + Value(Type, i64), + Node(u32), +} + +#[derive(Debug, Clone)] +struct FuncProto { + #[allow(dead_code)] + func_type_id: u32, + is_declaration: bool, + #[allow(dead_code)] + paramattr_index: u32, +} + +#[derive(Debug, Clone, Copy)] +struct PendingGlobalInitializer { + global_index: usize, + value_id: u32, +} + +#[derive(Debug, Clone, Copy)] +struct PendingStrtabName { + value_id: usize, + offset: usize, + size: usize, +} + +// --------------------------------------------------------------------------- +// InstrContext — carries resolution state for instruction-level winnow parsers +// --------------------------------------------------------------------------- + +struct InstrContext<'a> { + global_value_table: &'a [ValueEntry], + local_values: &'a [ValueEntry], + type_table: &'a [Type], + paramattr_lists: &'a [Vec], + bb_names: &'a FxHashMap, + diagnostics: &'a RefCell>, + current_value_id: u32, + byte_offset: usize, + policy: ReadPolicy, +} + +impl InstrContext<'_> { + fn record_compatibility_diagnostic( + &self, + kind: ReadDiagnosticKind, + context: &'static str, + message: impl Into, + ) { + if self.policy == ReadPolicy::Compatibility { + self.diagnostics.borrow_mut().push(ReadDiagnostic { + kind, + offset: Some(self.byte_offset), + context, + message: message.into(), + }); + } + } + + fn resolve_known_global_name(&self, value_id: usize) -> Option { + match self.global_value_table.get(value_id) { + Some(ValueEntry::Global(name) | ValueEntry::Function(name)) => Some(name.clone()), + _ => None, + } + } + + fn resolve_operand(&self, relative_id: u64) -> Result { + if relative_id > u64::from(self.current_value_id) { + return Err(ParseError::malformed( + self.byte_offset, + "value resolution", + format!( + "unresolvable relative value ID: {relative_id} exceeds current value ID {}", + self.current_value_id + ), + )); + } + let absolute_id = self.current_value_id - relative_id as u32; + let global_count = self.global_value_table.len() as u32; + + if absolute_id < global_count { + match &self.global_value_table[absolute_id as usize] { + ValueEntry::Global(name) | ValueEntry::Function(name) => { + Ok(Operand::GlobalRef(name.clone())) + } + _ => Err(ParseError::malformed( + self.byte_offset, + "value resolution", + format!( + "global value ID {absolute_id} does not reference a global or function" + ), + )), + } + } else { + let local_idx = (absolute_id - global_count) as usize; + if local_idx < self.local_values.len() { + match &self.local_values[local_idx] { + ValueEntry::Constant(ty, Constant::Int(val)) => { + Ok(Operand::IntConst(ty.clone(), *val)) + } + ValueEntry::Constant(_, Constant::Float(ty, val)) => { + Ok(Operand::FloatConst(ty.clone(), *val)) + } + ValueEntry::Constant(_, Constant::Null) => Ok(Operand::NullPtr), + ValueEntry::Constant(_, Constant::CString(_)) => Ok(Operand::NullPtr), + ValueEntry::IntToPtrConst(val, ty) => Ok(Operand::IntToPtr(*val, ty.clone())), + ValueEntry::GepConst { + source_ty, + ptr_name, + ptr_ty, + indices, + } => Ok(Operand::GetElementPtr { + ty: source_ty.clone(), + ptr: ptr_name.clone(), + ptr_ty: ptr_ty.clone(), + indices: indices.clone(), + }), + ValueEntry::Local(name, ty) => { + Ok(Operand::TypedLocalRef(name.clone(), ty.clone())) + } + ValueEntry::Param(name, ty) => { + Ok(Operand::TypedLocalRef(name.clone(), ty.clone())) + } + ValueEntry::Global(name) | ValueEntry::Function(name) => { + Ok(Operand::GlobalRef(name.clone())) + } + } + } else { + Err(ParseError::malformed( + self.byte_offset, + "value resolution", + format!( + "local value index {local_idx} out of range (have {} local values)", + self.local_values.len() + ), + )) + } + } + } + + fn resolve_phi_operand(&self, encoded_delta: u64, ty: &Type) -> Result { + let delta = sign_unrotate(encoded_delta); + let absolute_id_i64 = i64::from(self.current_value_id) - delta; + if absolute_id_i64 < 0 { + return Err(ParseError::malformed( + self.byte_offset, + "phi instruction", + format!( + "unresolvable PHI value delta {delta} for current value ID {}", + self.current_value_id + ), + )); + } + + let absolute_id = absolute_id_i64 as u32; + let global_count = self.global_value_table.len() as u32; + + if absolute_id < global_count { + return match &self.global_value_table[absolute_id as usize] { + ValueEntry::Global(name) | ValueEntry::Function(name) => { + Ok(Operand::GlobalRef(name.clone())) + } + _ => Err(ParseError::malformed( + self.byte_offset, + "phi instruction", + format!("PHI value ID {absolute_id} does not reference a global or function"), + )), + }; + } + + let local_idx = (absolute_id - global_count) as usize; + if local_idx < self.local_values.len() { + return match &self.local_values[local_idx] { + ValueEntry::Constant(inner_ty, Constant::Int(val)) => { + Ok(Operand::IntConst(inner_ty.clone(), *val)) + } + ValueEntry::Constant(_, Constant::Float(ty, val)) => { + Ok(Operand::FloatConst(ty.clone(), *val)) + } + ValueEntry::Constant(_, Constant::Null | Constant::CString(_)) => { + Ok(Operand::NullPtr) + } + ValueEntry::IntToPtrConst(val, target_ty) => { + Ok(Operand::IntToPtr(*val, target_ty.clone())) + } + ValueEntry::GepConst { + source_ty, + ptr_name, + ptr_ty, + indices, + } => Ok(Operand::GetElementPtr { + ty: source_ty.clone(), + ptr: ptr_name.clone(), + ptr_ty: ptr_ty.clone(), + indices: indices.clone(), + }), + ValueEntry::Local(name, inner_ty) => { + Ok(Operand::TypedLocalRef(name.clone(), inner_ty.clone())) + } + ValueEntry::Param(name, inner_ty) => { + Ok(Operand::TypedLocalRef(name.clone(), inner_ty.clone())) + } + ValueEntry::Global(name) | ValueEntry::Function(name) => { + Ok(Operand::GlobalRef(name.clone())) + } + }; + } + + Ok(Operand::TypedLocalRef( + format!("val_{absolute_id}"), + ty.clone(), + )) + } + + fn resolve_call_target_name(&self, encoded_id: u64) -> Result { + let relative_target = self.resolve_operand(encoded_id).ok(); + + if let Some(Operand::GlobalRef(name)) = &relative_target { + return Ok(name.clone()); + } + + if let Some(name) = self.resolve_known_global_name(encoded_id as usize) { + return Ok(name); + } + + match relative_target { + Some(Operand::LocalRef(name) | Operand::TypedLocalRef(name, _)) + if self.policy == ReadPolicy::Compatibility => + { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "call instruction", + format!( + "call target value {encoded_id} resolved to a local value and was imported using placeholder callee `{name}`" + ), + ); + Ok(name) + } + _ if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "call instruction", + format!( + "call target value {encoded_id} does not resolve to a known function and was imported as `unknown_{encoded_id}`" + ), + ); + Ok(format!("unknown_{encoded_id}")) + } + _ => Err(ParseError::unsupported( + self.byte_offset, + "call instruction", + format!("call target value {encoded_id} does not resolve to a known function"), + )), + } + } + + fn resolve_type(&self, type_id: u32) -> Result { + self.type_table + .get(type_id as usize) + .cloned() + .ok_or_else(|| { + ParseError::malformed( + self.byte_offset, + "type resolution", + format!( + "invalid type ID {type_id} (type table has {} entries)", + self.type_table.len() + ), + ) + }) + } + + fn resolve_function_type(&self, type_id: u32) -> Result<(Type, Vec), ParseError> { + match self.type_table.get(type_id as usize) { + Some(Type::Function(ret, params)) => Ok((ret.as_ref().clone(), params.clone())), + _ if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "call instruction", + format!( + "type ID {type_id} does not resolve to a function type and was imported as `void ()`" + ), + ); + Ok((Type::Void, Vec::new())) + } + _ => Err(ParseError::unsupported( + self.byte_offset, + "call instruction", + format!("type ID {type_id} does not resolve to a function type"), + )), + } + } + + fn resolve_bb_name(&self, bb_id: u32) -> String { + self.bb_names + .get(&bb_id) + .cloned() + .unwrap_or_else(|| format!("bb_{bb_id}")) + } + + fn infer_type(&self, op: &Operand) -> Type { + match op { + Operand::IntConst(ty, _) => ty.clone(), + Operand::FloatConst(ty, _) => ty.clone(), + Operand::NullPtr => Type::Ptr, + Operand::IntToPtr(_, ty) => ty.clone(), + Operand::TypedLocalRef(_, ty) => ty.clone(), + Operand::LocalRef(name) => { + // Look up tracked type from value table + for entry in self.local_values { + match entry { + ValueEntry::Local(n, ty) if n == name => return ty.clone(), + _ => {} + } + } + for entry in self.local_values { + if let ValueEntry::Param(param_name, ty) = entry + && param_name == name + { + return ty.clone(); + } + } + Type::Integer(64) + } + Operand::GlobalRef(_) => Type::Ptr, + Operand::GetElementPtr { ty, .. } => ty.clone(), + } + } + + fn result_name(&self) -> String { + format!("val_{}", self.current_value_id) + } +} + +// --------------------------------------------------------------------------- +// Primitive winnow parsers on &[u64] record values +// --------------------------------------------------------------------------- + +fn parse_char_string(input: &mut RecordInput<'_>) -> PResult { + let values = rest.parse_next(input)?; + Ok(values.iter().map(|&v| v as u8 as char).collect()) +} + +fn remap_block_name(name: &str, bb_names: &FxHashMap) -> String { + let Some(id) = name.strip_prefix("bb_") else { + return name.to_string(); + }; + + let Ok(bb_id) = id.parse::() else { + return name.to_string(); + }; + + bb_names + .get(&bb_id) + .cloned() + .unwrap_or_else(|| name.to_string()) +} + +fn remap_instruction_block_names(instr: &mut Instruction, bb_names: &FxHashMap) { + match instr { + Instruction::Jump { dest } => { + *dest = remap_block_name(dest, bb_names); + } + Instruction::Br { + true_dest, + false_dest, + .. + } => { + *true_dest = remap_block_name(true_dest, bb_names); + *false_dest = remap_block_name(false_dest, bb_names); + } + Instruction::Phi { incoming, .. } => { + for (_, block) in incoming { + *block = remap_block_name(block, bb_names); + } + } + Instruction::Switch { + default_dest, + cases, + .. + } => { + *default_dest = remap_block_name(default_dest, bb_names); + for (_, dest) in cases { + *dest = remap_block_name(dest, bb_names); + } + } + Instruction::Ret(_) + | Instruction::BinOp { .. } + | Instruction::ICmp { .. } + | Instruction::FCmp { .. } + | Instruction::Cast { .. } + | Instruction::Call { .. } + | Instruction::Alloca { .. } + | Instruction::Load { .. } + | Instruction::Store { .. } + | Instruction::Select { .. } + | Instruction::Unreachable + | Instruction::GetElementPtr { .. } => {} + } +} + +// --------------------------------------------------------------------------- +// Type record winnow parsers +// --------------------------------------------------------------------------- + +fn parse_type_integer(input: &mut RecordInput<'_>) -> PResult { + let width = opt(any).map(|v| v.unwrap_or(32) as u32).parse_next(input)?; + Ok(Type::Integer(width)) +} + +// --------------------------------------------------------------------------- +// Module record winnow parsers +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +struct ParsedGlobalVarRecord { + is_const: bool, + init_value_id: Option, + linkage: Linkage, + elem_type_id: Option, + legacy_placeholder: bool, +} + +fn decode_global_linkage(encoded: u64) -> Linkage { + if encoded == 3 { + Linkage::Internal + } else { + Linkage::External + } +} + +fn parse_global_var_record(input: &mut RecordInput<'_>) -> PResult { + let fields = rest.parse_next(input)?; + + if fields.len() >= 18 { + let Some(elem_type_id) = fields.get(2).copied() else { + return Err(ErrMode::Cut(ContextError::new())); + }; + let Some(flags) = fields.get(3).copied() else { + return Err(ErrMode::Cut(ContextError::new())); + }; + let Some(init_raw) = fields.get(4).copied() else { + return Err(ErrMode::Cut(ContextError::new())); + }; + let Some(linkage) = fields.get(5).copied() else { + return Err(ErrMode::Cut(ContextError::new())); + }; + + Ok(ParsedGlobalVarRecord { + is_const: flags & 1 != 0, + init_value_id: init_raw + .checked_sub(1) + .map(|value| u32::try_from(value).map_err(|_| ErrMode::Cut(ContextError::new()))) + .transpose()?, + linkage: decode_global_linkage(linkage), + elem_type_id: Some( + u32::try_from(elem_type_id).map_err(|_| ErrMode::Cut(ContextError::new()))?, + ), + legacy_placeholder: false, + }) + } else { + if fields.len() < 5 { + return Err(ErrMode::Cut(ContextError::new())); + } + + let raw_init = u32::try_from(fields[3]).map_err(|_| ErrMode::Cut(ContextError::new()))?; + Ok(ParsedGlobalVarRecord { + is_const: fields[2] != 0, + init_value_id: if raw_init > 1 { + Some(raw_init - 1) + } else { + None + }, + linkage: decode_global_linkage(fields[4]), + elem_type_id: if fields.len() >= 14 { + Some(u32::try_from(fields[13]).map_err(|_| ErrMode::Cut(ContextError::new()))?) + } else { + None + }, + legacy_placeholder: raw_init == 1, + }) + } +} + +fn parse_function_record(input: &mut RecordInput<'_>) -> PResult<(u32, bool, u32)> { + let func_type_id = any.map(|v: u64| v as u32).parse_next(input)?; + let _cc = any.parse_next(input)?; + let is_declaration = any.map(|v: u64| v != 0).parse_next(input)?; + let _linkage = opt(any).parse_next(input)?; + let paramattr = opt(any).map(|v| v.unwrap_or(0) as u32).parse_next(input)?; + Ok((func_type_id, is_declaration, paramattr)) +} + +// --------------------------------------------------------------------------- +// Instruction winnow parsers +// --------------------------------------------------------------------------- + +fn parse_ret_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let maybe_val = opt(any).parse_next(input)?; + match maybe_val { + None => Ok(Instruction::Ret(None)), + Some(val_id) => { + let op = map_parse_err(ctx.resolve_operand(val_id))?; + Ok(Instruction::Ret(Some(op))) + } + } +} + +fn parse_br_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let first = any.parse_next(input)?; + let second = opt(any).parse_next(input)?; + match second { + None => { + let dest = ctx.resolve_bb_name(first as u32); + Ok(Instruction::Jump { dest }) + } + Some(false_val) => { + let cond_id = any.parse_next(input)?; + let true_dest = ctx.resolve_bb_name(first as u32); + let false_dest = ctx.resolve_bb_name(false_val as u32); + let cond = map_parse_err(ctx.resolve_operand(cond_id))?; + Ok(Instruction::Br { + cond_ty: Type::Integer(1), + cond, + true_dest, + false_dest, + }) + } + } +} + +fn parse_binop_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let lhs_id = any.parse_next(input)?; + let rhs_id = any.parse_next(input)?; + let opcode = any.parse_next(input)?; + let lhs = map_parse_err(ctx.resolve_operand(lhs_id))?; + let rhs = map_parse_err(ctx.resolve_operand(rhs_id))?; + let ty = ctx.infer_type(&lhs); + let op = map_parse_err(opcode_to_binop(opcode, &ty))?; + Ok(Instruction::BinOp { + op, + ty, + lhs, + rhs, + result: ctx.result_name(), + }) +} + +fn parse_cmp2_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let lhs_id = any.parse_next(input)?; + let rhs_id = any.parse_next(input)?; + let pred_code = any.parse_next(input)?; + let lhs = map_parse_err(ctx.resolve_operand(lhs_id))?; + let rhs = map_parse_err(ctx.resolve_operand(rhs_id))?; + let ty = ctx.infer_type(&lhs); + let result = ctx.result_name(); + + if pred_code >= 32 { + Ok(Instruction::ICmp { + pred: map_parse_err(icmp_code_to_predicate(pred_code))?, + ty, + lhs, + rhs, + result, + }) + } else { + Ok(Instruction::FCmp { + pred: map_parse_err(fcmp_code_to_predicate(pred_code))?, + ty, + lhs, + rhs, + result, + }) + } +} + +fn parse_cast_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let val_id = any.parse_next(input)?; + let to_ty_id = any.parse_next(input)? as u32; + let cast_opcode = any.parse_next(input)?; + let value = map_parse_err(ctx.resolve_operand(val_id))?; + let from_ty = ctx.infer_type(&value); + let to_ty = map_parse_err(ctx.resolve_type(to_ty_id))?; + Ok(Instruction::Cast { + op: map_parse_err(opcode_to_cast(cast_opcode))?, + from_ty, + to_ty, + value, + result: ctx.result_name(), + }) +} + +fn parse_call_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let paramattr = any.parse_next(input)?; + let _packed_call_cc_info = any.parse_next(input)?; + let func_ty_id = any.parse_next(input)? as u32; + let callee_val_id = any.parse_next(input)?; + + let callee_name = map_parse_err(ctx.resolve_call_target_name(callee_val_id))?; + let (return_type, param_types) = map_parse_err(ctx.resolve_function_type(func_ty_id))?; + + let has_result = !matches!(return_type, Type::Void); + let result = if has_result { + Some(ctx.result_name()) + } else { + None + }; + let return_ty = if has_result { Some(return_type) } else { None }; + + let remaining = rest.parse_next(input)?; + let mut args = Vec::with_capacity(remaining.len()); + for (index, &rel_id) in remaining.iter().enumerate() { + let ty = if let Some(ty) = param_types.get(index).cloned() { + ty + } else if ctx.policy == ReadPolicy::Compatibility { + ctx.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "call instruction", + format!( + "call argument {index} exceeds imported function signature with {} parameter(s) and was imported as `ptr`", + param_types.len() + ), + ); + Type::Ptr + } else { + return map_parse_err(Err(ParseError::unsupported( + ctx.byte_offset, + "call instruction", + format!( + "call argument {index} exceeds imported function signature with {} parameter(s)", + param_types.len() + ), + ))); + }; + let op = map_parse_err(ctx.resolve_operand(rel_id))?; + args.push((ty, op)); + } + + Ok(Instruction::Call { + return_ty, + callee: callee_name, + args, + result, + attr_refs: if paramattr == 0 { + Vec::new() + } else { + ctx.paramattr_lists + .get((paramattr - 1) as usize) + .cloned() + .unwrap_or_default() + }, + }) +} + +fn parse_phi_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let ty_id = any.parse_next(input)? as u32; + let ty = map_parse_err(ctx.resolve_type(ty_id))?; + let remaining = rest.parse_next(input)?; + let mut incoming = Vec::new(); + let mut i = 0; + while i + 1 < remaining.len() { + let val_op = map_parse_err(ctx.resolve_phi_operand(remaining[i], &ty))?; + let bb_id = remaining[i + 1] as u32; + incoming.push((val_op, ctx.resolve_bb_name(bb_id))); + i += 2; + } + Ok(Instruction::Phi { + ty, + incoming, + result: ctx.result_name(), + }) +} + +fn parse_alloca_record( + ctx: &InstrContext<'_>, + input: &mut RecordInput<'_>, +) -> PResult { + let ty_id = any.parse_next(input)? as u32; + let ty = map_parse_err(ctx.resolve_type(ty_id))?; + Ok(Instruction::Alloca { + ty, + result: ctx.result_name(), + }) +} + +fn parse_load_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let ptr_id = any.parse_next(input)?; + let ty_id = any.parse_next(input)? as u32; + let ptr = map_parse_err(ctx.resolve_operand(ptr_id))?; + let ty = map_parse_err(ctx.resolve_type(ty_id))?; + let ptr_ty = ctx.infer_type(&ptr); + Ok(Instruction::Load { + ty, + ptr_ty, + ptr, + result: ctx.result_name(), + }) +} + +fn parse_store_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let ptr_id = any.parse_next(input)?; + let value_id = any.parse_next(input)?; + let ptr = map_parse_err(ctx.resolve_operand(ptr_id))?; + let value = map_parse_err(ctx.resolve_operand(value_id))?; + let ty = ctx.infer_type(&value); + let ptr_ty = ctx.infer_type(&ptr); + Ok(Instruction::Store { + ty, + value, + ptr_ty, + ptr, + }) +} + +fn parse_select_record( + ctx: &InstrContext<'_>, + input: &mut RecordInput<'_>, +) -> PResult { + let true_id = any.parse_next(input)?; + let false_id = any.parse_next(input)?; + let cond_id = any.parse_next(input)?; + let true_val = map_parse_err(ctx.resolve_operand(true_id))?; + let false_val = map_parse_err(ctx.resolve_operand(false_id))?; + let cond = map_parse_err(ctx.resolve_operand(cond_id))?; + let ty = ctx.infer_type(&true_val); + Ok(Instruction::Select { + cond, + true_val, + false_val, + ty, + result: ctx.result_name(), + }) +} + +fn parse_switch_record( + ctx: &InstrContext<'_>, + input: &mut RecordInput<'_>, +) -> PResult { + let ty_id = any.parse_next(input)? as u32; + let ty = map_parse_err(ctx.resolve_type(ty_id))?; + let value_id = any.parse_next(input)?; + let value = map_parse_err(ctx.resolve_operand(value_id))?; + let default_id = any.parse_next(input)? as u32; + let default_dest = ctx.resolve_bb_name(default_id); + + let remaining = rest.parse_next(input)?; + let mut cases = Vec::new(); + let mut i = 0; + while i + 1 < remaining.len() { + let case_val = sign_unrotate(remaining[i]); + let dest_id = remaining[i + 1] as u32; + cases.push((case_val, ctx.resolve_bb_name(dest_id))); + i += 2; + } + Ok(Instruction::Switch { + ty, + value, + default_dest, + cases, + }) +} + +fn parse_gep_record(ctx: &InstrContext<'_>, input: &mut RecordInput<'_>) -> PResult { + let inbounds = any.map(|v: u64| v != 0).parse_next(input)?; + let pointee_type_id = any.parse_next(input)? as u32; + let pointee_ty = map_parse_err(ctx.resolve_type(pointee_type_id))?; + let ptr_id = any.parse_next(input)?; + let ptr = map_parse_err(ctx.resolve_operand(ptr_id))?; + let ptr_ty = ctx.infer_type(&ptr); + + let remaining = rest.parse_next(input)?; + let indices = remaining + .iter() + .map(|&idx_id| map_parse_err(ctx.resolve_operand(idx_id))) + .collect::>>()?; + + Ok(Instruction::GetElementPtr { + inbounds, + pointee_ty, + ptr_ty, + ptr, + indices, + result: ctx.result_name(), + }) +} + +// --------------------------------------------------------------------------- +// Instruction dispatch +// --------------------------------------------------------------------------- + +fn dispatch_instruction( + code: u32, + ctx: &InstrContext<'_>, + values: &[u64], + byte_offset: usize, +) -> Result, ParseError> { + let mut input: RecordInput<'_> = values; + let make_err = |msg: &str| ParseError::malformed(byte_offset, "instruction record", msg); + + match code { + FUNC_CODE_INST_RET => parse_ret_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("RET record malformed")), + FUNC_CODE_INST_BR => parse_br_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("BR record has invalid number of values")), + FUNC_CODE_INST_BINOP => parse_binop_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("BINOP record too short")), + FUNC_CODE_INST_CMP2 => parse_cmp2_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("CMP2 record too short")), + FUNC_CODE_INST_CAST => parse_cast_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("CAST record too short")), + FUNC_CODE_INST_CALL => parse_call_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("CALL record too short")), + FUNC_CODE_INST_PHI => parse_phi_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("PHI record too short")), + FUNC_CODE_INST_ALLOCA => parse_alloca_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("ALLOCA record too short")), + FUNC_CODE_INST_LOAD => parse_load_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("LOAD record too short")), + FUNC_CODE_INST_STORE => parse_store_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("STORE record too short")), + FUNC_CODE_INST_SELECT => parse_select_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("SELECT record too short")), + FUNC_CODE_INST_SWITCH => parse_switch_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("SWITCH record too short")), + FUNC_CODE_INST_GEP => parse_gep_record(ctx, &mut input) + .map(Some) + .map_err(|_| make_err("GEP record too short")), + FUNC_CODE_INST_UNREACHABLE => Ok(Some(Instruction::Unreachable)), + _ if ctx.policy == ReadPolicy::Compatibility => { + ctx.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "instruction record", + format!("unsupported instruction record code {code} was dropped during import"), + ); + Ok(None) + } + _ => Err(ParseError::unsupported( + byte_offset, + "instruction record", + format!("unsupported instruction record code {code}"), + )), + } +} + +fn instruction_produces_value(code: u32, _values: &[u64], instr: &Instruction) -> bool { + match code { + FUNC_CODE_INST_BINOP + | FUNC_CODE_INST_CAST + | FUNC_CODE_INST_CMP2 + | FUNC_CODE_INST_PHI + | FUNC_CODE_INST_ALLOCA + | FUNC_CODE_INST_LOAD + | FUNC_CODE_INST_SELECT + | FUNC_CODE_INST_GEP => true, + FUNC_CODE_INST_CALL => matches!( + instr, + Instruction::Call { + result: Some(_), + .. + } + ), + _ => false, + } +} + +fn instruction_result_type(instr: &Instruction) -> Type { + match instr { + Instruction::BinOp { ty, .. } => ty.clone(), + Instruction::ICmp { .. } | Instruction::FCmp { .. } => Type::Integer(1), + Instruction::Call { return_ty, .. } => return_ty.clone().unwrap_or(Type::Void), + Instruction::Cast { to_ty, .. } => to_ty.clone(), + Instruction::Alloca { .. } => Type::Ptr, + Instruction::Load { ty, .. } => ty.clone(), + Instruction::Phi { ty, .. } => ty.clone(), + Instruction::Select { ty, .. } => ty.clone(), + Instruction::GetElementPtr { .. } => Type::Ptr, + _ => Type::Void, + } +} + +// --------------------------------------------------------------------------- +// BlockReader — block-level traversal using BitstreamReader +// --------------------------------------------------------------------------- + +struct BlockReader<'a> { + reader: BitstreamReader<'a>, + policy: ReadPolicy, + diagnostics: RefCell>, + type_table: Vec, + source_filename: Option, + target_triple: Option, + target_datalayout: Option, + global_value_table: Vec, + globals: Vec, + pending_global_initializers: Vec, + func_protos: Vec, + functions: Vec, + pending_struct_name: Option, + struct_types: Vec, + attribute_groups: Vec, + paramattr_lists: Vec>, + named_metadata: Vec, + metadata_nodes: Vec, + metadata_slot_map: Vec, + module_constants: Vec<(Type, Constant)>, + module_constant_value_offset: u32, + module_version: u32, + pending_strtab_names: Vec, + string_table: Vec, +} + +impl<'a> BlockReader<'a> { + fn new(data: &'a [u8], policy: ReadPolicy) -> Result { + if data.len() < 4 { + return Err(ParseError::malformed( + 0, + "bitcode header", + "data too short for magic bytes", + )); + } + if data[0] != 0x42 || data[1] != 0x43 || data[2] != 0xC0 || data[3] != 0xDE { + return Err(ParseError::malformed( + 0, + "bitcode header", + "invalid bitcode magic bytes", + )); + } + let mut reader = BitstreamReader::new(data); + reader.read_bits(32); + Ok(Self { + reader, + policy, + diagnostics: RefCell::new(Vec::new()), + type_table: Vec::new(), + source_filename: None, + target_triple: None, + target_datalayout: None, + global_value_table: Vec::new(), + globals: Vec::new(), + pending_global_initializers: Vec::new(), + func_protos: Vec::new(), + functions: Vec::new(), + pending_struct_name: None, + struct_types: Vec::new(), + attribute_groups: Vec::new(), + paramattr_lists: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + metadata_slot_map: Vec::new(), + module_constants: Vec::new(), + module_constant_value_offset: 0, + module_version: 0, + pending_strtab_names: Vec::new(), + string_table: Vec::new(), + }) + } + + fn error(&self, message: impl Into) -> ParseError { + ParseError::malformed(self.reader.byte_position(), "bitcode reader", message) + } + + fn unsupported(&self, context: &'static str, message: impl Into) -> ParseError { + ParseError::unsupported(self.reader.byte_position(), context, message) + } + + fn record_compatibility_diagnostic( + &self, + kind: ReadDiagnosticKind, + context: &'static str, + message: impl Into, + ) { + if self.policy == ReadPolicy::Compatibility { + self.diagnostics.borrow_mut().push(ReadDiagnostic { + kind, + offset: Some(self.reader.byte_position()), + context, + message: message.into(), + }); + } + } + + fn unsupported_or_recover( + &self, + context: &'static str, + strict_message: impl Into, + compatibility_message: impl Into, + fallback: T, + ) -> Result { + let strict_message = strict_message.into(); + let compatibility_message = compatibility_message.into(); + + match self.policy { + ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + context, + compatibility_message, + ); + Ok(fallback) + } + ReadPolicy::QirSubsetStrict => Err(self.unsupported(context, strict_message)), + } + } + + fn resolve_constant_type( + &self, + type_id: u32, + context: &'static str, + description: &'static str, + fallback: Type, + ) -> Result { + match self.type_table.get(type_id as usize).cloned() { + Some(ty) => Ok(ty), + None => self.unsupported_or_recover( + context, + format!("{description} references unknown type ID {type_id}"), + format!( + "{description} references unknown type ID {type_id} and was imported as `{fallback}`" + ), + fallback, + ), + } + } + + fn remap_operand_names(operand: &mut Operand, name_remap: &FxHashMap) { + match operand { + Operand::GlobalRef(name) => { + if let Some(final_name) = name_remap.get(name.as_str()) { + name.clone_from(final_name); + } + } + Operand::GetElementPtr { ptr, indices, .. } => { + if let Some(final_name) = name_remap.get(ptr.as_str()) { + ptr.clone_from(final_name); + } + for index in indices { + Self::remap_operand_names(index, name_remap); + } + } + Operand::LocalRef(_) + | Operand::TypedLocalRef(_, _) + | Operand::IntConst(_, _) + | Operand::FloatConst(_, _) + | Operand::NullPtr + | Operand::IntToPtr(_, _) => {} + } + } + + fn remap_instruction_names(instr: &mut Instruction, name_remap: &FxHashMap) { + match instr { + Instruction::Ret(Some(value)) => Self::remap_operand_names(value, name_remap), + Instruction::Br { cond, .. } => Self::remap_operand_names(cond, name_remap), + Instruction::BinOp { lhs, rhs, .. } + | Instruction::ICmp { lhs, rhs, .. } + | Instruction::FCmp { lhs, rhs, .. } => { + Self::remap_operand_names(lhs, name_remap); + Self::remap_operand_names(rhs, name_remap); + } + Instruction::Cast { value, .. } => Self::remap_operand_names(value, name_remap), + Instruction::Call { callee, args, .. } => { + if let Some(final_name) = name_remap.get(callee.as_str()) { + callee.clone_from(final_name); + } + for (_, operand) in args { + Self::remap_operand_names(operand, name_remap); + } + } + Instruction::Phi { incoming, .. } => { + for (operand, _) in incoming { + Self::remap_operand_names(operand, name_remap); + } + } + Instruction::Load { ptr, .. } => Self::remap_operand_names(ptr, name_remap), + Instruction::Store { value, ptr, .. } => { + Self::remap_operand_names(value, name_remap); + Self::remap_operand_names(ptr, name_remap); + } + Instruction::Select { + cond, + true_val, + false_val, + .. + } => { + Self::remap_operand_names(cond, name_remap); + Self::remap_operand_names(true_val, name_remap); + Self::remap_operand_names(false_val, name_remap); + } + Instruction::Switch { value, .. } => Self::remap_operand_names(value, name_remap), + Instruction::GetElementPtr { ptr, indices, .. } => { + Self::remap_operand_names(ptr, name_remap); + for index in indices { + Self::remap_operand_names(index, name_remap); + } + } + Instruction::Ret(None) + | Instruction::Jump { .. } + | Instruction::Alloca { .. } + | Instruction::Unreachable => {} + } + } + + fn remap_local_operand_names(operand: &mut Operand, name_remap: &FxHashMap) { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => { + if let Some(final_name) = name_remap.get(name.as_str()) { + name.clone_from(final_name); + } + } + Operand::GetElementPtr { indices, .. } => { + for index in indices { + Self::remap_local_operand_names(index, name_remap); + } + } + Operand::IntConst(_, _) + | Operand::FloatConst(_, _) + | Operand::NullPtr + | Operand::IntToPtr(_, _) + | Operand::GlobalRef(_) => {} + } + } + + fn remap_local_instruction_names( + instr: &mut Instruction, + name_remap: &FxHashMap, + ) { + match instr { + Instruction::Ret(Some(value)) => Self::remap_local_operand_names(value, name_remap), + Instruction::Br { cond, .. } => Self::remap_local_operand_names(cond, name_remap), + Instruction::BinOp { + lhs, rhs, result, .. + } + | Instruction::ICmp { + lhs, rhs, result, .. + } + | Instruction::FCmp { + lhs, rhs, result, .. + } => { + Self::remap_local_operand_names(lhs, name_remap); + Self::remap_local_operand_names(rhs, name_remap); + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Cast { value, result, .. } => { + Self::remap_local_operand_names(value, name_remap); + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Call { args, result, .. } => { + for (_, operand) in args { + Self::remap_local_operand_names(operand, name_remap); + } + if let Some(name) = result + && let Some(final_name) = name_remap.get(name.as_str()) + { + name.clone_from(final_name); + } + } + Instruction::Phi { + incoming, result, .. + } => { + for (operand, _) in incoming { + Self::remap_local_operand_names(operand, name_remap); + } + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Alloca { result, .. } => { + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Load { ptr, result, .. } => { + Self::remap_local_operand_names(ptr, name_remap); + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Store { value, ptr, .. } => { + Self::remap_local_operand_names(value, name_remap); + Self::remap_local_operand_names(ptr, name_remap); + } + Instruction::Select { + cond, + true_val, + false_val, + result, + .. + } => { + Self::remap_local_operand_names(cond, name_remap); + Self::remap_local_operand_names(true_val, name_remap); + Self::remap_local_operand_names(false_val, name_remap); + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Switch { value, .. } => { + Self::remap_local_operand_names(value, name_remap); + } + Instruction::GetElementPtr { + ptr, + indices, + result, + .. + } => { + Self::remap_local_operand_names(ptr, name_remap); + for index in indices { + Self::remap_local_operand_names(index, name_remap); + } + if let Some(final_name) = name_remap.get(result.as_str()) { + result.clone_from(final_name); + } + } + Instruction::Ret(None) | Instruction::Jump { .. } | Instruction::Unreachable => {} + } + } + + fn remap_module_symbol_uses(&mut self, name_remap: &FxHashMap) { + if name_remap.is_empty() { + return; + } + + for function in &mut self.functions { + for block in &mut function.basic_blocks { + for instruction in &mut block.instructions { + Self::remap_instruction_names(instruction, name_remap); + } + } + } + } + + fn resolve_strtab_name(&self, offset: usize, size: usize) -> Option { + let end = offset.checked_add(size)?; + let bytes = self.string_table.get(offset..end)?; + String::from_utf8(bytes.to_vec()).ok() + } + + fn apply_module_value_name( + &mut self, + value_id: usize, + name: String, + name_remap: &mut FxHashMap, + ) { + if value_id >= self.global_value_table.len() { + return; + } + + match self.global_value_table[value_id].clone() { + ValueEntry::Global(old_name) => { + let global_idx = self.global_index_from_value_id(value_id); + if global_idx < self.globals.len() { + self.globals[global_idx].name.clone_from(&name); + } + self.global_value_table[value_id] = ValueEntry::Global(name.clone()); + if old_name != name { + name_remap.insert(old_name, name); + } + } + ValueEntry::Function(old_name) => { + let func_idx = self.func_index_from_value_id(value_id); + if func_idx < self.functions.len() { + self.functions[func_idx].name.clone_from(&name); + } + self.global_value_table[value_id] = ValueEntry::Function(name.clone()); + if old_name != name { + name_remap.insert(old_name, name); + } + } + _ => {} + } + } + + fn apply_pending_strtab_names(&mut self) { + if self.pending_strtab_names.is_empty() || self.string_table.is_empty() { + return; + } + + let mut name_remap = FxHashMap::default(); + let pending_names = std::mem::take(&mut self.pending_strtab_names); + for pending in pending_names { + let Some(name) = self.resolve_strtab_name(pending.offset, pending.size) else { + continue; + }; + + self.apply_module_value_name(pending.value_id, name, &mut name_remap); + } + + self.remap_module_symbol_uses(&name_remap); + } + + fn read_record(&mut self, abbrev_id: u32) -> Result<(u32, Vec), ParseError> { + if abbrev_id == 3 { + Ok(self.reader.read_unabbrev_record()) + } else { + self.reader + .read_abbreviated_record(abbrev_id) + .map_err(|e| self.error(e)) + } + } + + // ------------------------------------------------------------------- + // Top-level reading + // ------------------------------------------------------------------- + + fn read_top_level(&mut self) -> Result<(), ParseError> { + let top_abbrev = 2; + while !self.reader.at_end() { + let abbrev_id = self.reader.read_abbrev_id(top_abbrev); + match abbrev_id { + 0 => break, + 1 => { + let (block_id, new_abbrev, block_len) = self.reader.enter_subblock(); + match block_id { + BLOCKINFO_BLOCK_ID => self.read_blockinfo_block(new_abbrev)?, + IDENTIFICATION_BLOCK_ID => self.reader.skip_block(block_len), + MODULE_BLOCK_ID => self.read_module_block(new_abbrev)?, + STRTAB_BLOCK_ID => self.read_strtab_block(new_abbrev)?, + _ => self.reader.skip_block(block_len), + } + } + _ => { + return Err(self.error(format!("unexpected top-level abbrev id {abbrev_id}"))); + } + } + } + Ok(()) + } + + // ------------------------------------------------------------------- + // Module block + // ------------------------------------------------------------------- + + fn read_module_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(MODULE_BLOCK_ID); + let mut func_body_index: usize = 0; + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (block_id, new_abbrev, block_len) = self.reader.enter_subblock(); + match block_id { + BLOCKINFO_BLOCK_ID => self.read_blockinfo_block(new_abbrev)?, + TYPE_BLOCK_ID_NEW => self.read_type_block(new_abbrev)?, + FUNCTION_BLOCK_ID => { + while func_body_index < self.func_protos.len() + && self.func_protos[func_body_index].is_declaration + { + func_body_index += 1; + } + if func_body_index < self.func_protos.len() { + self.read_function_block(new_abbrev, func_body_index)?; + func_body_index += 1; + } else { + self.reader.skip_block(block_len); + } + } + VALUE_SYMTAB_BLOCK_ID => self.read_module_vst(new_abbrev)?, + CONSTANTS_BLOCK_ID => { + self.read_module_constants(new_abbrev)?; + } + PARAMATTR_BLOCK_ID => self.read_paramattr_block(new_abbrev)?, + PARAMATTR_GROUP_BLOCK_ID => { + self.read_paramattr_group_block(new_abbrev)?; + } + METADATA_BLOCK_ID => self.read_metadata_block(new_abbrev)?, + _ => self.reader.skip_block(block_len), + } + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + self.handle_module_record(code, &values)?; + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + fn handle_module_record(&mut self, code: u32, values: &[u64]) -> Result<(), ParseError> { + match code { + MODULE_CODE_VERSION => { + self.module_version = values.first().copied().unwrap_or(0) as u32; + } + MODULE_CODE_TRIPLE => { + let s = parse_char_string + .parse(values) + .map_err(|_| self.error("failed to parse triple"))?; + self.target_triple = Some(s); + } + MODULE_CODE_DATALAYOUT => { + let s = parse_char_string + .parse(values) + .map_err(|_| self.error("failed to parse datalayout"))?; + self.target_datalayout = Some(s); + } + MODULE_CODE_GLOBALVAR => { + let mut input: RecordInput<'_> = values; + let parsed_record = parse_global_var_record(&mut input) + .map_err(|_| self.error("global var record too short"))?; + let ty = match parsed_record.elem_type_id { + Some(type_id) => match self.type_table.get(type_id as usize).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "global variable", + format!( + "global variable references unknown element type ID {type_id} and was imported as `ptr`" + ), + ); + Type::Ptr + } + None => { + return Err(self.unsupported( + "global variable", + format!( + "global variable references unknown element type ID {type_id}" + ), + )); + } + }, + None => Type::Ptr, + }; + let initializer = if parsed_record.legacy_placeholder { + Some(self.unsupported_or_recover( + "global variable", + "legacy placeholder global initializer encoding is not supported", + "legacy placeholder global initializer encoding was imported as null", + Constant::Null, + )?) + } else { + None + }; + let global = GlobalVariable { + name: String::new(), + ty, + linkage: parsed_record.linkage, + is_constant: parsed_record.is_const, + initializer, + }; + let idx = self.globals.len(); + self.globals.push(global); + if let Some(value_id) = parsed_record.init_value_id { + self.pending_global_initializers + .push(PendingGlobalInitializer { + global_index: idx, + value_id, + }); + } + self.global_value_table + .push(ValueEntry::Global(format!("__global_{idx}"))); + } + MODULE_CODE_VSTOFFSET => {} + MODULE_CODE_SOURCE_FILENAME => { + let s = parse_char_string + .parse(values) + .map_err(|_| self.error("failed to parse source_filename"))?; + self.source_filename = Some(s); + } + MODULE_CODE_FUNCTION => { + let legacy_type_is_function = values + .first() + .and_then(|value| self.type_table.get(*value as usize)) + .is_some_and(|ty| matches!(ty, Type::Function(_, _))); + let modern_type_is_function = values + .get(2) + .and_then(|value| self.type_table.get(*value as usize)) + .is_some_and(|ty| matches!(ty, Type::Function(_, _))); + let use_modern_v2_layout = + self.module_version >= 2 && modern_type_is_function && !legacy_type_is_function; + + let (func_type_id, is_declaration, paramattr, pending_name) = + if use_modern_v2_layout { + let func_type_id = values.get(2).copied().unwrap_or(0) as u32; + let is_declaration = values.get(4).copied().unwrap_or(0) != 0; + let paramattr = values.get(6).copied().unwrap_or(0) as u32; + let pending_name = Some(( + values.first().copied().unwrap_or(0) as usize, + values.get(1).copied().unwrap_or(0) as usize, + )); + (func_type_id, is_declaration, paramattr, pending_name) + } else { + let mut input: RecordInput<'_> = values; + let (func_type_id, is_declaration, paramattr) = + parse_function_record(&mut input) + .map_err(|_| self.error("function record too short"))?; + (func_type_id, is_declaration, paramattr, None) + }; + + let proto = FuncProto { + func_type_id, + is_declaration, + paramattr_index: paramattr, + }; + + let (return_type, param_types) = match self.type_table.get(func_type_id as usize) { + Some(Type::Function(ret, params)) => (ret.as_ref().clone(), params.clone()), + _ if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "function declaration", + format!( + "function record references non-function type ID {func_type_id} and was imported as `void ()`" + ), + ); + (Type::Void, Vec::new()) + } + _ => { + return Err(self.unsupported( + "function declaration", + format!( + "function record references non-function type ID {func_type_id}" + ), + )); + } + }; + + // Resolve attribute group refs from paramattr index (1-based; 0 = no attrs) + let attribute_group_refs = if paramattr > 0 { + self.paramattr_lists + .get((paramattr - 1) as usize) + .cloned() + .unwrap_or_default() + } else { + Vec::new() + }; + + let func = Function { + name: String::new(), + return_type, + params: param_types + .into_iter() + .enumerate() + .map(|(index, ty)| Param { + ty, + name: Some(format!("param_{index}")), + }) + .collect(), + is_declaration, + attribute_group_refs, + basic_blocks: Vec::new(), + }; + + let func_idx = self.functions.len(); + let value_id = self.global_value_table.len(); + self.functions.push(func); + self.func_protos.push(proto); + self.global_value_table + .push(ValueEntry::Function(format!("__func_{func_idx}"))); + if let Some((offset, size)) = pending_name { + self.pending_strtab_names.push(PendingStrtabName { + value_id, + offset, + size, + }); + } + } + _ => {} + } + Ok(()) + } + + // ------------------------------------------------------------------- + // Type block + // ------------------------------------------------------------------- + + fn read_type_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(TYPE_BLOCK_ID_NEW); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + self.handle_type_record(code, &values)?; + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + fn handle_type_record(&mut self, code: u32, values: &[u64]) -> Result<(), ParseError> { + let mut input: RecordInput<'_> = values; + match code { + TYPE_CODE_NUMENTRY => { + if let Some(&count) = values.first() { + self.type_table.reserve(count as usize); + } + } + TYPE_CODE_VOID => self.type_table.push(Type::Void), + TYPE_CODE_HALF => self.type_table.push(Type::Half), + TYPE_CODE_FLOAT => self.type_table.push(Type::Float), + TYPE_CODE_DOUBLE => self.type_table.push(Type::Double), + TYPE_CODE_LABEL => self.type_table.push(Type::Label), + TYPE_CODE_INTEGER => { + let ty = parse_type_integer(&mut input) + .map_err(|_| self.error("INTEGER type record malformed"))?; + self.type_table.push(ty); + } + TYPE_CODE_OPAQUE_POINTER => self.type_table.push(Type::Ptr), + TYPE_CODE_POINTER => { + let Some(inner_id) = values.first().copied() else { + self.type_table.push(Type::Ptr); + return Ok(()); + }; + let inner_id = inner_id as u32; + let inner = match self.type_table.get(inner_id as usize).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "type record", + format!( + "pointer type references unknown element type ID {inner_id} and was imported as `void`" + ), + ); + Type::Void + } + None => { + return Err(self.unsupported( + "type record", + format!("pointer type references unknown element type ID {inner_id}"), + )); + } + }; + let ty = match inner { + Type::Named(name) => Type::NamedPtr(name), + other => Type::TypedPtr(Box::new(other)), + }; + self.type_table.push(ty); + } + TYPE_CODE_STRUCT_NAME => { + let name = parse_char_string(&mut input) + .map_err(|_| self.error("STRUCT_NAME record malformed"))?; + self.pending_struct_name = Some(name); + } + TYPE_CODE_OPAQUE => { + let name = self + .pending_struct_name + .take() + .unwrap_or_else(|| "unknown".to_string()); + self.struct_types.push(name.clone()); + self.type_table.push(Type::Named(name)); + } + TYPE_CODE_ARRAY => { + if values.len() < 2 { + return Err(self.error("ARRAY type record too short")); + } + let len = values[0]; + let elem_id = values[1] as u32; + let elem = match self.type_table.get(elem_id as usize).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "type record", + format!( + "array type references unknown element type ID {elem_id} and was imported as `void`" + ), + ); + Type::Void + } + None => { + return Err(self.unsupported( + "type record", + format!("array type references unknown element type ID {elem_id}"), + )); + } + }; + let ty = Type::Array(len, Box::new(elem)); + self.type_table.push(ty); + } + TYPE_CODE_FUNCTION_TYPE => { + if values.len() < 2 { + return Err(self.error("FUNCTION type record too short")); + } + let ret_id = values[1] as u32; + let ret = match self.type_table.get(ret_id as usize).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "type record", + format!( + "function type references unknown return type ID {ret_id} and was imported as `void`" + ), + ); + Type::Void + } + None => { + return Err(self.unsupported( + "type record", + format!("function type references unknown return type ID {ret_id}"), + )); + } + }; + let mut param_types = Vec::with_capacity(values.len().saturating_sub(2)); + for ¶m_id in &values[2..] { + let param_id = param_id as u32; + let param_ty = match self.type_table.get(param_id as usize).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "type record", + format!( + "function type references unknown parameter type ID {param_id} and was imported as `void`" + ), + ); + Type::Void + } + None => { + return Err(self.unsupported( + "type record", + format!( + "function type references unknown parameter type ID {param_id}" + ), + )); + } + }; + param_types.push(param_ty); + } + let ty = Type::Function(Box::new(ret), param_types); + self.type_table.push(ty); + } + _ if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "type record", + format!("unsupported type record code {code} was imported as `void`"), + ); + self.type_table.push(Type::Void); + } + _ => { + return Err(self.unsupported( + "type record", + format!("unsupported type record code {code}"), + )); + } + } + Ok(()) + } + + // ------------------------------------------------------------------- + // Function block + // ------------------------------------------------------------------- + + fn read_function_block( + &mut self, + abbrev_width: u32, + func_index: usize, + ) -> Result<(), ParseError> { + self.reader.push_block_scope(FUNCTION_BLOCK_ID); + let mut local_values: Vec = Vec::new(); + + let func = &self.functions[func_index]; + for p in &func.params { + let name = p + .name + .clone() + .unwrap_or_else(|| format!("param_{}", local_values.len())); + local_values.push(ValueEntry::Param(name, p.ty.clone())); + } + + let mut num_bbs: usize = 0; + let mut current_bb: usize = 0; + let mut basic_blocks: Vec = Vec::new(); + let mut current_instructions: Vec = Vec::new(); + let mut bb_names: FxHashMap = FxHashMap::default(); + let mut local_name_entries: FxHashMap = FxHashMap::default(); + let mut next_result_id: u32 = + self.global_value_table.len() as u32 + local_values.len() as u32; + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + if !current_instructions.is_empty() || current_bb < num_bbs { + let bb_name = bb_names + .get(&(current_bb as u32)) + .cloned() + .unwrap_or_else(|| format!("bb_{current_bb}")); + basic_blocks.push(BasicBlock { + name: bb_name, + instructions: std::mem::take(&mut current_instructions), + }); + } + break; + } + 1 => { + let (block_id, new_abbrev, block_len) = self.reader.enter_subblock(); + match block_id { + CONSTANTS_BLOCK_ID => { + self.read_function_constants( + new_abbrev, + &mut local_values, + &mut next_result_id, + )?; + } + VALUE_SYMTAB_BLOCK_ID => { + self.read_function_vst( + new_abbrev, + &mut bb_names, + &mut local_name_entries, + )?; + } + _ => self.reader.skip_block(block_len), + } + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + if code == FUNC_CODE_DECLAREBLOCKS { + num_bbs = values.first().copied().unwrap_or(1) as usize; + } else { + let is_terminator = matches!( + code, + FUNC_CODE_INST_RET + | FUNC_CODE_INST_BR + | FUNC_CODE_INST_SWITCH + | FUNC_CODE_INST_UNREACHABLE + ); + let byte_offset = self.reader.byte_position(); + + let ctx = InstrContext { + global_value_table: &self.global_value_table, + local_values: &local_values, + type_table: &self.type_table, + paramattr_lists: &self.paramattr_lists, + bb_names: &bb_names, + diagnostics: &self.diagnostics, + current_value_id: next_result_id, + byte_offset, + policy: self.policy, + }; + + if let Some(instr) = dispatch_instruction(code, &ctx, &values, byte_offset)? + { + let produces_value = instruction_produces_value(code, &values, &instr); + + if produces_value { + let result_ty = instruction_result_type(&instr); + local_values.push(ValueEntry::Local( + format!("val_{next_result_id}"), + result_ty, + )); + next_result_id += 1; + } + + current_instructions.push(instr); + } + + if is_terminator && current_bb < num_bbs { + let bb_name = bb_names + .get(&(current_bb as u32)) + .cloned() + .unwrap_or_else(|| format!("bb_{current_bb}")); + basic_blocks.push(BasicBlock { + name: bb_name, + instructions: std::mem::take(&mut current_instructions), + }); + current_bb += 1; + } + } + } + } + } + + let mut local_name_remap = FxHashMap::default(); + let global_value_count = self.global_value_table.len() as u32; + let uses_absolute_local_ids = local_name_entries + .keys() + .any(|value_id| *value_id >= local_values.len() as u32); + for (value_id, name) in local_name_entries { + let local_id = if uses_absolute_local_ids { + let Some(local_id) = value_id.checked_sub(global_value_count) else { + continue; + }; + local_id + } else { + value_id + }; + let Some(entry) = local_values.get(local_id as usize) else { + continue; + }; + + match entry { + ValueEntry::Local(old_name, _) | ValueEntry::Param(old_name, _) => { + if old_name != &name { + local_name_remap.insert(old_name.clone(), name); + } + } + _ => {} + } + } + + for (i, bb) in basic_blocks.iter_mut().enumerate() { + if let Some(name) = bb_names.get(&(i as u32)) { + bb.name.clone_from(name); + } + + for instruction in &mut bb.instructions { + Self::remap_local_instruction_names(instruction, &local_name_remap); + remap_instruction_block_names(instruction, &bb_names); + } + } + + let func = &mut self.functions[func_index]; + for param in &mut func.params { + if let Some(name) = &mut param.name + && let Some(final_name) = local_name_remap.get(name.as_str()) + { + name.clone_from(final_name); + } + } + func.basic_blocks = basic_blocks; + self.reader.pop_block_scope(); + Ok(()) + } + + // ------------------------------------------------------------------- + // Attribute blocks + // ------------------------------------------------------------------- + + const PARAMATTR_GRP_CODE_ENTRY: u32 = 3; + const PARAMATTR_CODE_ENTRY: u32 = 2; + + fn read_paramattr_group_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(PARAMATTR_GROUP_BLOCK_ID); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + if code == Self::PARAMATTR_GRP_CODE_ENTRY && values.len() >= 2 { + let group_id = values[0] as u32; + // values[1] = param_index (0xFFFFFFFF for function attrs); not stored + let attributes = Self::parse_attr_encodings( + &values[2..], + self.policy, + self.reader.byte_position(), + &self.diagnostics, + )?; + self.attribute_groups.push(AttributeGroup { + id: group_id, + attributes, + }); + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + fn parse_attr_encodings( + values: &[u64], + policy: ReadPolicy, + byte_offset: usize, + diagnostics: &RefCell>, + ) -> Result, ParseError> { + let mut attrs = Vec::new(); + let mut i = 0; + while i < values.len() { + let kind = values[i]; + i += 1; + match kind { + // Code 3 = string attribute: null-terminated string + 3 => { + let mut s = String::new(); + while i < values.len() && values[i] != 0 { + s.push(values[i] as u8 as char); + i += 1; + } + i += 1; // skip null terminator + attrs.push(Attribute::StringAttr(s)); + } + // Code 4 = string key/value: key (null-terminated), value (null-terminated) + 4 => { + let mut key = String::new(); + while i < values.len() && values[i] != 0 { + key.push(values[i] as u8 as char); + i += 1; + } + i += 1; // skip null terminator + let mut val = String::new(); + while i < values.len() && values[i] != 0 { + val.push(values[i] as u8 as char); + i += 1; + } + i += 1; // skip null terminator + attrs.push(Attribute::KeyValue(key, val)); + } + // Skip enum (0/1) and other attribute kinds for v1 + _ if policy == ReadPolicy::Compatibility => { + diagnostics.borrow_mut().push(ReadDiagnostic { + kind: ReadDiagnosticKind::UnsupportedSemanticConstruct, + offset: Some(byte_offset), + context: "attribute group", + message: format!( + "attribute group contains unsupported encoded attribute kind {kind}; remaining attributes were skipped" + ), + }); + break; + } + _ => { + return Err(ParseError::unsupported( + byte_offset, + "attribute group", + format!( + "attribute group contains unsupported encoded attribute kind {kind}" + ), + )); + } + } + } + Ok(attrs) + } + + fn read_paramattr_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(PARAMATTR_BLOCK_ID); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + if code == Self::PARAMATTR_CODE_ENTRY { + self.paramattr_lists + .push(values.iter().map(|&v| v as u32).collect()); + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + // ------------------------------------------------------------------- + // Metadata block + // ------------------------------------------------------------------- + + fn read_metadata_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(METADATA_BLOCK_ID); + let mut pending_named_metadata_name: Option = None; + let mut pending_metadata_nodes: Vec> = Vec::new(); + let mut pending_named_metadata_nodes: Vec<(String, Vec)> = Vec::new(); + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + match code { + METADATA_STRING_OLD => { + let s: String = values.iter().map(|&v| v as u8 as char).collect(); + self.metadata_slot_map.push(MetadataSlotEntry::String(s)); + } + METADATA_VALUE => { + if values.len() < 2 { + return Err(self.error("METADATA_VALUE record too short")); + } + let type_id = values[0] as usize; + let value_id = values[1] as usize; + let ty = match self.type_table.get(type_id).cloned() { + Some(ty) => ty, + None if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "metadata", + format!( + "metadata value references unknown type ID {type_id} and was imported as `void`" + ), + ); + Type::Void + } + None => { + return Err(self.unsupported( + "metadata", + format!( + "metadata value references unknown type ID {type_id}" + ), + )); + } + }; + // Look up the value from global_value_table + // (module constants are stored there) + let val = match self.global_value_table.get(value_id) { + Some(ValueEntry::Constant(_, Constant::Int(v))) => *v, + _ if self.policy == ReadPolicy::Compatibility => { + self.record_compatibility_diagnostic( + ReadDiagnosticKind::UnsupportedSemanticConstruct, + "metadata", + format!( + "metadata value {value_id} is not an integer constant and was normalized to `0` during import" + ), + ); + 0 + } + _ => { + return Err(self.unsupported( + "metadata", + format!( + "metadata value {value_id} is not an integer constant and would be normalized during import" + ), + )); + } + }; + self.metadata_slot_map + .push(MetadataSlotEntry::Value(ty, val)); + } + METADATA_NODE => { + let node_id = self.metadata_nodes.len() as u32; + pending_metadata_nodes.push( + values + .iter() + .map(|&operand_ref| operand_ref as usize) + .collect(), + ); + self.metadata_nodes.push(MetadataNode { + id: node_id, + values: Vec::new(), + }); + self.metadata_slot_map + .push(MetadataSlotEntry::Node(node_id)); + } + METADATA_NAME => { + let s: String = values.iter().map(|&v| v as u8 as char).collect(); + pending_named_metadata_name = Some(s); + } + METADATA_NAMED_NODE => { + let name = pending_named_metadata_name.take().unwrap_or_default(); + pending_named_metadata_nodes.push(( + name, + values.iter().map(|&slot_ref| slot_ref as usize).collect(), + )); + } + _ => {} // skip other metadata codes + } + } + } + } + self.reader.pop_block_scope(); + + for (node, operand_slots) in self.metadata_nodes.iter_mut().zip(pending_metadata_nodes) { + let mut node_values = Vec::new(); + for slot_idx in operand_slots { + match self.metadata_slot_map.get(slot_idx) { + Some(MetadataSlotEntry::String(s)) => { + node_values.push(MetadataValue::String(s.clone())); + } + Some(MetadataSlotEntry::Value(ty, val)) => { + node_values.push(MetadataValue::Int(ty.clone(), *val)); + } + Some(MetadataSlotEntry::Node(child_id)) => { + node_values.push(MetadataValue::NodeRef(*child_id)); + } + None => {} + } + } + node.values = node_values; + } + + for (name, slot_refs) in pending_named_metadata_nodes { + let mut node_refs = Vec::new(); + for slot_idx in slot_refs { + if let Some(MetadataSlotEntry::Node(node_id)) = self.metadata_slot_map.get(slot_idx) + { + node_refs.push(*node_id); + } + } + self.named_metadata.push(NamedMetadata { name, node_refs }); + } + + // Post-process: reconstruct SubList values from synthetic child nodes + self.reconstruct_sublists(); + + Ok(()) + } + + /// Reconstructs `MetadataValue::SubList` from synthetic child nodes. + /// + /// Nodes that are referenced only from other nodes' operands (not from + /// named_metadata directly) are "synthetic child" nodes. We replace + /// their parent's `NodeRef` with `SubList` containing the child's values, + /// then remove the synthetic children and re-number remaining nodes. + fn reconstruct_sublists(&mut self) { + use rustc_hash::FxHashSet; + + fn expand_synthetic_value( + value: &MetadataValue, + synthetic_ids: &FxHashSet, + child_values: &FxHashMap>, + ) -> MetadataValue { + match value { + MetadataValue::NodeRef(child_id) if synthetic_ids.contains(child_id) => { + let Some(children) = child_values.get(child_id) else { + return MetadataValue::NodeRef(*child_id); + }; + + MetadataValue::SubList( + children + .iter() + .map(|child| expand_synthetic_value(child, synthetic_ids, child_values)) + .collect(), + ) + } + MetadataValue::SubList(children) => MetadataValue::SubList( + children + .iter() + .map(|child| expand_synthetic_value(child, synthetic_ids, child_values)) + .collect(), + ), + MetadataValue::Int(ty, value) => MetadataValue::Int(ty.clone(), *value), + MetadataValue::String(text) => MetadataValue::String(text.clone()), + MetadataValue::NodeRef(node_id) => MetadataValue::NodeRef(*node_id), + } + } + + fn remap_node_refs(values: &mut [MetadataValue], id_remap: &FxHashMap) { + for value in values { + match value { + MetadataValue::NodeRef(node_id) => { + if let Some(remapped) = id_remap.get(node_id) { + *node_id = *remapped; + } + } + MetadataValue::SubList(children) => remap_node_refs(children, id_remap), + MetadataValue::Int(_, _) | MetadataValue::String(_) => {} + } + } + } + + // Collect node IDs referenced directly by named_metadata + let directly_referenced: FxHashSet = self + .named_metadata + .iter() + .flat_map(|nm| nm.node_refs.iter().copied()) + .collect(); + + // Collect node IDs referenced from other nodes' operands + let mut node_ref_parents: FxHashMap> = FxHashMap::default(); + for node in &self.metadata_nodes { + for val in &node.values { + if let MetadataValue::NodeRef(child_id) = val { + node_ref_parents.entry(*child_id).or_default().push(node.id); + } + } + } + + // Identify synthetic child nodes: referenced from other nodes only, + // not from named_metadata + let synthetic_ids: FxHashSet = node_ref_parents + .keys() + .copied() + .filter(|id| !directly_referenced.contains(id)) + .collect(); + + if synthetic_ids.is_empty() { + return; + } + + // Build a map from node ID to its values for synthetic children + let child_values: FxHashMap> = self + .metadata_nodes + .iter() + .filter(|n| synthetic_ids.contains(&n.id)) + .map(|n| (n.id, n.values.clone())) + .collect(); + + // Replace synthetic NodeRef operands with recursively reconstructed SubList values. + for node in &mut self.metadata_nodes { + node.values = node + .values + .iter() + .map(|value| expand_synthetic_value(value, &synthetic_ids, &child_values)) + .collect(); + } + + // Remove synthetic child nodes + self.metadata_nodes + .retain(|n| !synthetic_ids.contains(&n.id)); + + // Re-number remaining node IDs sequentially + let old_ids: Vec = self.metadata_nodes.iter().map(|n| n.id).collect(); + let id_remap: FxHashMap = old_ids + .iter() + .enumerate() + .map(|(new_idx, &old_id)| (old_id, new_idx as u32)) + .collect(); + + for node in &mut self.metadata_nodes { + node.id = id_remap[&node.id]; + remap_node_refs(&mut node.values, &id_remap); + } + + // Update named_metadata node_refs + for nm in &mut self.named_metadata { + nm.node_refs = nm + .node_refs + .iter() + .filter_map(|old_id| id_remap.get(old_id).copied()) + .collect(); + } + } + + // ------------------------------------------------------------------- + // Constant expression helpers + // ------------------------------------------------------------------- + + /// Resolve an absolute value ID to a global name from the global value table. + fn resolve_global_name_by_id( + &self, + value_id: u32, + context: &'static str, + description: &'static str, + ) -> Result { + match self.global_value_table.get(value_id as usize) { + Some(ValueEntry::Global(name) | ValueEntry::Function(name)) => Ok(name.clone()), + _ => self.unsupported_or_recover( + context, + format!("{description} value ID {value_id} does not resolve to a global value"), + format!( + "{description} value ID {value_id} does not resolve to a global value and was imported as `unknown_{value_id}`" + ), + format!("unknown_{value_id}"), + ), + } + } + + /// Resolve an absolute value ID to an integer constant from the global table. + fn resolve_constant_int_from_global_table( + &self, + value_id: u32, + context: &'static str, + description: &'static str, + ) -> Result { + match self.global_value_table.get(value_id as usize) { + Some(ValueEntry::Constant(_, Constant::Int(val))) => Ok(*val), + _ => self.unsupported_or_recover( + context, + format!("{description} value ID {value_id} is not an integer constant"), + format!( + "{description} value ID {value_id} is not an integer constant and was normalized to `0` during import" + ), + 0, + ), + } + } + + /// Resolve an absolute value ID to an integer constant, checking both + /// global and function-local value tables. + fn resolve_constant_int_from_tables( + &self, + value_id: u32, + local_values: &[ValueEntry], + context: &'static str, + description: &'static str, + ) -> Result { + let global_count = self.global_value_table.len() as u32; + if value_id < global_count { + self.resolve_constant_int_from_global_table(value_id, context, description) + } else { + let local_idx = (value_id - global_count) as usize; + if let Some(ValueEntry::Constant(_, Constant::Int(val))) = local_values.get(local_idx) { + Ok(*val) + } else { + self.unsupported_or_recover( + context, + format!("{description} value ID {value_id} is not an integer constant"), + format!( + "{description} value ID {value_id} is not an integer constant and was normalized to `0` during import" + ), + 0, + ) + } + } + } + + /// Parse a GEP CE record and push the result into the global value table. + fn parse_gep_ce_into_global_table(&mut self, values: &[u64]) -> Result<(), ParseError> { + let mut op_idx = 0; + // If odd number of values, first element is the pointee type ID + let source_type_id = if values.len() % 2 == 1 { + let id = values[op_idx] as u32; + op_idx += 1; + id + } else { + 0 + }; + let source_ty = self.resolve_constant_type( + source_type_id, + "constant expression", + "getelementptr constant source type", + Type::Void, + )?; + + // First pair: pointer type + pointer value + if op_idx + 1 >= values.len() { + return Err(self.error("getelementptr constant expression record too short")); + } + let ptr_type_id = values[op_idx] as u32; + let ptr_value_id = values[op_idx + 1] as u32; + op_idx += 2; + + let ptr_ty = self.resolve_constant_type( + ptr_type_id, + "constant expression", + "getelementptr constant pointer type", + Type::Ptr, + )?; + let ptr_name = self.resolve_global_name_by_id( + ptr_value_id, + "constant expression", + "getelementptr constant base pointer", + )?; + + // Remaining pairs: index type + index value + let mut indices = Vec::new(); + while op_idx + 1 < values.len() { + let idx_type_id = values[op_idx] as u32; + let idx_value_id = values[op_idx + 1] as u32; + let idx_ty = self.resolve_constant_type( + idx_type_id, + "constant expression", + "getelementptr constant index type", + Type::Integer(64), + )?; + let idx_val = self.resolve_constant_int_from_global_table( + idx_value_id, + "constant expression", + "getelementptr constant index", + )?; + indices.push(Operand::IntConst(idx_ty, idx_val)); + op_idx += 2; + } + + self.global_value_table.push(ValueEntry::GepConst { + source_ty, + ptr_name, + ptr_ty, + indices, + }); + + Ok(()) + } + + /// Parse a GEP CE record and push the result into function-local values. + fn parse_gep_ce_into_local_values( + &self, + values: &[u64], + local_values: &mut Vec, + ) -> Result<(), ParseError> { + let mut op_idx = 0; + // If odd number of values, first element is the pointee type ID + let source_type_id = if values.len() % 2 == 1 { + let id = values[op_idx] as u32; + op_idx += 1; + id + } else { + 0 + }; + let source_ty = self.resolve_constant_type( + source_type_id, + "constant expression", + "getelementptr constant source type", + Type::Void, + )?; + + // First pair: pointer type + pointer value + if op_idx + 1 >= values.len() { + return Err(self.error("getelementptr constant expression record too short")); + } + let ptr_type_id = values[op_idx] as u32; + let ptr_value_id = values[op_idx + 1] as u32; + op_idx += 2; + + let ptr_ty = self.resolve_constant_type( + ptr_type_id, + "constant expression", + "getelementptr constant pointer type", + Type::Ptr, + )?; + let global_count = self.global_value_table.len() as u32; + let ptr_name = if ptr_value_id < global_count { + self.resolve_global_name_by_id( + ptr_value_id, + "constant expression", + "getelementptr constant base pointer", + )? + } else { + let local_idx = (ptr_value_id - global_count) as usize; + match local_values.get(local_idx) { + Some(ValueEntry::Global(name) | ValueEntry::Function(name)) => name.clone(), + _ => { + self.unsupported_or_recover( + "constant expression", + format!( + "getelementptr constant base pointer value ID {ptr_value_id} does not resolve to a global value" + ), + format!( + "getelementptr constant base pointer value ID {ptr_value_id} does not resolve to a global value and was imported as `unknown_{ptr_value_id}`" + ), + format!("unknown_{ptr_value_id}"), + )? + } + } + }; + + // Remaining pairs: index type + index value + let mut indices = Vec::new(); + while op_idx + 1 < values.len() { + let idx_type_id = values[op_idx] as u32; + let idx_value_id = values[op_idx + 1] as u32; + let idx_ty = self.resolve_constant_type( + idx_type_id, + "constant expression", + "getelementptr constant index type", + Type::Integer(64), + )?; + let idx_val = self.resolve_constant_int_from_tables( + idx_value_id, + local_values, + "constant expression", + "getelementptr constant index", + )?; + indices.push(Operand::IntConst(idx_ty, idx_val)); + op_idx += 2; + } + + local_values.push(ValueEntry::GepConst { + source_ty, + ptr_name, + ptr_ty, + indices, + }); + + Ok(()) + } + + // ------------------------------------------------------------------- + // Constants block + // ------------------------------------------------------------------- + + fn read_module_constants(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.module_constant_value_offset = self.global_value_table.len() as u32; + self.reader.push_block_scope(CONSTANTS_BLOCK_ID); + let mut current_type_id: u32 = 0; + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + match code { + CST_CODE_SETTYPE => { + let Some(type_id) = values.first().copied() else { + return Err(self.error("SETTYPE constant record too short")); + }; + current_type_id = type_id as u32; + } + CST_CODE_INTEGER => { + let Some(encoded) = values.first().copied() else { + return Err(self.error("INTEGER constant record too short")); + }; + let val = sign_unrotate(encoded); + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "integer constant current type", + Type::Void, + )?; + self.module_constants.push((ty.clone(), Constant::Int(val))); + self.global_value_table + .push(ValueEntry::Constant(ty, Constant::Int(val))); + } + CST_CODE_FLOAT => { + let Some(bits) = values.first().copied() else { + return Err(self.error("FLOAT constant record too short")); + }; + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "floating constant current type", + Type::Void, + )?; + let Some(val) = ty.decode_float_bits(bits) else { + return Err(self + .error("FLOAT constant record has non-floating current type")); + }; + self.module_constants + .push((ty.clone(), Constant::float(ty.clone(), val))); + self.global_value_table + .push(ValueEntry::Constant(ty.clone(), Constant::float(ty, val))); + } + CST_CODE_CSTRING => { + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "cstring constant current type", + Type::Void, + )?; + let bytes = values + .iter() + .map(|value| { + u8::try_from(*value).map_err(|_| { + self.error("CSTRING constant contains out-of-range byte") + }) + }) + .collect::, _>>()?; + let text = String::from_utf8(bytes).map_err(|_| { + self.error("CSTRING constant contains invalid UTF-8") + })?; + self.module_constants + .push((ty.clone(), Constant::CString(text.clone()))); + self.global_value_table + .push(ValueEntry::Constant(ty, Constant::CString(text))); + } + CST_CODE_NULL => { + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "null constant current type", + Type::Void, + )?; + self.module_constants.push((ty.clone(), Constant::Null)); + self.global_value_table + .push(ValueEntry::Constant(ty, Constant::Null)); + } + CST_CODE_CE_CAST => { + if values.len() >= 3 && values[0] == 10 { + let src_value_id = values[2] as u32; + let int_val = self.resolve_constant_int_from_global_table( + src_value_id, + "constant expression", + "inttoptr constant source", + )?; + let target_ty = self.resolve_constant_type( + current_type_id, + "constant expression", + "inttoptr constant target type", + Type::Ptr, + )?; + self.global_value_table + .push(ValueEntry::IntToPtrConst(int_val, target_ty)); + } + } + CST_CODE_CE_INBOUNDS_GEP => { + self.parse_gep_ce_into_global_table(&values)?; + } + _ => {} + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + fn read_function_constants( + &mut self, + abbrev_width: u32, + local_values: &mut Vec, + next_result_id: &mut u32, + ) -> Result<(), ParseError> { + self.reader.push_block_scope(CONSTANTS_BLOCK_ID); + let mut current_type_id: u32 = 0; + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + match code { + CST_CODE_SETTYPE => { + let Some(type_id) = values.first().copied() else { + return Err(self.error("SETTYPE constant record too short")); + }; + current_type_id = type_id as u32; + } + CST_CODE_INTEGER => { + let Some(encoded) = values.first().copied() else { + return Err(self.error("INTEGER constant record too short")); + }; + let val = sign_unrotate(encoded); + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "integer constant current type", + Type::Void, + )?; + local_values.push(ValueEntry::Constant(ty, Constant::Int(val))); + *next_result_id += 1; + } + CST_CODE_FLOAT => { + let Some(bits) = values.first().copied() else { + return Err(self.error("FLOAT constant record too short")); + }; + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "floating constant current type", + Type::Void, + )?; + let Some(val) = ty.decode_float_bits(bits) else { + return Err(self + .error("FLOAT constant record has non-floating current type")); + }; + local_values + .push(ValueEntry::Constant(ty.clone(), Constant::float(ty, val))); + *next_result_id += 1; + } + CST_CODE_NULL => { + let ty = self.resolve_constant_type( + current_type_id, + "constant record", + "null constant current type", + Type::Void, + )?; + local_values.push(ValueEntry::Constant(ty, Constant::Null)); + *next_result_id += 1; + } + CST_CODE_CE_CAST => { + if values.len() >= 3 && values[0] == 10 { + let src_value_id = values[2] as u32; + let int_val = self.resolve_constant_int_from_tables( + src_value_id, + local_values, + "constant expression", + "inttoptr constant source", + )?; + let target_ty = self.resolve_constant_type( + current_type_id, + "constant expression", + "inttoptr constant target type", + Type::Ptr, + )?; + local_values.push(ValueEntry::IntToPtrConst(int_val, target_ty)); + *next_result_id += 1; + } + } + CST_CODE_CE_INBOUNDS_GEP => { + self.parse_gep_ce_into_local_values(&values, local_values)?; + *next_result_id += 1; + } + _ => {} + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + // ------------------------------------------------------------------- + // Value symbol tables + // ------------------------------------------------------------------- + + fn read_function_vst( + &mut self, + abbrev_width: u32, + bb_names: &mut FxHashMap, + local_names: &mut FxHashMap, + ) -> Result<(), ParseError> { + self.reader.push_block_scope(VALUE_SYMTAB_BLOCK_ID); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + if !values.is_empty() { + let name: String = values[1..].iter().map(|&v| v as u8 as char).collect(); + match code { + VST_CODE_ENTRY => { + local_names.insert(values[0] as u32, name); + } + VST_CODE_BBENTRY => { + bb_names.insert(values[0] as u32, name); + } + _ => {} + } + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + fn read_module_vst(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(VALUE_SYMTAB_BLOCK_ID); + let mut name_remap = FxHashMap::default(); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + match code { + VST_CODE_ENTRY if values.len() > 1 => { + let value_id = values[0] as usize; + let name: String = + values[1..].iter().map(|&v| v as u8 as char).collect(); + self.apply_module_value_name(value_id, name, &mut name_remap); + } + VST_CODE_FNENTRY if values.len() > 2 => { + let value_id = values[0] as usize; + let name: String = + values[2..].iter().map(|&v| v as u8 as char).collect(); + self.apply_module_value_name(value_id, name, &mut name_remap); + } + _ => {} + } + } + } + } + self.reader.pop_block_scope(); + self.remap_module_symbol_uses(&name_remap); + Ok(()) + } + + fn read_strtab_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(STRTAB_BLOCK_ID); + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + self.reader + .read_define_abbrev() + .map_err(|e| self.error(e))?; + } + id => { + let (code, values) = self.read_record(id)?; + if code == STRTAB_BLOB { + self.string_table = values.iter().map(|&value| value as u8).collect(); + } + } + } + } + self.reader.pop_block_scope(); + self.apply_pending_strtab_names(); + Ok(()) + } + + // ------------------------------------------------------------------- + // Blockinfo block + // ------------------------------------------------------------------- + + fn read_blockinfo_block(&mut self, abbrev_width: u32) -> Result<(), ParseError> { + self.reader.push_block_scope(BLOCKINFO_BLOCK_ID); + let mut current_block_id: Option = None; + + loop { + if self.reader.at_end() { + break; + } + let abbrev_id = self.reader.read_abbrev_id(abbrev_width); + match abbrev_id { + 0 => { + self.reader.align32(); + break; + } + 1 => { + let (_, _, block_len) = self.reader.enter_subblock(); + self.reader.skip_block(block_len); + } + 2 => { + if let Some(target_id) = current_block_id { + self.reader + .read_blockinfo_abbrev(target_id) + .map_err(|e| self.error(e))?; + } else { + return Err(self.error("DEFINE_ABBREV in BLOCKINFO before SETBID")); + } + } + id => { + let (code, values) = self.read_record(id)?; + if code == BLOCKINFO_CODE_SETBID + && let Some(&block_id_val) = values.first() + { + current_block_id = Some(block_id_val as u32); + } + } + } + } + self.reader.pop_block_scope(); + Ok(()) + } + + // ------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------- + + fn global_index_from_value_id(&self, value_id: usize) -> usize { + let mut count = 0; + for (i, entry) in self.global_value_table.iter().enumerate() { + if i == value_id { + return count; + } + if matches!(entry, ValueEntry::Global(_)) { + count += 1; + } + } + count + } + + fn local_value_name_exists(local_values: &[ValueEntry], name: &str) -> bool { + local_values.iter().any(|entry| match entry { + ValueEntry::Local(entry_name, _) | ValueEntry::Param(entry_name, _) => { + entry_name == name + } + ValueEntry::Global(_) + | ValueEntry::Function(_) + | ValueEntry::Constant(_, _) + | ValueEntry::IntToPtrConst(_, _) + | ValueEntry::GepConst { .. } => false, + }) + } + + fn validate_phi_operands_resolved( + &self, + basic_blocks: &[BasicBlock], + local_values: &[ValueEntry], + ) -> Result<(), ParseError> { + for block in basic_blocks { + for instruction in &block.instructions { + let Instruction::Phi { incoming, .. } = instruction else { + continue; + }; + + for (operand, _) in incoming { + let (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _)) = operand + else { + continue; + }; + + if Self::local_value_name_exists(local_values, name) { + continue; + } + + self.unsupported_or_recover( + "phi instruction", + format!( + "PHI incoming value `{name}` could not be resolved during bitcode import" + ), + format!( + "PHI incoming value `{name}` could not be resolved during bitcode import and was preserved as a placeholder" + ), + (), + )?; + } + } + } + + Ok(()) + } + + fn func_index_from_value_id(&self, value_id: usize) -> usize { + let mut count = 0; + for (i, entry) in self.global_value_table.iter().enumerate() { + if i == value_id { + return count; + } + if matches!(entry, ValueEntry::Function(_)) { + count += 1; + } + } + count + } + + fn resolve_pending_global_initializers(&mut self) -> Result<(), ParseError> { + let pending = std::mem::take(&mut self.pending_global_initializers); + + for PendingGlobalInitializer { + global_index, + value_id, + } in pending + { + let initializer = match self.global_value_table.get(value_id as usize).cloned() { + Some(ValueEntry::Constant( + _, + constant @ (Constant::CString(_) + | Constant::Int(_) + | Constant::Float(_, _) + | Constant::Null), + )) => constant, + Some(_) => self.unsupported_or_recover( + "global variable", + format!( + "global initializer value ID {value_id} resolves to an unsupported initializer form" + ), + format!( + "global initializer value ID {value_id} resolves to an unsupported initializer form and was imported as null" + ), + Constant::Null, + )?, + None => self.unsupported_or_recover( + "global variable", + format!( + "global initializer value ID {value_id} could not be resolved during bitcode import" + ), + format!( + "global initializer value ID {value_id} could not be resolved during bitcode import and was imported as null" + ), + Constant::Null, + )?, + }; + + self.globals[global_index].initializer = Some(initializer); + } + + Ok(()) + } + + fn build_module(self) -> Module { + use crate::model::StructType as ModelStructType; + + let struct_types: Vec = self + .struct_types + .into_iter() + .map(|name| ModelStructType { + name, + is_opaque: true, + }) + .collect(); + + Module { + source_filename: self.source_filename, + target_datalayout: self.target_datalayout, + target_triple: self.target_triple, + struct_types, + globals: self.globals, + functions: self.functions, + attribute_groups: self.attribute_groups, + named_metadata: self.named_metadata, + metadata_nodes: self.metadata_nodes, + } + } +} + +// --------------------------------------------------------------------------- +// Helper functions +// --------------------------------------------------------------------------- + +fn sign_unrotate(val: u64) -> i64 { + if val & 1 == 0 { + (val >> 1) as i64 + } else if val == 1 { + i64::MIN + } else { + -((val >> 1) as i64) + } +} + +fn opcode_to_binop(opcode: u64, ty: &Type) -> Result { + let is_fp = ty.is_floating_point(); + match opcode { + 0 => { + if is_fp { + Ok(BinOpKind::Fadd) + } else { + Ok(BinOpKind::Add) + } + } + 1 => { + if is_fp { + Ok(BinOpKind::Fsub) + } else { + Ok(BinOpKind::Sub) + } + } + 2 => { + if is_fp { + Ok(BinOpKind::Fmul) + } else { + Ok(BinOpKind::Mul) + } + } + 3 => Ok(BinOpKind::Udiv), + 4 => { + if is_fp { + Ok(BinOpKind::Fdiv) + } else { + Ok(BinOpKind::Sdiv) + } + } + 5 => Ok(BinOpKind::Urem), + 6 => Ok(BinOpKind::Srem), + 7 => Ok(BinOpKind::Shl), + 8 => Ok(BinOpKind::Lshr), + 9 => Ok(BinOpKind::Ashr), + 10 => Ok(BinOpKind::And), + 11 => Ok(BinOpKind::Or), + 12 => Ok(BinOpKind::Xor), + _ => Err(ParseError::malformed( + 0, + "instruction record", + format!("unknown binop opcode: {opcode}"), + )), + } +} + +fn icmp_code_to_predicate(code: u64) -> Result { + match code { + 32 => Ok(IntPredicate::Eq), + 33 => Ok(IntPredicate::Ne), + 34 => Ok(IntPredicate::Ugt), + 35 => Ok(IntPredicate::Uge), + 36 => Ok(IntPredicate::Ult), + 37 => Ok(IntPredicate::Ule), + 38 => Ok(IntPredicate::Sgt), + 39 => Ok(IntPredicate::Sge), + 40 => Ok(IntPredicate::Slt), + 41 => Ok(IntPredicate::Sle), + _ => Err(ParseError::malformed( + 0, + "instruction record", + format!("unknown icmp predicate code: {code}"), + )), + } +} + +fn fcmp_code_to_predicate(code: u64) -> Result { + match code { + 1 => Ok(FloatPredicate::Oeq), + 2 => Ok(FloatPredicate::Ogt), + 3 => Ok(FloatPredicate::Oge), + 4 => Ok(FloatPredicate::Olt), + 5 => Ok(FloatPredicate::Ole), + 6 => Ok(FloatPredicate::One), + 7 => Ok(FloatPredicate::Ord), + 8 => Ok(FloatPredicate::Uno), + 9 => Ok(FloatPredicate::Ueq), + 10 => Ok(FloatPredicate::Ugt), + 11 => Ok(FloatPredicate::Uge), + 12 => Ok(FloatPredicate::Ult), + 13 => Ok(FloatPredicate::Ule), + 14 => Ok(FloatPredicate::Une), + _ => Err(ParseError::malformed( + 0, + "instruction record", + format!("unknown fcmp predicate code: {code}"), + )), + } +} + +fn opcode_to_cast(opcode: u64) -> Result { + match opcode { + 0 => Ok(CastKind::Trunc), + 1 => Ok(CastKind::Zext), + 2 => Ok(CastKind::Sext), + 4 => Ok(CastKind::FpTrunc), + 5 => Ok(CastKind::FpExt), + 6 => Ok(CastKind::Sitofp), + 7 => Ok(CastKind::Fptosi), + 9 => Ok(CastKind::PtrToInt), + 10 => Ok(CastKind::IntToPtr), + 11 => Ok(CastKind::Bitcast), + _ => Err(ParseError::malformed( + 0, + "instruction record", + format!("unknown cast opcode: {opcode}"), + )), + } +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +fn parse_bitcode_with_policy(data: &[u8], policy: ReadPolicy) -> Result { + let mut reader = BlockReader::new(data, policy)?; + reader.read_top_level()?; + reader.resolve_pending_global_initializers()?; + Ok(reader.build_module()) +} + +fn parse_bitcode_report_with_policy( + data: &[u8], + policy: ReadPolicy, +) -> Result> { + let mut reader = BlockReader::new(data, policy).map_err(|error| vec![error.into()])?; + match reader.read_top_level() { + Ok(()) => { + if let Err(error) = reader.resolve_pending_global_initializers() { + let mut diagnostics = std::mem::take(reader.diagnostics.get_mut()); + diagnostics.push(error.into()); + return Err(diagnostics); + } + let diagnostics = std::mem::take(reader.diagnostics.get_mut()); + let module = reader.build_module(); + Ok(ReadReport { + module, + diagnostics, + }) + } + Err(error) => { + let mut diagnostics = std::mem::take(reader.diagnostics.get_mut()); + diagnostics.push(error.into()); + Err(diagnostics) + } + } +} + +pub fn parse_bitcode_detailed( + data: &[u8], + policy: ReadPolicy, +) -> Result> { + let report = parse_bitcode_report_with_policy(data, policy)?; + if policy == ReadPolicy::Compatibility && !report.diagnostics.is_empty() { + Err(report.diagnostics) + } else { + Ok(report.module) + } +} + +/// Parses LLVM bitcode data into a `Module`. +pub fn parse_bitcode(data: &[u8]) -> Result { + parse_bitcode_with_policy(data, ReadPolicy::QirSubsetStrict) +} + +pub fn parse_bitcode_compatibility(data: &[u8]) -> Result { + parse_bitcode_detailed(data, ReadPolicy::Compatibility) + .map_err(|mut diagnostics| diagnostics.remove(0).into()) +} + +pub fn parse_bitcode_compatibility_report(data: &[u8]) -> Result> { + parse_bitcode_report_with_policy(data, ReadPolicy::Compatibility) +} diff --git a/source/compiler/qsc_llvm/src/bitcode/reader/tests.rs b/source/compiler/qsc_llvm/src/bitcode/reader/tests.rs new file mode 100644 index 0000000000..18f48e76cb --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/reader/tests.rs @@ -0,0 +1,786 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::bitcode::bitstream::BitstreamWriter; +use crate::bitcode::writer::write_bitcode; +use crate::model::Param; +use crate::model::test_helpers::*; +use crate::test_utils::{PointerProbe, assemble_text_ir, available_fast_matrix_lanes}; +use crate::{ReadDiagnosticKind, ReadPolicy}; +use std::cell::RefCell; + +fn round_trip_module(module: &Module) -> Module { + let bc = write_bitcode(module); + parse_bitcode(&bc).expect("should parse round-tripped bitcode") +} + +fn build_module_constants_bitcode( + type_records: &[(u32, Vec)], + module_records: &[(u32, Vec)], + constant_records: &[(u32, Vec)], +) -> Vec { + const TOP_ABBREV_WIDTH: u32 = 2; + const BLOCK_ABBREV_WIDTH: u32 = 4; + + let mut writer = BitstreamWriter::new(); + for magic in [0x42_u64, 0x43, 0xC0, 0xDE] { + writer.emit_bits(magic, 8); + } + + writer.enter_subblock(MODULE_BLOCK_ID, BLOCK_ABBREV_WIDTH, TOP_ABBREV_WIDTH); + + if !type_records.is_empty() { + writer.enter_subblock(TYPE_BLOCK_ID_NEW, BLOCK_ABBREV_WIDTH, BLOCK_ABBREV_WIDTH); + for (code, values) in type_records { + writer.emit_record(*code, values, BLOCK_ABBREV_WIDTH); + } + writer.exit_block(BLOCK_ABBREV_WIDTH); + } + + for (code, values) in module_records { + writer.emit_record(*code, values, BLOCK_ABBREV_WIDTH); + } + + if !constant_records.is_empty() { + writer.enter_subblock(CONSTANTS_BLOCK_ID, BLOCK_ABBREV_WIDTH, BLOCK_ABBREV_WIDTH); + for (code, values) in constant_records { + writer.emit_record(*code, values, BLOCK_ABBREV_WIDTH); + } + writer.exit_block(BLOCK_ABBREV_WIDTH); + } + + writer.exit_block(BLOCK_ABBREV_WIDTH); + writer.finish() +} + +fn bad_current_type_id_in_constants_block_fixture() -> Vec { + build_module_constants_bitcode( + &[(TYPE_CODE_NUMENTRY, vec![1]), (TYPE_CODE_INTEGER, vec![64])], + &[], + &[(CST_CODE_SETTYPE, vec![99]), (CST_CODE_INTEGER, vec![0])], + ) +} + +fn bad_inttoptr_source_id_constants_block_fixture() -> Vec { + build_module_constants_bitcode( + &[ + (TYPE_CODE_NUMENTRY, vec![2]), + (TYPE_CODE_INTEGER, vec![64]), + (TYPE_CODE_OPAQUE_POINTER, vec![]), + ], + &[], + &[ + (CST_CODE_SETTYPE, vec![1]), + (CST_CODE_CE_CAST, vec![10, 0, 77]), + ], + ) +} + +fn bad_gep_pointer_type_id_constants_block_fixture() -> Vec { + build_module_constants_bitcode( + &[(TYPE_CODE_NUMENTRY, vec![1]), (TYPE_CODE_INTEGER, vec![64])], + &[(MODULE_CODE_GLOBALVAR, vec![0, 0, 0, 0, 3])], + &[(CST_CODE_CE_INBOUNDS_GEP, vec![0, 99, 0])], + ) +} + +fn bad_gep_index_constant_id_constants_block_fixture() -> Vec { + build_module_constants_bitcode( + &[ + (TYPE_CODE_NUMENTRY, vec![2]), + (TYPE_CODE_INTEGER, vec![64]), + (TYPE_CODE_OPAQUE_POINTER, vec![]), + ], + &[(MODULE_CODE_GLOBALVAR, vec![0, 0, 0, 0, 3])], + &[(CST_CODE_CE_INBOUNDS_GEP, vec![0, 1, 0, 0, 77])], + ) +} + +fn assert_strict_rejection_and_report_recovery( + bitcode: &[u8], + expected_context: &'static str, + expected_message_fragment: &str, +) { + let diagnostics = parse_bitcode_detailed(bitcode, ReadPolicy::QirSubsetStrict) + .expect_err("strict reader should reject malformed constants-block input"); + + assert_eq!(diagnostics.len(), 1); + assert_eq!( + diagnostics[0].kind, + ReadDiagnosticKind::UnsupportedSemanticConstruct + ); + assert_eq!(diagnostics[0].context, expected_context); + assert!( + diagnostics[0].message.contains(expected_message_fragment), + "unexpected strict diagnostic: {:?}", + diagnostics + ); + + let report = parse_bitcode_compatibility_report(bitcode) + .expect("compatibility report path should recover malformed constants-block input"); + + assert_eq!(report.diagnostics.len(), 1); + assert_eq!( + report.diagnostics[0].kind, + ReadDiagnosticKind::UnsupportedSemanticConstruct + ); + assert_eq!(report.diagnostics[0].context, expected_context); + assert!( + report.diagnostics[0] + .message + .contains(expected_message_fragment), + "unexpected compatibility diagnostic: {:?}", + report.diagnostics + ); +} + +#[test] +fn sign_unrotate_positive() { + assert_eq!(sign_unrotate(0), 0); + assert_eq!(sign_unrotate(2), 1); + assert_eq!(sign_unrotate(4), 2); + assert_eq!(sign_unrotate(200), 100); +} + +#[test] +fn sign_unrotate_negative() { + assert_eq!(sign_unrotate(1), i64::MIN); + assert_eq!(sign_unrotate(3), -1); + assert_eq!(sign_unrotate(5), -2); + assert_eq!(sign_unrotate(7), -3); +} + +#[test] +fn invalid_magic_returns_error() { + let data = vec![0x00, 0x00, 0x00, 0x00]; + let result = parse_bitcode(&data); + assert!(result.is_err()); + let err = result.expect_err("should error on invalid magic"); + assert!(err.message.contains("magic"), "error: {err}"); +} + +#[test] +fn too_short_returns_error() { + let data = vec![0x42, 0x43]; + let result = parse_bitcode(&data); + assert!(result.is_err()); +} + +#[test] +fn parse_empty_module_bitcode() { + let m = empty_module(); + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("should parse empty module bitcode"); + assert!(parsed.functions.is_empty()); + assert!(parsed.globals.is_empty()); +} + +#[test] +fn parse_module_with_declaration() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("should parse module with declaration"); + assert_eq!(parsed.functions.len(), 1); + assert!(parsed.functions[0].is_declaration); + assert_eq!(parsed.functions[0].name, "__quantum__qis__h__body"); + assert_eq!(parsed.functions[0].params.len(), 1); +} + +#[test] +fn parse_simple_function_body() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(None)], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("should parse simple function body"); + assert_eq!(parsed.functions.len(), 1); + assert!(!parsed.functions[0].is_declaration); + assert_eq!(parsed.functions[0].basic_blocks.len(), 1); + assert_eq!(parsed.functions[0].basic_blocks[0].instructions.len(), 1); + assert!(matches!( + parsed.functions[0].basic_blocks[0].instructions[0], + Instruction::Ret(None) + )); +} + +#[test] +fn apply_pending_strtab_names_remaps_function_placeholders_and_call_uses() { + let mut reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + reader.functions.push(Function { + name: "__func_0".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Call { + return_ty: None, + callee: "__func_0".to_string(), + args: Vec::new(), + result: None, + attr_refs: Vec::new(), + }], + }], + }); + reader + .global_value_table + .push(ValueEntry::Function("__func_0".to_string())); + reader.pending_strtab_names.push(PendingStrtabName { + value_id: 0, + offset: 0, + size: 4, + }); + reader.string_table = b"test".to_vec(); + + reader.apply_pending_strtab_names(); + + assert!(reader.pending_strtab_names.is_empty()); + assert_eq!(reader.functions[0].name, "test"); + assert!(matches!( + reader.global_value_table[0], + ValueEntry::Function(ref name) if name == "test" + )); + assert!(matches!( + &reader.functions[0].basic_blocks[0].instructions[0], + Instruction::Call { callee, .. } if callee == "test" + )); +} + +#[test] +fn bitcode_round_trip_preserves_half_float_double_kinds() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "takes_fp".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Half, + name: None, + }, + Param { + ty: Type::Float, + name: None, + }, + Param { + ty: Type::Double, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "takes_fp".to_string(), + args: vec![ + (Type::Half, Operand::float_const(Type::Half, 1.5)), + (Type::Float, Operand::float_const(Type::Float, 2.5)), + (Type::Double, Operand::float_const(Type::Double, 3.5)), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let parsed = round_trip_module(&m); + + assert_eq!( + parsed.functions[0] + .params + .iter() + .map(|param| param.ty.clone()) + .collect::>(), + vec![Type::Half, Type::Float, Type::Double] + ); + match &parsed.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { + return_ty, + args, + result, + attr_refs, + .. + } => { + assert_eq!(return_ty, &None); + assert_eq!( + args, + &vec![ + (Type::Half, Operand::float_const(Type::Half, 1.5)), + (Type::Float, Operand::float_const(Type::Float, 2.5)), + (Type::Double, Operand::float_const(Type::Double, 3.5)), + ] + ); + assert_eq!(result, &None); + assert!(attr_refs.is_empty()); + } + other => panic!("expected call instruction, found {other:?}"), + } +} + +#[test] +fn parse_call_record_uses_relative_callee_value_with_explicit_type_flag() { + let globals = vec![ + ValueEntry::Function("__func_0".to_string()), + ValueEntry::Function("main".to_string()), + ]; + let locals = vec![ValueEntry::Param("a".to_string(), Type::Integer(64))]; + let type_table = vec![Type::Function( + Box::new(Type::Void), + vec![Type::Integer(64)], + )]; + let bb_names = FxHashMap::default(); + let diagnostics = RefCell::new(Vec::new()); + let ctx = InstrContext { + global_value_table: &globals, + local_values: &locals, + type_table: &type_table, + paramattr_lists: &[], + bb_names: &bb_names, + diagnostics: &diagnostics, + current_value_id: 3, + byte_offset: 12, + policy: ReadPolicy::QirSubsetStrict, + }; + let values = [0u64, CALL_EXPLICIT_TYPE_FLAG, 0, 3, 1]; + let mut input: RecordInput<'_> = &values; + + let instruction = parse_call_record(&ctx, &mut input).expect("should parse call record"); + + assert!(matches!( + instruction, + Instruction::Call { callee, args, .. } + if callee == "__func_0" + && matches!( + args.as_slice(), + [(Type::Integer(64), Operand::TypedLocalRef(name, ty))] + if name == "a" && ty == &Type::Integer(64) + ) + )); +} + +#[test] +fn parse_call_record_accepts_legacy_absolute_callee_value_with_explicit_type_flag() { + let globals = vec![ + ValueEntry::Function("__func_0".to_string()), + ValueEntry::Function("main".to_string()), + ]; + let locals = vec![ValueEntry::Param("a".to_string(), Type::Integer(64))]; + let type_table = vec![Type::Function( + Box::new(Type::Void), + vec![Type::Integer(64)], + )]; + let bb_names = FxHashMap::default(); + let diagnostics = RefCell::new(Vec::new()); + let ctx = InstrContext { + global_value_table: &globals, + local_values: &locals, + type_table: &type_table, + paramattr_lists: &[], + bb_names: &bb_names, + diagnostics: &diagnostics, + current_value_id: 3, + byte_offset: 12, + policy: ReadPolicy::QirSubsetStrict, + }; + let values = [0u64, CALL_EXPLICIT_TYPE_FLAG, 0, 0, 1]; + let mut input: RecordInput<'_> = &values; + + let instruction = parse_call_record(&ctx, &mut input).expect("should parse legacy call record"); + + assert!(matches!( + instruction, + Instruction::Call { callee, args, .. } + if callee == "__func_0" + && matches!( + args.as_slice(), + [(Type::Integer(64), Operand::TypedLocalRef(name, ty))] + if name == "a" && ty == &Type::Integer(64) + ) + )); +} + +#[test] +fn parse_call_record_resolves_attr_refs_from_paramattr_list() { + let globals = vec![ValueEntry::Function("callee".to_string())]; + let locals = Vec::new(); + let type_table = vec![Type::Function(Box::new(Type::Void), Vec::new())]; + let paramattr_lists = vec![vec![7, 11]]; + let bb_names = FxHashMap::default(); + let diagnostics = RefCell::new(Vec::new()); + let ctx = InstrContext { + global_value_table: &globals, + local_values: &locals, + type_table: &type_table, + paramattr_lists: ¶mattr_lists, + bb_names: &bb_names, + diagnostics: &diagnostics, + current_value_id: 1, + byte_offset: 12, + policy: ReadPolicy::QirSubsetStrict, + }; + let values = [1u64, 0, 0, 1]; + let mut input: RecordInput<'_> = &values; + + let instruction = parse_call_record(&ctx, &mut input).expect("should parse call record"); + + assert!(matches!( + instruction, + Instruction::Call { callee, attr_refs, .. } + if callee == "callee" && attr_refs == vec![7, 11] + )); +} + +#[test] +fn strict_call_target_placeholder_is_rejected() { + let globals = Vec::new(); + let locals = Vec::new(); + let type_table = vec![Type::Function(Box::new(Type::Void), Vec::new())]; + let bb_names = FxHashMap::default(); + let diagnostics = RefCell::new(Vec::new()); + let ctx = InstrContext { + global_value_table: &globals, + local_values: &locals, + type_table: &type_table, + paramattr_lists: &[], + bb_names: &bb_names, + diagnostics: &diagnostics, + current_value_id: 0, + byte_offset: 33, + policy: ReadPolicy::QirSubsetStrict, + }; + + let err = ctx + .resolve_call_target_name(0) + .expect_err("strict mode should reject unresolved callees"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "call instruction"); +} + +#[test] +fn strict_phi_forward_reference_uses_placeholder_name() { + let globals = Vec::new(); + let locals = Vec::new(); + let type_table = vec![Type::Integer(64)]; + let bb_names = FxHashMap::default(); + let diagnostics = RefCell::new(Vec::new()); + let ctx = InstrContext { + global_value_table: &globals, + local_values: &locals, + type_table: &type_table, + paramattr_lists: &[], + bb_names: &bb_names, + diagnostics: &diagnostics, + current_value_id: 4, + byte_offset: 41, + policy: ReadPolicy::QirSubsetStrict, + }; + + let operand = ctx + .resolve_phi_operand(3, &Type::Integer(64)) + .expect("strict mode should preserve forward PHI references as placeholders"); + + assert_eq!( + operand, + Operand::TypedLocalRef("val_5".to_string(), Type::Integer(64)) + ); +} + +#[test] +fn strict_unresolved_phi_placeholder_is_rejected_after_function_parse() { + let reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + let basic_blocks = vec![BasicBlock { + name: "loop".to_string(), + instructions: vec![Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![( + Operand::TypedLocalRef("val_99".to_string(), Type::Integer(64)), + "loop".to_string(), + )], + result: "val_4".to_string(), + }], + }]; + let local_values = vec![ValueEntry::Local("val_4".to_string(), Type::Integer(64))]; + + let err = reader + .validate_phi_operands_resolved(&basic_blocks, &local_values) + .expect_err("strict mode should reject unresolved PHI placeholders after function parse"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "phi instruction"); +} + +#[test] +fn strict_unknown_attribute_encoding_is_rejected() { + let diagnostics = RefCell::new(Vec::new()); + let err = + BlockReader::parse_attr_encodings(&[7], ReadPolicy::QirSubsetStrict, 19, &diagnostics) + .expect_err("strict mode should reject unknown attribute encodings"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "attribute group"); +} + +#[test] +fn strict_global_initializer_record_is_rejected() { + let mut reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + reader.type_table.push(Type::Integer(8)); + + let err = reader + .handle_module_record( + MODULE_CODE_GLOBALVAR, + &[0, 0, 1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) + .expect_err("strict mode should reject placeholder global initializers"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "global variable"); +} + +#[test] +fn strict_unknown_type_record_is_rejected() { + let mut reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + + let err = reader + .handle_type_record(999, &[]) + .expect_err("strict mode should reject unsupported type records"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "type record"); +} + +#[test] +fn strict_reader_rejects_bad_current_type_id_in_constants_block() { + let bitcode = bad_current_type_id_in_constants_block_fixture(); + + assert_strict_rejection_and_report_recovery(&bitcode, "constant record", "unknown type ID 99"); +} + +#[test] +fn strict_reader_rejects_bad_inttoptr_source_id() { + let bitcode = bad_inttoptr_source_id_constants_block_fixture(); + + assert_strict_rejection_and_report_recovery( + &bitcode, + "constant expression", + "inttoptr constant source value ID 77", + ); +} + +#[test] +fn strict_reader_rejects_bad_gep_pointer_type_id() { + let bitcode = bad_gep_pointer_type_id_constants_block_fixture(); + + assert_strict_rejection_and_report_recovery( + &bitcode, + "constant expression", + "getelementptr constant pointer type references unknown type ID 99", + ); +} + +#[test] +fn strict_reader_rejects_bad_gep_index_constant_id() { + let bitcode = bad_gep_index_constant_id_constants_block_fixture(); + + assert_strict_rejection_and_report_recovery( + &bitcode, + "constant expression", + "getelementptr constant index value ID 77", + ); +} + +#[test] +fn compatibility_entry_points_require_report_api_for_constants_recovery() { + let bitcode = bad_current_type_id_in_constants_block_fixture(); + + let diagnostics = parse_bitcode_detailed(&bitcode, ReadPolicy::Compatibility) + .expect_err("compatibility detailed parse should require the report API for recovery"); + + assert_eq!(diagnostics.len(), 1); + assert_eq!( + diagnostics[0].kind, + ReadDiagnosticKind::UnsupportedSemanticConstruct + ); + assert_eq!(diagnostics[0].context, "constant record"); + + let err = parse_bitcode_compatibility(&bitcode) + .expect_err("legacy compatibility helper should reject recoveries without diagnostics"); + + assert_eq!(err.kind, ReadDiagnosticKind::UnsupportedSemanticConstruct); + assert_eq!(err.context, "constant record"); + assert!(err.message.contains("unknown type ID 99")); +} + +#[test] +fn strict_bitcode_import_rejects_non_opaque_struct_body_fixture() { + let Some(lane) = available_fast_matrix_lanes().into_iter().next() else { + eprintln!( + "no external LLVM fast-matrix lane is available, skipping non-opaque struct bitcode fixture" + ); + return; + }; + + let bitcode = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + // `llvm-as` drops unused type aliases, so force the non-opaque struct + // body into the type table via a declaration that references `%Pair`. + "%Pair = type { i64, i64 }\ndeclare void @use(%Pair)\n", + ) + .unwrap_or_else(|error| { + panic!( + "llvm@{} should assemble non-opaque struct fixture: {error}", + lane.version + ) + }); + + let diagnostics = parse_bitcode_detailed(&bitcode, ReadPolicy::QirSubsetStrict) + .expect_err("strict bitcode import should reject non-opaque struct bodies"); + + assert_eq!(diagnostics.len(), 1); + assert_eq!( + diagnostics[0].kind, + ReadDiagnosticKind::UnsupportedSemanticConstruct + ); + assert_eq!(diagnostics[0].context, "type record"); + assert!( + diagnostics[0] + .message + .contains("unsupported type record code") + ); +} + +// Winnow-specific tests: verify primitive parsers on raw &[u64] slices + +#[test] +fn winnow_parse_char_string_from_record() { + let values: Vec = "hello".bytes().map(u64::from).collect(); + let mut input: RecordInput<'_> = &values; + let result = parse_char_string(&mut input).expect("should parse char string"); + assert_eq!(result, "hello"); + assert!(input.is_empty()); +} + +#[test] +fn winnow_parse_type_integer_record() { + let values = [64u64]; + let mut input: RecordInput<'_> = &values; + let result = parse_type_integer(&mut input).expect("should parse integer type"); + assert_eq!(result, Type::Integer(64)); +} + +#[test] +fn winnow_parse_type_integer_default() { + let values: [u64; 0] = []; + let mut input: RecordInput<'_> = &values; + let result = parse_type_integer(&mut input).expect("should default to 32"); + assert_eq!(result, Type::Integer(32)); +} + +#[test] +fn winnow_parse_type_pointer_named_record_preserves_named_ptr() { + let mut reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + reader.type_table.push(Type::Named("Qubit".to_string())); + + reader + .handle_type_record(TYPE_CODE_POINTER, &[0]) + .expect("should parse named pointer type record"); + + assert_eq!(reader.type_table[1], Type::NamedPtr("Qubit".to_string())); +} + +#[test] +fn winnow_parse_type_label_record_preserves_slot_identity() { + let mut reader = BlockReader::new(&[0x42, 0x43, 0xC0, 0xDE], ReadPolicy::QirSubsetStrict) + .expect("bitcode magic header should construct a reader"); + + reader + .handle_type_record(TYPE_CODE_LABEL, &[]) + .expect("should parse label type record"); + reader + .handle_type_record(TYPE_CODE_INTEGER, &[64]) + .expect("should parse later integer type record"); + + assert_eq!(reader.type_table[0], Type::Label); + assert_eq!(reader.type_table[1], Type::Integer(64)); +} + +#[test] +fn winnow_parse_global_var_record() { + let values = [0u64, 0, 1, 1, 3]; // ptr_ty=0, addr=0, const=true, init=true, internal + let mut input: RecordInput<'_> = &values; + let record = parse_global_var_record(&mut input).expect("should parse global var"); + assert!(record.is_const); + assert!(record.legacy_placeholder); + assert_eq!(record.init_value_id, None); + assert!(matches!(record.linkage, Linkage::Internal)); + assert_eq!(record.elem_type_id, None); // no trailing element type in short record +} diff --git a/source/compiler/qsc_llvm/src/bitcode/writer.rs b/source/compiler/qsc_llvm/src/bitcode/writer.rs new file mode 100644 index 0000000000..7771c0c617 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/writer.rs @@ -0,0 +1,2669 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use rustc_hash::FxHashMap; +use thiserror::Error; + +use super::bitstream::{AbbrevDef, AbbrevOperand, BitstreamWriter}; +use crate::model::Type; +use crate::model::{ + Attribute, BinOpKind, CastKind, Constant, FloatPredicate, Function, Instruction, IntPredicate, + Linkage, MetadataNode, MetadataValue, Module, Operand, +}; +use crate::qir::{QIR_MAJOR_VERSION_KEY, QirEmitTarget}; + +use super::constants::*; + +// Identification block codes +const IDENTIFICATION_CODE_STRING: u32 = 1; +const IDENTIFICATION_CODE_EPOCH: u32 = 2; + +// Attribute record codes +const PARAMATTR_GRP_CODE_ENTRY: u32 = 3; +const PARAMATTR_CODE_ENTRY: u32 = 2; + +// Fixed abbreviation width for all our blocks +const ABBREV_WIDTH: u32 = 4; +const TOP_LEVEL_ABBREV_WIDTH: u32 = 2; +const SYNTHETIC_METADATA_NODE_START: u32 = u32::MAX / 2; + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum WriteError { + #[error("bitcode writer could not resolve operand `{operand}` in {context}")] + UnresolvedOperand { context: String, operand: String }, + + #[error("bitcode writer could not resolve basic block `{block}` in {context}")] + MissingBasicBlock { context: String, block: String }, + + #[error("bitcode writer could not resolve callee `@{callee}`")] + UnknownCallee { callee: String }, + + #[error("bitcode writer could not resolve attribute refs {attr_refs:?} in {context}")] + MissingAttributeList { + context: String, + attr_refs: Vec, + }, + + #[error("bitcode writer could not encode floating constant `{value}` as `{ty}`")] + InvalidFloatingConstant { ty: Type, value: f64 }, + + #[error("bitcode writer could not resolve metadata constant `{ty} {value}`")] + MissingMetadataConstant { ty: Type, value: i64 }, + + #[error("bitcode writer could not resolve metadata node `!{node_id}`")] + MissingMetadataNode { node_id: u32 }, + + #[error("bitcode writer could not resolve module constant for {context}")] + MissingModuleConstant { context: String }, +} + +impl WriteError { + fn unresolved_operand(context: impl Into, operand: &Operand) -> Self { + Self::UnresolvedOperand { + context: context.into(), + operand: format_operand(operand), + } + } + + fn missing_basic_block(context: impl Into, block: &str) -> Self { + Self::MissingBasicBlock { + context: context.into(), + block: block.to_string(), + } + } +} + +fn format_operand(operand: &Operand) -> String { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => format!("%{name}"), + Operand::IntConst(ty, value) => format!("{ty} {value}"), + Operand::FloatConst(ty, value) => format!("{ty} {value}"), + Operand::NullPtr => "null".to_string(), + Operand::IntToPtr(value, ty) => format!("inttoptr (i64 {value} to {ty})"), + Operand::GetElementPtr { ptr, .. } => format!("getelementptr from {ptr}"), + Operand::GlobalRef(name) => format!("@{name}"), + } +} + +/// Writes a `Module` as LLVM bitcode. +pub fn write_bitcode(module: &Module) -> Vec { + try_write_bitcode(module) + .unwrap_or_else(|error| panic!("failed to write LLVM bitcode: {error}")) +} + +/// Writes a `Module` as LLVM bitcode for the requested QIR compatibility target. +pub fn write_bitcode_for_target(module: &Module, emit_target: QirEmitTarget) -> Vec { + try_write_bitcode_for_target(module, emit_target) + .unwrap_or_else(|error| panic!("failed to write LLVM bitcode: {error}")) +} + +/// Writes a `Module` as LLVM bitcode, returning a structured error if emission fails. +pub fn try_write_bitcode(module: &Module) -> Result, WriteError> { + try_write_bitcode_for_target(module, infer_emit_target(module)) +} + +/// Writes a `Module` as LLVM bitcode for the requested QIR compatibility target. +pub fn try_write_bitcode_for_target( + module: &Module, + emit_target: QirEmitTarget, +) -> Result, WriteError> { + let mut ctx = WriteContext::new(module, emit_target); + ctx.write()?; + Ok(ctx.writer.finish()) +} + +struct TypeTable { + types: Vec, + map: FxHashMap, +} + +impl TypeTable { + fn new() -> Self { + Self { + types: Vec::new(), + map: FxHashMap::default(), + } + } + + fn get_or_insert(&mut self, ty: &Type) -> u32 { + if let Some(&id) = self.map.get(ty) { + return id; + } + let id = self.types.len() as u32; + self.types.push(ty.clone()); + self.map.insert(ty.clone(), id); + id + } +} + +#[derive(Debug, Clone)] +enum MetadataSlotKind { + String(String), + Value(Type, i64), + Node(u32), +} + +#[derive(Debug, Clone)] +enum LoweredMetadataValue { + String(String), + Int(Type, i64), + NodeRef(u32), +} + +#[derive(Debug, Clone)] +struct LoweredMetadataNode { + id: u32, + values: Vec, +} + +struct WriteContext<'a> { + module: &'a Module, + emit_target: QirEmitTarget, + writer: BitstreamWriter, + type_table: TypeTable, + // Value enumeration: maps (name) -> value_id at module scope + global_value_ids: FxHashMap, + next_global_value_id: u32, + attr_list_table: Vec>, + metadata_slots: Vec, + module_constant_ids: FxHashMap, + module_constants: Vec<(Type, Constant)>, + module_strtab: Vec, + module_function_name_offsets: FxHashMap, + function_word_offsets: FxHashMap, + module_vst_offset_placeholder_bit: Option, +} + +/// Describes a constant expression to be emitted in pass 2 of the constants +/// block, after all regular (non-CE) constants have been assigned value IDs. +#[derive(Debug, Clone)] +enum PendingCE { + IntToPtr { + val: i64, + }, + InboundsGep { + source_ty: Type, + ptr_name: String, + indices: Vec, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ModuleConstantKey { + Int(Type, i64), + Float(Type, u64), + Null(Type), + CString(Type, String), +} + +fn lower_metadata_graph( + nodes: &[MetadataNode], +) -> (Vec, Vec) { + let mut synthetic_nodes = Vec::new(); + let mut visible_nodes = Vec::with_capacity(nodes.len()); + let mut next_synthetic_id = SYNTHETIC_METADATA_NODE_START; + + for node in nodes { + let values = + lower_metadata_values(&node.values, &mut next_synthetic_id, &mut synthetic_nodes); + visible_nodes.push(LoweredMetadataNode { + id: node.id, + values, + }); + } + + (synthetic_nodes, visible_nodes) +} + +fn lower_metadata_values( + values: &[MetadataValue], + next_synthetic_id: &mut u32, + synthetic_nodes: &mut Vec, +) -> Vec { + let mut lowered = Vec::with_capacity(values.len()); + + for value in values { + match value { + MetadataValue::String(text) => { + lowered.push(LoweredMetadataValue::String(text.clone())); + } + MetadataValue::Int(ty, value) => { + lowered.push(LoweredMetadataValue::Int(ty.clone(), *value)); + } + MetadataValue::NodeRef(node_id) => { + lowered.push(LoweredMetadataValue::NodeRef(*node_id)); + } + MetadataValue::SubList(children) => { + let synthetic_id = *next_synthetic_id; + *next_synthetic_id += 1; + let lowered_children = + lower_metadata_values(children, next_synthetic_id, synthetic_nodes); + synthetic_nodes.push(LoweredMetadataNode { + id: synthetic_id, + values: lowered_children, + }); + lowered.push(LoweredMetadataValue::NodeRef(synthetic_id)); + } + } + } + + lowered +} + +fn encode_metadata_operands( + values: &[LoweredMetadataValue], + find_string_slot: &impl Fn(&str) -> Option, + find_value_slot: &impl Fn(&Type, i64) -> Option, + find_node_slot: &impl Fn(u32) -> Option, +) -> Result, WriteError> { + let mut operands = Vec::with_capacity(values.len()); + + for value in values { + match value { + LoweredMetadataValue::String(text) => { + if let Some(idx) = find_string_slot(text) { + operands.push(idx as u64); + } + } + LoweredMetadataValue::Int(ty, value) => { + if let Some(idx) = find_value_slot(ty, *value) { + operands.push(idx as u64); + } + } + LoweredMetadataValue::NodeRef(node_id) => { + let idx = find_node_slot(*node_id) + .ok_or(WriteError::MissingMetadataNode { node_id: *node_id })?; + operands.push(idx as u64); + } + } + } + + Ok(operands) +} + +impl<'a> WriteContext<'a> { + fn new(module: &'a Module, emit_target: QirEmitTarget) -> Self { + Self { + module, + emit_target, + writer: BitstreamWriter::new(), + type_table: TypeTable::new(), + global_value_ids: FxHashMap::default(), + next_global_value_id: 0, + attr_list_table: Vec::new(), + metadata_slots: Vec::new(), + module_constant_ids: FxHashMap::default(), + module_constants: Vec::new(), + module_strtab: Vec::new(), + module_function_name_offsets: FxHashMap::default(), + function_word_offsets: FxHashMap::default(), + module_vst_offset_placeholder_bit: None, + } + } + + fn write(&mut self) -> Result<(), WriteError> { + // 1. Emit magic + self.writer.emit_bits(0x42, 8); + self.writer.emit_bits(0x43, 8); + self.writer.emit_bits(0xC0, 8); + self.writer.emit_bits(0xDE, 8); + + // 2. Write identification block + self.write_identification_block(); + + // 3. Collect all types from the module + self.collect_types(); + + // 4. Enumerate global values + self.enumerate_global_values(); + + // Reserve module-scope constant IDs before globals need to reference + // them as initializers and before metadata needs to reference them. + self.collect_module_constants()?; + + // 4b. Build the module string table payload used by opaque-lane + // modern function naming records. + self.build_module_function_strtab(); + + // 5. Enter MODULE_BLOCK + self.writer + .enter_subblock(MODULE_BLOCK_ID, ABBREV_WIDTH, TOP_LEVEL_ABBREV_WIDTH); + + // Emit the module layout version that matches the records written + // below for the selected compatibility lane. + self.writer.emit_record( + MODULE_CODE_VERSION, + &[self.emit_target.module_bitcode_version()], + ABBREV_WIDTH, + ); + + // 7. Write triple/datalayout if present + if let Some(ref triple) = self.module.target_triple { + let chars: Vec = triple.bytes().map(u64::from).collect(); + self.writer + .emit_record(MODULE_CODE_TRIPLE, &chars, ABBREV_WIDTH); + } + if let Some(ref dl) = self.module.target_datalayout { + let chars: Vec = dl.bytes().map(u64::from).collect(); + self.writer + .emit_record(MODULE_CODE_DATALAYOUT, &chars, ABBREV_WIDTH); + } + if let Some(ref sf) = self.module.source_filename { + let chars: Vec = sf.bytes().map(u64::from).collect(); + self.writer + .emit_record(MODULE_CODE_SOURCE_FILENAME, &chars, ABBREV_WIDTH); + } + + // 7b. Build and write attribute blocks (before type block per LLVM convention) + self.build_attr_list_table(); + self.write_paramattr_group_block(); + self.write_paramattr_block(); + + // 7c. Pre-register metadata constant types so they appear in the type block. + self.register_metadata_constant_types(); + + // 8. Write type block + self.write_type_block(); + + // 9. Write global variables + self.write_global_vars()?; + + // 10. Write function prototypes (declarations and definition headers) + self.write_function_protos()?; + + if self.uses_modern_function_naming_container() { + self.write_module_vst_offset_placeholder(); + } + + // 11. Write module-scope constants after globals and prototypes. + if !self.module_constants.is_empty() { + self.write_module_constants_block()?; + } + + // 12. Build and write metadata block + self.build_metadata_slots(); + self.write_metadata_block()?; + + // 13. Write function bodies + for func in &self.module.functions.clone() { + if !func.is_declaration { + self.write_function_body(func)?; + } + } + + // 14. Write value symbol table + self.write_value_symtab(); + + // 15. Exit MODULE_BLOCK + self.writer.exit_block(ABBREV_WIDTH); + + if self.uses_modern_function_naming_container() { + self.write_module_strtab_block(); + } + + Ok(()) + } + + fn uses_modern_function_naming_container(&self) -> bool { + self.emit_target == QirEmitTarget::QirV2Opaque + } + + fn build_module_function_strtab(&mut self) { + self.module_strtab.clear(); + self.module_function_name_offsets.clear(); + + if !self.uses_modern_function_naming_container() { + return; + } + + for function in &self.module.functions { + let offset = self.module_strtab.len() as u32; + let bytes = function.name.as_bytes(); + self.module_strtab.extend_from_slice(bytes); + self.module_function_name_offsets + .insert(function.name.clone(), (offset, bytes.len() as u32)); + } + } + + fn module_function_name_range(&self, name: &str) -> (u32, u32) { + self.module_function_name_offsets + .get(name) + .copied() + .unwrap_or((0, 0)) + } + + fn write_module_vst_offset_placeholder(&mut self) { + let abbrev = AbbrevDef { + operands: vec![ + AbbrevOperand::Literal(u64::from(MODULE_CODE_VSTOFFSET)), + AbbrevOperand::Fixed(32), + ], + }; + let abbrev_id = self.writer.emit_define_abbrev(&abbrev, ABBREV_WIDTH); + self.writer + .emit_abbreviated_record(abbrev_id, &[0], ABBREV_WIDTH); + self.module_vst_offset_placeholder_bit = Some(self.writer.bit_position() - 32); + } + + fn patch_module_vst_offset(&mut self) { + let Some(patch_bit_position) = self.module_vst_offset_placeholder_bit.take() else { + return; + }; + + let vst_bit_position = self.writer.bit_position(); + self.writer + .patch_u32_bits(patch_bit_position, (vst_bit_position / 32) as u32); + } + + fn write_module_strtab_block(&mut self) { + self.writer + .enter_subblock(STRTAB_BLOCK_ID, ABBREV_WIDTH, TOP_LEVEL_ABBREV_WIDTH); + + let abbrev = AbbrevDef { + operands: vec![ + AbbrevOperand::Literal(u64::from(STRTAB_BLOB)), + AbbrevOperand::Blob, + ], + }; + let abbrev_id = self.writer.emit_define_abbrev(&abbrev, ABBREV_WIDTH); + + let mut fields = Vec::with_capacity(self.module_strtab.len() + 1); + fields.push(self.module_strtab.len() as u64); + fields.extend(self.module_strtab.iter().map(|&byte| u64::from(byte))); + self.writer + .emit_abbreviated_record(abbrev_id, &fields, ABBREV_WIDTH); + + self.writer.exit_block(ABBREV_WIDTH); + } + + fn write_identification_block(&mut self) { + self.writer + .enter_subblock(IDENTIFICATION_BLOCK_ID, ABBREV_WIDTH, 2); + let producer = "qsc_codegen"; + let chars: Vec = producer.bytes().map(u64::from).collect(); + self.writer + .emit_record(IDENTIFICATION_CODE_STRING, &chars, ABBREV_WIDTH); + // epoch = 0 (current) + self.writer + .emit_record(IDENTIFICATION_CODE_EPOCH, &[0], ABBREV_WIDTH); + self.writer.exit_block(ABBREV_WIDTH); + } + + fn build_attr_list_table(&mut self) { + fn insert_attr_list( + attr_list_table: &mut Vec>, + seen: &mut FxHashMap, usize>, + attr_refs: &[u32], + normalize: bool, + ) { + if attr_refs.is_empty() { + return; + } + + let mut key = attr_refs.to_vec(); + if normalize { + key.sort_unstable(); + } + + if !seen.contains_key(&key) { + let idx = attr_list_table.len(); + attr_list_table.push(key.clone()); + seen.insert(key, idx); + } + } + + let mut seen: FxHashMap, usize> = FxHashMap::default(); + for f in &self.module.functions { + insert_attr_list( + &mut self.attr_list_table, + &mut seen, + &f.attribute_group_refs, + true, + ); + + for block in &f.basic_blocks { + for instruction in &block.instructions { + if let Instruction::Call { attr_refs, .. } = instruction { + insert_attr_list(&mut self.attr_list_table, &mut seen, attr_refs, false); + } + } + } + } + } + + fn write_paramattr_group_block(&mut self) { + // Collect the set of attribute group IDs actually referenced by functions or calls. + let used_ids: rustc_hash::FxHashSet = + self.attr_list_table.iter().flatten().copied().collect(); + + let groups: Vec<_> = self + .module + .attribute_groups + .iter() + .filter(|g| used_ids.contains(&g.id)) + .collect(); + + if groups.is_empty() { + return; + } + + self.writer + .enter_subblock(PARAMATTR_GROUP_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + for group in &groups { + let mut values: Vec = vec![ + u64::from(group.id), + 0xFFFF_FFFF, // function-level attributes + ]; + for attr in &group.attributes { + match attr { + Attribute::StringAttr(s) => { + values.push(3); // string attr code + for ch in s.bytes() { + values.push(u64::from(ch)); + } + values.push(0); // null terminator + } + Attribute::KeyValue(key, val) => { + values.push(4); // key/value attr code + for ch in key.bytes() { + values.push(u64::from(ch)); + } + values.push(0); // null terminator + for ch in val.bytes() { + values.push(u64::from(ch)); + } + values.push(0); // null terminator + } + } + } + self.writer + .emit_record(PARAMATTR_GRP_CODE_ENTRY, &values, ABBREV_WIDTH); + } + + self.writer.exit_block(ABBREV_WIDTH); + } + + fn write_paramattr_block(&mut self) { + if self.attr_list_table.is_empty() { + return; + } + + self.writer + .enter_subblock(PARAMATTR_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + for group_ids in &self.attr_list_table { + let values: Vec = group_ids.iter().map(|&id| u64::from(id)).collect(); + self.writer + .emit_record(PARAMATTR_CODE_ENTRY, &values, ABBREV_WIDTH); + } + + self.writer.exit_block(ABBREV_WIDTH); + } + + fn collect_types(&mut self) { + // Collect all types used in the module + for st in &self.module.struct_types { + if st.is_opaque { + self.type_table.get_or_insert(&Type::Named(st.name.clone())); + } + } + + for g in &self.module.globals { + self.collect_type(&g.ty); + // Globals are accessed via pointer + let global_ptr_ty = self.emit_target.pointer_type_for_pointee(&g.ty); + self.type_table.get_or_insert(&global_ptr_ty); + } + + for f in &self.module.functions { + self.collect_type(&f.return_type); + for p in &f.params { + self.collect_type(&p.ty); + } + // Function type + let func_ty = Type::Function( + Box::new(f.return_type.clone()), + f.params.iter().map(|p| p.ty.clone()).collect(), + ); + self.collect_type(&func_ty); + // Pointer to function + let function_ptr_ty = self.emit_target.pointer_type_for_pointee(&func_ty); + self.type_table.get_or_insert(&function_ptr_ty); + + for bb in &f.basic_blocks { + for instr in &bb.instructions { + self.collect_instruction_types(instr); + } + } + } + + // Collect types used in metadata values + for node in &self.module.metadata_nodes { + self.collect_metadata_value_types(&node.values); + } + } + + fn collect_metadata_value_types(&mut self, values: &[MetadataValue]) { + for val in values { + match val { + MetadataValue::Int(ty, _) => { + self.collect_type(ty); + } + MetadataValue::SubList(children) => { + self.collect_metadata_value_types(children); + } + MetadataValue::String(_) | MetadataValue::NodeRef(_) => {} + } + } + } + + fn collect_type(&mut self, ty: &Type) { + match ty { + Type::Function(ret, params) => { + self.collect_type(ret); + for p in params { + self.collect_type(p); + } + self.type_table.get_or_insert(ty); + } + Type::Array(_, elem) => { + self.collect_type(elem); + self.type_table.get_or_insert(ty); + } + Type::TypedPtr(inner) => { + self.collect_type(inner); + self.type_table.get_or_insert(ty); + } + Type::NamedPtr(name) => { + self.type_table.get_or_insert(&Type::Named(name.clone())); + self.type_table.get_or_insert(ty); + } + _ => { + self.type_table.get_or_insert(ty); + } + } + } + + fn collect_instruction_types(&mut self, instr: &Instruction) { + match instr { + Instruction::Ret(Some(op)) => self.collect_operand_types(op), + Instruction::Ret(None) => {} + Instruction::Br { cond_ty, cond, .. } => { + self.collect_type(cond_ty); + self.collect_operand_types(cond); + } + Instruction::Jump { .. } => {} + Instruction::BinOp { ty, lhs, rhs, .. } => { + self.collect_type(ty); + self.collect_operand_types(lhs); + self.collect_operand_types(rhs); + } + Instruction::ICmp { ty, lhs, rhs, .. } | Instruction::FCmp { ty, lhs, rhs, .. } => { + self.collect_type(ty); + self.collect_operand_types(lhs); + self.collect_operand_types(rhs); + self.type_table.get_or_insert(&Type::Integer(1)); + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + self.collect_type(from_ty); + self.collect_type(to_ty); + self.collect_operand_types(value); + } + Instruction::Call { + return_ty, args, .. + } => { + if let Some(ret) = return_ty { + self.collect_type(ret); + } + for (ty, op) in args { + self.collect_type(ty); + self.collect_operand_types(op); + } + } + Instruction::Phi { ty, incoming, .. } => { + self.collect_type(ty); + for (op, _) in incoming { + self.collect_operand_types(op); + } + } + Instruction::Alloca { ty, .. } => { + self.collect_type(ty); + let alloca_ptr_ty = self.emit_target.pointer_type_for_pointee(ty); + self.type_table.get_or_insert(&alloca_ptr_ty); + } + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + self.collect_type(ty); + self.collect_type(ptr_ty); + self.collect_operand_types(ptr); + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + self.collect_type(ty); + self.collect_operand_types(value); + self.collect_type(ptr_ty); + self.collect_operand_types(ptr); + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + self.collect_type(&Type::Integer(1)); + self.collect_operand_types(cond); + self.collect_type(ty); + self.collect_operand_types(true_val); + self.collect_operand_types(false_val); + } + Instruction::Switch { ty, value, .. } => { + self.collect_type(ty); + self.collect_operand_types(value); + } + Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + ptr, + indices, + .. + } => { + self.collect_type(pointee_ty); + self.collect_type(ptr_ty); + self.collect_operand_types(ptr); + for idx in indices { + self.collect_operand_types(idx); + } + } + Instruction::Unreachable => {} + } + } + + fn collect_operand_types(&mut self, op: &Operand) { + match op { + Operand::IntConst(ty, _) => self.collect_type(ty), + Operand::FloatConst(ty, _) => self.collect_type(ty), + Operand::NullPtr => { + let null_ptr_ty = self.emit_target.default_pointer_type(); + self.type_table.get_or_insert(&null_ptr_ty); + } + Operand::IntToPtr(_, ty) => { + self.collect_type(&Type::Integer(64)); + self.collect_type(ty); + } + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + self.collect_type(ty); + self.collect_type(ptr_ty); + for index in indices { + self.collect_operand_types(index); + } + } + Operand::LocalRef(_) | Operand::TypedLocalRef(_, _) | Operand::GlobalRef(_) => {} + } + } + + fn enumerate_global_values(&mut self) { + // Globals first + for g in &self.module.globals { + let id = self.next_global_value_id; + self.global_value_ids.insert(g.name.clone(), id); + self.next_global_value_id += 1; + // Suppress unused variable warning + let _ = &g.ty; + } + // Then functions + for f in &self.module.functions { + let id = self.next_global_value_id; + self.global_value_ids.insert(f.name.clone(), id); + self.next_global_value_id += 1; + } + } + + fn module_constant_key( + ty: &Type, + constant: &Constant, + ) -> Result { + match constant { + Constant::Int(value) => Ok(ModuleConstantKey::Int(ty.clone(), *value)), + Constant::Float(_, value) => { + let bits = ty.encode_float_bits(*value).ok_or_else(|| { + WriteError::InvalidFloatingConstant { + ty: ty.clone(), + value: *value, + } + })?; + Ok(ModuleConstantKey::Float(ty.clone(), bits)) + } + Constant::Null => Ok(ModuleConstantKey::Null(ty.clone())), + Constant::CString(text) => Ok(ModuleConstantKey::CString(ty.clone(), text.clone())), + } + } + + fn push_module_constant(&mut self, ty: Type, constant: Constant) -> Result<(), WriteError> { + let key = Self::module_constant_key(&ty, &constant)?; + if self.module_constant_ids.contains_key(&key) { + return Ok(()); + } + + let index = + u32::try_from(self.module_constants.len()).expect("module constant count exceeded u32"); + self.module_constant_ids + .insert(key, self.next_global_value_id + index); + self.module_constants.push((ty, constant)); + Ok(()) + } + + fn resolve_module_constant_value_id( + &self, + ty: &Type, + constant: &Constant, + context: impl Into, + ) -> Result { + let key = Self::module_constant_key(ty, constant)?; + self.module_constant_ids.get(&key).copied().ok_or_else(|| { + WriteError::MissingModuleConstant { + context: context.into(), + } + }) + } + + fn write_type_block(&mut self) { + self.writer + .enter_subblock(TYPE_BLOCK_ID_NEW, ABBREV_WIDTH, ABBREV_WIDTH); + + let num_types = self.type_table.types.len() as u64; + self.writer + .emit_record(TYPE_CODE_NUMENTRY, &[num_types], ABBREV_WIDTH); + + let types = self.type_table.types.clone(); + for ty in &types { + self.write_type_record(ty); + } + + self.writer.exit_block(ABBREV_WIDTH); + } + + fn write_type_record(&mut self, ty: &Type) { + match ty { + Type::Void => { + self.writer.emit_record(TYPE_CODE_VOID, &[], ABBREV_WIDTH); + } + Type::Integer(width) => { + self.writer + .emit_record(TYPE_CODE_INTEGER, &[u64::from(*width)], ABBREV_WIDTH); + } + Type::Half => { + self.writer.emit_record(TYPE_CODE_HALF, &[], ABBREV_WIDTH); + } + Type::Float => { + self.writer.emit_record(TYPE_CODE_FLOAT, &[], ABBREV_WIDTH); + } + Type::Double => { + self.writer.emit_record(TYPE_CODE_DOUBLE, &[], ABBREV_WIDTH); + } + Type::Label => { + self.writer.emit_record(TYPE_CODE_LABEL, &[], ABBREV_WIDTH); + } + Type::Ptr => { + self.writer + .emit_record(TYPE_CODE_OPAQUE_POINTER, &[0], ABBREV_WIDTH); + } + Type::Named(name) => { + // Emit struct name then opaque + let chars: Vec = name.bytes().map(u64::from).collect(); + self.writer + .emit_record(TYPE_CODE_STRUCT_NAME, &chars, ABBREV_WIDTH); + self.writer + .emit_record(TYPE_CODE_OPAQUE, &[0], ABBREV_WIDTH); + } + Type::NamedPtr(name) => { + if self.emit_target.uses_typed_pointers() { + let inner_id = self.type_table.get_or_insert(&Type::Named(name.clone())); + self.writer.emit_record( + TYPE_CODE_POINTER, + &[u64::from(inner_id), 0], + ABBREV_WIDTH, + ); + } else { + self.writer + .emit_record(TYPE_CODE_OPAQUE_POINTER, &[0], ABBREV_WIDTH); + } + } + Type::TypedPtr(inner) => { + if self.emit_target.uses_typed_pointers() { + let inner_id = self.type_table.get_or_insert(inner); + self.writer.emit_record( + TYPE_CODE_POINTER, + &[u64::from(inner_id), 0], + ABBREV_WIDTH, + ); + } else { + self.writer + .emit_record(TYPE_CODE_OPAQUE_POINTER, &[0], ABBREV_WIDTH); + } + } + Type::Array(len, elem) => { + let elem_id = self.type_table.get_or_insert(elem); + self.writer + .emit_record(TYPE_CODE_ARRAY, &[*len, u64::from(elem_id)], ABBREV_WIDTH); + } + Type::Function(ret, params) => { + let ret_id = self.type_table.get_or_insert(ret); + let mut values = vec![0u64, u64::from(ret_id)]; // 0 = not vararg + for p in params { + let p_id = self.type_table.get_or_insert(p); + values.push(u64::from(p_id)); + } + self.writer + .emit_record(TYPE_CODE_FUNCTION_TYPE, &values, ABBREV_WIDTH); + } + } + } + + fn write_global_vars(&mut self) -> Result<(), WriteError> { + for g in &self.module.globals { + let ty_id = self.type_table.get_or_insert(&g.ty); + let global_ptr_ty = self.emit_target.pointer_type_for_pointee(&g.ty); + let ptr_ty_id = self.type_table.get_or_insert(&global_ptr_ty); + let linkage: u64 = match g.linkage { + Linkage::External => 0, + Linkage::Internal => 3, + }; + let is_const: u64 = u64::from(g.is_constant); + let init_id = if let Some(initializer) = &g.initializer { + u64::from( + self.resolve_module_constant_value_id( + &g.ty, + initializer, + format!("global initializer @{}", g.name), + )? + 1, + ) + } else { + 0 + }; + // MODULE_CODE_GLOBALVAR: [pointer_type, address_space, is_const, init_id, linkage, alignment, section, ...] + // We append the actual element type ID as an extra trailing field + // so our reader can recover the global's element type. + let values = vec![ + u64::from(ptr_ty_id), // pointer type + 0, // address_space + is_const, + init_id, // init id (0 = none, otherwise value_id + 1) + linkage, + 0, // alignment + 0, // section + 0, // visibility + 0, // thread_local + 0, // unnamed_addr + 0, // externally_initialized + 0, // dso_local + 0, // comdat + u64::from(ty_id), // actual element type (our extension) + ]; + self.writer + .emit_record(MODULE_CODE_GLOBALVAR, &values, ABBREV_WIDTH); + } + + Ok(()) + } + + fn write_function_protos(&mut self) -> Result<(), WriteError> { + for f in &self.module.functions { + let func_ty = Type::Function( + Box::new(f.return_type.clone()), + f.params.iter().map(|p| p.ty.clone()).collect(), + ); + let func_ty_id = self.type_table.get_or_insert(&func_ty); + let is_decl: u64 = u64::from(f.is_declaration); + let linkage: u64 = 0; // external + + // Look up the 1-based paramattr index from the attr list table. + // 0 means no attributes. + let paramattr: u64 = if f.attribute_group_refs.is_empty() { + 0 + } else { + self.resolve_attr_list_index( + &f.attribute_group_refs, + true, + format!("function prototype @{}", f.name), + )? + }; + + let values = if self.uses_modern_function_naming_container() { + let (name_offset, name_size) = self.module_function_name_range(&f.name); + vec![ + u64::from(name_offset), + u64::from(name_size), + u64::from(func_ty_id), + 0, // calling conv + is_decl, // isproto (1 = declaration) + linkage, + paramattr, + 0, // alignment + 0, // section + 0, // visibility + 0, // gc + ] + } else { + // MODULE_CODE_FUNCTION: [type, callingconv, isproto, linkage, paramattr, alignment, section, visibility, gc, unnamed_addr, prologuedata, dllstorageclass, comdat, prefixdata, personalityfn, dso_local] + vec![ + u64::from(func_ty_id), + 0, // calling conv + is_decl, // isproto (1 = declaration) + linkage, + paramattr, + 0, // alignment + 0, // section + 0, // visibility + 0, // gc + ] + }; + self.writer + .emit_record(MODULE_CODE_FUNCTION, &values, ABBREV_WIDTH); + } + + Ok(()) + } + + fn register_metadata_constant_types(&mut self) { + // Pre-register types used by metadata integer constants so they + // appear in the type block. Does not allocate value IDs. + for node in &self.module.metadata_nodes.clone() { + Self::register_metadata_types_from_values(&node.values, &mut self.type_table); + } + } + + fn register_metadata_types_from_values(values: &[MetadataValue], type_table: &mut TypeTable) { + for v in values { + match v { + MetadataValue::Int(ty, _) => { + type_table.get_or_insert(ty); + } + MetadataValue::SubList(sub) => { + Self::register_metadata_types_from_values(sub, type_table); + } + MetadataValue::String(_) | MetadataValue::NodeRef(_) => {} + } + } + } + + fn collect_module_constants(&mut self) -> Result<(), WriteError> { + fn visit_metadata_values( + values: &[MetadataValue], + writer: &mut WriteContext<'_>, + ) -> Result<(), WriteError> { + for v in values { + match v { + MetadataValue::Int(ty, val) => { + writer.push_module_constant(ty.clone(), Constant::Int(*val))?; + } + MetadataValue::SubList(sub) => visit_metadata_values(sub, writer)?, + MetadataValue::String(_) | MetadataValue::NodeRef(_) => {} + } + } + + Ok(()) + } + + self.module_constant_ids.clear(); + self.module_constants.clear(); + + for global in &self.module.globals { + if let Some(initializer) = &global.initializer { + self.push_module_constant(global.ty.clone(), initializer.clone())?; + } + } + + for node in &self.module.metadata_nodes.clone() { + visit_metadata_values(&node.values, self)?; + } + + self.next_global_value_id += + u32::try_from(self.module_constants.len()).expect("module constant count exceeded u32"); + + Ok(()) + } + + fn write_module_constants_block(&mut self) -> Result<(), WriteError> { + if self.module_constants.is_empty() { + return Ok(()); + } + + self.writer + .enter_subblock(CONSTANTS_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + let mut current_type: Option = None; + + let module_constants = self.module_constants.clone(); + for (ty, constant) in &module_constants { + let ty_id = self.type_table.get_or_insert(ty); + if current_type != Some(ty_id) { + self.writer + .emit_record(CST_CODE_SETTYPE, &[u64::from(ty_id)], ABBREV_WIDTH); + current_type = Some(ty_id); + } + + match constant { + Constant::Int(value) => { + let encoded = sign_rotate(*value); + self.writer + .emit_record(CST_CODE_INTEGER, &[encoded], ABBREV_WIDTH); + } + Constant::Float(_, value) => { + let bits = ty.encode_float_bits(*value).ok_or_else(|| { + WriteError::InvalidFloatingConstant { + ty: ty.clone(), + value: *value, + } + }); + self.writer + .emit_record(CST_CODE_FLOAT, &[bits?], ABBREV_WIDTH); + } + Constant::Null => { + self.writer.emit_record(CST_CODE_NULL, &[], ABBREV_WIDTH); + } + Constant::CString(text) => { + let chars: Vec = text.bytes().map(u64::from).collect(); + self.writer + .emit_record(CST_CODE_CSTRING, &chars, ABBREV_WIDTH); + } + } + } + + self.writer.exit_block(ABBREV_WIDTH); + + Ok(()) + } + + fn build_metadata_slots(&mut self) { + if self.module.metadata_nodes.is_empty() && self.module.named_metadata.is_empty() { + return; + } + + // Phase 1: Collect unique strings and values from all metadata nodes, + // including SubList children. + let mut string_set: Vec = Vec::new(); + let mut value_set: Vec<(Type, i64)> = Vec::new(); + + fn collect_leaf_entries( + values: &[MetadataValue], + strings: &mut Vec, + vals: &mut Vec<(Type, i64)>, + ) { + for v in values { + match v { + MetadataValue::String(s) => { + if !strings.contains(s) { + strings.push(s.clone()); + } + } + MetadataValue::Int(ty, val) => { + let key = (ty.clone(), *val); + if !vals.contains(&key) { + vals.push(key); + } + } + MetadataValue::SubList(sub) => { + collect_leaf_entries(sub, strings, vals); + } + MetadataValue::NodeRef(_) => {} + } + } + } + + for node in &self.module.metadata_nodes { + collect_leaf_entries(&node.values, &mut string_set, &mut value_set); + } + + // Phase 2: Assign slots in order: strings, then values, then nodes. + self.metadata_slots.clear(); + + for s in &string_set { + self.metadata_slots + .push(MetadataSlotKind::String(s.clone())); + } + for (ty, val) in &value_set { + self.metadata_slots + .push(MetadataSlotKind::Value(ty.clone(), *val)); + } + + let (synthetic_nodes, _) = lower_metadata_graph(&self.module.metadata_nodes); + + // Add synthetic child node slots first + for node in &synthetic_nodes { + self.metadata_slots.push(MetadataSlotKind::Node(node.id)); + } + + // Then add visible metadata node slots + for node in &self.module.metadata_nodes { + self.metadata_slots.push(MetadataSlotKind::Node(node.id)); + } + } + + fn write_metadata_block(&mut self) -> Result<(), WriteError> { + if self.module.metadata_nodes.is_empty() && self.module.named_metadata.is_empty() { + return Ok(()); + } + + self.writer + .enter_subblock(METADATA_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + // Build helper structures for slot lookup + let slots = self.metadata_slots.clone(); + + // Helper: find slot index for a string + let find_string_slot = |s: &str| -> Option { + slots + .iter() + .position(|slot| matches!(slot, MetadataSlotKind::String(ss) if ss == s)) + }; + + // Helper: find slot index for a value + let find_value_slot = |ty: &Type, val: i64| -> Option { + slots.iter().position( + |slot| matches!(slot, MetadataSlotKind::Value(t, v) if t == ty && *v == val), + ) + }; + + let find_node_slot = |node_id: u32| -> Option { + slots + .iter() + .position(|slot| matches!(slot, MetadataSlotKind::Node(id) if *id == node_id)) + }; + + // Emit METADATA_STRING_OLD records + for slot in &slots { + if let MetadataSlotKind::String(s) = slot { + let chars: Vec = s.bytes().map(u64::from).collect(); + self.writer + .emit_record(METADATA_STRING_OLD, &chars, ABBREV_WIDTH); + } + } + + // Emit METADATA_VALUE records + for slot in &slots { + if let MetadataSlotKind::Value(ty, val) = slot { + let type_id = self.type_table.get_or_insert(ty); + let key = Self::module_constant_key(ty, &Constant::Int(*val))?; + let value_id = self.module_constant_ids.get(&key).copied().ok_or_else(|| { + WriteError::MissingMetadataConstant { + ty: ty.clone(), + value: *val, + } + })?; + self.writer.emit_record( + METADATA_VALUE, + &[u64::from(type_id), u64::from(value_id)], + ABBREV_WIDTH, + ); + } + } + + let (synthetic_nodes, visible_nodes) = lower_metadata_graph(&self.module.metadata_nodes); + + // Emit synthetic child METADATA_NODE records + for node in &synthetic_nodes { + let operands = encode_metadata_operands( + &node.values, + &find_string_slot, + &find_value_slot, + &find_node_slot, + )?; + self.writer + .emit_record(METADATA_NODE, &operands, ABBREV_WIDTH); + } + + // Emit visible METADATA_NODE records + for node in &visible_nodes { + let operands = encode_metadata_operands( + &node.values, + &find_string_slot, + &find_value_slot, + &find_node_slot, + )?; + self.writer + .emit_record(METADATA_NODE, &operands, ABBREV_WIDTH); + } + + // Emit named metadata + for nm in &self.module.named_metadata { + // METADATA_NAME + let name_chars: Vec = nm.name.bytes().map(u64::from).collect(); + self.writer + .emit_record(METADATA_NAME, &name_chars, ABBREV_WIDTH); + + // METADATA_NAMED_NODE — slot indexes for referenced visible nodes + let mut node_slot_refs: Vec = Vec::new(); + for &node_ref in &nm.node_refs { + let idx = find_node_slot(node_ref) + .ok_or(WriteError::MissingMetadataNode { node_id: node_ref })?; + node_slot_refs.push(idx as u64); + } + self.writer + .emit_record(METADATA_NAMED_NODE, &node_slot_refs, ABBREV_WIDTH); + } + + self.writer.exit_block(ABBREV_WIDTH); + + Ok(()) + } + + fn write_function_body(&mut self, func: &Function) -> Result<(), WriteError> { + if self.uses_modern_function_naming_container() + && let Some(&value_id) = self.global_value_ids.get(&func.name) + { + let function_bit_position = self.writer.bit_position(); + self.function_word_offsets + .insert(value_id, (function_bit_position / 32) as u32); + } + + self.writer + .enter_subblock(FUNCTION_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + // Declare number of basic blocks + let num_bbs = func.basic_blocks.len() as u64; + self.writer + .emit_record(FUNC_CODE_DECLAREBLOCKS, &[num_bbs], ABBREV_WIDTH); + + // Build local value mapping for this function + let mut local_value_ids = FxHashMap::default(); + let base_value_id = self.next_global_value_id; + let mut next_value_id: u32 = base_value_id; + + // Parameters get value IDs first + for p in &func.params { + if let Some(ref name) = p.name { + local_value_ids.insert(name.clone(), next_value_id); + } + next_value_id += 1; + } + + // Build basic block name -> index map + let bb_map = func + .basic_blocks + .iter() + .enumerate() + .map(|(i, bb)| (bb.name.clone(), i as u32)) + .collect::>(); + + // Collect constants used in this function + let (constants, ces) = self.collect_function_constants(func); + if !constants.is_empty() || !ces.is_empty() { + self.write_constants_block(&constants, &ces, &mut local_value_ids, &mut next_value_id)?; + } + + let mut reserved_value_id = next_value_id; + Self::reserve_function_result_ids(func, &mut local_value_ids, &mut reserved_value_id); + + // Write instructions + for bb in &func.basic_blocks { + for instr in &bb.instructions { + self.write_instruction(instr, &mut local_value_ids, &mut next_value_id, &bb_map)?; + } + } + + // Write value symbol table for function locals + let named_local_entries = + Self::collect_named_function_vst_entries(func, &local_value_ids, base_value_id); + self.write_function_vst(func, &bb_map, &named_local_entries); + + self.writer.exit_block(ABBREV_WIDTH); + + Ok(()) + } + + fn collect_named_function_vst_entries( + func: &Function, + local_value_ids: &FxHashMap, + base_value_id: u32, + ) -> Vec<(u32, String)> { + let mut entries = Vec::new(); + + for param in &func.params { + let Some(name) = ¶m.name else { + continue; + }; + let Some(&value_id) = local_value_ids.get(name) else { + continue; + }; + entries.push((value_id - base_value_id, name.clone())); + } + + for block in &func.basic_blocks { + for instruction in &block.instructions { + let Some(name) = Self::instruction_result_name(instruction) else { + continue; + }; + let Some(&value_id) = local_value_ids.get(name) else { + continue; + }; + entries.push((value_id - base_value_id, name.to_string())); + } + } + + entries + } + + fn reserve_function_result_ids( + func: &Function, + local_value_ids: &mut FxHashMap, + next_value_id: &mut u32, + ) { + for block in &func.basic_blocks { + for instruction in &block.instructions { + let Some(name) = Self::instruction_result_name(instruction) else { + continue; + }; + local_value_ids.insert(name.to_string(), *next_value_id); + *next_value_id += 1; + } + } + } + + fn instruction_result_name(instr: &Instruction) -> Option<&str> { + match instr { + Instruction::BinOp { result, .. } + | Instruction::ICmp { result, .. } + | Instruction::FCmp { result, .. } + | Instruction::Cast { result, .. } + | Instruction::Phi { result, .. } + | Instruction::Alloca { result, .. } + | Instruction::Load { result, .. } + | Instruction::Select { result, .. } + | Instruction::GetElementPtr { result, .. } + | Instruction::Call { + result: Some(result), + .. + } => Some(result.as_str()), + Instruction::Ret(_) + | Instruction::Call { result: None, .. } + | Instruction::Br { .. } + | Instruction::Jump { .. } + | Instruction::Store { .. } + | Instruction::Switch { .. } + | Instruction::Unreachable => None, + } + } + + fn collect_function_constants( + &self, + func: &Function, + ) -> (Vec<(Type, Constant)>, Vec<(Type, PendingCE)>) { + let mut constants: Vec<(Type, Constant)> = Vec::new(); + let mut ces: Vec<(Type, PendingCE)> = Vec::new(); + let mut seen: FxHashMap<(String, String), bool> = FxHashMap::default(); + + for bb in &func.basic_blocks { + for instr in &bb.instructions { + self.collect_constants_from_instruction(instr, &mut constants, &mut ces, &mut seen); + } + } + (constants, ces) + } + + fn collect_constants_from_instruction( + &self, + instr: &Instruction, + constants: &mut Vec<(Type, Constant)>, + ces: &mut Vec<(Type, PendingCE)>, + seen: &mut FxHashMap<(String, String), bool>, + ) { + match instr { + Instruction::Ret(Some(op)) => { + self.collect_constants_from_operand(op, constants, ces, seen); + } + Instruction::Br { cond, .. } => { + self.collect_constants_from_operand(cond, constants, ces, seen); + } + Instruction::BinOp { lhs, rhs, .. } => { + self.collect_constants_from_operand(lhs, constants, ces, seen); + self.collect_constants_from_operand(rhs, constants, ces, seen); + } + Instruction::ICmp { lhs, rhs, .. } | Instruction::FCmp { lhs, rhs, .. } => { + self.collect_constants_from_operand(lhs, constants, ces, seen); + self.collect_constants_from_operand(rhs, constants, ces, seen); + } + Instruction::Cast { value, .. } => { + self.collect_constants_from_operand(value, constants, ces, seen); + } + Instruction::Call { args, .. } => { + for (_, op) in args { + self.collect_constants_from_operand(op, constants, ces, seen); + } + } + Instruction::Phi { incoming, .. } => { + for (op, _) in incoming { + self.collect_constants_from_operand(op, constants, ces, seen); + } + } + Instruction::Load { ptr, .. } => { + self.collect_constants_from_operand(ptr, constants, ces, seen); + } + Instruction::Store { value, ptr, .. } => { + self.collect_constants_from_operand(value, constants, ces, seen); + self.collect_constants_from_operand(ptr, constants, ces, seen); + } + Instruction::Select { + cond, + true_val, + false_val, + .. + } => { + self.collect_constants_from_operand(cond, constants, ces, seen); + self.collect_constants_from_operand(true_val, constants, ces, seen); + self.collect_constants_from_operand(false_val, constants, ces, seen); + } + Instruction::Switch { value, .. } => { + self.collect_constants_from_operand(value, constants, ces, seen); + } + Instruction::GetElementPtr { ptr, indices, .. } => { + self.collect_constants_from_operand(ptr, constants, ces, seen); + for idx in indices { + self.collect_constants_from_operand(idx, constants, ces, seen); + } + } + Instruction::Ret(None) + | Instruction::Jump { .. } + | Instruction::Alloca { .. } + | Instruction::Unreachable => {} + } + } + + fn collect_constants_from_operand( + &self, + op: &Operand, + constants: &mut Vec<(Type, Constant)>, + ces: &mut Vec<(Type, PendingCE)>, + seen: &mut FxHashMap<(String, String), bool>, + ) { + match op { + Operand::IntConst(ty, val) => { + let key = (ty.to_string(), format!("int:{val}")); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) { + e.insert(true); + constants.push((ty.clone(), Constant::Int(*val))); + } + } + Operand::FloatConst(ty, val) => { + let bits = ty.encode_float_bits(*val).unwrap_or_else(|| val.to_bits()); + let key = (ty.to_string(), format!("float:{bits}")); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) { + e.insert(true); + constants.push((ty.clone(), Constant::float(ty.clone(), *val))); + } + } + Operand::NullPtr => { + let null_ptr_ty = self.emit_target.default_pointer_type(); + let key = (null_ptr_ty.to_string(), "null".to_string()); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) { + e.insert(true); + constants.push((null_ptr_ty, Constant::Null)); + } + } + Operand::IntToPtr(val, ty) => { + // Collect the integer constant for the CE's source operand + let int_ty = Type::Integer(64); + let key = (int_ty.to_string(), format!("int:{val}")); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) { + e.insert(true); + constants.push((int_ty, Constant::Int(*val))); + } + // Track CE for pass 2 emission + let ce_key = ("ce".to_string(), format!("inttoptr:{val}:{ty}")); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(ce_key) { + e.insert(true); + ces.push((ty.clone(), PendingCE::IntToPtr { val: *val })); + } + } + Operand::GetElementPtr { + ty: source_ty, + ptr: ptr_name, + indices, + .. + } => { + // Collect index constants for the CE's source operands + for idx in indices { + self.collect_constants_from_operand(idx, constants, ces, seen); + } + // Track CE for pass 2 emission + let idx_desc = indices + .iter() + .map(|i| format!("{i:?}")) + .collect::>() + .join(","); + let ce_key = ( + "ce".to_string(), + format!("gep:{source_ty}:{ptr_name}:{idx_desc}"), + ); + if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(ce_key) { + e.insert(true); + ces.push(( + Type::Ptr, + PendingCE::InboundsGep { + source_ty: source_ty.clone(), + ptr_name: ptr_name.clone(), + indices: indices.clone(), + }, + )); + } + } + Operand::LocalRef(_) | Operand::TypedLocalRef(_, _) | Operand::GlobalRef(_) => {} + } + } + + fn write_constants_block( + &mut self, + constants: &[(Type, Constant)], + pending_ces: &[(Type, PendingCE)], + local_value_ids: &mut FxHashMap, + next_value_id: &mut u32, + ) -> Result<(), WriteError> { + self.writer + .enter_subblock(CONSTANTS_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + let mut current_type: Option = None; + + // Pass 1: Emit regular constants + for (ty, cst) in constants { + let ty_id = self.type_table.get_or_insert(ty); + + // Emit SETTYPE if type changed + if current_type != Some(ty_id) { + self.writer + .emit_record(CST_CODE_SETTYPE, &[u64::from(ty_id)], ABBREV_WIDTH); + current_type = Some(ty_id); + } + + match cst { + Constant::Int(val) => { + let encoded = sign_rotate(*val); + self.writer + .emit_record(CST_CODE_INTEGER, &[encoded], ABBREV_WIDTH); + let name = format!("__const_int_{ty}_{val}"); + local_value_ids.insert(name, *next_value_id); + } + Constant::Float(float_ty, val) => { + let bits = float_ty.encode_float_bits(*val).ok_or_else(|| { + WriteError::InvalidFloatingConstant { + ty: float_ty.clone(), + value: *val, + } + })?; + self.writer + .emit_record(CST_CODE_FLOAT, &[bits], ABBREV_WIDTH); + let name = format!("__const_float_{float_ty}_{bits}"); + local_value_ids.insert(name, *next_value_id); + } + Constant::Null => { + self.writer.emit_record(CST_CODE_NULL, &[], ABBREV_WIDTH); + let name = format!("__const_null_{ty}"); + local_value_ids.insert(name, *next_value_id); + } + Constant::CString(_) => { + // CString constants are handled as global initializers, not in constant blocks + } + } + *next_value_id += 1; + } + + // Pass 2: Emit constant expressions with absolute value IDs + for (result_ty, ce) in pending_ces { + let result_ty_id = self.type_table.get_or_insert(result_ty); + if current_type != Some(result_ty_id) { + self.writer + .emit_record(CST_CODE_SETTYPE, &[u64::from(result_ty_id)], ABBREV_WIDTH); + current_type = Some(result_ty_id); + } + + match ce { + PendingCE::IntToPtr { val } => { + let src_type = Type::Integer(64); + let src_type_id = self.type_table.get_or_insert(&src_type); + let src_key = format!("__const_int_{src_type}_{val}"); + let src_value_id = local_value_ids.get(&src_key).copied().ok_or_else(|| { + WriteError::unresolved_operand( + "inttoptr constant expression source", + &Operand::IntConst(src_type.clone(), *val), + ) + })?; + self.writer.emit_record( + CST_CODE_CE_CAST, + &[10, u64::from(src_type_id), u64::from(src_value_id)], + ABBREV_WIDTH, + ); + let ce_key = format!("__ce_inttoptr_{val}_{result_ty}"); + local_value_ids.insert(ce_key, *next_value_id); + } + PendingCE::InboundsGep { + source_ty, + ptr_name, + indices, + } => { + let source_ty_id = self.type_table.get_or_insert(source_ty); + let gep_ptr_ty = self.emit_target.pointer_type_for_pointee(source_ty); + let ptr_type_id = self.type_table.get_or_insert(&gep_ptr_ty); + let ptr_value_id = + self.global_value_ids + .get(ptr_name) + .copied() + .ok_or_else(|| { + WriteError::unresolved_operand( + "getelementptr constant expression base pointer", + &Operand::GlobalRef(ptr_name.clone()), + ) + })?; + // Record format (odd length → first is pointee type): + // [pointee_type_id, ptr_type_id, ptr_value_id, idx_type, idx_val, ...] + let mut record = vec![ + u64::from(source_ty_id), + u64::from(ptr_type_id), + u64::from(ptr_value_id), + ]; + for idx in indices { + if let Operand::IntConst(idx_ty, idx_val) = idx { + let idx_type_id = self.type_table.get_or_insert(idx_ty); + let idx_key = format!("__const_int_{idx_ty}_{idx_val}"); + let idx_value_id = + local_value_ids.get(&idx_key).copied().ok_or_else(|| { + WriteError::unresolved_operand( + "getelementptr constant expression index", + idx, + ) + })?; + record.push(u64::from(idx_type_id)); + record.push(u64::from(idx_value_id)); + } + } + self.writer + .emit_record(CST_CODE_CE_INBOUNDS_GEP, &record, ABBREV_WIDTH); + let ce_key = Self::gep_ce_key(source_ty, ptr_name, indices); + local_value_ids.insert(ce_key, *next_value_id); + } + } + *next_value_id += 1; + } + + self.writer.exit_block(ABBREV_WIDTH); + + Ok(()) + } + + #[allow(clippy::too_many_lines)] + fn write_instruction( + &mut self, + instr: &Instruction, + local_value_ids: &mut FxHashMap, + next_value_id: &mut u32, + bb_map: &FxHashMap, + ) -> Result<(), WriteError> { + match instr { + Instruction::Ret(None) => { + self.writer + .emit_record(FUNC_CODE_INST_RET, &[], ABBREV_WIDTH); + } + Instruction::Ret(Some(op)) => { + let val = self.resolve_operand( + "return instruction", + op, + local_value_ids, + *next_value_id, + )?; + self.writer + .emit_record(FUNC_CODE_INST_RET, &[val], ABBREV_WIDTH); + } + Instruction::Jump { dest } => { + let bb_id = bb_map + .get(dest) + .copied() + .ok_or_else(|| WriteError::missing_basic_block("unconditional branch", dest))?; + self.writer + .emit_record(FUNC_CODE_INST_BR, &[u64::from(bb_id)], ABBREV_WIDTH); + } + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } => { + let true_id = bb_map.get(true_dest).copied().ok_or_else(|| { + WriteError::missing_basic_block( + "conditional branch true destination", + true_dest, + ) + })?; + let false_id = bb_map.get(false_dest).copied().ok_or_else(|| { + WriteError::missing_basic_block( + "conditional branch false destination", + false_dest, + ) + })?; + let cond_val = self.resolve_operand( + "conditional branch condition", + cond, + local_value_ids, + *next_value_id, + )?; + self.writer.emit_record( + FUNC_CODE_INST_BR, + &[u64::from(true_id), u64::from(false_id), cond_val], + ABBREV_WIDTH, + ); + } + Instruction::BinOp { + op, + lhs, + rhs, + result, + .. + } => { + let lhs_val = self.resolve_operand( + "binary operation lhs", + lhs, + local_value_ids, + *next_value_id, + )?; + let rhs_val = self.resolve_operand( + "binary operation rhs", + rhs, + local_value_ids, + *next_value_id, + )?; + let opcode = binop_to_opcode(op); + self.writer.emit_record( + FUNC_CODE_INST_BINOP, + &[lhs_val, rhs_val, opcode], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::ICmp { + pred, + lhs, + rhs, + result, + .. + } => { + let lhs_val = self.resolve_operand( + "integer comparison lhs", + lhs, + local_value_ids, + *next_value_id, + )?; + let rhs_val = self.resolve_operand( + "integer comparison rhs", + rhs, + local_value_ids, + *next_value_id, + )?; + let pred_code = icmp_predicate_code(pred); + self.writer.emit_record( + FUNC_CODE_INST_CMP2, + &[lhs_val, rhs_val, pred_code], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::FCmp { + pred, + lhs, + rhs, + result, + .. + } => { + let lhs_val = self.resolve_operand( + "floating comparison lhs", + lhs, + local_value_ids, + *next_value_id, + )?; + let rhs_val = self.resolve_operand( + "floating comparison rhs", + rhs, + local_value_ids, + *next_value_id, + )?; + let pred_code = fcmp_predicate_code(pred); + self.writer.emit_record( + FUNC_CODE_INST_CMP2, + &[lhs_val, rhs_val, pred_code], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Cast { + op: cast_op, + to_ty, + value, + result, + .. + } => { + let val = + self.resolve_operand("cast operand", value, local_value_ids, *next_value_id)?; + let to_ty_id = self.type_table.get_or_insert(to_ty); + let cast_opcode = cast_to_opcode(cast_op); + self.writer.emit_record( + FUNC_CODE_INST_CAST, + &[val, u64::from(to_ty_id), cast_opcode], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Call { + callee, + args, + result, + attr_refs, + .. + } => { + let call_context = format!("call instruction @{callee}"); + let func_ty = self.get_function_type(callee)?; + let callee_operand = Operand::GlobalRef(callee.clone()); + let callee_val = self.resolve_operand( + call_context.clone(), + &callee_operand, + local_value_ids, + *next_value_id, + )?; + let func_ty_id = self.type_table.get_or_insert(&func_ty); + let packed_call_cc_info = if self.emit_target == QirEmitTarget::QirV2Opaque { + CALL_EXPLICIT_TYPE_FLAG + } else { + 0 + }; + let paramattr = if attr_refs.is_empty() { + 0 + } else { + self.resolve_attr_list_index(attr_refs, false, call_context.clone())? + }; + + // Opaque-pointer CALL records must set the explicit function-type flag + // in the packed cc-info operand so external LLVM decodes the callee slot + // using the modern layout. + let mut values = vec![ + paramattr, + packed_call_cc_info, + u64::from(func_ty_id), // function type + callee_val, // callee value ID + ]; + for (_, op) in args { + let arg_val = self.resolve_operand( + call_context.clone(), + op, + local_value_ids, + *next_value_id, + )?; + values.push(arg_val); + } + self.writer + .emit_record(FUNC_CODE_INST_CALL, &values, ABBREV_WIDTH); + + if let Some(res) = result { + local_value_ids.insert(res.clone(), *next_value_id); + *next_value_id += 1; + } + } + Instruction::Phi { + ty, + incoming, + result, + } => { + let ty_id = self.type_table.get_or_insert(ty); + let mut values = vec![u64::from(ty_id)]; + for (op, bb_name) in incoming { + let val = self.resolve_phi_operand( + "phi incoming value", + op, + local_value_ids, + *next_value_id, + )?; + let bb_id = bb_map.get(bb_name).copied().ok_or_else(|| { + WriteError::missing_basic_block("phi incoming edge", bb_name) + })?; + values.push(val); + values.push(u64::from(bb_id)); + } + self.writer + .emit_record(FUNC_CODE_INST_PHI, &values, ABBREV_WIDTH); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Alloca { ty, result } => { + let ty_id = self.type_table.get_or_insert(ty); + let alloca_ptr_ty = self.emit_target.pointer_type_for_pointee(ty); + let ptr_ty_id = self.type_table.get_or_insert(&alloca_ptr_ty); + // ALLOCA: [instty, opty, op, align] + let i32_ty_id = self.type_table.get_or_insert(&Type::Integer(32)); + self.writer.emit_record( + FUNC_CODE_INST_ALLOCA, + &[ + u64::from(ty_id), + u64::from(i32_ty_id), + 0, + u64::from(ptr_ty_id), + ], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Load { + ty, ptr, result, .. + } => { + let ptr_val = + self.resolve_operand("load pointer", ptr, local_value_ids, *next_value_id)?; + let ty_id = self.type_table.get_or_insert(ty); + // LOAD: [opty, op, ty, align, vol] + self.writer.emit_record( + FUNC_CODE_INST_LOAD, + &[ptr_val, u64::from(ty_id), 0, 0], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Store { value, ptr, .. } => { + let ptr_val = + self.resolve_operand("store pointer", ptr, local_value_ids, *next_value_id)?; + let val = + self.resolve_operand("store value", value, local_value_ids, *next_value_id)?; + // STORE: [ptrty, ptr, valty, val, align, vol] + self.writer + .emit_record(FUNC_CODE_INST_STORE, &[ptr_val, val, 0, 0], ABBREV_WIDTH); + } + Instruction::Select { + cond, + true_val, + false_val, + result, + .. + } => { + let true_v = self.resolve_operand( + "select true value", + true_val, + local_value_ids, + *next_value_id, + )?; + let false_v = self.resolve_operand( + "select false value", + false_val, + local_value_ids, + *next_value_id, + )?; + let cond_v = self.resolve_operand( + "select condition", + cond, + local_value_ids, + *next_value_id, + )?; + self.writer.emit_record( + FUNC_CODE_INST_SELECT, + &[true_v, false_v, cond_v], + ABBREV_WIDTH, + ); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + Instruction::Switch { + ty, + value, + default_dest, + cases, + } => { + let ty_id = self.type_table.get_or_insert(ty); + let val = self.resolve_operand( + "switch selector", + value, + local_value_ids, + *next_value_id, + )?; + let default_id = bb_map.get(default_dest).copied().ok_or_else(|| { + WriteError::missing_basic_block("switch default destination", default_dest) + })?; + let mut values = vec![u64::from(ty_id), val, u64::from(default_id)]; + for (case_val, dest) in cases { + values.push(sign_rotate(*case_val)); + let dest_id = bb_map.get(dest).copied().ok_or_else(|| { + WriteError::missing_basic_block("switch case destination", dest) + })?; + values.push(u64::from(dest_id)); + } + self.writer + .emit_record(FUNC_CODE_INST_SWITCH, &values, ABBREV_WIDTH); + } + Instruction::Unreachable => { + self.writer + .emit_record(FUNC_CODE_INST_UNREACHABLE, &[], ABBREV_WIDTH); + } + Instruction::GetElementPtr { + inbounds, + pointee_ty, + ptr, + indices, + result, + .. + } => { + let inbounds_flag = u64::from(*inbounds); + let pointee_type_id = self.type_table.get_or_insert(pointee_ty); + let ptr_val = self.resolve_operand( + "getelementptr base pointer", + ptr, + local_value_ids, + *next_value_id, + )?; + let mut values = vec![inbounds_flag, u64::from(pointee_type_id), ptr_val]; + for idx in indices { + let idx_val = self.resolve_operand( + "getelementptr index", + idx, + local_value_ids, + *next_value_id, + )?; + values.push(idx_val); + } + self.writer + .emit_record(FUNC_CODE_INST_GEP, &values, ABBREV_WIDTH); + local_value_ids.insert(result.clone(), *next_value_id); + *next_value_id += 1; + } + } + + Ok(()) + } + + fn resolve_operand( + &self, + context: impl Into, + op: &Operand, + local_value_ids: &FxHashMap, + current_value_id: u32, + ) -> Result { + let context = context.into(); + let id = self.resolve_operand_value_id(&context, local_value_ids, op)?; + current_value_id + .checked_sub(id) + .map(u64::from) + .ok_or_else(|| WriteError::unresolved_operand(context, op)) + } + + fn resolve_phi_operand( + &self, + context: impl Into, + op: &Operand, + local_value_ids: &FxHashMap, + current_value_id: u32, + ) -> Result { + let context = context.into(); + let id = self.resolve_operand_value_id(&context, local_value_ids, op)?; + Ok(sign_rotate(i64::from(current_value_id) - i64::from(id))) + } + + fn resolve_operand_value_id( + &self, + context: &str, + local_value_ids: &FxHashMap, + op: &Operand, + ) -> Result { + let value_id = match op { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => { + local_value_ids.get(name).copied() + } + Operand::GlobalRef(name) => self.global_value_ids.get(name).copied(), + Operand::IntConst(ty, val) => { + let key = format!("__const_int_{ty}_{val}"); + local_value_ids.get(&key).copied() + } + Operand::FloatConst(ty, val) => { + let bits = ty.encode_float_bits(*val).ok_or_else(|| { + WriteError::InvalidFloatingConstant { + ty: ty.clone(), + value: *val, + } + })?; + let key = format!("__const_float_{ty}_{bits}"); + local_value_ids.get(&key).copied() + } + Operand::NullPtr => { + let key = format!("__const_null_{}", self.emit_target.default_pointer_type()); + local_value_ids.get(&key).copied() + } + Operand::IntToPtr(val, ty) => { + let key = format!("__ce_inttoptr_{val}_{ty}"); + local_value_ids.get(&key).copied() + } + Operand::GetElementPtr { + ty: source_ty, + ptr: ptr_name, + indices, + .. + } => { + let key = Self::gep_ce_key(source_ty, ptr_name, indices); + local_value_ids.get(&key).copied() + } + }; + + value_id.ok_or_else(|| WriteError::unresolved_operand(context.to_string(), op)) + } + + fn gep_ce_key(source_ty: &Type, ptr_name: &str, indices: &[Operand]) -> String { + let idx_desc = indices + .iter() + .map(|i| match i { + Operand::IntConst(ty, val) => format!("{ty}:{val}"), + _ => "?".to_string(), + }) + .collect::>() + .join(","); + format!("__ce_gep_{source_ty}_{ptr_name}_{idx_desc}") + } + + fn get_function_type(&self, name: &str) -> Result { + for f in &self.module.functions { + if f.name == name { + return Ok(Type::Function( + Box::new(f.return_type.clone()), + f.params.iter().map(|p| p.ty.clone()).collect(), + )); + } + } + Err(WriteError::UnknownCallee { + callee: name.to_string(), + }) + } + + fn resolve_attr_list_index( + &self, + attr_refs: &[u32], + normalize: bool, + context: String, + ) -> Result { + let mut key = attr_refs.to_vec(); + if normalize { + key.sort_unstable(); + } + + self.attr_list_table + .iter() + .position(|entry| *entry == key) + .map(|idx| (idx + 1) as u64) + .ok_or_else(|| WriteError::MissingAttributeList { + context, + attr_refs: attr_refs.to_vec(), + }) + } + + fn write_value_symtab(&mut self) { + if self.uses_modern_function_naming_container() { + self.patch_module_vst_offset(); + + self.writer + .enter_subblock(VALUE_SYMTAB_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + for global in &self.module.globals { + let Some(&value_id) = self.global_value_ids.get(&global.name) else { + continue; + }; + + let mut values = vec![u64::from(value_id)]; + values.extend(global.name.bytes().map(u64::from)); + self.writer + .emit_record(VST_CODE_ENTRY, &values, ABBREV_WIDTH); + } + + for function in &self.module.functions { + if function.is_declaration { + continue; + } + + let Some(&value_id) = self.global_value_ids.get(&function.name) else { + continue; + }; + let Some(&function_word_offset) = self.function_word_offsets.get(&value_id) else { + continue; + }; + + let mut values = vec![u64::from(value_id), u64::from(function_word_offset)]; + values.extend(function.name.bytes().map(u64::from)); + self.writer + .emit_record(VST_CODE_FNENTRY, &values, ABBREV_WIDTH); + } + + self.writer.exit_block(ABBREV_WIDTH); + return; + } + + let mut entries: Vec<(u32, String)> = Vec::new(); + + // Add global names + for g in &self.module.globals { + if let Some(&id) = self.global_value_ids.get(&g.name) { + entries.push((id, g.name.clone())); + } + } + for f in &self.module.functions { + if let Some(&id) = self.global_value_ids.get(&f.name) { + entries.push((id, f.name.clone())); + } + } + + if entries.is_empty() { + return; + } + + self.writer + .enter_subblock(VALUE_SYMTAB_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + for (id, name) in &entries { + let mut values: Vec = vec![u64::from(*id)]; + for b in name.bytes() { + values.push(u64::from(b)); + } + self.writer + .emit_record(VST_CODE_ENTRY, &values, ABBREV_WIDTH); + } + + self.writer.exit_block(ABBREV_WIDTH); + } + + fn write_function_vst( + &mut self, + func: &Function, + bb_map: &FxHashMap, + named_local_entries: &[(u32, String)], + ) { + // Write basic block entries in function VST + let has_named_bbs = func.basic_blocks.iter().any(|bb| !bb.name.is_empty()); + if named_local_entries.is_empty() && !has_named_bbs { + return; + } + + self.writer + .enter_subblock(VALUE_SYMTAB_BLOCK_ID, ABBREV_WIDTH, ABBREV_WIDTH); + + for (local_id, name) in named_local_entries { + let mut values: Vec = vec![u64::from(*local_id)]; + for b in name.bytes() { + values.push(u64::from(b)); + } + self.writer + .emit_record(VST_CODE_ENTRY, &values, ABBREV_WIDTH); + } + + for bb in &func.basic_blocks { + if !bb.name.is_empty() + && let Some(&bb_id) = bb_map.get(&bb.name) + { + let mut values: Vec = vec![u64::from(bb_id)]; + for b in bb.name.bytes() { + values.push(u64::from(b)); + } + self.writer + .emit_record(VST_CODE_BBENTRY, &values, ABBREV_WIDTH); + } + } + + self.writer.exit_block(ABBREV_WIDTH); + } +} + +fn infer_emit_target(module: &Module) -> QirEmitTarget { + if let Some(MetadataValue::Int(_, major_version)) = module.get_flag(QIR_MAJOR_VERSION_KEY) { + return match *major_version { + 1 => QirEmitTarget::QirV1Typed, + _ => QirEmitTarget::QirV2Opaque, + }; + } + + if module_contains_typed_pointers(module) { + QirEmitTarget::QirV1Typed + } else { + QirEmitTarget::QirV2Opaque + } +} + +fn module_contains_typed_pointers(module: &Module) -> bool { + module + .globals + .iter() + .any(|global| type_contains_typed_pointers(&global.ty)) + || module + .functions + .iter() + .any(function_contains_typed_pointers) +} + +fn function_contains_typed_pointers(function: &Function) -> bool { + type_contains_typed_pointers(&function.return_type) + || function + .params + .iter() + .any(|param| type_contains_typed_pointers(¶m.ty)) + || function + .basic_blocks + .iter() + .flat_map(|block| block.instructions.iter()) + .any(instruction_contains_typed_pointers) +} + +fn instruction_contains_typed_pointers(instruction: &Instruction) -> bool { + match instruction { + Instruction::Ret(Some(operand)) => operand_contains_typed_pointers(operand), + Instruction::Ret(None) | Instruction::Jump { .. } | Instruction::Unreachable => false, + Instruction::Br { cond_ty, cond, .. } => { + type_contains_typed_pointers(cond_ty) || operand_contains_typed_pointers(cond) + } + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + type_contains_typed_pointers(ty) + || operand_contains_typed_pointers(lhs) + || operand_contains_typed_pointers(rhs) + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + type_contains_typed_pointers(from_ty) + || type_contains_typed_pointers(to_ty) + || operand_contains_typed_pointers(value) + } + Instruction::Call { + return_ty, args, .. + } => { + return_ty.as_ref().is_some_and(type_contains_typed_pointers) + || args.iter().any(|(ty, operand)| { + type_contains_typed_pointers(ty) || operand_contains_typed_pointers(operand) + }) + } + Instruction::Phi { ty, incoming, .. } => { + type_contains_typed_pointers(ty) + || incoming + .iter() + .any(|(operand, _)| operand_contains_typed_pointers(operand)) + } + Instruction::Alloca { ty, .. } => type_contains_typed_pointers(ty), + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + type_contains_typed_pointers(ty) + || type_contains_typed_pointers(ptr_ty) + || operand_contains_typed_pointers(ptr) + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + type_contains_typed_pointers(ty) + || operand_contains_typed_pointers(value) + || type_contains_typed_pointers(ptr_ty) + || operand_contains_typed_pointers(ptr) + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + operand_contains_typed_pointers(cond) + || operand_contains_typed_pointers(true_val) + || operand_contains_typed_pointers(false_val) + || type_contains_typed_pointers(ty) + } + Instruction::Switch { ty, value, .. } => { + type_contains_typed_pointers(ty) || operand_contains_typed_pointers(value) + } + Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + ptr, + indices, + .. + } => { + type_contains_typed_pointers(pointee_ty) + || type_contains_typed_pointers(ptr_ty) + || operand_contains_typed_pointers(ptr) + || indices.iter().any(operand_contains_typed_pointers) + } + } +} + +fn operand_contains_typed_pointers(operand: &Operand) -> bool { + match operand { + Operand::LocalRef(_) | Operand::GlobalRef(_) | Operand::NullPtr => false, + Operand::TypedLocalRef(_, ty) | Operand::IntConst(ty, _) | Operand::FloatConst(ty, _) => { + type_contains_typed_pointers(ty) + } + Operand::IntToPtr(_, ty) => type_contains_typed_pointers(ty), + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + type_contains_typed_pointers(ty) + || type_contains_typed_pointers(ptr_ty) + || indices.iter().any(operand_contains_typed_pointers) + } + } +} + +fn type_contains_typed_pointers(ty: &Type) -> bool { + match ty { + Type::NamedPtr(_) | Type::TypedPtr(_) => true, + Type::Array(_, element) => type_contains_typed_pointers(element), + Type::Function(result, params) => { + type_contains_typed_pointers(result) || params.iter().any(type_contains_typed_pointers) + } + Type::Void + | Type::Integer(_) + | Type::Half + | Type::Float + | Type::Double + | Type::Label + | Type::Ptr + | Type::Named(_) => false, + } +} + +fn sign_rotate(val: i64) -> u64 { + let magnitude = val.unsigned_abs(); + if val >= 0 { + magnitude << 1 + } else { + (magnitude << 1) | 1 + } +} + +fn binop_to_opcode(op: &BinOpKind) -> u64 { + match op { + BinOpKind::Add | BinOpKind::Fadd => 0, + BinOpKind::Sub | BinOpKind::Fsub => 1, + BinOpKind::Mul | BinOpKind::Fmul => 2, + BinOpKind::Udiv => 3, + BinOpKind::Sdiv | BinOpKind::Fdiv => 4, + BinOpKind::Urem => 5, + BinOpKind::Srem => 6, + BinOpKind::Shl => 7, + BinOpKind::Lshr => 8, + BinOpKind::Ashr => 9, + BinOpKind::And => 10, + BinOpKind::Or => 11, + BinOpKind::Xor => 12, + } +} + +fn icmp_predicate_code(pred: &IntPredicate) -> u64 { + match pred { + IntPredicate::Eq => 32, + IntPredicate::Ne => 33, + IntPredicate::Ugt => 34, + IntPredicate::Uge => 35, + IntPredicate::Ult => 36, + IntPredicate::Ule => 37, + IntPredicate::Sgt => 38, + IntPredicate::Sge => 39, + IntPredicate::Slt => 40, + IntPredicate::Sle => 41, + } +} + +fn fcmp_predicate_code(pred: &FloatPredicate) -> u64 { + match pred { + FloatPredicate::Oeq => 1, + FloatPredicate::Ogt => 2, + FloatPredicate::Oge => 3, + FloatPredicate::Olt => 4, + FloatPredicate::Ole => 5, + FloatPredicate::One => 6, + FloatPredicate::Ord => 7, + FloatPredicate::Uno => 8, + FloatPredicate::Ueq => 9, + FloatPredicate::Ugt => 10, + FloatPredicate::Uge => 11, + FloatPredicate::Ult => 12, + FloatPredicate::Ule => 13, + FloatPredicate::Une => 14, + } +} + +fn cast_to_opcode(op: &CastKind) -> u64 { + match op { + CastKind::Trunc => 0, + CastKind::Zext => 1, + CastKind::Sext => 2, + CastKind::FpTrunc => 4, + CastKind::FpExt => 5, + CastKind::Sitofp => 6, + CastKind::Fptosi => 7, + CastKind::PtrToInt => 9, + CastKind::IntToPtr => 10, + CastKind::Bitcast => 11, + } +} diff --git a/source/compiler/qsc_llvm/src/bitcode/writer/tests.rs b/source/compiler/qsc_llvm/src/bitcode/writer/tests.rs new file mode 100644 index 0000000000..ef0ae1ad55 --- /dev/null +++ b/source/compiler/qsc_llvm/src/bitcode/writer/tests.rs @@ -0,0 +1,448 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::super::bitstream::BitstreamReader; +use super::*; +use crate::model::test_helpers::*; +use crate::model::{Attribute, AttributeGroup, BasicBlock, Param, StructType}; +use crate::parse_bitcode; +use crate::qir::QirEmitTarget; + +fn round_trip_module(module: &Module) -> Module { + let bc = write_bitcode(module); + parse_bitcode(&bc).expect("should parse round-tripped bitcode") +} + +fn simple_function_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: Vec::new(), + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn scan_module_version_and_top_level_blocks(bitcode: &[u8]) -> (Option, Vec) { + let mut reader = BitstreamReader::new(bitcode); + assert_eq!(reader.read_bits(8), 0x42); + assert_eq!(reader.read_bits(8), 0x43); + assert_eq!(reader.read_bits(8), 0xC0); + assert_eq!(reader.read_bits(8), 0xDE); + + let mut module_version = None; + let mut top_level_blocks = Vec::new(); + + while !reader.at_end() { + let abbrev_id = reader.read_abbrev_id(TOP_LEVEL_ABBREV_WIDTH); + match abbrev_id { + 0 => reader.align32(), + 1 => { + let (block_id, new_abbrev_width, block_len_words) = reader.enter_subblock(); + top_level_blocks.push(block_id); + + if block_id != MODULE_BLOCK_ID { + reader.skip_block(block_len_words); + continue; + } + + reader.push_block_scope(MODULE_BLOCK_ID); + loop { + let module_abbrev_id = reader.read_abbrev_id(new_abbrev_width); + match module_abbrev_id { + 0 => { + reader.align32(); + break; + } + 1 => { + let (_, _, nested_len_words) = reader.enter_subblock(); + reader.skip_block(nested_len_words); + } + 2 => reader + .read_define_abbrev() + .expect("module DEFINE_ABBREV should decode"), + 3 => { + let (code, values) = reader.read_unabbrev_record(); + if code == MODULE_CODE_VERSION { + module_version = values.first().copied(); + } + } + id => { + let (code, values) = reader + .read_abbreviated_record(id) + .expect("module abbreviated record should decode"); + if code == MODULE_CODE_VERSION { + module_version = values.first().copied(); + } + } + } + } + reader.pop_block_scope(); + } + 2 => reader + .read_define_abbrev() + .expect("top-level DEFINE_ABBREV should decode"), + 3 => { + let _ = reader.read_unabbrev_record(); + } + id => { + let _ = reader + .read_abbreviated_record(id) + .expect("top-level abbreviated record should decode"); + } + } + } + + (module_version, top_level_blocks) +} + +#[test] +fn magic_bytes_present() { + let m = empty_module(); + let bc = write_bitcode(&m); + assert!(bc.len() >= 4); + assert_eq!(&bc[0..4], &[0x42, 0x43, 0xC0, 0xDE]); +} + +#[test] +fn empty_module_produces_valid_bitcode() { + let m = empty_module(); + let bc = write_bitcode(&m); + assert_eq!(&bc[0..4], &[0x42, 0x43, 0xC0, 0xDE]); + // Must have content beyond just the magic bytes + assert!(bc.len() > 4); + // Length must be 4-byte aligned (bitstream alignment) + assert_eq!(bc.len() % 4, 0); +} + +#[test] +fn sign_rotate_matches_llvm_dense_vbr_contract() { + assert_eq!(sign_rotate(0), 0); + assert_eq!(sign_rotate(1), 2); + assert_eq!(sign_rotate(2), 4); + assert_eq!(sign_rotate(-1), 3); + assert_eq!(sign_rotate(-2), 5); + assert_eq!(sign_rotate(i64::MIN), 1); +} + +#[test] +fn module_with_declaration_produces_bitcode() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let bc = write_bitcode(&m); + assert_eq!(&bc[0..4], &[0x42, 0x43, 0xC0, 0xDE]); + assert!(bc.len() > 4); + assert_eq!(bc.len() % 4, 0); +} + +#[test] +fn typed_target_keeps_legacy_module_layout_without_top_level_strtab() { + let bitcode = write_bitcode_for_target(&simple_function_module(), QirEmitTarget::QirV1Typed); + let (module_version, top_level_blocks) = scan_module_version_and_top_level_blocks(&bitcode); + + assert_eq!(module_version, Some(1)); + assert!(!top_level_blocks.contains(&STRTAB_BLOCK_ID)); +} + +#[test] +fn opaque_target_emits_module_v2_and_top_level_strtab() { + let bitcode = write_bitcode_for_target(&simple_function_module(), QirEmitTarget::QirV2Opaque); + let (module_version, top_level_blocks) = scan_module_version_and_top_level_blocks(&bitcode); + + assert_eq!(module_version, Some(2)); + assert!(top_level_blocks.contains(&STRTAB_BLOCK_ID)); +} + +#[test] +fn bitcode_round_trip_preserves_named_ptr_argument_shapes() { + let qubit_ty = Type::NamedPtr("Qubit".to_string()); + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".to_string(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![ + Function { + name: "takes_qubit".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: qubit_ty.clone(), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "takes_qubit".to_string(), + args: vec![(qubit_ty.clone(), Operand::int_to_named_ptr(7, "Qubit"))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let parsed = round_trip_module(&m); + + assert_eq!(parsed.functions[0].params[0].ty, qubit_ty.clone()); + match &parsed.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { + return_ty, + args, + result, + attr_refs, + .. + } => { + assert_eq!(return_ty, &None); + assert_eq!( + args, + &vec![(qubit_ty, Operand::int_to_named_ptr(7, "Qubit"))] + ); + assert_eq!(result, &None); + assert!(attr_refs.is_empty()); + } + other => panic!("expected call instruction, found {other:?}"), + } +} + +#[test] +fn bitcode_roundtrip_preserves_call_site_attr_refs_and_function_attrs() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Integer(64), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: vec![( + Type::Integer(64), + Operand::IntConst(Type::Integer(64), 7), + )], + result: None, + attr_refs: vec![0], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: vec![AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr("alwaysinline".to_string())], + }], + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let parsed = round_trip_module(&m); + + assert_eq!(parsed.attribute_groups, m.attribute_groups); + assert_eq!(parsed.functions[1].attribute_group_refs, vec![0]); + match &parsed.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { + args, + attr_refs, + result, + .. + } => { + assert_eq!( + args, + &vec![(Type::Integer(64), Operand::IntConst(Type::Integer(64), 7))] + ); + assert_eq!(attr_refs, &vec![0]); + assert_eq!(result, &None); + } + other => panic!("expected call instruction, found {other:?}"), + } +} + +#[test] +fn try_write_bitcode_reports_unknown_callee() { + let mut module = simple_function_module(); + let Instruction::Call { callee, .. } = &mut module.functions[1].basic_blocks[0].instructions[0] + else { + panic!("expected call instruction in test module"); + }; + *callee = "missing".to_string(); + + let err = try_write_bitcode(&module).expect_err("missing callee should fail emission"); + + assert_eq!( + err, + WriteError::UnknownCallee { + callee: "missing".to_string(), + } + ); +} + +#[test] +fn try_write_bitcode_reports_missing_branch_target() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "missing".to_string(), + }], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let err = try_write_bitcode(&module).expect_err("missing branch target should fail emission"); + + assert_eq!( + err, + WriteError::MissingBasicBlock { + context: "unconditional branch".to_string(), + block: "missing".to_string(), + } + ); +} + +#[test] +fn try_write_bitcode_reports_invalid_float_constant_type() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::FloatConst( + Type::Integer(64), + 1.0, + )))], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let err = + try_write_bitcode(&module).expect_err("invalid float constant type should fail emission"); + + assert_eq!( + err, + WriteError::InvalidFloatingConstant { + ty: Type::Integer(64), + value: 1.0, + } + ); +} diff --git a/source/compiler/qsc_llvm/src/fuzz.rs b/source/compiler/qsc_llvm/src/fuzz.rs new file mode 100644 index 0000000000..6ba7a82181 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz.rs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +pub mod mutation; +pub mod qir_mutations; +pub mod qir_smith; diff --git a/source/compiler/qsc_llvm/src/fuzz/mutation.rs b/source/compiler/qsc_llvm/src/fuzz/mutation.rs new file mode 100644 index 0000000000..69df54aa14 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/mutation.rs @@ -0,0 +1,1020 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::model::{BasicBlock, BinOpKind, Function, Instruction, Operand, Param, Type}; +use crate::{ + GeneratedArtifact, Module, QirProfilePreset, ReadDiagnostic, ReadDiagnosticKind, ReadPolicy, + ReadReport, parse_bitcode_compatibility_report, parse_bitcode_detailed, parse_module_detailed, + validate_ir, validate_qir_profile, +}; + +#[derive(Copy, Clone)] +pub enum MutationKind { + ReferenceOrdering, + PhiStructure, + Dominance, + InvalidBranchTarget, + CallShape, + GepShape, + AttributeRef, +} + +impl MutationKind { + pub fn from_data(data: &[u8]) -> Self { + match mutation_selector(data, 0) % 7 { + 0 => Self::ReferenceOrdering, + 1 => Self::PhiStructure, + 2 => Self::Dominance, + 3 => Self::InvalidBranchTarget, + 4 => Self::CallShape, + 5 => Self::GepShape, + _ => Self::AttributeRef, + } + } +} + +pub type SeedMutator = fn(&Module, &[u8]) -> Module; + +pub fn mutation_selector(data: &[u8], index: usize) -> u8 { + data.get(index).copied().unwrap_or_default() +} + +fn validate_seed_module(module: &Module) { + // Seed modules continue to exercise the existing no-panic validation path. + let _profile_result = validate_qir_profile(module); + let _ir_result = validate_ir(module); +} + +pub fn validate_seed_artifact(artifact: &GeneratedArtifact) { + validate_seed_module(&artifact.module); +} + +pub fn validate_mutated_module(module: &Module) { + // Mutated modules intentionally violate structural rules, so only the LLVM + // IR validator is relevant on this lane. + let _mutated_ir_result = validate_ir(module); +} + +fn assert_meaningful_diagnostics(label: &str, diagnostics: &[ReadDiagnostic]) { + assert!( + !diagnostics.is_empty(), + "{label} should report at least one diagnostic" + ); + + for diagnostic in diagnostics { + assert!( + matches!( + diagnostic.kind, + ReadDiagnosticKind::MalformedInput + | ReadDiagnosticKind::UnsupportedSemanticConstruct + ), + "{label} returned an unexpected diagnostic kind: {diagnostic:?}" + ); + assert!( + !diagnostic.context.is_empty(), + "{label} returned a diagnostic without context: {diagnostic:?}" + ); + assert!( + !diagnostic.message.trim().is_empty(), + "{label} returned a diagnostic without a message: {diagnostic:?}" + ); + } +} + +fn assert_detailed_result_stable( + label: &str, + first: &Result>, + second: &Result>, +) where + T: PartialEq + std::fmt::Debug, +{ + assert_eq!( + first, second, + "{label} changed outcome between repeated detailed parses" + ); + + if let Err(diagnostics) = first { + assert_meaningful_diagnostics(label, diagnostics); + } +} + +fn assert_report_result_stable( + label: &str, + first: &Result>, + second: &Result>, +) { + assert_eq!( + first, second, + "{label} changed outcome between repeated compatibility reports" + ); + + match first { + Ok(report) if !report.diagnostics.is_empty() => { + assert_meaningful_diagnostics(label, &report.diagnostics); + } + Err(diagnostics) => assert_meaningful_diagnostics(label, diagnostics), + Ok(_) => {} + } +} + +fn compile_raw_bitcode(data: &[u8]) { + let strict_first = parse_bitcode_detailed(data, ReadPolicy::QirSubsetStrict); + let strict_second = parse_bitcode_detailed(data, ReadPolicy::QirSubsetStrict); + assert_detailed_result_stable("raw bitcode strict", &strict_first, &strict_second); + + let compatibility_first = parse_bitcode_detailed(data, ReadPolicy::Compatibility); + let compatibility_second = parse_bitcode_detailed(data, ReadPolicy::Compatibility); + assert_detailed_result_stable( + "raw bitcode compatibility", + &compatibility_first, + &compatibility_second, + ); + + let report_first = parse_bitcode_compatibility_report(data); + let report_second = parse_bitcode_compatibility_report(data); + assert_report_result_stable( + "raw bitcode compatibility report", + &report_first, + &report_second, + ); + + if compatibility_first.is_err() { + assert!( + strict_first.is_err(), + "strict mode should not salvage raw bitcode after compatibility diagnostics" + ); + } + + assert!( + !(strict_first.is_err() && compatibility_first.is_ok()), + "compatibility mode accepted raw bitcode without diagnostics while strict mode rejected it" + ); + + if let Ok(report) = &report_first { + if !report.diagnostics.is_empty() { + assert!( + strict_first.is_err(), + "strict mode should not salvage raw bitcode that compatibility recovered with diagnostics" + ); + assert!( + compatibility_first.is_err(), + "detailed compatibility parsing should surface recovery diagnostics as an error" + ); + } + } else { + assert!( + strict_first.is_err(), + "strict mode should not salvage raw bitcode after compatibility report failure" + ); + assert!( + compatibility_first.is_err(), + "detailed compatibility parsing should fail when compatibility reporting fails" + ); + } +} + +fn compile_raw_utf8_text(data: &[u8]) { + let Ok(text) = std::str::from_utf8(data) else { + return; + }; + + let strict_first = parse_module_detailed(text, ReadPolicy::QirSubsetStrict); + let strict_second = parse_module_detailed(text, ReadPolicy::QirSubsetStrict); + assert_detailed_result_stable("raw utf-8 text strict", &strict_first, &strict_second); + + let compatibility_first = parse_module_detailed(text, ReadPolicy::Compatibility); + let compatibility_second = parse_module_detailed(text, ReadPolicy::Compatibility); + assert_detailed_result_stable( + "raw utf-8 text compatibility", + &compatibility_first, + &compatibility_second, + ); + + if compatibility_first.is_err() { + assert!( + strict_first.is_err(), + "strict mode should not salvage malformed UTF-8 text after compatibility diagnostics" + ); + } + + assert!( + !(strict_first.is_err() && compatibility_first.is_ok()), + "compatibility mode accepted raw UTF-8 text without diagnostics while strict mode rejected it" + ); +} + +pub fn compile_raw_parser_lanes(data: &[u8]) { + compile_raw_bitcode(data); + compile_raw_utf8_text(data); +} + +pub fn dispatch_mutation_family(module: &mut Module, kind: MutationKind, selector: u8) { + // Keep select and memory-shape mutations deferred so failures stay attributable. + match kind { + MutationKind::ReferenceOrdering => mutate_reference_ordering(module, selector), + MutationKind::PhiStructure => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_phi_structure(function, selector); + } + MutationKind::Dominance => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_dominance(function, selector); + } + MutationKind::InvalidBranchTarget => mutate_invalid_branch_target(module, selector), + MutationKind::CallShape => mutate_call_shape(module, selector), + MutationKind::GepShape => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_gep_shape(function, selector); + } + MutationKind::AttributeRef => mutate_invalid_call_site_attr_ref(module), + } +} + +fn mutate_reference_ordering(module: &mut Module, selector: u8) { + match selector % 5 { + 0 => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_typed_local_ref_undefined(function); + } + 1 => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_typed_local_ref_use_before_def(function); + } + 2 => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_undefined_local_ref(function); + } + 3 => { + let Some(function) = first_defined_function(module) else { + return; + }; + mutate_local_ref_use_before_def(function); + } + _ => mutate_undefined_callee(module), + } +} + +fn first_defined_function(module: &mut Module) -> Option<&mut Function> { + if let Some(index) = module + .functions + .iter() + .position(|function| !function.is_declaration) + { + module.functions.get_mut(index) + } else { + module.functions.first_mut() + } +} + +fn mutate_typed_local_ref_undefined(function: &mut Function) { + let missing_name = next_available_local_name(function, "__qir_mut_missing"); + let result_name = next_available_local_name(function, "__qir_mut_use"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::TypedLocalRef(missing_name, Type::Integer(64)), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: result_name, + }, + ); + } +} + +fn mutate_typed_local_ref_use_before_def(function: &mut Function) { + let late_name = next_available_local_name(function, "__qir_mut_late"); + let result_name = next_available_local_name(function, "__qir_mut_use"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::TypedLocalRef(late_name.clone(), Type::Integer(64)), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: result_name, + }, + ); + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 2), + rhs: Operand::IntConst(Type::Integer(64), 3), + result: late_name, + }, + ); + } +} + +fn mutate_undefined_local_ref(function: &mut Function) { + let missing_name = next_available_local_name(function, "__qir_mut_missing"); + let result_name = next_available_local_name(function, "__qir_mut_use"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef(missing_name), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: result_name, + }, + ); + } +} + +fn mutate_local_ref_use_before_def(function: &mut Function) { + let late_name = next_available_local_name(function, "__qir_mut_late"); + let result_name = next_available_local_name(function, "__qir_mut_use"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef(late_name.clone()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: result_name, + }, + ); + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 2), + rhs: Operand::IntConst(Type::Integer(64), 3), + result: late_name, + }, + ); + } +} + +fn mutate_undefined_callee(module: &mut Module) { + let missing_callee = next_available_function_name(module, "__qir_mut_missing_callee"); + let Some(function) = first_defined_function(module) else { + return; + }; + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Call { + return_ty: None, + callee: missing_callee, + args: Vec::new(), + result: None, + attr_refs: Vec::new(), + }, + ); + } +} + +fn mutate_call_shape(module: &mut Module, selector: u8) { + match selector % 2 { + 0 => mutate_call_arg_operand_type_mismatch(module), + _ => mutate_call_typed_local_type_masking(module), + } +} + +fn mutate_call_arg_operand_type_mismatch(module: &mut Module) { + let callee = next_available_function_name(module, "__qir_mut_consume_i64"); + push_declaration( + module, + callee.clone(), + Type::Void, + vec![Param { + ty: Type::Integer(64), + name: Some("value".to_string()), + }], + ); + + let Some(function) = first_defined_function(module) else { + return; + }; + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Call { + return_ty: None, + callee, + args: vec![(Type::Integer(64), Operand::IntConst(Type::Integer(1), 1))], + result: None, + attr_refs: Vec::new(), + }, + ); + } +} + +fn mutate_call_typed_local_type_masking(module: &mut Module) { + let callee = next_available_function_name(module, "__qir_mut_consume_i1"); + push_declaration( + module, + callee.clone(), + Type::Void, + vec![Param { + ty: Type::Integer(1), + name: Some("flag".to_string()), + }], + ); + + let Some(function) = first_defined_function(module) else { + return; + }; + + let value_name = next_available_local_name(function, "__qir_mut_call_value"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: value_name.clone(), + }, + ); + insert_before_terminator( + entry_block, + Instruction::Call { + return_ty: None, + callee, + args: vec![( + Type::Integer(1), + Operand::TypedLocalRef(value_name, Type::Integer(1)), + )], + result: None, + attr_refs: Vec::new(), + }, + ); + } +} + +fn mutate_phi_structure(function: &mut Function, selector: u8) { + match selector % 3 { + 0 => mutate_phi_predecessor_multiplicity_mismatch(function), + 1 => mutate_phi_non_predecessor_incoming(function), + _ => mutate_phi_duplicate_incoming_diff_values(function), + } +} + +fn mutate_phi_predecessor_multiplicity_mismatch(function: &mut Function) { + let return_type = function.return_type.clone(); + let left_block_name = next_available_block_name(function, "__qir_mut_phi_left"); + let right_block_name = next_available_block_name(function, "__qir_mut_phi_right"); + let merge_block_name = next_available_block_name(function, "__qir_mut_phi_merge"); + let phi_result = next_available_local_name(function, "__qir_mut_phi"); + + if !replace_entry_terminator( + function, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: left_block_name.clone(), + false_dest: right_block_name.clone(), + }, + ) { + return; + } + + function.basic_blocks.push(BasicBlock { + name: left_block_name.clone(), + instructions: vec![Instruction::Jump { + dest: merge_block_name.clone(), + }], + }); + function.basic_blocks.push(BasicBlock { + name: right_block_name, + instructions: vec![Instruction::Jump { + dest: merge_block_name.clone(), + }], + }); + function.basic_blocks.push(BasicBlock { + name: merge_block_name, + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + ( + Operand::IntConst(Type::Integer(64), 0), + left_block_name.clone(), + ), + (Operand::IntConst(Type::Integer(64), 0), left_block_name), + ], + result: phi_result, + }, + Instruction::Ret(default_return_operand(&return_type)), + ], + }); +} + +fn mutate_phi_non_predecessor_incoming(function: &mut Function) { + let return_type = function.return_type.clone(); + let merge_block_name = next_available_block_name(function, "__qir_mut_phi_merge"); + let missing_pred_name = next_available_block_name(function, "__qir_mut_missing_pred"); + let phi_result = next_available_local_name(function, "__qir_mut_phi"); + + if !replace_entry_terminator( + function, + Instruction::Jump { + dest: merge_block_name.clone(), + }, + ) { + return; + } + + function.basic_blocks.push(BasicBlock { + name: merge_block_name, + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![(Operand::IntConst(Type::Integer(64), 0), missing_pred_name)], + result: phi_result, + }, + Instruction::Ret(default_return_operand(&return_type)), + ], + }); +} + +fn mutate_phi_duplicate_incoming_diff_values(function: &mut Function) { + let Some(entry_block_name) = function + .basic_blocks + .first() + .map(|block| block.name.clone()) + else { + return; + }; + + let return_type = function.return_type.clone(); + let merge_block_name = next_available_block_name(function, "__qir_mut_phi_merge"); + let phi_result = next_available_local_name(function, "__qir_mut_phi"); + + if !replace_entry_terminator( + function, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: merge_block_name.clone(), + false_dest: merge_block_name.clone(), + }, + ) { + return; + } + + function.basic_blocks.push(BasicBlock { + name: merge_block_name, + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + ( + Operand::IntConst(Type::Integer(64), 0), + entry_block_name.clone(), + ), + (Operand::IntConst(Type::Integer(64), 1), entry_block_name), + ], + result: phi_result, + }, + Instruction::Ret(default_return_operand(&return_type)), + ], + }); +} + +fn mutate_dominance(function: &mut Function, selector: u8) { + match selector % 2 { + 0 => mutate_phi_non_dominating_incoming_value(function), + _ => mutate_cross_block_non_dominating_use(function), + } +} + +fn mutate_phi_non_dominating_incoming_value(function: &mut Function) { + let return_type = function.return_type.clone(); + let left_block_name = next_available_block_name(function, "__qir_mut_dom_left"); + let right_block_name = next_available_block_name(function, "__qir_mut_dom_right"); + let merge_block_name = next_available_block_name(function, "__qir_mut_dom_merge"); + let value_name = next_available_local_name(function, "__qir_mut_dom_value"); + let phi_result = next_available_local_name(function, "__qir_mut_dom_phi"); + + if !replace_entry_terminator( + function, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: left_block_name.clone(), + false_dest: right_block_name.clone(), + }, + ) { + return; + } + + function.basic_blocks.push(BasicBlock { + name: left_block_name.clone(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: value_name.clone(), + }, + Instruction::Jump { + dest: merge_block_name.clone(), + }, + ], + }); + function.basic_blocks.push(BasicBlock { + name: right_block_name.clone(), + instructions: vec![Instruction::Jump { + dest: merge_block_name.clone(), + }], + }); + function.basic_blocks.push(BasicBlock { + name: merge_block_name, + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::LocalRef(value_name), right_block_name), + (Operand::IntConst(Type::Integer(64), 0), left_block_name), + ], + result: phi_result, + }, + Instruction::Ret(default_return_operand(&return_type)), + ], + }); +} + +fn mutate_cross_block_non_dominating_use(function: &mut Function) { + let return_type = function.return_type.clone(); + let then_block_name = next_available_block_name(function, "__qir_mut_dom_then"); + let else_block_name = next_available_block_name(function, "__qir_mut_dom_else"); + let merge_block_name = next_available_block_name(function, "__qir_mut_dom_merge"); + let value_name = next_available_local_name(function, "__qir_mut_dom_value"); + let use_result = next_available_local_name(function, "__qir_mut_dom_use"); + + if !replace_entry_terminator( + function, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: then_block_name.clone(), + false_dest: else_block_name.clone(), + }, + ) { + return; + } + + function.basic_blocks.push(BasicBlock { + name: then_block_name, + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: value_name.clone(), + }, + Instruction::Jump { + dest: merge_block_name.clone(), + }, + ], + }); + function.basic_blocks.push(BasicBlock { + name: else_block_name, + instructions: vec![Instruction::Jump { + dest: merge_block_name.clone(), + }], + }); + function.basic_blocks.push(BasicBlock { + name: merge_block_name, + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef(value_name), + rhs: Operand::IntConst(Type::Integer(64), 0), + result: use_result, + }, + Instruction::Ret(default_return_operand(&return_type)), + ], + }); +} + +fn mutate_invalid_branch_target(module: &mut Module, selector: u8) { + let Some(function) = first_defined_function(module) else { + return; + }; + + match selector % 4 { + 0 => mutate_invalid_conditional_branch_target(function), + 1 => mutate_invalid_jump_target(function), + 2 => mutate_invalid_switch_case_target(function), + _ => mutate_invalid_switch_default_target(function), + } +} + +fn mutate_invalid_conditional_branch_target(function: &mut Function) { + let Some(valid_target) = function + .basic_blocks + .first() + .map(|block| block.name.clone()) + else { + return; + }; + let missing_block = next_available_block_name(function, "__qir_mut_missing_block"); + + if let Some(entry_block) = function.basic_blocks.first_mut() + && let Some(terminator) = entry_block.instructions.last_mut() + { + *terminator = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: missing_block, + false_dest: valid_target, + }; + } +} + +fn mutate_invalid_jump_target(function: &mut Function) { + let missing_block = next_available_block_name(function, "__qir_mut_missing_block"); + + if let Some(entry_block) = function.basic_blocks.first_mut() + && let Some(terminator) = entry_block.instructions.last_mut() + { + *terminator = Instruction::Jump { + dest: missing_block, + }; + } +} + +fn mutate_invalid_switch_case_target(function: &mut Function) { + let Some(valid_target) = function + .basic_blocks + .first() + .map(|block| block.name.clone()) + else { + return; + }; + let missing_block = next_available_block_name(function, "__qir_mut_missing_block"); + + if let Some(entry_block) = function.basic_blocks.first_mut() + && let Some(terminator) = entry_block.instructions.last_mut() + { + *terminator = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + default_dest: valid_target, + cases: vec![(1, missing_block)], + }; + } +} + +fn mutate_invalid_switch_default_target(function: &mut Function) { + let Some(valid_target) = function + .basic_blocks + .first() + .map(|block| block.name.clone()) + else { + return; + }; + let missing_block = next_available_block_name(function, "__qir_mut_missing_block"); + + if let Some(entry_block) = function.basic_blocks.first_mut() + && let Some(terminator) = entry_block.instructions.last_mut() + { + *terminator = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + default_dest: missing_block, + cases: vec![(1, valid_target)], + }; + } +} + +fn mutate_gep_shape(function: &mut Function, selector: u8) { + match selector % 3 { + 0 => mutate_gep_no_indices(function), + 1 => mutate_gep_non_integer_index(function), + _ => mutate_gep_non_pointer(function), + } +} + +fn mutate_gep_no_indices(function: &mut Function) { + let result_name = next_available_local_name(function, "__qir_mut_gep"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Ptr, + ptr: Operand::NullPtr, + indices: Vec::new(), + result: result_name, + }, + ); + } +} + +fn mutate_gep_non_integer_index(function: &mut Function) { + let result_name = next_available_local_name(function, "__qir_mut_gep"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Ptr, + ptr: Operand::NullPtr, + indices: vec![Operand::float_const(Type::Double, 0.0)], + result: result_name, + }, + ); + } +} + +fn mutate_gep_non_pointer(function: &mut Function) { + let result_name = next_available_local_name(function, "__qir_mut_gep"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Integer(64), + ptr: Operand::IntConst(Type::Integer(64), 0), + indices: vec![Operand::IntConst(Type::Integer(32), 0)], + result: result_name, + }, + ); + } +} + +fn mutate_invalid_call_site_attr_ref(module: &mut Module) { + let callee = next_available_function_name(module, "__qir_mut_attr_callee"); + let missing_attr_ref = next_missing_attribute_group_id(module); + push_declaration(module, callee.clone(), Type::Void, Vec::new()); + + let Some(function) = first_defined_function(module) else { + return; + }; + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Call { + return_ty: None, + callee, + args: Vec::new(), + result: None, + attr_refs: vec![missing_attr_ref], + }, + ); + } +} + +fn push_declaration(module: &mut Module, name: String, return_type: Type, params: Vec) { + module.functions.push(Function { + name, + return_type, + params, + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); +} + +fn next_missing_attribute_group_id(module: &Module) -> u32 { + let mut candidate = 0; + + while module + .attribute_groups + .iter() + .any(|group| group.id == candidate) + { + candidate += 1; + } + + candidate +} + +fn default_return_operand(return_type: &Type) -> Option { + match return_type { + Type::Void | Type::Array(_, _) | Type::Function(_, _) | Type::Named(_) | Type::Label => { + None + } + Type::Integer(width) => Some(Operand::IntConst(Type::Integer(*width), 0)), + Type::Double => Some(Operand::float_const(Type::Double, 0.0)), + Type::Ptr | Type::NamedPtr(_) | Type::TypedPtr(_) => Some(Operand::NullPtr), + Type::Half => Some(Operand::float_const(Type::Half, 0.0)), + Type::Float => Some(Operand::float_const(Type::Float, 0.0)), + } +} + +fn replace_entry_terminator(function: &mut Function, terminator: Instruction) -> bool { + if let Some(entry_block) = function.basic_blocks.first_mut() + && let Some(current_terminator) = entry_block.instructions.last_mut() + { + *current_terminator = terminator; + true + } else { + false + } +} + +fn insert_before_terminator(block: &mut BasicBlock, instruction: Instruction) { + let insert_index = block.instructions.len().saturating_sub(1); + block.instructions.insert(insert_index, instruction); +} + +fn next_available_block_name(function: &Function, prefix: &str) -> String { + next_available_name(prefix, |candidate| { + function + .basic_blocks + .iter() + .any(|block| block.name == candidate) + }) +} + +fn next_available_function_name(module: &Module, prefix: &str) -> String { + next_available_name(prefix, |candidate| { + module + .functions + .iter() + .any(|function| function.name == candidate) + }) +} + +fn next_available_local_name(function: &Function, prefix: &str) -> String { + next_available_name(prefix, |candidate| { + function + .params + .iter() + .any(|param| param.name.as_deref() == Some(candidate)) + || function + .basic_blocks + .iter() + .flat_map(|block| block.instructions.iter()) + .filter_map(instruction_result_name) + .any(|name| name == candidate) + }) +} + +fn next_available_name(prefix: &str, exists: impl Fn(&str) -> bool) -> String { + if !exists(prefix) { + return prefix.to_string(); + } + + for index in 0.. { + let candidate = format!("{prefix}_{index}"); + if !exists(&candidate) { + return candidate; + } + } + + unreachable!("unbounded suffix search should always find a unique name") +} + +fn instruction_result_name(instruction: &Instruction) -> Option<&str> { + match instruction { + Instruction::BinOp { result, .. } + | Instruction::ICmp { result, .. } + | Instruction::FCmp { result, .. } + | Instruction::Cast { result, .. } + | Instruction::Call { + result: Some(result), + .. + } + | Instruction::Phi { result, .. } + | Instruction::Alloca { result, .. } + | Instruction::Load { result, .. } + | Instruction::Select { result, .. } + | Instruction::GetElementPtr { result, .. } => Some(result.as_str()), + _ => None, + } +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_mutations.rs b/source/compiler/qsc_llvm/src/fuzz/qir_mutations.rs new file mode 100644 index 0000000000..a733f95069 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_mutations.rs @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::model::{BasicBlock, Type}; +use crate::model::{Function, Instruction, Module, Operand, Param, StructType}; + +#[must_use] +pub fn mutate_adaptive_v1_typed_pointer_seed(seed: &Module, selector: u8) -> Module { + let mut mutated = seed.clone(); + + match selector % 3 { + 0 => { + let Some(function) = first_defined_function(&mut mutated) else { + return mutated; + }; + mutate_load_pointer_operand_mismatch(function); + } + 1 => { + let Some(function) = first_defined_function(&mut mutated) else { + return mutated; + }; + mutate_store_pointer_operand_mismatch(function); + } + _ => mutate_named_pointer_call_arg_mismatch(&mut mutated), + } + + mutated +} + +fn first_defined_function(module: &mut Module) -> Option<&mut Function> { + if let Some(index) = module + .functions + .iter() + .position(|function| !function.is_declaration) + { + module.functions.get_mut(index) + } else { + module.functions.first_mut() + } +} + +fn mutate_load_pointer_operand_mismatch(function: &mut Function) { + let result_name = next_available_local_name(function, "__qir_mut_load"); + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Load { + ty: Type::Integer(64), + ptr_ty: typed_ptr(Type::Integer(64)), + ptr: Operand::IntToPtr(0, typed_ptr(Type::Integer(8))), + result: result_name, + }, + ); + } +} + +fn mutate_store_pointer_operand_mismatch(function: &mut Function) { + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Store { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + ptr_ty: typed_ptr(Type::Integer(64)), + ptr: Operand::IntToPtr(0, typed_ptr(Type::Integer(8))), + }, + ); + } +} + +fn mutate_named_pointer_call_arg_mismatch(module: &mut Module) { + ensure_opaque_struct(module, "Qubit"); + + let callee = next_available_function_name(module, "__qir_mut_named_qubit_callee"); + let qubit_ptr = Type::NamedPtr("Qubit".to_string()); + + push_declaration( + module, + callee.clone(), + Type::Void, + vec![Param { + ty: qubit_ptr.clone(), + name: Some("qubit".to_string()), + }], + ); + + let Some(function) = first_defined_function(module) else { + return; + }; + + if let Some(entry_block) = function.basic_blocks.first_mut() { + insert_before_terminator( + entry_block, + Instruction::Call { + return_ty: None, + callee, + args: vec![( + qubit_ptr, + Operand::IntToPtr(0, Type::Named("Qubit".to_string())), + )], + result: None, + attr_refs: Vec::new(), + }, + ); + } +} + +fn typed_ptr(inner: Type) -> Type { + Type::TypedPtr(Box::new(inner)) +} + +fn ensure_opaque_struct(module: &mut Module, name: &str) { + if module + .struct_types + .iter() + .any(|struct_ty| struct_ty.name == name) + { + return; + } + + module.struct_types.push(StructType { + name: name.to_string(), + is_opaque: true, + }); +} + +fn push_declaration(module: &mut Module, name: String, return_type: Type, params: Vec) { + module.functions.push(Function { + name, + return_type, + params, + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); +} + +fn insert_before_terminator(block: &mut BasicBlock, instruction: Instruction) { + let insert_index = block.instructions.len().saturating_sub(1); + block.instructions.insert(insert_index, instruction); +} + +fn next_available_function_name(module: &Module, prefix: &str) -> String { + next_available_name(prefix, |candidate| { + module + .functions + .iter() + .any(|function| function.name == candidate) + }) +} + +fn next_available_local_name(function: &Function, prefix: &str) -> String { + next_available_name(prefix, |candidate| { + function + .params + .iter() + .any(|param| param.name.as_deref() == Some(candidate)) + || function + .basic_blocks + .iter() + .flat_map(|block| block.instructions.iter()) + .filter_map(instruction_result_name) + .any(|name| name == candidate) + }) +} + +fn next_available_name(prefix: &str, exists: impl Fn(&str) -> bool) -> String { + if !exists(prefix) { + return prefix.to_string(); + } + + for index in 0.. { + let candidate = format!("{prefix}_{index}"); + if !exists(&candidate) { + return candidate; + } + } + + unreachable!("unbounded suffix search should always find a unique name") +} + +fn instruction_result_name(instruction: &Instruction) -> Option<&str> { + match instruction { + Instruction::BinOp { result, .. } + | Instruction::ICmp { result, .. } + | Instruction::FCmp { result, .. } + | Instruction::Cast { result, .. } + | Instruction::Call { + result: Some(result), + .. + } + | Instruction::Phi { result, .. } + | Instruction::Alloca { result, .. } + | Instruction::Load { result, .. } + | Instruction::Select { result, .. } + | Instruction::GetElementPtr { result, .. } => Some(result.as_str()), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::BasicBlock; + use crate::{LlvmIrError, validate_ir}; + + fn adaptive_v1_seed_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".to_string(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(None)], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } + } + + fn assert_selector_triggers_type_mismatch(selector: u8, expected_instruction: &str) { + let seed = adaptive_v1_seed_module(); + assert!(validate_ir(&seed).is_empty(), "seed module should be valid"); + + let mutated = mutate_adaptive_v1_typed_pointer_seed(&seed, selector); + let errors = validate_ir(&mutated); + + assert!( + errors.iter().any(|error| matches!( + error, + LlvmIrError::TypeMismatch { instruction, .. } + if instruction == expected_instruction + )), + "expected {expected_instruction} type mismatch validation error, got {errors:?}" + ); + } + + #[test] + fn load_pointer_operand_mutation_triggers_validator_error() { + assert_selector_triggers_type_mismatch(0, "Load"); + } + + #[test] + fn store_pointer_operand_mutation_triggers_validator_error() { + assert_selector_triggers_type_mismatch(1, "Store"); + } + + #[test] + fn named_pointer_call_argument_mutation_triggers_validator_error() { + let seed = adaptive_v1_seed_module(); + assert!(validate_ir(&seed).is_empty(), "seed module should be valid"); + + let mutated = mutate_adaptive_v1_typed_pointer_seed(&seed, 2); + let errors = validate_ir(&mutated); + + assert!( + errors.iter().any(|error| matches!( + error, + LlvmIrError::TypeMismatch { + instruction, + expected, + found, + .. + } if instruction.starts_with("Call @__qir_mut_named_qubit_callee") + && expected == "%Qubit*" + && found == "%Qubit" + )), + "expected named-pointer call mismatch validation error, got {errors:?}" + ); + } +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith.rs new file mode 100644 index 0000000000..77a44ef766 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith.rs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod checked; +mod compare; +mod config; +mod generator; +mod io; +mod metadata; + +#[cfg(test)] +mod tests; + +use arbitrary::Unstructured; + +use crate::model::Module; +#[cfg(test)] +use crate::model::{ + BasicBlock, BinOpKind, CastKind, Constant, Function, GlobalVariable, Instruction, IntPredicate, + Linkage, MetadataValue, Operand, Param, Type, +}; +#[cfg(test)] +use crate::qir; + +pub use config::{ + EffectiveConfig, GeneratedArtifact, OutputMode, QirProfilePreset, QirSmithConfig, + QirSmithError, RoundTripKind, +}; + +use checked::populate_checked_artifact; +#[cfg(test)] +use compare::{ + assert_bitcode_roundtrip_matches_supported_v1_subset, ensure_text_roundtrip_matches, +}; +#[cfg(test)] +use config::{ + BASE_V1_BLOCK_COUNT, DEFAULT_MAX_BLOCKS_PER_FUNC, DEFAULT_MAX_FUNCS, + DEFAULT_MAX_INSTRS_PER_BLOCK, +}; +use generator::build_module_shell; +#[cfg(test)] +use generator::{ + GENERATED_TARGET_DATALAYOUT, GENERATED_TARGET_TRIPLE, QirGenState, ShellCounts, ShellPreset, + StableNameAllocator, +}; +use io::{emit_bitcode, emit_text}; +#[cfg(test)] +use io::{parse_bitcode_roundtrip, parse_text_roundtrip}; +#[cfg(test)] +use metadata::build_qdk_metadata; +use metadata::finalize_float_computations; + +pub fn generate( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + let effective = config.sanitize(); + generate_artifact(&effective, bytes) +} + +pub fn generate_from_bytes( + config: &QirSmithConfig, + bytes: &[u8], +) -> Result { + with_unstructured_bytes(bytes, |unstructured| generate(config, unstructured)) +} + +pub fn generate_module( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + let artifact = generate_for_mode(config, bytes, OutputMode::Model)?; + Ok(artifact.module) +} + +pub fn generate_module_from_bytes( + config: &QirSmithConfig, + bytes: &[u8], +) -> Result { + with_unstructured_bytes(bytes, |unstructured| generate_module(config, unstructured)) +} + +pub fn generate_text( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + let artifact = generate_for_mode(config, bytes, OutputMode::Text)?; + Ok(artifact.text.unwrap_or_default()) +} + +pub fn generate_text_from_bytes( + config: &QirSmithConfig, + bytes: &[u8], +) -> Result { + with_unstructured_bytes(bytes, |unstructured| generate_text(config, unstructured)) +} + +pub fn generate_bitcode( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, +) -> Result, QirSmithError> { + let artifact = generate_for_mode(config, bytes, OutputMode::Bitcode)?; + Ok(artifact.bitcode.unwrap_or_default()) +} + +pub fn generate_bitcode_from_bytes( + config: &QirSmithConfig, + bytes: &[u8], +) -> Result, QirSmithError> { + with_unstructured_bytes(bytes, |unstructured| generate_bitcode(config, unstructured)) +} + +pub fn generate_checked( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + generate_for_mode(config, bytes, OutputMode::RoundTripChecked) +} + +pub fn generate_checked_from_bytes( + config: &QirSmithConfig, + bytes: &[u8], +) -> Result { + with_unstructured_bytes(bytes, |unstructured| generate_checked(config, unstructured)) +} + +fn with_unstructured_bytes( + bytes: &[u8], + generate: impl FnOnce(&mut Unstructured<'_>) -> Result, +) -> Result { + let mut unstructured = Unstructured::new(bytes); + generate(&mut unstructured) +} + +fn generate_for_mode( + config: &QirSmithConfig, + bytes: &mut Unstructured<'_>, + output_mode: OutputMode, +) -> Result { + let config = config.with_output_mode(output_mode); + generate(&config, bytes) +} + +fn generate_artifact( + effective: &EffectiveConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + let mut artifact = build_generated_artifact(effective, bytes)?; + populate_requested_outputs(&mut artifact)?; + Ok(artifact) +} + +fn build_generated_artifact( + effective: &EffectiveConfig, + bytes: &mut Unstructured<'_>, +) -> Result { + let mut module = build_module_shell(effective, bytes); + + if matches!( + effective.profile, + QirProfilePreset::AdaptiveV1 | QirProfilePreset::AdaptiveV2 + ) { + finalize_float_computations(&mut module); + } + + Ok(GeneratedArtifact { + effective_config: effective.clone(), + module, + text: None, + bitcode: None, + }) +} + +fn populate_requested_outputs(artifact: &mut GeneratedArtifact) -> Result<(), QirSmithError> { + match artifact.effective_config.output_mode { + OutputMode::Model => {} + OutputMode::Text => { + artifact.text = Some(emit_text(&artifact.module)); + } + OutputMode::Bitcode => { + artifact.bitcode = Some(emit_bitcode(&artifact.module)?); + } + OutputMode::RoundTripChecked => populate_checked_artifact(artifact)?, + } + + Ok(()) +} + +const fn sanitize_count(value: usize, default: usize) -> usize { + if value == 0 { default } else { value } +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/checked.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/checked.rs new file mode 100644 index 0000000000..ebdc7fdc00 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/checked.rs @@ -0,0 +1,270 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::model::{Instruction, Module, Operand, Type}; + +use super::{ + compare::{ + assert_bitcode_roundtrip_matches_supported_v1_subset, ensure_text_roundtrip_matches, + }, + config::{GeneratedArtifact, QirSmithError, RoundTripKind}, + io::{emit_bitcode, emit_text, parse_bitcode_roundtrip, parse_text_roundtrip}, +}; + +pub(super) fn populate_checked_artifact( + artifact: &mut GeneratedArtifact, +) -> Result<(), QirSmithError> { + validate_checked_module( + &artifact.module, + artifact.effective_config.allow_typed_pointers, + )?; + + if artifact.effective_config.profile.to_qir_profile().is_some() { + let validation = crate::validation::validate_qir_profile(&artifact.module); + if let Some(first_error) = validation.errors.into_iter().next() { + return Err(QirSmithError::ProfileViolation(first_error)); + } + } + + match artifact + .effective_config + .roundtrip + .unwrap_or(RoundTripKind::TextAndBitcodeSinglePass) + { + RoundTripKind::TextOnly => { + let text = emit_text(&artifact.module); + let reparsed_text = parse_text_roundtrip(&text)?; + ensure_text_roundtrip_matches(&artifact.module, &reparsed_text)?; + artifact.text = Some(text); + } + RoundTripKind::BitcodeOnly => { + let bitcode = emit_bitcode(&artifact.module)?; + let reparsed_bitcode = parse_bitcode_roundtrip(&bitcode)?; + assert_bitcode_roundtrip_matches_supported_v1_subset( + &artifact.module, + &reparsed_bitcode, + )?; + artifact.bitcode = Some(bitcode); + } + RoundTripKind::TextAndBitcodeSinglePass => { + let text = emit_text(&artifact.module); + let reparsed_text = parse_text_roundtrip(&text)?; + ensure_text_roundtrip_matches(&artifact.module, &reparsed_text)?; + artifact.text = Some(text); + let bitcode = emit_bitcode(&reparsed_text)?; + let reparsed_bitcode = parse_bitcode_roundtrip(&bitcode)?; + assert_bitcode_roundtrip_matches_supported_v1_subset( + &reparsed_text, + &reparsed_bitcode, + )?; + artifact.bitcode = Some(bitcode); + } + } + + Ok(()) +} + +fn validate_checked_module( + module: &Module, + allow_typed_pointers: bool, +) -> Result<(), QirSmithError> { + let Some(entry_point) = module.functions.first() else { + return Err(QirSmithError::ModelGeneration( + "checked mode requires a generated entry point".to_string(), + )); + }; + + if entry_point.name != crate::qir::ENTRYPOINT_NAME || entry_point.is_declaration { + return Err(QirSmithError::ModelGeneration( + "checked mode expects the generated module to start with a defined ENTRYPOINT__main" + .to_string(), + )); + } + + if entry_point.return_type != Type::Integer(64) { + return Err(QirSmithError::ModelGeneration( + "checked mode expects ENTRYPOINT__main to return i64".to_string(), + )); + } + + if !entry_point.params.is_empty() { + return Err(QirSmithError::ModelGeneration( + "checked mode expects ENTRYPOINT__main to take no parameters".to_string(), + )); + } + + if module + .functions + .iter() + .filter(|function| !function.is_declaration) + .count() + != 1 + { + return Err(QirSmithError::ModelGeneration( + "checked mode only supports a single defined entry point in the v1 subset".to_string(), + )); + } + + for global in &module.globals { + validate_checked_type(&global.ty, allow_typed_pointers)?; + } + + for function in &module.functions { + validate_checked_type(&function.return_type, allow_typed_pointers)?; + for param in &function.params { + validate_checked_type(¶m.ty, allow_typed_pointers)?; + } + for block in &function.basic_blocks { + for instruction in &block.instructions { + validate_checked_instruction(instruction, allow_typed_pointers)?; + } + } + } + + Ok(()) +} + +fn validate_checked_type(ty: &Type, allow_typed_pointers: bool) -> Result<(), QirSmithError> { + match ty { + Type::Void | Type::Integer(_) | Type::Half | Type::Float | Type::Double | Type::Ptr => { + Ok(()) + } + Type::Label | Type::Function(_, _) => Err(QirSmithError::ModelGeneration(format!( + "type {ty} is outside the supported checked subset" + ))), + Type::Array(_, element) => validate_checked_type(element, allow_typed_pointers), + Type::NamedPtr(_) | Type::TypedPtr(_) | Type::Named(_) => { + if allow_typed_pointers { + Ok(()) + } else { + Err(QirSmithError::ModelGeneration(format!( + "type {ty} is outside the supported opaque-pointer checked subset" + ))) + } + } + } +} + +fn validate_checked_operand( + operand: &Operand, + allow_typed_pointers: bool, +) -> Result<(), QirSmithError> { + match operand { + Operand::LocalRef(_) + | Operand::TypedLocalRef(_, _) + | Operand::FloatConst(_, _) + | Operand::NullPtr + | Operand::GlobalRef(_) => Ok(()), + Operand::IntConst(ty, _) => validate_checked_type(ty, allow_typed_pointers), + Operand::IntToPtr(_, ty) => { + validate_checked_type(ty, allow_typed_pointers)?; + if !allow_typed_pointers && ty != &Type::Ptr { + return Err(QirSmithError::ModelGeneration(format!( + "inttoptr target type {ty} is outside the supported opaque-pointer checked subset" + ))); + } + Ok(()) + } + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + if allow_typed_pointers { + validate_checked_type(ty, allow_typed_pointers)?; + validate_checked_type(ptr_ty, allow_typed_pointers)?; + for idx_op in indices { + validate_checked_operand(idx_op, allow_typed_pointers)?; + } + Ok(()) + } else { + Err(QirSmithError::ModelGeneration( + "getelementptr operands are outside the supported opaque-pointer checked subset" + .to_string(), + )) + } + } + } +} + +fn validate_checked_instruction( + instruction: &Instruction, + allow_typed_pointers: bool, +) -> Result<(), QirSmithError> { + match instruction { + Instruction::Ret(value) => { + if let Some(value) = value { + validate_checked_operand(value, allow_typed_pointers)?; + } + Ok(()) + } + Instruction::Br { cond_ty, cond, .. } => { + validate_checked_type(cond_ty, allow_typed_pointers)?; + validate_checked_operand(cond, allow_typed_pointers) + } + Instruction::Jump { .. } => Ok(()), + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + validate_checked_type(ty, allow_typed_pointers)?; + validate_checked_operand(lhs, allow_typed_pointers)?; + validate_checked_operand(rhs, allow_typed_pointers) + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + validate_checked_type(from_ty, allow_typed_pointers)?; + validate_checked_type(to_ty, allow_typed_pointers)?; + validate_checked_operand(value, allow_typed_pointers) + } + Instruction::Call { + return_ty, + args, + attr_refs, + .. + } => { + if let Some(return_ty) = return_ty { + validate_checked_type(return_ty, allow_typed_pointers)?; + } + for (ty, operand) in args { + validate_checked_type(ty, allow_typed_pointers)?; + validate_checked_operand(operand, allow_typed_pointers)?; + } + if !attr_refs.is_empty() { + return Err(QirSmithError::ModelGeneration( + "call attribute references are outside the supported v1 checked subset" + .to_string(), + )); + } + Ok(()) + } + Instruction::Phi { .. } => Err(QirSmithError::ModelGeneration( + "phi instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Alloca { .. } => Err(QirSmithError::ModelGeneration( + "alloca instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Load { .. } => Err(QirSmithError::ModelGeneration( + "load instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Store { .. } => Err(QirSmithError::ModelGeneration( + "store instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Select { .. } => Err(QirSmithError::ModelGeneration( + "select instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Switch { .. } => Err(QirSmithError::ModelGeneration( + "switch instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::GetElementPtr { .. } => Err(QirSmithError::ModelGeneration( + "getelementptr instructions are outside the supported v1 checked subset".to_string(), + )), + Instruction::Unreachable => Err(QirSmithError::ModelGeneration( + "unreachable instructions are outside the supported v1 checked subset".to_string(), + )), + } +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/compare.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/compare.rs new file mode 100644 index 0000000000..acd2ef7c87 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/compare.rs @@ -0,0 +1,501 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use rustc_hash::FxHashSet; + +use crate::model::{Constant, Instruction, Module, Operand}; + +use super::config::QirSmithError; + +pub(super) fn ensure_text_roundtrip_matches( + original: &Module, + reparsed: &Module, +) -> Result<(), QirSmithError> { + if original == reparsed { + Ok(()) + } else { + Err(QirSmithError::TextRoundTrip( + "text roundtrip changed module structure for the supported v1 subset".to_string(), + )) + } +} + +fn is_synthetic_param_name(name: &str, param_index: usize) -> bool { + name.strip_prefix("param_") + .and_then(|suffix| suffix.parse::().ok()) + .is_some_and(|index| index == param_index) +} + +fn param_names_semantically_equal( + expected: Option<&str>, + actual: Option<&str>, + param_index: usize, +) -> bool { + match (expected, actual) { + (None, None) => true, + (None, Some(name)) => is_synthetic_param_name(name, param_index), + (Some(expected_name), Some(actual_name)) => expected_name == actual_name, + (Some(_), None) => false, + } +} + +fn local_ref_name(operand: &Operand) -> Option<&str> { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => Some(name), + _ => None, + } +} + +fn constants_semantically_equal(expected: &Constant, actual: &Constant) -> bool { + match (expected, actual) { + (Constant::CString(expected_text), Constant::CString(actual_text)) => { + expected_text == actual_text + } + (Constant::Int(expected_value), Constant::Int(actual_value)) => { + expected_value == actual_value + } + ( + Constant::Float(expected_ty, expected_value), + Constant::Float(actual_ty, actual_value), + ) => expected_ty == actual_ty && expected_value.to_bits() == actual_value.to_bits(), + (Constant::Null, Constant::Null) => true, + _ => false, + } +} + +fn optional_constants_semantically_equal( + expected: Option<&Constant>, + actual: Option<&Constant>, +) -> bool { + match (expected, actual) { + (None, None) => true, + (Some(expected_constant), Some(actual_constant)) => { + constants_semantically_equal(expected_constant, actual_constant) + } + _ => false, + } +} + +fn operands_semantically_equal(expected: &Operand, actual: &Operand) -> bool { + if let (Some(expected_name), Some(actual_name)) = + (local_ref_name(expected), local_ref_name(actual)) + { + return expected_name == actual_name; + } + + match (expected, actual) { + ( + Operand::IntConst(expected_ty, expected_value), + Operand::IntConst(actual_ty, actual_value), + ) => expected_ty == actual_ty && expected_value == actual_value, + ( + Operand::FloatConst(expected_ty, expected_value), + Operand::FloatConst(actual_ty, actual_value), + ) => expected_ty == actual_ty && expected_value.to_bits() == actual_value.to_bits(), + (Operand::NullPtr, Operand::NullPtr) => true, + ( + Operand::IntToPtr(expected_value, expected_ty), + Operand::IntToPtr(actual_value, actual_ty), + ) => expected_value == actual_value && expected_ty == actual_ty, + ( + Operand::GetElementPtr { + ty: expected_ty, + ptr: expected_ptr, + ptr_ty: expected_ptr_ty, + indices: expected_indices, + }, + Operand::GetElementPtr { + ty: actual_ty, + ptr: actual_ptr, + ptr_ty: actual_ptr_ty, + indices: actual_indices, + }, + ) => { + expected_ty == actual_ty + && expected_ptr == actual_ptr + && expected_ptr_ty == actual_ptr_ty + && expected_indices.len() == actual_indices.len() + && expected_indices.iter().zip(actual_indices.iter()).all( + |(expected_index, actual_index)| { + operands_semantically_equal(expected_index, actual_index) + }, + ) + } + (Operand::GlobalRef(expected_name), Operand::GlobalRef(actual_name)) => { + expected_name == actual_name + } + _ => false, + } +} + +fn optional_operands_semantically_equal( + expected: Option<&Operand>, + actual: Option<&Operand>, +) -> bool { + match (expected, actual) { + (None, None) => true, + (Some(expected_operand), Some(actual_operand)) => { + operands_semantically_equal(expected_operand, actual_operand) + } + _ => false, + } +} + +#[allow(clippy::too_many_lines)] +fn instructions_semantically_equal(expected: &Instruction, actual: &Instruction) -> bool { + match (expected, actual) { + (Instruction::Ret(expected_value), Instruction::Ret(actual_value)) => { + optional_operands_semantically_equal(expected_value.as_ref(), actual_value.as_ref()) + } + ( + Instruction::Br { + cond_ty: expected_cond_ty, + cond: expected_cond, + true_dest: expected_true_dest, + false_dest: expected_false_dest, + }, + Instruction::Br { + cond_ty: actual_cond_ty, + cond: actual_cond, + true_dest: actual_true_dest, + false_dest: actual_false_dest, + }, + ) => { + expected_cond_ty == actual_cond_ty + && expected_true_dest == actual_true_dest + && expected_false_dest == actual_false_dest + && operands_semantically_equal(expected_cond, actual_cond) + } + ( + Instruction::Jump { + dest: expected_dest, + }, + Instruction::Jump { dest: actual_dest }, + ) => expected_dest == actual_dest, + ( + Instruction::BinOp { + op: expected_op, + ty: expected_ty, + lhs: expected_lhs, + rhs: expected_rhs, + result: expected_result, + }, + Instruction::BinOp { + op: actual_op, + ty: actual_ty, + lhs: actual_lhs, + rhs: actual_rhs, + result: actual_result, + }, + ) => { + expected_op == actual_op + && expected_ty == actual_ty + && expected_result == actual_result + && operands_semantically_equal(expected_lhs, actual_lhs) + && operands_semantically_equal(expected_rhs, actual_rhs) + } + ( + Instruction::ICmp { + pred: expected_pred, + ty: expected_ty, + lhs: expected_lhs, + rhs: expected_rhs, + result: expected_result, + }, + Instruction::ICmp { + pred: actual_pred, + ty: actual_ty, + lhs: actual_lhs, + rhs: actual_rhs, + result: actual_result, + }, + ) => { + expected_pred == actual_pred + && expected_ty == actual_ty + && expected_result == actual_result + && operands_semantically_equal(expected_lhs, actual_lhs) + && operands_semantically_equal(expected_rhs, actual_rhs) + } + ( + Instruction::FCmp { + pred: expected_pred, + ty: expected_ty, + lhs: expected_lhs, + rhs: expected_rhs, + result: expected_result, + }, + Instruction::FCmp { + pred: actual_pred, + ty: actual_ty, + lhs: actual_lhs, + rhs: actual_rhs, + result: actual_result, + }, + ) => { + expected_pred == actual_pred + && expected_ty == actual_ty + && expected_result == actual_result + && operands_semantically_equal(expected_lhs, actual_lhs) + && operands_semantically_equal(expected_rhs, actual_rhs) + } + ( + Instruction::Cast { + op: expected_op, + from_ty: expected_from_ty, + to_ty: expected_to_ty, + value: expected_value, + result: expected_result, + }, + Instruction::Cast { + op: actual_op, + from_ty: actual_from_ty, + to_ty: actual_to_ty, + value: actual_value, + result: actual_result, + }, + ) => { + expected_op == actual_op + && expected_from_ty == actual_from_ty + && expected_to_ty == actual_to_ty + && expected_result == actual_result + && operands_semantically_equal(expected_value, actual_value) + } + ( + Instruction::Call { + return_ty: expected_return_ty, + callee: expected_callee, + args: expected_args, + result: expected_result, + attr_refs: expected_attr_refs, + }, + Instruction::Call { + return_ty: actual_return_ty, + callee: actual_callee, + args: actual_args, + result: actual_result, + attr_refs: actual_attr_refs, + }, + ) => { + expected_return_ty == actual_return_ty + && expected_callee == actual_callee + && expected_result == actual_result + && expected_attr_refs == actual_attr_refs + && expected_args.len() == actual_args.len() + && expected_args.iter().zip(actual_args.iter()).all( + |((expected_ty, expected_operand), (actual_ty, actual_operand))| { + expected_ty == actual_ty + && operands_semantically_equal(expected_operand, actual_operand) + }, + ) + } + _ => false, + } +} + +#[allow(clippy::too_many_lines)] +pub(super) fn assert_bitcode_roundtrip_matches_supported_v1_subset( + original: &Module, + reparsed: &Module, +) -> Result<(), QirSmithError> { + if original.source_filename != reparsed.source_filename { + return Err(QirSmithError::BitcodeRoundTrip( + "source_filename changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.target_datalayout != reparsed.target_datalayout { + return Err(QirSmithError::BitcodeRoundTrip( + "target_datalayout changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.target_triple != reparsed.target_triple { + return Err(QirSmithError::BitcodeRoundTrip( + "target_triple changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.struct_types != reparsed.struct_types { + return Err(QirSmithError::BitcodeRoundTrip( + "struct types changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.globals.len() != reparsed.globals.len() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global count changed across the supported v1 bitcode roundtrip: expected {}, found {}", + original.globals.len(), + reparsed.globals.len() + ))); + } + + for (global_index, (expected, actual)) in original + .globals + .iter() + .zip(reparsed.globals.iter()) + .enumerate() + { + if expected.name != actual.name { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global {global_index} name changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.linkage != actual.linkage { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global {global_index} linkage changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.ty != actual.ty { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global {global_index} type changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.is_constant != actual.is_constant { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global {global_index} mutability changed across the supported v1 bitcode roundtrip" + ))); + } + if !optional_constants_semantically_equal( + expected.initializer.as_ref(), + actual.initializer.as_ref(), + ) { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "global {global_index} initializer changed across the supported v1 bitcode roundtrip" + ))); + } + } + + if original.functions.len() != reparsed.functions.len() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function count changed across the supported v1 bitcode roundtrip: expected {}, found {}", + original.functions.len(), + reparsed.functions.len() + ))); + } + + for (function_index, (expected, actual)) in original + .functions + .iter() + .zip(reparsed.functions.iter()) + .enumerate() + { + if expected.name != actual.name { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} name changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.is_declaration != actual.is_declaration { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} declaration shape changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.return_type != actual.return_type { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} return type changed across the supported v1 bitcode roundtrip" + ))); + } + if expected.attribute_group_refs != actual.attribute_group_refs { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} attribute_group_refs changed across the supported v1 bitcode roundtrip" + ))); + } + + if expected.params.len() != actual.params.len() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} parameter count changed across the supported v1 bitcode roundtrip" + ))); + } + + for (param_index, (expected_param, actual_param)) in + expected.params.iter().zip(actual.params.iter()).enumerate() + { + if expected_param.ty != actual_param.ty { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} parameter {param_index} type changed across the supported v1 bitcode roundtrip" + ))); + } + if !param_names_semantically_equal( + expected_param.name.as_deref(), + actual_param.name.as_deref(), + param_index, + ) { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} parameter {param_index} name changed across the supported v1 bitcode roundtrip" + ))); + } + } + + if expected.basic_blocks.len() != actual.basic_blocks.len() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} basic block count changed across the supported v1 bitcode roundtrip" + ))); + } + + for (block_index, (expected_block, actual_block)) in expected + .basic_blocks + .iter() + .zip(actual.basic_blocks.iter()) + .enumerate() + { + if expected_block.name != actual_block.name { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} block {block_index} name changed across the supported v1 bitcode roundtrip" + ))); + } + + if expected_block.instructions.len() != actual_block.instructions.len() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} block {block_index} instruction count changed across the supported v1 bitcode roundtrip" + ))); + } + + for (instruction_index, (expected_instruction, actual_instruction)) in expected_block + .instructions + .iter() + .zip(actual_block.instructions.iter()) + .enumerate() + { + if !instructions_semantically_equal(expected_instruction, actual_instruction) { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "function {function_index} block {block_index} instruction {instruction_index} changed across the supported v1 bitcode roundtrip" + ))); + } + } + } + } + + let referenced_ids: FxHashSet = original + .functions + .iter() + .flat_map(|function| function.attribute_group_refs.iter().copied()) + .collect(); + let original_referenced: Vec<_> = original + .attribute_groups + .iter() + .filter(|group| referenced_ids.contains(&group.id)) + .collect(); + let reparsed_referenced: Vec<_> = reparsed + .attribute_groups + .iter() + .filter(|group| referenced_ids.contains(&group.id)) + .collect(); + if original_referenced != reparsed_referenced { + return Err(QirSmithError::BitcodeRoundTrip( + "attribute_groups changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.named_metadata != reparsed.named_metadata { + return Err(QirSmithError::BitcodeRoundTrip( + "named_metadata changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + if original.metadata_nodes != reparsed.metadata_nodes { + return Err(QirSmithError::BitcodeRoundTrip( + "metadata_nodes changed across the supported v1 bitcode roundtrip".to_string(), + )); + } + + Ok(()) +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/config.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/config.rs new file mode 100644 index 0000000000..e36924cb1f --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/config.rs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use miette::Diagnostic; +use thiserror::Error; + +use crate::{model::Module, qir, validation::QirProfileError}; + +pub(super) const DEFAULT_MAX_FUNCS: usize = 1; +pub(super) const DEFAULT_MAX_BLOCKS_PER_FUNC: usize = 6; +pub(super) const DEFAULT_MAX_INSTRS_PER_BLOCK: usize = 12; +pub(super) const BASE_V1_BLOCK_COUNT: usize = 4; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QirProfilePreset { + BaseV1, + AdaptiveV1, + AdaptiveV2, + BareRoundtrip, +} + +impl QirProfilePreset { + #[must_use] + pub fn to_qir_profile(self) -> Option { + match self { + Self::BaseV1 => Some(qir::QirProfile::BaseV1), + Self::AdaptiveV1 => Some(qir::QirProfile::AdaptiveV1), + Self::AdaptiveV2 => Some(qir::QirProfile::AdaptiveV2), + Self::BareRoundtrip => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum OutputMode { + #[default] + Model, + Text, + Bitcode, + RoundTripChecked, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RoundTripKind { + TextOnly, + BitcodeOnly, + TextAndBitcodeSinglePass, +} + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QirSmithConfig { + pub profile: QirProfilePreset, + pub output_mode: OutputMode, + pub roundtrip: Option, + pub max_funcs: usize, + pub max_blocks_per_func: usize, + pub max_instrs_per_block: usize, + pub allow_phi: bool, + pub allow_switch: bool, + pub allow_memory_ops: bool, + pub allow_typed_pointers: bool, + pub bare_roundtrip_mode: bool, +} + +impl Default for QirSmithConfig { + fn default() -> Self { + Self::for_profile(QirProfilePreset::AdaptiveV2) + } +} + +impl QirSmithConfig { + #[must_use] + pub const fn for_profile(profile: QirProfilePreset) -> Self { + let bare_roundtrip_mode = matches!(profile, QirProfilePreset::BareRoundtrip); + let allow_typed_pointers = matches!( + profile, + QirProfilePreset::BaseV1 | QirProfilePreset::AdaptiveV1 + ); + let max_blocks_per_func = match profile { + QirProfilePreset::BaseV1 => BASE_V1_BLOCK_COUNT, + _ => DEFAULT_MAX_BLOCKS_PER_FUNC, + }; + Self { + profile, + output_mode: OutputMode::Model, + roundtrip: None, + max_funcs: DEFAULT_MAX_FUNCS, + max_blocks_per_func, + max_instrs_per_block: DEFAULT_MAX_INSTRS_PER_BLOCK, + allow_phi: false, + allow_switch: false, + allow_memory_ops: false, + allow_typed_pointers, + bare_roundtrip_mode, + } + } + + #[must_use] + pub fn sanitize(&self) -> EffectiveConfig { + let profile = if self.bare_roundtrip_mode + || matches!(self.profile, QirProfilePreset::BareRoundtrip) + { + QirProfilePreset::BareRoundtrip + } else { + self.profile + }; + + let defaults = Self::for_profile(profile); + let output_mode = self.output_mode; + let roundtrip = if matches!(output_mode, OutputMode::RoundTripChecked) { + let default_kind = if matches!( + profile, + QirProfilePreset::BaseV1 | QirProfilePreset::AdaptiveV1 + ) { + RoundTripKind::TextOnly + } else { + RoundTripKind::TextAndBitcodeSinglePass + }; + Some(self.roundtrip.unwrap_or(default_kind)) + } else { + None + }; + + EffectiveConfig { + profile, + output_mode, + roundtrip, + max_funcs: super::sanitize_count(self.max_funcs, defaults.max_funcs), + max_blocks_per_func: super::sanitize_count( + self.max_blocks_per_func, + defaults.max_blocks_per_func, + ), + max_instrs_per_block: super::sanitize_count( + self.max_instrs_per_block, + defaults.max_instrs_per_block, + ), + allow_phi: self.allow_phi && !matches!(output_mode, OutputMode::RoundTripChecked), + allow_switch: self.allow_switch && !matches!(output_mode, OutputMode::RoundTripChecked), + allow_memory_ops: self.allow_memory_ops + && !matches!(output_mode, OutputMode::RoundTripChecked), + allow_typed_pointers: defaults.allow_typed_pointers, + bare_roundtrip_mode: defaults.bare_roundtrip_mode, + } + } + + pub(super) fn with_output_mode(&self, output_mode: OutputMode) -> Self { + let mut config = self.clone(); + config.output_mode = output_mode; + config + } +} + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EffectiveConfig { + pub profile: QirProfilePreset, + pub output_mode: OutputMode, + pub roundtrip: Option, + pub max_funcs: usize, + pub max_blocks_per_func: usize, + pub max_instrs_per_block: usize, + pub allow_phi: bool, + pub allow_switch: bool, + pub allow_memory_ops: bool, + pub allow_typed_pointers: bool, + pub bare_roundtrip_mode: bool, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct GeneratedArtifact { + pub effective_config: EffectiveConfig, + pub module: Module, + pub text: Option, + pub bitcode: Option>, +} + +#[derive(Debug, Error, Diagnostic, Clone, PartialEq, Eq)] +pub enum QirSmithError { + #[error( + "qir_smith generation is not implemented yet for profile {profile:?} in {output_mode:?} mode" + )] + #[diagnostic(code("Qsc.Llvm.QirSmith.GenerationNotImplemented"))] + GenerationNotImplemented { + profile: QirProfilePreset, + output_mode: OutputMode, + }, + + #[error("qir_smith generated module is outside the supported v1 checked subset: {0}")] + #[diagnostic(code("Qsc.Llvm.QirSmith.ModelGenerationFailed"))] + ModelGeneration(String), + + #[error("qir_smith text roundtrip failed: {0}")] + #[diagnostic(code("Qsc.Llvm.QirSmith.TextRoundTripFailed"))] + TextRoundTrip(String), + + #[error("qir_smith bitcode roundtrip failed: {0}")] + #[diagnostic(code("Qsc.Llvm.QirSmith.BitcodeRoundTripFailed"))] + BitcodeRoundTrip(String), + + #[error(transparent)] + #[diagnostic(transparent)] + ProfileViolation(QirProfileError), +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/generator.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/generator.rs new file mode 100644 index 0000000000..25397a7a37 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/generator.rs @@ -0,0 +1,1697 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use arbitrary::Unstructured; +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{ + model::{ + BasicBlock, BinOpKind, CastKind, Constant, FloatPredicate, Function, GlobalVariable, + Instruction, IntPredicate, Linkage, Module, Operand, Param, Type, + }, + qir, +}; + +use super::{ + config::{BASE_V1_BLOCK_COUNT, EffectiveConfig, QirProfilePreset}, + metadata::{build_qdk_attribute_groups, build_qdk_metadata}, +}; + +const MAX_SHELL_COUNT: usize = 4; +const MAX_OPTIONAL_GLOBALS: usize = 2; +pub(super) const GENERATED_TARGET_DATALAYOUT: &str = "e-p:64:64"; +pub(super) const GENERATED_TARGET_TRIPLE: &str = "arm64-apple-macosx15.0.0"; +const MODELED_INTEGER_INITIALIZERS: [i64; 4] = [-1, 0, 1, 2]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) struct ShellPreset { + include_qdk_shell: bool, + default_include_declarations: bool, + default_include_globals: bool, +} + +impl ShellPreset { + pub(super) fn from_profile(profile: QirProfilePreset) -> Self { + match profile { + QirProfilePreset::BaseV1 + | QirProfilePreset::AdaptiveV1 + | QirProfilePreset::AdaptiveV2 => Self { + include_qdk_shell: true, + default_include_declarations: true, + default_include_globals: true, + }, + QirProfilePreset::BareRoundtrip => Self { + include_qdk_shell: false, + default_include_declarations: false, + default_include_globals: false, + }, + } + } + + fn entry_point_attr_refs(self) -> Vec { + if self.include_qdk_shell { + vec![qir::ENTRY_POINT_ATTR_GROUP_ID] + } else { + Vec::new() + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub(super) struct ShellCounts { + pub(super) required_num_qubits: usize, + pub(super) required_num_results: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ShellDeclaration { + Hadamard, + ControlledX, + Measure, + ArrayRecordOutput, + ResultRecordOutput, + ResultArrayRecordOutput, + TupleRecordOutput, + BoolRecordOutput, + IntRecordOutput, + DoubleRecordOutput, + QubitAllocate, + QubitRelease, + Initialize, + ReadResult, +} + +const SHELL_DECLARATIONS: [ShellDeclaration; 14] = [ + ShellDeclaration::Hadamard, + ShellDeclaration::ControlledX, + ShellDeclaration::Measure, + ShellDeclaration::ArrayRecordOutput, + ShellDeclaration::ResultRecordOutput, + ShellDeclaration::ResultArrayRecordOutput, + ShellDeclaration::TupleRecordOutput, + ShellDeclaration::BoolRecordOutput, + ShellDeclaration::IntRecordOutput, + ShellDeclaration::DoubleRecordOutput, + ShellDeclaration::QubitAllocate, + ShellDeclaration::QubitRelease, + ShellDeclaration::Initialize, + ShellDeclaration::ReadResult, +]; + +#[allow(clippy::struct_field_names)] +#[derive(Debug, Default)] +pub(super) struct StableNameAllocator { + global_index: usize, + block_index: usize, + local_index: usize, +} + +impl StableNameAllocator { + pub(super) fn next_global_name(&mut self) -> String { + let name = self.global_index.to_string(); + self.global_index += 1; + name + } + + pub(super) fn next_block_name(&mut self) -> String { + let name = format!("block_{}", self.block_index); + self.block_index += 1; + name + } + + pub(super) fn next_local_name(&mut self) -> String { + let name = format!("var_{}", self.local_index); + self.local_index += 1; + name + } +} + +#[derive(Debug, Clone)] +struct CfgBlockPlan { + name: String, + predecessors: Vec, + terminator: BlockTerminator, +} + +#[derive(Debug, Clone)] +enum BlockTerminator { + Ret, + Jump { + dest: String, + }, + Branch { + true_dest: String, + false_dest: String, + }, + Switch { + ty: Type, + default_dest: String, + cases: Vec<(i64, String)>, + }, + Unreachable, +} + +#[derive(Debug, Clone, PartialEq)] +struct MemorySlot { + ty: Type, + ptr_ty: Type, + ptr: Operand, +} + +#[derive(Debug, Clone, Default)] +pub(super) struct TypedValuePool { + values: FxHashMap>, + memory_slots: Vec, +} + +impl TypedValuePool { + fn add(&mut self, ty: Type, operand: Operand) { + let entry = self.values.entry(ty).or_default(); + if !entry.contains(&operand) { + entry.push(operand); + } + } + + fn has_values(&self, ty: &Type) -> bool { + self.values.get(ty).is_some_and(|values| !values.is_empty()) + } + + fn has_local(&self, ty: &Type) -> bool { + self.values.get(ty).is_some_and(|values| { + values.iter().any(|operand| { + matches!(operand, Operand::LocalRef(_) | Operand::TypedLocalRef(_, _)) + }) + }) + } + + fn has_memory_slots(&self) -> bool { + !self.memory_slots.is_empty() + } + + fn add_memory_slot(&mut self, ty: Type, ptr_ty: Type, ptr: Operand) { + self.add(ptr_ty.clone(), ptr.clone()); + + let slot = MemorySlot { ty, ptr_ty, ptr }; + if !self.memory_slots.contains(&slot) { + self.memory_slots.push(slot); + } + } + + fn choose_memory_slot(&self, bytes: &mut Unstructured<'_>) -> Option { + choose_index(bytes, self.memory_slots.len()).map(|index| self.memory_slots[index].clone()) + } + + fn choose( + &self, + ty: &Type, + bytes: &mut Unstructured<'_>, + prefer_locals: bool, + ) -> Option { + let values = self.values.get(ty)?; + if values.is_empty() { + return None; + } + + if prefer_locals { + let locals: Vec<_> = values + .iter() + .filter(|operand| { + matches!(operand, Operand::LocalRef(_) | Operand::TypedLocalRef(_, _)) + }) + .cloned() + .collect(); + if let Some(index) = choose_index(bytes, locals.len()) { + return Some(locals[index].clone()); + } + } + + choose_index(bytes, values.len()).map(|index| values[index].clone()) + } + + fn choose_ptr_operand( + &self, + ty: &Type, + bytes: &mut Unstructured<'_>, + prefer_globals: bool, + ) -> Option { + let values = self.values.get(ty)?; + let filtered: Vec<_> = values + .iter() + .filter(|operand| { + let is_global = is_global_operand(operand); + if prefer_globals { + is_global + } else { + !is_global + } + }) + .cloned() + .collect(); + + if let Some(index) = choose_index(bytes, filtered.len()) { + return Some(filtered[index].clone()); + } + + self.choose(ty, bytes, false) + } + + fn intersection(pools: &[&Self]) -> Self { + let Some((first, rest)) = pools.split_first() else { + return Self::default(); + }; + + let mut intersection = (*first).clone(); + for pool in rest { + intersection.retain_common(pool); + } + intersection + } + + fn retain_common(&mut self, other: &Self) { + self.values.retain(|ty, operands| { + operands.retain(|operand| other.contains(ty, operand)); + !operands.is_empty() + }); + self.memory_slots + .retain(|slot| other.memory_slots.contains(slot)); + } + + fn contains(&self, ty: &Type, operand: &Operand) -> bool { + self.values + .get(ty) + .is_some_and(|operands| operands.contains(operand)) + } +} + +#[derive(Debug, Clone)] +struct CallTarget { + name: String, + return_ty: Option, + params: Vec, +} + +impl From<&Function> for CallTarget { + fn from(function: &Function) -> Self { + Self { + name: function.name.clone(), + return_ty: (function.return_type != Type::Void).then(|| function.return_type.clone()), + params: function + .params + .iter() + .map(|param| param.ty.clone()) + .collect(), + } + } +} + +#[derive(Debug, Clone, Copy)] +enum BodyInstructionKind { + Call, + I64BinOp, + FloatBinOp, + ICmp, + FCmp, + Zext, + SIToFP, + FPToSI, + Alloca, + Load, + Store, + Select, + GetElementPtr, +} + +const BODY_INSTRUCTION_KINDS: [BodyInstructionKind; 8] = [ + BodyInstructionKind::Call, + BodyInstructionKind::I64BinOp, + BodyInstructionKind::FloatBinOp, + BodyInstructionKind::ICmp, + BodyInstructionKind::FCmp, + BodyInstructionKind::Zext, + BodyInstructionKind::SIToFP, + BodyInstructionKind::FPToSI, +]; + +const BASE_BODY_INSTRUCTION_KINDS: [BodyInstructionKind; 1] = [BodyInstructionKind::Call]; + +const MEMORY_BODY_INSTRUCTION_KINDS: [BodyInstructionKind; 5] = [ + BodyInstructionKind::Load, + BodyInstructionKind::Store, + BodyInstructionKind::Select, + BodyInstructionKind::GetElementPtr, + BodyInstructionKind::Alloca, +]; + +const I64_BINOPS: [BinOpKind; 6] = [ + BinOpKind::Add, + BinOpKind::Sub, + BinOpKind::Mul, + BinOpKind::And, + BinOpKind::Or, + BinOpKind::Xor, +]; + +const FLOAT_BINOPS: [BinOpKind; 3] = [BinOpKind::Fadd, BinOpKind::Fsub, BinOpKind::Fmul]; + +const FLOAT_SCALAR_TYPES: [Type; 3] = [Type::Half, Type::Float, Type::Double]; + +const INT_PREDICATES: [IntPredicate; 6] = [ + IntPredicate::Eq, + IntPredicate::Ne, + IntPredicate::Slt, + IntPredicate::Sle, + IntPredicate::Sgt, + IntPredicate::Sge, +]; + +const FLOAT_PREDICATES: [FloatPredicate; 6] = [ + FloatPredicate::Oeq, + FloatPredicate::One, + FloatPredicate::Olt, + FloatPredicate::Ole, + FloatPredicate::Ogt, + FloatPredicate::Oge, +]; + +#[derive(Debug)] +pub(super) struct QirGenState { + pub(super) module: Module, + preset: ShellPreset, + profile: QirProfilePreset, + shell_counts: ShellCounts, + typed_pointers: bool, + names: StableNameAllocator, + declaration_registry: FxHashSet, + global_registry: FxHashMap, +} + +impl QirGenState { + pub(super) fn new( + preset: ShellPreset, + shell_counts: ShellCounts, + profile: QirProfilePreset, + bytes: &mut Unstructured<'_>, + ) -> Self { + let qir_profile = profile.to_qir_profile(); + let typed_pointers = matches!( + profile, + QirProfilePreset::BaseV1 | QirProfilePreset::AdaptiveV1 + ); + let (attribute_groups, named_metadata, metadata_nodes, struct_types) = + if let Some(qir_profile) = qir_profile { + let (named_metadata, metadata_nodes) = build_qdk_metadata(qir_profile, bytes); + ( + build_qdk_attribute_groups(qir_profile, shell_counts), + named_metadata, + metadata_nodes, + qir_profile.struct_types(), + ) + } else { + (Vec::new(), Vec::new(), Vec::new(), Vec::new()) + }; + + Self { + module: Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types, + globals: Vec::new(), + functions: Vec::new(), + attribute_groups, + named_metadata, + metadata_nodes, + }, + preset, + profile, + shell_counts, + typed_pointers, + names: StableNameAllocator::default(), + declaration_registry: FxHashSet::default(), + global_registry: FxHashMap::default(), + } + } + + fn build(mut self, effective: &EffectiveConfig, bytes: &mut Unstructured<'_>) -> Module { + self.add_optional_declarations(bytes); + self.add_optional_globals(effective, bytes); + let entry_point = self.build_entry_point(effective, bytes); + self.module.functions.insert(0, entry_point); + self.module + } + + fn populate_target_headers(&mut self, effective: &EffectiveConfig) { + if !should_emit_target_headers(effective) { + return; + } + + self.module.target_datalayout = Some(GENERATED_TARGET_DATALAYOUT.to_string()); + self.module.target_triple = Some(GENERATED_TARGET_TRIPLE.to_string()); + } + + fn build_entry_point( + &mut self, + effective: &EffectiveConfig, + bytes: &mut Unstructured<'_>, + ) -> Function { + let include_initialize = self.preset.include_qdk_shell; + if include_initialize { + self.register_declaration(ShellDeclaration::Initialize); + } + + if self.preset.include_qdk_shell { + self.register_declaration(ShellDeclaration::TupleRecordOutput); + } + + if matches!( + effective.profile, + QirProfilePreset::AdaptiveV1 | QirProfilePreset::AdaptiveV2 + ) { + self.register_declaration(ShellDeclaration::ReadResult); + } + + if matches!(effective.profile, QirProfilePreset::AdaptiveV2) { + self.register_declaration(ShellDeclaration::ResultArrayRecordOutput); + } + + let cfg = self.plan_entry_cfg(effective, bytes); + let call_targets = self.collect_call_targets(); + let mut basic_blocks = Vec::with_capacity(cfg.len()); + let mut exit_pools = Vec::with_capacity(cfg.len()); + + for (block_index, plan) in cfg.iter().enumerate() { + let predecessor_pools: Vec<_> = plan + .predecessors + .iter() + .map(|&index| &exit_pools[index]) + .collect(); + let predecessor_names: Vec<_> = plan + .predecessors + .iter() + .map(|&index| cfg[index].name.clone()) + .collect(); + + let mut pool = if block_index == 0 { + self.build_base_value_pool() + } else { + TypedValuePool::intersection(&predecessor_pools) + }; + + let require_nontrivial_body = block_index == 0 + || (self.profile == QirProfilePreset::BaseV1 && block_index + 1 < cfg.len()); + + let mut instructions = self.build_block_instructions( + effective, + plan, + &predecessor_names, + &predecessor_pools, + &call_targets, + &mut pool, + effective.max_instrs_per_block, + bytes, + require_nontrivial_body, + ); + + if block_index == 0 && include_initialize { + let null_arg = Operand::NullPtr; + let init_call = Instruction::Call { + return_ty: None, + callee: qir::rt::INITIALIZE.to_string(), + args: vec![( + if self.typed_pointers { + Type::TypedPtr(Box::new(Type::Integer(8))) + } else { + Type::Ptr + }, + null_arg, + )], + result: None, + attr_refs: Vec::new(), + }; + instructions.insert(0, init_call); + } + + exit_pools.push(pool); + basic_blocks.push(BasicBlock { + name: plan.name.clone(), + instructions, + }); + } + + Function { + name: qir::ENTRYPOINT_NAME.to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: self.preset.entry_point_attr_refs(), + basic_blocks, + } + } + + fn collect_call_targets(&self) -> Vec { + self.module + .functions + .iter() + .filter(|function| function.is_declaration) + .map(CallTarget::from) + .collect() + } + + fn pointer_result_type(&self, pointee_ty: &Type) -> Type { + if self.typed_pointers { + Type::TypedPtr(Box::new(pointee_ty.clone())) + } else { + Type::Ptr + } + } + + pub(super) fn build_base_value_pool(&self) -> TypedValuePool { + let mut pool = TypedValuePool::default(); + + for value in [0, 1] { + pool.add(Type::Integer(1), Operand::IntConst(Type::Integer(1), value)); + } + for value in [-1, 0, 1, 2, 3] { + pool.add( + Type::Integer(64), + Operand::IntConst(Type::Integer(64), value), + ); + } + for ty in FLOAT_SCALAR_TYPES { + for value in [0.0, 1.0, -1.0, 2.5] { + pool.add(ty.clone(), Operand::float_const(ty.clone(), value)); + } + } + + if self.typed_pointers { + let qubit_ty = Type::NamedPtr("Qubit".to_string()); + let result_ty = Type::NamedPtr("Result".to_string()); + let i8_ptr_ty = Type::TypedPtr(Box::new(Type::Integer(8))); + + for value in 0..=2 { + pool.add(qubit_ty.clone(), Operand::int_to_named_ptr(value, "Qubit")); + pool.add( + result_ty.clone(), + Operand::int_to_named_ptr(value, "Result"), + ); + } + for global in self + .module + .globals + .iter() + .filter(|global| is_string_label_global(global)) + { + let array_ty = global.ty.clone(); + pool.add( + i8_ptr_ty.clone(), + Operand::GetElementPtr { + ty: array_ty.clone(), + ptr: global.name.clone(), + ptr_ty: array_ty, + indices: vec![ + Operand::IntConst(Type::Integer(64), 0), + Operand::IntConst(Type::Integer(64), 0), + ], + }, + ); + } + } else { + pool.add(Type::Ptr, Operand::NullPtr); + for value in 0..=2 { + pool.add(Type::Ptr, Operand::IntToPtr(value, Type::Ptr)); + } + for global in self + .module + .globals + .iter() + .filter(|global| is_string_label_global(global)) + { + pool.add(Type::Ptr, Operand::GlobalRef(global.name.clone())); + } + } + + pool + } + + fn plan_entry_cfg( + &mut self, + effective: &EffectiveConfig, + bytes: &mut Unstructured<'_>, + ) -> Vec { + let block_count = + if self.profile == QirProfilePreset::BaseV1 && !has_expanded_generation(effective) { + BASE_V1_BLOCK_COUNT + } else if effective.max_blocks_per_func <= 1 { + 1 + } else { + bytes + .int_in_range(2..=effective.max_blocks_per_func) + .unwrap_or(2) + }; + + let mut names = Vec::with_capacity(block_count); + for _ in 0..block_count { + names.push(self.names.next_block_name()); + } + + let mut predecessors = vec![Vec::new(); block_count]; + let mut terminators = Vec::with_capacity(block_count); + for block_index in 0..block_count { + if block_index + 1 >= block_count { + if effective.allow_switch && take_flag(bytes, false) { + terminators.push(BlockTerminator::Unreachable); + } else { + terminators.push(BlockTerminator::Ret); + } + continue; + } + + let next = block_index + 1; + let can_branch = block_index + 2 < block_count + && (self.profile != QirProfilePreset::BaseV1 || has_expanded_generation(effective)); + let can_switch = effective.allow_switch && block_index + 2 < block_count; + let default_branch = block_index == 0 && can_branch; + + if can_switch && take_flag(bytes, false) { + predecessors[next].push(block_index); + predecessors[next + 1].push(block_index); + + let mut cases = vec![(0, names[next + 1].clone())]; + if block_index + 3 < block_count && take_flag(bytes, false) { + let extra_dest = block_count - 1; + if extra_dest > next + 1 { + predecessors[extra_dest].push(block_index); + cases.push((1, names[extra_dest].clone())); + } + } + + terminators.push(BlockTerminator::Switch { + ty: Type::Integer(64), + default_dest: names[next].clone(), + cases, + }); + } else if can_branch && take_flag(bytes, default_branch) { + let false_dest = bytes + .int_in_range((next + 1)..=block_count - 1) + .unwrap_or(block_count - 1); + predecessors[next].push(block_index); + predecessors[false_dest].push(block_index); + terminators.push(BlockTerminator::Branch { + true_dest: names[next].clone(), + false_dest: names[false_dest].clone(), + }); + } else { + predecessors[next].push(block_index); + terminators.push(BlockTerminator::Jump { + dest: names[next].clone(), + }); + } + } + + names + .into_iter() + .zip(predecessors) + .zip(terminators) + .map(|((name, predecessors), terminator)| CfgBlockPlan { + name, + predecessors, + terminator, + }) + .collect() + } + + #[allow(clippy::too_many_arguments)] + fn build_block_instructions( + &mut self, + effective: &EffectiveConfig, + plan: &CfgBlockPlan, + predecessor_names: &[String], + predecessor_pools: &[&TypedValuePool], + call_targets: &[CallTarget], + pool: &mut TypedValuePool, + max_instrs_per_block: usize, + bytes: &mut Unstructured<'_>, + require_nontrivial_body: bool, + ) -> Vec { + let body_budget = max_instrs_per_block.saturating_sub(1); + let min_body_instructions = usize::from(require_nontrivial_body && body_budget > 0); + let body_instruction_count = + take_body_instruction_count(bytes, body_budget, min_body_instructions); + let mut instructions = Vec::with_capacity(body_instruction_count + 3); + + if effective.allow_phi + && body_budget > instructions.len() + && let Some(instruction) = + self.build_phi_instruction(predecessor_names, predecessor_pools, pool, bytes) + { + instructions.push(instruction); + } + + if effective.allow_memory_ops + && !pool.has_memory_slots() + && body_budget > instructions.len() + && let Some(instruction) = self.build_alloca_instruction(pool) + { + instructions.push(instruction); + } + + for _ in 0..body_instruction_count { + if instructions.len() >= body_budget { + break; + } + + if let Some(instruction) = + self.build_body_instruction(effective, call_targets, pool, bytes) + { + instructions.push(instruction); + } + } + + if require_nontrivial_body && body_budget > 0 && instructions.is_empty() { + let fallback = if effective.allow_memory_ops { + self.build_alloca_instruction(pool) + .or_else(|| self.build_i64_binop_instruction(pool, bytes)) + } else if self.profile == QirProfilePreset::BaseV1 + && !has_expanded_generation(effective) + { + self.build_call_instruction(call_targets, pool, bytes) + } else { + self.build_i64_binop_instruction(pool, bytes) + .or_else(|| self.build_call_instruction(call_targets, pool, bytes)) + }; + if let Some(instruction) = fallback { + instructions.push(instruction); + } + } + + if matches!(plan.terminator, BlockTerminator::Branch { .. }) + && !pool.has_local(&Type::Integer(1)) + && instructions.len() < body_budget + && let Some(instruction) = self.build_compare_instruction(pool, bytes) + { + instructions.push(instruction); + } + + instructions.push(Self::build_terminator(plan, pool, bytes)); + instructions + } + + fn build_body_instruction( + &mut self, + effective: &EffectiveConfig, + call_targets: &[CallTarget], + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let mut kinds: Vec<_> = match self.profile { + QirProfilePreset::BaseV1 if !has_expanded_generation(effective) => { + BASE_BODY_INSTRUCTION_KINDS.to_vec() + } + _ => BODY_INSTRUCTION_KINDS.to_vec(), + }; + + if effective.allow_memory_ops { + kinds.splice(0..0, MEMORY_BODY_INSTRUCTION_KINDS); + } + + let start = choose_index(bytes, kinds.len()).unwrap_or(0); + for offset in 0..kinds.len() { + let kind = kinds[(start + offset) % kinds.len()]; + let instruction = match kind { + BodyInstructionKind::Call => self.build_call_instruction(call_targets, pool, bytes), + BodyInstructionKind::I64BinOp => self.build_i64_binop_instruction(pool, bytes), + BodyInstructionKind::FloatBinOp => self.build_float_binop_instruction(pool, bytes), + BodyInstructionKind::ICmp => self.build_icmp_instruction(pool, bytes), + BodyInstructionKind::FCmp => self.build_fcmp_instruction(pool, bytes), + BodyInstructionKind::Zext => self.build_zext_instruction(pool, bytes), + BodyInstructionKind::SIToFP => self.build_sitofp_instruction(pool, bytes), + BodyInstructionKind::FPToSI => self.build_fptosi_instruction(pool, bytes), + BodyInstructionKind::Alloca => self.build_alloca_instruction(pool), + BodyInstructionKind::Load => self.build_load_instruction(pool, bytes), + BodyInstructionKind::Store => Self::build_store_instruction(pool, bytes), + BodyInstructionKind::Select => self.build_select_instruction(pool, bytes), + BodyInstructionKind::GetElementPtr => self.build_gep_instruction(pool, bytes), + }; + if instruction.is_some() { + return instruction; + } + } + + None + } + + fn build_call_instruction( + &mut self, + call_targets: &[CallTarget], + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let candidates: Vec<_> = call_targets + .iter() + .filter(|target| { + self.profile != QirProfilePreset::BaseV1 || is_base_v1_safe_call_target(target) + }) + .filter(|target| target.params.iter().all(|ty| pool.has_values(ty))) + .collect(); + let target = candidates.get(choose_index(bytes, candidates.len())?)?; + + let mut args = Vec::with_capacity(target.params.len()); + for (param_index, ty) in target.params.iter().enumerate() { + let operand = if is_pointer_type(ty) { + pool.choose_ptr_operand( + ty, + bytes, + prefers_global_label_arg(&target.name, param_index), + )? + } else { + pool.choose(ty, bytes, false)? + }; + args.push((ty.clone(), operand)); + } + + let (return_ty, result) = if let Some(ty) = &target.return_ty { + let result_name = self.names.next_local_name(); + pool.add(ty.clone(), Operand::LocalRef(result_name.clone())); + (Some(ty.clone()), Some(result_name)) + } else { + (None, None) + }; + + Some(Instruction::Call { + return_ty, + callee: target.name.clone(), + args, + result, + attr_refs: Vec::new(), + }) + } + + fn build_i64_binop_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let op = choose_from_slice(bytes, &I64_BINOPS)?; + let lhs = pool.choose(&Type::Integer(64), bytes, true)?; + let rhs = pool.choose(&Type::Integer(64), bytes, false)?; + let result = self.names.next_local_name(); + pool.add(Type::Integer(64), Operand::LocalRef(result.clone())); + + Some(Instruction::BinOp { + op, + ty: Type::Integer(64), + lhs, + rhs, + result, + }) + } + + pub(super) fn build_float_binop_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let op = choose_from_slice(bytes, &FLOAT_BINOPS)?; + let ty = choose_available_floating_type(pool, bytes)?; + let lhs = pool.choose(&ty, bytes, true)?; + let rhs = pool.choose(&ty, bytes, false)?; + let result = self.names.next_local_name(); + pool.add(ty.clone(), Operand::LocalRef(result.clone())); + + Some(Instruction::BinOp { + op, + ty, + lhs, + rhs, + result, + }) + } + + fn build_icmp_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let pred = choose_from_slice(bytes, &INT_PREDICATES)?; + let lhs = pool.choose(&Type::Integer(64), bytes, true)?; + let rhs = pool.choose(&Type::Integer(64), bytes, false)?; + let result = self.names.next_local_name(); + pool.add(Type::Integer(1), Operand::LocalRef(result.clone())); + + Some(Instruction::ICmp { + pred, + ty: Type::Integer(64), + lhs, + rhs, + result, + }) + } + + pub(super) fn build_fcmp_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let pred = choose_from_slice(bytes, &FLOAT_PREDICATES)?; + let ty = choose_available_floating_type(pool, bytes)?; + let lhs = pool.choose(&ty, bytes, true)?; + let rhs = pool.choose(&ty, bytes, false)?; + let result = self.names.next_local_name(); + pool.add(Type::Integer(1), Operand::LocalRef(result.clone())); + + Some(Instruction::FCmp { + pred, + ty, + lhs, + rhs, + result, + }) + } + + fn build_compare_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + if take_flag(bytes, true) { + self.build_icmp_instruction(pool, bytes) + .or_else(|| self.build_fcmp_instruction(pool, bytes)) + } else { + self.build_fcmp_instruction(pool, bytes) + .or_else(|| self.build_icmp_instruction(pool, bytes)) + } + } + + fn build_zext_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let value = pool.choose(&Type::Integer(1), bytes, true)?; + let result = self.names.next_local_name(); + pool.add(Type::Integer(64), Operand::LocalRef(result.clone())); + + Some(Instruction::Cast { + op: CastKind::Zext, + from_ty: Type::Integer(1), + to_ty: Type::Integer(64), + value, + result, + }) + } + + pub(super) fn build_sitofp_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let to_ty = choose_available_floating_type(pool, bytes)?; + let value = pool.choose(&Type::Integer(64), bytes, true)?; + let result = self.names.next_local_name(); + pool.add(to_ty.clone(), Operand::LocalRef(result.clone())); + + Some(Instruction::Cast { + op: CastKind::Sitofp, + from_ty: Type::Integer(64), + to_ty, + value, + result, + }) + } + + pub(super) fn build_fptosi_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let from_ty = choose_available_floating_type(pool, bytes)?; + let value = pool.choose(&from_ty, bytes, true)?; + let result = self.names.next_local_name(); + pool.add(Type::Integer(64), Operand::LocalRef(result.clone())); + + Some(Instruction::Cast { + op: CastKind::Fptosi, + from_ty, + to_ty: Type::Integer(64), + value, + result, + }) + } + + fn build_phi_instruction( + &mut self, + predecessor_names: &[String], + predecessor_pools: &[&TypedValuePool], + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + if predecessor_pools.len() < 2 + || predecessor_names.len() != predecessor_pools.len() + || !predecessor_pools + .iter() + .all(|pred_pool| pred_pool.has_values(&Type::Integer(64))) + { + return None; + } + + let mut incoming = Vec::with_capacity(predecessor_pools.len()); + for (pred_name, pred_pool) in predecessor_names.iter().zip(predecessor_pools.iter()) { + incoming.push(( + pred_pool.choose(&Type::Integer(64), bytes, true)?, + pred_name.clone(), + )); + } + + let result = self.names.next_local_name(); + pool.add(Type::Integer(64), Operand::LocalRef(result.clone())); + + Some(Instruction::Phi { + ty: Type::Integer(64), + incoming, + result, + }) + } + + fn build_alloca_instruction(&mut self, pool: &mut TypedValuePool) -> Option { + let ty = Type::Integer(64); + let result = self.names.next_local_name(); + let ptr_ty = self.pointer_result_type(&ty); + let ptr = Operand::LocalRef(result.clone()); + pool.add_memory_slot(ty.clone(), ptr_ty, ptr); + + Some(Instruction::Alloca { ty, result }) + } + + fn build_load_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let slot = pool.choose_memory_slot(bytes)?; + let result = self.names.next_local_name(); + pool.add(slot.ty.clone(), Operand::LocalRef(result.clone())); + + Some(Instruction::Load { + ty: slot.ty, + ptr_ty: slot.ptr_ty, + ptr: slot.ptr, + result, + }) + } + + fn build_store_instruction( + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let slot = pool.choose_memory_slot(bytes)?; + let value = pool.choose(&slot.ty, bytes, true)?; + + Some(Instruction::Store { + ty: slot.ty, + value, + ptr_ty: slot.ptr_ty, + ptr: slot.ptr, + }) + } + + fn build_select_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let result = self.names.next_local_name(); + let cond = pool.choose(&Type::Integer(1), bytes, true)?; + let true_val = pool.choose(&Type::Integer(64), bytes, true)?; + let false_val = pool.choose(&Type::Integer(64), bytes, false)?; + pool.add(Type::Integer(64), Operand::LocalRef(result.clone())); + + Some(Instruction::Select { + cond, + true_val, + false_val, + ty: Type::Integer(64), + result, + }) + } + + fn build_gep_instruction( + &mut self, + pool: &mut TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Option { + let slot = pool.choose_memory_slot(bytes)?; + let result = self.names.next_local_name(); + let ptr = Operand::LocalRef(result.clone()); + let indices = vec![Operand::IntConst(Type::Integer(64), 0)]; + pool.add_memory_slot(slot.ty.clone(), slot.ptr_ty.clone(), ptr); + + Some(Instruction::GetElementPtr { + inbounds: true, + pointee_ty: slot.ty, + ptr_ty: slot.ptr_ty, + ptr: slot.ptr, + indices, + result, + }) + } + + fn build_terminator( + plan: &CfgBlockPlan, + pool: &TypedValuePool, + bytes: &mut Unstructured<'_>, + ) -> Instruction { + match &plan.terminator { + BlockTerminator::Ret => Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + BlockTerminator::Jump { dest } => Instruction::Jump { dest: dest.clone() }, + BlockTerminator::Branch { + true_dest, + false_dest, + } => Instruction::Br { + cond_ty: Type::Integer(1), + cond: pool + .choose(&Type::Integer(1), bytes, true) + .unwrap_or(Operand::IntConst(Type::Integer(1), 1)), + true_dest: true_dest.clone(), + false_dest: false_dest.clone(), + }, + BlockTerminator::Switch { + ty, + default_dest, + cases, + } => Instruction::Switch { + ty: ty.clone(), + value: pool + .choose(ty, bytes, true) + .unwrap_or_else(|| Operand::IntConst(ty.clone(), 0)), + default_dest: default_dest.clone(), + cases: cases.clone(), + }, + BlockTerminator::Unreachable => Instruction::Unreachable, + } + } + + fn add_optional_declarations(&mut self, bytes: &mut Unstructured<'_>) { + for declaration in SHELL_DECLARATIONS { + match declaration { + ShellDeclaration::DoubleRecordOutput => { + if matches!( + self.profile, + QirProfilePreset::AdaptiveV1 | QirProfilePreset::AdaptiveV2 + ) { + continue; + } + } + ShellDeclaration::BoolRecordOutput + | ShellDeclaration::IntRecordOutput + | ShellDeclaration::QubitAllocate + | ShellDeclaration::QubitRelease => { + if self.profile == QirProfilePreset::BaseV1 { + continue; + } + } + _ => {} + } + if take_flag(bytes, self.preset.default_include_declarations) { + self.register_declaration(declaration); + } + } + } + + fn register_declaration(&mut self, declaration: ShellDeclaration) { + let name = declaration_name(declaration); + if self.declaration_registry.insert(name.to_string()) { + self.module + .functions + .push(build_declaration(declaration, self.typed_pointers)); + } + } + + fn add_optional_globals(&mut self, effective: &EffectiveConfig, bytes: &mut Unstructured<'_>) { + if !take_flag(bytes, self.preset.default_include_globals) { + return; + } + + if self.preset.include_qdk_shell { + self.add_qdk_label_globals(); + } else { + let count = take_optional_global_count(bytes); + for index in 0..count { + self.intern_string_global(build_bare_global_string(bytes, index)); + } + } + + if supports_modeled_global_generation(effective.profile) { + self.add_modeled_opaque_globals(bytes); + } + } + + fn add_qdk_label_globals(&mut self) { + self.intern_string_global("0_a".to_string()); + + let result_labels = self + .shell_counts + .required_num_results + .min(MAX_OPTIONAL_GLOBALS); + for result_index in 0..result_labels { + self.intern_string_global(format!("{}_a{}r", result_index + 1, result_index)); + } + } + + fn add_modeled_opaque_globals(&mut self, bytes: &mut Unstructured<'_>) { + if self.typed_pointers { + return; + } + + self.push_fresh_global(Type::Integer(64), Linkage::External, false, None); + self.push_fresh_global(Type::Ptr, Linkage::Internal, false, Some(Constant::Null)); + + let value = choose_from_slice(bytes, &MODELED_INTEGER_INITIALIZERS).unwrap_or(0); + self.push_fresh_global( + Type::Integer(64), + Linkage::Internal, + true, + Some(Constant::Int(value)), + ); + } + + fn push_fresh_global( + &mut self, + ty: Type, + linkage: Linkage, + is_constant: bool, + initializer: Option, + ) -> String { + let name = self.names.next_global_name(); + self.module.globals.push(GlobalVariable { + name: name.clone(), + ty, + linkage, + is_constant, + initializer, + }); + + name + } + + pub(super) fn intern_string_global(&mut self, value: String) -> String { + if let Some(name) = self.global_registry.get(&value) { + return name.clone(); + } + + let array_len = u64::try_from(value.len() + 1).expect("CString length should fit in u64"); + let name = self.push_fresh_global( + Type::Array(array_len, Box::new(Type::Integer(8))), + Linkage::Internal, + true, + Some(Constant::CString(value.clone())), + ); + self.global_registry.insert(value, name.clone()); + + name + } +} + +pub(super) fn build_module_shell( + effective: &EffectiveConfig, + bytes: &mut Unstructured<'_>, +) -> Module { + let preset = ShellPreset::from_profile(effective.profile); + let shell_counts = take_shell_counts(preset, bytes); + let mut state = QirGenState::new(preset, shell_counts, effective.profile, bytes); + state.populate_target_headers(effective); + state.build(effective, bytes) +} + +fn take_body_instruction_count( + bytes: &mut Unstructured<'_>, + max_body_instructions: usize, + min_body_instructions: usize, +) -> usize { + if max_body_instructions == 0 { + return 0; + } + + bytes + .int_in_range(min_body_instructions..=max_body_instructions) + .unwrap_or(min_body_instructions) +} + +fn choose_index(bytes: &mut Unstructured<'_>, len: usize) -> Option { + if len == 0 { + return None; + } + + Some(bytes.int_in_range(0..=len - 1).unwrap_or(0)) +} + +fn choose_from_slice(bytes: &mut Unstructured<'_>, values: &[T]) -> Option { + choose_index(bytes, values.len()).map(|index| values[index].clone()) +} + +fn choose_available_floating_type( + pool: &TypedValuePool, + bytes: &mut Unstructured<'_>, +) -> Option { + let available: Vec<_> = FLOAT_SCALAR_TYPES + .into_iter() + .filter(|ty| pool.has_values(ty)) + .collect(); + choose_from_slice(bytes, &available) +} + +fn is_pointer_type(ty: &Type) -> bool { + matches!(ty, Type::Ptr | Type::NamedPtr(_) | Type::TypedPtr(_)) +} + +fn is_global_operand(op: &Operand) -> bool { + matches!(op, Operand::GlobalRef(_) | Operand::GetElementPtr { .. }) +} + +fn is_string_label_global(global: &GlobalVariable) -> bool { + global.is_constant + && matches!( + (&global.ty, &global.initializer), + (Type::Array(_, element), Some(Constant::CString(_))) + if element.as_ref() == &Type::Integer(8) + ) +} + +fn should_emit_target_headers(effective: &EffectiveConfig) -> bool { + matches!( + ( + effective.profile, + effective.output_mode, + effective.roundtrip + ), + ( + QirProfilePreset::BareRoundtrip, + super::OutputMode::RoundTripChecked, + Some(super::RoundTripKind::BitcodeOnly) + ) + ) +} + +fn supports_modeled_global_generation(profile: QirProfilePreset) -> bool { + matches!( + profile, + QirProfilePreset::AdaptiveV2 | QirProfilePreset::BareRoundtrip + ) +} + +fn has_expanded_generation(effective: &EffectiveConfig) -> bool { + effective.allow_phi || effective.allow_switch || effective.allow_memory_ops +} + +fn prefers_global_label_arg(callee: &str, param_index: usize) -> bool { + qir::output_label_arg_index(callee) == Some(param_index) +} + +fn is_base_v1_safe_call_target(target: &CallTarget) -> bool { + matches!( + target.name.as_str(), + qir::qis::H | qir::qis::CX | qir::qis::M + ) +} + +fn take_shell_counts(preset: ShellPreset, bytes: &mut Unstructured<'_>) -> ShellCounts { + if !preset.include_qdk_shell { + return ShellCounts::default(); + } + + ShellCounts { + required_num_qubits: take_small_count(bytes, 0), + required_num_results: take_small_count(bytes, 0), + } +} + +fn take_small_count(bytes: &mut Unstructured<'_>, default: usize) -> usize { + bytes.int_in_range(0..=MAX_SHELL_COUNT).unwrap_or(default) +} + +fn take_optional_global_count(bytes: &mut Unstructured<'_>) -> usize { + bytes.int_in_range(0..=MAX_OPTIONAL_GLOBALS).unwrap_or(0) +} + +fn take_flag(bytes: &mut Unstructured<'_>, default: bool) -> bool { + bytes.arbitrary::().unwrap_or(default) +} + +fn build_bare_global_string(bytes: &mut Unstructured<'_>, index: usize) -> String { + let suffix_len = bytes.int_in_range(3..=6_usize).unwrap_or(4); + let mut value = format!("g{index}_"); + + for _ in 0..suffix_len { + let symbol = bytes.int_in_range(0..=35_u8).unwrap_or(0); + let ch = if symbol < 26 { + char::from(b'a' + symbol) + } else { + char::from(b'0' + (symbol - 26)) + }; + value.push(ch); + } + + value +} + +fn declaration_name(declaration: ShellDeclaration) -> &'static str { + match declaration { + ShellDeclaration::Hadamard => qir::qis::H, + ShellDeclaration::ControlledX => qir::qis::CX, + ShellDeclaration::Measure => qir::qis::M, + ShellDeclaration::ArrayRecordOutput => qir::rt::ARRAY_RECORD_OUTPUT, + ShellDeclaration::ResultRecordOutput => qir::rt::RESULT_RECORD_OUTPUT, + ShellDeclaration::ResultArrayRecordOutput => qir::rt::RESULT_ARRAY_RECORD_OUTPUT, + ShellDeclaration::TupleRecordOutput => qir::rt::TUPLE_RECORD_OUTPUT, + ShellDeclaration::BoolRecordOutput => qir::rt::BOOL_RECORD_OUTPUT, + ShellDeclaration::IntRecordOutput => qir::rt::INT_RECORD_OUTPUT, + ShellDeclaration::DoubleRecordOutput => qir::rt::DOUBLE_RECORD_OUTPUT, + ShellDeclaration::QubitAllocate => qir::rt::QUBIT_ALLOCATE, + ShellDeclaration::QubitRelease => qir::rt::QUBIT_RELEASE, + ShellDeclaration::Initialize => qir::rt::INITIALIZE, + ShellDeclaration::ReadResult => qir::rt::READ_RESULT, + } +} + +#[allow(clippy::too_many_lines)] +fn build_declaration(declaration: ShellDeclaration, typed_pointers: bool) -> Function { + let qubit_ptr = if typed_pointers { + Type::NamedPtr("Qubit".to_string()) + } else { + Type::Ptr + }; + let result_ptr = if typed_pointers { + Type::NamedPtr("Result".to_string()) + } else { + Type::Ptr + }; + let i8_ptr = if typed_pointers { + Type::TypedPtr(Box::new(Type::Integer(8))) + } else { + Type::Ptr + }; + + match declaration { + ShellDeclaration::Hadamard => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![Param { + ty: qubit_ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::ControlledX => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: qubit_ptr.clone(), + name: None, + }, + Param { + ty: qubit_ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::Measure => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: qubit_ptr, + name: None, + }, + Param { + ty: result_ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: vec![qir::IRREVERSIBLE_ATTR_GROUP_ID], + basic_blocks: Vec::new(), + }, + ShellDeclaration::ResultRecordOutput => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: result_ptr, + name: None, + }, + Param { + ty: i8_ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::ResultArrayRecordOutput => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Integer(64), + name: None, + }, + Param { + ty: i8_ptr.clone(), + name: None, + }, + Param { + ty: i8_ptr.clone(), + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::ArrayRecordOutput + | ShellDeclaration::TupleRecordOutput + | ShellDeclaration::IntRecordOutput => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Integer(64), + name: None, + }, + Param { + ty: i8_ptr.clone(), + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::BoolRecordOutput => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Integer(1), + name: None, + }, + Param { + ty: i8_ptr.clone(), + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::DoubleRecordOutput => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Double, + name: None, + }, + Param { + ty: i8_ptr.clone(), + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::QubitAllocate => Function { + name: declaration_name(declaration).to_string(), + return_type: if typed_pointers { + Type::NamedPtr("Qubit".to_string()) + } else { + Type::Ptr + }, + params: vec![Param { + ty: i8_ptr.clone(), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::QubitRelease => Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![Param { + ty: if typed_pointers { + Type::NamedPtr("Qubit".to_string()) + } else { + Type::Ptr + }, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + ShellDeclaration::Initialize => { + let ptr_param = if typed_pointers { + Type::TypedPtr(Box::new(Type::Integer(8))) + } else { + Type::Ptr + }; + Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Void, + params: vec![Param { + ty: ptr_param, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + } + } + ShellDeclaration::ReadResult => { + let param_ty = if typed_pointers { + Type::NamedPtr("Result".to_string()) + } else { + Type::Ptr + }; + Function { + name: declaration_name(declaration).to_string(), + return_type: Type::Integer(1), + params: vec![Param { + ty: param_ty, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + } + } + } +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/io.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/io.rs new file mode 100644 index 0000000000..1b12e94e8a --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/io.rs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::{ + bitcode::{reader::parse_bitcode_compatibility_report, writer::try_write_bitcode}, + model::Module, + text::{reader::parse_module, writer::write_module_to_string}, +}; + +use super::config::QirSmithError; + +pub(super) fn emit_text(module: &Module) -> String { + write_module_to_string(module) +} + +pub(super) fn parse_text_roundtrip(text: &str) -> Result { + parse_module(text).map_err(QirSmithError::TextRoundTrip) +} + +pub(super) fn emit_bitcode(module: &Module) -> Result, QirSmithError> { + try_write_bitcode(module).map_err(|error| { + QirSmithError::ModelGeneration(format!("bitcode emission failed: {error}")) + }) +} + +pub(super) fn parse_bitcode_roundtrip(bitcode: &[u8]) -> Result { + let report = parse_bitcode_compatibility_report(bitcode).map_err(|diagnostics| { + QirSmithError::BitcodeRoundTrip(format_read_diagnostics(&diagnostics)) + })?; + + if !report.diagnostics.is_empty() { + return Err(QirSmithError::BitcodeRoundTrip(format!( + "compatibility diagnostics were reported during bitcode import: {}", + format_read_diagnostics(&report.diagnostics) + ))); + } + + Ok(report.module) +} + +pub(super) fn format_read_diagnostics(diagnostics: &[crate::ReadDiagnostic]) -> String { + diagnostics + .iter() + .map(ToString::to_string) + .collect::>() + .join("; ") +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/metadata.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/metadata.rs new file mode 100644 index 0000000000..ef45d16127 --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/metadata.rs @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use arbitrary::Unstructured; +use rustc_hash::FxHashMap; + +use crate::{ + model::{Attribute, AttributeGroup, MetadataNode, MetadataValue, Module, NamedMetadata, Type}, + qir, +}; + +use crate::fuzz::qir_smith::generator::ShellCounts; + +const SUPPORTED_FLOAT_COMPUTATIONS: [&str; 3] = ["half", "float", "double"]; + +pub(super) fn build_qdk_attribute_groups( + profile: qir::QirProfile, + shell_counts: ShellCounts, +) -> Vec { + vec![ + AttributeGroup { + id: qir::ENTRY_POINT_ATTR_GROUP_ID, + attributes: vec![ + Attribute::StringAttr(qir::ENTRY_POINT_ATTR.to_string()), + Attribute::StringAttr(qir::OUTPUT_LABELING_SCHEMA_ATTR.to_string()), + Attribute::KeyValue( + qir::QIR_PROFILES_ATTR.to_string(), + profile.profile_name().to_string(), + ), + Attribute::KeyValue( + qir::REQUIRED_NUM_QUBITS_ATTR.to_string(), + shell_counts.required_num_qubits.to_string(), + ), + Attribute::KeyValue( + qir::REQUIRED_NUM_RESULTS_ATTR.to_string(), + shell_counts.required_num_results.to_string(), + ), + ], + }, + AttributeGroup { + id: qir::IRREVERSIBLE_ATTR_GROUP_ID, + attributes: vec![Attribute::StringAttr(qir::IRREVERSIBLE_ATTR.to_string())], + }, + ] +} + +pub(super) fn build_qdk_metadata( + profile: qir::QirProfile, + _bytes: &mut Unstructured<'_>, +) -> (Vec, Vec) { + let mut metadata_nodes = vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_ERROR), + MetadataValue::String(qir::QIR_MAJOR_VERSION_KEY.to_string()), + MetadataValue::Int(Type::Integer(32), profile.major_version()), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_MAX), + MetadataValue::String(qir::QIR_MINOR_VERSION_KEY.to_string()), + MetadataValue::Int(Type::Integer(32), profile.minor_version()), + ], + }, + MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_ERROR), + MetadataValue::String(qir::DYNAMIC_QUBIT_MGMT_KEY.to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 3, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_ERROR), + MetadataValue::String(qir::DYNAMIC_RESULT_MGMT_KEY.to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + ]; + + if matches!(profile, qir::QirProfile::AdaptiveV2) { + metadata_nodes.push(MetadataNode { + id: 4, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_APPEND), + MetadataValue::String(qir::INT_COMPUTATIONS_KEY.to_string()), + MetadataValue::SubList(vec![MetadataValue::String("i64".to_string())]), + ], + }); + metadata_nodes.push(MetadataNode { + id: 5, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_APPEND), + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.to_string()), + supported_float_computation_metadata(), + ], + }); + metadata_nodes.push(MetadataNode { + id: 6, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_MAX), + MetadataValue::String(qir::BACKWARDS_BRANCHING_KEY.to_string()), + MetadataValue::Int(Type::Integer(2), 3), + ], + }); + metadata_nodes.push(MetadataNode { + id: 7, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_ERROR), + MetadataValue::String(qir::ARRAYS_KEY.to_string()), + MetadataValue::Int(Type::Integer(1), 1), + ], + }); + } + + if profile == qir::QirProfile::AdaptiveV1 { + let next_id = + u32::try_from(metadata_nodes.len()).expect("metadata node count should fit in u32"); + metadata_nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_APPEND), + MetadataValue::String(qir::INT_COMPUTATIONS_KEY.to_string()), + MetadataValue::SubList(vec![MetadataValue::String("i64".to_string())]), + ], + }); + + let next_id = + u32::try_from(metadata_nodes.len()).expect("metadata node count should fit in u32"); + metadata_nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_APPEND), + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.to_string()), + supported_float_computation_metadata(), + ], + }); + } + + let node_refs: Vec = metadata_nodes.iter().map(|n| n.id).collect(); + + ( + vec![NamedMetadata { + name: qir::MODULE_FLAGS_NAME.to_string(), + node_refs, + }], + metadata_nodes, + ) +} + +fn supported_float_computation_metadata() -> MetadataValue { + MetadataValue::SubList( + SUPPORTED_FLOAT_COMPUTATIONS + .iter() + .map(|width| MetadataValue::String((*width).to_string())) + .collect(), + ) +} + +pub(super) fn finalize_float_computations(module: &mut Module) { + let analysis = qir::inspect::analyze_float_surface(module); + let float_flag_node_id = find_module_flag_node_id(module, qir::FLOAT_COMPUTATIONS_KEY); + + if !analysis.has_float_op { + if let Some(node_id) = float_flag_node_id { + remove_module_flag_node(module, node_id); + } + return; + } + + let metadata_value = MetadataValue::SubList( + analysis + .surface_width_names() + .into_iter() + .map(|width| MetadataValue::String(width.to_string())) + .collect(), + ); + + if let Some(node_id) = float_flag_node_id + && let Some(node) = module + .metadata_nodes + .iter_mut() + .find(|candidate| candidate.id == node_id) + { + let behavior = node.values.first().cloned().unwrap_or(MetadataValue::Int( + Type::Integer(32), + qir::FLAG_BEHAVIOR_APPEND, + )); + node.values = vec![ + behavior, + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.to_string()), + metadata_value, + ]; + return; + } + + if !module + .named_metadata + .iter() + .any(|metadata| metadata.name == qir::MODULE_FLAGS_NAME) + { + return; + } + + let node_id = next_metadata_node_id(module); + module.metadata_nodes.push(MetadataNode { + id: node_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), qir::FLAG_BEHAVIOR_APPEND), + MetadataValue::String(qir::FLOAT_COMPUTATIONS_KEY.to_string()), + metadata_value, + ], + }); + + if let Some(module_flags) = module + .named_metadata + .iter_mut() + .find(|metadata| metadata.name == qir::MODULE_FLAGS_NAME) + { + module_flags.node_refs.push(node_id); + } +} + +fn find_module_flag_node_id(module: &Module, key: &str) -> Option { + let module_flags = module + .named_metadata + .iter() + .find(|metadata| metadata.name == qir::MODULE_FLAGS_NAME)?; + + for &node_ref in &module_flags.node_refs { + let Some(node) = module + .metadata_nodes + .iter() + .find(|candidate| candidate.id == node_ref) + else { + continue; + }; + if node.values.len() >= 2 + && let MetadataValue::String(flag_name) = &node.values[1] + && flag_name == key + { + return Some(node_ref); + } + } + + None +} + +fn next_metadata_node_id(module: &Module) -> u32 { + module + .metadata_nodes + .iter() + .map(|node| node.id) + .max() + .map_or(0, |id| id.saturating_add(1)) +} + +fn remap_metadata_value_node_refs(value: &mut MetadataValue, id_remap: &FxHashMap) { + match value { + MetadataValue::NodeRef(node_id) => { + if let Some(remapped_id) = id_remap.get(node_id).copied() { + *node_id = remapped_id; + } + } + MetadataValue::SubList(values) => { + for child in values { + remap_metadata_value_node_refs(child, id_remap); + } + } + MetadataValue::Int(_, _) | MetadataValue::String(_) => {} + } +} + +fn renumber_metadata_nodes(module: &mut Module) { + let old_ids: Vec = module.metadata_nodes.iter().map(|node| node.id).collect(); + if old_ids + .iter() + .enumerate() + .all(|(index, &node_id)| node_id == u32::try_from(index).expect("invalid index value")) + { + return; + } + + let id_remap: FxHashMap = old_ids + .iter() + .enumerate() + .map(|(index, &old_id)| (old_id, u32::try_from(index).expect("invalid index value"))) + .collect(); + + for metadata in &mut module.named_metadata { + metadata.node_refs = metadata + .node_refs + .iter() + .filter_map(|node_id| id_remap.get(node_id).copied()) + .collect(); + } + + for node in &mut module.metadata_nodes { + node.id = id_remap[&node.id]; + for value in &mut node.values { + remap_metadata_value_node_refs(value, &id_remap); + } + } +} + +fn remove_module_flag_node(module: &mut Module, node_id: u32) { + if let Some(module_flags) = module + .named_metadata + .iter_mut() + .find(|metadata| metadata.name == qir::MODULE_FLAGS_NAME) + { + module_flags + .node_refs + .retain(|&candidate| candidate != node_id); + } + + module.metadata_nodes.retain(|node| node.id != node_id); + renumber_metadata_nodes(module); +} diff --git a/source/compiler/qsc_llvm/src/fuzz/qir_smith/tests.rs b/source/compiler/qsc_llvm/src/fuzz/qir_smith/tests.rs new file mode 100644 index 0000000000..6a597dc2cf --- /dev/null +++ b/source/compiler/qsc_llvm/src/fuzz/qir_smith/tests.rs @@ -0,0 +1,1655 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::test_utils::{PointerProbe, assemble_text_ir, available_fast_matrix_lanes}; + +fn deterministic_seed_bytes() -> Vec { + (0_u8..=127).collect() +} + +fn opaque_adaptive_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 4, + max_instrs_per_block: 6, + ..QirSmithConfig::default() + } +} + +fn bare_roundtrip_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 4, + max_instrs_per_block: 6, + ..QirSmithConfig::for_profile(QirProfilePreset::BareRoundtrip) + } +} + +fn checked_effective_config(config: &QirSmithConfig, roundtrip: RoundTripKind) -> EffectiveConfig { + QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + roundtrip: Some(roundtrip), + ..config.clone() + } + .sanitize() +} + +fn adaptive_shell_state() -> QirGenState { + let seed_bytes = deterministic_seed_bytes(); + let mut unstructured = Unstructured::new(&seed_bytes); + QirGenState::new( + ShellPreset::from_profile(QirProfilePreset::AdaptiveV1), + ShellCounts::default(), + QirProfilePreset::AdaptiveV1, + &mut unstructured, + ) +} + +fn generated_entry_point(module: &Module) -> &Function { + module + .functions + .first() + .expect("generated module should include an entry point") +} + +fn metadata_string_list(module: &Module, key: &str) -> Option> { + match module.get_flag(key) { + Some(MetadataValue::SubList(items)) => Some( + items + .iter() + .filter_map(|value| match value { + MetadataValue::String(text) => Some(text.clone()), + _ => None, + }) + .collect(), + ), + _ => None, + } +} + +fn adaptive_v1_module_with_float_metadata_shell( + globals: Vec, + declarations: Vec, + instructions: Vec, +) -> Module { + let seed_bytes = [0_u8; 1]; + let mut unstructured = Unstructured::new(&seed_bytes); + let (named_metadata, metadata_nodes) = + build_qdk_metadata(qir::QirProfile::AdaptiveV1, &mut unstructured); + + let mut functions = declarations; + functions.push(Function { + name: qir::ENTRYPOINT_NAME.to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions, + }], + }); + + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals, + functions, + attribute_groups: Vec::new(), + named_metadata, + metadata_nodes, + } +} + +fn double_record_output_declaration() -> Function { + Function { + name: qir::rt::DOUBLE_RECORD_OUTPUT.to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Double, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + } +} + +#[test] +fn sanitize_clamps_counts_and_preserves_opt_in_expansion_flags_outside_checked_mode() { + let effective = QirSmithConfig { + output_mode: OutputMode::Text, + roundtrip: Some(RoundTripKind::BitcodeOnly), + max_funcs: 0, + max_blocks_per_func: 0, + max_instrs_per_block: 0, + allow_phi: true, + allow_switch: true, + allow_memory_ops: true, + allow_typed_pointers: true, + ..QirSmithConfig::default() + } + .sanitize(); + + assert_eq!(effective.profile, QirProfilePreset::AdaptiveV2); + assert_eq!(effective.output_mode, OutputMode::Text); + assert_eq!(effective.roundtrip, None); + assert_eq!(effective.max_funcs, DEFAULT_MAX_FUNCS); + assert_eq!(effective.max_blocks_per_func, DEFAULT_MAX_BLOCKS_PER_FUNC); + assert_eq!(effective.max_instrs_per_block, DEFAULT_MAX_INSTRS_PER_BLOCK); + assert!(effective.allow_phi); + assert!(effective.allow_switch); + assert!(effective.allow_memory_ops); + assert!(!effective.allow_typed_pointers); + assert!(!effective.bare_roundtrip_mode); +} + +#[test] +fn sanitize_clears_opt_in_expansion_flags_in_checked_mode() { + let effective = QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + allow_phi: true, + allow_switch: true, + allow_memory_ops: true, + ..bare_roundtrip_config() + } + .sanitize(); + + assert_eq!( + effective.roundtrip, + Some(RoundTripKind::TextAndBitcodeSinglePass) + ); + assert!(!effective.allow_phi); + assert!(!effective.allow_switch); + assert!(!effective.allow_memory_ops); +} + +#[test] +fn sanitize_promotes_bare_roundtrip_profile_and_defaults_checked_roundtrip() { + let effective = QirSmithConfig { + profile: QirProfilePreset::AdaptiveV2, + output_mode: OutputMode::RoundTripChecked, + roundtrip: None, + max_blocks_per_func: 2, + max_instrs_per_block: 3, + bare_roundtrip_mode: true, + ..QirSmithConfig::default() + } + .sanitize(); + + assert_eq!(effective.profile, QirProfilePreset::BareRoundtrip); + assert_eq!(effective.output_mode, OutputMode::RoundTripChecked); + assert_eq!( + effective.roundtrip, + Some(RoundTripKind::TextAndBitcodeSinglePass) + ); + assert_eq!(effective.max_blocks_per_func, 2); + assert_eq!(effective.max_instrs_per_block, 3); + assert!(effective.bare_roundtrip_mode); +} + +#[test] +fn stable_name_allocator_uses_monotonic_indices() { + let mut allocator = StableNameAllocator::default(); + + assert_eq!(allocator.next_global_name(), "0"); + assert_eq!(allocator.next_global_name(), "1"); + assert_eq!(allocator.next_block_name(), "block_0"); + assert_eq!(allocator.next_block_name(), "block_1"); + assert_eq!(allocator.next_local_name(), "var_0"); + assert_eq!(allocator.next_local_name(), "var_1"); +} + +#[test] +fn intern_string_global_reuses_existing_names() { + let seed_bytes = deterministic_seed_bytes(); + let mut unstructured = Unstructured::new(&seed_bytes); + let mut state = QirGenState::new( + ShellPreset::from_profile(QirProfilePreset::BareRoundtrip), + ShellCounts::default(), + QirProfilePreset::BareRoundtrip, + &mut unstructured, + ); + + let first = state.intern_string_global("alpha".to_string()); + let second = state.intern_string_global("alpha".to_string()); + let third = state.intern_string_global("beta".to_string()); + + assert_eq!(first, second); + assert_ne!(first, third); + assert_eq!(state.module.globals.len(), 2); + assert_eq!(state.module.globals[0].name, "0"); + assert_eq!(state.module.globals[1].name, "1"); + assert_eq!( + state.module.globals[0].initializer, + Some(Constant::CString("alpha".to_string())) + ); + assert_eq!( + state.module.globals[1].initializer, + Some(Constant::CString("beta".to_string())) + ); +} + +#[test] +fn generate_from_bytes_is_deterministic_for_same_seed_and_config() { + let seed_bytes = deterministic_seed_bytes(); + let config = opaque_adaptive_config(); + + let first = generate_from_bytes(&config, &seed_bytes) + .expect("generation should succeed for deterministic seed bytes"); + let second = generate_from_bytes(&config, &seed_bytes) + .expect("generation should succeed for deterministic seed bytes"); + + assert_eq!(first, second); +} + +#[test] +fn opaque_adaptive_modules_stay_within_safe_v1_shape() { + let seed_bytes = deterministic_seed_bytes(); + let module = generate_module_from_bytes(&opaque_adaptive_config(), &seed_bytes) + .expect("generation should produce a module"); + let entry_point = generated_entry_point(&module); + let float_analysis = crate::qir::inspect::analyze_float_surface(&module); + + assert_eq!(entry_point.name, qir::ENTRYPOINT_NAME); + assert!(!entry_point.is_declaration); + assert_eq!(entry_point.return_type, Type::Integer(64)); + assert!(entry_point.params.is_empty()); + assert_eq!( + entry_point.attribute_group_refs, + vec![qir::ENTRY_POINT_ATTR_GROUP_ID] + ); + assert_eq!(module.attribute_groups.len(), 2); + assert_eq!(module.named_metadata.len(), 1); + assert_eq!( + module.metadata_nodes.len(), + 7 + usize::from(float_analysis.has_float_op) + ); + assert_eq!( + metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY), + float_analysis.has_float_op.then(|| { + float_analysis + .surface_width_names() + .into_iter() + .map(str::to_string) + .collect() + }) + ); + assert!( + module + .get_flag(qir::QIR_MAJOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 2)) + ); + assert!( + module + .get_flag(qir::QIR_MINOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 0)) + ); + + let block_names: Vec<_> = entry_point + .basic_blocks + .iter() + .map(|block| block.name.clone()) + .collect(); + let expected_block_names: Vec<_> = (0..block_names.len()) + .map(|index| format!("block_{index}")) + .collect(); + assert_eq!(block_names, expected_block_names); + + let global_names: Vec<_> = module + .globals + .iter() + .map(|global| global.name.clone()) + .collect(); + let expected_global_names: Vec<_> = (0..module.globals.len()) + .map(|index| index.to_string()) + .collect(); + assert_eq!(global_names, expected_global_names); + + for block in &entry_point.basic_blocks { + assert!( + !block.instructions.is_empty(), + "generated blocks should always end with a terminator" + ); + assert!(matches!( + block + .instructions + .last() + .expect("blocks should contain instructions"), + Instruction::Ret(_) | Instruction::Jump { .. } | Instruction::Br { .. } + )); + + for instruction in &block.instructions { + assert!(matches!( + instruction, + Instruction::Ret(_) + | Instruction::Jump { .. } + | Instruction::Br { .. } + | Instruction::Call { .. } + | Instruction::BinOp { .. } + | Instruction::ICmp { .. } + | Instruction::FCmp { .. } + | Instruction::Cast { .. } + )); + } + } +} + +#[test] +fn bare_roundtrip_profile_omits_qdk_shell_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let module = generate_module_from_bytes(&bare_roundtrip_config(), &seed_bytes) + .expect("generation should produce a bare roundtrip module"); + let entry_point = generated_entry_point(&module); + + assert_eq!(entry_point.name, qir::ENTRYPOINT_NAME); + assert!(entry_point.attribute_group_refs.is_empty()); + assert!(module.attribute_groups.is_empty()); + assert!(module.named_metadata.is_empty()); + assert!(module.metadata_nodes.is_empty()); +} + +#[test] +fn checked_mode_reuses_model_generation_core() { + let seed_bytes = deterministic_seed_bytes(); + let config = bare_roundtrip_config(); + let checked_config = QirSmithConfig { + roundtrip: Some(RoundTripKind::TextAndBitcodeSinglePass), + ..config.clone() + }; + + let model = generate_module_from_bytes(&config, &seed_bytes) + .expect("model generation should succeed for deterministic seed bytes"); + let checked = generate_checked_from_bytes(&checked_config, &seed_bytes) + .expect("checked generation should succeed for deterministic seed bytes"); + + assert_eq!(checked.module, model); + assert_eq!( + checked.effective_config, + checked_effective_config(&config, RoundTripKind::TextAndBitcodeSinglePass) + ); + assert!(checked.text.is_some()); + assert!(checked.bitcode.is_some()); +} + +#[test] +fn checked_text_only_emits_text_without_bitcode() { + let checked = generate_checked_from_bytes( + &QirSmithConfig { + roundtrip: Some(RoundTripKind::TextOnly), + ..bare_roundtrip_config() + }, + &deterministic_seed_bytes(), + ) + .expect("checked text-only generation should succeed"); + + assert!(checked.text.is_some()); + assert!(checked.bitcode.is_none()); +} + +#[test] +fn checked_bitcode_only_emits_bitcode_without_text() { + let checked = generate_checked_from_bytes( + &QirSmithConfig { + roundtrip: Some(RoundTripKind::BitcodeOnly), + ..bare_roundtrip_config() + }, + &deterministic_seed_bytes(), + ) + .expect("checked bitcode-only generation should succeed"); + + assert!(checked.text.is_none()); + assert!(checked.bitcode.is_some()); +} + +fn synthesize_placeholder_param_names(functions: &mut [Function]) { + for function in functions { + for (param_index, param) in function.params.iter_mut().enumerate() { + if param.name.is_none() { + param.name = Some(format!("param_{param_index}")); + } + } + } +} + +fn annotate_local_ref_operand(operand: &mut Operand, ty: &Type) { + match operand { + Operand::LocalRef(name) => { + *operand = Operand::TypedLocalRef(name.clone(), ty.clone()); + } + Operand::TypedLocalRef(_, actual_ty) => { + *actual_ty = ty.clone(); + } + Operand::GetElementPtr { indices, .. } => { + for index in indices { + annotate_local_ref_operand(index, &Type::Integer(64)); + } + } + Operand::IntConst(_, _) + | Operand::FloatConst(_, _) + | Operand::NullPtr + | Operand::IntToPtr(_, _) + | Operand::GlobalRef(_) => {} + } +} + +fn annotate_checked_subset_local_refs(functions: &mut [Function]) { + for function in functions { + let return_ty = function.return_type.clone(); + for block in &mut function.basic_blocks { + for instruction in &mut block.instructions { + match instruction { + Instruction::Ret(Some(value)) => { + annotate_local_ref_operand(value, &return_ty); + } + Instruction::Br { cond_ty, cond, .. } => { + annotate_local_ref_operand(cond, cond_ty); + } + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + annotate_local_ref_operand(lhs, ty); + annotate_local_ref_operand(rhs, ty); + } + Instruction::Cast { from_ty, value, .. } => { + annotate_local_ref_operand(value, from_ty); + } + Instruction::Call { args, .. } => { + for (arg_ty, operand) in args { + annotate_local_ref_operand(operand, arg_ty); + } + } + Instruction::Ret(None) + | Instruction::Jump { .. } + | Instruction::Phi { .. } + | Instruction::Alloca { .. } + | Instruction::Load { .. } + | Instruction::Store { .. } + | Instruction::Select { .. } + | Instruction::Switch { .. } + | Instruction::GetElementPtr { .. } + | Instruction::Unreachable => {} + } + } + } + } +} + +fn checked_bitcode_semantic_fixture() -> Module { + Module { + source_filename: Some("checked_bitcode_fixture".to_string()), + target_datalayout: Some("e-p:64:64".to_string()), + target_triple: Some("arm64-apple-macosx15.0.0".to_string()), + struct_types: Vec::new(), + globals: vec![GlobalVariable { + name: "message".to_string(), + ty: Type::Array(6, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("hello".to_string())), + }], + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Integer(64), + name: Some("count".to_string()), + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "test".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: vec![ + (Type::Ptr, Operand::GlobalRef("message".to_string())), + (Type::Integer(64), Operand::IntConst(Type::Integer(64), 3)), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "sum".to_string(), + }, + Instruction::ICmp { + pred: IntPredicate::Eq, + ty: Type::Integer(64), + lhs: Operand::LocalRef("sum".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 3), + result: "cond".to_string(), + }, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("cond".to_string()), + true_dest: "then".to_string(), + false_dest: "exit".to_string(), + }, + ], + }, + BasicBlock { + name: "then".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::LocalRef( + "sum".to_string(), + )))], + }, + BasicBlock { + name: "exit".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }, + ], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +#[test] +fn checked_bitcode_equivalence_verifies_attrs_and_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let original = generate_module_from_bytes(&opaque_adaptive_config(), &seed_bytes) + .expect("generation should produce a module"); + let mut reparsed = original.clone(); + synthesize_placeholder_param_names(&mut reparsed.functions); + + assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed).expect( + "checked bitcode comparison should preserve attrs and metadata while allowing placeholder parameter names", + ); +} + +#[test] +fn checked_bitcode_equivalence_allows_typed_local_refs_in_instruction_payloads() { + let original = checked_bitcode_semantic_fixture(); + let mut reparsed = original.clone(); + synthesize_placeholder_param_names(&mut reparsed.functions); + annotate_checked_subset_local_refs(&mut reparsed.functions); + + assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed).expect( + "checked bitcode comparison should allow typed local refs and placeholder parameter names when the instruction context fixes the type", + ); +} + +#[test] +fn checked_bitcode_rejects_missing_attribute_groups() { + let seed_bytes = deterministic_seed_bytes(); + let original = generate_module_from_bytes(&opaque_adaptive_config(), &seed_bytes) + .expect("generation should produce a module"); + let mut reparsed = original.clone(); + reparsed.attribute_groups.clear(); + + let err = assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed) + .expect_err("should reject missing attribute groups"); + assert!(matches!( + err, + QirSmithError::BitcodeRoundTrip(message) if message.contains("attribute_groups") + )); +} + +#[test] +fn checked_bitcode_rejects_missing_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let original = generate_module_from_bytes(&opaque_adaptive_config(), &seed_bytes) + .expect("generation should produce a module"); + let mut reparsed = original.clone(); + reparsed.named_metadata.clear(); + reparsed.metadata_nodes.clear(); + + let err = assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed) + .expect_err("should reject missing metadata"); + assert!(matches!( + err, + QirSmithError::BitcodeRoundTrip(message) if message.contains("named_metadata") + )); +} + +#[test] +fn checked_bitcode_rejects_global_initializer_mismatch() { + let original = checked_bitcode_semantic_fixture(); + let mut reparsed = original.clone(); + reparsed.globals[0].initializer = Some(Constant::CString("hullo".to_string())); + + let err = assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed) + .expect_err("should reject changed global initializers"); + assert!(matches!( + err, + QirSmithError::BitcodeRoundTrip(message) if message.contains("initializer") + )); +} + +#[test] +fn checked_bitcode_rejects_instruction_payload_mismatch() { + let original = checked_bitcode_semantic_fixture(); + let mut reparsed = original.clone(); + let Instruction::Call { callee, .. } = + &mut reparsed.functions[1].basic_blocks[0].instructions[0] + else { + panic!("fixture should start with a call instruction"); + }; + *callee = "other_callee".to_string(); + + let err = assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed) + .expect_err("should reject changed instruction payloads"); + assert!(matches!( + err, + QirSmithError::BitcodeRoundTrip(message) + if message.contains("instruction 0 changed") + )); +} + +#[test] +fn checked_mode_reports_unsupported_v1_models_as_model_generation_errors() { + let seed_bytes = deterministic_seed_bytes(); + let base_config = bare_roundtrip_config(); + let mut module = generate_module_from_bytes(&base_config, &seed_bytes) + .expect("generation should produce a module"); + let incoming_block = module.functions[0].basic_blocks[0].name.clone(); + module.functions[0].basic_blocks[0].instructions.insert( + 0, + Instruction::Phi { + ty: Type::Integer(1), + incoming: vec![(Operand::IntConst(Type::Integer(1), 1), incoming_block)], + result: "var_phi".to_string(), + }, + ); + + let mut artifact = GeneratedArtifact { + effective_config: checked_effective_config(&base_config, RoundTripKind::TextOnly), + module, + text: None, + bitcode: None, + }; + + let err = populate_checked_artifact(&mut artifact) + .expect_err("checked mode should reject unsupported v1 model shapes"); + assert!(matches!( + err, + QirSmithError::ModelGeneration(message) if message.contains("phi") + )); +} + +#[test] +fn checked_text_mismatches_use_text_error_category() { + let seed_bytes = deterministic_seed_bytes(); + let original = generate_module_from_bytes(&bare_roundtrip_config(), &seed_bytes) + .expect("generation should produce a module"); + let mut reparsed = original.clone(); + reparsed.functions[0].name.push_str("_changed"); + + let err = ensure_text_roundtrip_matches(&original, &reparsed) + .expect_err("mismatched text roundtrip structure should fail"); + assert!(matches!( + err, + QirSmithError::TextRoundTrip(message) + if message.contains("changed module structure") + )); +} + +#[test] +fn checked_bitcode_mismatches_use_bitcode_error_category() { + let seed_bytes = deterministic_seed_bytes(); + let original = generate_module_from_bytes(&bare_roundtrip_config(), &seed_bytes) + .expect("generation should produce a module"); + let mut reparsed = original.clone(); + reparsed.functions.pop(); + + let err = assert_bitcode_roundtrip_matches_supported_v1_subset(&original, &reparsed) + .expect_err("mismatched bitcode roundtrip structure should fail"); + assert!(matches!( + err, + QirSmithError::BitcodeRoundTrip(message) if message.contains("function count") + )); +} + +// --- BaseV1 / AdaptiveV1 module generation tests --- + +fn base_v1_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: BASE_V1_BLOCK_COUNT, + max_instrs_per_block: 6, + ..QirSmithConfig::for_profile(QirProfilePreset::BaseV1) + } +} + +fn adaptive_v1_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 4, + max_instrs_per_block: 6, + ..QirSmithConfig::for_profile(QirProfilePreset::AdaptiveV1) + } +} + +fn control_flow_expansion_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 5, + max_instrs_per_block: 10, + allow_phi: true, + allow_switch: true, + ..bare_roundtrip_config() + } +} + +fn memory_expansion_config() -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 4, + max_instrs_per_block: 10, + allow_memory_ops: true, + ..bare_roundtrip_config() + } +} + +fn expansion_seed_bank() -> Vec> { + let mut seeds: Vec<_> = (0_u8..=15).map(|byte| vec![byte; 128]).collect(); + seeds.push(deterministic_seed_bytes()); + seeds +} + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Default)] +struct InstructionCoverage { + phi: bool, + switch: bool, + unreachable: bool, + alloca: bool, + load: bool, + store: bool, + select: bool, + instruction_gep: bool, +} + +impl InstructionCoverage { + fn observe_module(&mut self, module: &Module) { + for function in &module.functions { + for block in &function.basic_blocks { + for instruction in &block.instructions { + match instruction { + Instruction::Phi { .. } => self.phi = true, + Instruction::Alloca { .. } => self.alloca = true, + Instruction::Load { .. } => self.load = true, + Instruction::Store { .. } => self.store = true, + Instruction::Select { .. } => self.select = true, + Instruction::Switch { .. } => self.switch = true, + Instruction::Unreachable => self.unreachable = true, + Instruction::GetElementPtr { .. } => self.instruction_gep = true, + Instruction::Ret(_) + | Instruction::Br { .. } + | Instruction::Jump { .. } + | Instruction::BinOp { .. } + | Instruction::ICmp { .. } + | Instruction::FCmp { .. } + | Instruction::Cast { .. } + | Instruction::Call { .. } => {} + } + } + } + } + } +} + +fn generated_text_roundtrip_module(config: &QirSmithConfig, seed: &[u8]) -> Module { + let seed_summary = checked_smoke_seed_summary(seed); + let text = generate_text_from_bytes(config, seed) + .unwrap_or_else(|err| panic!("text generation failed for {seed_summary}: {err}")); + + parse_text_roundtrip(&text) + .unwrap_or_else(|err| panic!("text roundtrip failed for {seed_summary}: {err}")) +} + +fn missing_control_flow_families(coverage: &InstructionCoverage) -> Vec<&'static str> { + let mut missing = Vec::new(); + if !coverage.phi { + missing.push("phi"); + } + if !coverage.switch { + missing.push("switch"); + } + if !coverage.unreachable { + missing.push("unreachable"); + } + missing +} + +fn missing_memory_families(coverage: &InstructionCoverage) -> Vec<&'static str> { + let mut missing = Vec::new(); + if !coverage.alloca { + missing.push("alloca"); + } + if !coverage.load { + missing.push("load"); + } + if !coverage.store { + missing.push("store"); + } + if !coverage.select { + missing.push("select"); + } + if !coverage.instruction_gep { + missing.push("instruction getelementptr"); + } + missing +} + +fn typed_checked_smoke_seeds() -> [&'static [u8]; 3] { + [&[0_u8; 64][..], &[1_u8; 64][..], &[42_u8; 128][..]] +} + +fn checked_smoke_seed_summary(seed: &[u8]) -> String { + format!( + "len={} first_byte={}", + seed.len(), + seed.first().copied().unwrap_or_default() + ) +} + +fn assert_checked_generation_smoke_case( + config: &QirSmithConfig, + seeds: &[&[u8]], + invariant: impl Fn(&GeneratedArtifact), +) { + let checked_config = QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + roundtrip: None, + ..config.clone() + }; + let expected_effective = checked_config.sanitize(); + let profile = checked_config.profile; + let expects_bitcode = matches!( + expected_effective.roundtrip, + Some(RoundTripKind::BitcodeOnly | RoundTripKind::TextAndBitcodeSinglePass) + ); + + assert!( + expected_effective.roundtrip.is_some(), + "{profile:?} checked smoke should sanitize to a roundtrip mode" + ); + + for seed in seeds { + let seed_summary = checked_smoke_seed_summary(seed); + let artifact = generate_checked_from_bytes(&checked_config, seed).unwrap_or_else(|err| { + panic!("{profile:?} checked smoke failed for {seed_summary}: {err}") + }); + + assert_eq!( + artifact.effective_config, expected_effective, + "{profile:?} checked smoke should sanitize to the expected config for {seed_summary}" + ); + assert!( + artifact.text.is_some(), + "{profile:?} checked smoke should always emit text for {seed_summary}" + ); + assert_eq!( + artifact.bitcode.is_some(), + expects_bitcode, + "{profile:?} checked smoke bitcode presence should match sanitize defaults for {seed_summary}" + ); + + invariant(&artifact); + } +} + +fn assert_v1_typed_checked_smoke_shell(artifact: &GeneratedArtifact) { + assert!(artifact.effective_config.allow_typed_pointers); + assert!( + artifact + .module + .struct_types + .iter() + .any(|ty| ty.name == "Qubit"), + "typed checked smokes should retain the %Qubit shell type" + ); + assert!( + artifact + .module + .struct_types + .iter() + .any(|ty| ty.name == "Result"), + "typed checked smokes should retain the %Result shell type" + ); +} + +fn assert_base_v1_checked_smoke_invariant(artifact: &GeneratedArtifact) { + assert_v1_typed_checked_smoke_shell(artifact); + assert_eq!(artifact.module.metadata_nodes.len(), 4); + assert!( + artifact + .module + .get_flag(qir::INT_COMPUTATIONS_KEY) + .is_none() + ); + assert!(metadata_string_list(&artifact.module, qir::FLOAT_COMPUTATIONS_KEY).is_none()); +} + +fn assert_adaptive_v1_checked_smoke_invariant(artifact: &GeneratedArtifact) { + assert_v1_typed_checked_smoke_shell(artifact); + assert!( + artifact + .module + .get_flag(qir::INT_COMPUTATIONS_KEY) + .is_some() + ); +} + +fn assert_adaptive_v2_checked_smoke_invariant(artifact: &GeneratedArtifact) { + let analysis = crate::qir::inspect::analyze_float_surface(&artifact.module); + + assert!(!analysis.has_float_op); + assert!(metadata_string_list(&artifact.module, qir::FLOAT_COMPUTATIONS_KEY).is_none()); + + let metadata_ids: Vec<_> = artifact + .module + .metadata_nodes + .iter() + .map(|node| node.id) + .collect(); + let expected_ids: Vec<_> = + (0..u32::try_from(metadata_ids.len()).expect("metadata count should fit in u32")).collect(); + assert_eq!(metadata_ids, expected_ids); +} + +fn assert_bare_roundtrip_checked_smoke_invariant(artifact: &GeneratedArtifact) { + assert!(artifact.module.attribute_groups.is_empty()); + assert!(artifact.module.named_metadata.is_empty()); + assert!(artifact.module.metadata_nodes.is_empty()); +} + +fn checked_bitcode_only_artifact(config: &QirSmithConfig, seed: &[u8]) -> GeneratedArtifact { + let checked_config = QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + roundtrip: Some(RoundTripKind::BitcodeOnly), + ..config.clone() + }; + let profile = checked_config.profile; + let seed_summary = checked_smoke_seed_summary(seed); + + let artifact = generate_checked_from_bytes(&checked_config, seed).unwrap_or_else(|err| { + panic!("{profile:?} checked bitcode-only generation failed for {seed_summary}: {err}") + }); + + assert_eq!( + artifact.effective_config, + checked_effective_config(config, RoundTripKind::BitcodeOnly), + "{profile:?} checked bitcode-only generation should sanitize to the expected config for {seed_summary}" + ); + assert!(artifact.text.is_none()); + assert!(artifact.bitcode.is_some()); + + artifact +} + +fn assert_generated_target_headers(module: &Module) { + assert_eq!( + module.target_datalayout.as_deref(), + Some(GENERATED_TARGET_DATALAYOUT) + ); + assert_eq!( + module.target_triple.as_deref(), + Some(GENERATED_TARGET_TRIPLE) + ); +} + +fn has_generated_modeled_globals(module: &Module) -> bool { + module.globals.iter().any(|global| { + matches!( + global, + GlobalVariable { + ty: Type::Integer(64), + linkage: Linkage::External, + is_constant: false, + initializer: None, + .. + } + ) + }) +} + +fn assert_generated_modeled_globals(module: &Module) { + assert!(module.globals.iter().any(|global| { + matches!( + global, + GlobalVariable { + ty: Type::Integer(64), + linkage: Linkage::External, + is_constant: false, + initializer: None, + .. + } + ) + })); + assert!(module.globals.iter().any(|global| { + matches!( + global, + GlobalVariable { + ty: Type::Ptr, + linkage: Linkage::Internal, + is_constant: false, + initializer: Some(Constant::Null), + .. + } + ) + })); + assert!(module.globals.iter().any(|global| { + matches!( + global, + GlobalVariable { + ty: Type::Integer(64), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::Int(_)), + .. + } + ) + })); +} + +fn find_checked_bitcode_only_artifact( + config: &QirSmithConfig, + seeds: &[&[u8]], + predicate: impl Fn(&GeneratedArtifact) -> bool, + description: &str, +) -> GeneratedArtifact { + for seed in seeds { + let artifact = checked_bitcode_only_artifact(config, seed); + if predicate(&artifact) { + return artifact; + } + } + + panic!( + "no checked bitcode-only artifact matched {description} across {} fixed seeds", + seeds.len() + ); +} + +#[test] +fn base_v1_generates_correct_shell() { + let seed_bytes = deterministic_seed_bytes(); + let text = generate_text_from_bytes(&base_v1_config(), &seed_bytes) + .expect("BaseV1 text generation should succeed"); + + assert!(text.contains("%Qubit = type opaque")); + assert!(text.contains("%Result = type opaque")); + assert!(text.contains("base_profile")); + assert!(text.contains("qir_major_version\", i32 1")); + assert!(text.contains("qir_minor_version\", i32 0")); + assert!(!text.contains("int_computations")); + assert!(!text.contains("backwards_branching")); +} + +#[test] +fn base_v1_module_has_struct_types_and_v1_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let module = generate_module_from_bytes(&base_v1_config(), &seed_bytes) + .expect("BaseV1 module generation should succeed"); + let entry_point = generated_entry_point(&module); + + assert_eq!(module.struct_types.len(), 2); + assert_eq!(module.named_metadata.len(), 1); + assert_eq!(module.metadata_nodes.len(), 4); + assert_eq!(entry_point.name, qir::ENTRYPOINT_NAME); + assert!( + module + .get_flag(qir::QIR_MAJOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 1)) + ); + assert!( + module + .get_flag(qir::QIR_MINOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 0)) + ); + assert_eq!(entry_point.basic_blocks.len(), BASE_V1_BLOCK_COUNT); +} + +#[test] +fn adaptive_v1_generates_correct_shell() { + let seed_bytes = deterministic_seed_bytes(); + let text = generate_text_from_bytes(&adaptive_v1_config(), &seed_bytes) + .expect("AdaptiveV1 text generation should succeed"); + + assert!(text.contains("%Qubit = type opaque")); + assert!(text.contains("%Result = type opaque")); + assert!(text.contains("adaptive_profile")); + assert!(text.contains("qir_major_version\", i32 1")); + assert!(text.contains("qir_minor_version\", i32 0")); + assert!(!text.contains("backwards_branching")); +} + +#[test] +fn adaptive_v1_module_has_struct_types_and_v1_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let module = generate_module_from_bytes(&adaptive_v1_config(), &seed_bytes) + .expect("AdaptiveV1 module generation should succeed"); + let entry_point = generated_entry_point(&module); + + assert_eq!(module.struct_types.len(), 2); + assert_eq!(module.named_metadata.len(), 1); + assert!( + module.metadata_nodes.len() >= 4 && module.metadata_nodes.len() <= 6, + "AdaptiveV1 should have 4 base nodes plus up to 2 optional capability nodes, got {}", + module.metadata_nodes.len() + ); + assert_eq!(entry_point.name, qir::ENTRYPOINT_NAME); + assert!( + module + .get_flag(qir::QIR_MAJOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 1)) + ); + assert!( + module + .get_flag(qir::QIR_MINOR_VERSION_KEY) + .is_some_and(|value| *value == MetadataValue::Int(Type::Integer(32), 0)) + ); +} + +// --- BaseV1 / AdaptiveV1 text roundtrip tests --- + +#[test] +fn base_v1_text_roundtrip_checked() { + let seeds = typed_checked_smoke_seeds(); + assert_checked_generation_smoke_case( + &base_v1_config(), + &seeds, + assert_base_v1_checked_smoke_invariant, + ); +} + +#[test] +fn adaptive_v1_text_roundtrip_checked() { + let seeds = typed_checked_smoke_seeds(); + assert_checked_generation_smoke_case( + &adaptive_v1_config(), + &seeds, + assert_adaptive_v1_checked_smoke_invariant, + ); +} + +#[test] +fn bare_roundtrip_checked_omits_qdk_shell_metadata() { + let seed_bytes = deterministic_seed_bytes(); + let seeds = [seed_bytes.as_slice()]; + + assert_checked_generation_smoke_case( + &bare_roundtrip_config(), + &seeds, + assert_bare_roundtrip_checked_smoke_invariant, + ); +} + +#[test] +fn bare_roundtrip_checked_bitcode_only_emits_target_headers() { + let seed_bytes = deterministic_seed_bytes(); + let artifact = checked_bitcode_only_artifact(&bare_roundtrip_config(), &seed_bytes); + + assert_generated_target_headers(&artifact.module); +} + +#[test] +fn bare_roundtrip_checked_bitcode_only_preserves_broader_modeled_globals() { + let deterministic_seed = deterministic_seed_bytes(); + let seeds = [ + deterministic_seed.as_slice(), + &[0_u8; 64][..], + &[1_u8; 64][..], + &[42_u8; 128][..], + ]; + let artifact = find_checked_bitcode_only_artifact( + &bare_roundtrip_config(), + &seeds, + |artifact| has_generated_modeled_globals(&artifact.module), + "broader modeled globals", + ); + + assert_generated_modeled_globals(&artifact.module); +} + +#[test] +fn bare_roundtrip_checked_bitcode_only_preserves_generated_headers() { + let seed_bytes = deterministic_seed_bytes(); + let artifact = checked_bitcode_only_artifact(&bare_roundtrip_config(), &seed_bytes); + let bitcode = artifact + .bitcode + .as_ref() + .expect("checked bitcode-only generation should emit bitcode"); + let reparsed = parse_bitcode_roundtrip(bitcode) + .expect("checked roundtrip should preserve generated target headers"); + + assert_generated_target_headers(&reparsed); + assert_bitcode_roundtrip_matches_supported_v1_subset(&artifact.module, &reparsed) + .expect("checked compat roundtrip should preserve generated target header fidelity"); +} + +#[test] +fn bare_roundtrip_checked_bitcode_only_also_succeeds_through_strict_parse() { + let deterministic_seed = deterministic_seed_bytes(); + let seeds = [ + deterministic_seed.as_slice(), + &[0_u8; 64][..], + &[1_u8; 64][..], + &[42_u8; 128][..], + ]; + let artifact = find_checked_bitcode_only_artifact( + &bare_roundtrip_config(), + &seeds, + |artifact| has_generated_modeled_globals(&artifact.module), + "strict-parse parity coverage for broader modeled globals", + ); + let bitcode = artifact + .bitcode + .as_ref() + .expect("checked bitcode-only generation should emit bitcode"); + let strict = crate::bitcode::reader::parse_bitcode(bitcode) + .expect("zero-diagnostic checked bitcode should also succeed through strict parse_bitcode"); + + assert_generated_target_headers(&strict); + assert_generated_modeled_globals(&strict); + assert_bitcode_roundtrip_matches_supported_v1_subset(&artifact.module, &strict).expect( + "strict bitcode parse should preserve generated header and global fidelity for the supported subset", + ); +} + +// --- Sanitize tests for v1 profiles --- + +#[test] +fn sanitize_preserves_base_v1_profile() { + let effective = QirSmithConfig { + output_mode: OutputMode::Text, + ..QirSmithConfig::for_profile(QirProfilePreset::BaseV1) + } + .sanitize(); + + assert_eq!(effective.profile, QirProfilePreset::BaseV1); + assert!(effective.allow_typed_pointers); + assert!(!effective.bare_roundtrip_mode); + assert_eq!(effective.max_blocks_per_func, BASE_V1_BLOCK_COUNT); +} + +#[test] +fn sanitize_preserves_adaptive_v1_profile() { + let effective = QirSmithConfig { + output_mode: OutputMode::Text, + ..QirSmithConfig::for_profile(QirProfilePreset::AdaptiveV1) + } + .sanitize(); + + assert_eq!(effective.profile, QirProfilePreset::AdaptiveV1); + assert!(effective.allow_typed_pointers); + assert!(!effective.bare_roundtrip_mode); +} + +#[test] +fn sanitize_v1_profiles_default_to_text_only_roundtrip() { + let base_v1_effective = QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + roundtrip: None, + ..QirSmithConfig::for_profile(QirProfilePreset::BaseV1) + } + .sanitize(); + + assert_eq!(base_v1_effective.roundtrip, Some(RoundTripKind::TextOnly)); + + let adaptive_v1_effective = QirSmithConfig { + output_mode: OutputMode::RoundTripChecked, + roundtrip: None, + ..QirSmithConfig::for_profile(QirProfilePreset::AdaptiveV1) + } + .sanitize(); + + assert_eq!( + adaptive_v1_effective.roundtrip, + Some(RoundTripKind::TextOnly) + ); +} + +#[test] +fn initialize_declaration_appears_when_flag_enabled() { + // Use all-1s seed: take_flag(bytes, false) returns true for all flags + let seed_bytes: Vec = vec![1; 128]; + let text = generate_text_from_bytes(&base_v1_config(), &seed_bytes) + .expect("BaseV1 text generation should succeed"); + + assert!( + text.contains("@__quantum__rt__initialize"), + "Initialize declaration should appear when flag is set" + ); +} + +#[test] +fn initialize_call_in_entry_block_when_flag_enabled() { + let seed_bytes: Vec = vec![1; 128]; + let text = generate_text_from_bytes(&base_v1_config(), &seed_bytes) + .expect("BaseV1 text generation should succeed"); + + if text.contains("@__quantum__rt__initialize") { + assert!( + text.contains("call void @__quantum__rt__initialize("), + "Initialize call should appear when declaration is present" + ); + } +} + +#[test] +fn initialize_always_present_for_qdk_shell() { + // Initialize is always included for QDK shell presets. + let seed_bytes: Vec = vec![0; 128]; + let text = generate_text_from_bytes(&base_v1_config(), &seed_bytes) + .expect("BaseV1 text generation should succeed"); + + assert!( + text.contains("call void @__quantum__rt__initialize("), + "Initialize call should always appear for QDK shell presets" + ); +} + +#[test] +fn adaptive_v1_emits_capability_metadata_when_flags_enabled() { + // Use all-1s seed: all take_flag calls return true + let seed_bytes: Vec = vec![1; 128]; + let module = generate_module_from_bytes(&adaptive_v1_config(), &seed_bytes) + .expect("AdaptiveV1 module generation should succeed"); + + let has_int = module.get_flag(qir::INT_COMPUTATIONS_KEY).is_some(); + let has_float = module.get_flag(qir::FLOAT_COMPUTATIONS_KEY).is_some(); + + assert!( + has_int || has_float, + "AdaptiveV1 with all-1s seed should emit at least one capability metadata node" + ); +} + +#[test] +fn adaptive_v1_zero_seed_smoke_omits_float_metadata_without_float_operations() { + let seed_bytes = vec![0_u8; 16]; + let module = generate_module_from_bytes(&adaptive_v1_config(), &seed_bytes) + .expect("AdaptiveV1 module generation should succeed"); + let analysis = crate::qir::inspect::analyze_float_surface(&module); + + assert!( + module.get_flag(qir::INT_COMPUTATIONS_KEY).is_some(), + "AdaptiveV1 should always emit int_computations" + ); + assert!( + !analysis.has_float_op, + "zero seed should continue to exercise a no-float generator path" + ); + assert!( + analysis.surface_width_names().is_empty(), + "no-float AdaptiveV1 modules should not retain float-typed IR surface" + ); + assert!( + metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY).is_none(), + "float_computations should be omitted when no floating-point operation exists" + ); +} + +#[test] +fn adaptive_v1_checked_seed_smokes_do_not_leave_float_surface_without_float_ops() { + for seed in [&[0_u8; 64][..], &[1_u8; 64], &[42_u8; 128]] { + let module = generate_module_from_bytes(&adaptive_v1_config(), seed) + .expect("AdaptiveV1 module generation should succeed"); + let analysis = crate::qir::inspect::analyze_float_surface(&module); + + assert!( + analysis.has_float_op || analysis.surface_width_names().is_empty(), + "AdaptiveV1 should not leave float-typed IR surface without a float op for seed length {} and first byte {}", + seed.len(), + seed.first().copied().unwrap_or_default() + ); + } +} + +#[test] +fn adaptive_v2_checked_empty_seed_keeps_metadata_ids_dense() { + let seeds: [&[u8]; 1] = [&[]]; + + assert_checked_generation_smoke_case( + &QirSmithConfig::for_profile(QirProfilePreset::AdaptiveV2), + &seeds, + assert_adaptive_v2_checked_smoke_invariant, + ); +} + +#[test] +fn opt_in_text_roundtrip_emits_control_flow_instruction_slice() { + let config = control_flow_expansion_config(); + let mut coverage = InstructionCoverage::default(); + + for seed in expansion_seed_bank() { + let module = generated_text_roundtrip_module(&config, &seed); + coverage.observe_module(&module); + + if missing_control_flow_families(&coverage).is_empty() { + break; + } + } + + let missing = missing_control_flow_families(&coverage); + assert!( + missing.is_empty(), + "fixed qir_smith expansion seeds should emit phi, switch, and unreachable after text roundtrip; missing {}", + missing.join(", ") + ); +} + +#[test] +fn opt_in_text_roundtrip_emits_memory_instruction_slice() { + let config = memory_expansion_config(); + let mut coverage = InstructionCoverage::default(); + + for seed in expansion_seed_bank() { + let module = generated_text_roundtrip_module(&config, &seed); + coverage.observe_module(&module); + + if missing_memory_families(&coverage).is_empty() { + break; + } + } + + let missing = missing_memory_families(&coverage); + assert!( + missing.is_empty(), + "fixed qir_smith expansion seeds should emit alloca, load, store, select, and instruction getelementptr after text roundtrip; missing {}", + missing.join(", ") + ); +} + +#[test] +fn adaptive_float_builders_cover_half_float_and_double() { + let mut state = adaptive_shell_state(); + + for (selector, expected_ty) in [ + (0_u8, Type::Half), + (1_u8, Type::Float), + (2_u8, Type::Double), + ] { + let mut pool = state.build_base_value_pool(); + let selector_bytes = [selector; 16]; + let mut bytes = Unstructured::new(&selector_bytes); + + let binop = state + .build_float_binop_instruction(&mut pool, &mut bytes) + .expect("float binop builder should succeed for supported widths"); + assert!(matches!(binop, Instruction::BinOp { ty, .. } if ty == expected_ty.clone())); + + let mut pool = state.build_base_value_pool(); + let selector_bytes = [selector; 16]; + let mut bytes = Unstructured::new(&selector_bytes); + let fcmp = state + .build_fcmp_instruction(&mut pool, &mut bytes) + .expect("fcmp builder should succeed for supported widths"); + assert!(matches!(fcmp, Instruction::FCmp { ty, .. } if ty == expected_ty.clone())); + + let mut pool = state.build_base_value_pool(); + let selector_bytes = [selector; 16]; + let mut bytes = Unstructured::new(&selector_bytes); + let sitofp = state + .build_sitofp_instruction(&mut pool, &mut bytes) + .expect("sitofp builder should succeed for supported widths"); + assert!( + matches!(sitofp, Instruction::Cast { op: CastKind::Sitofp, to_ty, .. } if to_ty == expected_ty.clone()) + ); + + let mut pool = state.build_base_value_pool(); + let selector_bytes = [selector; 16]; + let mut bytes = Unstructured::new(&selector_bytes); + let fptosi = state + .build_fptosi_instruction(&mut pool, &mut bytes) + .expect("fptosi builder should succeed for supported widths"); + assert!( + matches!(fptosi, Instruction::Cast { op: CastKind::Fptosi, from_ty, .. } if from_ty == expected_ty) + ); + } +} + +#[test] +fn finalize_float_computations_rewrites_supported_metadata_to_exact_surface_subset() { + let mut module = adaptive_v1_module_with_float_metadata_shell( + vec![GlobalVariable { + name: "g".to_string(), + ty: Type::Float, + linkage: Linkage::Internal, + is_constant: false, + initializer: None, + }], + Vec::new(), + vec![ + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Half, + lhs: Operand::float_const(Type::Half, 1.0), + rhs: Operand::float_const(Type::Half, 2.0), + result: "sum".to_string(), + }, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ], + ); + + assert_eq!( + metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY), + Some(vec![ + "half".to_string(), + "float".to_string(), + "double".to_string(), + ]) + ); + + finalize_float_computations(&mut module); + + assert_eq!( + metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY), + Some(vec!["half".to_string(), "float".to_string()]) + ); +} + +#[test] +fn finalize_float_computations_removes_flag_without_float_operations() { + let mut module = adaptive_v1_module_with_float_metadata_shell( + Vec::new(), + vec![double_record_output_declaration()], + vec![ + Instruction::Call { + return_ty: None, + callee: qir::rt::DOUBLE_RECORD_OUTPUT.to_string(), + args: vec![ + (Type::Double, Operand::float_const(Type::Double, 1.0)), + (Type::Ptr, Operand::NullPtr), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ], + ); + let analysis = crate::qir::inspect::analyze_float_surface(&module); + + assert!(!analysis.has_float_op); + assert_eq!(analysis.surface_width_names(), vec!["double"]); + assert!(metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY).is_some()); + + finalize_float_computations(&mut module); + + assert!(metadata_string_list(&module, qir::FLOAT_COMPUTATIONS_KEY).is_none()); +} + +#[test] +fn base_v1_never_emits_capability_metadata() { + for seed in [&[0u8; 128][..], &[1; 128], &[42; 128]] { + let module = generate_module_from_bytes(&base_v1_config(), seed) + .expect("BaseV1 module generation should succeed"); + + assert!( + module.get_flag(qir::INT_COMPUTATIONS_KEY).is_none(), + "BaseV1 should never emit int_computations" + ); + assert!( + module.get_flag(qir::FLOAT_COMPUTATIONS_KEY).is_none(), + "BaseV1 should never emit float_computations" + ); + } +} + +#[test] +fn parse_bitcode_roundtrip_preserves_supported_global_initializers() { + let Some(lane) = available_fast_matrix_lanes().into_iter().next() else { + eprintln!( + "no external LLVM fast-matrix lane is available, skipping qir_smith global-initializer regression" + ); + return; + }; + + let bitcode = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + "@0 = internal constant [4 x i8] c\"0_r\\00\"\n", + ) + .unwrap_or_else(|error| { + panic!( + "llvm@{} should assemble qir_smith global-initializer fixture: {error}", + lane.version + ) + }); + + let module = parse_bitcode_roundtrip(&bitcode) + .expect("checked roundtrip should preserve supported global initializers"); + + assert_eq!(module.globals.len(), 1); + assert_eq!( + module.globals[0].ty, + Type::Array(4, Box::new(Type::Integer(8))) + ); + assert!(module.globals[0].is_constant); + assert_eq!( + module.globals[0].initializer, + Some(Constant::CString("0_r".to_string())) + ); +} diff --git a/source/compiler/qsc_llvm/src/lib.rs b/source/compiler/qsc_llvm/src/lib.rs new file mode 100644 index 0000000000..e2223dae13 --- /dev/null +++ b/source/compiler/qsc_llvm/src/lib.rs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod test_utils; +#[cfg(test)] +mod tests; + +use miette::Diagnostic; +use thiserror::Error; + +pub mod bitcode; +pub mod fuzz; +pub mod model; +pub mod qir; +pub mod text; +pub mod validation; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ReadPolicy { + Compatibility, + QirSubsetStrict, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ReadDiagnosticKind { + MalformedInput, + UnsupportedSemanticConstruct, +} + +impl std::fmt::Display for ReadDiagnosticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::MalformedInput => write!(f, "malformed input"), + Self::UnsupportedSemanticConstruct => write!(f, "unsupported semantic construct"), + } + } +} + +#[derive(Clone, Debug, Diagnostic, Error, PartialEq, Eq)] +#[error("{kind}: {context}: {message}")] +#[diagnostic(code(qsc_llvm::read))] +pub struct ReadDiagnostic { + pub kind: ReadDiagnosticKind, + pub offset: Option, + pub context: &'static str, + pub message: String, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ReadReport { + pub module: model::Module, + pub diagnostics: Vec, +} + +pub use bitcode::reader::{ + ParseError, parse_bitcode, parse_bitcode_compatibility, parse_bitcode_compatibility_report, + parse_bitcode_detailed, +}; +pub use bitcode::writer::{ + WriteError, try_write_bitcode, try_write_bitcode_for_target, write_bitcode, + write_bitcode_for_target, +}; +pub use fuzz::qir_smith::{ + EffectiveConfig, GeneratedArtifact, OutputMode, QirProfilePreset, QirSmithConfig, + QirSmithError, RoundTripKind, generate, generate_bitcode, generate_bitcode_from_bytes, + generate_checked, generate_checked_from_bytes, generate_from_bytes, generate_module, + generate_module_from_bytes, generate_text, generate_text_from_bytes, +}; +pub use model::Module; +pub use model::builder::ModuleBuilder; +pub use text::reader::{parse_module, parse_module_compatibility, parse_module_detailed}; +pub use text::writer::write_module_to_string; +pub use validation::{LlvmIrError, validate_ir, validate_qir_profile}; diff --git a/source/compiler/qsc_llvm/src/model.rs b/source/compiler/qsc_llvm/src/model.rs new file mode 100644 index 0000000000..a4dc0df9ee --- /dev/null +++ b/source/compiler/qsc_llvm/src/model.rs @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +#[cfg(test)] +pub(crate) mod test_helpers; + +pub mod builder; + +use half::f16; +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Type { + Void, + Integer(u32), + Half, + Float, + Double, + Label, + Ptr, + NamedPtr(String), + TypedPtr(Box), + Array(u64, Box), + Function(Box, Vec), + Named(String), +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Void => write!(f, "void"), + Type::Integer(n) => write!(f, "i{n}"), + Type::Half => write!(f, "half"), + Type::Float => write!(f, "float"), + Type::Double => write!(f, "double"), + Type::Label => write!(f, "label"), + Type::Ptr => write!(f, "ptr"), + Type::NamedPtr(s) => write!(f, "%{s}*"), + Type::TypedPtr(inner) => write!(f, "{inner}*"), + Type::Array(n, ty) => write!(f, "[{n} x {ty}]"), + Type::Function(ret, params) => { + write!(f, "{ret} (")?; + for (i, p) in params.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{p}")?; + } + write!(f, ")") + } + Type::Named(s) => write!(f, "%{s}"), + } + } +} + +impl Type { + #[must_use] + pub fn is_floating_point(&self) -> bool { + matches!(self, Self::Half | Self::Float | Self::Double) + } + + #[must_use] + pub fn floating_point_bit_width(&self) -> Option { + match self { + Self::Half => Some(16), + Self::Float => Some(32), + Self::Double => Some(64), + _ => None, + } + } + + #[must_use] + #[allow(clippy::cast_possible_truncation)] + pub fn canonicalize_float_value(&self, value: f64) -> Option { + match self { + Self::Half => Some(f16::from_f64(value).to_f64()), + Self::Float => Some(f64::from(value as f32)), + Self::Double => Some(value), + _ => None, + } + } + + #[must_use] + #[allow(clippy::cast_possible_truncation)] + pub fn encode_float_bits(&self, value: f64) -> Option { + let canonical = self.canonicalize_float_value(value)?; + match self { + Self::Half => Some(u64::from(f16::from_f64(canonical).to_bits())), + Self::Float => Some(u64::from((canonical as f32).to_bits())), + Self::Double => Some(canonical.to_bits()), + _ => None, + } + } + + #[must_use] + pub fn decode_float_bits(&self, bits: u64) -> Option { + match self { + Self::Half => Some(f16::from_bits(u16::try_from(bits).ok()?).to_f64()), + Self::Float => Some(f64::from(f32::from_bits(u32::try_from(bits).ok()?))), + Self::Double => Some(f64::from_bits(bits)), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Module { + pub source_filename: Option, + pub target_datalayout: Option, + pub target_triple: Option, + pub struct_types: Vec, + pub globals: Vec, + pub functions: Vec, + pub attribute_groups: Vec, + pub named_metadata: Vec, + pub metadata_nodes: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StructType { + pub name: String, + pub is_opaque: bool, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct GlobalVariable { + pub name: String, + pub ty: Type, + pub linkage: Linkage, + pub is_constant: bool, + pub initializer: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Linkage { + Internal, + External, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Constant { + CString(String), + Int(i64), + Float(Type, f64), + Null, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Function { + pub name: String, + pub return_type: Type, + pub params: Vec, + pub is_declaration: bool, + pub attribute_group_refs: Vec, + pub basic_blocks: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Param { + pub ty: Type, + pub name: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BasicBlock { + pub name: String, + pub instructions: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Instruction { + Ret(Option), + Br { + cond_ty: Type, + cond: Operand, + true_dest: String, + false_dest: String, + }, + Jump { + dest: String, + }, + BinOp { + op: BinOpKind, + ty: Type, + lhs: Operand, + rhs: Operand, + result: String, + }, + ICmp { + pred: IntPredicate, + ty: Type, + lhs: Operand, + rhs: Operand, + result: String, + }, + FCmp { + pred: FloatPredicate, + ty: Type, + lhs: Operand, + rhs: Operand, + result: String, + }, + Cast { + op: CastKind, + from_ty: Type, + to_ty: Type, + value: Operand, + result: String, + }, + Call { + return_ty: Option, + callee: String, + args: Vec<(Type, Operand)>, + result: Option, + attr_refs: Vec, + }, + Phi { + ty: Type, + incoming: Vec<(Operand, String)>, + result: String, + }, + Alloca { + ty: Type, + result: String, + }, + Load { + ty: Type, + ptr_ty: Type, + ptr: Operand, + result: String, + }, + Store { + ty: Type, + value: Operand, + ptr_ty: Type, + ptr: Operand, + }, + Select { + cond: Operand, + true_val: Operand, + false_val: Operand, + ty: Type, + result: String, + }, + Switch { + ty: Type, + value: Operand, + default_dest: String, + cases: Vec<(i64, String)>, + }, + GetElementPtr { + inbounds: bool, + pointee_ty: Type, + ptr_ty: Type, + ptr: Operand, + indices: Vec, + result: String, + }, + Unreachable, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BinOpKind { + Add, + Sub, + Mul, + Sdiv, + Srem, + Shl, + Ashr, + And, + Or, + Xor, + Fadd, + Fsub, + Fmul, + Fdiv, + Udiv, + Urem, + Lshr, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum IntPredicate { + Eq, + Ne, + Sgt, + Sge, + Slt, + Sle, + Ult, + Ule, + Ugt, + Uge, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum FloatPredicate { + Oeq, + Ogt, + Oge, + Olt, + Ole, + One, + Ord, + Uno, + Ueq, + Ugt, + Uge, + Ult, + Ule, + Une, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum CastKind { + Sitofp, + Fptosi, + Zext, + Sext, + Trunc, + FpExt, + FpTrunc, + IntToPtr, + PtrToInt, + Bitcast, +} + +#[derive(Debug, Clone)] +pub enum Operand { + LocalRef(String), + TypedLocalRef(String, Type), + IntConst(Type, i64), + FloatConst(Type, f64), + NullPtr, + IntToPtr(i64, Type), + GetElementPtr { + ty: Type, + ptr: String, + ptr_ty: Type, + indices: Vec, + }, + GlobalRef(String), +} + +impl PartialEq for Operand { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::TypedLocalRef(lhs_name, lhs_ty), Self::TypedLocalRef(rhs_name, rhs_ty)) => { + lhs_name == rhs_name && lhs_ty == rhs_ty + } + (Self::IntConst(lhs_ty, lhs_val), Self::IntConst(rhs_ty, rhs_val)) => { + lhs_ty == rhs_ty && lhs_val == rhs_val + } + (Self::FloatConst(lhs_ty, lhs), Self::FloatConst(rhs_ty, rhs)) => { + lhs_ty == rhs_ty && lhs == rhs + } + (Self::NullPtr, Self::NullPtr) => true, + (Self::IntToPtr(lhs_val, lhs_ty), Self::IntToPtr(rhs_val, rhs_ty)) => { + lhs_val == rhs_val && lhs_ty == rhs_ty + } + ( + Self::GetElementPtr { + ty: lhs_ty, + ptr: lhs_ptr, + ptr_ty: lhs_ptr_ty, + indices: lhs_indices, + }, + Self::GetElementPtr { + ty: rhs_ty, + ptr: rhs_ptr, + ptr_ty: rhs_ptr_ty, + indices: rhs_indices, + }, + ) => { + lhs_ty == rhs_ty + && lhs_ptr == rhs_ptr + && lhs_ptr_ty == rhs_ptr_ty + && lhs_indices == rhs_indices + } + (Self::LocalRef(lhs) | Self::TypedLocalRef(lhs, _), Self::LocalRef(rhs)) + | (Self::LocalRef(lhs), Self::TypedLocalRef(rhs, _)) + | (Self::GlobalRef(lhs), Self::GlobalRef(rhs)) => lhs == rhs, + _ => false, + } + } +} + +impl Constant { + #[must_use] + pub fn float(ty: Type, value: f64) -> Self { + let value = ty.canonicalize_float_value(value).unwrap_or(value); + Self::Float(ty, value) + } +} + +impl Operand { + #[must_use] + pub fn float_const(ty: Type, value: f64) -> Self { + let value = ty.canonicalize_float_value(value).unwrap_or(value); + Self::FloatConst(ty, value) + } + + #[must_use] + pub fn int_to_named_ptr>(value: i64, name: S) -> Self { + Self::IntToPtr(value, Type::NamedPtr(name.into())) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AttributeGroup { + pub id: u32, + pub attributes: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Attribute { + StringAttr(String), + KeyValue(String, String), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NamedMetadata { + pub name: String, + pub node_refs: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct MetadataNode { + pub id: u32, + pub values: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum MetadataValue { + Int(Type, i64), + String(String), + NodeRef(u32), + SubList(Vec), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ModuleFlagNodeIssue { + DanglingReference { node_ref: u32 }, + MalformedEntry { node_ref: u32, reason: &'static str }, +} + +#[derive(Debug, Clone)] +pub(crate) struct ModuleFlagNode<'a> { + pub(crate) node_id: u32, + pub(crate) behavior: &'a MetadataValue, + pub(crate) key: &'a str, + pub(crate) value: &'a MetadataValue, +} + +#[derive(Debug, Clone, Default)] +pub(crate) struct ModuleFlagAudit<'a> { + pub(crate) entries: Vec>, + pub(crate) issues: Vec, +} + +impl Module { + #[must_use] + pub(crate) fn audit_module_flags(&self) -> ModuleFlagAudit<'_> { + let Some(named_metadata) = self + .named_metadata + .iter() + .find(|metadata| metadata.name == "llvm.module.flags") + else { + return ModuleFlagAudit::default(); + }; + + let mut audit = ModuleFlagAudit::default(); + + for &node_ref in &named_metadata.node_refs { + let Some(node) = self + .metadata_nodes + .iter() + .find(|candidate| candidate.id == node_ref) + else { + audit + .issues + .push(ModuleFlagNodeIssue::DanglingReference { node_ref }); + continue; + }; + + if node.values.len() < 3 { + audit.issues.push(ModuleFlagNodeIssue::MalformedEntry { + node_ref: node.id, + reason: "module flag nodes must contain behavior, name, and value operands", + }); + continue; + } + + let MetadataValue::String(key) = &node.values[1] else { + audit.issues.push(ModuleFlagNodeIssue::MalformedEntry { + node_ref: node.id, + reason: "module flag names must be metadata strings", + }); + continue; + }; + + audit.entries.push(ModuleFlagNode { + node_id: node.id, + behavior: &node.values[0], + key, + value: &node.values[2], + }); + } + + audit + } + + /// Retrieves a module flag value by key from `!llvm.module.flags` named metadata. + #[must_use] + pub fn get_flag(&self, key: &str) -> Option<&MetadataValue> { + self.audit_module_flags() + .entries + .into_iter() + .find(|entry| entry.key == key) + .map(|entry| entry.value) + } +} diff --git a/source/compiler/qsc_llvm/src/model/builder.rs b/source/compiler/qsc_llvm/src/model/builder.rs new file mode 100644 index 0000000000..e38e930b7d --- /dev/null +++ b/source/compiler/qsc_llvm/src/model/builder.rs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use super::{Function, Instruction, Module, Operand, Param}; +use crate::model::Type; +use crate::qir; + +#[derive(Debug, Clone)] +pub struct BuilderError(pub String); + +impl std::fmt::Display for BuilderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BuilderError: {}", self.0) + } +} + +/// Cursor-based builder for in-place IR mutation. +pub struct ModuleBuilder<'a> { + module: &'a mut Module, + func_idx: usize, + block_idx: usize, + instr_idx: usize, +} + +impl<'a> ModuleBuilder<'a> { + pub fn new(module: &'a mut Module) -> Self { + Self { + module, + func_idx: 0, + block_idx: 0, + instr_idx: 0, + } + } + + /// Set cursor to the function with the given name. + pub fn position_at_function(&mut self, name: &str) -> Result<(), BuilderError> { + let idx = self + .module + .functions + .iter() + .position(|f| f.name == name) + .ok_or_else(|| BuilderError(format!("function not found: {name}")))?; + self.func_idx = idx; + self.block_idx = 0; + self.instr_idx = 0; + Ok(()) + } + + /// Set cursor to the block with the given label within the current function. + pub fn position_at_block(&mut self, label: &str) -> Result<(), BuilderError> { + let func = &self.module.functions[self.func_idx]; + let idx = func + .basic_blocks + .iter() + .position(|b| b.name == label) + .ok_or_else(|| BuilderError(format!("block not found: {label}")))?; + self.block_idx = idx; + self.instr_idx = 0; + Ok(()) + } + + /// Set cursor before a specific instruction index in the current block. + pub fn position_before(&mut self, instr_idx: usize) { + self.instr_idx = instr_idx; + } + + /// Set cursor to end of current block. + pub fn position_at_end(&mut self) { + let len = self.module.functions[self.func_idx].basic_blocks[self.block_idx] + .instructions + .len(); + self.instr_idx = len; + } + + /// Insert an instruction before the current cursor position. + /// Equivalent to `PyQIR`'s `builder.insert_before(instr, new_instr)`. + pub fn insert_before(&mut self, instruction: Instruction) { + let block = &mut self.module.functions[self.func_idx].basic_blocks[self.block_idx]; + block.instructions.insert(self.instr_idx, instruction); + self.instr_idx += 1; + } + + /// Insert an instruction at the end of the current block. + /// Equivalent to `PyQIR`'s `builder.insert_at_end(block, instr)`. + pub fn insert_at_end(&mut self, instruction: Instruction) { + let block = &mut self.module.functions[self.func_idx].basic_blocks[self.block_idx]; + block.instructions.push(instruction); + } + + /// Remove the instruction at the current cursor position. + /// Equivalent to `PyQIR`'s `call.erase()` / `instr.remove()`. + /// Cursor stays at same index (now pointing to next instruction). + pub fn erase_at_cursor(&mut self) -> Instruction { + let block = &mut self.module.functions[self.func_idx].basic_blocks[self.block_idx]; + block.instructions.remove(self.instr_idx) + } + + /// Returns the number of instructions in the current block. + #[must_use] + pub fn current_block_len(&self) -> usize { + self.module.functions[self.func_idx].basic_blocks[self.block_idx] + .instructions + .len() + } + + /// Construct a Call instruction and insert it at the cursor position. + /// Equivalent to `PyQIR`'s `builder.call(func, args)`. + pub fn call( + &mut self, + callee: &str, + args: Vec<(Type, Operand)>, + return_ty: Option, + result: Option, + ) { + let instr = Instruction::Call { + callee: callee.to_string(), + args, + return_ty, + result, + attr_refs: Vec::new(), + }; + self.insert_before(instr); + } + + /// Construct an arbitrary instruction and insert it at the cursor position. + /// Equivalent to `PyQIR`'s `builder.instr(opcode, operands)`. + pub fn instr(&mut self, instruction: Instruction) { + self.insert_before(instruction); + } + + /// Add a function declaration if one with the same name doesn't already exist. + /// Returns `true` if a new declaration was added. + pub fn ensure_declaration( + &mut self, + name: &str, + return_type: Type, + param_types: Vec, + ) -> bool { + if self.module.functions.iter().any(|f| f.name == name) { + return false; + } + self.module.functions.push(Function { + name: name.to_string(), + return_type, + params: param_types + .into_iter() + .map(|ty| Param { ty, name: None }) + .collect(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + true + } + + /// Remove functions not matching a predicate. + /// Resets cursor indices to 0 after removal to avoid stale references. + pub fn retain_functions bool>(&mut self, f: F) { + self.module.functions.retain(f); + self.func_idx = 0; + self.block_idx = 0; + self.instr_idx = 0; + } + + /// Find the entry point function by scanning attribute groups for + /// the `"entry_point"` attribute. + #[must_use] + pub fn entry_point_index(&self) -> Option { + qir::inspect::find_entry_point(self.module) + } + + /// Immutable access to the underlying `Module`. + #[must_use] + pub fn module(&self) -> &Module { + self.module + } + + /// Mutable access to the underlying `Module`. + pub fn module_mut(&mut self) -> &mut Module { + self.module + } +} diff --git a/source/compiler/qsc_llvm/src/model/builder/tests.rs b/source/compiler/qsc_llvm/src/model/builder/tests.rs new file mode 100644 index 0000000000..191b6d7733 --- /dev/null +++ b/source/compiler/qsc_llvm/src/model/builder/tests.rs @@ -0,0 +1,420 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::model::{Attribute, AttributeGroup, BasicBlock, Function}; +use crate::text::writer::write_module_to_string; + +fn simple_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::NullPtr)], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn two_function_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "func_a".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "exit".to_string(), + }], + }, + BasicBlock { + name: "exit".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + ], + }, + Function { + name: "func_b".to_string(), + return_type: Type::Integer(32), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(32), + 42, + )))], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +#[test] +fn insert_before_adds_instruction() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + // Position before the ret instruction (index 1) + builder.position_before(1); + builder.call( + "__quantum__qis__cx__body", + vec![(Type::Ptr, Operand::NullPtr), (Type::Ptr, Operand::NullPtr)], + None, + None, + ); + + let block = &module.functions[0].basic_blocks[0]; + assert_eq!(block.instructions.len(), 3); + // Original call at 0, new call at 1, ret at 2 + assert!(matches!(block.instructions[1], Instruction::Call { .. })); + assert!(matches!(block.instructions[2], Instruction::Ret(None))); +} + +#[test] +fn insert_at_end_appends() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + builder.insert_at_end(Instruction::Unreachable); + + let block = &module.functions[0].basic_blocks[0]; + assert_eq!(block.instructions.len(), 3); + assert!(matches!(block.instructions[2], Instruction::Unreachable)); +} + +#[test] +fn erase_at_cursor_removes_instruction() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + // Erase the first instruction (h gate call) + builder.position_before(0); + let erased = builder.erase_at_cursor(); + assert!(matches!(erased, Instruction::Call { .. })); + + let block = &module.functions[0].basic_blocks[0]; + assert_eq!(block.instructions.len(), 1); + assert!(matches!(block.instructions[0], Instruction::Ret(None))); +} + +#[test] +fn position_at_different_functions() { + let mut module = two_function_module(); + let mut builder = ModuleBuilder::new(&mut module); + + builder + .position_at_function("func_b") + .expect("func_b should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + assert_eq!(builder.current_block_len(), 1); + + builder + .position_at_function("func_a") + .expect("func_a should exist"); + builder + .position_at_block("exit") + .expect("exit block should exist"); + assert_eq!(builder.current_block_len(), 1); +} + +#[test] +fn position_at_nonexistent_function_errors() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + let result = builder.position_at_function("no_such_fn"); + assert!(result.is_err()); +} + +#[test] +fn position_at_nonexistent_block_errors() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + let result = builder.position_at_block("no_such_block"); + assert!(result.is_err()); +} + +#[test] +fn instr_inserts_arbitrary_instruction() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + builder.position_before(0); + + builder.instr(Instruction::Alloca { + ty: Type::Integer(64), + result: "%x".to_string(), + }); + + let block = &module.functions[0].basic_blocks[0]; + assert_eq!(block.instructions.len(), 3); + assert!(matches!(block.instructions[0], Instruction::Alloca { .. })); +} + +#[test] +fn multiple_sequential_mutations() { + let mut module = simple_module(); + { + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + // Insert two calls before the ret (index 1) + builder.position_before(1); + builder.call( + "__quantum__qis__x__body", + vec![(Type::Ptr, Operand::NullPtr)], + None, + None, + ); + // Cursor is now at 2, insert another + builder.call( + "__quantum__qis__z__body", + vec![(Type::Ptr, Operand::NullPtr)], + None, + None, + ); + + // h, x, z, ret + assert_eq!(builder.current_block_len(), 4); + + // Erase the original h call at index 0 + builder.position_before(0); + let erased = builder.erase_at_cursor(); + assert!( + matches!(erased, Instruction::Call { callee, .. } if callee == "__quantum__qis__h__body") + ); + + // x, z, ret + assert_eq!(builder.current_block_len(), 3); + } + + // Verify final state after builder is dropped + assert_eq!(module.functions[0].basic_blocks[0].instructions.len(), 3); +} + +#[test] +fn position_at_end_sets_cursor_past_last() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + builder.position_at_end(); + assert_eq!(builder.current_block_len(), 2); + // Inserting at end via insert_before at position == len is effectively append + builder.instr(Instruction::Unreachable); + assert_eq!(builder.current_block_len(), 3); +} + +#[test] +fn modified_module_serializes_to_valid_ir() { + let mut module = simple_module(); + { + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + + // Insert a cx call before ret + builder.position_before(1); + builder.call( + "__quantum__qis__cx__body", + vec![(Type::Ptr, Operand::NullPtr), (Type::Ptr, Operand::NullPtr)], + None, + None, + ); + } + + let text = write_module_to_string(&module); + assert!(text.contains("call void @__quantum__qis__h__body(ptr null)")); + assert!(text.contains("call void @__quantum__qis__cx__body(ptr null, ptr null)")); + assert!(text.contains("ret void")); +} + +#[test] +fn round_trip_after_mutation() { + use crate::text::reader::parse_module; + + let mut module = simple_module(); + { + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("test_fn") + .expect("function should exist"); + builder + .position_at_block("entry") + .expect("block should exist"); + builder.position_before(1); + builder.call( + "__quantum__qis__z__body", + vec![(Type::Ptr, Operand::NullPtr)], + None, + None, + ); + } + + let text1 = write_module_to_string(&module); + let parsed = parse_module(&text1).expect("should parse modified module"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text1, text2); +} + +#[test] +fn ensure_declaration_adds_new() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + let added = builder.ensure_declaration("new_decl", Type::Void, vec![Type::Ptr]); + assert!(added); + assert_eq!(module.functions.len(), 2); + assert_eq!(module.functions[1].name, "new_decl"); + assert!(module.functions[1].is_declaration); +} + +#[test] +fn ensure_declaration_skips_existing() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + let added = builder.ensure_declaration("test_fn", Type::Void, vec![]); + assert!(!added); + assert_eq!(module.functions.len(), 1); +} + +#[test] +fn retain_functions_removes_matching() { + let mut module = two_function_module(); + let mut builder = ModuleBuilder::new(&mut module); + builder + .position_at_function("func_b") + .expect("func_b should exist"); + builder.retain_functions(|f| f.name == "func_a"); + assert_eq!(module.functions.len(), 1); + assert_eq!(module.functions[0].name, "func_a"); +} + +#[test] +fn entry_point_index_found() { + let mut module = simple_module(); + module.functions[0].attribute_group_refs = vec![0]; + module.attribute_groups.push(AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr("entry_point".to_string())], + }); + let builder = ModuleBuilder::new(&mut module); + assert_eq!(builder.entry_point_index(), Some(0)); +} + +#[test] +fn entry_point_index_not_found() { + let mut module = simple_module(); + let builder = ModuleBuilder::new(&mut module); + assert_eq!(builder.entry_point_index(), None); +} + +#[test] +fn entry_point_index_skips_declarations() { + let mut module = simple_module(); + module.functions.insert( + 0, + Function { + name: "decl_entry".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: true, + attribute_group_refs: vec![0], + basic_blocks: Vec::new(), + }, + ); + module.functions[1].attribute_group_refs = vec![0]; + module.attribute_groups.push(AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr("entry_point".to_string())], + }); + + let builder = ModuleBuilder::new(&mut module); + assert_eq!(builder.entry_point_index(), Some(1)); +} + +#[test] +fn module_access() { + let mut module = simple_module(); + let mut builder = ModuleBuilder::new(&mut module); + assert_eq!(builder.module().functions.len(), 1); + builder.module_mut().functions.clear(); + assert_eq!(builder.module().functions.len(), 0); +} diff --git a/source/compiler/qsc_llvm/src/model/test_helpers.rs b/source/compiler/qsc_llvm/src/model/test_helpers.rs new file mode 100644 index 0000000000..0fc5c8b0de --- /dev/null +++ b/source/compiler/qsc_llvm/src/model/test_helpers.rs @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; + +pub fn empty_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +pub fn single_instruction_module(instr: Instruction) -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![instr, Instruction::Ret(None)], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +#[allow(clippy::too_many_lines)] +pub fn bell_module_v2() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: vec![ + GlobalVariable { + name: "0".to_string(), + ty: Type::Array(4, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("0_a".to_string())), + }, + GlobalVariable { + name: "1".to_string(), + ty: Type::Array(6, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("1_a0r".to_string())), + }, + GlobalVariable { + name: "2".to_string(), + ty: Type::Array(6, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("2_a1r".to_string())), + }, + ], + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "__quantum__qis__cx__body".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "__quantum__qis__m__body".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: vec![1], + basic_blocks: Vec::new(), + }, + Function { + name: "__quantum__rt__array_record_output".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Integer(64), + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "__quantum__rt__result_record_output".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "ENTRYPOINT__main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "block_0".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__cx__body".to_string(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(1, Type::Ptr)), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__m__body".to_string(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__m__body".to_string(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(1, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(1, Type::Ptr)), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__rt__array_record_output".to_string(), + args: vec![ + (Type::Integer(64), Operand::IntConst(Type::Integer(64), 2)), + (Type::Ptr, Operand::GlobalRef("0".to_string())), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__rt__result_record_output".to_string(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::GlobalRef("1".to_string())), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "__quantum__rt__result_record_output".to_string(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(1, Type::Ptr)), + (Type::Ptr, Operand::GlobalRef("2".to_string())), + ], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ], + }], + }, + ], + attribute_groups: vec![ + AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr("entry_point".to_string()), + Attribute::StringAttr("output_labeling_schema".to_string()), + Attribute::KeyValue("qir_profiles".to_string(), "adaptive_profile".to_string()), + Attribute::KeyValue("required_num_qubits".to_string(), "2".to_string()), + Attribute::KeyValue("required_num_results".to_string(), "2".to_string()), + ], + }, + AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr("irreversible".to_string())], + }, + ], + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1, 2, 3, 4, 5, 6], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("qir_minor_version".to_string()), + MetadataValue::Int(Type::Integer(32), 1), + ], + }, + MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_qubit_management".to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 3, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_result_management".to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 4, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("int_computations".to_string()), + MetadataValue::SubList(vec![MetadataValue::String("i64".to_string())]), + ], + }, + MetadataNode { + id: 5, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("backwards_branching".to_string()), + MetadataValue::Int(Type::Integer(2), 3), + ], + }, + MetadataNode { + id: 6, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("arrays".to_string()), + MetadataValue::Int(Type::Integer(1), 1), + ], + }, + ], + } +} diff --git a/source/compiler/qsc_llvm/src/model/tests.rs b/source/compiler/qsc_llvm/src/model/tests.rs new file mode 100644 index 0000000000..eaed9a117e --- /dev/null +++ b/source/compiler/qsc_llvm/src/model/tests.rs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use expect_test::expect; +use test_helpers::{empty_module, single_instruction_module}; + +#[test] +fn display_void() { + expect!["void"].assert_eq(&Type::Void.to_string()); +} + +#[test] +fn display_integer_1() { + expect!["i1"].assert_eq(&Type::Integer(1).to_string()); +} + +#[test] +fn display_integer_64() { + expect!["i64"].assert_eq(&Type::Integer(64).to_string()); +} + +#[test] +fn display_double() { + expect!["double"].assert_eq(&Type::Double.to_string()); +} + +#[test] +fn display_ptr() { + expect!["ptr"].assert_eq(&Type::Ptr.to_string()); +} + +#[test] +fn display_named_ptr() { + expect!["%Qubit*"].assert_eq(&Type::NamedPtr("Qubit".to_string()).to_string()); +} + +#[test] +fn display_array() { + expect!["[4 x i8]"].assert_eq(&Type::Array(4, Box::new(Type::Integer(8))).to_string()); +} + +#[test] +fn display_function_no_params() { + expect!["void ()"].assert_eq(&Type::Function(Box::new(Type::Void), vec![]).to_string()); +} + +#[test] +fn display_function_with_params() { + expect!["void (ptr, ptr)"] + .assert_eq(&Type::Function(Box::new(Type::Void), vec![Type::Ptr, Type::Ptr]).to_string()); +} + +#[test] +fn display_function_with_return() { + expect!["i1 (ptr)"] + .assert_eq(&Type::Function(Box::new(Type::Integer(1)), vec![Type::Ptr]).to_string()); +} + +#[test] +fn display_named() { + expect!["%Qubit"].assert_eq(&Type::Named("Qubit".to_string()).to_string()); +} + +#[test] +fn display_named_result() { + expect!["%Result"].assert_eq(&Type::Named("Result".to_string()).to_string()); +} + +#[test] +fn equality() { + assert_eq!(Type::Integer(32), Type::Integer(32)); + assert_ne!(Type::Integer(32), Type::Integer(64)); + assert_ne!(Type::Ptr, Type::Double); +} + +#[test] +fn clone_preserves_equality() { + let ty = Type::Array(8, Box::new(Type::Integer(8))); + let cloned = ty.clone(); + assert_eq!(ty, cloned); +} + +#[test] +fn empty_module_has_no_functions() { + let m = empty_module(); + assert!(m.functions.is_empty()); + assert!(m.globals.is_empty()); + assert!(m.struct_types.is_empty()); +} + +#[test] +fn single_instruction_module_has_one_function() { + let m = single_instruction_module(Instruction::Ret(None)); + assert_eq!(m.functions.len(), 1); + assert_eq!(m.functions[0].name, "test_fn"); + assert_eq!(m.functions[0].basic_blocks.len(), 1); + assert_eq!(m.functions[0].basic_blocks[0].instructions.len(), 2); +} + +#[test] +fn module_clone_preserves_equality() { + let m = single_instruction_module(Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))); + let cloned = m.clone(); + assert_eq!(m, cloned); +} + +#[test] +fn instruction_debug_format() { + let instr = Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("a".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "sum".to_string(), + }; + let debug_str = format!("{instr:?}"); + assert!(debug_str.contains("BinOp")); + assert!(debug_str.contains("Add")); +} + +#[test] +fn call_instruction_construction() { + let call = Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: Vec::new(), + }; + assert!(matches!(call, Instruction::Call { .. })); +} + +#[test] +fn global_variable_construction() { + let g = GlobalVariable { + name: "0".to_string(), + ty: Type::Array(4, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("0_r".to_string())), + }; + assert!(g.is_constant); + assert_eq!(g.linkage, Linkage::Internal); +} + +#[test] +fn metadata_construction() { + let named = NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1], + }; + let node = MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 1), + ], + }; + assert_eq!(named.node_refs.len(), 2); + assert_eq!(node.values.len(), 3); +} diff --git a/source/compiler/qsc_llvm/src/qir.rs b/source/compiler/qsc_llvm/src/qir.rs new file mode 100644 index 0000000000..8cf7592627 --- /dev/null +++ b/source/compiler/qsc_llvm/src/qir.rs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Public QIR facade. + +mod build; +pub(crate) mod inspect; +mod spec; + +pub use build::{double_op, i64_op, qubit_op, result_op, void_call}; +pub use inspect::{ + extract_float, extract_id, find_entry_point, get_function_attribute, operand_key, +}; +pub use spec::*; diff --git a/source/compiler/qsc_llvm/src/qir/build.rs b/source/compiler/qsc_llvm/src/qir/build.rs new file mode 100644 index 0000000000..96fcc0323a --- /dev/null +++ b/source/compiler/qsc_llvm/src/qir/build.rs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::model::Type; +use crate::model::{Instruction, Operand}; + +use super::spec::{QUBIT_TYPE_NAME, RESULT_TYPE_NAME}; + +/// Build a void call instruction: `call void @callee(args...)`. +#[must_use] +pub fn void_call(callee: &str, args: Vec<(Type, Operand)>) -> Instruction { + Instruction::Call { + callee: callee.to_string(), + args, + return_ty: None, + result: None, + attr_refs: Vec::new(), + } +} + +/// Build a qubit operand: `inttoptr (i64 to %Qubit*)`. +#[must_use] +pub fn qubit_op(id: u32) -> (Type, Operand) { + ( + Type::NamedPtr(QUBIT_TYPE_NAME.to_string()), + Operand::int_to_named_ptr(i64::from(id), QUBIT_TYPE_NAME), + ) +} + +/// Build a result operand: `inttoptr (i64 to %Result*)`. +#[must_use] +pub fn result_op(id: u32) -> (Type, Operand) { + ( + Type::NamedPtr(RESULT_TYPE_NAME.to_string()), + Operand::int_to_named_ptr(i64::from(id), RESULT_TYPE_NAME), + ) +} + +/// Build a `double` constant operand pair. +#[must_use] +pub fn double_op(val: f64) -> (Type, Operand) { + (Type::Double, Operand::float_const(Type::Double, val)) +} + +/// Build an `i64` constant operand pair. +#[must_use] +pub fn i64_op(val: i64) -> (Type, Operand) { + (Type::Integer(64), Operand::IntConst(Type::Integer(64), val)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::Type; + use crate::model::{Instruction, Operand}; + + #[test] + fn test_void_call() { + let instr = void_call("__quantum__qis__h__body", vec![qubit_op(0)]); + match &instr { + Instruction::Call { + callee, + args, + return_ty, + result, + attr_refs, + } => { + assert_eq!(callee, "__quantum__qis__h__body"); + assert_eq!(args.len(), 1); + assert!(return_ty.is_none()); + assert!(result.is_none()); + assert!(attr_refs.is_empty()); + } + _ => panic!("expected Instruction::Call"), + } + } + + #[test] + fn test_qubit_op() { + let (ty, op) = qubit_op(3); + assert_eq!(ty, Type::NamedPtr("Qubit".to_string())); + assert_eq!( + op, + Operand::IntToPtr(3, Type::NamedPtr("Qubit".to_string())) + ); + } + + #[test] + fn test_result_op() { + let (ty, op) = result_op(5); + assert_eq!(ty, Type::NamedPtr("Result".to_string())); + assert_eq!( + op, + Operand::IntToPtr(5, Type::NamedPtr("Result".to_string())) + ); + } + + #[test] + fn test_double_op() { + let (ty, op) = double_op(std::f64::consts::PI); + assert_eq!(ty, Type::Double); + assert_eq!(op, Operand::float_const(Type::Double, std::f64::consts::PI)); + } + + #[test] + fn test_i64_op() { + let (ty, op) = i64_op(42); + assert_eq!(ty, Type::Integer(64)); + assert_eq!(op, Operand::IntConst(Type::Integer(64), 42)); + } +} diff --git a/source/compiler/qsc_llvm/src/qir/inspect.rs b/source/compiler/qsc_llvm/src/qir/inspect.rs new file mode 100644 index 0000000000..b2b227505f --- /dev/null +++ b/source/compiler/qsc_llvm/src/qir/inspect.rs @@ -0,0 +1,931 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::collections::BTreeSet; + +use crate::model::{ + Attribute, BinOpKind, CastKind, Constant, Function, Instruction, MetadataValue, Module, + ModuleFlagNodeIssue, Operand, Type, +}; + +use super::spec::ENTRY_POINT_ATTR; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct FloatSurfaceAnalysis { + pub(crate) has_float_op: bool, + surface_widths: BTreeSet, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ModuleFlagIssue { + DanglingReference { + node_ref: u32, + }, + MalformedNode { + node_ref: u32, + reason: &'static str, + }, + InvalidBehavior { + flag_name: String, + node_id: u32, + found: String, + }, + InvalidValue { + flag_name: String, + node_id: u32, + expected: &'static str, + found: String, + }, + InvalidStringListItem { + flag_name: String, + node_id: u32, + index: usize, + found: String, + }, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct ModuleFlagAccess { + pub(crate) value: Option, + pub(crate) issues: Vec, +} + +impl ModuleFlagIssue { + #[must_use] + pub(crate) fn flag_name(&self) -> Option<&str> { + match self { + Self::DanglingReference { .. } | Self::MalformedNode { .. } => None, + Self::InvalidBehavior { flag_name, .. } + | Self::InvalidValue { flag_name, .. } + | Self::InvalidStringListItem { flag_name, .. } => Some(flag_name.as_str()), + } + } +} + +fn describe_metadata_value(value: &MetadataValue) -> String { + match value { + MetadataValue::Int(ty, _) => format!("integer ({ty})"), + MetadataValue::String(_) => "string".to_string(), + MetadataValue::NodeRef(node_id) => format!("node reference !{node_id}"), + MetadataValue::SubList(_) => "metadata sublist".to_string(), + } +} + +fn map_module_flag_node_issue(issue: &ModuleFlagNodeIssue) -> ModuleFlagIssue { + match issue { + ModuleFlagNodeIssue::DanglingReference { node_ref } => ModuleFlagIssue::DanglingReference { + node_ref: *node_ref, + }, + ModuleFlagNodeIssue::MalformedEntry { node_ref, reason } => { + ModuleFlagIssue::MalformedNode { + node_ref: *node_ref, + reason, + } + } + } +} + +fn find_module_flag_entry<'a>( + module: &'a Module, + key: &str, +) -> Option> { + module + .audit_module_flags() + .entries + .into_iter() + .find(|entry| entry.key == key) +} + +impl FloatSurfaceAnalysis { + fn record_type(&mut self, ty: &Type) { + match ty { + Type::TypedPtr(inner) | Type::Array(_, inner) => self.record_type(inner), + Type::Function(return_type, params) => { + self.record_type(return_type); + for param in params { + self.record_type(param); + } + } + _ => { + if let Some(width) = ty.floating_point_bit_width() { + self.surface_widths.insert(width); + } + } + } + } + + #[must_use] + pub(crate) fn surface_width_names(&self) -> Vec<&'static str> { + self.surface_widths + .iter() + .filter_map(|width| match width { + 16 => Some("half"), + 32 => Some("float"), + 64 => Some("double"), + _ => None, + }) + .collect() + } +} + +#[must_use] +pub(crate) fn analyze_float_surface(module: &Module) -> FloatSurfaceAnalysis { + let mut analysis = FloatSurfaceAnalysis::default(); + + for global in &module.globals { + analysis.record_type(&global.ty); + if let Some(initializer) = &global.initializer { + analyze_constant(initializer, &mut analysis); + } + } + + for function in &module.functions { + analyze_function(function, &mut analysis); + } + + analysis +} + +fn analyze_function(function: &Function, analysis: &mut FloatSurfaceAnalysis) { + analysis.record_type(&function.return_type); + for param in &function.params { + analysis.record_type(¶m.ty); + } + + for block in &function.basic_blocks { + for instruction in &block.instructions { + analyze_instruction(instruction, analysis); + } + } +} + +#[allow(clippy::too_many_lines)] +fn analyze_instruction(instruction: &Instruction, analysis: &mut FloatSurfaceAnalysis) { + analysis.has_float_op |= matches!( + instruction, + Instruction::BinOp { + op: BinOpKind::Fadd | BinOpKind::Fsub | BinOpKind::Fmul | BinOpKind::Fdiv, + .. + } | Instruction::FCmp { .. } + | Instruction::Cast { + op: CastKind::FpExt | CastKind::FpTrunc | CastKind::Sitofp | CastKind::Fptosi, + .. + } + ); + + match instruction { + Instruction::Ret(value) => { + if let Some(value) = value { + analyze_operand(value, analysis); + } + } + Instruction::Br { cond_ty, cond, .. } => { + analysis.record_type(cond_ty); + analyze_operand(cond, analysis); + } + Instruction::Jump { .. } | Instruction::Unreachable => {} + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + analysis.record_type(ty); + analyze_operand(lhs, analysis); + analyze_operand(rhs, analysis); + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + analysis.record_type(from_ty); + analysis.record_type(to_ty); + analyze_operand(value, analysis); + } + Instruction::Call { + return_ty, args, .. + } => { + if let Some(return_ty) = return_ty { + analysis.record_type(return_ty); + } + for (ty, operand) in args { + analysis.record_type(ty); + analyze_operand(operand, analysis); + } + } + Instruction::Phi { ty, incoming, .. } => { + analysis.record_type(ty); + for (operand, _) in incoming { + analyze_operand(operand, analysis); + } + } + Instruction::Alloca { ty, .. } => analysis.record_type(ty), + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + analysis.record_type(ty); + analysis.record_type(ptr_ty); + analyze_operand(ptr, analysis); + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + analysis.record_type(ty); + analyze_operand(value, analysis); + analysis.record_type(ptr_ty); + analyze_operand(ptr, analysis); + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + analyze_operand(cond, analysis); + analyze_operand(true_val, analysis); + analyze_operand(false_val, analysis); + analysis.record_type(ty); + } + Instruction::Switch { ty, value, .. } => { + analysis.record_type(ty); + analyze_operand(value, analysis); + } + Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + ptr, + indices, + .. + } => { + analysis.record_type(pointee_ty); + analysis.record_type(ptr_ty); + analyze_operand(ptr, analysis); + for index in indices { + analyze_operand(index, analysis); + } + } + } +} + +fn analyze_constant(constant: &Constant, analysis: &mut FloatSurfaceAnalysis) { + if let Constant::Float(ty, _) = constant { + analysis.record_type(ty); + } +} + +fn analyze_operand(operand: &Operand, analysis: &mut FloatSurfaceAnalysis) { + match operand { + Operand::LocalRef(_) | Operand::GlobalRef(_) | Operand::NullPtr => {} + Operand::TypedLocalRef(_, ty) + | Operand::IntConst(ty, _) + | Operand::FloatConst(ty, _) + | Operand::IntToPtr(_, ty) => analysis.record_type(ty), + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + analysis.record_type(ty); + analysis.record_type(ptr_ty); + for index in indices { + analyze_operand(index, analysis); + } + } + } +} + +fn function_attributes(module: &Module, func_idx: usize) -> impl Iterator + '_ { + module.functions[func_idx] + .attribute_group_refs + .iter() + .filter_map(|&group_ref| module.attribute_groups.iter().find(|ag| ag.id == group_ref)) + .flat_map(|group| group.attributes.iter()) +} + +fn has_function_string_attribute(module: &Module, func_idx: usize, attr_name: &str) -> bool { + function_attributes(module, func_idx) + .any(|attr| matches!(attr, Attribute::StringAttr(name) if name == attr_name)) +} + +/// Extract the integer ID from an `IntToPtr` or `NullPtr` operand. +/// `NullPtr` is treated as ID 0 (`inttoptr(i64 0)` can normalize to `null`). +#[must_use] +pub fn extract_id(operand: &Operand) -> Option { + match operand { + Operand::IntToPtr(val, _) => u32::try_from(*val).ok(), + Operand::NullPtr => Some(0), + _ => None, + } +} + +/// Extract an `f64` value from a `FloatConst` operand. +#[must_use] +pub fn extract_float(operand: &Operand) -> Option { + match operand { + Operand::FloatConst(_, val) => Some(*val), + _ => None, + } +} + +/// Generate a stable string key for an operand. +#[must_use] +pub fn operand_key(operand: &Operand) -> String { + format!("{operand:?}") +} + +/// Find the entry-point function index in a module. +#[must_use] +pub fn find_entry_point(module: &Module) -> Option { + module + .functions + .iter() + .enumerate() + .find_map(|(func_idx, func)| { + (!func.is_declaration + && has_function_string_attribute(module, func_idx, ENTRY_POINT_ATTR)) + .then_some(func_idx) + }) +} + +/// Count the number of non-declaration entry-point functions in a module. +#[must_use] +pub(crate) fn count_entry_points(module: &Module) -> usize { + module + .functions + .iter() + .enumerate() + .filter(|(func_idx, func)| { + !func.is_declaration + && has_function_string_attribute(module, *func_idx, ENTRY_POINT_ATTR) + }) + .count() +} + +/// Extract a key-value attribute from the given function's attribute groups. +#[must_use] +pub fn get_function_attribute<'a>( + module: &'a Module, + func_idx: usize, + key: &str, +) -> Option<&'a str> { + function_attributes(module, func_idx).find_map(|attr| { + if let Attribute::KeyValue(attr_key, value) = attr + && attr_key == key + { + Some(value.as_str()) + } else { + None + } + }) +} + +/// Check whether a function has the given attribute in string or key-value form. +#[must_use] +pub(crate) fn has_function_attribute(module: &Module, func_idx: usize, attr_name: &str) -> bool { + function_attributes(module, func_idx).any(|attr| match attr { + Attribute::StringAttr(name) => name == attr_name, + Attribute::KeyValue(key, _) => key == attr_name, + }) +} + +/// Look up a module flag value by key. +#[must_use] +pub(crate) fn get_module_flag<'a>(module: &'a Module, key: &str) -> Option<&'a MetadataValue> { + module.get_flag(key) +} + +#[must_use] +pub(crate) fn inspect_module_flag_metadata(module: &Module) -> Vec { + module + .audit_module_flags() + .issues + .iter() + .map(map_module_flag_node_issue) + .collect() +} + +#[must_use] +pub(crate) fn inspect_module_flag_int(module: &Module, key: &str) -> ModuleFlagAccess { + let mut access = ModuleFlagAccess::default(); + + if let Some(entry) = find_module_flag_entry(module, key) { + match entry.value { + MetadataValue::Int(_, value) => access.value = Some(*value), + other => access.issues.push(ModuleFlagIssue::InvalidValue { + flag_name: key.to_string(), + node_id: entry.node_id, + expected: "integer", + found: describe_metadata_value(other), + }), + } + } + + access +} + +/// Look up an integer module flag value by key. +#[cfg(test)] +#[must_use] +pub(crate) fn get_module_flag_int(module: &Module, key: &str) -> Option { + inspect_module_flag_int(module, key).value +} + +#[must_use] +pub(crate) fn inspect_module_flag_bool(module: &Module, key: &str) -> ModuleFlagAccess { + let mut access = ModuleFlagAccess::default(); + + if let Some(entry) = find_module_flag_entry(module, key) { + match entry.value { + MetadataValue::Int(_, value) => access.value = Some(*value != 0), + other => access.issues.push(ModuleFlagIssue::InvalidValue { + flag_name: key.to_string(), + node_id: entry.node_id, + expected: "integer boolean", + found: describe_metadata_value(other), + }), + } + } + + access +} + +/// Look up a boolean module flag value by key. +#[cfg(test)] +#[must_use] +pub(crate) fn get_module_flag_bool(module: &Module, key: &str) -> bool { + inspect_module_flag_bool(module, key).value.unwrap_or(false) +} + +#[must_use] +pub(crate) fn inspect_module_flag_string_list( + module: &Module, + key: &str, +) -> ModuleFlagAccess> { + let mut access = ModuleFlagAccess::default(); + + if let Some(entry) = find_module_flag_entry(module, key) { + match entry.value { + MetadataValue::SubList(items) => { + let mut values = Vec::with_capacity(items.len()); + for (index, value) in items.iter().enumerate() { + if let MetadataValue::String(text) = value { + values.push(text.clone()); + } else { + access.issues.push(ModuleFlagIssue::InvalidStringListItem { + flag_name: key.to_string(), + node_id: entry.node_id, + index, + found: describe_metadata_value(value), + }); + } + } + + if access.issues.is_empty() { + access.value = Some(values); + } + } + other => access.issues.push(ModuleFlagIssue::InvalidValue { + flag_name: key.to_string(), + node_id: entry.node_id, + expected: "metadata string list", + found: describe_metadata_value(other), + }), + } + } + + access +} + +/// Look up a string-list module flag value by key. +#[cfg(test)] +#[must_use] +pub(crate) fn get_module_flag_string_list(module: &Module, key: &str) -> Vec { + inspect_module_flag_string_list(module, key) + .value + .unwrap_or_default() +} + +#[must_use] +pub(crate) fn inspect_module_flag_behavior(module: &Module, key: &str) -> ModuleFlagAccess { + let mut access = ModuleFlagAccess::default(); + + if let Some(entry) = find_module_flag_entry(module, key) { + match entry.behavior { + MetadataValue::Int(_, behavior) => access.value = Some(*behavior), + other => access.issues.push(ModuleFlagIssue::InvalidBehavior { + flag_name: key.to_string(), + node_id: entry.node_id, + found: describe_metadata_value(other), + }), + } + } + + access +} + +/// Look up the module-flag merge behavior for a given key. +#[cfg(test)] +#[must_use] +pub(crate) fn get_module_flag_behavior(module: &Module, key: &str) -> Option { + inspect_module_flag_behavior(module, key).value +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::Type; + use crate::model::{ + AttributeGroup, BasicBlock, FloatPredicate, Function, GlobalVariable, Instruction, Linkage, + MetadataNode, MetadataValue, Module, NamedMetadata, Operand, Param, + }; + use crate::qir::spec::MODULE_FLAGS_NAME; + + fn module_with_inspection_data() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "decl_entry".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: true, + attribute_group_refs: vec![0], + basic_blocks: Vec::new(), + }, + Function { + name: "ENTRYPOINT__main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![1], + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }], + }, + ], + attribute_groups: vec![ + AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr(ENTRY_POINT_ATTR.to_string())], + }, + AttributeGroup { + id: 1, + attributes: vec![ + Attribute::StringAttr(ENTRY_POINT_ATTR.to_string()), + Attribute::StringAttr("output_labeling_schema".to_string()), + Attribute::KeyValue( + "qir_profiles".to_string(), + "adaptive_profile".to_string(), + ), + ], + }, + ], + named_metadata: vec![NamedMetadata { + name: MODULE_FLAGS_NAME.to_string(), + node_refs: vec![0, 1, 2, 3, 4], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("qir_minor_version".to_string()), + MetadataValue::Int(Type::Integer(32), 0), + ], + }, + MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_qubit_management".to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 3, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_result_management".to_string()), + MetadataValue::Int(Type::Integer(1), 1), + ], + }, + MetadataNode { + id: 4, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("int_computations".to_string()), + MetadataValue::SubList(vec![ + MetadataValue::String("i64".to_string()), + MetadataValue::String("i32".to_string()), + ]), + ], + }, + ], + } + } + + #[test] + fn test_extract_id_from_inttoptr() { + let op = Operand::int_to_named_ptr(7, "Qubit"); + assert_eq!(extract_id(&op), Some(7)); + } + + #[test] + fn test_extract_id_from_nullptr() { + assert_eq!(extract_id(&Operand::NullPtr), Some(0)); + } + + #[test] + fn test_extract_id_from_other() { + let op = Operand::float_const(Type::Double, 1.0); + assert_eq!(extract_id(&op), None); + } + + #[test] + fn test_extract_float() { + let op = Operand::float_const(Type::Double, std::f64::consts::E); + assert_eq!(extract_float(&op), Some(std::f64::consts::E)); + + let op_int = Operand::IntConst(Type::Integer(64), 1); + assert_eq!(extract_float(&op_int), None); + } + + #[test] + fn test_operand_key() { + let op1 = Operand::int_to_named_ptr(0, "Qubit"); + let op2 = Operand::int_to_named_ptr(1, "Qubit"); + let key1 = operand_key(&op1); + let key2 = operand_key(&op2); + assert_ne!(key1, key2); + + let key1b = operand_key(&op1); + assert_eq!(key1, key1b); + } + + #[test] + fn test_find_entry_point_ignores_declarations() { + let module = module_with_inspection_data(); + assert_eq!(find_entry_point(&module), Some(1)); + } + + #[test] + fn test_count_entry_points_ignores_declarations() { + let module = module_with_inspection_data(); + assert_eq!(count_entry_points(&module), 1); + } + + #[test] + fn test_get_function_attribute_reads_key_value() { + let module = module_with_inspection_data(); + assert_eq!( + get_function_attribute(&module, 1, "qir_profiles"), + Some("adaptive_profile") + ); + } + + #[test] + fn test_has_function_attribute_matches_string_and_key_value() { + let module = module_with_inspection_data(); + assert!(has_function_attribute(&module, 1, ENTRY_POINT_ATTR)); + assert!(has_function_attribute(&module, 1, "output_labeling_schema")); + assert!(has_function_attribute(&module, 1, "qir_profiles")); + } + + #[test] + fn test_get_module_flag_helpers() { + let module = module_with_inspection_data(); + + assert_eq!(get_module_flag_int(&module, "qir_major_version"), Some(2)); + assert!(!get_module_flag_bool(&module, "dynamic_qubit_management")); + assert!(get_module_flag_bool(&module, "dynamic_result_management")); + assert_eq!( + get_module_flag_string_list(&module, "int_computations"), + vec!["i64".to_string(), "i32".to_string()] + ); + assert_eq!( + get_module_flag_behavior(&module, "qir_minor_version"), + Some(7) + ); + } + + #[test] + fn test_get_module_flag_string_list_reads_float_computations() { + let mut module = module_with_inspection_data(); + module.named_metadata[0].node_refs.push(5); + module.metadata_nodes.push(MetadataNode { + id: 5, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("float_computations".to_string()), + MetadataValue::SubList(vec![ + MetadataValue::String("half".to_string()), + MetadataValue::String("double".to_string()), + ]), + ], + }); + + assert_eq!( + get_module_flag_string_list(&module, "float_computations"), + vec!["half".to_string(), "double".to_string()] + ); + } + + #[test] + fn test_inspect_module_flag_metadata_reports_dangling_refs_without_hiding_valid_flags() { + let mut module = module_with_inspection_data(); + module.named_metadata[0].node_refs.insert(0, 999); + + assert_eq!( + inspect_module_flag_metadata(&module), + vec![ModuleFlagIssue::DanglingReference { node_ref: 999 }] + ); + assert_eq!( + inspect_module_flag_int(&module, "qir_major_version").value, + Some(2) + ); + } + + #[test] + fn test_inspect_module_flag_access_reports_malformed_payloads() { + let mut module = module_with_inspection_data(); + module.metadata_nodes[3].values[2] = MetadataValue::String("true".to_string()); + module.metadata_nodes[4].values[2] = MetadataValue::SubList(vec![ + MetadataValue::String("i64".to_string()), + MetadataValue::Int(Type::Integer(32), 1), + ]); + + let bool_flag = inspect_module_flag_bool(&module, "dynamic_result_management"); + assert_eq!(bool_flag.value, None); + assert_eq!( + bool_flag.issues, + vec![ModuleFlagIssue::InvalidValue { + flag_name: "dynamic_result_management".to_string(), + node_id: 3, + expected: "integer boolean", + found: "string".to_string(), + }] + ); + + let string_list_flag = inspect_module_flag_string_list(&module, "int_computations"); + assert_eq!(string_list_flag.value, None); + assert_eq!( + string_list_flag.issues, + vec![ModuleFlagIssue::InvalidStringListItem { + flag_name: "int_computations".to_string(), + node_id: 4, + index: 1, + found: "integer (i32)".to_string(), + }] + ); + } + + #[test] + fn test_analyze_float_surface_collects_recursive_widths_and_ops() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: vec![GlobalVariable { + name: "g".to_string(), + ty: Type::Array(1, Box::new(Type::Float)), + linkage: Linkage::Internal, + is_constant: false, + initializer: None, + }], + functions: vec![ + Function { + name: "decl".to_string(), + return_type: Type::TypedPtr(Box::new(Type::Double)), + params: vec![Param { + ty: Type::Function(Box::new(Type::Void), vec![Type::Half]), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "entry".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::FCmp { + pred: FloatPredicate::Olt, + ty: Type::Half, + lhs: Operand::TypedLocalRef("lhs".to_string(), Type::Half), + rhs: Operand::float_const(Type::Half, 0.0), + result: "cond".to_string(), + }, + Instruction::Select { + cond: Operand::LocalRef("cond".to_string()), + true_val: Operand::TypedLocalRef("then".to_string(), Type::Float), + false_val: Operand::TypedLocalRef("else".to_string(), Type::Double), + ty: Type::Float, + result: "value".to_string(), + }, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let analysis = analyze_float_surface(&module); + + assert!(analysis.has_float_op); + assert_eq!( + analysis.surface_width_names(), + vec!["half", "float", "double"] + ); + } + + #[test] + fn test_analyze_float_surface_tracks_declaration_only_widths_without_ops() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: vec![GlobalVariable { + name: "g".to_string(), + ty: Type::Float, + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::Float(Type::Float, 1.0)), + }], + functions: vec![Function { + name: "decl".to_string(), + return_type: Type::Double, + params: vec![Param { + ty: Type::TypedPtr(Box::new(Type::Half)), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let analysis = analyze_float_surface(&module); + + assert!(!analysis.has_float_op); + assert_eq!( + analysis.surface_width_names(), + vec!["half", "float", "double"] + ); + } + + #[test] + fn test_analyze_float_surface_ignores_metadata_payload_types() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: vec![MetadataNode { + id: 0, + values: vec![MetadataValue::Int(Type::Double, 1)], + }], + }; + + let analysis = analyze_float_surface(&module); + + assert!(!analysis.has_float_op); + assert!(analysis.surface_width_names().is_empty()); + } +} diff --git a/source/compiler/qsc_llvm/src/qir/spec.rs b/source/compiler/qsc_llvm/src/qir/spec.rs new file mode 100644 index 0000000000..6805494a05 --- /dev/null +++ b/source/compiler/qsc_llvm/src/qir/spec.rs @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::model::{StructType, Type}; + +/// Represents the writer-facing compatibility contract for emitted QIR bitcode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QirEmitTarget { + /// Typed-pointer QIR v1 output for legacy LLVM 14 and 15 style consumers. + QirV1Typed, + /// Opaque-pointer QIR v2 output for LLVM 16 and newer style consumers. + QirV2Opaque, +} + +impl QirEmitTarget { + /// The QIR major version number for this emit target. + #[must_use] + pub const fn major_version(self) -> i64 { + match self { + Self::QirV1Typed => 1, + Self::QirV2Opaque => 2, + } + } + + /// The QIR minor version number for this emit target. + #[must_use] + pub const fn minor_version(self) -> i64 { + 0 + } + + /// Whether this target emits typed pointers (`%Qubit*`) instead of opaque pointers (`ptr`). + #[must_use] + pub const fn uses_typed_pointers(self) -> bool { + matches!(self, Self::QirV1Typed) + } + + /// The module layout version currently emitted for this target. + #[must_use] + pub const fn module_bitcode_version(self) -> u64 { + match self { + Self::QirV1Typed => 1, + Self::QirV2Opaque => 2, + } + } + + /// Opaque struct type declarations needed for typed-pointer emit targets. + #[must_use] + pub fn struct_types(self) -> Vec { + if self.uses_typed_pointers() { + vec![ + StructType { + name: RESULT_TYPE_NAME.into(), + is_opaque: true, + }, + StructType { + name: QUBIT_TYPE_NAME.into(), + is_opaque: true, + }, + ] + } else { + vec![] + } + } + + /// Returns the pointer type the writer should synthesize for the given pointee. + #[must_use] + pub fn pointer_type_for_pointee(self, pointee: &Type) -> Type { + match self { + Self::QirV1Typed => match pointee { + Type::Named(name) | Type::NamedPtr(name) => Type::NamedPtr(name.clone()), + Type::TypedPtr(inner) => Type::TypedPtr(inner.clone()), + other => Type::TypedPtr(Box::new(other.clone())), + }, + Self::QirV2Opaque => Type::Ptr, + } + } + + /// Returns the default pointer type to use when the model carries an untyped null pointer. + #[must_use] + pub fn default_pointer_type(self) -> Type { + self.pointer_type_for_pointee(&Type::Integer(8)) + } +} + +/// Represents a concrete QIR spec profile and version combination. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QirProfile { + /// QIR Base Profile v1: typed pointers, no dynamic features. + BaseV1, + /// QIR Adaptive Profile v1: typed pointers, forward branching and qubit reuse. + AdaptiveV1, + /// QIR Adaptive Profile v2: opaque pointers, full capabilities. + AdaptiveV2, +} + +impl QirProfile { + /// The writer-facing emit target for this profile. + #[must_use] + pub const fn emit_target(self) -> QirEmitTarget { + match self { + Self::BaseV1 | Self::AdaptiveV1 => QirEmitTarget::QirV1Typed, + Self::AdaptiveV2 => QirEmitTarget::QirV2Opaque, + } + } + + /// The `qir_profiles` attribute value for this profile. + #[must_use] + pub const fn profile_name(self) -> &'static str { + match self { + Self::BaseV1 => BASE_PROFILE, + Self::AdaptiveV1 | Self::AdaptiveV2 => ADAPTIVE_PROFILE, + } + } + + /// The QIR major version number for module flags metadata. + #[must_use] + pub const fn major_version(self) -> i64 { + self.emit_target().major_version() + } + + /// The QIR minor version number for module flags metadata. + #[must_use] + pub const fn minor_version(self) -> i64 { + self.emit_target().minor_version() + } + + /// Whether this profile uses typed pointers (`%Qubit*`) or opaque pointers (`ptr`). + #[must_use] + pub const fn uses_typed_pointers(self) -> bool { + self.emit_target().uses_typed_pointers() + } + + /// Opaque struct type declarations needed for typed-pointer profiles. + #[must_use] + pub fn struct_types(self) -> Vec { + self.emit_target().struct_types() + } +} + +/// Returns the argument index that must carry a string output label for the given runtime call. +#[must_use] +pub fn output_label_arg_index(callee: &str) -> Option { + match callee { + rt::TUPLE_RECORD_OUTPUT + | rt::ARRAY_RECORD_OUTPUT + | rt::RESULT_RECORD_OUTPUT + | rt::BOOL_RECORD_OUTPUT + | rt::INT_RECORD_OUTPUT + | rt::DOUBLE_RECORD_OUTPUT => Some(1), + rt::RESULT_ARRAY_RECORD_OUTPUT => Some(2), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn qir_profiles_map_to_explicit_emit_targets() { + assert_eq!(QirProfile::BaseV1.emit_target(), QirEmitTarget::QirV1Typed); + assert_eq!( + QirProfile::AdaptiveV1.emit_target(), + QirEmitTarget::QirV1Typed + ); + assert_eq!( + QirProfile::AdaptiveV2.emit_target(), + QirEmitTarget::QirV2Opaque + ); + } + + #[test] + fn qir_emit_targets_choose_expected_pointer_shapes() { + assert_eq!( + QirEmitTarget::QirV1Typed.pointer_type_for_pointee(&Type::Named("Qubit".into())), + Type::NamedPtr("Qubit".into()) + ); + assert_eq!( + QirEmitTarget::QirV1Typed.pointer_type_for_pointee(&Type::Integer(8)), + Type::TypedPtr(Box::new(Type::Integer(8))) + ); + assert_eq!( + QirEmitTarget::QirV2Opaque.pointer_type_for_pointee(&Type::Named("Qubit".into())), + Type::Ptr + ); + } + + #[test] + fn qir_emit_targets_choose_expected_module_bitcode_versions() { + assert_eq!(QirEmitTarget::QirV1Typed.module_bitcode_version(), 1); + assert_eq!(QirEmitTarget::QirV2Opaque.module_bitcode_version(), 2); + } + + #[test] + fn output_recording_calls_map_to_expected_label_argument_indexes() { + for callee in [ + rt::TUPLE_RECORD_OUTPUT, + rt::ARRAY_RECORD_OUTPUT, + rt::RESULT_RECORD_OUTPUT, + rt::BOOL_RECORD_OUTPUT, + rt::INT_RECORD_OUTPUT, + rt::DOUBLE_RECORD_OUTPUT, + ] { + assert_eq!(output_label_arg_index(callee), Some(1)); + } + + assert_eq!( + output_label_arg_index(rt::RESULT_ARRAY_RECORD_OUTPUT), + Some(2) + ); + assert_eq!(output_label_arg_index(rt::INITIALIZE), None); + } +} + +pub const BASE_PROFILE: &str = "base_profile"; +pub const ADAPTIVE_PROFILE: &str = "adaptive_profile"; + +pub const ENTRYPOINT_NAME: &str = "ENTRYPOINT__main"; + +pub const ENTRY_POINT_ATTR: &str = "entry_point"; +pub const OUTPUT_LABELING_SCHEMA_ATTR: &str = "output_labeling_schema"; +pub const QIR_PROFILES_ATTR: &str = "qir_profiles"; +pub const REQUIRED_NUM_QUBITS_ATTR: &str = "required_num_qubits"; +pub const REQUIRED_NUM_RESULTS_ATTR: &str = "required_num_results"; +pub const IRREVERSIBLE_ATTR: &str = "irreversible"; +pub const QDK_NOISE_ATTR: &str = "qdk_noise"; + +pub const QIR_MAJOR_VERSION_KEY: &str = "qir_major_version"; +pub const QIR_MINOR_VERSION_KEY: &str = "qir_minor_version"; +pub const DYNAMIC_QUBIT_MGMT_KEY: &str = "dynamic_qubit_management"; +pub const DYNAMIC_RESULT_MGMT_KEY: &str = "dynamic_result_management"; +pub const INT_COMPUTATIONS_KEY: &str = "int_computations"; +pub const FLOAT_COMPUTATIONS_KEY: &str = "float_computations"; +pub const BACKWARDS_BRANCHING_KEY: &str = "backwards_branching"; +pub const ARRAYS_KEY: &str = "arrays"; +pub const IR_FUNCTIONS_KEY: &str = "ir_functions"; +pub const MULTIPLE_TARGET_BRANCHING_KEY: &str = "multiple_target_branching"; +pub const MULTIPLE_RETURN_POINTS_KEY: &str = "multiple_return_points"; +pub const MODULE_FLAGS_NAME: &str = "llvm.module.flags"; + +pub const QUBIT_TYPE_NAME: &str = "Qubit"; +pub const RESULT_TYPE_NAME: &str = "Result"; + +pub const ENTRY_POINT_ATTR_GROUP_ID: u32 = 0; +pub const IRREVERSIBLE_ATTR_GROUP_ID: u32 = 1; +pub const QDK_NOISE_ATTR_GROUP_ID: u32 = 2; + +pub const FLAG_BEHAVIOR_ERROR: i64 = 1; +pub const FLAG_BEHAVIOR_APPEND: i64 = 5; +pub const FLAG_BEHAVIOR_MAX: i64 = 7; + +pub mod rt { + pub const INITIALIZE: &str = "__quantum__rt__initialize"; + pub const READ_RESULT: &str = "__quantum__rt__read_result"; + pub const READ_LOSS: &str = "__quantum__rt__read_loss"; + pub const RESULT_RECORD_OUTPUT: &str = "__quantum__rt__result_record_output"; + pub const TUPLE_RECORD_OUTPUT: &str = "__quantum__rt__tuple_record_output"; + pub const ARRAY_RECORD_OUTPUT: &str = "__quantum__rt__array_record_output"; + pub const BOOL_RECORD_OUTPUT: &str = "__quantum__rt__bool_record_output"; + pub const INT_RECORD_OUTPUT: &str = "__quantum__rt__int_record_output"; + pub const DOUBLE_RECORD_OUTPUT: &str = "__quantum__rt__double_record_output"; + pub const QUBIT_ALLOCATE: &str = "__quantum__rt__qubit_allocate"; + pub const QUBIT_BORROW: &str = "__quantum__rt__qubit_borrow"; + pub const QUBIT_RELEASE: &str = "__quantum__rt__qubit_release"; + pub const BEGIN_PARALLEL: &str = "__quantum__rt__begin_parallel"; + pub const END_PARALLEL: &str = "__quantum__rt__end_parallel"; + pub const READ_ATOM_RESULT: &str = "__quantum__rt__read_atom_result"; + pub const RESULT_ALLOCATE: &str = "__quantum__rt__result_allocate"; + pub const RESULT_RELEASE: &str = "__quantum__rt__result_release"; + pub const QUBIT_ARRAY_ALLOCATE: &str = "__quantum__rt__qubit_array_allocate"; + pub const QUBIT_ARRAY_RELEASE: &str = "__quantum__rt__qubit_array_release"; + pub const RESULT_ARRAY_ALLOCATE: &str = "__quantum__rt__result_array_allocate"; + pub const RESULT_ARRAY_RELEASE: &str = "__quantum__rt__result_array_release"; + pub const RESULT_ARRAY_RECORD_OUTPUT: &str = "__quantum__rt__result_array_record_output"; +} + +pub mod qis { + pub const X: &str = "__quantum__qis__x__body"; + pub const Y: &str = "__quantum__qis__y__body"; + pub const Z: &str = "__quantum__qis__z__body"; + pub const H: &str = "__quantum__qis__h__body"; + pub const S: &str = "__quantum__qis__s__body"; + pub const S_ADJ: &str = "__quantum__qis__s__adj"; + pub const SX: &str = "__quantum__qis__sx__body"; + pub const T: &str = "__quantum__qis__t__body"; + pub const T_ADJ: &str = "__quantum__qis__t__adj"; + + pub const RX: &str = "__quantum__qis__rx__body"; + pub const RY: &str = "__quantum__qis__ry__body"; + pub const RZ: &str = "__quantum__qis__rz__body"; + + pub const CX: &str = "__quantum__qis__cx__body"; + pub const CY: &str = "__quantum__qis__cy__body"; + pub const CZ: &str = "__quantum__qis__cz__body"; + pub const SWAP: &str = "__quantum__qis__swap__body"; + + pub const RXX: &str = "__quantum__qis__rxx__body"; + pub const RYY: &str = "__quantum__qis__ryy__body"; + pub const RZZ: &str = "__quantum__qis__rzz__body"; + + pub const CCX: &str = "__quantum__qis__ccx__body"; + + pub const M: &str = "__quantum__qis__m__body"; + pub const MZ: &str = "__quantum__qis__mz__body"; + pub const MRESETZ: &str = "__quantum__qis__mresetz__body"; + pub const RESET: &str = "__quantum__qis__reset__body"; + + pub const BARRIER: &str = "__quantum__qis__barrier__body"; + pub const MOVE: &str = "__quantum__qis__move__body"; + pub const READ_RESULT: &str = "__quantum__qis__read_result__body"; +} diff --git a/source/compiler/qsc_llvm/src/test_utils.rs b/source/compiler/qsc_llvm/src/test_utils.rs new file mode 100644 index 0000000000..d688d17401 --- /dev/null +++ b/source/compiler/qsc_llvm/src/test_utils.rs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::path::PathBuf; +use std::process::Command; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +static TEMP_FILE_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum LlvmCompatLaneKind { + LegacyTyped, + BridgeDualMode, + OpaquePreferred, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct LlvmCompatLane { + pub(crate) version: u8, + pub(crate) kind: LlvmCompatLaneKind, +} + +impl LlvmCompatLane { + pub(crate) const LLVM_14: Self = Self { + version: 14, + kind: LlvmCompatLaneKind::LegacyTyped, + }; + pub(crate) const LLVM_15: Self = Self { + version: 15, + kind: LlvmCompatLaneKind::BridgeDualMode, + }; + pub(crate) const LLVM_16: Self = Self { + version: 16, + kind: LlvmCompatLaneKind::OpaquePreferred, + }; + pub(crate) const LLVM_21: Self = Self { + version: 21, + kind: LlvmCompatLaneKind::OpaquePreferred, + }; + pub(crate) const FAST_MATRIX: [Self; 4] = + [Self::LLVM_14, Self::LLVM_15, Self::LLVM_16, Self::LLVM_21]; + + #[must_use] + pub(crate) fn tool_path(self, tool: &str) -> PathBuf { + PathBuf::from(format!( + "/opt/homebrew/opt/llvm@{}/bin/{tool}", + self.version + )) + } + + #[must_use] + pub(crate) fn tool_command(self, tool: &str) -> Command { + Command::new(self.tool_path(tool)) + } + + #[must_use] + pub(crate) fn has_tool(self, tool: &str) -> bool { + self.tool_command(tool) + .arg("--version") + .output() + .is_ok_and(|output| output.status.success()) + } + + #[must_use] + pub(crate) fn is_available(self) -> bool { + ["llvm-as", "llvm-dis", "opt"] + .into_iter() + .all(|tool| self.has_tool(tool)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum PointerProbe { + TypedText, + OpaqueText, +} + +impl PointerProbe { + #[must_use] + pub(crate) const fn tool_args(self, lane: LlvmCompatLane) -> &'static [&'static str] { + match (lane.version, self) { + (14, Self::OpaqueText) => &["-opaque-pointers"], + _ => &[], + } + } +} + +#[must_use] +pub(crate) fn available_fast_matrix_lanes() -> Vec { + LlvmCompatLane::FAST_MATRIX + .into_iter() + .filter(|lane| lane.is_available()) + .collect() +} + +pub(crate) fn assemble_text_ir( + lane: LlvmCompatLane, + probe: PointerProbe, + text: &str, +) -> Result, String> { + let tmp_ll = unique_temp_path(&format!("qsc-llvm{}", lane.version), "ll"); + let tmp_bc = unique_temp_path(&format!("qsc-llvm{}", lane.version), "bc"); + + std::fs::write(&tmp_ll, text) + .map_err(|error| format!("write {}: {error}", tmp_ll.display()))?; + + let mut command = lane.tool_command("llvm-as"); + for arg in probe.tool_args(lane) { + command.arg(arg); + } + let output = command + .arg(&tmp_ll) + .arg("-o") + .arg(&tmp_bc) + .output() + .map_err(|error| format!("spawn llvm-as {}: {error}", lane.version))?; + + std::fs::remove_file(&tmp_ll).ok(); + + if !output.status.success() { + std::fs::remove_file(&tmp_bc).ok(); + return Err(String::from_utf8_lossy(&output.stderr).into_owned()); + } + + let bitcode = + std::fs::read(&tmp_bc).map_err(|error| format!("read {}: {error}", tmp_bc.display()))?; + std::fs::remove_file(&tmp_bc).ok(); + Ok(bitcode) +} + +pub(crate) fn disassemble_bitcode( + lane: LlvmCompatLane, + probe: PointerProbe, + bitcode: &[u8], +) -> Result { + let tmp_bc = unique_temp_path(&format!("qsc-llvm{}", lane.version), "bc"); + std::fs::write(&tmp_bc, bitcode) + .map_err(|error| format!("write {}: {error}", tmp_bc.display()))?; + + let mut command = lane.tool_command("llvm-dis"); + for arg in probe.tool_args(lane) { + command.arg(arg); + } + let output = command + .arg(&tmp_bc) + .arg("-o") + .arg("-") + .output() + .map_err(|error| format!("spawn llvm-dis {}: {error}", lane.version))?; + + std::fs::remove_file(&tmp_bc).ok(); + + if !output.status.success() { + return Err(String::from_utf8_lossy(&output.stderr).into_owned()); + } + + String::from_utf8(output.stdout).map_err(|error| error.to_string()) +} + +pub(crate) fn verify_bitcode( + lane: LlvmCompatLane, + probe: PointerProbe, + bitcode: &[u8], +) -> Result<(), String> { + let tmp_bc = unique_temp_path(&format!("qsc-llvm{}", lane.version), "bc"); + std::fs::write(&tmp_bc, bitcode) + .map_err(|error| format!("write {}: {error}", tmp_bc.display()))?; + + let mut command = lane.tool_command("opt"); + for arg in probe.tool_args(lane) { + command.arg(arg); + } + let output = command + .arg("-passes=verify") + .arg(&tmp_bc) + .arg("-disable-output") + .output() + .map_err(|error| format!("spawn opt {}: {error}", lane.version))?; + + std::fs::remove_file(&tmp_bc).ok(); + + if output.status.success() { + Ok(()) + } else { + Err(String::from_utf8_lossy(&output.stderr).into_owned()) + } +} + +pub(crate) fn analyze_bitcode(lane: LlvmCompatLane, bitcode: &[u8]) -> Result { + let tmp_bc = unique_temp_path(&format!("qsc-llvm{}", lane.version), "bc"); + std::fs::write(&tmp_bc, bitcode) + .map_err(|error| format!("write {}: {error}", tmp_bc.display()))?; + + let output = lane + .tool_command("llvm-bcanalyzer") + .arg("-dump") + .arg("--disable-histogram") + .arg(&tmp_bc) + .output() + .map_err(|error| format!("spawn llvm-bcanalyzer {}: {error}", lane.version))?; + + std::fs::remove_file(&tmp_bc).ok(); + + if !output.status.success() { + return Err(String::from_utf8_lossy(&output.stderr).into_owned()); + } + + String::from_utf8(output.stdout).map_err(|error| error.to_string()) +} + +#[must_use] +fn unique_temp_path(prefix: &str, extension: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let counter = TEMP_FILE_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "{prefix}-{}-{nanos}-{counter}.{extension}", + std::process::id() + )) +} diff --git a/source/compiler/qsc_llvm/src/tests.rs b/source/compiler/qsc_llvm/src/tests.rs new file mode 100644 index 0000000000..34a62f59bf --- /dev/null +++ b/source/compiler/qsc_llvm/src/tests.rs @@ -0,0 +1,4189 @@ +use super::model::Type; +use super::model::test_helpers::*; +use super::model::*; +use super::qir::{QirEmitTarget, QirProfile}; +use super::test_utils::{ + LlvmCompatLane, PointerProbe, analyze_bitcode, assemble_text_ir, available_fast_matrix_lanes, + disassemble_bitcode, verify_bitcode, +}; +use super::{ + QirProfilePreset, QirSmithConfig, ReadPolicy, generate_module_from_bytes, + parse_bitcode_compatibility_report, parse_bitcode_detailed, parse_module, + write_bitcode_for_target, write_module_to_string, +}; + +#[test] +fn round_trip_empty_module() { + let m = empty_module(); + let text1 = write_module_to_string(&m); + let parsed = parse_module(&text1).expect("failed to parse module"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text1, text2); +} + +#[test] +fn round_trip_bell_module_v2() { + let m = bell_module_v2(); + let text1 = write_module_to_string(&m); + let parsed = parse_module(&text1).expect("failed to parse module"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text1, text2); + assert_eq!(m, parsed); +} + +#[allow(clippy::too_many_lines)] +#[test] +fn text_to_model_preserves_all_constructs() { + // Build a comprehensive module with every construct type + let m = Module { + source_filename: Some("qir".to_string()), + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".to_string(), + is_opaque: true, + }], + globals: vec![GlobalVariable { + name: "0".to_string(), + ty: Type::Array(4, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("0_r".to_string())), + }], + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "ENTRYPOINT__main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![ + BasicBlock { + name: "block_0".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "var_0".to_string(), + }, + Instruction::ICmp { + pred: IntPredicate::Slt, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 10), + result: "var_1".to_string(), + }, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("var_1".to_string()), + true_dest: "block_1".to_string(), + false_dest: "block_2".to_string(), + }, + ], + }, + BasicBlock { + name: "block_1".to_string(), + instructions: vec![Instruction::Jump { + dest: "block_2".to_string(), + }], + }, + BasicBlock { + name: "block_2".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }, + ], + }, + ], + attribute_groups: vec![AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr("entry_point".to_string()), + Attribute::KeyValue("qir_profiles".to_string(), "adaptive_profile".to_string()), + ], + }], + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("int_computations".to_string()), + MetadataValue::SubList(vec![MetadataValue::String("i64".to_string())]), + ], + }, + ], + }; + + let text1 = write_module_to_string(&m); + let parsed = parse_module(&text1).expect("failed to parse comprehensive module"); + let text2 = write_module_to_string(&parsed); + + // Text round-trip + assert_eq!(text1, text2); + + // Model equality + assert_eq!(m.source_filename, parsed.source_filename); + assert_eq!(m.struct_types, parsed.struct_types); + assert_eq!(m.globals, parsed.globals); + assert_eq!(m.functions.len(), parsed.functions.len()); + assert_eq!(m.attribute_groups, parsed.attribute_groups); + assert_eq!(m.named_metadata, parsed.named_metadata); + assert_eq!(m.metadata_nodes, parsed.metadata_nodes); + assert_eq!(m, parsed); +} + +// --- Cross-format round-trip tests (Step 6.2) --- + +#[test] +fn text_to_bitcode_to_text_round_trip() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + // Build a module, emit text, parse text, write bitcode, read bitcode, emit text again + let m = bell_module_v2(); + let text1 = write_module_to_string(&m); + let parsed_text = parse_module(&text1).expect("text parse failed"); + let bc = write_bitcode(&parsed_text); + let parsed_bc = parse_bitcode(&bc).expect("bitcode parse failed"); + + // Compare structural properties that survive bitcode round-trip + assert_eq!(parsed_bc.functions.len(), m.functions.len()); + for (orig, parsed) in m.functions.iter().zip(parsed_bc.functions.iter()) { + assert_eq!(orig.name, parsed.name); + assert_eq!(orig.is_declaration, parsed.is_declaration); + assert_eq!(orig.params.len(), parsed.params.len()); + assert_eq!(orig.attribute_group_refs, parsed.attribute_group_refs); + } + assert_eq!(parsed_bc.attribute_groups, m.attribute_groups); + assert_eq!(parsed_bc.named_metadata, m.named_metadata); + assert_eq!(parsed_bc.metadata_nodes, m.metadata_nodes); +} + +#[test] +fn bitcode_roundtrip_attribute_groups() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "__quantum__qis__m__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: vec![1], + basic_blocks: Vec::new(), + }, + Function { + name: "ENTRYPOINT__main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }], + }, + ], + attribute_groups: vec![ + AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr("entry_point".to_string()), + Attribute::StringAttr("output_labeling_schema".to_string()), + Attribute::KeyValue("qir_profiles".to_string(), "adaptive_profile".to_string()), + Attribute::KeyValue("required_num_qubits".to_string(), "2".to_string()), + Attribute::KeyValue("required_num_results".to_string(), "2".to_string()), + ], + }, + AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr("irreversible".to_string())], + }, + ], + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let bc1 = write_bitcode(&m); + let parsed = parse_bitcode(&bc1).expect("parse failed"); + assert_eq!(parsed.attribute_groups, m.attribute_groups); + assert_eq!( + parsed.functions[0].attribute_group_refs, + m.functions[0].attribute_group_refs + ); + assert_eq!( + parsed.functions[1].attribute_group_refs, + m.functions[1].attribute_group_refs + ); + let bc2 = write_bitcode(&parsed); + assert_eq!(bc1, bc2); +} + +#[test] +fn bitcode_roundtrip_call_site_attr_refs() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: Vec::new(), + result: None, + attr_refs: vec![0, 1], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: vec![ + AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr("alwaysinline".to_string())], + }, + AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr("noreturn".to_string())], + }, + ], + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let bc1 = write_bitcode(&m); + let parsed = parse_bitcode(&bc1).expect("parse failed"); + + assert_eq!(parsed.attribute_groups, m.attribute_groups); + assert!(matches!( + &parsed.functions[1].basic_blocks[0].instructions[0], + Instruction::Call { attr_refs, .. } if attr_refs == &vec![0, 1] + )); + + let bc2 = write_bitcode(&parsed); + assert_eq!(bc1, bc2); +} + +#[test] +fn module_flags_roundtrip() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = bell_module_v2(); + let bc1 = write_bitcode(&m); + let parsed = parse_bitcode(&bc1).expect("parse failed"); + + assert_eq!(parsed.named_metadata, m.named_metadata); + assert_eq!(parsed.metadata_nodes, m.metadata_nodes); + assert!( + parsed + .get_flag("qir_major_version") + .is_some_and(|v| *v == MetadataValue::Int(Type::Integer(32), 2)) + ); + assert!( + parsed + .get_flag("backwards_branching") + .is_some_and(|v| *v == MetadataValue::Int(Type::Integer(2), 3)) + ); + assert!( + parsed + .get_flag("int_computations") + .is_some_and(|v| matches!(v, MetadataValue::SubList(_))) + ); + assert_eq!(parsed.attribute_groups, m.attribute_groups); + for (i, (a, b)) in m.functions.iter().zip(parsed.functions.iter()).enumerate() { + assert_eq!( + a.attribute_group_refs, b.attribute_group_refs, + "function {i} attribute_group_refs mismatch" + ); + } +} + +#[test] +fn bitcode_round_trip_empty_module() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = empty_module(); + let bc1 = write_bitcode(&m); + let parsed = parse_bitcode(&bc1).expect("first parse failed"); + let bc2 = write_bitcode(&parsed); + // Second bitcode should be identical to first + assert_eq!(bc1, bc2); +} + +#[test] +fn bitcode_round_trip_declarations() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![super::model::Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(None)], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let bc1 = write_bitcode(&m); + let parsed = parse_bitcode(&bc1).expect("parse failed"); + assert_eq!(parsed.functions.len(), 2); + assert_eq!(parsed.functions[0].name, "__quantum__qis__h__body"); + assert!(parsed.functions[0].is_declaration); + assert_eq!(parsed.functions[1].name, "main"); + assert!(!parsed.functions[1].is_declaration); + assert_eq!(parsed.functions[1].basic_blocks.len(), 1); +} + +// --- LLVM tool verification tests (Step 6.3) --- + +#[test] +fn bitcode_accepted_by_llvm_dis() { + use super::bitcode::writer::write_bitcode; + use std::io::Write; + use std::process::Command; + + // Check if llvm-dis is available + let llvm_dis = Command::new("llvm-dis").arg("--version").output(); + if llvm_dis.is_err() || !llvm_dis.expect("llvm-dis check failed").status.success() { + eprintln!("llvm-dis not available, skipping test"); + return; + } + + let m = bell_module_v2(); + let bc = write_bitcode(&m); + + let tmp = std::env::temp_dir().join("qsc_test_llvm_dis.bc"); + let mut f = std::fs::File::create(&tmp).expect("failed to create temp file"); + f.write_all(&bc).expect("failed to write bitcode"); + drop(f); + + let output = Command::new("llvm-dis") + .arg(&tmp) + .arg("-o") + .arg("-") + .output() + .expect("failed to run llvm-dis"); + + std::fs::remove_file(&tmp).ok(); + + assert!( + output.status.success(), + "llvm-dis failed: {}", + String::from_utf8_lossy(&output.stderr) + ); +} + +#[test] +fn text_ir_accepted_by_llvm_as() { + use std::io::Write; + use std::process::Command; + + let llvm_as = Command::new("llvm-as").arg("--version").output(); + if llvm_as.is_err() || !llvm_as.expect("llvm-as check failed").status.success() { + eprintln!("llvm-as not available, skipping test"); + return; + } + + let m = bell_module_v2(); + let text = write_module_to_string(&m); + + let tmp = std::env::temp_dir().join("qsc_test_llvm_as.ll"); + let mut f = std::fs::File::create(&tmp).expect("failed to create temp file"); + f.write_all(text.as_bytes()) + .expect("failed to write text IR"); + drop(f); + + let output = Command::new("llvm-as") + .arg(&tmp) + .arg("-o") + .arg("/dev/null") + .output() + .expect("failed to run llvm-as"); + + std::fs::remove_file(&tmp).ok(); + + assert!( + output.status.success(), + "llvm-as failed: {}", + String::from_utf8_lossy(&output.stderr) + ); +} + +#[test] +fn bitcode_analyzable_by_llvm_bcanalyzer() { + use super::bitcode::writer::write_bitcode; + use std::io::Write; + use std::process::Command; + + let llvm_bc = Command::new("llvm-bcanalyzer").arg("--version").output(); + if llvm_bc.is_err() + || !llvm_bc + .expect("llvm-bcanalyzer check failed") + .status + .success() + { + eprintln!("llvm-bcanalyzer not available, skipping test"); + return; + } + + let m = bell_module_v2(); + let bc = write_bitcode(&m); + + let tmp = std::env::temp_dir().join("qsc_test_bcanalyzer.bc"); + let mut f = std::fs::File::create(&tmp).expect("failed to create temp file"); + f.write_all(&bc).expect("failed to write bitcode"); + drop(f); + + let output = Command::new("llvm-bcanalyzer") + .arg(&tmp) + .output() + .expect("failed to run llvm-bcanalyzer"); + + std::fs::remove_file(&tmp).ok(); + + assert!( + output.status.success(), + "llvm-bcanalyzer failed: {}", + String::from_utf8_lossy(&output.stderr) + ); +} + +// --- Phase 1 round-trip tests for new variants --- + +fn round_trip_text(ir: &str) { + let parsed = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + let text = write_module_to_string(&parsed); + let reparsed = parse_module(&text).unwrap_or_else(|e| panic!("reparse failed: {e}")); + let text2 = write_module_to_string(&reparsed); + assert_eq!(text, text2); +} + +fn wrap_instr(body: &str) -> String { + format!( + "\ +declare void @dummy() + +define void @test() {{ +entry: +{body} + ret void +}} +" + ) +} + +fn wrap_instr_i64(body: &str) -> String { + format!( + "\ +declare void @dummy() + +define i64 @test(i64 %a, i64 %b) {{ +entry: +{body} + ret i64 %r +}} +" + ) +} + +#[test] +fn round_trip_select() { + let ir = "\ +declare void @dummy() + +define i64 @test(i64 %a, i64 %b) { +entry: + %cond = icmp slt i64 %a, %b + %r = select i1 %cond, i64 %a, i64 %b + ret i64 %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_switch() { + let ir = "\ +declare void @dummy() + +define void @test(i32 %val) { +entry: + switch i32 %val, label %default [ + i32 0, label %case0 + i32 1, label %case1 + ] +case0: + ret void +case1: + ret void +default: + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_unreachable() { + let ir = wrap_instr(" unreachable"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_udiv() { + let ir = wrap_instr_i64(" %r = udiv i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_urem() { + let ir = wrap_instr_i64(" %r = urem i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_lshr() { + let ir = wrap_instr_i64(" %r = lshr i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_zext() { + let ir = "\ +declare void @dummy() + +define i64 @test(i32 %a) { +entry: + %r = zext i32 %a to i64 + ret i64 %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_sext() { + let ir = "\ +declare void @dummy() + +define i64 @test(i32 %a) { +entry: + %r = sext i32 %a to i64 + ret i64 %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_trunc() { + let ir = "\ +declare void @dummy() + +define i32 @test(i64 %a) { +entry: + %r = trunc i64 %a to i32 + ret i32 %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fpext() { + let ir = "\ +declare void @use_double(double) + +define void @test() { +entry: + %r = fpext double 1.0 to double + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fptrunc() { + let ir = "\ +declare void @use_double(double) + +define void @test() { +entry: + %r = fptrunc double 1.0 to double + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_inttoptr_cast() { + let ir = "\ +declare void @dummy() + +define ptr @test(i64 %a) { +entry: + %r = inttoptr i64 %a to ptr + ret ptr %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_ptrtoint() { + let ir = "\ +declare void @dummy() + +define i64 @test(ptr %a) { +entry: + %r = ptrtoint ptr %a to i64 + ret i64 %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_bitcast() { + let ir = "\ +declare void @dummy() + +define ptr @test(ptr %a) { +entry: + %r = bitcast ptr %a to ptr + ret ptr %r +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_icmp_ult() { + let ir = wrap_instr_i64(" %c = icmp ult i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_ule() { + let ir = wrap_instr_i64(" %c = icmp ule i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_ugt() { + let ir = wrap_instr_i64(" %c = icmp ugt i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_uge() { + let ir = wrap_instr_i64(" %c = icmp uge i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn module_get_flag() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("qir_minor_version".to_string()), + MetadataValue::Int(Type::Integer(32), 0), + ], + }, + ], + }; + + assert_eq!( + m.get_flag("qir_major_version"), + Some(&MetadataValue::Int(Type::Integer(32), 2)) + ); + assert_eq!( + m.get_flag("qir_minor_version"), + Some(&MetadataValue::Int(Type::Integer(32), 0)) + ); + assert_eq!(m.get_flag("nonexistent"), None); +} + +#[test] +fn module_get_flag_skips_dangling_module_flag_refs() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![999, 0, 1], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("qir_minor_version".to_string()), + MetadataValue::Int(Type::Integer(32), 0), + ], + }, + ], + }; + + assert_eq!( + m.get_flag("qir_major_version"), + Some(&MetadataValue::Int(Type::Integer(32), 2)) + ); + assert_eq!( + m.get_flag("qir_minor_version"), + Some(&MetadataValue::Int(Type::Integer(32), 0)) + ); +} + +#[test] +fn bitcode_self_round_trip_comprehensive() { + // Verify text→parse→write_bitcode→parse_bitcode→write_text round-trip + // with a comprehensive module exercising many construct types. + // Our writer produces UNABBREV_RECORD only; this verifies the new + // abbreviation infrastructure (scope tracking etc.) does not break + // existing record reading. + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let m = Module { + source_filename: None, + target_datalayout: Some( + "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128".to_string(), + ), + target_triple: Some("x86_64-unknown-linux-gnu".to_string()), + struct_types: vec![StructType { + name: "Qubit".to_string(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "var_0".to_string(), + }, + Instruction::Ret(Some(Operand::LocalRef("var_0".to_string()))), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let bc1 = write_bitcode(&m); + let parsed1 = parse_bitcode(&bc1).expect("first bitcode parse failed"); + + // Verify structural properties survive the round-trip + assert_eq!(parsed1.functions.len(), m.functions.len()); + assert_eq!(parsed1.target_triple, m.target_triple); + assert_eq!(parsed1.target_datalayout, m.target_datalayout); + for (orig, parsed) in m.functions.iter().zip(parsed1.functions.iter()) { + assert_eq!(orig.name, parsed.name); + assert_eq!(orig.is_declaration, parsed.is_declaration); + assert_eq!(orig.params.len(), parsed.params.len()); + assert_eq!(orig.basic_blocks.len(), parsed.basic_blocks.len()); + } + + // Verify re-encoding the parsed module produces structurally equivalent output + let bc2 = write_bitcode(&parsed1); + let parsed2 = parse_bitcode(&bc2).expect("second bitcode parse failed"); + assert_eq!(parsed1.functions.len(), parsed2.functions.len()); + assert_eq!(parsed1.target_triple, parsed2.target_triple); + assert_eq!(parsed1.target_datalayout, parsed2.target_datalayout); + for (f1, f2) in parsed1.functions.iter().zip(parsed2.functions.iter()) { + assert_eq!(f1.name, f2.name); + assert_eq!(f1.is_declaration, f2.is_declaration); + assert_eq!(f1.basic_blocks.len(), f2.basic_blocks.len()); + } +} + +#[test] +fn bitcode_llvm_as_round_trip() { + // If llvm-as is available, produce bitcode with LLVM (which uses abbreviations) + // and verify our reader can parse it back. + use super::bitcode::reader::parse_bitcode; + let lane = LlvmCompatLane::LLVM_21; + if !lane.is_available() { + eprintln!("llvm@21 toolchain not available, skipping test"); + return; + } + + let text = "\ +; ModuleID = 'test'\n\ +target triple = \"x86_64-unknown-linux-gnu\"\n\ +\n\ +%Qubit = type opaque\n\ +\n\ +declare void @__quantum__qis__h__body(ptr)\n\ +\n\ +define void @main() {\n\ +entry:\n\ + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr))\n\ + ret void\n\ +}\n"; + + let bc_data = assemble_text_ir(lane, PointerProbe::OpaqueText, text) + .unwrap_or_else(|error| panic!("llvm@21 llvm-as failed: {error}")); + + // The LLVM-produced bitcode will contain abbreviations. + // This tests our new DEFINE_ABBREV + abbreviated record support. + let parsed = parse_bitcode(&bc_data).expect("failed to parse LLVM-produced bitcode"); + assert_eq!(parsed.functions.len(), 2); + assert!( + parsed + .functions + .iter() + .any(|f| f.name == "__quantum__qis__h__body") + ); + assert!(parsed.functions.iter().any(|f| f.name == "main")); + + let main = parsed + .functions + .iter() + .find(|function| function.name == "main") + .expect("missing main function"); + assert!(matches!( + &main.basic_blocks[0].instructions[0], + Instruction::Call { callee, .. } if callee == "__quantum__qis__h__body" + )); +} + +fn llvm_modern_module_naming_fixture_ir() -> &'static str { + "\ +target triple = \"x86_64-unknown-linux-gnu\"\n\ +\n\ +define i64 @test(i64 %arg, i64 %other) {\n\ +entry:\n\ + %sum = add i64 %arg, %other\n\ + br label %loop\n\ +loop:\n\ + %acc = phi i64 [ %sum, %entry ], [ %next, %loop ]\n\ + %next = add i64 %acc, 1\n\ + %cond = icmp slt i64 %next, 10\n\ + br i1 %cond, label %loop, label %exit\n\ +exit:\n\ + ret i64 %next\n\ +}\n" +} + +#[test] +fn bitcode_llvm_modern_module_naming_fixture_preserves_names() { + let lane = LlvmCompatLane::LLVM_21; + if !lane.is_available() || !lane.has_tool("llvm-bcanalyzer") { + eprintln!("llvm@21 with llvm-bcanalyzer not available, skipping test"); + return; + } + + let bc_data = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + llvm_modern_module_naming_fixture_ir(), + ) + .unwrap_or_else(|error| panic!("llvm@21 llvm-as failed: {error}")); + + let analysis = analyze_bitcode(lane, &bc_data) + .unwrap_or_else(|error| panic!("llvm@21 llvm-bcanalyzer failed: {error}")); + for expected in [ + "", + " &'static str { + "\ +declare void @__quantum__qis__h__body(ptr)\n\ +\n\ +define void @main() {\n\ +entry:\n\ + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr))\n\ + ret void\n\ +}\n" +} + +fn qir_typed_pointer_smoke_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: QirProfile::AdaptiveV1.struct_types(), + globals: Vec::new(), + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::NamedPtr("Qubit".to_string()), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![( + Type::NamedPtr("Qubit".to_string()), + Operand::int_to_named_ptr(0, "Qubit"), + )], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn qir_opaque_pointer_smoke_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "__quantum__qis__h__body".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".to_string(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn qir_typed_pointer_gep_smoke_module() -> Module { + let array_ty = Type::Array(4, Box::new(Type::Integer(8))); + + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: vec![GlobalVariable { + name: "0".to_string(), + ty: array_ty.clone(), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("abc".to_string())), + }], + functions: vec![ + Function { + name: "use_ptr".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "main".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "use_ptr".to_string(), + args: vec![( + Type::Ptr, + Operand::GetElementPtr { + ty: array_ty.clone(), + ptr: "0".to_string(), + ptr_ty: Type::TypedPtr(Box::new(array_ty.clone())), + indices: vec![ + Operand::IntConst(Type::Integer(64), 0), + Operand::IntConst(Type::Integer(64), 0), + ], + }, + )], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn main_gep_ptr_ty(module: &Module) -> &Type { + match &module.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { args, .. } => match &args[0].1 { + Operand::GetElementPtr { ptr_ty, .. } => ptr_ty, + other => panic!("expected getelementptr operand, found {other:?}"), + }, + other => panic!("expected leading call instruction, found {other:?}"), + } +} + +fn adaptive_half_only_float_external_fixture_ir() -> &'static str { + "declare void @use_half(half)\n\ +\n\ +define i64 @ENTRYPOINT__main() #0 {\n\ +entry:\n\ + %half_sum = fadd half 1.5, 2.25\n\ + call void @use_half(half %half_sum)\n\ + ret i64 0\n\ +}\n\ +\n\ +attributes #0 = { \"entry_point\" \"output_labeling_schema\" \"qir_profiles\"=\"adaptive_profile\" \"required_num_qubits\"=\"0\" \"required_num_results\"=\"0\" }\n\ +\n\ +!llvm.module.flags = !{!0, !1, !2, !3, !4}\n\ +!0 = !{i32 1, !\"qir_major_version\", i32 2}\n\ +!1 = !{i32 7, !\"qir_minor_version\", i32 0}\n\ +!2 = !{i32 1, !\"dynamic_qubit_management\", i1 false}\n\ +!3 = !{i32 1, !\"dynamic_result_management\", i1 false}\n\ +!4 = !{i32 5, !\"float_computations\", !{!\"half\"}}\n" +} + +fn adaptive_no_float_external_fixture_ir() -> &'static str { + "define i64 @ENTRYPOINT__main() #0 {\n\ +entry:\n\ + ret i64 0\n\ +}\n\ +\n\ +attributes #0 = { \"entry_point\" \"output_labeling_schema\" \"qir_profiles\"=\"adaptive_profile\" \"required_num_qubits\"=\"0\" \"required_num_results\"=\"0\" }\n\ +\n\ +!llvm.module.flags = !{!0, !1, !2, !3}\n\ +!0 = !{i32 1, !\"qir_major_version\", i32 2}\n\ +!1 = !{i32 7, !\"qir_minor_version\", i32 0}\n\ +!2 = !{i32 1, !\"dynamic_qubit_management\", i1 false}\n\ +!3 = !{i32 1, !\"dynamic_result_management\", i1 false}\n" +} + +fn assert_opaque_qir_text_fixture_survives_fast_matrix( + text: &str, + expected_substrings: &[&str], + absent_substrings: &[&str], +) { + let lanes = available_fast_matrix_lanes(); + if lanes.is_empty() { + eprintln!("external LLVM fast matrix not available, skipping test"); + return; + } + + for lane in lanes { + let bitcode = + assemble_text_ir(lane, PointerProbe::OpaqueText, text).unwrap_or_else(|error| { + panic!("llvm@{} opaque assemble failed: {error}", lane.version) + }); + + verify_bitcode(lane, PointerProbe::OpaqueText, &bitcode) + .unwrap_or_else(|error| panic!("llvm@{} opaque verify failed: {error}", lane.version)); + + let disassembly = disassemble_bitcode(lane, PointerProbe::OpaqueText, &bitcode) + .unwrap_or_else(|error| { + panic!("llvm@{} opaque disassembly failed: {error}", lane.version) + }); + + for expected in expected_substrings { + assert!( + disassembly.contains(expected), + "llvm@{} disassembly should contain {expected:?}, got:\n{disassembly}", + lane.version + ); + } + + for absent in absent_substrings { + assert!( + !disassembly.contains(absent), + "llvm@{} disassembly should not contain {absent:?}, got:\n{disassembly}", + lane.version + ); + } + } +} + +#[test] +fn qir_explicit_typed_emit_target_roundtrips_named_pointer_module() { + use super::bitcode::reader::parse_bitcode; + + let module = qir_typed_pointer_smoke_module(); + let bitcode = write_bitcode_for_target(&module, QirEmitTarget::QirV1Typed); + let round_tripped = parse_bitcode(&bitcode).expect("typed lane parse failed"); + + assert_eq!( + round_tripped.functions[0].params[0].ty, + Type::NamedPtr("Qubit".into()) + ); + match &round_tripped.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { args, .. } => assert_eq!( + args, + &vec![( + Type::NamedPtr("Qubit".into()), + Operand::int_to_named_ptr(0, "Qubit"), + )] + ), + other => panic!("expected typed-lane call instruction, found {other:?}"), + } +} + +#[test] +fn qir_explicit_opaque_emit_target_roundtrips_opaque_pointer_module() { + use super::bitcode::reader::parse_bitcode; + + let module = qir_opaque_pointer_smoke_module(); + let bitcode = write_bitcode_for_target(&module, QirEmitTarget::QirV2Opaque); + let round_tripped = parse_bitcode(&bitcode).expect("opaque lane parse failed"); + + assert_eq!(round_tripped.functions[0].params[0].ty, Type::Ptr); + match &round_tripped.functions[1].basic_blocks[0].instructions[0] { + Instruction::Call { callee, args, .. } => { + assert_eq!(callee, "__quantum__qis__h__body"); + assert_eq!(args, &vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))]); + } + other => panic!("expected opaque-lane call instruction, found {other:?}"), + } +} + +#[test] +fn typed_pointer_gep_text_roundtrip_preserves_base_pointer_type() { + let module = qir_typed_pointer_gep_smoke_module(); + let text = write_module_to_string(&module); + let parsed = parse_module(&text).expect("typed-pointer GEP text should parse"); + let expected = Type::TypedPtr(Box::new(Type::Array(4, Box::new(Type::Integer(8))))); + + assert_eq!(main_gep_ptr_ty(&module), &expected); + assert_eq!(main_gep_ptr_ty(&parsed), &expected); +} + +#[test] +fn qir_emitted_opaque_bitcode_verifies_across_external_opaque_lanes() { + const REQUIRED_OPAQUE_LANES: [u8; 2] = [16, 21]; + + let available_lanes = available_fast_matrix_lanes(); + let missing_lanes: Vec<_> = REQUIRED_OPAQUE_LANES + .into_iter() + .filter(|version| !available_lanes.iter().any(|lane| lane.version == *version)) + .collect(); + if !missing_lanes.is_empty() { + let missing = missing_lanes + .into_iter() + .map(|version| format!("llvm@{version}")) + .collect::>() + .join(", "); + eprintln!( + "required external LLVM opaque lanes not available, skipping emitted opaque verification: {missing}" + ); + return; + } + + let module = qir_opaque_pointer_smoke_module(); + let bitcode = write_bitcode_for_target(&module, QirEmitTarget::QirV2Opaque); + + for lane in available_lanes + .into_iter() + .filter(|lane| REQUIRED_OPAQUE_LANES.contains(&lane.version)) + { + verify_bitcode(lane, PointerProbe::OpaqueText, &bitcode).unwrap_or_else(|error| { + let reproducer = disassemble_bitcode(lane, PointerProbe::OpaqueText, &bitcode) + .map(|disassembly| format!("disassembly:\n{disassembly}")) + .or_else(|disassembly_error| { + analyze_bitcode(lane, &bitcode).map(|dump| { + format!( + "disassembly failed: {disassembly_error}\nllvm-bcanalyzer dump:\n{dump}" + ) + }) + }) + .unwrap_or_else(|analyzer_error| { + format!( + "disassembly and llvm-bcanalyzer both failed while preparing a reproducer: {analyzer_error}" + ) + }); + panic!( + "llvm@{} rejected qsc_llvm-emitted opaque bitcode: {error}\n{reproducer}", + lane.version + ); + }); + } +} + +#[test] +fn qir_emitted_opaque_bitcode_uses_modern_module_function_naming_records() { + let lane = LlvmCompatLane::LLVM_21; + if !lane.is_available() || !lane.has_tool("llvm-bcanalyzer") { + eprintln!("llvm@21 with llvm-bcanalyzer not available, skipping test"); + return; + } + + let module = qir_opaque_pointer_smoke_module(); + let bitcode = write_bitcode_for_target(&module, QirEmitTarget::QirV2Opaque); + + let analysis = analyze_bitcode(lane, &bitcode) + .unwrap_or_else(|error| panic!("llvm@21 llvm-bcanalyzer failed: {error}")); + + for expected in [" { + let typed_bc = + assemble_text_ir(lane, PointerProbe::TypedText, qir_typed_pointer_smoke_ir()) + .unwrap_or_else(|error| panic!("llvm@14 typed assemble failed: {error}")); + verify_bitcode(lane, PointerProbe::TypedText, &typed_bc) + .unwrap_or_else(|error| panic!("llvm@14 typed verify failed: {error}")); + + let opaque_bc = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + qir_opaque_pointer_smoke_ir(), + ) + .unwrap_or_else(|error| panic!("llvm@14 opaque assemble failed: {error}")); + verify_bitcode(lane, PointerProbe::OpaqueText, &opaque_bc) + .unwrap_or_else(|error| panic!("llvm@14 opaque verify failed: {error}")); + } + 15 => { + let typed_bc = + assemble_text_ir(lane, PointerProbe::TypedText, qir_typed_pointer_smoke_ir()) + .unwrap_or_else(|error| panic!("llvm@15 typed assemble failed: {error}")); + verify_bitcode(lane, PointerProbe::TypedText, &typed_bc) + .unwrap_or_else(|error| panic!("llvm@15 typed verify failed: {error}")); + let typed_disassembly = + disassemble_bitcode(lane, PointerProbe::TypedText, &typed_bc).unwrap_or_else( + |error| panic!("llvm@15 typed disassembly failed: {error}"), + ); + assert!( + typed_disassembly.contains("%Qubit*"), + "llvm@15 bridge lane should preserve typed spelling, got:\n{typed_disassembly}" + ); + + let opaque_bc = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + qir_opaque_pointer_smoke_ir(), + ) + .unwrap_or_else(|error| panic!("llvm@15 opaque assemble failed: {error}")); + verify_bitcode(lane, PointerProbe::OpaqueText, &opaque_bc) + .unwrap_or_else(|error| panic!("llvm@15 opaque verify failed: {error}")); + let opaque_disassembly = + disassemble_bitcode(lane, PointerProbe::OpaqueText, &opaque_bc).unwrap_or_else( + |error| panic!("llvm@15 opaque disassembly failed: {error}"), + ); + assert!( + opaque_disassembly.contains("ptr"), + "llvm@15 bridge lane should preserve opaque spelling, got:\n{opaque_disassembly}" + ); + } + 16 | 21 => { + let opaque_bc = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + qir_opaque_pointer_smoke_ir(), + ) + .unwrap_or_else(|error| { + panic!("llvm@{} opaque assemble failed: {error}", lane.version) + }); + verify_bitcode(lane, PointerProbe::OpaqueText, &opaque_bc).unwrap_or_else( + |error| panic!("llvm@{} opaque verify failed: {error}", lane.version), + ); + } + other => panic!("unexpected fast-matrix lane llvm@{other}"), + } + } +} + +#[test] +fn qir_external_llvm_fast_matrix_accepts_half_only_float_metadata_artifact() { + assert_opaque_qir_text_fixture_survives_fast_matrix( + adaptive_half_only_float_external_fixture_ir(), + &["fadd half", "!\"float_computations\"", "!{!\"half\"}"], + &["!{!\"half\","], + ); +} + +#[test] +fn qir_external_llvm_fast_matrix_accepts_no_float_artifact() { + assert_opaque_qir_text_fixture_survives_fast_matrix( + adaptive_no_float_external_fixture_ir(), + &["ret i64 0", "!\"qir_major_version\""], + &["!\"float_computations\""], + ); +} + +#[test] +fn external_global_initializer_bitcode_is_preserved_in_strict_mode() { + let Some(lane) = available_fast_matrix_lanes().into_iter().next() else { + eprintln!( + "no external LLVM fast-matrix lane is available, skipping unsupported-input bitcode fixture" + ); + return; + }; + + let bitcode = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + "@0 = internal constant [4 x i8] c\"0_r\\00\"\n", + ) + .unwrap_or_else(|error| { + panic!( + "llvm@{} should assemble external global initializer fixture: {error}", + lane.version + ) + }); + + let module = parse_bitcode_detailed(&bitcode, ReadPolicy::QirSubsetStrict) + .expect("strict bitcode import should preserve supported global initializers"); + + assert_eq!(module.globals.len(), 1); + assert_eq!( + module.globals[0].ty, + Type::Array(4, Box::new(Type::Integer(8))) + ); + assert!(module.globals[0].is_constant); + assert_eq!( + module.globals[0].initializer, + Some(Constant::CString("0_r".to_string())) + ); +} + +#[test] +fn external_global_initializer_bitcode_has_no_compatibility_diagnostics() { + let Some(lane) = available_fast_matrix_lanes().into_iter().next() else { + eprintln!( + "no external LLVM fast-matrix lane is available, skipping unsupported-input bitcode fixture" + ); + return; + }; + + let bitcode = assemble_text_ir( + lane, + PointerProbe::OpaqueText, + "@0 = internal constant [4 x i8] c\"0_r\\00\"\n", + ) + .unwrap_or_else(|error| { + panic!( + "llvm@{} should assemble compatibility-report fixture: {error}", + lane.version + ) + }); + + let report = parse_bitcode_compatibility_report(&bitcode) + .expect("compatibility bitcode import should preserve supported global initializers"); + + assert!( + report.diagnostics.is_empty(), + "unexpected diagnostics: {:?}", + report.diagnostics + ); + assert_eq!(report.module.globals.len(), 1); + assert_eq!( + report.module.globals[0].initializer, + Some(Constant::CString("0_r".to_string())) + ); +} + +#[test] +fn round_trip_gep_inbounds() { + let ir = "\ +@str = internal constant [5 x i8] c\"hello\"\n\ +\n\ +declare void @use_ptr(ptr)\n\ +\n\ +define void @test_fn() {\n\ +entry:\n\ + %0 = getelementptr inbounds [5 x i8], ptr @str, i64 0, i64 0\n\ + call void @use_ptr(ptr %0)\n\ + ret void\n\ +}\n"; + let parsed = parse_module(ir).expect("failed to parse GEP IR"); + let text1 = write_module_to_string(&parsed); + let reparsed = parse_module(&text1).expect("failed to reparse GEP IR"); + let text2 = write_module_to_string(&reparsed); + assert_eq!(text1, text2); + assert_eq!(parsed, reparsed); + + // Verify the instruction is a GetElementPtr + let func = parsed + .functions + .iter() + .find(|f| f.name == "test_fn") + .expect("missing test_fn"); + let instrs = &func.basic_blocks[0].instructions; + assert!( + matches!( + &instrs[0], + Instruction::GetElementPtr { inbounds: true, .. } + ), + "expected GEP inbounds instruction" + ); +} + +#[test] +fn round_trip_gep_no_inbounds() { + let ir = "\ +@label = internal constant [10 x i8] c\"some_label\"\n\ +\n\ +declare void @use_ptr(ptr)\n\ +\n\ +define void @test_fn() {\n\ +entry:\n\ + %0 = getelementptr [10 x i8], ptr @label, i64 0, i64 0\n\ + call void @use_ptr(ptr %0)\n\ + ret void\n\ +}\n"; + let parsed = parse_module(ir).expect("failed to parse GEP IR"); + let text1 = write_module_to_string(&parsed); + let reparsed = parse_module(&text1).expect("failed to reparse GEP IR"); + let text2 = write_module_to_string(&reparsed); + assert_eq!(text1, text2); + assert_eq!(parsed, reparsed); + + // Verify the instruction is a GetElementPtr without inbounds + let func = parsed + .functions + .iter() + .find(|f| f.name == "test_fn") + .expect("missing test_fn"); + let instrs = &func.basic_blocks[0].instructions; + assert!( + matches!( + &instrs[0], + Instruction::GetElementPtr { + inbounds: false, + .. + } + ), + "expected GEP non-inbounds instruction" + ); +} + +#[test] +fn parse_adaptive_profile_ir() { + // Test patterns from the adaptive profile tests that caused failures + let ir = r#"%Result = type opaque +%Qubit = type opaque + +define void @make_bell(%Qubit* %q0, %Qubit* %q1) { +entry: + call void @__quantum__qis__h__body(%Qubit* %q0) + call void @__quantum__qis__cx__body(%Qubit* %q0, %Qubit* %q1) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) +declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions.len(), 3); +} + +#[test] +fn parse_icmp_ult_in_context() { + let ir = r#"define void @test() { +entry: + %cond = icmp ult i64 %i, 8 + ret void +} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions.len(), 1); +} + +#[test] +fn parse_block_with_comment() { + let ir = r#"define void @test() { +block_0: + br label %loop_cond +loop_cond: ; preds = %loop_body, %block_0 + ret void +} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions[0].basic_blocks.len(), 2); +} + +#[test] +fn parse_declare_i1_return() { + let ir = r#"declare i1 @__quantum__rt__read_loss(%Result*) +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions.len(), 1); +} + +#[test] +fn parse_metadata_i1_bool() { + let ir = r#"!llvm.module.flags = !{!0} +!0 = !{i32 1, !"dynamic_qubit_management", i1 false} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.metadata_nodes.len(), 1); +} + +#[test] +fn parse_phi_with_named_blocks() { + let ir = r#"define void @test() { +block_0: + br label %loop_cond +loop_cond: + %i = phi i64 [ 0, %block_0 ], [ %i_next, %loop_body ] + ret void +loop_body: + %i_next = add i64 %i, 1 + br label %loop_cond +} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions[0].basic_blocks.len(), 3); +} + +#[test] +fn parse_metadata_nested_group() { + let ir = r#"!llvm.module.flags = !{!0} +!0 = !{i32 5, !"int_computations", !{!"i64"}} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.metadata_nodes.len(), 1); +} + +#[test] +fn parse_two_declares_same_line() { + // Some IR has declarations on consecutive lines without blank lines + let ir = r#"declare void @__quantum__rt__bool_record_output(i1, i8*) +declare void @__quantum__rt__int_record_output(i64, i8*) +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions.len(), 2); +} + +#[test] +fn parse_full_bell_loop_funcs() { + let ir = r#"%Result = type opaque +%Qubit = type opaque + +define i64 @ENTRYPOINT__main() #0 { +block_0: + br label %loop_cond +loop_cond: ; preds = %loop_body, %block_0 + %i = phi i64 [ 0, %block_0 ], [ %i_next, %loop_body ] + %cond = icmp ult i64 %i, 8 + br i1 %cond, label %loop_body, label %loop_cond2 +loop_body: ; preds = %loop_cond + %q0 = inttoptr i64 %i to %Qubit* + %i1 = add i64 %i, 1 + %q1 = inttoptr i64 %i1 to %Qubit* + call void @make_bell(%Qubit* %q0, %Qubit* %q1) + %i_next = add i64 %i, 2 + br label %loop_cond +loop_cond2: ; preds = %loop_cond + %i3 = phi i64 [ 0, %loop_cond ], [ %i_next2, %loop_body2 ] + %cond2 = icmp ult i64 %i3, 16 + br i1 %cond2, label %loop_body2, label %end +loop_body2: ; preds = %loop_cond2 + %q2 = inttoptr i64 %i3 to %Qubit* + %r = inttoptr i64 %i3 to %Result* + call void @__quantum__qis__mresetz__body(%Qubit* %q2, %Result* %r) + %i_next2 = add i64 %i3, 1 + br label %loop_cond2 +end: ; preds = %loop_cond2 + call void @__quantum__rt__array_record_output(i64 8, i8* null) + ret i64 0 +} + +define void @make_bell(%Qubit* %q0, %Qubit* %q1) { +entry: + call void @__quantum__qis__h__body(%Qubit* %q0) + call void @__quantum__qis__cx__body(%Qubit* %q0, %Qubit* %q1) + ret void +} + +declare void @__quantum__qis__h__body(%Qubit*) +declare void @__quantum__qis__cx__body(%Qubit*, %Qubit*) +declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1 +declare void @__quantum__rt__array_record_output(i64, i8*) + +attributes #0 = { "entry_point" "qir_profiles"="adaptive_profile" "required_num_qubits"="16" "required_num_results"="16" } +attributes #1 = { "irreversible" } + +!llvm.module.flags = !{!0, !1, !2, !3} + +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} +"#; + let m = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + assert_eq!(m.functions.len(), 6); + assert_eq!(m.functions[0].basic_blocks.len(), 6); +} + +// --- Phase 2: Text roundtrip tests for untested BinOp variants --- + +#[test] +fn round_trip_sub() { + let ir = wrap_instr_i64(" %r = sub i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_mul() { + let ir = wrap_instr_i64(" %r = mul i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_sdiv() { + let ir = wrap_instr_i64(" %r = sdiv i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_srem() { + let ir = wrap_instr_i64(" %r = srem i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_shl() { + let ir = wrap_instr_i64(" %r = shl i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_ashr() { + let ir = wrap_instr_i64(" %r = ashr i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_and() { + let ir = wrap_instr_i64(" %r = and i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_or() { + let ir = wrap_instr_i64(" %r = or i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_xor() { + let ir = wrap_instr_i64(" %r = xor i64 %a, %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_fadd() { + let ir = "\ +declare void @use_double(double) + +define void @test(double %a, double %b) { +entry: + %r = fadd double %a, %b + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fsub() { + let ir = "\ +declare void @use_double(double) + +define void @test(double %a, double %b) { +entry: + %r = fsub double %a, %b + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fmul() { + let ir = "\ +declare void @use_double(double) + +define void @test(double %a, double %b) { +entry: + %r = fmul double %a, %b + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fdiv() { + let ir = "\ +declare void @use_double(double) + +define void @test(double %a, double %b) { +entry: + %r = fdiv double %a, %b + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +// --- Phase 2: Text roundtrip tests for untested ICmp predicates --- + +#[test] +fn round_trip_icmp_eq() { + let ir = wrap_instr_i64(" %c = icmp eq i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_ne() { + let ir = wrap_instr_i64(" %c = icmp ne i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_sgt() { + let ir = wrap_instr_i64(" %c = icmp sgt i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_sge() { + let ir = wrap_instr_i64(" %c = icmp sge i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +#[test] +fn round_trip_icmp_sle() { + let ir = wrap_instr_i64(" %c = icmp sle i64 %a, %b\n %r = select i1 %c, i64 %a, i64 %b"); + round_trip_text(&ir); +} + +// --- Phase 2: Text roundtrip tests for FCmp predicates --- + +#[test] +fn round_trip_fcmp_oeq() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp oeq double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ogt() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ogt double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_oge() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp oge double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_olt() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp olt double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ole() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ole double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_one() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp one double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ord() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ord double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ueq() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ueq double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ugt() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ugt double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_uge() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp uge double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ult() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ult double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_ule() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp ule double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_une() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp une double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fcmp_uno() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp uno double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + round_trip_text(ir); +} + +// --- Phase 2: Text roundtrip tests for missing instruction types --- + +#[test] +fn round_trip_phi() { + let ir = "\ +define i64 @test() { +block_0: + br label %loop +loop: + %i = phi i64 [ 0, %block_0 ], [ %next, %loop ] + %next = add i64 %i, 1 + %cond = icmp slt i64 %next, 10 + br i1 %cond, label %loop, label %exit +exit: + ret i64 %i +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_alloca_load_store() { + let ir = "\ +define i64 @test() { +entry: + %ptr = alloca i64 + store i64 42, ptr %ptr + %val = load i64, ptr %ptr + ret i64 %val +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_sitofp() { + let ir = "\ +declare void @use_double(double) + +define void @test(i64 %a) { +entry: + %r = sitofp i64 %a to double + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_fptosi() { + let ir = "\ +define i64 @test(double %a) { +entry: + %r = fptosi double %a to i64 + ret i64 %r +} +"; + round_trip_text(ir); +} + +// --- Phase 2: Text roundtrip tests for missing operand types --- + +#[test] +fn round_trip_float_const() { + let ir = "\ +declare void @use_double(double) + +define void @test() { +entry: + %r = fadd double 1.0, 2.0 + call void @use_double(double %r) + ret void +} +"; + round_trip_text(ir); +} + +#[test] +fn round_trip_null_ptr() { + let ir = "\ +declare void @use_ptr(ptr) + +define void @test() { +entry: + call void @use_ptr(ptr null) + ret void +} +"; + round_trip_text(ir); +} + +// --- Phase 3: Bitcode roundtrip tests for instruction types --- + +/// Helper: parse text IR, write to bitcode, parse bitcode back, return both modules. +fn bitcode_roundtrip(ir: &str) -> (Module, Module) { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let module = parse_module(ir).unwrap_or_else(|e| panic!("text parse failed: {e}")); + let bc = write_bitcode(&module); + let rt = parse_bitcode(&bc).unwrap_or_else(|e| panic!("bitcode parse failed: {e}")); + (module, rt) +} + +/// Helper: assert structural equivalence of functions (non-lossy fields only). +fn assert_bitcode_roundtrip_structure(orig: &Module, rt: &Module) { + assert_eq!(orig.functions.len(), rt.functions.len()); + for (fo, fr) in orig.functions.iter().zip(rt.functions.iter()) { + assert_eq!(fo.name, fr.name); + assert_eq!(fo.is_declaration, fr.is_declaration); + assert_eq!(fo.params.len(), fr.params.len()); + assert_eq!(fo.basic_blocks.len(), fr.basic_blocks.len()); + for (bo, br) in fo.basic_blocks.iter().zip(fr.basic_blocks.iter()) { + assert_eq!( + bo.instructions.len(), + br.instructions.len(), + "instruction count mismatch in block '{}' of function '{}'", + bo.name, + fo.name, + ); + } + } +} + +fn qir_smith_v1_config(profile: QirProfilePreset) -> QirSmithConfig { + QirSmithConfig { + max_blocks_per_func: 4, + max_instrs_per_block: 6, + ..QirSmithConfig::for_profile(profile) + } +} + +fn assert_qir_smith_v1_type(ty: &Type) { + match ty { + Type::NamedPtr(_) | Type::TypedPtr(_) => { + panic!("typed-pointer types are outside the v1 qir_smith test boundary") + } + Type::Array(_, element) => assert_qir_smith_v1_type(element), + Type::Function(result, params) => { + assert_qir_smith_v1_type(result); + for param in params { + assert_qir_smith_v1_type(param); + } + } + Type::Void + | Type::Integer(_) + | Type::Half + | Type::Float + | Type::Double + | Type::Label + | Type::Ptr + | Type::Named(_) => {} + } +} + +fn assert_qir_smith_v1_operand(operand: &Operand) { + match operand { + Operand::IntConst(ty, _) | Operand::IntToPtr(_, ty) => assert_qir_smith_v1_type(ty), + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + assert_qir_smith_v1_type(ty); + assert_qir_smith_v1_type(ptr_ty); + for index in indices { + assert_qir_smith_v1_operand(index); + } + } + Operand::LocalRef(_) + | Operand::TypedLocalRef(_, _) + | Operand::FloatConst(_, _) + | Operand::NullPtr + | Operand::GlobalRef(_) => {} + } +} + +#[allow(clippy::too_many_lines)] +fn assert_qir_smith_v1_module(module: &Module) { + for global in &module.globals { + assert_qir_smith_v1_type(&global.ty); + } + + for function in &module.functions { + assert_qir_smith_v1_type(&function.return_type); + for param in &function.params { + assert_qir_smith_v1_type(¶m.ty); + } + + for block in &function.basic_blocks { + for instruction in &block.instructions { + match instruction { + Instruction::Ret(Some(operand)) => assert_qir_smith_v1_operand(operand), + Instruction::Ret(None) + | Instruction::Jump { .. } + | Instruction::Unreachable => {} + Instruction::Br { cond_ty, cond, .. } => { + assert_qir_smith_v1_type(cond_ty); + assert_qir_smith_v1_operand(cond); + } + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + assert_qir_smith_v1_type(ty); + assert_qir_smith_v1_operand(lhs); + assert_qir_smith_v1_operand(rhs); + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + assert_qir_smith_v1_type(from_ty); + assert_qir_smith_v1_type(to_ty); + assert_qir_smith_v1_operand(value); + } + Instruction::Call { + return_ty, args, .. + } => { + if let Some(return_ty) = return_ty { + assert_qir_smith_v1_type(return_ty); + } + for (arg_ty, operand) in args { + assert_qir_smith_v1_type(arg_ty); + assert_qir_smith_v1_operand(operand); + } + } + Instruction::Phi { .. } => { + panic!("phi instructions are outside the v1 qir_smith test boundary") + } + Instruction::Alloca { ty, .. } => assert_qir_smith_v1_type(ty), + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + assert_qir_smith_v1_type(ty); + assert_qir_smith_v1_type(ptr_ty); + assert_qir_smith_v1_operand(ptr); + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + assert_qir_smith_v1_type(ty); + assert_qir_smith_v1_operand(value); + assert_qir_smith_v1_type(ptr_ty); + assert_qir_smith_v1_operand(ptr); + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + assert_qir_smith_v1_operand(cond); + assert_qir_smith_v1_operand(true_val); + assert_qir_smith_v1_operand(false_val); + assert_qir_smith_v1_type(ty); + } + Instruction::Switch { .. } => { + panic!("switch instructions are outside the v1 qir_smith test boundary") + } + Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + ptr, + indices, + .. + } => { + assert_qir_smith_v1_type(pointee_ty); + assert_qir_smith_v1_type(ptr_ty); + assert_qir_smith_v1_operand(ptr); + for index in indices { + assert_qir_smith_v1_operand(index); + } + } + } + } + } + } +} + +fn has_opaque_ptr_in_type(ty: &Type) -> bool { + match ty { + Type::Ptr => true, + Type::Array(_, element) => has_opaque_ptr_in_type(element), + Type::Function(result, params) => { + has_opaque_ptr_in_type(result) || params.iter().any(has_opaque_ptr_in_type) + } + _ => false, + } +} + +fn has_opaque_ptr_in_operand(op: &Operand) -> bool { + match op { + Operand::IntConst(ty, _) | Operand::IntToPtr(_, ty) => has_opaque_ptr_in_type(ty), + Operand::GetElementPtr { + ty, + ptr_ty, + indices, + .. + } => { + has_opaque_ptr_in_type(ty) + || has_opaque_ptr_in_type(ptr_ty) + || indices.iter().any(has_opaque_ptr_in_operand) + } + _ => false, + } +} + +fn instr_has_opaque_ptr(instr: &Instruction) -> bool { + match instr { + Instruction::Call { + return_ty, args, .. + } => { + if let Some(rt) = return_ty + && has_opaque_ptr_in_type(rt) + { + return true; + } + args.iter() + .any(|(ty, op)| has_opaque_ptr_in_type(ty) || has_opaque_ptr_in_operand(op)) + } + Instruction::Br { cond_ty, cond, .. } => { + has_opaque_ptr_in_type(cond_ty) || has_opaque_ptr_in_operand(cond) + } + Instruction::BinOp { ty, lhs, rhs, .. } + | Instruction::ICmp { ty, lhs, rhs, .. } + | Instruction::FCmp { ty, lhs, rhs, .. } => { + has_opaque_ptr_in_type(ty) + || has_opaque_ptr_in_operand(lhs) + || has_opaque_ptr_in_operand(rhs) + } + Instruction::Cast { + from_ty, + to_ty, + value, + .. + } => { + has_opaque_ptr_in_type(from_ty) + || has_opaque_ptr_in_type(to_ty) + || has_opaque_ptr_in_operand(value) + } + Instruction::Ret(Some(op)) => has_opaque_ptr_in_operand(op), + Instruction::Alloca { ty, .. } => has_opaque_ptr_in_type(ty), + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + has_opaque_ptr_in_type(ty) + || has_opaque_ptr_in_type(ptr_ty) + || has_opaque_ptr_in_operand(ptr) + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + has_opaque_ptr_in_type(ty) + || has_opaque_ptr_in_operand(value) + || has_opaque_ptr_in_type(ptr_ty) + || has_opaque_ptr_in_operand(ptr) + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + has_opaque_ptr_in_operand(cond) + || has_opaque_ptr_in_operand(true_val) + || has_opaque_ptr_in_operand(false_val) + || has_opaque_ptr_in_type(ty) + } + Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + ptr, + indices, + .. + } => { + has_opaque_ptr_in_type(pointee_ty) + || has_opaque_ptr_in_type(ptr_ty) + || has_opaque_ptr_in_operand(ptr) + || indices.iter().any(has_opaque_ptr_in_operand) + } + _ => false, + } +} + +/// Run profile validation on a generated module and verify it completes +/// without panicking. Does not assert that violations are empty because +/// ``BareRoundtrip`` intentionally lacks QIR metadata (MS-*/MF-* violations). +fn assert_qir_smith_profile_validation_does_not_panic(module: &Module) { + // Validate — must not panic regardless of violation count. + let _result = super::validate_qir_profile(module); +} + +/// Run profile validation on a generated module and assert that all +/// errors are empty. Use this for QIR-profile-conformant modules +fn assert_qir_smith_profile_valid(module: &Module) { + let result = super::validate_qir_profile(module); + assert!( + result.errors.is_empty(), + "Generated module has profile errors: {:#?}", + result.errors + ); +} + +fn assert_qir_smith_typed_pointer_module(module: &Module) { + assert!( + module.struct_types.iter().any(|s| s.name == "Qubit"), + "v1 module must define %Qubit struct type" + ); + assert!( + module.struct_types.iter().any(|s| s.name == "Result"), + "v1 module must define %Result struct type" + ); + + for global in &module.globals { + assert!( + !has_opaque_ptr_in_type(&global.ty), + "opaque Ptr in global {}: {:?}", + global.name, + global.ty + ); + } + + for function in &module.functions { + assert!( + !has_opaque_ptr_in_type(&function.return_type), + "opaque Ptr in return type of {}: {:?}", + function.name, + function.return_type + ); + for (pi, param) in function.params.iter().enumerate() { + assert!( + !has_opaque_ptr_in_type(¶m.ty), + "opaque Ptr in param {pi} of {}: {:?}", + function.name, + param.ty + ); + } + + for (bi, block) in function.basic_blocks.iter().enumerate() { + for (ii, instruction) in block.instructions.iter().enumerate() { + assert!( + !instr_has_opaque_ptr(instruction), + "opaque Ptr in function {} block {bi} instruction {ii}: {instruction:?}", + function.name, + ); + } + } + } +} + +fn assert_generated_qir_smith_roundtrips(config: &QirSmithConfig, seed: &[u8]) { + let generated = generate_module_from_bytes(config, seed) + .unwrap_or_else(|err| panic!("qir_smith generation failed: {err}")); + assert_qir_smith_v1_module(&generated); + + if matches!(config.profile, QirProfilePreset::BareRoundtrip) { + // BareRoundtrip intentionally lacks QIR metadata. + assert_qir_smith_profile_validation_does_not_panic(&generated); + } else { + assert_qir_smith_profile_valid(&generated); + } + + let ir = write_module_to_string(&generated); + round_trip_text(&ir); + + let orig = parse_module(&ir).unwrap_or_else(|e| panic!("text parse failed: {e}")); + let bc = write_bitcode_for_target(&orig, QirEmitTarget::QirV2Opaque); + let report = parse_bitcode_compatibility_report(&bc).unwrap_or_else(|diagnostics| { + panic!("bitcode compatibility parse failed: {diagnostics:?}") + }); + if orig + .globals + .iter() + .any(|global| global.initializer.is_some()) + { + assert!( + report.diagnostics.is_empty(), + "unexpected compatibility diagnostics for qir_smith globals with initializers: {:?}", + report.diagnostics + ); + } + let rt = report.module; + assert_qir_smith_v1_module(&orig); + assert_qir_smith_v1_module(&rt); + assert_bitcode_roundtrip_structure(&orig, &rt); +} + +#[test] +fn qir_smith_adaptive_v2_roundtrips_with_fixed_seed() { + let config = qir_smith_v1_config(QirProfilePreset::AdaptiveV2); + assert_generated_qir_smith_roundtrips(&config, b"qir-smith-opaque-adaptive-v1"); +} + +#[test] +fn qir_smith_bare_roundtrip_roundtrips_with_fixed_seed() { + let config = qir_smith_v1_config(QirProfilePreset::BareRoundtrip); + assert_generated_qir_smith_roundtrips(&config, b"qir-smith-bare-roundtrip-v1"); +} + +#[test] +fn qir_smith_base_v1_text_roundtrips_with_fixed_seed() { + let config = qir_smith_v1_config(QirProfilePreset::BaseV1); + let generated = generate_module_from_bytes(&config, b"base-v1-seed-000") + .unwrap_or_else(|err| panic!("qir_smith BaseV1 generation failed: {err}")); + assert_qir_smith_typed_pointer_module(&generated); + assert_qir_smith_profile_valid(&generated); + + let ir = write_module_to_string(&generated); + let parsed = parse_module(&ir).unwrap_or_else(|e| panic!("BaseV1 text parse failed: {e}")); + assert_qir_smith_typed_pointer_module(&parsed); + + round_trip_text(&ir); +} + +#[test] +fn qir_smith_adaptive_v1_text_roundtrips_with_fixed_seed() { + let config = qir_smith_v1_config(QirProfilePreset::AdaptiveV1); + let generated = generate_module_from_bytes(&config, b"qir-smith-adaptive-v1") + .unwrap_or_else(|err| panic!("qir_smith AdaptiveV1 generation failed: {err}")); + assert_qir_smith_typed_pointer_module(&generated); + assert_qir_smith_profile_valid(&generated); + + let ir = write_module_to_string(&generated); + let parsed = parse_module(&ir).unwrap_or_else(|e| panic!("AdaptiveV1 text parse failed: {e}")); + assert_qir_smith_typed_pointer_module(&parsed); + + round_trip_text(&ir); + + let bc = write_bitcode_for_target(&parsed, QirEmitTarget::QirV1Typed); + let report = parse_bitcode_compatibility_report(&bc).unwrap_or_else(|diagnostics| { + panic!("bitcode compatibility parse failed: {diagnostics:?}") + }); + assert!( + report.diagnostics.is_empty(), + "unexpected compatibility diagnostics for AdaptiveV1 globals with initializers: {:?}", + report.diagnostics + ); + let rt = report.module; + // This compatibility-report roundtrip remains structural-only; the direct + // text parse above covers typed-pointer fidelity. + assert_bitcode_roundtrip_structure(&parsed, &rt); +} + +#[test] +fn qir_v2_qubit_and_result_operands_remain_opaque_ptrs() { + let ir = r#"%Result = type opaque +%Qubit = type opaque + +@0 = internal constant [4 x i8] c"0_r\00" + +define i64 @ENTRYPOINT__main() #0 { +entry: + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__m__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @0) + ret i64 0 +} + +declare void @__quantum__rt__initialize(ptr) +declare void @__quantum__qis__h__body(ptr) +declare void @__quantum__qis__m__body(ptr, ptr) #1 +declare void @__quantum__rt__result_record_output(ptr, ptr) + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +!llvm.module.flags = !{!0, !1, !2, !3} +!0 = !{i32 1, !"qir_major_version", i32 2} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} +"#; + + let module = parse_module(ir).unwrap_or_else(|e| panic!("parse failed: {e}")); + + assert_eq!( + module.struct_types, + vec![ + StructType { + name: "Result".to_string(), + is_opaque: true, + }, + StructType { + name: "Qubit".to_string(), + is_opaque: true, + }, + ] + ); + + for name in [ + "__quantum__rt__initialize", + "__quantum__qis__h__body", + "__quantum__qis__m__body", + "__quantum__rt__result_record_output", + ] { + let function = module + .functions + .iter() + .find(|function| function.name == name) + .unwrap_or_else(|| panic!("missing function {name}")); + for (index, param) in function.params.iter().enumerate() { + assert_eq!( + param.ty, + Type::Ptr, + "{name} param {index} should remain ptr, got {:?}", + param.ty + ); + } + } + + let entry = module + .functions + .iter() + .find(|function| function.name == "ENTRYPOINT__main") + .expect("missing entry point"); + for instruction in &entry.basic_blocks[0].instructions { + if let Instruction::Call { callee, args, .. } = instruction { + for (index, (arg_ty, operand)) in args.iter().enumerate() { + if let Operand::IntToPtr(_, cast_ty) = operand { + assert_eq!( + *arg_ty, + Type::Ptr, + "{callee} arg {index} should use ptr, got {arg_ty:?}" + ); + assert_eq!( + *cast_ty, + Type::Ptr, + "{callee} arg {index} inttoptr should target ptr, got {cast_ty:?}" + ); + } + } + } + } + + let round_tripped = parse_module(&write_module_to_string(&module)) + .unwrap_or_else(|e| panic!("roundtrip parse failed: {e}")); + assert_eq!(module, round_tripped); +} + +#[allow(clippy::too_many_lines)] +#[test] +fn adaptive_float_computations_roundtrip_preserves_half_float_and_double() { + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "use_half".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Half, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "use_float".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Float, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "use_double".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Double, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "ENTRYPOINT__main".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Half, + lhs: Operand::float_const(Type::Half, 1.5), + rhs: Operand::float_const(Type::Half, 2.25), + result: "half_sum".to_string(), + }, + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Half, + to_ty: Type::Float, + value: Operand::LocalRef("half_sum".to_string()), + result: "as_float".to_string(), + }, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Float, + lhs: Operand::LocalRef("as_float".to_string()), + rhs: Operand::float_const(Type::Float, 0.5), + result: "float_sum".to_string(), + }, + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Float, + to_ty: Type::Double, + value: Operand::LocalRef("float_sum".to_string()), + result: "as_double".to_string(), + }, + Instruction::Cast { + op: CastKind::FpTrunc, + from_ty: Type::Double, + to_ty: Type::Half, + value: Operand::LocalRef("as_double".to_string()), + result: "back_to_half".to_string(), + }, + Instruction::Call { + return_ty: None, + callee: "use_half".to_string(), + args: vec![(Type::Half, Operand::LocalRef("back_to_half".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "use_float".to_string(), + args: vec![(Type::Float, Operand::LocalRef("float_sum".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "use_double".to_string(), + args: vec![(Type::Double, Operand::LocalRef("as_double".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ], + }], + }, + ], + attribute_groups: vec![AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr("entry_point".to_string()), + Attribute::StringAttr("output_labeling_schema".to_string()), + Attribute::KeyValue("qir_profiles".to_string(), "adaptive_profile".to_string()), + Attribute::KeyValue("required_num_qubits".to_string(), "0".to_string()), + Attribute::KeyValue("required_num_results".to_string(), "0".to_string()), + ], + }], + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1, 2, 3, 4], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("qir_minor_version".to_string()), + MetadataValue::Int(Type::Integer(32), 0), + ], + }, + MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_qubit_management".to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 3, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_result_management".to_string()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }, + MetadataNode { + id: 4, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("float_computations".to_string()), + MetadataValue::SubList(vec![ + MetadataValue::String("half".to_string()), + MetadataValue::String("float".to_string()), + MetadataValue::String("double".to_string()), + ]), + ], + }, + ], + }; + + let text = write_module_to_string(&module); + let parsed = parse_module(&text).unwrap_or_else(|e| panic!("parse failed: {e}")); + + assert_eq!( + parsed.get_flag("float_computations"), + Some(&MetadataValue::SubList(vec![ + MetadataValue::String("half".to_string()), + MetadataValue::String("float".to_string()), + MetadataValue::String("double".to_string()), + ])) + ); + + for (name, expected_ty) in [ + ("use_half", Type::Half), + ("use_float", Type::Float), + ("use_double", Type::Double), + ] { + let function = parsed + .functions + .iter() + .find(|function| function.name == name) + .unwrap_or_else(|| panic!("missing function {name}")); + assert_eq!( + function.params, + vec![Param { + ty: expected_ty, + name: None, + }] + ); + } + + let entry = parsed + .functions + .iter() + .find(|function| function.name == "ENTRYPOINT__main") + .expect("missing entry point"); + let instructions = &entry.basic_blocks[0].instructions; + assert!(matches!( + &instructions[0], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Half, + lhs: Operand::FloatConst(Type::Half, _), + rhs: Operand::FloatConst(Type::Half, _), + result, + } if result == "half_sum" + )); + assert!(matches!( + &instructions[1], + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Half, + to_ty: Type::Float, + value: Operand::TypedLocalRef(name, Type::Half), + result, + } if name == "half_sum" && result == "as_float" + )); + assert!(matches!( + &instructions[2], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Float, + lhs: Operand::TypedLocalRef(name, Type::Float), + rhs: Operand::FloatConst(Type::Float, _), + result, + } if name == "as_float" && result == "float_sum" + )); + assert!(matches!( + &instructions[3], + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Float, + to_ty: Type::Double, + value: Operand::TypedLocalRef(name, Type::Float), + result, + } if name == "float_sum" && result == "as_double" + )); + assert!(matches!( + &instructions[4], + Instruction::Cast { + op: CastKind::FpTrunc, + from_ty: Type::Double, + to_ty: Type::Half, + value: Operand::TypedLocalRef(name, Type::Double), + result, + } if name == "as_double" && result == "back_to_half" + )); + + let (from_text, from_bc) = bitcode_roundtrip(&text); + let from_text_entry = from_text + .functions + .iter() + .find(|function| function.name == "ENTRYPOINT__main") + .expect("missing text entry point"); + let from_bc_entry = from_bc + .functions + .iter() + .find(|function| function.name == "ENTRYPOINT__main") + .expect("missing bitcode entry point"); + + assert_eq!( + from_text.get_flag("float_computations"), + from_bc.get_flag("float_computations") + ); + assert_eq!(from_text.functions.len(), from_bc.functions.len()); + assert_eq!( + from_bc.functions[0] + .params + .iter() + .map(|param| param.ty.clone()) + .collect::>(), + vec![Type::Half] + ); + assert_eq!( + from_bc.functions[1] + .params + .iter() + .map(|param| param.ty.clone()) + .collect::>(), + vec![Type::Float] + ); + assert_eq!( + from_bc.functions[2] + .params + .iter() + .map(|param| param.ty.clone()) + .collect::>(), + vec![Type::Double] + ); + + let from_bc_instructions = &from_bc_entry.basic_blocks[0].instructions; + assert_eq!( + from_text_entry.basic_blocks[0].instructions.len(), + from_bc_instructions.len() + ); + assert!(matches!( + &from_bc_instructions[0], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Half, + lhs: Operand::FloatConst(Type::Half, _), + rhs: Operand::FloatConst(Type::Half, _), + .. + } + )); + assert!(matches!( + &from_bc_instructions[1], + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Half, + to_ty: Type::Float, + value: Operand::TypedLocalRef(_, Type::Half), + .. + } + )); + assert!(matches!( + &from_bc_instructions[2], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Float, + lhs: Operand::TypedLocalRef(_, Type::Float), + rhs: Operand::FloatConst(Type::Float, _), + .. + } + )); + assert!(matches!( + &from_bc_instructions[3], + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Float, + to_ty: Type::Double, + value: Operand::TypedLocalRef(_, Type::Float), + .. + } + )); + assert!(matches!( + &from_bc_instructions[4], + Instruction::Cast { + op: CastKind::FpTrunc, + from_ty: Type::Double, + to_ty: Type::Half, + value: Operand::TypedLocalRef(_, Type::Double), + .. + } + )); + assert!(matches!( + &from_bc_instructions[5], + Instruction::Call { + return_ty: None, + args, + .. + } if args.len() == 1 + && args[0].0 == Type::Half + && matches!(&args[0].1, Operand::TypedLocalRef(_, Type::Half)) + )); + assert!(matches!( + &from_bc_instructions[6], + Instruction::Call { + return_ty: None, + args, + .. + } if args.len() == 1 + && args[0].0 == Type::Float + && matches!(&args[0].1, Operand::TypedLocalRef(_, Type::Float)) + )); + assert!(matches!( + &from_bc_instructions[7], + Instruction::Call { + return_ty: None, + args, + .. + } if args.len() == 1 + && args[0].0 == Type::Double + && matches!(&args[0].1, Operand::TypedLocalRef(_, Type::Double)) + )); +} + +#[test] +fn bitcode_roundtrip_preserves_nested_metadata_sublists_and_node_refs() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let complex_flag = MetadataValue::SubList(vec![ + MetadataValue::NodeRef(0), + MetadataValue::SubList(vec![ + MetadataValue::String("half".to_string()), + MetadataValue::NodeRef(0), + ]), + MetadataValue::String("leaf".to_string()), + ]); + + let module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1], + }], + metadata_nodes: vec![ + MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".to_string()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }, + MetadataNode { + id: 1, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("complex_flag".to_string()), + complex_flag.clone(), + ], + }, + ], + }; + + let bitcode = write_bitcode(&module); + let parsed = parse_bitcode(&bitcode).expect("nested metadata roundtrip should parse"); + + assert_eq!(parsed.metadata_nodes, module.metadata_nodes); + assert_eq!(parsed.get_flag("complex_flag"), Some(&complex_flag)); +} + +#[test] +fn bitcode_roundtrip_binop_sub() { + let ir = "\ +define i64 @test(i64 %a, i64 %b) { +entry: + %r = sub i64 %a, %b + ret i64 %r +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::BinOp { + op: BinOpKind::Sub, + .. + } + )); +} + +#[test] +fn bitcode_roundtrip_binop_fadd() { + let ir = "\ +declare void @use_double(double) + +define void @test(double %a, double %b) { +entry: + %r = fadd double %a, %b + call void @use_double(double %r) + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_fn = rt + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + assert_eq!(rt_fn.params[0].name.as_deref(), Some("a")); + assert_eq!(rt_fn.params[1].name.as_deref(), Some("b")); + + let rt_instrs = &rt_fn.basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty, + lhs, + rhs, + result, + } if ty == &Type::Double + && result == "r" + && matches!(lhs, Operand::TypedLocalRef(name, lhs_ty) if name == "a" && lhs_ty == &Type::Double) + && matches!(rhs, Operand::TypedLocalRef(name, rhs_ty) if name == "b" && rhs_ty == &Type::Double) + )); + assert!(matches!( + &rt_instrs[1], + Instruction::Call { + return_ty: None, + args, + result: None, + .. + } if matches!( + args.as_slice(), + [(Type::Double, Operand::TypedLocalRef(name, ty))] + if name == "r" && ty == &Type::Double + ) + )); +} + +#[test] +fn bitcode_roundtrip_icmp_slt() { + let ir = "\ +define i64 @test(i64 %a, i64 %b) { +entry: + %c = icmp slt i64 %a, %b + %r = select i1 %c, i64 %a, i64 %b + ret i64 %r +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::ICmp { + pred: IntPredicate::Slt, + .. + } + )); +} + +#[test] +fn bitcode_roundtrip_fcmp_oeq() { + let ir = "\ +declare void @use_i1(i1) + +define void @test(double %a, double %b) { +entry: + %c = fcmp oeq double %a, %b + call void @use_i1(i1 %c) + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[1].basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::FCmp { + pred: FloatPredicate::Oeq, + .. + } + )); +} + +#[test] +fn bitcode_roundtrip_call() { + let ir = "\ +declare void @callee(i64) + +define void @test(i64 %a) { +entry: + call void @callee(i64 %a) + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let test_fn = rt + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + let rt_instrs = &test_fn.basic_blocks[0].instructions; + assert!( + matches!(&rt_instrs[0], Instruction::Call { callee, .. } if callee == "callee"), + "expected Call instruction targeting @callee, got {:?}", + rt_instrs[0] + ); +} + +#[test] +fn bitcode_roundtrip_global_ref_operand() { + let ir = "\ +@message = internal constant [5 x i8] c\"hello\"\n\ +\n\ +declare void @use_ptr(ptr)\n\ +\n\ +define void @test() {\n\ +entry:\n\ + call void @use_ptr(ptr @message)\n\ + ret void\n\ +}\n"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let test_fn = rt + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[0], + Instruction::Call { args, .. } + if matches!(args.first(), Some((Type::Ptr, Operand::GlobalRef(name))) if name == "message") + )); +} + +#[test] +fn bitcode_roundtrip_operand_gep_base_name() { + let ir = "\ +@str = internal constant [5 x i8] c\"hello\"\n\ +\n\ +declare void @use_ptr(ptr)\n\ +\n\ +define void @test() {\n\ +entry:\n\ + call void @use_ptr(ptr getelementptr inbounds ([5 x i8], ptr @str, i64 0, i64 0))\n\ + ret void\n\ +}\n"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let test_fn = rt + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[0], + Instruction::Call { args, .. } + if matches!( + args.first(), + Some(( + Type::Ptr, + Operand::GetElementPtr { ptr, .. }, + )) if ptr == "str" + ) + )); +} + +#[test] +fn bitcode_roundtrip_ret_value() { + let ir = "\ +define i64 @test() { +entry: + ret i64 42 +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!(matches!(&rt_instrs[0], Instruction::Ret(Some(_)))); +} + +#[test] +fn bitcode_roundtrip_ret_void() { + let ir = "\ +define void @test() { +entry: + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!(matches!(&rt_instrs[0], Instruction::Ret(None))); +} + +#[test] +fn bitcode_roundtrip_br_conditional() { + let ir = "\ +define void @test(i64 %a, i64 %b) { +entry: + %cond = icmp slt i64 %a, %b + br i1 %cond, label %then, label %else +then: + ret void +else: + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_fn = &rt.functions[0]; + assert_eq!(rt_fn.params[0].name.as_deref(), Some("a")); + assert_eq!(rt_fn.params[1].name.as_deref(), Some("b")); + assert_eq!( + rt_fn + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["entry", "then", "else"] + ); + + let rt_instrs = &rt_fn.basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::ICmp { lhs, rhs, result, .. } + if result == "cond" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "a") + && matches!(rhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "b") + )); + assert!(matches!( + &rt_instrs[1], + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } if true_dest == "then" + && false_dest == "else" + && matches!(cond, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "cond") + )); +} + +#[test] +fn bitcode_roundtrip_jump() { + let ir = "\ +define void @test() { +entry: + br label %exit +exit: + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_fn = &rt.functions[0]; + assert_eq!( + rt_fn + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["entry", "exit"] + ); + + let rt_instrs = &rt_fn.basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::Jump { dest } if dest == "exit" + )); +} + +#[test] +fn bitcode_roundtrip_phi() { + let ir = "\ +define i64 @test() { +block_0: + br label %loop +loop: + %i = phi i64 [ 0, %block_0 ], [ %next, %loop ] + %next = add i64 %i, 1 + %cond = icmp slt i64 %next, 10 + br i1 %cond, label %loop, label %exit +exit: + ret i64 %i +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_fn = &rt.functions[0]; + assert_eq!( + rt_fn + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["block_0", "loop", "exit"] + ); + + let rt_instrs = &rt_fn.basic_blocks[1].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::Phi { + ty, + incoming, + result, + } if ty == &Type::Integer(64) + && result == "i" + && incoming.len() == 2 + && matches!(&incoming[0], (Operand::IntConst(phi_ty, 0), from) if phi_ty == &Type::Integer(64) && from == "block_0") + && matches!(&incoming[1], (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _), from) if name == "next" && from == "loop") + )); + assert!(matches!( + &rt_instrs[1], + Instruction::BinOp { lhs, rhs, result, .. } + if result == "next" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "i") + && matches!(rhs, Operand::IntConst(ty, 1) if ty == &Type::Integer(64)) + )); + assert!(matches!( + &rt_instrs[2], + Instruction::ICmp { lhs, rhs, result, .. } + if result == "cond" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "next") + && matches!(rhs, Operand::IntConst(ty, 10) if ty == &Type::Integer(64)) + )); + assert!(matches!( + &rt_instrs[3], + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } if true_dest == "loop" + && false_dest == "exit" + && matches!(cond, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "cond") + )); + assert!(matches!( + &rt_fn.basic_blocks[2].instructions[0], + Instruction::Ret(Some(Operand::LocalRef(name) | Operand::TypedLocalRef(name, _))) if name == "i" + )); +} + +#[test] +fn bitcode_roundtrip_alloca() { + let ir = "\ +define void @test() { +entry: + %ptr = alloca i64 + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!( + matches!(&rt_instrs[0], Instruction::Alloca { .. }), + "expected Alloca instruction, got {:?}", + rt_instrs[0] + ); +} + +#[test] +fn bitcode_roundtrip_load() { + let ir = "\ +define i64 @test() { +entry: + %ptr = alloca i64 + %val = load i64, ptr %ptr + ret i64 %val +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::Alloca { ty, result } if ty == &Type::Integer(64) && result == "ptr" + )); + assert!(matches!( + &rt_instrs[1], + Instruction::Load { + ty, + ptr_ty, + ptr, + result, + } if ty == &Type::Integer(64) + && ptr_ty == &Type::Ptr + && result == "val" + && matches!(ptr, Operand::TypedLocalRef(name, local_ty) if name == "ptr" && local_ty == &Type::Ptr) + )); + assert!(matches!( + &rt_instrs[2], + Instruction::Ret(Some(Operand::TypedLocalRef(name, ty))) + if name == "val" && ty == &Type::Integer(64) + )); +} + +#[test] +fn bitcode_roundtrip_store() { + let ir = "\ +define void @test() { +entry: + %ptr = alloca i64 + store i64 42, ptr %ptr + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!( + matches!(&rt_instrs[1], Instruction::Store { .. }), + "expected Store instruction, got {:?}", + rt_instrs[1] + ); +} + +#[test] +fn bitcode_roundtrip_select() { + let ir = "\ +define i64 @test(i64 %a, i64 %b) { +entry: + %cond = icmp slt i64 %a, %b + %r = select i1 %cond, i64 %a, i64 %b + ret i64 %r +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!( + matches!(&rt_instrs[1], Instruction::Select { .. }), + "expected Select instruction, got {:?}", + rt_instrs[1] + ); +} + +#[test] +fn bitcode_roundtrip_switch() { + let ir = "\ +define void @test(i32 %val) { +entry: + switch i32 %val, label %default [ + i32 0, label %case0 + i32 1, label %case1 + ] +case0: + ret void +case1: + ret void +default: + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_fn = &rt.functions[0]; + assert_eq!(rt_fn.params[0].name.as_deref(), Some("val")); + assert_eq!( + rt_fn + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["entry", "case0", "case1", "default"] + ); + + let rt_instrs = &rt_fn.basic_blocks[0].instructions; + assert!(matches!( + &rt_instrs[0], + Instruction::Switch { + ty, + value, + default_dest, + cases, + } if ty == &Type::Integer(32) + && default_dest == "default" + && cases == &vec![(0, "case0".to_string()), (1, "case1".to_string())] + && matches!(value, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "val") + )); +} + +#[test] +fn bitcode_roundtrip_gep() { + let ir = "\ +@str = internal constant [5 x i8] c\"hello\" + +declare void @use_ptr(ptr) + +define void @test() { +entry: + %0 = getelementptr inbounds [5 x i8], ptr @str, i64 0, i64 0 + call void @use_ptr(ptr %0) + ret void +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + // The test function is after the declare + let test_fn = rt + .functions + .iter() + .find(|f| f.name == "test") + .expect("missing test function"); + let rt_instrs = &test_fn.basic_blocks[0].instructions; + assert!( + matches!( + &rt_instrs[0], + Instruction::GetElementPtr { inbounds: true, .. } + ), + "expected GetElementPtr instruction, got {:?}", + rt_instrs[0] + ); +} + +#[test] +fn bitcode_roundtrip_unreachable() { + let ir = "\ +define void @test() { +entry: + unreachable +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!( + matches!(&rt_instrs[0], Instruction::Unreachable), + "expected Unreachable instruction, got {:?}", + rt_instrs[0] + ); +} + +#[test] +fn bitcode_roundtrip_cast_zext() { + let ir = "\ +define i64 @test(i32 %a) { +entry: + %r = zext i32 %a to i64 + ret i64 %r +} +"; + let (orig, rt) = bitcode_roundtrip(ir); + assert_bitcode_roundtrip_structure(&orig, &rt); + let rt_instrs = &rt.functions[0].basic_blocks[0].instructions; + assert!( + matches!( + &rt_instrs[0], + Instruction::Cast { + op: CastKind::Zext, + .. + } + ), + "expected Cast/Zext instruction, got {:?}", + rt_instrs[0] + ); +} + +// --- Phase 3: Cross-format roundtrip tests --- + +#[test] +fn bitcode_roundtrip_global_type_preserved() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let ir = "\ +@str = internal constant [5 x i8] c\"hello\" + +declare void @use_ptr(ptr) + +define void @test() { +entry: + %0 = getelementptr inbounds [5 x i8], ptr @str, i64 0, i64 0 + call void @use_ptr(ptr %0) + ret void +} +"; + let module = parse_module(ir).expect("text parse"); + assert_eq!( + module.globals[0].ty, + Type::Array(5, Box::new(Type::Integer(8))) + ); + + let bc = write_bitcode(&module); + let rt = parse_bitcode(&bc).expect("bitcode parse"); + assert_eq!( + rt.globals[0].ty, + Type::Array(5, Box::new(Type::Integer(8))), + "global type should be preserved through bitcode roundtrip" + ); + assert!(rt.globals[0].is_constant); +} + +#[test] +fn cross_format_text_to_bitcode_to_text_call() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let ir = "\ +declare void @callee(i64) + +define void @test(i64 %a) { +entry: + call void @callee(i64 %a) + ret void +} +"; + let m1 = parse_module(ir).expect("text parse"); + let bc = write_bitcode(&m1); + let m2 = parse_bitcode(&bc).expect("bitcode parse"); + let text2 = write_module_to_string(&m2); + let m3 = parse_module(&text2).expect("re-text parse"); + + // Structural comparison + assert_eq!(m2.functions.len(), m3.functions.len()); + for (f2, f3) in m2.functions.iter().zip(m3.functions.iter()) { + assert_eq!(f2.name, f3.name); + assert_eq!(f2.is_declaration, f3.is_declaration); + assert_eq!(f2.basic_blocks.len(), f3.basic_blocks.len()); + } + + let bitcode_test_fn = m2 + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing bitcode test function"); + assert!(matches!( + &bitcode_test_fn.basic_blocks[0].instructions[0], + Instruction::Call { callee, .. } if callee == "callee" + )); + + let reparsed_test_fn = m3 + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing reparsed test function"); + assert!(matches!( + &reparsed_test_fn.basic_blocks[0].instructions[0], + Instruction::Call { callee, .. } if callee == "callee" + )); +} + +#[test] +fn cross_format_text_to_bitcode_to_text_branch() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let ir = "\ +define void @test(i64 %a, i64 %b) { +entry: + %cond = icmp slt i64 %a, %b + br i1 %cond, label %then, label %else +then: + ret void +else: + ret void +} +"; + let m1 = parse_module(ir).expect("text parse"); + let bc = write_bitcode(&m1); + let m2 = parse_bitcode(&bc).expect("bitcode parse"); + let text2 = write_module_to_string(&m2); + let m3 = parse_module(&text2).expect("re-text parse"); + + assert_eq!(m2.functions.len(), m3.functions.len()); + let assert_branch_function = |function: &Function| { + assert_eq!(function.params[0].name.as_deref(), Some("a")); + assert_eq!(function.params[1].name.as_deref(), Some("b")); + assert_eq!( + function + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["entry", "then", "else"] + ); + assert!(matches!( + &function.basic_blocks[0].instructions[0], + Instruction::ICmp { lhs, rhs, result, .. } + if result == "cond" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "a") + && matches!(rhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "b") + )); + assert!(matches!( + &function.basic_blocks[0].instructions[1], + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } if true_dest == "then" + && false_dest == "else" + && matches!(cond, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "cond") + )); + }; + + let bitcode_test_fn = m2 + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing bitcode test function"); + assert_branch_function(bitcode_test_fn); + + let reparsed_test_fn = m3 + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing reparsed test function"); + assert_branch_function(reparsed_test_fn); +} + +#[test] +fn cross_format_text_to_bitcode_to_text_binop() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let ir = "\ +define i64 @test(i64 %a, i64 %b) { +entry: + %r = add i64 %a, %b + ret i64 %r +} +"; + let m1 = parse_module(ir).expect("text parse"); + let bc = write_bitcode(&m1); + let m2 = parse_bitcode(&bc).expect("bitcode parse"); + let text2 = write_module_to_string(&m2); + let m3 = parse_module(&text2).expect("re-text parse"); + + assert_eq!(m2.functions.len(), m3.functions.len()); + for (f2, f3) in m2.functions.iter().zip(m3.functions.iter()) { + assert_eq!(f2.name, f3.name); + assert_eq!(f2.basic_blocks.len(), f3.basic_blocks.len()); + for (b2, b3) in f2.basic_blocks.iter().zip(f3.basic_blocks.iter()) { + assert_eq!(b2.instructions.len(), b3.instructions.len()); + } + } +} + +#[test] +fn cross_format_bitcode_to_text_to_bitcode() { + use super::bitcode::reader::parse_bitcode; + use super::bitcode::writer::write_bitcode; + + let ir = "\ +define i64 @test(i64 %a, i64 %b) { +entry: + %sum = add i64 %a, %b + %cond = icmp slt i64 %sum, 100 + br i1 %cond, label %then, label %else +then: + ret i64 %sum +else: + ret i64 0 +} +"; + // text -> module -> bitcode -> module -> text -> module -> bitcode + let m1 = parse_module(ir).expect("text parse"); + let bc1 = write_bitcode(&m1); + let m2 = parse_bitcode(&bc1).expect("bitcode parse 1"); + let text2 = write_module_to_string(&m2); + let m3 = parse_module(&text2).expect("re-text parse"); + let bc2 = write_bitcode(&m3); + let m4 = parse_bitcode(&bc2).expect("bitcode parse 2"); + + // m2 and m4 should be structurally equivalent (both went through bitcode) + assert_eq!(m2.functions.len(), m4.functions.len()); + for (f2, f4) in m2.functions.iter().zip(m4.functions.iter()) { + assert_eq!(f2.name, f4.name); + assert_eq!(f2.is_declaration, f4.is_declaration); + assert_eq!(f2.basic_blocks.len(), f4.basic_blocks.len()); + for (b2, b4) in f2.basic_blocks.iter().zip(f4.basic_blocks.iter()) { + assert_eq!(b2.instructions.len(), b4.instructions.len()); + } + } +} diff --git a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests.rs b/source/compiler/qsc_llvm/src/text.rs similarity index 58% rename from source/compiler/qsc_codegen/src/qir/v1/instruction_tests.rs rename to source/compiler/qsc_llvm/src/text.rs index 1951ad088c..676a88496d 100644 --- a/source/compiler/qsc_codegen/src/qir/v1/instruction_tests.rs +++ b/source/compiler/qsc_llvm/src/text.rs @@ -1,8 +1,5 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -mod bool; -mod double; -mod int; -mod invalid; -mod phi; +pub mod reader; +pub mod writer; diff --git a/source/compiler/qsc_llvm/src/text/reader.rs b/source/compiler/qsc_llvm/src/text/reader.rs new file mode 100644 index 0000000000..a517eff4eb --- /dev/null +++ b/source/compiler/qsc_llvm/src/text/reader.rs @@ -0,0 +1,1537 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use crate::model::Type; +use crate::model::{ + Attribute, AttributeGroup, BasicBlock, BinOpKind, CastKind, Constant, FloatPredicate, Function, + GlobalVariable, Instruction, IntPredicate, Linkage, MetadataNode, MetadataValue, Module, + NamedMetadata, Operand, Param, StructType, +}; +use crate::{ReadDiagnostic, ReadDiagnosticKind, ReadPolicy}; +use winnow::combinator::opt; +use winnow::error::{ContextError, ErrMode, StrContext}; +use winnow::prelude::*; +use winnow::token::{any, literal, one_of, take_while}; + +type Input<'a> = &'a str; +type PResult = winnow::ModalResult; + +fn ws_no_newline(input: &mut Input<'_>) -> PResult<()> { + take_while(0.., |c: char| c == ' ' || c == '\t' || c == '\r') + .void() + .parse_next(input) +} + +fn ws(input: &mut Input<'_>) -> PResult<()> { + take_while(0.., |c: char| c.is_ascii_whitespace()) + .void() + .parse_next(input) +} + +fn line_comment(input: &mut Input<'_>) -> PResult<()> { + (';', take_while(0.., |c: char| c != '\n'), opt('\n')) + .void() + .parse_next(input) +} + +fn ws_and_comments(input: &mut Input<'_>) -> PResult<()> { + loop { + ws(input)?; + if input.starts_with(';') { + line_comment(input)?; + } else { + break; + } + } + Ok(()) +} + +fn identifier_chars(input: &mut Input<'_>) -> PResult { + take_while(1.., |c: char| { + c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '-' + }) + .map(String::from) + .context(StrContext::Label("identifier")) + .parse_next(input) +} + +/// Parse unsigned decimal integer +fn parse_u64(input: &mut Input<'_>) -> PResult { + take_while(1.., |c: char| c.is_ascii_digit()) + .try_map(|s: &str| s.parse::()) + .context(StrContext::Label("unsigned integer")) + .parse_next(input) +} + +fn parse_u32(input: &mut Input<'_>) -> PResult { + take_while(1.., |c: char| c.is_ascii_digit()) + .try_map(|s: &str| s.parse::()) + .context(StrContext::Label("u32")) + .parse_next(input) +} + +/// Parse a possibly-negative decimal integer +fn parse_integer(input: &mut Input<'_>) -> PResult { + let neg = opt('-').parse_next(input)?; + let digits: &str = take_while(1.., |c: char| c.is_ascii_digit()) + .context(StrContext::Label("integer digits")) + .parse_next(input)?; + let val: i64 = digits + .parse() + .map_err(|_| ErrMode::Backtrack(ContextError::new()))?; + Ok(if neg.is_some() { -val } else { val }) +} + +/// Parse a possibly-negative float (digits with optional dot) +fn parse_float_literal(input: &mut Input<'_>) -> PResult { + // LLVM hex float: 0xHHHHHHHHHHHHHHHH (16 hex digits = IEEE 754 double) + if input.starts_with("0x") { + literal("0x").parse_next(input)?; + let hex_str: &str = take_while(1.., |c: char| c.is_ascii_hexdigit()) + .context(StrContext::Label("hex float digits")) + .parse_next(input)?; + let bits = u64::from_str_radix(hex_str, 16) + .map_err(|_| ErrMode::Backtrack(ContextError::new()))?; + return Ok(f64::from_bits(bits)); + } + + let neg = opt('-').parse_next(input)?; + let int_part: &str = take_while(1.., |c: char| c.is_ascii_digit()) + .context(StrContext::Label("float digits")) + .parse_next(input)?; + let frac = opt(('.', take_while(0.., |c: char| c.is_ascii_digit()))).parse_next(input)?; + let exp = opt(( + 'e', + opt(one_of(['+', '-'])), + take_while(1.., |c: char| c.is_ascii_digit()), + )) + .parse_next(input)?; + let mut s = String::new(); + if neg.is_some() { + s.push('-'); + } + s.push_str(int_part); + if let Some((_, frac_digits)) = frac { + s.push('.'); + s.push_str(frac_digits); + } + if let Some((_, sign, exp_digits)) = exp { + s.push('e'); + if let Some(sign_ch) = sign { + s.push(sign_ch); + } + s.push_str(exp_digits); + } + let val: f64 = s + .parse() + .map_err(|_| ErrMode::Backtrack(ContextError::new()))?; + Ok(val) +} + +/// Parse a quoted string: "..." +fn parse_quoted_string(input: &mut Input<'_>) -> PResult { + '"'.parse_next(input)?; + let mut s = String::new(); + loop { + let ch = any + .context(StrContext::Label("char in quoted string")) + .parse_next(input)?; + match ch { + '"' => return Ok(s), + '\\' => { + let next = any.parse_next(input)?; + match next { + '\\' => s.push('\\'), + '"' => s.push('"'), + 'n' => s.push('\n'), + c if c.is_ascii_hexdigit() => { + let mut hex = String::new(); + hex.push(c); + // Peek for a second hex digit + if input.chars().next().is_some_and(|h| h.is_ascii_hexdigit()) { + let h2 = any.parse_next(input)?; + hex.push(h2); + } + let byte = u8::from_str_radix(&hex, 16) + .map_err(|_| ErrMode::Backtrack(ContextError::new()))?; + s.push(byte as char); + } + other => { + s.push('\\'); + s.push(other); + } + } + } + other => s.push(other), + } + } +} + +/// Parse a C-string: c"...\00" +fn parse_c_string(input: &mut Input<'_>) -> PResult { + 'c'.parse_next(input)?; + '"'.parse_next(input)?; + let mut s = String::new(); + loop { + let ch = any.parse_next(input)?; + match ch { + '"' => return Ok(s), + '\\' => { + let h1 = any.parse_next(input)?; + let h2 = any.parse_next(input)?; + let mut hex = String::new(); + hex.push(h1); + hex.push(h2); + let byte = u8::from_str_radix(&hex, 16) + .map_err(|_| ErrMode::Backtrack(ContextError::new()))?; + if byte == 0 { + continue; // null terminator - skip + } + s.push(byte as char); + } + other => s.push(other), + } + } +} + +fn parse_type(input: &mut Input<'_>) -> PResult { + let ty = parse_base_type(input)?; + // Check for pointer suffix `*` + if input.starts_with('*') { + '*'.parse_next(input)?; + match ty { + Type::Named(name) => Ok(Type::NamedPtr(name)), + other => Ok(Type::TypedPtr(Box::new(other))), + } + } else { + Ok(ty) + } +} + +fn parse_base_type(input: &mut Input<'_>) -> PResult { + if input.starts_with("void") { + literal("void").parse_next(input)?; + return Ok(Type::Void); + } + if input.starts_with("half") { + literal("half").parse_next(input)?; + return Ok(Type::Half); + } + if input.starts_with("float") { + literal("float").parse_next(input)?; + return Ok(Type::Float); + } + if input.starts_with("double") { + literal("double").parse_next(input)?; + return Ok(Type::Double); + } + if input.starts_with("ptr") { + literal("ptr").parse_next(input)?; + return Ok(Type::Ptr); + } + if input.starts_with('i') { + 'i'.parse_next(input)?; + let n = parse_u32(input)?; + return Ok(Type::Integer(n)); + } + if input.starts_with('[') { + '['.parse_next(input)?; + ws_no_newline(input)?; + let count = parse_u64(input)?; + ws_no_newline(input)?; + literal("x").parse_next(input)?; + ws_no_newline(input)?; + let elem = parse_type(input)?; + ws_no_newline(input)?; + ']'.parse_next(input)?; + return Ok(Type::Array(count, Box::new(elem))); + } + if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + return Ok(Type::Named(name)); + } + Err(ErrMode::Backtrack(ContextError::new())) +} + +/// Parse a typed operand like `i64 42`, `ptr null`, `ptr @name`, etc. +fn parse_typed_operand(input: &mut Input<'_>) -> PResult { + // LocalRef shorthand: %name (no type prefix when used in `ret %name`) + if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + return Ok(Operand::LocalRef(name)); + } + + let ty = parse_type(input)?; + ws_no_newline(input)?; + + parse_operand_value_with_type(&ty, input) +} + +fn parse_operand_value_with_type(ty: &Type, input: &mut Input<'_>) -> PResult { + match ty { + Type::Ptr => { + if input.starts_with("null") { + literal("null").parse_next(input)?; + Ok(Operand::NullPtr) + } else if input.starts_with("inttoptr") { + parse_inttoptr_expr(input) + } else if input.starts_with("getelementptr") { + parse_gep_expr(input) + } else if input.starts_with('@') { + '@'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::GlobalRef(name)) + } else if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::TypedLocalRef(name, ty.clone())) + } else { + Err(ErrMode::Backtrack(ContextError::new())) + } + } + Type::NamedPtr(_) => { + if input.starts_with("inttoptr") { + parse_inttoptr_expr_with_type(ty, input) + } else if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::TypedLocalRef(name, ty.clone())) + } else { + Err(ErrMode::Backtrack(ContextError::new())) + } + } + Type::TypedPtr(_) => { + if input.starts_with("null") { + literal("null").parse_next(input)?; + Ok(Operand::NullPtr) + } else if input.starts_with("getelementptr") { + parse_gep_expr(input) + } else if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::TypedLocalRef(name, ty.clone())) + } else { + Err(ErrMode::Backtrack(ContextError::new())) + } + } + Type::Integer(_) => { + if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::TypedLocalRef(name, ty.clone())) + } else { + parse_int_or_bool_value(ty, input) + } + } + Type::Half | Type::Float | Type::Double => { + if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + Ok(Operand::TypedLocalRef(name, ty.clone())) + } else { + let f = parse_float_literal(input)?; + Ok(Operand::float_const(ty.clone(), f)) + } + } + _ => Err(ErrMode::Backtrack(ContextError::new())), + } +} + +/// Parse an untyped operand given the expected type context +fn parse_untyped_operand(ty: &Type, input: &mut Input<'_>) -> PResult { + if input.starts_with('%') { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + return Ok(Operand::TypedLocalRef(name, ty.clone())); + } + if input.starts_with('@') { + '@'.parse_next(input)?; + let name = identifier_chars(input)?; + return Ok(Operand::GlobalRef(name)); + } + if input.starts_with("null") { + literal("null").parse_next(input)?; + return Ok(Operand::NullPtr); + } + if input.starts_with("inttoptr") { + if matches!(ty, Type::Ptr | Type::NamedPtr(_) | Type::TypedPtr(_)) { + return parse_inttoptr_expr_with_type(ty, input); + } + return parse_inttoptr_expr(input); + } + if input.starts_with("getelementptr") { + return parse_gep_expr(input); + } + if input.starts_with("true") { + literal("true").parse_next(input)?; + return Ok(Operand::IntConst(ty.clone(), 1)); + } + if input.starts_with("false") { + literal("false").parse_next(input)?; + return Ok(Operand::IntConst(ty.clone(), 0)); + } + + if ty.is_floating_point() { + let f = parse_float_literal(input)?; + Ok(Operand::float_const(ty.clone(), f)) + } else { + let val = parse_integer(input)?; + Ok(Operand::IntConst(ty.clone(), val)) + } +} + +fn parse_int_or_bool_value(ty: &Type, input: &mut Input<'_>) -> PResult { + if input.starts_with("true") { + literal("true").parse_next(input)?; + return Ok(Operand::IntConst(ty.clone(), 1)); + } + if input.starts_with("false") { + literal("false").parse_next(input)?; + return Ok(Operand::IntConst(ty.clone(), 0)); + } + let val = parse_integer(input)?; + Ok(Operand::IntConst(ty.clone(), val)) +} + +fn parse_inttoptr_expr(input: &mut Input<'_>) -> PResult { + literal("inttoptr").parse_next(input)?; + ws_no_newline(input)?; + '('.parse_next(input)?; + ws_no_newline(input)?; + literal("i64").parse_next(input)?; + ws_no_newline(input)?; + let val = parse_integer(input)?; + ws_no_newline(input)?; + literal("to").parse_next(input)?; + ws_no_newline(input)?; + let target_ty = parse_type(input)?; + ws_no_newline(input)?; + ')'.parse_next(input)?; + Ok(Operand::IntToPtr(val, target_ty)) +} + +fn parse_inttoptr_expr_with_type(ty: &Type, input: &mut Input<'_>) -> PResult { + literal("inttoptr").parse_next(input)?; + ws_no_newline(input)?; + '('.parse_next(input)?; + ws_no_newline(input)?; + literal("i64").parse_next(input)?; + ws_no_newline(input)?; + let val = parse_integer(input)?; + ws_no_newline(input)?; + literal("to").parse_next(input)?; + ws_no_newline(input)?; + let target_ty = parse_type(input)?; + if target_ty != *ty { + return Err(ErrMode::Cut(ContextError::new())); + } + ws_no_newline(input)?; + ')'.parse_next(input)?; + Ok(Operand::IntToPtr(val, ty.clone())) +} + +fn parse_gep_expr(input: &mut Input<'_>) -> PResult { + literal("getelementptr").parse_next(input)?; + ws_no_newline(input)?; + literal("inbounds").parse_next(input)?; + ws_no_newline(input)?; + '('.parse_next(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let ptr_ty = parse_type(input)?; + ws_no_newline(input)?; + '@'.parse_next(input)?; + let ptr = identifier_chars(input)?; + + let mut indices = Vec::new(); + ws_no_newline(input)?; + while input.starts_with(',') { + ','.parse_next(input)?; + ws_no_newline(input)?; + let idx = parse_typed_operand(input)?; + indices.push(idx); + ws_no_newline(input)?; + } + ')'.parse_next(input)?; + + Ok(Operand::GetElementPtr { + ty, + ptr, + ptr_ty, + indices, + }) +} + +/// Parse GEP instruction body (keyword already consumed by dispatch): +/// `[inbounds] , , , ...` +fn parse_gep_instruction_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let inbounds = if input.starts_with("inbounds") { + literal("inbounds").parse_next(input)?; + ws_no_newline(input)?; + true + } else { + false + }; + let pointee_ty = parse_type(input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let ptr_ty = parse_type(input)?; + ws_no_newline(input)?; + let ptr = parse_untyped_operand(&ptr_ty, input)?; + + let mut indices = Vec::new(); + ws_no_newline(input)?; + while input.starts_with(',') { + ','.parse_next(input)?; + ws_no_newline(input)?; + let idx = parse_typed_operand(input)?; + indices.push(idx); + ws_no_newline(input)?; + } + + Ok(Instruction::GetElementPtr { + inbounds, + pointee_ty, + ptr_ty, + ptr, + indices, + result: result.to_string(), + }) +} + +fn parse_ret(input: &mut Input<'_>) -> PResult { + literal("ret").parse_next(input)?; + ws_no_newline(input)?; + if input.starts_with("void") { + literal("void").parse_next(input)?; + Ok(Instruction::Ret(None)) + } else { + let operand = parse_typed_operand(input)?; + Ok(Instruction::Ret(Some(operand))) + } +} + +fn parse_br(input: &mut Input<'_>) -> PResult { + literal("br").parse_next(input)?; + ws_no_newline(input)?; + + if input.starts_with("label") { + // Unconditional: br label %dest + literal("label").parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let dest = identifier_chars(input)?; + return Ok(Instruction::Jump { dest }); + } + + // Conditional: br i1 %cond, label %true, label %false + let cond_ty = parse_type(input)?; + ws_no_newline(input)?; + let cond = parse_untyped_operand(&cond_ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + literal("label").parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let true_dest = identifier_chars(input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + literal("label").parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let false_dest = identifier_chars(input)?; + + Ok(Instruction::Br { + cond_ty, + cond, + true_dest, + false_dest, + }) +} + +fn parse_binop(op: BinOpKind, result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let lhs = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let rhs = parse_untyped_operand(&ty, input)?; + + Ok(Instruction::BinOp { + op, + ty, + lhs, + rhs, + result: result.to_string(), + }) +} + +fn parse_int_predicate(input: &mut Input<'_>) -> PResult { + let kw: &str = take_while(1.., |c: char| c.is_ascii_alphabetic()).parse_next(input)?; + match kw { + "eq" => Ok(IntPredicate::Eq), + "ne" => Ok(IntPredicate::Ne), + "sgt" => Ok(IntPredicate::Sgt), + "sge" => Ok(IntPredicate::Sge), + "slt" => Ok(IntPredicate::Slt), + "sle" => Ok(IntPredicate::Sle), + "ult" => Ok(IntPredicate::Ult), + "ule" => Ok(IntPredicate::Ule), + "ugt" => Ok(IntPredicate::Ugt), + "uge" => Ok(IntPredicate::Uge), + _ => Err(ErrMode::Backtrack(ContextError::new())), + } +} + +fn parse_float_predicate(input: &mut Input<'_>) -> PResult { + let kw: &str = take_while(1.., |c: char| c.is_ascii_alphabetic()).parse_next(input)?; + match kw { + "oeq" => Ok(FloatPredicate::Oeq), + "ogt" => Ok(FloatPredicate::Ogt), + "oge" => Ok(FloatPredicate::Oge), + "olt" => Ok(FloatPredicate::Olt), + "ole" => Ok(FloatPredicate::Ole), + "one" => Ok(FloatPredicate::One), + "ord" => Ok(FloatPredicate::Ord), + "uno" => Ok(FloatPredicate::Uno), + "ueq" => Ok(FloatPredicate::Ueq), + "ugt" => Ok(FloatPredicate::Ugt), + "uge" => Ok(FloatPredicate::Uge), + "ult" => Ok(FloatPredicate::Ult), + "ule" => Ok(FloatPredicate::Ule), + "une" => Ok(FloatPredicate::Une), + _ => Err(ErrMode::Backtrack(ContextError::new())), + } +} + +fn parse_cast(kind: CastKind, result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let from_ty = parse_type(input)?; + ws_no_newline(input)?; + let value = parse_untyped_operand(&from_ty, input)?; + ws_no_newline(input)?; + literal("to").parse_next(input)?; + ws_no_newline(input)?; + let to_ty = parse_type(input)?; + + Ok(Instruction::Cast { + op: kind, + from_ty, + to_ty, + value, + result: result.to_string(), + }) +} + +fn parse_switch(input: &mut Input<'_>) -> PResult { + literal("switch").parse_next(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let value = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + literal("label").parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let default_dest = identifier_chars(input)?; + ws_no_newline(input)?; + '['.parse_next(input)?; + + let mut cases = Vec::new(); + loop { + ws_and_comments(input)?; + if input.starts_with(']') { + ']'.parse_next(input)?; + break; + } + let _case_ty = parse_type(input)?; + ws_no_newline(input)?; + let case_val = parse_integer(input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + literal("label").parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let dest = identifier_chars(input)?; + cases.push((case_val, dest)); + } + + Ok(Instruction::Switch { + ty, + value, + default_dest, + cases, + }) +} + +fn parse_call(result: Option<&str>, input: &mut Input<'_>) -> PResult { + literal("call").parse_next(input)?; + ws_no_newline(input)?; + let ret_type = parse_type(input)?; + ws_no_newline(input)?; + '@'.parse_next(input)?; + let callee = identifier_chars(input)?; + '('.parse_next(input)?; + + let mut args = Vec::new(); + ws_no_newline(input)?; + if !input.starts_with(')') { + loop { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let op = parse_untyped_operand(&ty, input)?; + args.push((ty, op)); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } else { + break; + } + } + } + ')'.parse_next(input)?; + + let mut attr_refs = Vec::new(); + ws_no_newline(input)?; + while input.starts_with('#') { + '#'.parse_next(input)?; + let id = parse_u32(input)?; + attr_refs.push(id); + ws_no_newline(input)?; + } + + let (return_ty, result_name) = if ret_type == Type::Void { + (None, None) + } else { + (Some(ret_type), result.map(String::from)) + }; + + Ok(Instruction::Call { + return_ty, + callee, + args, + result: result_name, + attr_refs, + }) +} + +fn parse_store(input: &mut Input<'_>) -> PResult { + literal("store").parse_next(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let value = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let ptr_ty = parse_type(input)?; + ws_no_newline(input)?; + let ptr = parse_untyped_operand(&ptr_ty, input)?; + + Ok(Instruction::Store { + ty, + value, + ptr_ty, + ptr, + }) +} + +fn parse_icmp_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let pred = parse_int_predicate(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let lhs = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let rhs = parse_untyped_operand(&ty, input)?; + + Ok(Instruction::ICmp { + pred, + ty, + lhs, + rhs, + result: result.to_string(), + }) +} + +fn parse_fcmp_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let pred = parse_float_predicate(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let lhs = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let rhs = parse_untyped_operand(&ty, input)?; + + Ok(Instruction::FCmp { + pred, + ty, + lhs, + rhs, + result: result.to_string(), + }) +} + +fn parse_select_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let _cond_ty = parse_type(input)?; + ws_no_newline(input)?; + let cond = parse_untyped_operand(&Type::Integer(1), input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let true_val = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let _false_ty = parse_type(input)?; + ws_no_newline(input)?; + let false_val = parse_untyped_operand(&ty, input)?; + + Ok(Instruction::Select { + cond, + true_val, + false_val, + ty, + result: result.to_string(), + }) +} + +fn parse_call_body(result: Option<&str>, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let ret_type = parse_type(input)?; + ws_no_newline(input)?; + '@'.parse_next(input)?; + let callee = identifier_chars(input)?; + '('.parse_next(input)?; + + let mut args = Vec::new(); + ws_no_newline(input)?; + if !input.starts_with(')') { + loop { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let op = parse_untyped_operand(&ty, input)?; + args.push((ty, op)); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } else { + break; + } + } + } + ')'.parse_next(input)?; + + let mut attr_refs = Vec::new(); + ws_no_newline(input)?; + while input.starts_with('#') { + '#'.parse_next(input)?; + let id = parse_u32(input)?; + attr_refs.push(id); + ws_no_newline(input)?; + } + + let (return_ty, result_name) = if ret_type == Type::Void { + (None, None) + } else { + (Some(ret_type), result.map(String::from)) + }; + + Ok(Instruction::Call { + return_ty, + callee, + args, + result: result_name, + attr_refs, + }) +} + +fn parse_phi_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + + let mut incoming = Vec::new(); + loop { + ws_no_newline(input)?; + if !input.starts_with('[') { + break; + } + '['.parse_next(input)?; + ws_no_newline(input)?; + let val = parse_untyped_operand(&ty, input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + '%'.parse_next(input)?; + let block = identifier_chars(input)?; + ws_no_newline(input)?; + ']'.parse_next(input)?; + incoming.push((val, block)); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } + } + + Ok(Instruction::Phi { + ty, + incoming, + result: result.to_string(), + }) +} + +fn parse_alloca_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let ty = parse_type(input)?; + + Ok(Instruction::Alloca { + ty, + result: result.to_string(), + }) +} + +fn parse_load_body(result: &str, input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + ','.parse_next(input)?; + ws_no_newline(input)?; + let ptr_ty = parse_type(input)?; + ws_no_newline(input)?; + let ptr = parse_untyped_operand(&ptr_ty, input)?; + + Ok(Instruction::Load { + ty, + ptr_ty, + ptr, + result: result.to_string(), + }) +} + +/// Unified assignment RHS dispatcher — keyword already consumed +fn dispatch_assignment_rhs(kw: &str, result: &str, input: &mut Input<'_>) -> PResult { + match kw { + "add" => parse_binop(BinOpKind::Add, result, input), + "sub" => parse_binop(BinOpKind::Sub, result, input), + "mul" => parse_binop(BinOpKind::Mul, result, input), + "sdiv" => parse_binop(BinOpKind::Sdiv, result, input), + "srem" => parse_binop(BinOpKind::Srem, result, input), + "shl" => parse_binop(BinOpKind::Shl, result, input), + "ashr" => parse_binop(BinOpKind::Ashr, result, input), + "and" => parse_binop(BinOpKind::And, result, input), + "or" => parse_binop(BinOpKind::Or, result, input), + "xor" => parse_binop(BinOpKind::Xor, result, input), + "fadd" => parse_binop(BinOpKind::Fadd, result, input), + "fsub" => parse_binop(BinOpKind::Fsub, result, input), + "fmul" => parse_binop(BinOpKind::Fmul, result, input), + "fdiv" => parse_binop(BinOpKind::Fdiv, result, input), + "udiv" => parse_binop(BinOpKind::Udiv, result, input), + "urem" => parse_binop(BinOpKind::Urem, result, input), + "lshr" => parse_binop(BinOpKind::Lshr, result, input), + "icmp" => parse_icmp_body(result, input), + "fcmp" => parse_fcmp_body(result, input), + "sitofp" => parse_cast(CastKind::Sitofp, result, input), + "fptosi" => parse_cast(CastKind::Fptosi, result, input), + "zext" => parse_cast(CastKind::Zext, result, input), + "sext" => parse_cast(CastKind::Sext, result, input), + "trunc" => parse_cast(CastKind::Trunc, result, input), + "fpext" => parse_cast(CastKind::FpExt, result, input), + "fptrunc" => parse_cast(CastKind::FpTrunc, result, input), + "inttoptr" => parse_cast(CastKind::IntToPtr, result, input), + "ptrtoint" => parse_cast(CastKind::PtrToInt, result, input), + "bitcast" => parse_cast(CastKind::Bitcast, result, input), + "select" => parse_select_body(result, input), + "call" => parse_call_body(Some(result), input), + "phi" => parse_phi_body(result, input), + "alloca" => parse_alloca_body(result, input), + "load" => parse_load_body(result, input), + "getelementptr" => parse_gep_instruction_body(result, input), + _ => Err(ErrMode::Backtrack(ContextError::new())), + } +} + +fn parse_instruction(input: &mut Input<'_>) -> PResult { + ws_no_newline(input)?; + + // Check for assignment: %result = ... + if input.starts_with('%') { + // Try parsing as assignment + let checkpoint = *input; + '%'.parse_next(input)?; + let result_name = identifier_chars(input)?; + ws_no_newline(input)?; + if input.starts_with('=') { + '='.parse_next(input)?; + ws_no_newline(input)?; + // Read keyword + let kw: String = take_while(1.., |c: char| c.is_ascii_alphanumeric()) + .map(String::from) + .parse_next(input)?; + return dispatch_assignment_rhs(&kw, &result_name, input); + } + // Not an assignment, restore + *input = checkpoint; + } + + // Non-assignment instructions + if input.starts_with("ret") { + return parse_ret(input); + } + if input.starts_with("br") { + return parse_br(input); + } + if input.starts_with("call") { + return parse_call(None, input); + } + if input.starts_with("store") { + return parse_store(input); + } + if input.starts_with("switch") { + return parse_switch(input); + } + if input.starts_with("unreachable") { + literal("unreachable").parse_next(input)?; + return Ok(Instruction::Unreachable); + } + + Err(ErrMode::Backtrack(ContextError::new())) +} + +fn parse_block_label(input: &mut Input<'_>) -> PResult { + let label: &str = take_while(1.., |c: char| c != ':' && c != '\n' && c != '}') + .context(StrContext::Label("block label")) + .parse_next(input)?; + ':'.parse_next(input)?; + Ok(label.to_string()) +} + +fn is_label_start(input: &Input<'_>) -> bool { + // Labels start at column 0. Instructions are indented (leading space). + // A label is: non-whitespace chars followed by ':' + if let Some(ch) = input.chars().next() + && (ch.is_ascii_alphanumeric() || ch == '_') + { + // Scan ahead for ':' + for c in input.chars() { + if c == ':' { + return true; + } + if c == '\n' || c == ' ' || c == '=' { + return false; + } + } + } + false +} + +fn parse_instructions(input: &mut Input<'_>) -> PResult> { + let mut instructions = Vec::new(); + + loop { + ws_and_comments(input)?; + + if input.is_empty() || input.starts_with('}') { + break; + } + + // Check if this is a label (next block) + if is_label_start(input) { + break; + } + + instructions.push(parse_instruction(input)?); + } + + Ok(instructions) +} + +fn parse_basic_blocks(input: &mut Input<'_>) -> PResult> { + let mut blocks = Vec::new(); + + loop { + ws_and_comments(input)?; + + if input.starts_with('}') || input.is_empty() { + break; + } + + // If this looks like a label (identifier at column 0 followed by ':'), + // parse it. Otherwise, the first block has an implicit label. + let label = if is_label_start(input) { + parse_block_label(input)? + } else if blocks.is_empty() { + // LLVM IR allows the first basic block to omit the label. + // Use an implicit "0" label (LLVM default for unnamed blocks). + "0".to_string() + } else { + // Non-first block without a label is unexpected — stop. + break; + }; + + let instructions = parse_instructions(input)?; + blocks.push(BasicBlock { + name: label, + instructions, + }); + } + + Ok(blocks) +} + +fn parse_source_filename(input: &mut Input<'_>) -> PResult { + literal("source_filename").parse_next(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + parse_quoted_string(input) +} + +fn parse_target_directive(module: &mut Module, input: &mut Input<'_>) -> PResult<()> { + literal("target").parse_next(input)?; + ws_no_newline(input)?; + if input.starts_with("datalayout") { + literal("datalayout").parse_next(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + module.target_datalayout = Some(parse_quoted_string(input)?); + } else if input.starts_with("triple") { + literal("triple").parse_next(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + module.target_triple = Some(parse_quoted_string(input)?); + } else { + return Err(ErrMode::Backtrack(ContextError::new())); + } + Ok(()) +} + +fn parse_struct_type(input: &mut Input<'_>) -> PResult { + '%'.parse_next(input)?; + let name = identifier_chars(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + literal("type").parse_next(input)?; + ws_no_newline(input)?; + let is_opaque = if input.starts_with("opaque") { + literal("opaque").parse_next(input)?; + true + } else { + '{'.parse_next(input)?; + '}'.parse_next(input)?; + false + }; + Ok(StructType { name, is_opaque }) +} + +fn parse_global(input: &mut Input<'_>) -> PResult { + '@'.parse_next(input)?; + let name = identifier_chars(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + + let linkage = if input.starts_with("internal") { + literal("internal").parse_next(input)?; + Linkage::Internal + } else if input.starts_with("external") { + literal("external").parse_next(input)?; + Linkage::External + } else { + return Err(ErrMode::Backtrack(ContextError::new())); + }; + + ws_no_newline(input)?; + + let is_constant = if input.starts_with("constant") { + literal("constant").parse_next(input)?; + true + } else if input.starts_with("global") { + literal("global").parse_next(input)?; + false + } else { + return Err(ErrMode::Backtrack(ContextError::new())); + }; + + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + + let initializer = if input.starts_with('c') && input.get(1..2) == Some("\"") { + Some(Constant::CString(parse_c_string(input)?)) + } else if input.starts_with("null") { + literal("null").parse_next(input)?; + Some(Constant::Null) + } else if ty.is_floating_point() { + Some(Constant::float(ty.clone(), parse_float_literal(input)?)) + } else if input + .chars() + .next() + .is_some_and(|c| c.is_ascii_digit() || c == '-') + { + Some(Constant::Int(parse_integer(input)?)) + } else { + None + }; + + Ok(GlobalVariable { + name, + ty, + linkage, + is_constant, + initializer, + }) +} + +fn parse_declaration(input: &mut Input<'_>) -> PResult { + literal("declare").parse_next(input)?; + ws_no_newline(input)?; + let return_type = parse_type(input)?; + ws_no_newline(input)?; + '@'.parse_next(input)?; + let name = identifier_chars(input)?; + '('.parse_next(input)?; + let params = parse_param_list(input)?; + ')'.parse_next(input)?; + + let mut attribute_group_refs = Vec::new(); + ws_no_newline(input)?; + while input.starts_with('#') { + '#'.parse_next(input)?; + let id = parse_u32(input)?; + attribute_group_refs.push(id); + ws_no_newline(input)?; + } + + Ok(Function { + name, + return_type, + params, + is_declaration: true, + attribute_group_refs, + basic_blocks: Vec::new(), + }) +} + +fn parse_definition(input: &mut Input<'_>) -> PResult { + literal("define").parse_next(input)?; + ws_no_newline(input)?; + let return_type = parse_type(input)?; + ws_no_newline(input)?; + '@'.parse_next(input)?; + let name = identifier_chars(input)?; + '('.parse_next(input)?; + let params = parse_param_list(input)?; + ')'.parse_next(input)?; + + let mut attribute_group_refs = Vec::new(); + ws_no_newline(input)?; + while input.starts_with('#') { + '#'.parse_next(input)?; + let id = parse_u32(input)?; + attribute_group_refs.push(id); + ws_no_newline(input)?; + } + + ws_no_newline(input)?; + '{'.parse_next(input)?; + + let basic_blocks = parse_basic_blocks(input)?; + + ws_and_comments(input)?; + '}'.parse_next(input)?; + + Ok(Function { + name, + return_type, + params, + is_declaration: false, + attribute_group_refs, + basic_blocks, + }) +} + +fn parse_param_list(input: &mut Input<'_>) -> PResult> { + let mut params = Vec::new(); + ws_no_newline(input)?; + if input.starts_with(')') { + return Ok(params); + } + loop { + ws_no_newline(input)?; + let ty = parse_type(input)?; + ws_no_newline(input)?; + let name = if input.starts_with('%') { + '%'.parse_next(input)?; + Some(identifier_chars(input)?) + } else { + None + }; + params.push(Param { ty, name }); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } else { + break; + } + } + Ok(params) +} + +fn parse_attribute_group(input: &mut Input<'_>) -> PResult { + literal("attributes").parse_next(input)?; + ws_no_newline(input)?; + '#'.parse_next(input)?; + let id = parse_u32(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + '{'.parse_next(input)?; + ws_no_newline(input)?; + + let mut attributes = Vec::new(); + while !input.starts_with('}') { + let key = parse_quoted_string(input)?; + ws_no_newline(input)?; + if input.starts_with('=') { + '='.parse_next(input)?; + let value = parse_quoted_string(input)?; + attributes.push(Attribute::KeyValue(key, value)); + } else { + attributes.push(Attribute::StringAttr(key)); + } + ws_no_newline(input)?; + } + '}'.parse_next(input)?; + + Ok(AttributeGroup { id, attributes }) +} + +fn parse_named_metadata(input: &mut Input<'_>) -> PResult { + '!'.parse_next(input)?; + let name = identifier_chars(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + '!'.parse_next(input)?; + '{'.parse_next(input)?; + + let mut node_refs = Vec::new(); + ws_no_newline(input)?; + if !input.starts_with('}') { + loop { + ws_no_newline(input)?; + '!'.parse_next(input)?; + let id = parse_u32(input)?; + node_refs.push(id); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } else { + break; + } + } + } + '}'.parse_next(input)?; + + Ok(NamedMetadata { name, node_refs }) +} + +fn parse_metadata_node(input: &mut Input<'_>) -> PResult { + '!'.parse_next(input)?; + let id = parse_u32(input)?; + ws_no_newline(input)?; + '='.parse_next(input)?; + ws_no_newline(input)?; + '!'.parse_next(input)?; + '{'.parse_next(input)?; + + let values = parse_metadata_values(input)?; + + '}'.parse_next(input)?; + + Ok(MetadataNode { id, values }) +} + +fn parse_metadata_values(input: &mut Input<'_>) -> PResult> { + let mut values = Vec::new(); + ws_no_newline(input)?; + if input.starts_with('}') { + return Ok(values); + } + loop { + ws_no_newline(input)?; + let val = parse_metadata_value(input)?; + values.push(val); + ws_no_newline(input)?; + if input.starts_with(',') { + ','.parse_next(input)?; + } else { + break; + } + } + Ok(values) +} + +fn parse_metadata_value(input: &mut Input<'_>) -> PResult { + if input.starts_with('!') { + '!'.parse_next(input)?; + if input.starts_with('"') { + // !"string" + let s = parse_quoted_string(input)?; + return Ok(MetadataValue::String(s)); + } + if input.starts_with('{') { + // !{...} sublist + '{'.parse_next(input)?; + let vals = parse_metadata_values(input)?; + '}'.parse_next(input)?; + return Ok(MetadataValue::SubList(vals)); + } + // !N node reference + let id = parse_u32(input)?; + return Ok(MetadataValue::NodeRef(id)); + } + + // Typed value: i32 42, i1 true, etc. + let ty = parse_type(input)?; + ws_no_newline(input)?; + + if let Type::Integer(1) = &ty { + if input.starts_with("true") { + literal("true").parse_next(input)?; + return Ok(MetadataValue::Int(ty, 1)); + } + if input.starts_with("false") { + literal("false").parse_next(input)?; + return Ok(MetadataValue::Int(ty, 0)); + } + } + + let val = parse_integer(input)?; + Ok(MetadataValue::Int(ty, val)) +} + +fn parse_module_inner(input: &mut Input<'_>) -> PResult { + let mut module = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: Vec::new(), + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + loop { + ws_and_comments(input)?; + + if input.is_empty() { + break; + } + + if input.starts_with("source_filename") { + module.source_filename = Some(parse_source_filename(input)?); + } else if input.starts_with("target") { + parse_target_directive(&mut module, input)?; + } else if input.starts_with('%') { + module.struct_types.push(parse_struct_type(input)?); + } else if input.starts_with('@') { + module.globals.push(parse_global(input)?); + } else if input.starts_with("declare") { + module.functions.push(parse_declaration(input)?); + } else if input.starts_with("define") { + module.functions.push(parse_definition(input)?); + } else if input.starts_with("attributes") { + module.attribute_groups.push(parse_attribute_group(input)?); + } else if input.starts_with('!') { + // Could be named metadata or numbered metadata node + if input + .get(1..2) + .is_some_and(|c| c.chars().next().is_some_and(|ch| ch.is_ascii_digit())) + { + module.metadata_nodes.push(parse_metadata_node(input)?); + } else { + module.named_metadata.push(parse_named_metadata(input)?); + } + } else { + return Err(ErrMode::Backtrack(ContextError::new())); + } + } + + Ok(module) +} + +fn format_read_diagnostics(diagnostics: &[ReadDiagnostic]) -> String { + diagnostics + .iter() + .map(std::string::ToString::to_string) + .collect::>() + .join("; ") +} + +fn parse_module_with_policy(input: &str, _policy: ReadPolicy) -> Result { + let mut inp: Input<'_> = input; + parse_module_inner + .parse_next(&mut inp) + .map_err(|error| ReadDiagnostic { + kind: ReadDiagnosticKind::MalformedInput, + offset: Some(input.len().saturating_sub(inp.len())), + context: "text IR", + message: format!("parse error: {error}"), + }) +} + +pub fn parse_module_detailed( + input: &str, + policy: ReadPolicy, +) -> Result> { + parse_module_with_policy(input, policy).map_err(|error| vec![error]) +} + +/// Parse LLVM text IR into a `Module`. +/// +/// This is a drop-in replacement for the hand-written recursive-descent parser +/// in `text_reader.rs`, implemented using winnow combinators. +pub fn parse_module(input: &str) -> Result { + parse_module_detailed(input, ReadPolicy::QirSubsetStrict) + .map_err(|diagnostics| format_read_diagnostics(&diagnostics)) +} + +pub fn parse_module_compatibility(input: &str) -> Result { + parse_module_detailed(input, ReadPolicy::Compatibility) + .map_err(|diagnostics| format_read_diagnostics(&diagnostics)) +} diff --git a/source/compiler/qsc_llvm/src/text/reader/tests.rs b/source/compiler/qsc_llvm/src/text/reader/tests.rs new file mode 100644 index 0000000000..fbb8fcdc44 --- /dev/null +++ b/source/compiler/qsc_llvm/src/text/reader/tests.rs @@ -0,0 +1,849 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::model::test_helpers::*; +use crate::text::writer::write_module_to_string; +use crate::{ReadDiagnosticKind, ReadPolicy}; + +// Helper: round-trip an instruction through write → parse → write +fn round_trip_instruction(instr: Instruction) { + let m = single_instruction_module(instr); + let text = write_module_to_string(&m); + let parsed = parse_module(&text).expect("winnow parse failed"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text, text2, "text round-trip mismatch"); + assert_eq!(m, parsed, "model equality mismatch"); +} + +// --- Source filename --- + +#[test] +fn parse_source_filename_test() { + let input = "source_filename = \"qir\"\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.source_filename.as_deref(), Some("qir")); +} + +// --- Target directives --- + +#[test] +fn parse_target_datalayout_test() { + let input = "target datalayout = \"e-m:e-i64:64-f80:128\"\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.target_datalayout.as_deref(), Some("e-m:e-i64:64-f80:128")); +} + +#[test] +fn parse_target_triple_test() { + let input = "target triple = \"x86_64-unknown-linux-gnu\"\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.target_triple.as_deref(), Some("x86_64-unknown-linux-gnu")); +} + +// --- Struct types --- + +#[test] +fn parse_opaque_struct_test() { + let input = "%Qubit = type opaque\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.struct_types.len(), 1); + assert_eq!(m.struct_types[0].name, "Qubit"); + assert!(m.struct_types[0].is_opaque); +} + +// --- Global variables --- + +#[test] +fn parse_global_string_constant_test() { + let input = r#"@0 = internal constant [4 x i8] c"0_r\00""#; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.globals.len(), 1); + assert_eq!(m.globals[0].name, "0"); + assert!(m.globals[0].is_constant); + assert!(matches!(m.globals[0].linkage, Linkage::Internal)); + assert_eq!( + m.globals[0].initializer, + Some(Constant::CString("0_r".to_string())) + ); +} + +// --- Function declarations --- + +#[test] +fn parse_void_single_ptr_declaration_test() { + let input = "declare void @__quantum__rt__initialize(ptr)\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.functions.len(), 1); + let f = &m.functions[0]; + assert!(f.is_declaration); + assert_eq!(f.name, "__quantum__rt__initialize"); + assert_eq!(f.return_type, Type::Void); + assert_eq!(f.params.len(), 1); + assert_eq!(f.params[0].ty, Type::Ptr); +} + +#[test] +fn parse_declaration_with_attr_ref_test() { + let input = "declare void @__quantum__qis__m__body(ptr, ptr) #1\n"; + let m = parse_module(input).expect("parse failed"); + let f = &m.functions[0]; + assert_eq!(f.attribute_group_refs, vec![1]); +} + +// --- Function definitions --- + +#[test] +fn parse_simple_definition_test() { + let input = r#"define i64 @ENTRYPOINT__main() #0 { +block_0: + ret i64 0 +} +"#; + let m = parse_module(input).expect("parse failed"); + let f = &m.functions[0]; + assert!(!f.is_declaration); + assert_eq!(f.name, "ENTRYPOINT__main"); + assert_eq!(f.return_type, Type::Integer(64)); + assert_eq!(f.attribute_group_refs, vec![0]); + assert_eq!(f.basic_blocks.len(), 1); + assert_eq!(f.basic_blocks[0].name, "block_0"); +} + +// --- Instructions --- + +#[test] +fn parse_ret_void_test() { + let input = "define void @f() {\nentry:\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert!(matches!(instr, Instruction::Ret(None))); +} + +#[test] +fn parse_ret_i64_test() { + let input = "define i64 @f() {\nentry:\n ret i64 0\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))) + ); +} + +#[test] +fn parse_br_conditional_test() { + let input = "define void @f() {\nentry:\n br i1 %var_0, label %block_1, label %block_2\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("var_0".into()), + true_dest: "block_1".into(), + false_dest: "block_2".into(), + } + ); +} + +#[test] +fn parse_br_bool_const_test() { + let input = "define void @f() {\nentry:\n br i1 true, label %block_1, label %block_2\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "block_1".into(), + false_dest: "block_2".into(), + } + ); +} + +#[test] +fn parse_jump_test() { + let input = "define void @f() {\nentry:\n br label %block_1\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Jump { + dest: "block_1".into() + } + ); +} + +#[test] +fn parse_binop_add_test() { + let input = "define void @f() {\nentry:\n %var_2 = add i64 %var_0, %var_1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + } + ); +} + +#[test] +fn parse_binop_sub_const_test() { + let input = "define void @f() {\nentry:\n %var_1 = sub i64 %var_0, 1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::BinOp { + op: BinOpKind::Sub, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "var_1".into(), + } + ); +} + +#[test] +fn parse_binop_xor_bool_test() { + let input = "define void @f() {\nentry:\n %var_1 = xor i1 %var_0, true\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::BinOp { + op: BinOpKind::Xor, + ty: Type::Integer(1), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(1), 1), + result: "var_1".into(), + } + ); +} + +#[test] +fn parse_binop_fadd_test() { + let input = + "define void @f() {\nentry:\n %var_2 = fadd double %var_0, %var_1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + } + ); +} + +#[test] +fn parse_icmp_eq_test() { + let input = "define void @f() {\nentry:\n %var_1 = icmp eq i64 %var_0, 42\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::ICmp { + pred: IntPredicate::Eq, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), 42), + result: "var_1".into(), + } + ); +} + +#[test] +fn parse_fcmp_oeq_test() { + let input = + "define void @f() {\nentry:\n %var_2 = fcmp oeq double %var_0, %var_1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::FCmp { + pred: FloatPredicate::Oeq, + ty: Type::Double, + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + } + ); +} + +#[test] +fn parse_cast_sitofp_test() { + let input = + "define void @f() {\nentry:\n %var_1 = sitofp i64 %var_0 to double\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Cast { + op: CastKind::Sitofp, + from_ty: Type::Integer(64), + to_ty: Type::Double, + value: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + } + ); +} + +#[test] +fn parse_call_void_test() { + let input = "define void @f() {\nentry:\n call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr))\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: vec![], + } + ); +} + +#[test] +fn parse_call_with_return_test() { + let input = "define void @f() {\nentry:\n %var_0 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr))\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Call { + return_ty: Some(Type::Integer(1)), + callee: "__quantum__rt__read_result".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: Some("var_0".into()), + attr_refs: vec![], + } + ); +} + +#[test] +fn parse_call_with_attr_ref_test() { + let input = "define void @f() {\nentry:\n call void @__quantum__qis__m__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) #1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Call { + return_ty: None, + callee: "__quantum__qis__m__body".into(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + ], + result: None, + attr_refs: vec![1], + } + ); +} + +#[test] +fn parse_call_named_ptr_inttoptr_test() { + let input = "%Qubit = type opaque\n\ndeclare void @takes_qubit(%Qubit*)\ndefine void @f() {\nentry:\n call void @takes_qubit(%Qubit* inttoptr (i64 0 to %Qubit*))\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.functions[0].params[0].ty, Type::NamedPtr("Qubit".into())); + let instr = &m.functions[1].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Call { + return_ty: None, + callee: "takes_qubit".into(), + args: vec![( + Type::NamedPtr("Qubit".into()), + Operand::int_to_named_ptr(0, "Qubit"), + )], + result: None, + attr_refs: vec![], + } + ); +} + +#[test] +fn parse_call_named_ptr_inttoptr_mismatched_target_rejected_test() { + let input = "%Qubit = type opaque\n%Result = type opaque\n\ndeclare void @takes_qubit(%Qubit*)\ndefine void @f() {\nentry:\n call void @takes_qubit(%Qubit* inttoptr (i64 0 to %Result*))\n ret void\n}\n"; + let result = parse_module(input); + assert!( + result.is_err(), + "mismatched named-pointer inttoptr should fail to parse" + ); +} + +#[test] +fn parse_call_with_global_ref_test() { + let input = "define void @f() {\nentry:\n call void @__quantum__rt__array_record_output(i64 2, ptr @0)\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Call { + return_ty: None, + callee: "__quantum__rt__array_record_output".into(), + args: vec![ + (Type::Integer(64), Operand::IntConst(Type::Integer(64), 2)), + (Type::Ptr, Operand::GlobalRef("0".into())), + ], + result: None, + attr_refs: vec![], + } + ); +} + +#[test] +fn parse_phi_i1_test() { + let input = "define void @f() {\nentry:\n %var_3 = phi i1 [true, %block_0], [%var_2, %block_1]\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Phi { + ty: Type::Integer(1), + incoming: vec![ + (Operand::IntConst(Type::Integer(1), 1), "block_0".into()), + (Operand::LocalRef("var_2".into()), "block_1".into()), + ], + result: "var_3".into(), + } + ); +} + +#[test] +fn parse_alloca_test() { + let input = "define void @f() {\nentry:\n %var_0 = alloca i1\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Alloca { + ty: Type::Integer(1), + result: "var_0".into(), + } + ); +} + +#[test] +fn parse_load_test() { + let input = "define void @f() {\nentry:\n %var_1 = load i1, ptr %var_0\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Load { + ty: Type::Integer(1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + } + ); +} + +#[test] +fn parse_store_test() { + let input = "define void @f() {\nentry:\n store i1 true, ptr %var_0\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Store { + ty: Type::Integer(1), + value: Operand::IntConst(Type::Integer(1), 1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + } + ); +} + +#[test] +fn parse_store_half_const_test() { + let input = "define void @f() {\nentry:\n store half 1.5, ptr %var_0\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Store { + ty: Type::Half, + value: Operand::float_const(Type::Half, 1.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + } + ); +} + +#[test] +fn parse_store_float_const_test() { + let input = "define void @f() {\nentry:\n store float 2.5, ptr %var_0\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Store { + ty: Type::Float, + value: Operand::float_const(Type::Float, 2.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + } + ); +} + +#[test] +fn parse_store_double_const_test() { + let input = "define void @f() {\nentry:\n store double 3.5, ptr %var_0\n ret void\n}\n"; + let m = parse_module(input).expect("parse failed"); + let instr = &m.functions[0].basic_blocks[0].instructions[0]; + assert_eq!( + *instr, + Instruction::Store { + ty: Type::Double, + value: Operand::float_const(Type::Double, 3.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + } + ); +} + +// --- Attribute groups --- + +#[test] +fn parse_attribute_group_test() { + let input = r#"attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="1" } +"#; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.attribute_groups.len(), 1); + let ag = &m.attribute_groups[0]; + assert_eq!(ag.id, 0); + assert_eq!(ag.attributes.len(), 5); + assert_eq!( + ag.attributes[0], + Attribute::StringAttr("entry_point".into()) + ); + assert_eq!( + ag.attributes[2], + Attribute::KeyValue("qir_profiles".into(), "adaptive_profile".into()) + ); +} + +// --- Metadata --- + +#[test] +fn parse_named_metadata_test() { + let input = "!llvm.module.flags = !{!0, !1, !2}\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.named_metadata.len(), 1); + let nm = &m.named_metadata[0]; + assert_eq!(nm.name, "llvm.module.flags"); + assert_eq!(nm.node_refs, vec![0, 1, 2]); +} + +#[test] +fn parse_metadata_node_int_and_string_test() { + let input = "!0 = !{i32 1, !\"qir_major_version\", i32 2}\n"; + let m = parse_module(input).expect("parse failed"); + let node = &m.metadata_nodes[0]; + assert_eq!(node.id, 0); + assert_eq!(node.values.len(), 3); + assert_eq!(node.values[0], MetadataValue::Int(Type::Integer(32), 1)); + assert_eq!( + node.values[1], + MetadataValue::String("qir_major_version".into()) + ); + assert_eq!(node.values[2], MetadataValue::Int(Type::Integer(32), 2)); +} + +#[test] +fn parse_metadata_node_with_bool_test() { + let input = "!2 = !{i32 1, !\"dynamic_qubit_management\", i1 false}\n"; + let m = parse_module(input).expect("parse failed"); + let node = &m.metadata_nodes[0]; + assert_eq!(node.values[2], MetadataValue::Int(Type::Integer(1), 0)); +} + +#[test] +fn parse_metadata_node_with_sublist_test() { + let input = "!4 = !{i32 5, !\"int_computations\", !{!\"i64\"}}\n"; + let m = parse_module(input).expect("parse failed"); + let node = &m.metadata_nodes[0]; + assert_eq!( + node.values[2], + MetadataValue::SubList(vec![MetadataValue::String("i64".into())]) + ); +} + +// --- Comments --- + +#[test] +fn parse_skips_comments_test() { + let input = "; this is a comment\nsource_filename = \"qir\"\n; another comment\n"; + let m = parse_module(input).expect("parse failed"); + assert_eq!(m.source_filename.as_deref(), Some("qir")); +} + +// --- Error tests --- + +#[test] +fn parse_error_on_invalid_input_test() { + let result = parse_module("invalid_keyword at module level"); + assert!(result.is_err()); +} + +#[test] +fn parse_module_detailed_reports_structured_error_on_invalid_input_test() { + let diagnostics = parse_module_detailed( + "invalid_keyword at module level", + ReadPolicy::QirSubsetStrict, + ) + .expect_err("invalid text IR should surface a structured diagnostic"); + + assert_eq!(diagnostics.len(), 1); + assert_eq!(diagnostics[0].kind, ReadDiagnosticKind::MalformedInput); + assert_eq!(diagnostics[0].context, "text IR"); +} + +#[test] +fn strict_text_import_rejects_non_opaque_struct_body_fixture() { + let diagnostics = + parse_module_detailed("%Pair = type { i64, i64 }\n", ReadPolicy::QirSubsetStrict) + .expect_err("non-opaque struct bodies should remain unsupported"); + + assert_eq!(diagnostics.len(), 1); + assert_eq!(diagnostics[0].kind, ReadDiagnosticKind::MalformedInput); + assert_eq!(diagnostics[0].context, "text IR"); +} + +// --- Round-trip tests --- + +#[test] +fn round_trip_empty_module_test() { + let m = empty_module(); + let text = write_module_to_string(&m); + let parsed = parse_module(&text).expect("parse failed"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text, text2); +} + +#[test] +fn round_trip_source_filename_test() { + let input = "source_filename = \"qir\"\n"; + let m = parse_module(input).expect("parse failed"); + let text = write_module_to_string(&m); + let m2 = parse_module(&text).expect("parse failed"); + assert_eq!(m, m2); +} + +#[test] +fn round_trip_global_variables_test() { + let input = r#"@0 = internal constant [4 x i8] c"0_r\00" +@1 = internal constant [6 x i8] c"1_a0r\00" +"#; + let m = parse_module(input).expect("parse failed"); + let text = write_module_to_string(&m); + let m2 = parse_module(&text).expect("parse failed"); + assert_eq!(m, m2); +} + +#[test] +fn round_trip_ret_void_test() { + round_trip_instruction(Instruction::Ret(None)); +} + +#[test] +fn round_trip_ret_value_test() { + round_trip_instruction(Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))); +} + +#[test] +fn round_trip_binop_add_test() { + round_trip_instruction(Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }); +} + +#[test] +fn round_trip_icmp_test() { + round_trip_instruction(Instruction::ICmp { + pred: IntPredicate::Eq, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), 42), + result: "var_1".into(), + }); +} + +#[test] +fn round_trip_fcmp_test() { + round_trip_instruction(Instruction::FCmp { + pred: FloatPredicate::Oeq, + ty: Type::Double, + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }); +} + +#[test] +fn round_trip_cast_test() { + round_trip_instruction(Instruction::Cast { + op: CastKind::Sitofp, + from_ty: Type::Integer(64), + to_ty: Type::Double, + value: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + }); +} + +#[test] +fn round_trip_call_void_test() { + round_trip_instruction(Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: vec![], + }); +} + +#[test] +fn round_trip_call_with_return_test() { + round_trip_instruction(Instruction::Call { + return_ty: Some(Type::Integer(1)), + callee: "__quantum__rt__read_result".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: Some("var_0".into()), + attr_refs: vec![], + }); +} + +#[test] +fn round_trip_phi_test() { + round_trip_instruction(Instruction::Phi { + ty: Type::Integer(1), + incoming: vec![ + (Operand::IntConst(Type::Integer(1), 1), "block_0".into()), + (Operand::LocalRef("var_2".into()), "block_1".into()), + ], + result: "var_3".into(), + }); +} + +#[test] +fn round_trip_alloca_test() { + round_trip_instruction(Instruction::Alloca { + ty: Type::Integer(1), + result: "var_0".into(), + }); +} + +#[test] +fn round_trip_load_test() { + round_trip_instruction(Instruction::Load { + ty: Type::Integer(1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + }); +} + +#[test] +fn round_trip_store_test() { + round_trip_instruction(Instruction::Store { + ty: Type::Integer(1), + value: Operand::IntConst(Type::Integer(1), 1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + }); +} + +#[test] +fn round_trip_select_test() { + round_trip_instruction(Instruction::Select { + cond: Operand::LocalRef("var_0".into()), + true_val: Operand::IntConst(Type::Integer(64), 1), + false_val: Operand::IntConst(Type::Integer(64), 2), + ty: Type::Integer(64), + result: "var_1".into(), + }); +} + +#[test] +fn round_trip_switch_test() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Switch { + ty: Type::Integer(64), + value: Operand::LocalRef("var_0".into()), + default_dest: "block_default".into(), + cases: vec![(0, "block_0".into()), (1, "block_1".into())], + }], + }, + BasicBlock { + name: "block_0".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + BasicBlock { + name: "block_1".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + BasicBlock { + name: "block_default".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let text = write_module_to_string(&m); + let parsed = parse_module(&text).expect("winnow parse failed"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text, text2); + assert_eq!(m, parsed); +} + +#[test] +fn round_trip_unreachable_test() { + round_trip_instruction(Instruction::Unreachable); +} + +#[test] +fn round_trip_bell_module_v2_test() { + let m = bell_module_v2(); + let text = write_module_to_string(&m); + let parsed = parse_module(&text).expect("winnow parse failed"); + let text2 = write_module_to_string(&parsed); + assert_eq!(text, text2); + assert_eq!(m, parsed); +} diff --git a/source/compiler/qsc_llvm/src/text/writer.rs b/source/compiler/qsc_llvm/src/text/writer.rs new file mode 100644 index 0000000000..cf39d62f80 --- /dev/null +++ b/source/compiler/qsc_llvm/src/text/writer.rs @@ -0,0 +1,570 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use std::fmt::Write; + +use crate::model::Type; +use crate::model::{ + Attribute, AttributeGroup, BasicBlock, BinOpKind, CastKind, Constant, FloatPredicate, Function, + GlobalVariable, Instruction, IntPredicate, Linkage, MetadataNode, MetadataValue, Module, + NamedMetadata, Operand, StructType, +}; + +#[must_use] +pub fn write_module_to_string(module: &Module) -> String { + let mut buf = String::new(); + write_module(&mut buf, module).expect("writing to string should succeed"); + buf +} + +pub fn write_module(w: &mut dyn Write, module: &Module) -> Result<(), std::fmt::Error> { + // 1. source_filename + if let Some(ref name) = module.source_filename { + writeln!(w, "source_filename = \"{name}\"")?; + } + + // 2. struct types + for st in &module.struct_types { + write_struct_type(w, st)?; + writeln!(w)?; + } + + // 3. blank line + globals + if !module.globals.is_empty() { + if !module.struct_types.is_empty() { + writeln!(w)?; + } + for g in &module.globals { + write_global(w, g)?; + writeln!(w)?; + } + } + + // 4. functions (declarations and definitions in original order) + if !module.functions.is_empty() { + writeln!(w)?; + } + + for f in &module.functions { + write_function(w, f)?; + writeln!(w)?; + writeln!(w)?; + } + + // 6. attribute groups + for ag in &module.attribute_groups { + write_attribute_group(w, ag)?; + writeln!(w)?; + } + + // 7. named metadata header comment + named metadata + if !module.named_metadata.is_empty() { + writeln!(w)?; + writeln!(w, "; module flags")?; + writeln!(w)?; + for nm in &module.named_metadata { + write_named_metadata(w, nm)?; + writeln!(w)?; + } + } + + // 8. metadata nodes + if !module.metadata_nodes.is_empty() { + writeln!(w)?; + for node in &module.metadata_nodes { + write_metadata_node(w, node)?; + writeln!(w)?; + } + } + + Ok(()) +} + +fn write_struct_type(w: &mut dyn Write, st: &StructType) -> Result<(), std::fmt::Error> { + if st.is_opaque { + write!(w, "%{} = type opaque", st.name) + } else { + write!(w, "%{} = type {{}}", st.name) + } +} + +fn write_global(w: &mut dyn Write, g: &GlobalVariable) -> Result<(), std::fmt::Error> { + let linkage = match g.linkage { + Linkage::Internal => "internal", + Linkage::External => "external", + }; + let kind = if g.is_constant { "constant" } else { "global" }; + write!(w, "@{} = {linkage} {kind} {}", g.name, g.ty)?; + if let Some(ref init) = g.initializer { + write!(w, " ")?; + write_constant(w, init)?; + } + Ok(()) +} + +fn write_constant(w: &mut dyn Write, c: &Constant) -> Result<(), std::fmt::Error> { + match c { + Constant::CString(s) => { + write!(w, "c\"{s}\\00\"") + } + Constant::Int(i) => write!(w, "{i}"), + Constant::Float(_, f) => write_float(w, *f), + Constant::Null => write!(w, "null"), + } +} + +fn write_float(w: &mut dyn Write, f: f64) -> Result<(), std::fmt::Error> { + if (f.floor() - f.ceil()).abs() < f64::EPSILON { + write!(w, "{f:.1}") + } else { + write!(w, "{f}") + } +} + +fn write_function(w: &mut dyn Write, f: &Function) -> Result<(), std::fmt::Error> { + if f.is_declaration { + write!(w, "declare {} @{}(", f.return_type, f.name)?; + write_param_list(w, f)?; + write!(w, ")")?; + for attr_ref in &f.attribute_group_refs { + write!(w, " #{attr_ref}")?; + } + } else { + write!(w, "define {} @{}(", f.return_type, f.name)?; + write_param_list(w, f)?; + write!(w, ")")?; + for attr_ref in &f.attribute_group_refs { + write!(w, " #{attr_ref}")?; + } + writeln!(w, " {{")?; + for bb in &f.basic_blocks { + write_basic_block(w, bb)?; + } + write!(w, "}}")?; + } + Ok(()) +} + +fn write_param_list(w: &mut dyn Write, f: &Function) -> Result<(), std::fmt::Error> { + for (i, p) in f.params.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write!(w, "{}", p.ty)?; + if let Some(ref name) = p.name { + write!(w, " %{name}")?; + } + } + Ok(()) +} + +fn write_basic_block(w: &mut dyn Write, bb: &BasicBlock) -> Result<(), std::fmt::Error> { + writeln!(w, "{}:", bb.name)?; + for instr in &bb.instructions { + write_instruction(w, instr)?; + writeln!(w)?; + } + Ok(()) +} + +#[allow(clippy::too_many_lines)] +fn write_instruction(w: &mut dyn Write, instr: &Instruction) -> Result<(), std::fmt::Error> { + match instr { + Instruction::Ret(None) => write!(w, " ret void"), + Instruction::Ret(Some(operand)) => { + write!(w, " ret ")?; + write_typed_operand(w, operand) + } + Instruction::Br { + cond_ty, + cond, + true_dest, + false_dest, + } => { + write!(w, " br {cond_ty} ")?; + write_untyped_operand(w, cond)?; + write!(w, ", label %{true_dest}, label %{false_dest}") + } + Instruction::Jump { dest } => { + write!(w, " br label %{dest}") + } + Instruction::BinOp { + op, + ty, + lhs, + rhs, + result, + } => { + let op_str = binop_name(op); + write!(w, " %{result} = {op_str} {ty} ")?; + write_untyped_operand(w, lhs)?; + write!(w, ", ")?; + write_untyped_operand(w, rhs) + } + Instruction::ICmp { + pred, + ty, + lhs, + rhs, + result, + } => { + let pred_str = icmp_pred_name(pred); + write!(w, " %{result} = icmp {pred_str} {ty} ")?; + write_untyped_operand(w, lhs)?; + write!(w, ", ")?; + write_untyped_operand(w, rhs) + } + Instruction::FCmp { + pred, + ty, + lhs, + rhs, + result, + } => { + let pred_str = fcmp_pred_name(pred); + write!(w, " %{result} = fcmp {pred_str} {ty} ")?; + write_untyped_operand(w, lhs)?; + write!(w, ", ")?; + write_untyped_operand(w, rhs) + } + Instruction::Cast { + op, + from_ty, + to_ty, + value, + result, + } => { + let op_str = cast_name(op); + write!(w, " %{result} = {op_str} {from_ty} ")?; + write_untyped_operand(w, value)?; + write!(w, " to {to_ty}") + } + Instruction::Call { + return_ty, + callee, + args, + result, + attr_refs, + } => { + let ret_ty = return_ty.as_ref().map_or(Type::Void, Clone::clone); + if let Some(r) = result { + write!(w, " %{r} = call {ret_ty} @{callee}(")?; + } else { + write!(w, " call {ret_ty} @{callee}(")?; + } + for (i, (ty, op)) in args.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write!(w, "{ty} ")?; + write_untyped_operand(w, op)?; + } + write!(w, ")")?; + for attr_ref in attr_refs { + write!(w, " #{attr_ref}")?; + } + Ok(()) + } + Instruction::Phi { + ty, + incoming, + result, + } => { + write!(w, " %{result} = phi {ty} ")?; + for (i, (val, block)) in incoming.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write!(w, "[")?; + write_untyped_operand(w, val)?; + write!(w, ", %{block}]")?; + } + Ok(()) + } + Instruction::Alloca { ty, result } => { + write!(w, " %{result} = alloca {ty}") + } + Instruction::Load { + ty, + ptr_ty, + ptr, + result, + } => { + write!(w, " %{result} = load {ty}, {ptr_ty} ")?; + write_untyped_operand(w, ptr) + } + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + write!(w, " store {ty} ")?; + write_untyped_operand(w, value)?; + write!(w, ", {ptr_ty} ")?; + write_untyped_operand(w, ptr) + } + Instruction::Select { + cond, + true_val, + false_val, + ty, + result, + } => { + write!(w, " %{result} = select i1 ")?; + write_untyped_operand(w, cond)?; + write!(w, ", {ty} ")?; + write_untyped_operand(w, true_val)?; + write!(w, ", {ty} ")?; + write_untyped_operand(w, false_val) + } + Instruction::Switch { + ty, + value, + default_dest, + cases, + } => { + write!(w, " switch {ty} ")?; + write_untyped_operand(w, value)?; + writeln!(w, ", label %{default_dest} [")?; + for (val, dest) in cases { + writeln!(w, " {ty} {val}, label %{dest}")?; + } + write!(w, " ]") + } + Instruction::Unreachable => write!(w, " unreachable"), + Instruction::GetElementPtr { + inbounds, + pointee_ty, + ptr_ty, + ptr, + indices, + result, + } => { + write!(w, " %{result} = getelementptr ")?; + if *inbounds { + write!(w, "inbounds ")?; + } + write!(w, "{pointee_ty}, {ptr_ty} ")?; + write_untyped_operand(w, ptr)?; + for idx in indices { + write!(w, ", ")?; + write_typed_operand(w, idx)?; + } + Ok(()) + } + } +} + +fn write_typed_operand(w: &mut dyn Write, op: &Operand) -> Result<(), std::fmt::Error> { + match op { + Operand::LocalRef(name) => { + write!(w, "%{name}") + } + Operand::TypedLocalRef(name, ty) => { + write!(w, "{ty} %{name}") + } + Operand::IntConst(ty, val) => { + write!(w, "{ty} ")?; + write_int_value(w, ty, *val) + } + Operand::FloatConst(ty, f) => { + write!(w, "{ty} ")?; + write_float(w, *f) + } + Operand::NullPtr => write!(w, "ptr null"), + Operand::IntToPtr(val, ty) => { + write!(w, "{ty} inttoptr (i64 {val} to {ty})") + } + Operand::GetElementPtr { + ty, + ptr, + ptr_ty, + indices, + } => { + write!(w, "{ptr_ty} getelementptr inbounds ({ty}, {ptr_ty} @{ptr}")?; + for idx in indices { + write!(w, ", ")?; + write_typed_operand(w, idx)?; + } + write!(w, ")") + } + Operand::GlobalRef(name) => write!(w, "ptr @{name}"), + } +} + +fn write_untyped_operand(w: &mut dyn Write, op: &Operand) -> Result<(), std::fmt::Error> { + match op { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => write!(w, "%{name}"), + Operand::IntConst(ty, val) => write_int_value(w, ty, *val), + Operand::FloatConst(_, f) => write_float(w, *f), + Operand::NullPtr => write!(w, "null"), + Operand::IntToPtr(val, ty) => { + write!(w, "inttoptr (i64 {val} to {ty})") + } + Operand::GetElementPtr { + ty, + ptr, + ptr_ty, + indices, + } => { + write!(w, "getelementptr inbounds ({ty}, {ptr_ty} @{ptr}")?; + for idx in indices { + write!(w, ", ")?; + write_typed_operand(w, idx)?; + } + write!(w, ")") + } + Operand::GlobalRef(name) => write!(w, "@{name}"), + } +} + +fn write_int_value(w: &mut dyn Write, ty: &Type, val: i64) -> Result<(), std::fmt::Error> { + if let Type::Integer(1) = ty { + if val == 0 { + write!(w, "false") + } else { + write!(w, "true") + } + } else { + write!(w, "{val}") + } +} + +fn write_attribute_group(w: &mut dyn Write, ag: &AttributeGroup) -> Result<(), std::fmt::Error> { + write!(w, "attributes #{} = {{ ", ag.id)?; + for (i, attr) in ag.attributes.iter().enumerate() { + if i > 0 { + write!(w, " ")?; + } + match attr { + Attribute::StringAttr(s) => write!(w, "\"{s}\"")?, + Attribute::KeyValue(k, v) => write!(w, "\"{k}\"=\"{v}\"")?, + } + } + write!(w, " }}") +} + +fn write_named_metadata(w: &mut dyn Write, nm: &NamedMetadata) -> Result<(), std::fmt::Error> { + write!(w, "!{} = !{{", nm.name)?; + for (i, node_ref) in nm.node_refs.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write!(w, "!{node_ref}")?; + } + write!(w, "}}") +} + +fn write_metadata_node(w: &mut dyn Write, node: &MetadataNode) -> Result<(), std::fmt::Error> { + write!(w, "!{} = !{{", node.id)?; + for (i, val) in node.values.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write_metadata_value(w, val)?; + } + write!(w, "}}") +} + +fn write_metadata_value(w: &mut dyn Write, val: &MetadataValue) -> Result<(), std::fmt::Error> { + match val { + MetadataValue::Int(ty, v) => { + if let Type::Integer(1) = ty { + if *v == 0 { + write!(w, "{ty} false") + } else { + write!(w, "{ty} true") + } + } else { + write!(w, "{ty} {v}") + } + } + MetadataValue::String(s) => write!(w, "!\"{s}\""), + MetadataValue::NodeRef(id) => write!(w, "!{id}"), + MetadataValue::SubList(vals) => { + write!(w, "!{{")?; + for (i, v) in vals.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + write_metadata_value(w, v)?; + } + write!(w, "}}") + } + } +} + +fn binop_name(op: &BinOpKind) -> &'static str { + match op { + BinOpKind::Add => "add", + BinOpKind::Sub => "sub", + BinOpKind::Mul => "mul", + BinOpKind::Sdiv => "sdiv", + BinOpKind::Srem => "srem", + BinOpKind::Shl => "shl", + BinOpKind::Ashr => "ashr", + BinOpKind::And => "and", + BinOpKind::Or => "or", + BinOpKind::Xor => "xor", + BinOpKind::Fadd => "fadd", + BinOpKind::Fsub => "fsub", + BinOpKind::Fmul => "fmul", + BinOpKind::Fdiv => "fdiv", + BinOpKind::Udiv => "udiv", + BinOpKind::Urem => "urem", + BinOpKind::Lshr => "lshr", + } +} + +fn icmp_pred_name(pred: &IntPredicate) -> &'static str { + match pred { + IntPredicate::Eq => "eq", + IntPredicate::Ne => "ne", + IntPredicate::Sgt => "sgt", + IntPredicate::Sge => "sge", + IntPredicate::Slt => "slt", + IntPredicate::Sle => "sle", + IntPredicate::Ult => "ult", + IntPredicate::Ule => "ule", + IntPredicate::Ugt => "ugt", + IntPredicate::Uge => "uge", + } +} + +fn fcmp_pred_name(pred: &FloatPredicate) -> &'static str { + match pred { + FloatPredicate::Oeq => "oeq", + FloatPredicate::Ogt => "ogt", + FloatPredicate::Oge => "oge", + FloatPredicate::Olt => "olt", + FloatPredicate::Ole => "ole", + FloatPredicate::One => "one", + FloatPredicate::Ord => "ord", + FloatPredicate::Uno => "uno", + FloatPredicate::Ueq => "ueq", + FloatPredicate::Ugt => "ugt", + FloatPredicate::Uge => "uge", + FloatPredicate::Ult => "ult", + FloatPredicate::Ule => "ule", + FloatPredicate::Une => "une", + } +} + +fn cast_name(op: &CastKind) -> &'static str { + match op { + CastKind::Sitofp => "sitofp", + CastKind::Fptosi => "fptosi", + CastKind::Zext => "zext", + CastKind::Sext => "sext", + CastKind::Trunc => "trunc", + CastKind::FpExt => "fpext", + CastKind::FpTrunc => "fptrunc", + CastKind::IntToPtr => "inttoptr", + CastKind::PtrToInt => "ptrtoint", + CastKind::Bitcast => "bitcast", + } +} diff --git a/source/compiler/qsc_llvm/src/text/writer/tests.rs b/source/compiler/qsc_llvm/src/text/writer/tests.rs new file mode 100644 index 0000000000..178bf36b9e --- /dev/null +++ b/source/compiler/qsc_llvm/src/text/writer/tests.rs @@ -0,0 +1,779 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::model::test_helpers::*; +use crate::model::*; +use expect_test::expect; + +// --- Struct type emission --- + +#[test] +fn struct_type_opaque() { + let st = StructType { + name: "Qubit".into(), + is_opaque: true, + }; + let mut buf = String::new(); + write_struct_type(&mut buf, &st).expect("failed to write"); + expect!["%Qubit = type opaque"].assert_eq(&buf); +} + +// --- Global variable emission --- + +#[test] +fn global_string_constant() { + let g = GlobalVariable { + name: "0".into(), + ty: Type::Array(4, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("0_r".into())), + }; + let mut buf = String::new(); + write_global(&mut buf, &g).expect("failed to write"); + expect![r#"@0 = internal constant [4 x i8] c"0_r\00""#].assert_eq(&buf); +} + +#[test] +fn global_string_longer() { + let g = GlobalVariable { + name: "1".into(), + ty: Type::Array(6, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("1_a0r".into())), + }; + let mut buf = String::new(); + write_global(&mut buf, &g).expect("failed to write"); + expect![r#"@1 = internal constant [6 x i8] c"1_a0r\00""#].assert_eq(&buf); +} + +// --- Function declaration emission --- + +#[test] +fn void_single_ptr_declaration() { + let f = Function { + name: "__quantum__rt__initialize".into(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: vec![], + basic_blocks: vec![], + }; + let mut buf = String::new(); + write_function(&mut buf, &f).expect("failed to write"); + expect!["declare void @__quantum__rt__initialize(ptr)"].assert_eq(&buf); +} + +#[test] +fn void_two_ptr_declaration() { + let f = Function { + name: "__quantum__qis__cx__body".into(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: vec![], + basic_blocks: vec![], + }; + let mut buf = String::new(); + write_function(&mut buf, &f).expect("failed to write"); + expect!["declare void @__quantum__qis__cx__body(ptr, ptr)"].assert_eq(&buf); +} + +#[test] +fn declaration_with_return_type() { + let f = Function { + name: "__quantum__rt__read_result".into(), + return_type: Type::Integer(1), + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: vec![], + basic_blocks: vec![], + }; + let mut buf = String::new(); + write_function(&mut buf, &f).expect("failed to write"); + expect!["declare i1 @__quantum__rt__read_result(ptr)"].assert_eq(&buf); +} + +#[test] +fn declaration_with_attr_ref() { + let f = Function { + name: "__quantum__qis__m__body".into(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Ptr, + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: vec![1], + basic_blocks: vec![], + }; + let mut buf = String::new(); + write_function(&mut buf, &f).expect("failed to write"); + expect!["declare void @__quantum__qis__m__body(ptr, ptr) #1"].assert_eq(&buf); +} + +// --- Function definition emission --- + +#[test] +fn simple_definition() { + let f = Function { + name: "ENTRYPOINT__main".into(), + return_type: Type::Integer(64), + params: vec![], + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "block_0".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }], + }; + let mut buf = String::new(); + write_function(&mut buf, &f).expect("failed to write"); + expect![[r#" + define i64 @ENTRYPOINT__main() #0 { + block_0: + ret i64 0 + }"#]] + .assert_eq(&buf); +} + +// --- Instruction emission --- + +#[test] +fn ret_void() { + let mut buf = String::new(); + write_instruction(&mut buf, &Instruction::Ret(None)).expect("failed to write"); + expect![" ret void"].assert_eq(&buf); +} + +#[test] +fn ret_i64() { + let mut buf = String::new(); + write_instruction( + &mut buf, + &Instruction::Ret(Some(Operand::IntConst(Type::Integer(64), 0))), + ) + .expect("failed to write"); + expect![" ret i64 0"].assert_eq(&buf); +} + +#[test] +fn br_conditional() { + let instr = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("var_0".into()), + true_dest: "block_1".into(), + false_dest: "block_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" br i1 %var_0, label %block_1, label %block_2"].assert_eq(&buf); +} + +#[test] +fn br_conditional_with_bool_const() { + let instr = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "block_1".into(), + false_dest: "block_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" br i1 true, label %block_1, label %block_2"].assert_eq(&buf); +} + +#[test] +fn jump_unconditional() { + let instr = Instruction::Jump { + dest: "block_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" br label %block_1"].assert_eq(&buf); +} + +#[test] +fn binop_add_i64() { + let instr = Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = add i64 %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn binop_sub_i64() { + let instr = Instruction::BinOp { + op: BinOpKind::Sub, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = sub i64 %var_0, 1"].assert_eq(&buf); +} + +#[test] +fn binop_mul_i64() { + let instr = Instruction::BinOp { + op: BinOpKind::Mul, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = mul i64 %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn binop_and_i1() { + let instr = Instruction::BinOp { + op: BinOpKind::And, + ty: Type::Integer(1), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = and i1 %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn binop_xor_not_i1() { + let instr = Instruction::BinOp { + op: BinOpKind::Xor, + ty: Type::Integer(1), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(1), 1), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = xor i1 %var_0, true"].assert_eq(&buf); +} + +#[test] +fn binop_xor_not_i64() { + let instr = Instruction::BinOp { + op: BinOpKind::Xor, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), -1), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = xor i64 %var_0, -1"].assert_eq(&buf); +} + +#[test] +fn binop_fadd_double() { + let instr = Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = fadd double %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn icmp_eq_i64() { + let instr = Instruction::ICmp { + pred: IntPredicate::Eq, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = icmp eq i64 %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn icmp_slt_i64() { + let instr = Instruction::ICmp { + pred: IntPredicate::Slt, + ty: Type::Integer(64), + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::IntConst(Type::Integer(64), 10), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = icmp slt i64 %var_0, 10"].assert_eq(&buf); +} + +#[test] +fn fcmp_oeq_double() { + let instr = Instruction::FCmp { + pred: FloatPredicate::Oeq, + ty: Type::Double, + lhs: Operand::LocalRef("var_0".into()), + rhs: Operand::LocalRef("var_1".into()), + result: "var_2".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_2 = fcmp oeq double %var_0, %var_1"].assert_eq(&buf); +} + +#[test] +fn cast_sitofp() { + let instr = Instruction::Cast { + op: CastKind::Sitofp, + from_ty: Type::Integer(64), + to_ty: Type::Double, + value: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = sitofp i64 %var_0 to double"].assert_eq(&buf); +} + +#[test] +fn cast_fptosi() { + let instr = Instruction::Cast { + op: CastKind::Fptosi, + from_ty: Type::Double, + to_ty: Type::Integer(64), + value: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = fptosi double %var_0 to i64"].assert_eq(&buf); +} + +#[test] +fn call_void_no_return() { + let instr = Instruction::Call { + return_ty: None, + callee: "__quantum__qis__h__body".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: None, + attr_refs: vec![], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr))"].assert_eq(&buf); +} + +#[test] +fn call_void_two_args() { + let instr = Instruction::Call { + return_ty: None, + callee: "__quantum__qis__cx__body".into(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(1, Type::Ptr)), + ], + result: None, + attr_refs: vec![], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr))"].assert_eq(&buf); +} + +#[test] +fn call_with_return() { + let instr = Instruction::Call { + return_ty: Some(Type::Integer(1)), + callee: "__quantum__rt__read_result".into(), + args: vec![(Type::Ptr, Operand::IntToPtr(0, Type::Ptr))], + result: Some("var_0".into()), + attr_refs: vec![], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_0 = call i1 @__quantum__rt__read_result(ptr inttoptr (i64 0 to ptr))"] + .assert_eq(&buf); +} + +#[test] +fn call_with_attr_ref() { + let instr = Instruction::Call { + return_ty: None, + callee: "__quantum__qis__m__body".into(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + ], + result: None, + attr_refs: vec![1], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" call void @__quantum__qis__m__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) #1"].assert_eq(&buf); +} + +#[test] +fn call_with_i64_arg() { + let instr = Instruction::Call { + return_ty: None, + callee: "__quantum__rt__array_record_output".into(), + args: vec![ + (Type::Integer(64), Operand::IntConst(Type::Integer(64), 2)), + (Type::Ptr, Operand::GlobalRef("0".into())), + ], + result: None, + attr_refs: vec![], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" call void @__quantum__rt__array_record_output(i64 2, ptr @0)"].assert_eq(&buf); +} + +#[test] +fn call_result_record_output() { + let instr = Instruction::Call { + return_ty: None, + callee: "__quantum__rt__result_record_output".into(), + args: vec![ + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::GlobalRef("1".into())), + ], + result: None, + attr_refs: vec![], + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![ + " call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1)" + ] + .assert_eq(&buf); +} + +#[test] +fn phi_i1() { + let instr = Instruction::Phi { + ty: Type::Integer(1), + incoming: vec![ + (Operand::IntConst(Type::Integer(1), 1), "block_0".into()), + (Operand::LocalRef("var_2".into()), "block_1".into()), + ], + result: "var_3".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_3 = phi i1 [true, %block_0], [%var_2, %block_1]"].assert_eq(&buf); +} + +#[test] +fn alloca_i1() { + let instr = Instruction::Alloca { + ty: Type::Integer(1), + result: "var_0".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_0 = alloca i1"].assert_eq(&buf); +} + +#[test] +fn load_i1() { + let instr = Instruction::Load { + ty: Type::Integer(1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + result: "var_1".into(), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" %var_1 = load i1, ptr %var_0"].assert_eq(&buf); +} + +#[test] +fn store_i1_true() { + let instr = Instruction::Store { + ty: Type::Integer(1), + value: Operand::IntConst(Type::Integer(1), 1), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" store i1 true, ptr %var_0"].assert_eq(&buf); +} + +#[test] +fn store_half_constant() { + let instr = Instruction::Store { + ty: Type::Half, + value: Operand::float_const(Type::Half, 1.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" store half 1.5, ptr %var_0"].assert_eq(&buf); +} + +#[test] +fn store_float_constant() { + let instr = Instruction::Store { + ty: Type::Float, + value: Operand::float_const(Type::Float, 2.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" store float 2.5, ptr %var_0"].assert_eq(&buf); +} + +#[test] +fn store_double_constant() { + let instr = Instruction::Store { + ty: Type::Double, + value: Operand::float_const(Type::Double, 3.5), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("var_0".into()), + }; + let mut buf = String::new(); + write_instruction(&mut buf, &instr).expect("failed to write"); + expect![" store double 3.5, ptr %var_0"].assert_eq(&buf); +} + +// --- Operand formatting --- + +#[test] +fn operand_inttoptr_named() { + let op = Operand::IntToPtr(0, Type::NamedPtr("Qubit".into())); + let mut buf = String::new(); + write_typed_operand(&mut buf, &op).expect("failed to write"); + expect!["%Qubit* inttoptr (i64 0 to %Qubit*)"].assert_eq(&buf); +} + +#[test] +fn operand_float_whole() { + let mut buf = String::new(); + write_float(&mut buf, 3.0).expect("failed to write"); + expect!["3.0"].assert_eq(&buf); +} + +#[test] +fn operand_float_fractional() { + let mut buf = String::new(); + write_float(&mut buf, std::f64::consts::PI).expect("failed to write"); + expect!["3.141592653589793"].assert_eq(&buf); +} + +// --- Attribute group emission --- + +#[test] +fn attribute_group_entry_point() { + let ag = AttributeGroup { + id: 0, + attributes: vec![ + Attribute::StringAttr("entry_point".into()), + Attribute::StringAttr("output_labeling_schema".into()), + Attribute::KeyValue("qir_profiles".into(), "adaptive_profile".into()), + Attribute::KeyValue("required_num_qubits".into(), "2".into()), + Attribute::KeyValue("required_num_results".into(), "1".into()), + ], + }; + let mut buf = String::new(); + write_attribute_group(&mut buf, &ag).expect("failed to write"); + expect![r#"attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="1" }"#].assert_eq(&buf); +} + +#[test] +fn attribute_group_irreversible() { + let ag = AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr("irreversible".into())], + }; + let mut buf = String::new(); + write_attribute_group(&mut buf, &ag).expect("failed to write"); + expect![r#"attributes #1 = { "irreversible" }"#].assert_eq(&buf); +} + +// --- Metadata emission --- + +#[test] +fn named_metadata_output() { + let nm = NamedMetadata { + name: "llvm.module.flags".into(), + node_refs: vec![0, 1, 2, 3, 4, 5, 6, 7], + }; + let mut buf = String::new(); + write_named_metadata(&mut buf, &nm).expect("failed to write"); + expect!["!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7}"].assert_eq(&buf); +} + +#[test] +fn metadata_node_with_int_and_string() { + let node = MetadataNode { + id: 0, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("qir_major_version".into()), + MetadataValue::Int(Type::Integer(32), 2), + ], + }; + let mut buf = String::new(); + write_metadata_node(&mut buf, &node).expect("failed to write"); + expect![r#"!0 = !{i32 1, !"qir_major_version", i32 2}"#].assert_eq(&buf); +} + +#[test] +fn metadata_node_with_bool_false() { + let node = MetadataNode { + id: 2, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("dynamic_qubit_management".into()), + MetadataValue::Int(Type::Integer(1), 0), + ], + }; + let mut buf = String::new(); + write_metadata_node(&mut buf, &node).expect("failed to write"); + expect![r#"!2 = !{i32 1, !"dynamic_qubit_management", i1 false}"#].assert_eq(&buf); +} + +#[test] +fn metadata_node_with_bool_true() { + let node = MetadataNode { + id: 7, + values: vec![ + MetadataValue::Int(Type::Integer(32), 1), + MetadataValue::String("arrays".into()), + MetadataValue::Int(Type::Integer(1), 1), + ], + }; + let mut buf = String::new(); + write_metadata_node(&mut buf, &node).expect("failed to write"); + expect![r#"!7 = !{i32 1, !"arrays", i1 true}"#].assert_eq(&buf); +} + +#[test] +fn metadata_node_with_i2() { + let node = MetadataNode { + id: 6, + values: vec![ + MetadataValue::Int(Type::Integer(32), 7), + MetadataValue::String("backwards_branching".into()), + MetadataValue::Int(Type::Integer(2), 3), + ], + }; + let mut buf = String::new(); + write_metadata_node(&mut buf, &node).expect("failed to write"); + expect![r#"!6 = !{i32 7, !"backwards_branching", i2 3}"#].assert_eq(&buf); +} + +#[test] +fn metadata_node_with_sublist() { + let node = MetadataNode { + id: 4, + values: vec![ + MetadataValue::Int(Type::Integer(32), 5), + MetadataValue::String("int_computations".into()), + MetadataValue::SubList(vec![MetadataValue::String("i64".into())]), + ], + }; + let mut buf = String::new(); + write_metadata_node(&mut buf, &node).expect("failed to write"); + expect![r#"!4 = !{i32 5, !"int_computations", !{!"i64"}}"#].assert_eq(&buf); +} + +// --- Full module emission --- + +#[test] +fn empty_module_output() { + let m = empty_module(); + let output = write_module_to_string(&m); + expect![""].assert_eq(&output); +} + +#[test] +fn bell_module_v2_output() { + let m = bell_module_v2(); + let output = write_module_to_string(&m); + expect![[r#" + @0 = internal constant [4 x i8] c"0_a\00" + @1 = internal constant [6 x i8] c"1_a0r\00" + @2 = internal constant [6 x i8] c"2_a1r\00" + + declare void @__quantum__qis__h__body(ptr) + + declare void @__quantum__qis__cx__body(ptr, ptr) + + declare void @__quantum__qis__m__body(ptr, ptr) #1 + + declare void @__quantum__rt__array_record_output(i64, ptr) + + declare void @__quantum__rt__result_record_output(ptr, ptr) + + define i64 @ENTRYPOINT__main() #0 { + block_0: + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__cx__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__qis__m__body(ptr inttoptr (i64 0 to ptr), ptr inttoptr (i64 0 to ptr)) + call void @__quantum__qis__m__body(ptr inttoptr (i64 1 to ptr), ptr inttoptr (i64 1 to ptr)) + call void @__quantum__rt__array_record_output(i64 2, ptr @0) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @1) + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 1 to ptr), ptr @2) + ret i64 0 + } + + attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="2" "required_num_results"="2" } + attributes #1 = { "irreversible" } + + ; module flags + + !llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6} + + !0 = !{i32 1, !"qir_major_version", i32 2} + !1 = !{i32 7, !"qir_minor_version", i32 1} + !2 = !{i32 1, !"dynamic_qubit_management", i1 false} + !3 = !{i32 1, !"dynamic_result_management", i1 false} + !4 = !{i32 5, !"int_computations", !{!"i64"}} + !5 = !{i32 7, !"backwards_branching", i2 3} + !6 = !{i32 1, !"arrays", i1 true} + "#]] + .assert_eq(&output); +} diff --git a/source/compiler/qsc_llvm/src/validation.rs b/source/compiler/qsc_llvm/src/validation.rs new file mode 100644 index 0000000000..0c3685fa41 --- /dev/null +++ b/source/compiler/qsc_llvm/src/validation.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +pub mod llvm; +pub mod qir; + +pub use llvm::{LlvmIrError, validate_ir}; +pub use qir::{ + Capabilities, DetectedProfile, QirProfileError, QirProfileValidation, validate_qir_profile, +}; diff --git a/source/compiler/qsc_llvm/src/validation/llvm.rs b/source/compiler/qsc_llvm/src/validation/llvm.rs new file mode 100644 index 0000000000..2532966ec9 --- /dev/null +++ b/source/compiler/qsc_llvm/src/validation/llvm.rs @@ -0,0 +1,1726 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[cfg(test)] +mod tests; + +use crate::model::Type; +use crate::model::{BinOpKind, CastKind, Function, Instruction, MetadataValue, Module, Operand}; +use miette::Diagnostic; +use rustc_hash::{FxHashMap, FxHashSet}; +use thiserror::Error; + +#[derive(Clone, Debug, Diagnostic, Error, PartialEq, Eq)] +pub enum LlvmIrError { + // Structure (3 variants) + #[error("function `{function}` is not a declaration but has no basic blocks")] + #[diagnostic(code("Qsc.Llvm.IrValidator.MissingBasicBlocks"))] + MissingBasicBlocks { function: String }, + + #[error("declaration `{function}` should not have basic blocks")] + #[diagnostic(code("Qsc.Llvm.IrValidator.DeclarationHasBlocks"))] + DeclarationHasBlocks { function: String }, + + #[error("basic block `{block}` in function `{function}` is empty")] + #[diagnostic(code("Qsc.Llvm.IrValidator.EmptyBasicBlock"))] + EmptyBasicBlock { function: String, block: String }, + + // Terminators (2 variants) + #[error("basic block `{block}` in function `{function}` does not end with a terminator")] + #[diagnostic(code("Qsc.Llvm.IrValidator.MissingTerminator"))] + MissingTerminator { function: String, block: String }, + + #[error( + "terminator at index {instr_idx} in block `{block}` of function `{function}` is not the last instruction" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.MidBlockTerminator"))] + MidBlockTerminator { + function: String, + block: String, + instr_idx: usize, + }, + + // Type consistency (7 variants) + #[error("{instruction}: type mismatch — expected `{expected}`, found `{found}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.TypeMismatch"))] + TypeMismatch { + instruction: String, + expected: String, + found: String, + location: String, + }, + + #[error( + "branch condition in block `{block}` of function `{function}` is `{found_type}`, expected `i1`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.BrCondNotI1"))] + BrCondNotI1 { + function: String, + block: String, + found_type: String, + }, + + #[error( + "select condition in block `{block}` of function `{function}` is `{found_type}`, expected `i1`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.SelectCondNotI1"))] + SelectCondNotI1 { + function: String, + block: String, + found_type: String, + }, + + #[error( + "return type mismatch in function `{function}` — expected `{expected}`, found `{found}`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.RetTypeMismatch"))] + RetTypeMismatch { + function: String, + expected: String, + found: String, + }, + + #[error("{instruction}: integer operation on non-integer type `{ty}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.IntOpOnNonInt"))] + IntOpOnNonInt { + instruction: String, + ty: String, + location: String, + }, + + #[error("{instruction}: floating-point operation on non-float type `{ty}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.FloatOpOnNonFloat"))] + FloatOpOnNonFloat { + instruction: String, + ty: String, + location: String, + }, + + #[error("{instruction}: expected pointer type, found `{found}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.PtrExpected"))] + PtrExpected { + instruction: String, + found: String, + location: String, + }, + + // Switch (2 variants) + #[error( + "switch in block `{block}` of function `{function}` declares non-integer type `{found_type}`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.SwitchTypeNotInteger"))] + SwitchTypeNotInteger { + function: String, + block: String, + found_type: String, + }, + + #[error( + "switch in block `{block}` of function `{function}` has duplicate case value `{case_value}`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.SwitchDuplicateCaseValue"))] + SwitchDuplicateCaseValue { + function: String, + block: String, + case_value: i64, + }, + + // References (4 variants) + #[error("undefined local reference `%{name}` in block `{block}` of function `{function}`")] + #[diagnostic(code("Qsc.Llvm.IrValidator.UndefinedLocalRef"))] + UndefinedLocalRef { + name: String, + function: String, + block: String, + }, + + #[error("undefined global reference `@{name}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.UndefinedGlobalRef"))] + UndefinedGlobalRef { name: String, location: String }, + + #[error("undefined callee `@{name}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.UndefinedCallee"))] + UndefinedCallee { name: String, location: String }, + + #[error( + "branch target `{target}` does not exist in function `{function}` (from block `{block}`)" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.InvalidBranchTarget"))] + InvalidBranchTarget { + target: String, + function: String, + block: String, + }, + + // SSA (1 variant) + #[error("duplicate definition of `%{name}` in function `{function}`")] + #[diagnostic(code("Qsc.Llvm.IrValidator.DuplicateDefinition"))] + DuplicateDefinition { name: String, function: String }, + + // Cast (1 variant) + #[error("invalid cast `{cast_kind}` from `{from_ty}` to `{to_ty}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.InvalidCast"))] + InvalidCast { + cast_kind: String, + from_ty: String, + to_ty: String, + location: String, + }, + + // Call (2 variants) + #[error("call to `{callee}`: expected {expected} arguments, found {found} in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.ArgCountMismatch"))] + ArgCountMismatch { + callee: String, + expected: usize, + found: usize, + location: String, + }, + + #[error( + "call to `{callee}`: argument {param_idx} type mismatch — expected `{expected}`, found `{found}` in {location}" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.ArgTypeMismatch"))] + ArgTypeMismatch { + callee: String, + param_idx: usize, + expected: String, + found: String, + location: String, + }, + + // PHI (6 variants) + #[error( + "PHI `%{result}` in block `{block}` of function `{function}` is not at the start of the block" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiNotAtBlockStart"))] + PhiNotAtBlockStart { + function: String, + block: String, + result: String, + }, + + #[error("PHI `%{result}` in block `{block}` of function `{function}` has void type")] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiVoidType"))] + PhiVoidType { + function: String, + block: String, + result: String, + }, + + #[error("PHI `%{result}` in entry block of function `{function}`")] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiInEntryBlock"))] + PhiInEntryBlock { function: String, result: String }, + + #[error( + "PHI `%{result}` in block `{block}` of function `{function}`: expected {expected} incoming entries, found {found}" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiPredCountMismatch"))] + PhiPredCountMismatch { + function: String, + block: String, + result: String, + expected: usize, + found: usize, + }, + + #[error( + "PHI `%{result}` in block `{block}` of function `{function}`: incoming block `{incoming_block}` is not a predecessor" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiIncomingNotPredecessor"))] + PhiIncomingNotPredecessor { + function: String, + block: String, + result: String, + incoming_block: String, + }, + + #[error( + "PHI `%{result}` in block `{block}` of function `{function}`: duplicate incoming block `{dup_block}` with different values" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.PhiDuplicateBlockDiffValue"))] + PhiDuplicateBlockDiffValue { + function: String, + block: String, + result: String, + dup_block: String, + }, + + // Alloca (1 variant) + #[error( + "alloca `%{result}` in block `{block}` of function `{function}` uses unsized type `{ty}`" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.AllocaUnsizedType"))] + AllocaUnsizedType { + function: String, + block: String, + result: String, + ty: String, + }, + + // GEP (2 variants) + #[error("GEP in block `{block}` of function `{function}` has no indices")] + #[diagnostic(code("Qsc.Llvm.IrValidator.GepNoIndices"))] + GepNoIndices { function: String, block: String }, + + #[error("{instruction}: unsized pointee type `{pointee_ty}` in {location}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.UnsizedPointeeType"))] + UnsizedPointeeType { + instruction: String, + pointee_ty: String, + location: String, + }, + + // Typed pointer consistency (1 variant) + #[error( + "{instruction}: typed pointer inner type `{ptr_inner_ty}` does not match expected type `{expected_ty}` in {location}" + )] + #[diagnostic(code("Qsc.Llvm.IrValidator.TypedPtrMismatch"))] + TypedPtrMismatch { + instruction: String, + ptr_inner_ty: String, + expected_ty: String, + location: String, + }, + + // Dominance (1 variant) + #[error( + "use of `%{name}` in block `{use_block}` of function `{function}` is not dominated by its definition in block `{def_block}`" + )] + #[diagnostic( + code("Qsc.Llvm.IrValidator.UseNotDominatedByDef"), + help("ensure the definition of `%{name}` dominates all its uses") + )] + UseNotDominatedByDef { + name: String, + def_block: String, + use_block: String, + function: String, + }, + + // Attribute group integrity (2 variants) + #[error("duplicate attribute group ID #{id}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.DuplicateAttributeGroupId"))] + DuplicateAttributeGroupId { id: u32 }, + + #[error("function `{function}` references undefined attribute group #{ref_id}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.InvalidAttributeGroupRef"))] + InvalidAttributeGroupRef { function: String, ref_id: u32 }, + + // Metadata integrity (3 variants) + #[error("duplicate metadata node ID !{id}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.DuplicateMetadataNodeId"))] + DuplicateMetadataNodeId { id: u32 }, + + #[error("undefined metadata node reference !{ref_id} in {context}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.InvalidMetadataNodeRef"))] + InvalidMetadataNodeRef { context: String, ref_id: u32 }, + + #[error("metadata reference cycle detected at node !{node_id}")] + #[diagnostic(code("Qsc.Llvm.IrValidator.MetadataRefCycle"))] + MetadataRefCycle { node_id: u32 }, +} + +type CfgMap<'a> = FxHashMap<&'a str, Vec<&'a str>>; + +/// Validates general LLVM IR structural correctness. +/// Returns a list of all errors found. An empty list means the module is well-formed. +#[must_use] +pub fn validate_ir(module: &Module) -> Vec { + let mut errors = Vec::new(); + + errors.extend(validate_attribute_groups(module)); + errors.extend(validate_metadata(module)); + + for func in &module.functions { + errors.extend(validate_function_structure(func)); + } + + for func in module.functions.iter().filter(|f| !f.is_declaration) { + errors.extend(validate_terminators(func)); + + let (ssa_env, ssa_errors) = build_ssa_env(func); + errors.extend(ssa_errors); + errors.extend(validate_intra_block_ordering(func)); + + let (successors, predecessors) = build_cfg(func); + + errors.extend(validate_references(func, &ssa_env, module)); + errors.extend(validate_types(func, &ssa_env, module)); + errors.extend(validate_casts(func, &ssa_env)); + errors.extend(validate_switches(func, &ssa_env)); + errors.extend(validate_phis(func, &predecessors, &ssa_env)); + errors.extend(validate_allocas(func, module)); + errors.extend(validate_gep(func, &ssa_env)); + + if !func.basic_blocks.is_empty() { + let rpo = reverse_postorder(&func.basic_blocks[0].name, &successors); + let idom = compute_dominators(&func.basic_blocks[0].name, &rpo, &predecessors); + errors.extend(validate_dominance(func, &ssa_env, &idom)); + } + } + + errors +} + +// --------------------------------------------------------------------------- +// Helper functions +// --------------------------------------------------------------------------- + +fn is_terminator(instr: &Instruction) -> bool { + matches!( + instr, + Instruction::Ret(_) + | Instruction::Br { .. } + | Instruction::Jump { .. } + | Instruction::Switch { .. } + | Instruction::Unreachable + ) +} + +fn instruction_result(instr: &Instruction) -> Option<(String, Type)> { + match instr { + Instruction::BinOp { result, ty, .. } + | Instruction::Phi { result, ty, .. } + | Instruction::Load { result, ty, .. } + | Instruction::Select { result, ty, .. } => Some((result.clone(), ty.clone())), + Instruction::ICmp { result, .. } | Instruction::FCmp { result, .. } => { + Some((result.clone(), Type::Integer(1))) + } + Instruction::Cast { result, to_ty, .. } => Some((result.clone(), to_ty.clone())), + Instruction::Call { + result: Some(r), + return_ty: Some(ty), + .. + } => Some((r.clone(), ty.clone())), + Instruction::Alloca { result, .. } | Instruction::GetElementPtr { result, .. } => { + Some((result.clone(), Type::Ptr)) + } + _ => None, + } +} + +fn instruction_operands(instr: &Instruction) -> Vec<&Operand> { + match instr { + Instruction::Ret(Some(op)) => vec![op], + Instruction::Ret(None) + | Instruction::Unreachable + | Instruction::Jump { .. } + | Instruction::Alloca { .. } => vec![], + Instruction::Br { cond, .. } => vec![cond], + Instruction::BinOp { lhs, rhs, .. } + | Instruction::ICmp { lhs, rhs, .. } + | Instruction::FCmp { lhs, rhs, .. } => vec![lhs, rhs], + Instruction::Cast { value, .. } | Instruction::Switch { value, .. } => vec![value], + Instruction::Call { args, .. } => args.iter().map(|(_, op)| op).collect(), + Instruction::Phi { incoming, .. } => incoming.iter().map(|(op, _)| op).collect(), + Instruction::Load { ptr, .. } => vec![ptr], + Instruction::Store { value, ptr, .. } => vec![value, ptr], + Instruction::Select { + cond, + true_val, + false_val, + .. + } => vec![cond, true_val, false_val], + Instruction::GetElementPtr { ptr, indices, .. } => { + let mut ops = vec![ptr]; + ops.extend(indices.iter()); + ops + } + } +} + +fn is_ptr_type(ty: &Type) -> bool { + matches!(ty, Type::Ptr | Type::NamedPtr(_) | Type::TypedPtr(_)) +} + +fn bit_width(ty: &Type) -> Option { + match ty { + Type::Integer(n) => Some(*n), + Type::Half => Some(16), + Type::Float => Some(32), + Type::Double => Some(64), + _ => None, + } +} + +fn types_equivalent(a: &Type, b: &Type) -> bool { + a == b +} + +fn is_sized_alloca_type(ty: &Type, module: &Module) -> bool { + match ty { + Type::Void | Type::Label | Type::Function(..) => false, + Type::Named(name) => !module + .struct_types + .iter() + .any(|struct_ty| struct_ty.name == *name && struct_ty.is_opaque), + _ => true, + } +} + +fn local_ref_name(operand: &Operand) -> Option<&str> { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => Some(name.as_str()), + _ => None, + } +} + +fn operand_types_compatible(actual: &Type, expected: &Type) -> bool { + types_equivalent(actual, expected) + || (is_ptr_type(actual) + && is_ptr_type(expected) + && (matches!(actual, Type::Ptr) || matches!(expected, Type::Ptr))) +} + +// --------------------------------------------------------------------------- +// SSA environment builder +// --------------------------------------------------------------------------- + +fn build_ssa_env(func: &Function) -> (FxHashMap, Vec) { + let mut env: FxHashMap = FxHashMap::default(); + let mut errors = Vec::new(); + + for (i, param) in func.params.iter().enumerate() { + let name = param.name.clone().unwrap_or_else(|| i.to_string()); + env.insert(name, param.ty.clone()); + } + + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Some((name, ty)) = instruction_result(instr) { + if env.contains_key(&name) { + errors.push(LlvmIrError::DuplicateDefinition { + name: name.clone(), + function: func.name.clone(), + }); + } + env.insert(name, ty); + } + } + } + + (env, errors) +} + +// --------------------------------------------------------------------------- +// CFG builder +// --------------------------------------------------------------------------- + +fn build_cfg<'a>(func: &'a Function) -> (CfgMap<'a>, CfgMap<'a>) { + let mut successors: CfgMap<'a> = FxHashMap::default(); + let mut predecessors: CfgMap<'a> = FxHashMap::default(); + + for bb in &func.basic_blocks { + successors.entry(bb.name.as_str()).or_default(); + predecessors.entry(bb.name.as_str()).or_default(); + } + + for bb in &func.basic_blocks { + let targets: Vec<&str> = match bb.instructions.last() { + Some(Instruction::Br { + true_dest, + false_dest, + .. + }) => vec![true_dest.as_str(), false_dest.as_str()], + Some(Instruction::Jump { dest }) => vec![dest.as_str()], + Some(Instruction::Switch { + default_dest, + cases, + .. + }) => { + let mut t = vec![default_dest.as_str()]; + t.extend(cases.iter().map(|(_, d)| d.as_str())); + t + } + _ => vec![], + }; + for target in &targets { + successors.entry(bb.name.as_str()).or_default().push(target); + predecessors + .entry(target) + .or_default() + .push(bb.name.as_str()); + } + } + + (successors, predecessors) +} + +// --------------------------------------------------------------------------- +// Operand type resolution +// --------------------------------------------------------------------------- + +fn resolve_operand_type(operand: &Operand, locals: &FxHashMap) -> Option { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => locals.get(name).cloned(), + Operand::IntConst(ty, _) | Operand::IntToPtr(_, ty) | Operand::FloatConst(ty, _) => { + Some(ty.clone()) + } + Operand::NullPtr | Operand::GetElementPtr { .. } | Operand::GlobalRef(_) => Some(Type::Ptr), + } +} + +// --------------------------------------------------------------------------- +// Validation passes +// --------------------------------------------------------------------------- + +fn validate_function_structure(func: &Function) -> Vec { + let mut errors = Vec::new(); + if func.is_declaration { + if !func.basic_blocks.is_empty() { + errors.push(LlvmIrError::DeclarationHasBlocks { + function: func.name.clone(), + }); + } + } else if func.basic_blocks.is_empty() { + errors.push(LlvmIrError::MissingBasicBlocks { + function: func.name.clone(), + }); + } + errors +} + +fn validate_terminators(func: &Function) -> Vec { + let mut errors = Vec::new(); + for bb in &func.basic_blocks { + if bb.instructions.is_empty() { + errors.push(LlvmIrError::EmptyBasicBlock { + function: func.name.clone(), + block: bb.name.clone(), + }); + continue; + } + if !is_terminator(bb.instructions.last().expect("non-empty")) { + errors.push(LlvmIrError::MissingTerminator { + function: func.name.clone(), + block: bb.name.clone(), + }); + } + for (idx, instr) in bb.instructions.iter().enumerate() { + if idx < bb.instructions.len() - 1 && is_terminator(instr) { + errors.push(LlvmIrError::MidBlockTerminator { + function: func.name.clone(), + block: bb.name.clone(), + instr_idx: idx, + }); + } + } + } + errors +} + +fn validate_intra_block_ordering(func: &Function) -> Vec { + let mut errors = Vec::new(); + for bb in &func.basic_blocks { + let block_defs: FxHashSet = bb + .instructions + .iter() + .filter_map(|i| instruction_result(i).map(|(name, _)| name)) + .collect(); + + let mut defined_so_far: FxHashSet = FxHashSet::default(); + for instr in &bb.instructions { + if matches!(instr, Instruction::Phi { .. }) { + if let Some((name, _)) = instruction_result(instr) { + defined_so_far.insert(name); + } + continue; + } + + for op in instruction_operands(instr) { + if let Some(name) = local_ref_name(op) + && block_defs.contains(name) + && !defined_so_far.contains(name) + { + errors.push(LlvmIrError::UndefinedLocalRef { + name: name.to_string(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + } + if let Some((name, _)) = instruction_result(instr) { + defined_so_far.insert(name); + } + } + } + errors +} + +fn validate_references( + func: &Function, + ssa_env: &FxHashMap, + module: &Module, +) -> Vec { + let mut errors = Vec::new(); + let block_names: FxHashSet<&str> = func + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect(); + let func_names: FxHashSet<&str> = module.functions.iter().map(|f| f.name.as_str()).collect(); + let global_names: FxHashSet<&str> = module.globals.iter().map(|g| g.name.as_str()).collect(); + + for bb in &func.basic_blocks { + let location = format!("function `{}`, block `{}`", func.name, bb.name); + + for instr in &bb.instructions { + for operand in instruction_operands(instr) { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => { + if !ssa_env.contains_key(name) { + errors.push(LlvmIrError::UndefinedLocalRef { + name: name.clone(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + } + Operand::GlobalRef(name) => { + if !global_names.contains(name.as_str()) + && !func_names.contains(name.as_str()) + { + errors.push(LlvmIrError::UndefinedGlobalRef { + name: name.clone(), + location: location.clone(), + }); + } + } + _ => {} + } + } + + match instr { + Instruction::Br { + true_dest, + false_dest, + .. + } => { + for target in [true_dest, false_dest] { + if !block_names.contains(target.as_str()) { + errors.push(LlvmIrError::InvalidBranchTarget { + target: target.clone(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + } + } + Instruction::Jump { dest } => { + if !block_names.contains(dest.as_str()) { + errors.push(LlvmIrError::InvalidBranchTarget { + target: dest.clone(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + } + Instruction::Switch { + default_dest, + cases, + .. + } => { + if !block_names.contains(default_dest.as_str()) { + errors.push(LlvmIrError::InvalidBranchTarget { + target: default_dest.clone(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + for (_, dest) in cases { + if !block_names.contains(dest.as_str()) { + errors.push(LlvmIrError::InvalidBranchTarget { + target: dest.clone(), + function: func.name.clone(), + block: bb.name.clone(), + }); + } + } + } + _ => {} + } + + if let Instruction::Call { callee, .. } = instr + && !func_names.contains(callee.as_str()) + { + errors.push(LlvmIrError::UndefinedCallee { + name: callee.clone(), + location: location.clone(), + }); + } + } + } + errors +} + +fn is_int_binop(op: &BinOpKind) -> bool { + matches!( + op, + BinOpKind::Add + | BinOpKind::Sub + | BinOpKind::Mul + | BinOpKind::Sdiv + | BinOpKind::Srem + | BinOpKind::Shl + | BinOpKind::Ashr + | BinOpKind::And + | BinOpKind::Or + | BinOpKind::Xor + | BinOpKind::Udiv + | BinOpKind::Urem + | BinOpKind::Lshr + ) +} + +#[allow(clippy::too_many_lines)] +fn validate_types( + func: &Function, + ssa_env: &FxHashMap, + module: &Module, +) -> Vec { + let mut errors = Vec::new(); + + for bb in &func.basic_blocks { + let location = format!("function `{}`, block `{}`", func.name, bb.name); + + for instr in &bb.instructions { + match instr { + Instruction::BinOp { + op, ty, lhs, rhs, .. + } => { + let instr_name = format!("{op:?}"); + if is_int_binop(op) { + if !matches!(ty, Type::Integer(_)) { + errors.push(LlvmIrError::IntOpOnNonInt { + instruction: instr_name.clone(), + ty: ty.to_string(), + location: location.clone(), + }); + } + } else if !ty.is_floating_point() { + errors.push(LlvmIrError::FloatOpOnNonFloat { + instruction: instr_name.clone(), + ty: ty.to_string(), + location: location.clone(), + }); + } + for (side, operand) in [("lhs", lhs), ("rhs", rhs)] { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: instr_name.clone(), + expected: ty.to_string(), + found: resolved.to_string(), + location: format!("{location}, {side}"), + }); + } + } + } + + Instruction::ICmp { ty, lhs, rhs, .. } => { + if !matches!(ty, Type::Integer(_)) && !is_ptr_type(ty) { + errors.push(LlvmIrError::TypeMismatch { + instruction: "ICmp".to_string(), + expected: "integer or pointer type".to_string(), + found: ty.to_string(), + location: location.clone(), + }); + } + for (side, operand) in [("lhs", lhs), ("rhs", rhs)] { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "ICmp".to_string(), + expected: ty.to_string(), + found: resolved.to_string(), + location: format!("{location}, {side}"), + }); + } + } + } + + Instruction::FCmp { ty, lhs, rhs, .. } => { + if !ty.is_floating_point() { + errors.push(LlvmIrError::FloatOpOnNonFloat { + instruction: "FCmp".to_string(), + ty: ty.to_string(), + location: location.clone(), + }); + } + for (side, operand) in [("lhs", lhs), ("rhs", rhs)] { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "FCmp".to_string(), + expected: ty.to_string(), + found: resolved.to_string(), + location: format!("{location}, {side}"), + }); + } + } + } + + Instruction::Br { cond_ty, cond, .. } => { + if *cond_ty != Type::Integer(1) { + errors.push(LlvmIrError::BrCondNotI1 { + function: func.name.clone(), + block: bb.name.clone(), + found_type: cond_ty.to_string(), + }); + } + if let Some(resolved) = resolve_operand_type(cond, ssa_env) + && resolved != Type::Integer(1) + { + errors.push(LlvmIrError::BrCondNotI1 { + function: func.name.clone(), + block: bb.name.clone(), + found_type: resolved.to_string(), + }); + } + } + + Instruction::Select { + cond, + true_val, + false_val, + ty, + .. + } => { + if let Some(cond_ty) = resolve_operand_type(cond, ssa_env) + && cond_ty != Type::Integer(1) + { + errors.push(LlvmIrError::SelectCondNotI1 { + function: func.name.clone(), + block: bb.name.clone(), + found_type: cond_ty.to_string(), + }); + } + for (side, operand) in [("true_val", true_val), ("false_val", false_val)] { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Select".to_string(), + expected: ty.to_string(), + found: resolved.to_string(), + location: format!("{location}, {side}"), + }); + } + } + } + + Instruction::Ret(Some(operand)) => { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, &func.return_type) + { + errors.push(LlvmIrError::RetTypeMismatch { + function: func.name.clone(), + expected: func.return_type.to_string(), + found: resolved.to_string(), + }); + } + } + Instruction::Ret(None) => { + if func.return_type != Type::Void { + errors.push(LlvmIrError::RetTypeMismatch { + function: func.name.clone(), + expected: func.return_type.to_string(), + found: "void".to_string(), + }); + } + } + + Instruction::Store { + ty, + value, + ptr_ty, + ptr, + } => { + if !is_ptr_type(ptr_ty) { + errors.push(LlvmIrError::PtrExpected { + instruction: "Store".to_string(), + found: ptr_ty.to_string(), + location: location.clone(), + }); + } + if let Some(resolved) = resolve_operand_type(value, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Store".to_string(), + expected: ty.to_string(), + found: resolved.to_string(), + location: location.clone(), + }); + } + if is_ptr_type(ptr_ty) + && let Some(resolved_ptr_ty) = resolve_operand_type(ptr, ssa_env) + && !operand_types_compatible(&resolved_ptr_ty, ptr_ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Store".to_string(), + expected: ptr_ty.to_string(), + found: resolved_ptr_ty.to_string(), + location: format!("{location}, ptr"), + }); + } + if let Type::TypedPtr(inner) = ptr_ty + && !types_equivalent(inner, ty) + { + errors.push(LlvmIrError::TypedPtrMismatch { + instruction: "Store".to_string(), + ptr_inner_ty: inner.to_string(), + expected_ty: ty.to_string(), + location: location.clone(), + }); + } + } + + Instruction::Load { + ty, ptr_ty, ptr, .. + } => { + if !is_ptr_type(ptr_ty) { + errors.push(LlvmIrError::PtrExpected { + instruction: "Load".to_string(), + found: ptr_ty.to_string(), + location: location.clone(), + }); + } + if is_ptr_type(ptr_ty) + && let Some(resolved_ptr_ty) = resolve_operand_type(ptr, ssa_env) + && !operand_types_compatible(&resolved_ptr_ty, ptr_ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Load".to_string(), + expected: ptr_ty.to_string(), + found: resolved_ptr_ty.to_string(), + location: format!("{location}, ptr"), + }); + } + if let Type::TypedPtr(inner) = ptr_ty + && !types_equivalent(inner, ty) + { + errors.push(LlvmIrError::TypedPtrMismatch { + instruction: "Load".to_string(), + ptr_inner_ty: inner.to_string(), + expected_ty: ty.to_string(), + location: location.clone(), + }); + } + } + + Instruction::Call { + callee, + args, + return_ty, + .. + } => { + if let Some(target_func) = module.functions.iter().find(|f| f.name == *callee) { + match (return_ty, &target_func.return_type) { + (None, Type::Void) => {} + (Some(found), expected) if types_equivalent(found, expected) => {} + (Some(found), expected) => { + errors.push(LlvmIrError::TypeMismatch { + instruction: format!("Call @{callee}"), + expected: expected.to_string(), + found: found.to_string(), + location: format!("{location}, return type"), + }); + } + (None, expected) => { + errors.push(LlvmIrError::TypeMismatch { + instruction: format!("Call @{callee}"), + expected: expected.to_string(), + found: "void".to_string(), + location: format!("{location}, return type"), + }); + } + } + + if args.len() == target_func.params.len() { + for (i, ((arg_ty, operand), param)) in + args.iter().zip(&target_func.params).enumerate() + { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !operand_types_compatible(&resolved, arg_ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: format!("Call @{callee}"), + expected: arg_ty.to_string(), + found: resolved.to_string(), + location: format!("{location}, arg {i}"), + }); + } + if !types_equivalent(arg_ty, ¶m.ty) { + errors.push(LlvmIrError::ArgTypeMismatch { + callee: callee.clone(), + param_idx: i, + expected: param.ty.to_string(), + found: arg_ty.to_string(), + location: location.clone(), + }); + } + } + } else { + errors.push(LlvmIrError::ArgCountMismatch { + callee: callee.clone(), + expected: target_func.params.len(), + found: args.len(), + location: location.clone(), + }); + } + } + } + + _ => {} + } + } + } + errors +} + +fn is_valid_cast(kind: &CastKind, from: &Type, to: &Type) -> bool { + match kind { + CastKind::Zext | CastKind::Sext => matches!( + (from, to), + (Type::Integer(fw), Type::Integer(tw)) if fw < tw + ), + CastKind::Trunc => matches!( + (from, to), + (Type::Integer(fw), Type::Integer(tw)) if fw > tw + ), + CastKind::Sitofp => matches!(from, Type::Integer(_)) && to.is_floating_point(), + CastKind::Fptosi => from.is_floating_point() && matches!(to, Type::Integer(_)), + CastKind::FpExt => matches!( + (from.floating_point_bit_width(), to.floating_point_bit_width()), + (Some(from_width), Some(to_width)) if from_width < to_width + ), + CastKind::FpTrunc => matches!( + (from.floating_point_bit_width(), to.floating_point_bit_width()), + (Some(from_width), Some(to_width)) if from_width > to_width + ), + CastKind::IntToPtr => matches!(from, Type::Integer(_)) && is_ptr_type(to), + CastKind::PtrToInt => is_ptr_type(from) && matches!(to, Type::Integer(_)), + CastKind::Bitcast => match (is_ptr_type(from), is_ptr_type(to)) { + (true, true) => true, + (false, false) => bit_width(from) == bit_width(to) && bit_width(from).is_some(), + _ => false, + }, + } +} + +fn validate_casts(func: &Function, ssa_env: &FxHashMap) -> Vec { + let mut errors = Vec::new(); + for bb in &func.basic_blocks { + let location = format!("function `{}`, block `{}`", func.name, bb.name); + for instr in &bb.instructions { + if let Instruction::Cast { + op, + from_ty, + to_ty, + value, + .. + } = instr + { + if let Some(resolved) = resolve_operand_type(value, ssa_env) + && !types_equivalent(&resolved, from_ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Cast".to_string(), + expected: from_ty.to_string(), + found: resolved.to_string(), + location: location.clone(), + }); + } + if !is_valid_cast(op, from_ty, to_ty) { + errors.push(LlvmIrError::InvalidCast { + cast_kind: format!("{op:?}"), + from_ty: from_ty.to_string(), + to_ty: to_ty.to_string(), + location: location.clone(), + }); + } + } + } + } + errors +} + +fn validate_switches(func: &Function, ssa_env: &FxHashMap) -> Vec { + let mut errors = Vec::new(); + + for bb in &func.basic_blocks { + let location = format!("function `{}`, block `{}`", func.name, bb.name); + + for instr in &bb.instructions { + if let Instruction::Switch { + ty, value, cases, .. + } = instr + { + if !matches!(ty, Type::Integer(_)) { + errors.push(LlvmIrError::SwitchTypeNotInteger { + function: func.name.clone(), + block: bb.name.clone(), + found_type: ty.to_string(), + }); + } + + if let Some(resolved) = resolve_operand_type(value, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "Switch".to_string(), + expected: ty.to_string(), + found: resolved.to_string(), + location: location.clone(), + }); + } + + let mut seen_case_values = FxHashSet::default(); + for (case_value, _) in cases { + if !seen_case_values.insert(*case_value) { + errors.push(LlvmIrError::SwitchDuplicateCaseValue { + function: func.name.clone(), + block: bb.name.clone(), + case_value: *case_value, + }); + } + } + } + } + } + + errors +} + +fn validate_phis( + func: &Function, + predecessors: &CfgMap<'_>, + ssa_env: &FxHashMap, +) -> Vec { + let mut errors = Vec::new(); + + for (bb_idx, bb) in func.basic_blocks.iter().enumerate() { + let mut seen_non_phi = false; + + for instr in &bb.instructions { + if let Instruction::Phi { + ty, + incoming, + result, + } = instr + { + // Rule 7: No PHI in entry block + if bb_idx == 0 { + errors.push(LlvmIrError::PhiInEntryBlock { + function: func.name.clone(), + result: result.clone(), + }); + } + + // Rule 1: PHIs must be grouped at start of block + if seen_non_phi { + errors.push(LlvmIrError::PhiNotAtBlockStart { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + }); + } + + // Rule 2: PHI type must not be Void + if *ty == Type::Void { + errors.push(LlvmIrError::PhiVoidType { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + }); + } + + // Rule 3: Incoming value types must match PHI type + for (operand, _label) in incoming { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !types_equivalent(&resolved, ty) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: format!("Phi %{result}"), + expected: ty.to_string(), + found: resolved.to_string(), + location: format!("function `{}`, block `{}`", func.name, bb.name), + }); + } + } + + // Rule 4: Incoming count == predecessor count + let preds = predecessors.get(bb.name.as_str()); + let pred_count = preds.map_or(0, Vec::len); + if incoming.len() != pred_count { + errors.push(LlvmIrError::PhiPredCountMismatch { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + expected: pred_count, + found: incoming.len(), + }); + } + + // Rule 5: Incoming labels must match predecessors as an exact multiset. + let mut pred_counts: FxHashMap<&str, usize> = FxHashMap::default(); + if let Some(preds) = preds { + for &pred in preds { + *pred_counts.entry(pred).or_default() += 1; + } + } + for (_operand, label) in incoming { + match pred_counts.get_mut(label.as_str()) { + Some(remaining) if *remaining > 0 => *remaining -= 1, + _ => { + errors.push(LlvmIrError::PhiIncomingNotPredecessor { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + incoming_block: label.clone(), + }); + } + } + } + + // Rule 6: Duplicate incoming blocks must carry identical values + let mut seen_labels: FxHashMap<&str, &Operand> = FxHashMap::default(); + for (operand, label) in incoming { + if let Some(prev_op) = seen_labels.get(label.as_str()) { + if *prev_op != operand { + errors.push(LlvmIrError::PhiDuplicateBlockDiffValue { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + dup_block: label.clone(), + }); + } + } else { + seen_labels.insert(label.as_str(), operand); + } + } + } else { + seen_non_phi = true; + } + } + } + errors +} + +fn validate_allocas(func: &Function, module: &Module) -> Vec { + let mut errors = Vec::new(); + + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Instruction::Alloca { ty, result } = instr + && !is_sized_alloca_type(ty, module) + { + errors.push(LlvmIrError::AllocaUnsizedType { + function: func.name.clone(), + block: bb.name.clone(), + result: result.clone(), + ty: ty.to_string(), + }); + } + } + } + + errors +} + +fn validate_gep(func: &Function, ssa_env: &FxHashMap) -> Vec { + let mut errors = Vec::new(); + for bb in &func.basic_blocks { + let location = format!("function `{}`, block `{}`", func.name, bb.name); + for instr in &bb.instructions { + if let Instruction::GetElementPtr { + pointee_ty, + ptr_ty, + indices, + .. + } = instr + { + if !is_ptr_type(ptr_ty) { + errors.push(LlvmIrError::PtrExpected { + instruction: "GetElementPtr".to_string(), + found: ptr_ty.to_string(), + location: location.clone(), + }); + } + + if matches!(pointee_ty, Type::Void | Type::Function(..)) { + errors.push(LlvmIrError::UnsizedPointeeType { + instruction: "GetElementPtr".to_string(), + pointee_ty: pointee_ty.to_string(), + location: location.clone(), + }); + } + + if indices.is_empty() { + errors.push(LlvmIrError::GepNoIndices { + function: func.name.clone(), + block: bb.name.clone(), + }); + } + + for (index, operand) in indices.iter().enumerate() { + if let Some(resolved) = resolve_operand_type(operand, ssa_env) + && !matches!(resolved, Type::Integer(_)) + { + errors.push(LlvmIrError::TypeMismatch { + instruction: "GetElementPtr".to_string(), + expected: "integer type".to_string(), + found: resolved.to_string(), + location: format!("{location}, index {index}"), + }); + } + } + + if let Type::TypedPtr(inner) = ptr_ty + && !types_equivalent(inner, pointee_ty) + { + errors.push(LlvmIrError::TypedPtrMismatch { + instruction: "GetElementPtr".to_string(), + ptr_inner_ty: inner.to_string(), + expected_ty: pointee_ty.to_string(), + location: location.clone(), + }); + } + } + } + } + errors +} + +// --------------------------------------------------------------------------- +// Dominance analysis +// --------------------------------------------------------------------------- + +fn reverse_postorder<'a>( + entry: &'a str, + successors: &FxHashMap<&'a str, Vec<&'a str>>, +) -> Vec<&'a str> { + fn dfs<'a>( + block: &'a str, + successors: &FxHashMap<&'a str, Vec<&'a str>>, + visited: &mut FxHashSet<&'a str>, + postorder: &mut Vec<&'a str>, + ) { + if !visited.insert(block) { + return; + } + if let Some(succs) = successors.get(block) { + for &s in succs { + dfs(s, successors, visited, postorder); + } + } + postorder.push(block); + } + + let mut visited = FxHashSet::default(); + let mut postorder = Vec::new(); + dfs(entry, successors, &mut visited, &mut postorder); + postorder.reverse(); + postorder +} + +fn compute_dominators<'a>( + entry: &'a str, + rpo: &[&'a str], + predecessors: &FxHashMap<&'a str, Vec<&'a str>>, +) -> FxHashMap<&'a str, &'a str> { + let rpo_number: FxHashMap<&str, usize> = rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + let mut idom: FxHashMap<&str, &str> = FxHashMap::default(); + + let mut changed = true; + while changed { + changed = false; + for &block in rpo { + if block == entry { + continue; + } + let Some(preds) = predecessors.get(block) else { + continue; + }; + let mut new_idom = None; + for &pred in preds { + if pred == entry || idom.contains_key(pred) { + new_idom = Some(pred); + break; + } + } + let Some(mut new_idom_val) = new_idom else { + continue; + }; + for &pred in preds { + if pred == new_idom_val { + continue; + } + if pred == entry || idom.contains_key(pred) { + new_idom_val = intersect(pred, new_idom_val, &idom, &rpo_number, entry); + } + } + if idom.get(block) != Some(&new_idom_val) { + idom.insert(block, new_idom_val); + changed = true; + } + } + } + idom +} + +fn intersect<'a>( + mut b1: &'a str, + mut b2: &'a str, + idom: &FxHashMap<&'a str, &'a str>, + rpo_number: &FxHashMap<&str, usize>, + entry: &'a str, +) -> &'a str { + while b1 != b2 { + while rpo_number.get(b1).copied().unwrap_or(0) > rpo_number.get(b2).copied().unwrap_or(0) { + b1 = if b1 == entry { + entry + } else { + idom.get(b1).copied().unwrap_or(entry) + }; + } + while rpo_number.get(b2).copied().unwrap_or(0) > rpo_number.get(b1).copied().unwrap_or(0) { + b2 = if b2 == entry { + entry + } else { + idom.get(b2).copied().unwrap_or(entry) + }; + } + } + b1 +} + +fn dominates(def: &str, use_block: &str, idom: &FxHashMap<&str, &str>, entry: &str) -> bool { + if def == use_block { + return true; + } + let mut current = use_block; + while current != entry { + current = match idom.get(current) { + Some(&dom) => dom, + None => return false, + }; + if current == def { + return true; + } + } + def == entry +} + +fn validate_dominance( + func: &Function, + ssa_env: &FxHashMap, + idom: &FxHashMap<&str, &str>, +) -> Vec { + let mut errors = Vec::new(); + let entry = func.basic_blocks[0].name.as_str(); + + // Build def_block map: SSA name → block where it's defined + let mut def_block: FxHashMap = FxHashMap::default(); + for (i, param) in func.params.iter().enumerate() { + let name = param.name.clone().unwrap_or_else(|| i.to_string()); + def_block.insert(name, entry.to_string()); + } + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Some((name, _)) = instruction_result(instr) { + def_block.insert(name, bb.name.clone()); + } + } + } + + for bb in &func.basic_blocks { + for instr in &bb.instructions { + // PHI incoming values: check that def dominates the incoming block + if let Instruction::Phi { incoming, .. } = instr { + for (operand, label) in incoming { + let (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _)) = operand + else { + continue; + }; + if !ssa_env.contains_key(name) { + continue; + } + if let Some(db) = def_block.get(name) + && db != label + && !dominates(db, label, idom, entry) + { + errors.push(LlvmIrError::UseNotDominatedByDef { + name: name.clone(), + def_block: db.clone(), + use_block: label.clone(), + function: func.name.clone(), + }); + } + } + continue; + } + + // Regular instructions: check that def dominates the current block + for operand in instruction_operands(instr) { + let (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _)) = operand else { + continue; + }; + if !ssa_env.contains_key(name) { + continue; + } + if let Some(db) = def_block.get(name) + && db.as_str() != bb.name.as_str() + && !dominates(db, &bb.name, idom, entry) + { + errors.push(LlvmIrError::UseNotDominatedByDef { + name: name.clone(), + def_block: db.clone(), + use_block: bb.name.clone(), + function: func.name.clone(), + }); + } + } + } + } + errors +} + +// --------------------------------------------------------------------------- +// Attribute group validation +// --------------------------------------------------------------------------- + +fn validate_attribute_groups(module: &Module) -> Vec { + let mut errors = Vec::new(); + let mut seen_ids = FxHashSet::default(); + let mut valid_ids = FxHashSet::default(); + + for group in &module.attribute_groups { + if !seen_ids.insert(group.id) { + errors.push(LlvmIrError::DuplicateAttributeGroupId { id: group.id }); + } + valid_ids.insert(group.id); + } + + for func in &module.functions { + for &ref_id in &func.attribute_group_refs { + if !valid_ids.contains(&ref_id) { + errors.push(LlvmIrError::InvalidAttributeGroupRef { + function: func.name.clone(), + ref_id, + }); + } + } + + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Instruction::Call { attr_refs, .. } = instr { + for &ref_id in attr_refs { + if !valid_ids.contains(&ref_id) { + errors.push(LlvmIrError::InvalidAttributeGroupRef { + function: func.name.clone(), + ref_id, + }); + } + } + } + } + } + } + + errors +} + +// --------------------------------------------------------------------------- +// Metadata validation +// --------------------------------------------------------------------------- + +fn extract_node_refs(values: &[MetadataValue]) -> Vec { + let mut refs = Vec::new(); + for v in values { + match v { + MetadataValue::NodeRef(id) => refs.push(*id), + MetadataValue::SubList(sub) => refs.extend(extract_node_refs(sub)), + _ => {} + } + } + refs +} + +fn detect_metadata_cycles( + node_id: u32, + nodes: &FxHashMap, + visited: &mut FxHashSet, + in_stack: &mut FxHashSet, + errors: &mut Vec, +) { + if in_stack.contains(&node_id) { + errors.push(LlvmIrError::MetadataRefCycle { node_id }); + return; + } + if visited.contains(&node_id) { + return; + } + visited.insert(node_id); + in_stack.insert(node_id); + if let Some(values) = nodes.get(&node_id) { + for ref_id in extract_node_refs(values) { + detect_metadata_cycles(ref_id, nodes, visited, in_stack, errors); + } + } + in_stack.remove(&node_id); +} + +fn validate_metadata(module: &Module) -> Vec { + let mut errors = Vec::new(); + let mut seen_ids = FxHashSet::default(); + let mut valid_ids = FxHashSet::default(); + let mut node_map: FxHashMap = FxHashMap::default(); + + // ID uniqueness + for node in &module.metadata_nodes { + if !seen_ids.insert(node.id) { + errors.push(LlvmIrError::DuplicateMetadataNodeId { id: node.id }); + } + valid_ids.insert(node.id); + node_map.insert(node.id, &node.values); + } + + // Reference validity: named metadata + for nm in &module.named_metadata { + for &ref_id in &nm.node_refs { + if !valid_ids.contains(&ref_id) { + errors.push(LlvmIrError::InvalidMetadataNodeRef { + context: format!("named metadata `!{}`", nm.name), + ref_id, + }); + } + } + } + + // Reference validity: node-to-node references + for node in &module.metadata_nodes { + for ref_id in extract_node_refs(&node.values) { + if !valid_ids.contains(&ref_id) { + errors.push(LlvmIrError::InvalidMetadataNodeRef { + context: format!("metadata node !{}", node.id), + ref_id, + }); + } + } + } + + // Cycle detection + let mut visited = FxHashSet::default(); + let mut in_stack = FxHashSet::default(); + for node in &module.metadata_nodes { + detect_metadata_cycles(node.id, &node_map, &mut visited, &mut in_stack, &mut errors); + } + + errors +} diff --git a/source/compiler/qsc_llvm/src/validation/llvm/tests.rs b/source/compiler/qsc_llvm/src/validation/llvm/tests.rs new file mode 100644 index 0000000000..ad6371ecec --- /dev/null +++ b/source/compiler/qsc_llvm/src/validation/llvm/tests.rs @@ -0,0 +1,2702 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::model::Type; +use crate::model::{ + Attribute, AttributeGroup, BasicBlock, Function, IntPredicate, MetadataNode, Module, + NamedMetadata, Param, StructType, +}; + +// ----------------------------------------------------------------------- +// Baseline helpers +// ----------------------------------------------------------------------- + +fn valid_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Integer(64), + params: vec![Param { + ty: Type::Integer(64), + name: Some("x".to_string()), + }], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("x".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "sum".to_string(), + }, + Instruction::Ret(Some(Operand::LocalRef("sum".to_string()))), + ], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn valid_switch_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Integer(64), + name: Some("x".to_string()), + }], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Switch { + ty: Type::Integer(64), + value: Operand::LocalRef("x".to_string()), + default_dest: "default".to_string(), + cases: vec![(0, "zero".to_string())], + }], + }, + BasicBlock { + name: "zero".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + BasicBlock { + name: "default".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn valid_alloca_module() -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Alloca { + ty: Type::Integer(64), + result: "slot".to_string(), + }, + Instruction::Ret(None), + ], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn two_block_module_with_phi(phi_instr: Instruction) -> Module { + Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![phi_instr, Instruction::Ret(None)], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + } +} + +fn has_error bool>(errors: &[LlvmIrError], pred: F) -> bool { + errors.iter().any(pred) +} + +fn declaration(name: &str, return_type: Type, params: Vec) -> Function { + Function { + name: name.to_string(), + return_type, + params, + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + } +} + +fn typed_ptr(inner: Type) -> Type { + Type::TypedPtr(Box::new(inner)) +} + +// ----------------------------------------------------------------------- +// Step 5.6: Positive baseline test +// ----------------------------------------------------------------------- + +#[test] +fn valid_module_passes() { + let m = valid_module(); + let errors = validate_ir(&m); + assert!(errors.is_empty(), "unexpected errors: {errors:?}"); +} + +// ----------------------------------------------------------------------- +// Step 5.2: Structure and terminator tests +// ----------------------------------------------------------------------- + +#[test] +fn missing_basic_blocks() { + let mut m = valid_module(); + m.functions[0].is_declaration = false; + m.functions[0].basic_blocks = vec![]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::MissingBasicBlocks { .. } + ))); +} + +#[test] +fn declaration_has_blocks() { + let mut m = valid_module(); + m.functions[0].is_declaration = true; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::DeclarationHasBlocks { .. } + ))); +} + +#[test] +fn empty_basic_block() { + let mut m = valid_module(); + m.functions[0].basic_blocks.push(BasicBlock { + name: "empty_bb".to_string(), + instructions: vec![], + }); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::EmptyBasicBlock { .. } + ))); +} + +#[test] +fn missing_terminator() { + let mut m = valid_module(); + // Replace the Ret with a non-terminator + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.pop(); // remove Ret + bb.instructions.push(Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 0), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "no_term".to_string(), + }); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::MissingTerminator { .. } + ))); +} + +#[test] +fn mid_block_terminator() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + // Insert Unreachable before the final Ret + bb.instructions + .insert(bb.instructions.len() - 1, Instruction::Unreachable); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::MidBlockTerminator { .. } + ))); +} + +// ----------------------------------------------------------------------- +// Step 5.3: SSA and reference validation tests +// ----------------------------------------------------------------------- + +#[test] +fn duplicate_definition() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + // Insert a second instruction with the same result name "sum" + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("x".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "sum".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::DuplicateDefinition { .. } + ))); +} + +#[test] +fn undefined_local_ref() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("nonexistent".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "undef_use".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedLocalRef { .. } + ))); +} + +#[test] +fn typed_local_ref_undefined_local() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::TypedLocalRef("typed_missing".to_string(), Type::Integer(64)), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "typed_undef_use".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedLocalRef { name, .. } if name == "typed_missing" + ))); +} + +#[test] +fn typed_local_ref_use_before_definition() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::TypedLocalRef("late".to_string(), Type::Integer(64)), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "early_use".to_string(), + }, + ); + bb.instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 2), + rhs: Operand::IntConst(Type::Integer(64), 3), + result: "late".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedLocalRef { name, .. } if name == "late" + ))); +} + +#[test] +fn local_ref_use_before_definition() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("late".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "early_use".to_string(), + }, + ); + bb.instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 2), + rhs: Operand::IntConst(Type::Integer(64), 3), + result: "late".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedLocalRef { name, .. } if name == "late" + ))); +} + +#[test] +fn alloca_result_used_before_definition() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Load { + ty: Type::Integer(64), + ptr_ty: Type::Ptr, + ptr: Operand::LocalRef("stack_slot".to_string()), + result: "loaded".to_string(), + }, + ); + bb.instructions.insert( + 1, + Instruction::Alloca { + ty: Type::Integer(64), + result: "stack_slot".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedLocalRef { name, .. } if name == "stack_slot" + ))); +} + +#[test] +fn undefined_global_ref() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Load { + ty: Type::Integer(64), + ptr_ty: Type::Ptr, + ptr: Operand::GlobalRef("missing_global".to_string()), + result: "loaded".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedGlobalRef { name, .. } if name == "missing_global" + ))); +} + +#[test] +fn invalid_branch_target() { + let mut m = valid_module(); + m.functions[0].basic_blocks[0].instructions = vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "no_such_block".to_string(), + false_dest: "entry".to_string(), + }]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidBranchTarget { .. } + ))); +} + +#[test] +fn invalid_jump_target() { + let mut m = valid_module(); + m.functions[0].basic_blocks[0].instructions = vec![Instruction::Jump { + dest: "no_such_block".to_string(), + }]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidBranchTarget { target, .. } if target == "no_such_block" + ))); +} + +#[test] +fn invalid_switch_target() { + let mut m = valid_module(); + m.functions[0].basic_blocks[0].instructions = vec![Instruction::Switch { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + default_dest: "entry".to_string(), + cases: vec![(1, "no_such_block".to_string())], + }]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidBranchTarget { target, .. } if target == "no_such_block" + ))); +} + +#[test] +fn invalid_switch_default_target() { + let mut m = valid_switch_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::LocalRef("x".to_string()), + default_dest: "no_such_block".to_string(), + cases: vec![(0, "zero".to_string())], + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidBranchTarget { target, .. } if target == "no_such_block" + ))); +} + +#[test] +fn valid_switch_passes() { + let m = valid_switch_module(); + let errors = validate_ir(&m); + assert!(errors.is_empty(), "unexpected errors: {errors:?}"); +} + +#[test] +fn switch_type_not_integer() { + let mut m = valid_switch_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Switch { + ty: Type::Double, + value: Operand::float_const(Type::Double, 0.0), + default_dest: "default".to_string(), + cases: vec![(0, "zero".to_string())], + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::SwitchTypeNotInteger { found_type, .. } if found_type == "double" + ))); +} + +#[test] +fn switch_selector_type_mismatch() { + let mut m = valid_switch_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Switch { + ty: Type::Integer(32), + value: Operand::LocalRef("x".to_string()), + default_dest: "default".to_string(), + cases: vec![(0, "zero".to_string())], + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { + instruction, + expected, + found, + .. + } if instruction == "Switch" && expected == "i32" && found == "i64" + ))); +} + +#[test] +fn switch_duplicate_case_value() { + let mut m = valid_switch_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::LocalRef("x".to_string()), + default_dest: "default".to_string(), + cases: vec![(0, "zero".to_string()), (0, "default".to_string())], + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::SwitchDuplicateCaseValue { case_value: 0, .. } + ))); +} + +#[test] +fn valid_alloca_passes() { + let m = valid_alloca_module(); + let errors = validate_ir(&m); + assert!(errors.is_empty(), "unexpected errors: {errors:?}"); +} + +#[test] +fn alloca_void_type() { + let mut m = valid_alloca_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Alloca { + ty: Type::Void, + result: "void_slot".to_string(), + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::AllocaUnsizedType { result, .. } if result == "void_slot" + ))); +} + +#[test] +fn alloca_function_type() { + let mut m = valid_alloca_module(); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Alloca { + ty: Type::Function(Box::new(Type::Void), vec![Type::Integer(64)]), + result: "fn_slot".to_string(), + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::AllocaUnsizedType { result, .. } if result == "fn_slot" + ))); +} + +#[test] +fn alloca_named_opaque_struct_type() { + let mut m = valid_alloca_module(); + m.struct_types.push(StructType { + name: "Opaque".to_string(), + is_opaque: true, + }); + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Alloca { + ty: Type::Named("Opaque".to_string()), + result: "opaque_slot".to_string(), + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::AllocaUnsizedType { result, .. } if result == "opaque_slot" + ))); +} + +#[test] +fn undefined_callee() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Call { + return_ty: Some(Type::Void), + callee: "nonexistent_fn".to_string(), + args: vec![], + result: None, + attr_refs: vec![], + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UndefinedCallee { .. } + ))); +} + +#[test] +fn bitcode_roundtrip_call_preserves_callee_name_for_validator() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Ret(None)], + }], + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: vec![], + result: None, + attr_refs: vec![], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + let caller = parsed + .functions + .iter() + .find(|function| function.name == "caller") + .expect("missing caller function"); + assert!(matches!( + &caller.basic_blocks[0].instructions[0], + Instruction::Call { callee, .. } if callee == "callee" + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +fn bitcode_roundtrip_call_preserves_attr_refs_for_validator() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: Vec::new(), + result: None, + attr_refs: vec![0, 1], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: vec![ + AttributeGroup { + id: 0, + attributes: vec![Attribute::StringAttr("alwaysinline".to_string())], + }, + AttributeGroup { + id: 1, + attributes: vec![Attribute::StringAttr("noreturn".to_string())], + }, + ], + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + let caller = parsed + .functions + .iter() + .find(|function| function.name == "caller") + .expect("missing caller function"); + assert!(matches!( + &caller.basic_blocks[0].instructions[0], + Instruction::Call { attr_refs, .. } if attr_refs == &vec![0, 1] + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +fn bitcode_roundtrip_global_ref_preserves_name_for_validator() { + use crate::model::{Constant, GlobalVariable, Linkage}; + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: vec![GlobalVariable { + name: "message".to_string(), + ty: Type::Array(5, Box::new(Type::Integer(8))), + linkage: Linkage::Internal, + is_constant: true, + initializer: Some(Constant::CString("hello".to_string())), + }], + functions: vec![ + Function { + name: "use_ptr".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Ptr, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "use_ptr".to_string(), + args: vec![(Type::Ptr, Operand::GlobalRef("message".to_string()))], + result: None, + attr_refs: vec![], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + let caller = parsed + .functions + .iter() + .find(|function| function.name == "caller") + .expect("missing caller function"); + assert!(matches!( + &caller.basic_blocks[0].instructions[0], + Instruction::Call { args, .. } + if matches!(args.first(), Some((Type::Ptr, Operand::GlobalRef(name))) if name == "message") + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +fn bitcode_roundtrip_preserves_local_names_for_validator() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "chooser".to_string(), + return_type: Type::Integer(64), + params: vec![ + Param { + ty: Type::Integer(64), + name: Some("a".to_string()), + }, + Param { + ty: Type::Integer(64), + name: Some("b".to_string()), + }, + ], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::ICmp { + pred: IntPredicate::Slt, + ty: Type::Integer(64), + lhs: Operand::LocalRef("a".to_string()), + rhs: Operand::LocalRef("b".to_string()), + result: "cond".to_string(), + }, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("cond".to_string()), + true_dest: "then".to_string(), + false_dest: "else".to_string(), + }, + ], + }, + BasicBlock { + name: "then".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("a".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "then_value".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "else".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("b".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "else_value".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + ( + Operand::LocalRef("then_value".to_string()), + "then".to_string(), + ), + ( + Operand::LocalRef("else_value".to_string()), + "else".to_string(), + ), + ], + result: "result".to_string(), + }, + Instruction::Ret(Some(Operand::LocalRef("result".to_string()))), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + let chooser = parsed + .functions + .iter() + .find(|function| function.name == "chooser") + .expect("missing chooser function"); + + assert_eq!(chooser.params[0].name.as_deref(), Some("a")); + assert_eq!(chooser.params[1].name.as_deref(), Some("b")); + assert_eq!( + chooser + .basic_blocks + .iter() + .map(|bb| bb.name.as_str()) + .collect::>(), + vec!["entry", "then", "else", "merge"] + ); + assert!(matches!( + &chooser.basic_blocks[0].instructions[0], + Instruction::ICmp { lhs, rhs, result, .. } + if result == "cond" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "a") + && matches!(rhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "b") + )); + assert!(matches!( + &chooser.basic_blocks[0].instructions[1], + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } if true_dest == "then" + && false_dest == "else" + && matches!(cond, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "cond") + )); + assert!(matches!( + &chooser.basic_blocks[1].instructions[0], + Instruction::BinOp { lhs, result, .. } + if result == "then_value" + && matches!(lhs, Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) if name == "a") + )); + assert!(matches!( + &chooser.basic_blocks[1].instructions[1], + Instruction::Jump { dest } if dest == "merge" + )); + assert!(matches!( + &chooser.basic_blocks[3].instructions[0], + Instruction::Phi { + incoming, + result, + .. + } if result == "result" + && incoming.len() == 2 + && matches!(&incoming[0], (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _), from) if name == "then_value" && from == "then") + && matches!(&incoming[1], (Operand::LocalRef(name) | Operand::TypedLocalRef(name, _), from) if name == "else_value" && from == "else") + )); + assert!(matches!( + &chooser.basic_blocks[3].instructions[1], + Instruction::Ret(Some(Operand::LocalRef(name) | Operand::TypedLocalRef(name, _))) if name == "result" + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn bitcode_roundtrip_preserves_float_local_ref_types_for_validator() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + Function { + name: "use_double".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Double, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "test".to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Double, + name: Some("a".to_string()), + }, + Param { + ty: Type::Double, + name: Some("b".to_string()), + }, + ], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::TypedLocalRef("a".to_string(), Type::Double), + rhs: Operand::TypedLocalRef("b".to_string(), Type::Double), + result: "r".to_string(), + }, + Instruction::Call { + return_ty: None, + callee: "use_double".to_string(), + args: vec![( + Type::Double, + Operand::TypedLocalRef("r".to_string(), Type::Double), + )], + result: None, + attr_refs: vec![], + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + let test_fn = parsed + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + assert_eq!(test_fn.params[0].name.as_deref(), Some("a")); + assert_eq!(test_fn.params[1].name.as_deref(), Some("b")); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[0], + Instruction::BinOp { + op: BinOpKind::Fadd, + ty, + lhs, + rhs, + result, + } if ty == &Type::Double + && result == "r" + && matches!(lhs, Operand::TypedLocalRef(name, local_ty) if name == "a" && local_ty == &Type::Double) + && matches!(rhs, Operand::TypedLocalRef(name, local_ty) if name == "b" && local_ty == &Type::Double) + )); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[1], + Instruction::Call { + return_ty: None, + args, + result: None, + .. + } if matches!( + args.as_slice(), + [(Type::Double, Operand::TypedLocalRef(name, ty))] + if name == "r" && ty == &Type::Double + ) + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +fn bitcode_roundtrip_preserves_load_local_ref_types_for_validator() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::Alloca { + ty: Type::Integer(64), + result: "ptr".to_string(), + }, + Instruction::Load { + ty: Type::Integer(64), + ptr_ty: Type::Ptr, + ptr: Operand::TypedLocalRef("ptr".to_string(), Type::Ptr), + result: "val".to_string(), + }, + Instruction::Ret(Some(Operand::TypedLocalRef( + "val".to_string(), + Type::Integer(64), + ))), + ], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + let test_fn = parsed + .functions + .iter() + .find(|function| function.name == "test") + .expect("missing test function"); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[0], + Instruction::Alloca { ty, result } if ty == &Type::Integer(64) && result == "ptr" + )); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[1], + Instruction::Load { + ty, + ptr_ty, + ptr, + result, + } if ty == &Type::Integer(64) + && ptr_ty == &Type::Ptr + && result == "val" + && matches!(ptr, Operand::TypedLocalRef(name, local_ty) if name == "ptr" && local_ty == &Type::Ptr) + )); + assert!(matches!( + &test_fn.basic_blocks[0].instructions[2], + Instruction::Ret(Some(Operand::TypedLocalRef(name, ty))) + if name == "val" && ty == &Type::Integer(64) + )); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +// ----------------------------------------------------------------------- +// Step 5.4: Type consistency and cast validation tests +// ----------------------------------------------------------------------- + +#[test] +fn binop_type_mismatch() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(32), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "mismatch".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { .. } + ))); +} + +#[test] +fn int_op_on_float() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Double, + lhs: Operand::float_const(Type::Double, 1.0), + rhs: Operand::float_const(Type::Double, 2.0), + result: "bad_int_op".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::IntOpOnNonInt { .. } + ))); +} + +#[test] +fn float_op_on_int() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "bad_float_op".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::FloatOpOnNonFloat { .. } + ))); +} + +#[test] +fn br_cond_not_i1() { + let mut m = valid_module(); + m.functions[0].basic_blocks[0].instructions = vec![Instruction::Br { + cond_ty: Type::Integer(32), + cond: Operand::IntConst(Type::Integer(32), 0), + true_dest: "entry".to_string(), + false_dest: "entry".to_string(), + }]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::BrCondNotI1 { .. } + ))); +} + +#[test] +fn ret_type_mismatch() { + let mut m = valid_module(); + // Function returns i64 but we return a double + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.pop(); // remove Ret + bb.instructions + .push(Instruction::Ret(Some(Operand::float_const( + Type::Double, + 1.0, + )))); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::RetTypeMismatch { .. } + ))); +} + +#[test] +#[allow(clippy::too_many_lines)] +fn widened_float_surface_accepts_half_float_and_double() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![ + declaration( + "consume_half", + Type::Void, + vec![Param { + ty: Type::Half, + name: Some("value".to_string()), + }], + ), + declaration( + "consume_float", + Type::Void, + vec![Param { + ty: Type::Float, + name: Some("value".to_string()), + }], + ), + declaration( + "consume_double", + Type::Void, + vec![Param { + ty: Type::Double, + name: Some("value".to_string()), + }], + ), + Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Half, + lhs: Operand::float_const(Type::Half, 1.5), + rhs: Operand::float_const(Type::Half, 2.25), + result: "half_sum".to_string(), + }, + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Half, + to_ty: Type::Float, + value: Operand::LocalRef("half_sum".to_string()), + result: "as_float".to_string(), + }, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Float, + lhs: Operand::LocalRef("as_float".to_string()), + rhs: Operand::float_const(Type::Float, 0.5), + result: "float_sum".to_string(), + }, + Instruction::FCmp { + pred: crate::model::FloatPredicate::Oeq, + ty: Type::Float, + lhs: Operand::LocalRef("float_sum".to_string()), + rhs: Operand::float_const(Type::Float, 4.25), + result: "float_eq".to_string(), + }, + Instruction::Cast { + op: CastKind::FpExt, + from_ty: Type::Float, + to_ty: Type::Double, + value: Operand::LocalRef("float_sum".to_string()), + result: "as_double".to_string(), + }, + Instruction::Cast { + op: CastKind::FpTrunc, + from_ty: Type::Double, + to_ty: Type::Half, + value: Operand::LocalRef("as_double".to_string()), + result: "back_to_half".to_string(), + }, + Instruction::Call { + return_ty: None, + callee: "consume_half".to_string(), + args: vec![(Type::Half, Operand::LocalRef("back_to_half".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "consume_float".to_string(), + args: vec![(Type::Float, Operand::LocalRef("float_sum".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Call { + return_ty: None, + callee: "consume_double".to_string(), + args: vec![(Type::Double, Operand::LocalRef("as_double".to_string()))], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let errors = validate_ir(&m); + assert!( + errors.is_empty(), + "valid half/float/double flow failed: {errors:?}" + ); +} + +#[test] +fn typed_local_ref_use_site_type_masking() { + let mut m = valid_module(); + m.functions.push(declaration( + "consume_i1", + Type::Void, + vec![Param { + ty: Type::Integer(1), + name: Some("flag".to_string()), + }], + )); + + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "wide".to_string(), + }, + ); + bb.instructions.insert( + 1, + Instruction::Call { + return_ty: None, + callee: "consume_i1".to_string(), + args: vec![( + Type::Integer(1), + Operand::TypedLocalRef("wide".to_string(), Type::Integer(1)), + )], + result: None, + attr_refs: Vec::new(), + }, + ); + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { + instruction, + expected, + found, + .. + } if instruction == "Call @consume_i1" && expected == "i1" && found == "i64" + ))); +} + +#[test] +fn load_pointer_operand_mismatch() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Load { + ty: Type::Integer(64), + ptr_ty: typed_ptr(Type::Integer(64)), + ptr: Operand::IntToPtr(0, typed_ptr(Type::Integer(8))), + result: "loaded".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { instruction, .. } if instruction == "Load" + ))); +} + +#[test] +fn store_pointer_operand_mismatch() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Store { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + ptr_ty: typed_ptr(Type::Integer(64)), + ptr: Operand::IntToPtr(0, typed_ptr(Type::Integer(8))), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { instruction, .. } if instruction == "Store" + ))); +} + +#[test] +fn call_arg_operand_type_mismatch() { + let mut m = valid_module(); + m.functions.push(declaration( + "consume_i64", + Type::Void, + vec![Param { + ty: Type::Integer(64), + name: Some("value".to_string()), + }], + )); + + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Call { + return_ty: None, + callee: "consume_i64".to_string(), + args: vec![(Type::Integer(64), Operand::IntConst(Type::Integer(1), 1))], + result: None, + attr_refs: Vec::new(), + }, + ); + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { + instruction, + expected, + found, + .. + } if instruction == "Call @consume_i64" && expected == "i64" && found == "i1" + ))); +} + +#[test] +fn invalid_cast_trunc_wider() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Cast { + op: CastKind::Trunc, + from_ty: Type::Integer(32), + to_ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(32), 0), + result: "bad_trunc".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidCast { .. } + ))); +} + +#[test] +fn valid_cast_zext() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Cast { + op: CastKind::Zext, + from_ty: Type::Integer(32), + to_ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(32), 0), + result: "good_zext".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!( + !has_error(&errors, |e| matches!(e, LlvmIrError::InvalidCast { .. })), + "zext i32→i64 should be valid, but got: {errors:?}" + ); +} + +#[test] +fn invalid_cast_int_to_ptr_non_int() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Cast { + op: CastKind::IntToPtr, + from_ty: Type::Double, + to_ty: Type::Ptr, + value: Operand::float_const(Type::Double, 0.0), + result: "bad_inttoptr".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidCast { .. } + ))); +} + +// ----------------------------------------------------------------------- +// Step 5.5: PHI validation tests +// ----------------------------------------------------------------------- + +#[test] +fn phi_not_at_block_start() { + let mut m = two_block_module_with_phi(Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![(Operand::IntConst(Type::Integer(64), 0), "entry".to_string())], + result: "p".to_string(), + }); + // Insert a non-PHI instruction before the PHI in block "merge" + let merge = &mut m.functions[0].basic_blocks[1]; + merge.instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 0), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "dummy".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiNotAtBlockStart { .. } + ))); +} + +#[test] +fn phi_void_type() { + let m = two_block_module_with_phi(Instruction::Phi { + ty: Type::Void, + incoming: vec![(Operand::IntConst(Type::Integer(64), 0), "entry".to_string())], + result: "p".to_string(), + }); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiVoidType { .. } + ))); +} + +#[test] +fn phi_in_entry_block() { + let mut m = valid_module(); + // Put a PHI node in the entry block + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![], + result: "phi_entry".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiInEntryBlock { .. } + ))); +} + +#[test] +fn phi_backedge_value_defined_later_in_same_block_is_valid() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "loop".to_string(), + }], + }, + BasicBlock { + name: "loop".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::IntConst(Type::Integer(64), 0), "entry".to_string()), + (Operand::LocalRef("next".to_string()), "loop".to_string()), + ], + result: "acc".to_string(), + }, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("acc".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "next".to_string(), + }, + Instruction::ICmp { + pred: IntPredicate::Slt, + ty: Type::Integer(64), + lhs: Operand::LocalRef("next".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 10), + result: "cond".to_string(), + }, + Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::LocalRef("cond".to_string()), + true_dest: "loop".to_string(), + false_dest: "exit".to_string(), + }, + ], + }, + BasicBlock { + name: "exit".to_string(), + instructions: vec![Instruction::Ret(Some(Operand::LocalRef( + "next".to_string(), + )))], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let errors = validate_ir(&m); + assert!( + errors.is_empty(), + "expected no validation errors, got: {errors:?}" + ); +} + +#[test] +fn phi_pred_count_mismatch() { + // "merge" has 1 predecessor ("entry"), but PHI has 2 incoming entries + let m = two_block_module_with_phi(Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::IntConst(Type::Integer(64), 0), "entry".to_string()), + (Operand::IntConst(Type::Integer(64), 1), "other".to_string()), + ], + result: "p".to_string(), + }); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiPredCountMismatch { .. } + ))); +} + +#[test] +fn phi_incoming_not_predecessor() { + // "merge" has 1 predecessor ("entry"), list only a non-predecessor label + let m = two_block_module_with_phi(Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![( + Operand::IntConst(Type::Integer(64), 0), + "no_such_pred".to_string(), + )], + result: "p".to_string(), + }); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiIncomingNotPredecessor { .. } + ))); +} + +#[test] +fn phi_predecessor_multiset_mismatch() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "left".to_string(), + false_dest: "right".to_string(), + }], + }, + BasicBlock { + name: "left".to_string(), + instructions: vec![Instruction::Jump { + dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "right".to_string(), + instructions: vec![Instruction::Jump { + dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::IntConst(Type::Integer(64), 0), "left".to_string()), + (Operand::IntConst(Type::Integer(64), 0), "left".to_string()), + ], + result: "p".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiIncomingNotPredecessor { + incoming_block, + .. + } if incoming_block == "left" + ))); +} + +#[test] +fn phi_duplicate_block_diff_value() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "merge".to_string(), + false_dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::IntConst(Type::Integer(64), 0), "entry".to_string()), + (Operand::IntConst(Type::Integer(64), 1), "entry".to_string()), + ], + result: "p".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PhiDuplicateBlockDiffValue { dup_block, .. } if dup_block == "entry" + ))); +} + +// ----------------------------------------------------------------------- +// Step 5.5: GEP validation tests +// ----------------------------------------------------------------------- + +#[test] +fn gep_no_indices() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Ptr, + ptr: Operand::NullPtr, + indices: vec![], + result: "gep_empty".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::GepNoIndices { .. } + ))); +} + +#[test] +fn gep_void_pointee() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Void, + ptr_ty: Type::Ptr, + ptr: Operand::NullPtr, + indices: vec![Operand::IntConst(Type::Integer(32), 0)], + result: "gep_void".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UnsizedPointeeType { .. } + ))); +} + +#[test] +fn gep_non_ptr() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Integer(64), + ptr: Operand::IntConst(Type::Integer(64), 0), + indices: vec![Operand::IntConst(Type::Integer(32), 0)], + result: "gep_notptr".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::PtrExpected { .. } + ))); +} + +#[test] +fn gep_non_integer_index() { + let mut m = valid_module(); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::GetElementPtr { + inbounds: true, + pointee_ty: Type::Integer(8), + ptr_ty: Type::Ptr, + ptr: Operand::NullPtr, + indices: vec![Operand::float_const(Type::Double, 0.0)], + result: "gep_bad_index".to_string(), + }, + ); + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { + instruction, + expected, + .. + } if instruction == "GetElementPtr" && expected == "integer type" + ))); +} + +// ----------------------------------------------------------------------- +// Step 8.6: Dominance validation tests +// ----------------------------------------------------------------------- + +#[test] +fn cross_block_non_dominating_use() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: vec![], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "then_bb".to_string(), + false_dest: "else_bb".to_string(), + }], + }, + BasicBlock { + name: "then_bb".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "x".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "else_bb".to_string(), + instructions: vec![Instruction::Jump { + dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("x".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 0), + result: "y".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UseNotDominatedByDef { .. } + ))); +} + +#[test] +fn phi_from_non_dominating_block() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: vec![], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "left".to_string(), + false_dest: "right".to_string(), + }], + }, + BasicBlock { + name: "left".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "x".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "right".to_string(), + instructions: vec![Instruction::Jump { + dest: "merge".to_string(), + }], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::LocalRef("x".to_string()), "right".to_string()), + (Operand::IntConst(Type::Integer(64), 0), "left".to_string()), + ], + result: "p".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UseNotDominatedByDef { .. } + ))); +} + +#[test] +fn valid_diamond_with_phi() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: vec![], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 0), + true_dest: "then_bb".to_string(), + false_dest: "else_bb".to_string(), + }], + }, + BasicBlock { + name: "then_bb".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "x".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "else_bb".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 3), + rhs: Operand::IntConst(Type::Integer(64), 4), + result: "y_val".to_string(), + }, + Instruction::Jump { + dest: "merge".to_string(), + }, + ], + }, + BasicBlock { + name: "merge".to_string(), + instructions: vec![ + Instruction::Phi { + ty: Type::Integer(64), + incoming: vec![ + (Operand::LocalRef("x".to_string()), "then_bb".to_string()), + ( + Operand::LocalRef("y_val".to_string()), + "else_bb".to_string(), + ), + ], + result: "p".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!( + !has_error(&errors, |e| matches!( + e, + LlvmIrError::UseNotDominatedByDef { .. } + )), + "valid diamond with PHI should not report dominance errors: {errors:?}" + ); +} + +#[test] +fn unreachable_block_def() { + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: Vec::new(), + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: vec![], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "reachable".to_string(), + }], + }, + BasicBlock { + name: "unreachable_bb".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "x".to_string(), + }, + Instruction::Jump { + dest: "reachable".to_string(), + }, + ], + }, + BasicBlock { + name: "reachable".to_string(), + instructions: vec![ + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::LocalRef("x".to_string()), + rhs: Operand::IntConst(Type::Integer(64), 0), + result: "z".to_string(), + }, + Instruction::Ret(None), + ], + }, + ], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::UseNotDominatedByDef { .. } + ))); +} + +#[test] +fn dominance_helpers_ignore_unreachable_blocks() { + let func = Function { + name: "test_fn".to_string(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![ + BasicBlock { + name: "entry".to_string(), + instructions: vec![Instruction::Jump { + dest: "reachable".to_string(), + }], + }, + BasicBlock { + name: "unreachable_bb".to_string(), + instructions: vec![Instruction::Jump { + dest: "reachable".to_string(), + }], + }, + BasicBlock { + name: "reachable".to_string(), + instructions: vec![Instruction::Ret(None)], + }, + ], + }; + + let (successors, predecessors) = build_cfg(&func); + let entry = func.basic_blocks[0].name.as_str(); + let rpo = reverse_postorder(entry, &successors); + let idom = compute_dominators(entry, &rpo, &predecessors); + + assert!(!rpo.contains(&"unreachable_bb")); + assert_eq!(idom.get("reachable").copied(), Some("entry")); + assert!(dominates("entry", "reachable", &idom, entry)); + assert!(!dominates("entry", "unreachable_bb", &idom, entry)); +} + +// ----------------------------------------------------------------------- +// Step 9.5: Attribute group and metadata validation tests +// ----------------------------------------------------------------------- + +#[test] +fn duplicate_attribute_group_id() { + let mut m = valid_module(); + m.attribute_groups = vec![ + AttributeGroup { + id: 0, + attributes: Vec::new(), + }, + AttributeGroup { + id: 0, + attributes: Vec::new(), + }, + ]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::DuplicateAttributeGroupId { id: 0 } + ))); +} + +#[test] +fn invalid_attribute_group_ref() { + let mut m = valid_module(); + m.attribute_groups = vec![AttributeGroup { + id: 0, + attributes: Vec::new(), + }]; + m.functions[0].attribute_group_refs = vec![99]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidAttributeGroupRef { ref_id: 99, .. } + ))); +} + +#[test] +fn invalid_call_site_attribute_group_ref() { + let mut m = valid_module(); + m.functions + .push(declaration("callee", Type::Void, Vec::new())); + let bb = &mut m.functions[0].basic_blocks[0]; + bb.instructions.insert( + 0, + Instruction::Call { + return_ty: None, + callee: "callee".to_string(), + args: Vec::new(), + result: None, + attr_refs: vec![99], + }, + ); + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidAttributeGroupRef { ref_id: 99, .. } + ))); +} + +#[test] +fn duplicate_metadata_node_id() { + let mut m = valid_module(); + m.metadata_nodes = vec![ + MetadataNode { + id: 0, + values: Vec::new(), + }, + MetadataNode { + id: 0, + values: Vec::new(), + }, + ]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::DuplicateMetadataNodeId { id: 0 } + ))); +} + +#[test] +fn invalid_metadata_node_ref() { + let mut m = valid_module(); + m.named_metadata = vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![42], + }]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidMetadataNodeRef { ref_id: 42, .. } + ))); +} + +#[test] +fn metadata_cycle() { + let mut m = valid_module(); + m.metadata_nodes = vec![ + MetadataNode { + id: 0, + values: vec![MetadataValue::NodeRef(1)], + }, + MetadataNode { + id: 1, + values: vec![MetadataValue::NodeRef(0)], + }, + ]; + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::MetadataRefCycle { .. } + ))); +} + +#[test] +fn valid_metadata_passes() { + let mut m = valid_module(); + m.attribute_groups = vec![ + AttributeGroup { + id: 0, + attributes: Vec::new(), + }, + AttributeGroup { + id: 1, + attributes: Vec::new(), + }, + ]; + m.functions[0].attribute_group_refs = vec![0]; + m.metadata_nodes = vec![ + MetadataNode { + id: 0, + values: vec![MetadataValue::Int(Type::Integer(32), 1)], + }, + MetadataNode { + id: 1, + values: vec![MetadataValue::NodeRef(0)], + }, + ]; + m.named_metadata = vec![NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![0, 1], + }]; + let errors = validate_ir(&m); + assert!( + !has_error(&errors, |e| matches!( + e, + LlvmIrError::DuplicateAttributeGroupId { .. } + | LlvmIrError::InvalidAttributeGroupRef { .. } + | LlvmIrError::DuplicateMetadataNodeId { .. } + | LlvmIrError::InvalidMetadataNodeRef { .. } + | LlvmIrError::MetadataRefCycle { .. } + )), + "valid metadata and attributes should not trigger errors: {errors:?}" + ); +} + +// ----------------------------------------------------------------------- +// Step 11.1: Bitcode round-trip NamedPtr regression test +// ----------------------------------------------------------------------- + +#[test] +fn bitcode_roundtrip_preserves_named_ptr_types() { + use crate::{parse_bitcode, write_bitcode}; + + let m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".into(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![Function { + name: "test_func".into(), + return_type: Type::Void, + params: vec![Param { + ty: Type::NamedPtr("Qubit".into()), + name: Some("q".into()), + }], + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".into(), + instructions: vec![Instruction::Ret(None)], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let orig_errors = validate_ir(&m); + assert!(orig_errors.is_empty(), "original: {orig_errors:?}"); + + let bc = write_bitcode(&m); + let parsed = parse_bitcode(&bc).expect("parse failed"); + + assert_eq!( + m.functions[0].params[0].ty, + parsed.functions[0].params[0].ty + ); + + let rt_errors = validate_ir(&parsed); + assert!(rt_errors.is_empty(), "round-tripped: {rt_errors:?}"); +} + +#[test] +fn named_ptr_inttoptr_cast_rejects_named_target_and_accepts_named_ptr_target() { + let mut m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".into(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![Function { + name: "test_fn".into(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".into(), + instructions: vec![ + Instruction::Cast { + op: CastKind::IntToPtr, + from_ty: Type::Integer(64), + to_ty: Type::Named("Qubit".into()), + value: Operand::IntConst(Type::Integer(64), 0), + result: "bad_qubit".into(), + }, + Instruction::Ret(None), + ], + }], + }], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::InvalidCast { + cast_kind, + from_ty, + to_ty, + .. + } if cast_kind == "IntToPtr" && from_ty == "i64" && to_ty == "%Qubit" + ))); + + m.functions[0].basic_blocks[0].instructions[0] = Instruction::Cast { + op: CastKind::IntToPtr, + from_ty: Type::Integer(64), + to_ty: Type::NamedPtr("Qubit".into()), + value: Operand::IntConst(Type::Integer(64), 0), + result: "good_qubit".into(), + }; + + let errors = validate_ir(&m); + assert!( + errors.is_empty(), + "valid named-pointer inttoptr cast failed: {errors:?}" + ); +} + +#[test] +fn named_ptr_call_rejects_named_inttoptr_operand() { + let qubit_ty = Type::NamedPtr("Qubit".into()); + let mut m = Module { + source_filename: None, + target_datalayout: None, + target_triple: None, + struct_types: vec![StructType { + name: "Qubit".into(), + is_opaque: true, + }], + globals: Vec::new(), + functions: vec![ + Function { + name: "callee".into(), + return_type: Type::Void, + params: vec![Param { + ty: qubit_ty.clone(), + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }, + Function { + name: "caller".into(), + return_type: Type::Void, + params: Vec::new(), + is_declaration: false, + attribute_group_refs: Vec::new(), + basic_blocks: vec![BasicBlock { + name: "entry".into(), + instructions: vec![ + Instruction::Call { + return_ty: None, + callee: "callee".into(), + args: vec![( + qubit_ty, + Operand::IntToPtr(0, Type::Named("Qubit".into())), + )], + result: None, + attr_refs: Vec::new(), + }, + Instruction::Ret(None), + ], + }], + }, + ], + attribute_groups: Vec::new(), + named_metadata: Vec::new(), + metadata_nodes: Vec::new(), + }; + + let errors = validate_ir(&m); + assert!(has_error(&errors, |e| matches!( + e, + LlvmIrError::TypeMismatch { + instruction, + expected, + found, + .. + } if instruction == "Call @callee" && expected == "%Qubit*" && found == "%Qubit" + ))); + + m.functions[1].basic_blocks[0].instructions[0] = Instruction::Call { + return_ty: None, + callee: "callee".into(), + args: vec![( + Type::NamedPtr("Qubit".into()), + Operand::int_to_named_ptr(0, "Qubit"), + )], + result: None, + attr_refs: Vec::new(), + }; + + let errors = validate_ir(&m); + assert!( + errors.is_empty(), + "valid named-pointer call failed: {errors:?}" + ); +} diff --git a/source/compiler/qsc_llvm/src/validation/qir.rs b/source/compiler/qsc_llvm/src/validation/qir.rs new file mode 100644 index 0000000000..2cc89962d8 --- /dev/null +++ b/source/compiler/qsc_llvm/src/validation/qir.rs @@ -0,0 +1,1700 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! QIR profile validation module. +//! +//! Validates a [`Module`] against the QIR profile it claims to conform to, +//! returning structured [`QirProfileError`] diagnostics for all rule +//! violations found. Supports Base Profile v1 and Adaptive Profile variants. + +#[cfg(test)] +mod tests; + +use crate::model::Type; +use crate::model::{ + BasicBlock, BinOpKind, CastKind, Constant, Function, Instruction, Module, Operand, +}; +use crate::qir::{ + self, DYNAMIC_QUBIT_MGMT_KEY, DYNAMIC_RESULT_MGMT_KEY, IRREVERSIBLE_ATTR, + OUTPUT_LABELING_SCHEMA_ATTR, QIR_MAJOR_VERSION_KEY, QIR_MINOR_VERSION_KEY, QIR_PROFILES_ATTR, + QirProfile, REQUIRED_NUM_QUBITS_ATTR, REQUIRED_NUM_RESULTS_ATTR, inspect, +}; +use miette::Diagnostic; +use rustc_hash::FxHashSet; +use thiserror::Error; + +/// Detected profile and capabilities from module introspection. +#[derive(Debug)] +pub struct DetectedProfile { + pub profile: QirProfile, + pub capabilities: Capabilities, +} + +/// Capability flags extracted from module flags metadata. +#[derive(Debug, Default)] +#[allow(clippy::struct_excessive_bools)] +pub struct Capabilities { + pub int_computations: Vec, + pub float_computations: Vec, + pub ir_functions: bool, + pub backwards_branching: u8, + pub multiple_target_branching: bool, + pub multiple_return_points: bool, + pub dynamic_qubit_management: bool, + pub dynamic_result_management: bool, + pub arrays: bool, +} + +/// QIR profile validation error with Miette diagnostic support. +#[derive(Clone, Debug, Diagnostic, Error, PartialEq, Eq)] +pub enum QirProfileError { + #[error("missing opaque `{type_name}` struct type definition")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingOpaqueType"), + help( + "QIR profiles using typed pointers require opaque Qubit and Result struct type definitions" + ) + )] + MissingOpaqueType { type_name: String }, + + #[error("expected exactly 1 entry point, found {count}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.EntryPointCount"), + help("a QIR module must define exactly one function with the `entry_point` attribute") + )] + EntryPointCount { count: usize }, + + #[error("missing required module flag `{flag_name}`")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingModuleFlag"), + help( + "all QIR profiles require qir_major_version, qir_minor_version, dynamic_qubit_management, and dynamic_result_management flags" + ) + )] + MissingModuleFlag { flag_name: String }, + + #[error("llvm.module.flags references missing metadata node !{node_ref}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.DanglingModuleFlagReference"), + help("remove or repair the dangling llvm.module.flags reference") + )] + DanglingModuleFlagReference { node_ref: u32 }, + + #[error("malformed llvm.module.flags node !{node_ref}: {reason}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MalformedModuleFlagNode"), + help("repair the malformed llvm.module.flags metadata node structure") + )] + MalformedModuleFlagNode { node_ref: u32, reason: String }, + + #[error("malformed module flag `{flag_name}`: {reason}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MalformedModuleFlag"), + help("repair the module flag payload so validation can interpret it") + )] + MalformedModuleFlag { flag_name: String, reason: String }, + + #[error("entry point missing required attribute `{attr_name}`")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingEntryPointAttr"), + help( + "entry points require: entry_point, qir_profiles, output_labeling_schema, and required_num_qubits/results unless the matching dynamic management flag is enabled" + ) + )] + MissingEntryPointAttr { attr_name: String }, + + #[error("qir_profiles attribute value `{found}` does not match detected profile `{expected}`")] + #[diagnostic(code("Qsc.Llvm.QirValidator.ProfileMismatch"), severity(Warning))] + ProfileMismatch { expected: String, found: String }, + + #[error( + "unsupported qir_profiles value `{profile_name}` with qir_major_version `{major_version}`" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.UnsupportedProfileMetadata"), + help( + "use a supported profile/major-version pair such as base_profile+1 or adaptive_profile+1/2" + ) + )] + UnsupportedProfileMetadata { + profile_name: String, + major_version: i64, + }, + + #[error("base profile requires `{flag_name}` to be false")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.BaseDynamicMgmtEnabled"), + help("the base profile does not support dynamic qubit or result management") + )] + BaseDynamicMgmtEnabled { flag_name: String }, + + #[error("entry point must have no parameters, found {param_count}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.EntryPointParams"), + help( + "QIR entry points take no parameters; qubit and result allocation is expressed through inttoptr casts" + ) + )] + EntryPointParams { param_count: usize }, + + #[error("entry point must return i64, found {found_type}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.EntryPointReturnType"), + help("the return type represents an exit code: 0 = success") + )] + EntryPointReturnType { found_type: String }, + + #[error( + "base profile requires exactly 4 basic blocks (entry, body, measurements, output), found {block_count}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.BaseBlockCount"), + help( + "base profile programs follow a fixed 4-block structure connected by unconditional branches" + ) + )] + BaseBlockCount { block_count: usize }, + + #[error("instruction `{instruction}` is not allowed in {profile} profile{context}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.InstructionNotAllowed"), + help("{suggestion}") + )] + InstructionNotAllowed { + instruction: String, + profile: String, + context: String, + suggestion: String, + }, + + #[error( + "conditional branch is not allowed in base profile in function `{function}` block {block}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.BaseConditionalBranch"), + help("base profile only allows unconditional branching between its 4 fixed blocks") + )] + BaseConditionalBranch { function: String, block: usize }, + + #[error("{instruction} requires `{capability}` capability flag in {location}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingCapability"), + help( + "set the `{capability}` module flag to enable this instruction in the adaptive profile" + ) + )] + MissingCapability { + instruction: String, + capability: String, + location: String, + }, + + #[error("missing required declaration for `{function_name}`")] + #[diagnostic(code("Qsc.Llvm.QirValidator.MissingDeclaration"), help("{help_text}"))] + MissingDeclaration { + function_name: String, + help_text: String, + }, + + #[error( + "incorrect signature for `{function_name}`: expected {expected_sig}, found {found_sig}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.WrongSignature"), + help("runtime function signatures must match the QIR spec") + )] + WrongSignature { + function_name: String, + expected_sig: String, + found_sig: String, + }, + + #[error( + "`__quantum__rt__initialize` has incorrect signature: expected void(ptr), found {found_sig}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.InitializeWrongSignature"), + help("initialize must accept a single ptr argument and return void") + )] + InitializeWrongSignature { found_sig: String }, + + #[error("QIS function `{function_name}` must return void in base profile, found {found_type}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.QisNonVoidReturn"), + help( + "base profile requires all QIS functions to return void; measurement results are communicated via writeonly result pointers" + ) + )] + QisNonVoidReturn { + function_name: String, + found_type: String, + }, + + #[error("measurement function `{function_name}` must have `irreversible` attribute")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingIrreversible"), + help("measurement functions must be marked irreversible per the QIR spec") + )] + MissingIrreversible { function_name: String }, + + #[error("measurement function `{function_name}` result parameter must be `writeonly`")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.MissingWriteonly"), + help( + "measurement result pointers must be writeonly to ensure results are only consumed by output recording functions" + ) + )] + MissingWriteonly { function_name: String }, + + #[error( + "base profile requires linear control flow: block {block_idx} does not jump to block {expected_next}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.NonLinearFlow"), + help( + "base profile blocks must form a linear sequence: entry -> body -> measurements -> output" + ) + )] + NonLinearFlow { + block_idx: usize, + expected_next: usize, + }, + + #[error("control flow graph contains a cycle without `backwards_branching` capability")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.UnauthorizedCycle"), + help("set the `backwards_branching` module flag to enable loops in the adaptive profile") + )] + UnauthorizedCycle { function: String }, + + #[error( + "multiple return points without `multiple_return_points` capability (found {ret_count} ret instructions)" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.UnauthorizedMultipleReturns"), + help("set the `multiple_return_points` module flag to enable multiple ret statements") + )] + UnauthorizedMultipleReturns { function: String, ret_count: usize }, + + #[error("{feature} instructions are used but `{flag_name}` capability is not declared")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.CapabilityNotDeclared"), + help("add the `{flag_name}` module flag to declare this capability") + )] + CapabilityNotDeclared { feature: String, flag_name: String }, + + #[error("float width `{width_name}` is used but not declared in `float_computations`")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.FloatWidthNotDeclared"), + help( + "add `{width_name}` to the `float_computations` module flag or remove the float-typed IR surface" + ) + )] + FloatWidthNotDeclared { width_name: String }, + + #[error( + "`float_computations` is declared but the module contains no floating-point operations" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.FloatCapabilityWithoutOperation"), + help("remove the `float_computations` module flag or add a floating-point operation") + )] + FloatCapabilityWithoutOperation, + + #[error("array instruction `{instruction}` requires `arrays` capability in {location}")] + #[diagnostic( + code("Qsc.Llvm.QirValidator.ArraysNotEnabled"), + help( + "set the `arrays` module flag to enable array types and operations in the adaptive profile" + ) + )] + ArraysNotEnabled { + instruction: String, + location: String, + }, + + #[error( + "output-recording call `{function_name}` requires a string label operand in {location}, found {found_operand}" + )] + #[diagnostic( + code("Qsc.Llvm.QirValidator.InvalidOutputLabelOperand"), + help("use a global constant string or a getelementptr derived from one for output labels") + )] + InvalidOutputLabelOperand { + function_name: String, + location: String, + found_operand: String, + }, +} + +/// Result of profile validation. +#[derive(Debug)] +pub struct QirProfileValidation { + pub detected: DetectedProfile, + pub errors: Vec, +} + +/// Validate a [`Module`] against the QIR profile it claims to conform to. +/// +/// Returns the detected profile + capabilities and all errors found. +#[must_use] +pub fn validate_qir_profile(module: &Module) -> QirProfileValidation { + let (detected, malformed_flags, mut errors) = detect_profile(module); + errors.extend(validate_module_structure( + module, + &detected, + &malformed_flags, + )); + errors.extend(validate_entry_point(module, &detected)); + errors.extend(validate_instructions(module, &detected, &malformed_flags)); + errors.extend(validate_declarations(module, &detected)); + errors.extend(validate_output_recording_calls(module)); + errors.extend(validate_cfg(module, &detected, &malformed_flags)); + errors.extend(validate_consistency(module, &detected, &malformed_flags)); + QirProfileValidation { detected, errors } +} + +fn detect_profile(module: &Module) -> (DetectedProfile, FxHashSet, Vec) { + let mut errors = convert_module_flag_issues(&inspect::inspect_module_flag_metadata(module)); + let mut malformed_flags = FxHashSet::default(); + + let major_flag = inspect::inspect_module_flag_int(module, QIR_MAJOR_VERSION_KEY); + malformed_flags.extend(collect_malformed_flag_names(&major_flag.issues)); + errors.extend(convert_module_flag_issues(&major_flag.issues)); + + let minor_flag = inspect::inspect_module_flag_int(module, QIR_MINOR_VERSION_KEY); + malformed_flags.extend(collect_malformed_flag_names(&minor_flag.issues)); + errors.extend(convert_module_flag_issues(&minor_flag.issues)); + + let (profile_name, major, minor) = detect_profile_metadata(module); + + let profile = match (profile_name.as_deref(), major, minor) { + (Some("base_profile"), Some(1), _) => QirProfile::BaseV1, + (Some("adaptive_profile"), Some(1), _) => QirProfile::AdaptiveV1, + (Some("adaptive_profile"), Some(2), _) => QirProfile::AdaptiveV2, + (Some(profile_name), Some(major_version), _) => { + errors.push(QirProfileError::UnsupportedProfileMetadata { + profile_name: profile_name.to_string(), + major_version, + }); + QirProfile::BaseV1 + } + _ => QirProfile::BaseV1, + }; + + let (capabilities, capability_malformed_flags, capability_errors) = + extract_capabilities(module); + malformed_flags.extend(capability_malformed_flags); + errors.extend(capability_errors); + + ( + DetectedProfile { + profile, + capabilities, + }, + malformed_flags, + errors, + ) +} + +fn detect_profile_metadata(module: &Module) -> (Option, Option, Option) { + // Try to read qir_profiles from entry-point attributes. + let profile_name = inspect::find_entry_point(module) + .and_then(|idx| inspect::get_function_attribute(module, idx, QIR_PROFILES_ATTR)) + .map(String::from); + + let major = inspect::inspect_module_flag_int(module, QIR_MAJOR_VERSION_KEY).value; + let minor = inspect::inspect_module_flag_int(module, QIR_MINOR_VERSION_KEY).value; + + (profile_name, major, minor) +} + +fn extract_capabilities( + module: &Module, +) -> (Capabilities, FxHashSet, Vec) { + let mut capabilities = Capabilities::default(); + let mut malformed_flags = FxHashSet::default(); + let mut errors = Vec::new(); + + let int_computations = inspect::inspect_module_flag_string_list(module, "int_computations"); + malformed_flags.extend(collect_malformed_flag_names(&int_computations.issues)); + errors.extend(convert_module_flag_issues(&int_computations.issues)); + capabilities.int_computations = int_computations.value.unwrap_or_default(); + + let float_computations = inspect::inspect_module_flag_string_list(module, "float_computations"); + malformed_flags.extend(collect_malformed_flag_names(&float_computations.issues)); + errors.extend(convert_module_flag_issues(&float_computations.issues)); + capabilities.float_computations = float_computations.value.unwrap_or_default(); + + let ir_functions = inspect::inspect_module_flag_bool(module, "ir_functions"); + malformed_flags.extend(collect_malformed_flag_names(&ir_functions.issues)); + errors.extend(convert_module_flag_issues(&ir_functions.issues)); + capabilities.ir_functions = ir_functions.value.unwrap_or(false); + + let backwards_branching = inspect::inspect_module_flag_int(module, "backwards_branching"); + malformed_flags.extend(collect_malformed_flag_names(&backwards_branching.issues)); + errors.extend(convert_module_flag_issues(&backwards_branching.issues)); + if let Some(value) = backwards_branching.value { + if let Ok(value) = u8::try_from(value) { + capabilities.backwards_branching = value; + } else { + malformed_flags.insert("backwards_branching".to_string()); + errors.push(QirProfileError::MalformedModuleFlag { + flag_name: "backwards_branching".into(), + reason: format!("expected an integer between 0 and 255, found `{value}`"), + }); + } + } + + let multiple_target_branching = + inspect::inspect_module_flag_bool(module, "multiple_target_branching"); + malformed_flags.extend(collect_malformed_flag_names( + &multiple_target_branching.issues, + )); + errors.extend(convert_module_flag_issues( + &multiple_target_branching.issues, + )); + capabilities.multiple_target_branching = multiple_target_branching.value.unwrap_or(false); + + let multiple_return_points = + inspect::inspect_module_flag_bool(module, "multiple_return_points"); + malformed_flags.extend(collect_malformed_flag_names(&multiple_return_points.issues)); + errors.extend(convert_module_flag_issues(&multiple_return_points.issues)); + capabilities.multiple_return_points = multiple_return_points.value.unwrap_or(false); + + let dynamic_qubit_management = + inspect::inspect_module_flag_bool(module, DYNAMIC_QUBIT_MGMT_KEY); + malformed_flags.extend(collect_malformed_flag_names( + &dynamic_qubit_management.issues, + )); + errors.extend(convert_module_flag_issues(&dynamic_qubit_management.issues)); + capabilities.dynamic_qubit_management = dynamic_qubit_management.value.unwrap_or(false); + + let dynamic_result_management = + inspect::inspect_module_flag_bool(module, DYNAMIC_RESULT_MGMT_KEY); + malformed_flags.extend(collect_malformed_flag_names( + &dynamic_result_management.issues, + )); + errors.extend(convert_module_flag_issues( + &dynamic_result_management.issues, + )); + capabilities.dynamic_result_management = dynamic_result_management.value.unwrap_or(false); + + let arrays = inspect::inspect_module_flag_bool(module, "arrays"); + malformed_flags.extend(collect_malformed_flag_names(&arrays.issues)); + errors.extend(convert_module_flag_issues(&arrays.issues)); + capabilities.arrays = arrays.value.unwrap_or(false); + + (capabilities, malformed_flags, errors) +} + +fn convert_module_flag_issues(issues: &[inspect::ModuleFlagIssue]) -> Vec { + issues + .iter() + .map(|issue| match issue { + inspect::ModuleFlagIssue::DanglingReference { node_ref } => { + QirProfileError::DanglingModuleFlagReference { + node_ref: *node_ref, + } + } + inspect::ModuleFlagIssue::MalformedNode { node_ref, reason } => { + QirProfileError::MalformedModuleFlagNode { + node_ref: *node_ref, + reason: (*reason).to_string(), + } + } + inspect::ModuleFlagIssue::InvalidBehavior { + flag_name, + node_id, + found, + } => QirProfileError::MalformedModuleFlag { + flag_name: flag_name.clone(), + reason: format!("node !{node_id} has non-integer merge behavior payload `{found}`"), + }, + inspect::ModuleFlagIssue::InvalidValue { + flag_name, + node_id, + expected, + found, + } => QirProfileError::MalformedModuleFlag { + flag_name: flag_name.clone(), + reason: format!("node !{node_id} expected {expected} payload, found `{found}`"), + }, + inspect::ModuleFlagIssue::InvalidStringListItem { + flag_name, + node_id, + index, + found, + } => QirProfileError::MalformedModuleFlag { + flag_name: flag_name.clone(), + reason: format!( + "node !{node_id} has non-string string-list item at index {index}: `{found}`" + ), + }, + }) + .collect() +} + +fn collect_malformed_flag_names(issues: &[inspect::ModuleFlagIssue]) -> FxHashSet { + issues + .iter() + .filter_map(|issue| issue.flag_name().map(str::to_string)) + .collect() +} + +fn validate_module_structure( + module: &Module, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, +) -> Vec { + let mut v = Vec::new(); + + // MS-01: Typed-pointer profiles need Qubit/Result struct types. + if detected.profile.uses_typed_pointers() { + let has_qubit = module.struct_types.iter().any(|s| s.name == "Qubit"); + let has_result = module.struct_types.iter().any(|s| s.name == "Result"); + if !has_qubit { + v.push(QirProfileError::MissingOpaqueType { + type_name: "Qubit".into(), + }); + } + if !has_result { + v.push(QirProfileError::MissingOpaqueType { + type_name: "Result".into(), + }); + } + } + + // MS-03: Exactly one entry point. + let entry_count = inspect::count_entry_points(module); + if entry_count != 1 { + v.push(QirProfileError::EntryPointCount { count: entry_count }); + } + + // MF-01..04: Required module flags. + check_required_flag(&mut v, module, QIR_MAJOR_VERSION_KEY); + check_required_flag(&mut v, module, QIR_MINOR_VERSION_KEY); + check_required_flag(&mut v, module, DYNAMIC_QUBIT_MGMT_KEY); + check_required_flag(&mut v, module, DYNAMIC_RESULT_MGMT_KEY); + + // MF-05: qir_major_version behavior must be Error (1). + let qir_major_behavior = inspect::inspect_module_flag_behavior(module, QIR_MAJOR_VERSION_KEY); + v.extend(convert_module_flag_issues(&qir_major_behavior.issues)); + if let Some(behavior) = qir_major_behavior.value + && behavior != qir::FLAG_BEHAVIOR_ERROR + { + v.push(QirProfileError::MissingModuleFlag { + flag_name: format!( + "{QIR_MAJOR_VERSION_KEY} behavior must be Error (1), found {behavior}" + ), + }); + } + + // MF-06: qir_minor_version behavior must be Max (7). + let qir_minor_behavior = inspect::inspect_module_flag_behavior(module, QIR_MINOR_VERSION_KEY); + v.extend(convert_module_flag_issues(&qir_minor_behavior.issues)); + if let Some(behavior) = qir_minor_behavior.value + && behavior != qir::FLAG_BEHAVIOR_MAX + { + v.push(QirProfileError::MissingModuleFlag { + flag_name: format!( + "{QIR_MINOR_VERSION_KEY} behavior must be Max (7), found {behavior}" + ), + }); + } + + // AT-01..05: Entry point attribute checks. + if let Some(ep_idx) = inspect::find_entry_point(module) { + let func = &module.functions[ep_idx]; + check_entry_point_attrs(&mut v, module, ep_idx, func, detected, malformed_flags); + } + + // DT-04/DT-05: Base profile must have dynamic_*_management = false. + if detected.profile == QirProfile::BaseV1 { + if detected.capabilities.dynamic_qubit_management { + v.push(QirProfileError::BaseDynamicMgmtEnabled { + flag_name: "dynamic_qubit_management".into(), + }); + } + if detected.capabilities.dynamic_result_management { + v.push(QirProfileError::BaseDynamicMgmtEnabled { + flag_name: "dynamic_result_management".into(), + }); + } + } + + v +} + +fn check_required_flag(v: &mut Vec, module: &Module, key: &str) { + if inspect::get_module_flag(module, key).is_none() { + v.push(QirProfileError::MissingModuleFlag { + flag_name: key.into(), + }); + } +} + +fn check_entry_point_attrs( + v: &mut Vec, + module: &Module, + ep_idx: usize, + _func: &Function, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, +) { + // AT-01: entry_point attribute (implicitly satisfied if find_entry_point found it). + + // AT-02: qir_profiles matches. + let expected_name = detected.profile.profile_name(); + if let Some(actual) = inspect::get_function_attribute(module, ep_idx, QIR_PROFILES_ATTR) { + if actual != expected_name { + v.push(QirProfileError::ProfileMismatch { + expected: expected_name.into(), + found: actual.into(), + }); + } + } else { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "qir_profiles".into(), + }); + } + + let require_qubit_count = !detected.capabilities.dynamic_qubit_management + && !malformed_flags.contains(DYNAMIC_QUBIT_MGMT_KEY); + let require_result_count = !detected.capabilities.dynamic_result_management + && !malformed_flags.contains(DYNAMIC_RESULT_MGMT_KEY); + + // AT-03 + AT-06: required_num_qubits (must exist and parse as u64 unless dynamic). + match inspect::get_function_attribute(module, ep_idx, REQUIRED_NUM_QUBITS_ATTR) { + None if require_qubit_count => { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "required_num_qubits".into(), + }); + } + Some(val) if val.parse::().is_err() => { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "required_num_qubits (must be a non-negative integer)".into(), + }); + } + _ => {} + } + + // AT-04 + AT-07: required_num_results (must exist and parse as u64 unless dynamic). + match inspect::get_function_attribute(module, ep_idx, REQUIRED_NUM_RESULTS_ATTR) { + None if require_result_count => { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "required_num_results".into(), + }); + } + Some(val) if val.parse::().is_err() => { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "required_num_results (must be a non-negative integer)".into(), + }); + } + _ => {} + } + + // AT-05: output_labeling_schema. + if !inspect::has_function_attribute(module, ep_idx, OUTPUT_LABELING_SCHEMA_ATTR) { + v.push(QirProfileError::MissingEntryPointAttr { + attr_name: "output_labeling_schema".into(), + }); + } +} + +fn validate_entry_point(module: &Module, detected: &DetectedProfile) -> Vec { + let mut v = Vec::new(); + let Some(ep_idx) = inspect::find_entry_point(module) else { + return v; // Already reported as MS-03. + }; + let func = &module.functions[ep_idx]; + + // EP-01: No parameters. + if !func.params.is_empty() { + v.push(QirProfileError::EntryPointParams { + param_count: func.params.len(), + }); + } + + // EP-02: Return type = i64. + if func.return_type != Type::Integer(64) { + v.push(QirProfileError::EntryPointReturnType { + found_type: func.return_type.to_string(), + }); + } + + // EP-03: Base profile requires exactly 4 blocks. + if detected.profile == QirProfile::BaseV1 && func.basic_blocks.len() != 4 { + v.push(QirProfileError::BaseBlockCount { + block_count: func.basic_blocks.len(), + }); + } + + // EP-08 / CI-04: Base profile — all branches must be unconditional. + if detected.profile == QirProfile::BaseV1 { + for (bi, bb) in func.basic_blocks.iter().enumerate() { + for instr in &bb.instructions { + if matches!(instr, Instruction::Br { .. }) { + v.push(QirProfileError::BaseConditionalBranch { + function: func.name.clone(), + block: bi, + }); + } + } + } + } + + v +} + +fn validate_instructions( + module: &Module, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, +) -> Vec { + let mut v = Vec::new(); + let Some(ep_idx) = inspect::find_entry_point(module) else { + return v; + }; + let func = &module.functions[ep_idx]; + + for (bi, bb) in func.basic_blocks.iter().enumerate() { + for (ii, instr) in bb.instructions.iter().enumerate() { + let context = format!(" in function '{}' block {bi} instruction {ii}", func.name); + check_instruction_allowed(instr, detected, malformed_flags, &context, &mut v); + } + } + + v +} + +#[allow(clippy::too_many_lines)] +fn check_instruction_allowed( + instr: &Instruction, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, + context: &str, + v: &mut Vec, +) { + let profile_name = detected.profile.profile_name(); + let int_flag_malformed = malformed_flags.contains("int_computations"); + let float_flag_malformed = malformed_flags.contains("float_computations"); + let backwards_branching_malformed = malformed_flags.contains("backwards_branching"); + let multiple_target_branching_malformed = malformed_flags.contains("multiple_target_branching"); + let arrays_flag_malformed = malformed_flags.contains("arrays"); + match instr { + // Always allowed in all profiles. + Instruction::Call { .. } + | Instruction::Ret(_) + | Instruction::Jump { .. } + | Instruction::GetElementPtr { .. } + | Instruction::Unreachable => {} + + // Cast — only IntToPtr allowed in base; expanded in adaptive. + Instruction::Cast { op, .. } => match op { + CastKind::IntToPtr => {} // Allowed in all profiles. + CastKind::Zext | CastKind::Sext | CastKind::Trunc + if !matches!(detected.profile, QirProfile::BaseV1) + && !int_flag_malformed + && !detected.capabilities.int_computations.is_empty() => {} + CastKind::FpExt | CastKind::FpTrunc + if !matches!(detected.profile, QirProfile::BaseV1) + && !float_flag_malformed + && !detected.capabilities.float_computations.is_empty() => {} + CastKind::Sitofp | CastKind::Fptosi + if !matches!(detected.profile, QirProfile::BaseV1) + && !int_flag_malformed + && !float_flag_malformed + && !detected.capabilities.int_computations.is_empty() + && !detected.capabilities.float_computations.is_empty() => {} + CastKind::Zext | CastKind::Sext | CastKind::Trunc + if !matches!(detected.profile, QirProfile::BaseV1) && int_flag_malformed => {} + CastKind::FpExt | CastKind::FpTrunc + if !matches!(detected.profile, QirProfile::BaseV1) && float_flag_malformed => {} + CastKind::Sitofp | CastKind::Fptosi + if !matches!(detected.profile, QirProfile::BaseV1) + && (int_flag_malformed || float_flag_malformed) => {} + _ => { + v.push(QirProfileError::InstructionNotAllowed { + instruction: format!("cast {op:?}"), + profile: profile_name.into(), + context: context.into(), + suggestion: "only inttoptr casts are allowed in base profile; adaptive profiles require appropriate capability flags".into(), + }); + } + }, + + // Conditional branch — not allowed in base. + Instruction::Br { .. } => { + if detected.profile == QirProfile::BaseV1 { + // Already reported under CI-04 in Pass 2; avoid double-report. + } + // For adaptive — conditional branch is always allowed. + } + + // BinOp — not allowed in base; adaptive depends on capabilities. + Instruction::BinOp { op, .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: format!("binop {op:?}"), + profile: profile_name.into(), + context: context.into(), + suggestion: "binary operations are not allowed in base profile".into(), + }); + } else if is_int_binop(op) + && !int_flag_malformed + && detected.capabilities.int_computations.is_empty() + { + v.push(QirProfileError::MissingCapability { + instruction: format!("integer binop {op:?}"), + capability: "int_computations".into(), + location: context.into(), + }); + } else if is_float_binop(op) + && !float_flag_malformed + && detected.capabilities.float_computations.is_empty() + { + v.push(QirProfileError::MissingCapability { + instruction: format!("float binop {op:?}"), + capability: "float_computations".into(), + location: context.into(), + }); + } + } + + // ICmp — not allowed in base; needs int cap in adaptive. + Instruction::ICmp { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: "icmp".into(), + profile: profile_name.into(), + context: context.into(), + suggestion: "integer comparison is not allowed in base profile".into(), + }); + } else if !int_flag_malformed && detected.capabilities.int_computations.is_empty() { + v.push(QirProfileError::MissingCapability { + instruction: "icmp".into(), + capability: "int_computations".into(), + location: context.into(), + }); + } + } + + // FCmp — not allowed in base; needs float cap in adaptive. + Instruction::FCmp { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: "fcmp".into(), + profile: profile_name.into(), + context: context.into(), + suggestion: "float comparison is not allowed in base profile".into(), + }); + } else if !float_flag_malformed && detected.capabilities.float_computations.is_empty() { + v.push(QirProfileError::MissingCapability { + instruction: "fcmp".into(), + capability: "float_computations".into(), + location: context.into(), + }); + } + } + + // Phi — only adaptive with backwards_branching or int_computations. + Instruction::Phi { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: "phi".into(), + profile: profile_name.into(), + context: context.into(), + suggestion: "phi nodes are not allowed in base profile".into(), + }); + } else if detected.capabilities.backwards_branching == 0 + && detected.capabilities.int_computations.is_empty() + && !backwards_branching_malformed + && !int_flag_malformed + { + v.push(QirProfileError::MissingCapability { + instruction: "phi".into(), + capability: "backwards_branching or int_computations".into(), + location: context.into(), + }); + } + } + + // Select — adaptive with int_computations. + Instruction::Select { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: "select".into(), + profile: profile_name.into(), + context: context.into(), + suggestion: "select is not allowed in base profile".into(), + }); + } else if !int_flag_malformed && detected.capabilities.int_computations.is_empty() { + v.push(QirProfileError::MissingCapability { + instruction: "select".into(), + capability: "int_computations".into(), + location: context.into(), + }); + } + } + + // Switch — adaptive with multiple_target_branching. + Instruction::Switch { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: "switch".into(), + profile: profile_name.into(), + context: context.into(), + suggestion: "switch is not allowed in base profile".into(), + }); + } else if !multiple_target_branching_malformed + && !detected.capabilities.multiple_target_branching + { + v.push(QirProfileError::MissingCapability { + instruction: "switch".into(), + capability: "multiple_target_branching".into(), + location: context.into(), + }); + } + } + + // Alloca, Load, Store — not in base; may appear in ir_functions context. + Instruction::Alloca { .. } | Instruction::Load { .. } | Instruction::Store { .. } => { + if detected.profile == QirProfile::BaseV1 { + v.push(QirProfileError::InstructionNotAllowed { + instruction: instruction_name(instr).into(), + profile: profile_name.into(), + context: context.into(), + suggestion: format!( + "{} is not allowed in base profile", + instruction_name(instr) + ), + }); + } else if !arrays_flag_malformed && !detected.capabilities.arrays { + // AR-02: Alloca/Load/Store require arrays capability in adaptive profiles. + v.push(QirProfileError::ArraysNotEnabled { + instruction: instruction_name(instr).into(), + location: context.into(), + }); + } + } + } +} + +fn is_int_binop(op: &BinOpKind) -> bool { + matches!( + op, + BinOpKind::Add + | BinOpKind::Sub + | BinOpKind::Mul + | BinOpKind::Sdiv + | BinOpKind::Udiv + | BinOpKind::Srem + | BinOpKind::Urem + | BinOpKind::Shl + | BinOpKind::Ashr + | BinOpKind::Lshr + | BinOpKind::And + | BinOpKind::Or + | BinOpKind::Xor + ) +} + +fn is_float_binop(op: &BinOpKind) -> bool { + matches!( + op, + BinOpKind::Fadd | BinOpKind::Fsub | BinOpKind::Fmul | BinOpKind::Fdiv + ) +} + +fn instruction_name(instr: &Instruction) -> &'static str { + match instr { + Instruction::Ret(_) => "ret", + Instruction::Br { .. } => "br", + Instruction::Jump { .. } => "jump", + Instruction::BinOp { .. } => "binop", + Instruction::ICmp { .. } => "icmp", + Instruction::FCmp { .. } => "fcmp", + Instruction::Cast { .. } => "cast", + Instruction::Call { .. } => "call", + Instruction::Phi { .. } => "phi", + Instruction::Alloca { .. } => "alloca", + Instruction::Load { .. } => "load", + Instruction::Store { .. } => "store", + Instruction::Select { .. } => "select", + Instruction::Switch { .. } => "switch", + Instruction::GetElementPtr { .. } => "getelementptr", + Instruction::Unreachable => "unreachable", + } +} + +#[allow(clippy::too_many_lines)] +fn validate_declarations(module: &Module, detected: &DetectedProfile) -> Vec { + let mut v = Vec::new(); + + // RT-01: __quantum__rt__initialize must be declared. + let has_initialize = module + .functions + .iter() + .any(|f| f.is_declaration && f.name == qir::rt::INITIALIZE); + if !has_initialize { + v.push(QirProfileError::MissingDeclaration { + function_name: qir::rt::INITIALIZE.into(), + help_text: "the runtime initialize function must be declared for all QIR profiles" + .into(), + }); + } + + // RT-02: If __quantum__rt__initialize is declared, validate signature void(ptr). + if let Some(init_func) = module + .functions + .iter() + .find(|f| f.is_declaration && f.name == qir::rt::INITIALIZE) + { + let ok = init_func.return_type == Type::Void && init_func.params.len() == 1; + if !ok { + v.push(QirProfileError::InitializeWrongSignature { + found_sig: format_function_sig(init_func), + }); + } + } + + // AP-MC-02: Adaptive profiles require __quantum__rt__read_result. + if detected.profile != QirProfile::BaseV1 { + let has_read_result = module.functions.iter().any(|f| { + f.is_declaration + && f.name == qir::rt::READ_RESULT + && f.return_type == Type::Integer(1) + && f.params.len() == 1 + }); + if !has_read_result { + v.push(QirProfileError::MissingDeclaration { + function_name: qir::rt::READ_RESULT.into(), + help_text: + "adaptive profile requires __quantum__rt__read_result(ptr) → i1 declaration" + .into(), + }); + } + } + + // QIS-01: QIS functions should return void (base) or void/data type (adaptive). + // QIS-02: Measurement QIS functions must have irreversible attribute. + for (func_idx, func) in module.functions.iter().enumerate() { + if !func.is_declaration { + continue; + } + + if func.name.starts_with("__quantum__qis__") { + // Base profile: all QIS must return void. + if detected.profile == QirProfile::BaseV1 && func.return_type != Type::Void { + v.push(QirProfileError::QisNonVoidReturn { + function_name: func.name.clone(), + found_type: func.return_type.to_string(), + }); + } + + // QIS-02: Measurement functions must have irreversible attribute. + if func.name.to_lowercase().contains("measure") { + if !inspect::has_function_attribute(module, func_idx, IRREVERSIBLE_ATTR) { + v.push(QirProfileError::MissingIrreversible { + function_name: func.name.clone(), + }); + } + } + } + + // RT-03: Check output recording function signatures. + if func.name == qir::rt::TUPLE_RECORD_OUTPUT || func.name == qir::rt::ARRAY_RECORD_OUTPUT { + // Expected: void(i64, ptr) + let ok = func.return_type == Type::Void + && func.params.len() == 2 + && func.params[0].ty == Type::Integer(64); + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(i64, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + + if func.name == qir::rt::RESULT_RECORD_OUTPUT { + // Expected: void(ptr, ptr) + let ok = func.return_type == Type::Void && func.params.len() == 2; + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(ptr, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + + if func.name == qir::rt::RESULT_ARRAY_RECORD_OUTPUT { + // Expected: void(i64, ptr, ptr) + let ok = func.return_type == Type::Void + && func.params.len() == 3 + && func.params[0].ty == Type::Integer(64); + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(i64, ptr, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + + // RT-04: __quantum__rt__bool_record_output signature void(i1, ptr). + if func.name == qir::rt::BOOL_RECORD_OUTPUT { + let ok = func.return_type == Type::Void + && func.params.len() == 2 + && func.params[0].ty == Type::Integer(1); + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(i1, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + + // RT-05: __quantum__rt__int_record_output signature void(i64, ptr). + if func.name == qir::rt::INT_RECORD_OUTPUT { + let ok = func.return_type == Type::Void + && func.params.len() == 2 + && func.params[0].ty == Type::Integer(64); + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(i64, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + + // RT-06: __quantum__rt__double_record_output + if func.name == qir::rt::DOUBLE_RECORD_OUTPUT { + let ok = func.return_type == Type::Void + && func.params.len() == 2 + && func.params[0].ty == Type::Double; + if !ok { + v.push(QirProfileError::WrongSignature { + function_name: func.name.clone(), + expected_sig: "void(double, ptr)".into(), + found_sig: format_function_sig(func), + }); + } + } + } + + // RT-07: dynamic_qubit_management requires qubit_allocate. + if detected.capabilities.dynamic_qubit_management { + check_required_declaration( + &mut v, + module, + qir::rt::QUBIT_ALLOCATE, + "dynamic_qubit_management requires __quantum__rt__qubit_allocate declaration", + ); + } + + // RT-08: dynamic_qubit_management requires qubit_release. + if detected.capabilities.dynamic_qubit_management { + check_required_declaration( + &mut v, + module, + qir::rt::QUBIT_RELEASE, + "dynamic_qubit_management requires __quantum__rt__qubit_release declaration", + ); + } + + // RT-09: dynamic_result_management requires result_allocate. + if detected.capabilities.dynamic_result_management { + check_required_declaration( + &mut v, + module, + qir::rt::RESULT_ALLOCATE, + "dynamic_result_management requires __quantum__rt__result_allocate declaration", + ); + } + + // RT-10: dynamic_result_management requires result_release. + if detected.capabilities.dynamic_result_management { + check_required_declaration( + &mut v, + module, + qir::rt::RESULT_RELEASE, + "dynamic_result_management requires __quantum__rt__result_release declaration", + ); + } + + // AR-03: arrays capability requires array RT functions. + if detected.capabilities.arrays { + if detected.capabilities.dynamic_qubit_management { + check_required_declaration( + &mut v, + module, + qir::rt::QUBIT_ARRAY_ALLOCATE, + "arrays + dynamic_qubit_management requires __quantum__rt__qubit_array_allocate declaration", + ); + check_required_declaration( + &mut v, + module, + qir::rt::QUBIT_ARRAY_RELEASE, + "arrays + dynamic_qubit_management requires __quantum__rt__qubit_array_release declaration", + ); + } + if detected.capabilities.dynamic_result_management { + check_required_declaration( + &mut v, + module, + qir::rt::RESULT_ARRAY_ALLOCATE, + "arrays + dynamic_result_management requires __quantum__rt__result_array_allocate declaration", + ); + check_required_declaration( + &mut v, + module, + qir::rt::RESULT_ARRAY_RELEASE, + "arrays + dynamic_result_management requires __quantum__rt__result_array_release declaration", + ); + } + } + + if module_uses_function(module, qir::rt::RESULT_ARRAY_RECORD_OUTPUT) { + check_required_declaration( + &mut v, + module, + qir::rt::RESULT_ARRAY_RECORD_OUTPUT, + "calls to __quantum__rt__result_array_record_output require a matching declaration", + ); + } + + v +} + +fn check_required_declaration( + v: &mut Vec, + module: &Module, + name: &str, + help: &str, +) { + let found = module + .functions + .iter() + .any(|f| f.is_declaration && f.name == name); + if !found { + v.push(QirProfileError::MissingDeclaration { + function_name: name.into(), + help_text: help.into(), + }); + } +} + +fn module_uses_function(module: &Module, name: &str) -> bool { + module + .functions + .iter() + .filter(|func| !func.is_declaration) + .flat_map(|func| &func.basic_blocks) + .flat_map(|block| &block.instructions) + .any(|instr| matches!(instr, Instruction::Call { callee, .. } if callee == name)) +} + +fn format_function_sig(func: &Function) -> String { + let params: Vec = func.params.iter().map(|p| p.ty.to_string()).collect(); + format!("{}({})", func.return_type, params.join(", ")) +} + +fn validate_output_recording_calls(module: &Module) -> Vec { + let mut v = Vec::new(); + + for func in module.functions.iter().filter(|func| !func.is_declaration) { + for (block_idx, block) in func.basic_blocks.iter().enumerate() { + for (instr_idx, instr) in block.instructions.iter().enumerate() { + let Instruction::Call { callee, args, .. } = instr else { + continue; + }; + + let Some(label_arg_index) = qir::output_label_arg_index(callee) else { + continue; + }; + + let Some((_, label_operand)) = args.get(label_arg_index) else { + continue; + }; + + if !is_string_label_operand(module, label_operand) { + v.push(QirProfileError::InvalidOutputLabelOperand { + function_name: callee.clone(), + location: format!( + "function '{}' block {block_idx} instruction {instr_idx}", + func.name + ), + found_operand: describe_operand(label_operand), + }); + } + } + } + } + + v +} +fn is_string_label_operand(module: &Module, operand: &Operand) -> bool { + match operand { + Operand::GlobalRef(name) => is_string_label_global(module, name), + Operand::GetElementPtr { ptr, .. } => is_string_label_global(module, ptr), + _ => false, + } +} + +fn is_string_label_global(module: &Module, name: &str) -> bool { + module.globals.iter().any(|global| { + global.name == name + && global.is_constant + && matches!(global.initializer, Some(Constant::CString(_))) + }) +} + +fn describe_operand(operand: &Operand) -> String { + match operand { + Operand::LocalRef(name) => format!("local %{name}"), + Operand::TypedLocalRef(name, ty) => format!("local %{name} ({ty})"), + Operand::IntConst(_, value) => format!("integer constant {value}"), + Operand::FloatConst(ty, value) => format!("{ty} constant {value}"), + Operand::NullPtr => "null pointer".into(), + Operand::IntToPtr(value, _) => format!("inttoptr({value})"), + Operand::GetElementPtr { ptr, .. } => format!("getelementptr from @{ptr}"), + Operand::GlobalRef(name) => format!("global @{name}"), + } +} + +fn validate_cfg( + module: &Module, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, +) -> Vec { + let mut v = Vec::new(); + let Some(ep_idx) = inspect::find_entry_point(module) else { + return v; + }; + let func = &module.functions[ep_idx]; + + if func.basic_blocks.is_empty() { + return v; + } + + if detected.profile == QirProfile::BaseV1 { + // CF-01: Linear flow — each block jumps to the next sequentially. + validate_base_linear_flow(func, &mut v); + } else { + // CF-03: Cycle detection when backwards_branching = 0. + if detected.capabilities.backwards_branching == 0 + && !malformed_flags.contains("backwards_branching") + { + detect_cycles(func, &mut v); + } + // CF-04: Multiple ret only if multiple_return_points. + if !detected.capabilities.multiple_return_points + && !malformed_flags.contains("multiple_return_points") + { + let ret_count = count_ret_instructions(func); + if ret_count > 1 { + v.push(QirProfileError::UnauthorizedMultipleReturns { + function: func.name.clone(), + ret_count, + }); + } + } + } + + v +} + +fn validate_base_linear_flow(func: &Function, v: &mut Vec) { + // Each block (except last) should end with an unconditional jump to the next block. + for (bi, bb) in func.basic_blocks.iter().enumerate() { + if bi == func.basic_blocks.len() - 1 { + continue; // Last block should end with ret, not checked here. + } + let term = bb.instructions.last(); + let next_idx = bi + 1; + let next_name = &func.basic_blocks[next_idx].name; + match term { + Some(Instruction::Jump { dest }) if dest == next_name => {} // OK + _ => { + v.push(QirProfileError::NonLinearFlow { + block_idx: bi, + expected_next: next_idx, + }); + } + } + } +} + +fn detect_cycles(func: &Function, v: &mut Vec) { + let n = func.basic_blocks.len(); + if n == 0 { + return; + } + + // Build block-name → index map. + let block_index: FxHashSet<_> = func + .basic_blocks + .iter() + .enumerate() + .map(|(i, bb)| (bb.name.as_str(), i)) + .collect::>() + .into_iter() + .collect(); + + // Build adjacency list. + let mut adj: Vec> = vec![Vec::new(); n]; + for (bi, bb) in func.basic_blocks.iter().enumerate() { + for dest_name in block_successors(bb) { + // Look up block index by name. + if let Some(&(_, idx)) = block_index.iter().find(|&&(name, _)| name == dest_name) { + adj[bi].push(idx); + } + } + } + + // DFS back-edge detection. + let mut color = vec![0u8; n]; // 0=white, 1=gray, 2=black + let mut has_cycle = false; + dfs_cycle_check(0, &adj, &mut color, &mut has_cycle); + + if has_cycle { + v.push(QirProfileError::UnauthorizedCycle { + function: func.name.clone(), + }); + } +} + +fn dfs_cycle_check(node: usize, adj: &[Vec], color: &mut [u8], has_cycle: &mut bool) { + color[node] = 1; // Gray + for &next in &adj[node] { + if color[next] == 1 { + *has_cycle = true; + return; + } + if color[next] == 0 { + dfs_cycle_check(next, adj, color, has_cycle); + if *has_cycle { + return; + } + } + } + color[node] = 2; // Black +} + +fn block_successors(bb: &BasicBlock) -> Vec<&str> { + let Some(term) = bb.instructions.last() else { + return Vec::new(); + }; + match term { + Instruction::Jump { dest } => vec![dest.as_str()], + Instruction::Br { + true_dest, + false_dest, + .. + } => vec![true_dest.as_str(), false_dest.as_str()], + Instruction::Switch { + default_dest, + cases, + .. + } => { + let mut dests = vec![default_dest.as_str()]; + for (_, dest) in cases { + dests.push(dest.as_str()); + } + dests + } + _ => Vec::new(), + } +} + +fn count_ret_instructions(func: &Function) -> usize { + func.basic_blocks + .iter() + .flat_map(|bb| &bb.instructions) + .filter(|i| matches!(i, Instruction::Ret(_))) + .count() +} + +fn validate_consistency( + module: &Module, + detected: &DetectedProfile, + malformed_flags: &FxHashSet, +) -> Vec { + let mut v = Vec::new(); + + // Only applies to adaptive profiles. + if detected.profile == QirProfile::BaseV1 { + return v; + } + + let used = scan_used_capabilities(module, detected); + let float_analysis = inspect::analyze_float_surface(module); + let used_float_widths = float_analysis.surface_width_names(); + let float_flag_present = inspect::get_module_flag(module, "float_computations").is_some() + && !malformed_flags.contains("float_computations"); + let declared_float_widths: FxHashSet<_> = detected + .capabilities + .float_computations + .iter() + .map(String::as_str) + .collect(); + + // CR-01: int instructions used → int_computations must be declared. + if used.has_int_instructions + && detected.capabilities.int_computations.is_empty() + && !malformed_flags.contains("int_computations") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "integer computation".into(), + flag_name: "int_computations".into(), + }); + } + + // CR-02: any float-typed IR surface requires float_computations to be declared. + if !used_float_widths.is_empty() + && !float_flag_present + && !malformed_flags.contains("float_computations") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "floating-point type usage".into(), + flag_name: "float_computations".into(), + }); + } + + // CR-02a: float_computations may only be declared when the module has a float op. + if float_flag_present && !float_analysis.has_float_op { + v.push(QirProfileError::FloatCapabilityWithoutOperation); + } + + // CR-02b: every used float width must be declared in float_computations. + if float_flag_present { + for width_name in used_float_widths { + if !declared_float_widths.contains(width_name) { + v.push(QirProfileError::FloatWidthNotDeclared { + width_name: width_name.to_string(), + }); + } + } + } + + // CR-03: Non-entry-point function definitions → ir_functions must be true. + if used.has_ir_functions + && !detected.capabilities.ir_functions + && !malformed_flags.contains("ir_functions") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "non-entry-point function definition".into(), + flag_name: "ir_functions".into(), + }); + } + + // CR-05: switch used → multiple_target_branching must be true. + if used.has_switch + && !detected.capabilities.multiple_target_branching + && !malformed_flags.contains("multiple_target_branching") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "switch".into(), + flag_name: "multiple_target_branching".into(), + }); + } + + // CR-06: Multiple ret → multiple_return_points must be true. + if used.ret_count > 1 + && !detected.capabilities.multiple_return_points + && !malformed_flags.contains("multiple_return_points") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "multiple return point".into(), + flag_name: "multiple_return_points".into(), + }); + } + + // CR-07: phi used → backwards_branching > 0. + if used.has_phi + && detected.capabilities.backwards_branching == 0 + && !malformed_flags.contains("backwards_branching") + { + v.push(QirProfileError::CapabilityNotDeclared { + feature: "phi".into(), + flag_name: "backwards_branching".into(), + }); + } + + v +} + +#[allow(clippy::struct_excessive_bools)] +struct UsedCapabilities { + has_int_instructions: bool, + has_switch: bool, + has_phi: bool, + has_ir_functions: bool, + ret_count: usize, +} + +fn scan_used_capabilities(module: &Module, _detected: &DetectedProfile) -> UsedCapabilities { + let mut used = UsedCapabilities { + has_int_instructions: false, + has_switch: false, + has_phi: false, + has_ir_functions: false, + ret_count: 0, + }; + + let ep_idx = inspect::find_entry_point(module); + + // CR-03: Check for non-entry-point defined functions. + for (i, func) in module.functions.iter().enumerate() { + if !func.is_declaration && Some(i) != ep_idx { + used.has_ir_functions = true; + } + } + + // Scan instructions in all defined functions. + for func in &module.functions { + if func.is_declaration { + continue; + } + for bb in &func.basic_blocks { + for instr in &bb.instructions { + match instr { + Instruction::BinOp { op, .. } => { + if is_int_binop(op) { + used.has_int_instructions = true; + } + } + Instruction::ICmp { .. } | Instruction::Select { .. } => { + used.has_int_instructions = true; + } + Instruction::Cast { op, .. } => match op { + CastKind::Zext | CastKind::Sext | CastKind::Trunc => { + used.has_int_instructions = true; + } + CastKind::Sitofp | CastKind::Fptosi => { + used.has_int_instructions = true; + } + _ => {} + }, + Instruction::Phi { .. } => { + used.has_phi = true; + } + Instruction::Switch { .. } => { + used.has_switch = true; + } + Instruction::Ret(_) => { + used.ret_count += 1; + } + _ => {} + } + } + } + } + + used +} diff --git a/source/compiler/qsc_llvm/src/validation/qir/tests.rs b/source/compiler/qsc_llvm/src/validation/qir/tests.rs new file mode 100644 index 0000000000..831f7ca399 --- /dev/null +++ b/source/compiler/qsc_llvm/src/validation/qir/tests.rs @@ -0,0 +1,1418 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::*; +use crate::model::Type; +use crate::model::*; +use crate::text::reader::parse_module; + +// -- Test helper: build a minimal valid base profile v1 module -- +fn base_v1_module() -> Module { + parse_module( + r#"%Result = type opaque +%Qubit = type opaque + +@0 = internal constant [4 x i8] c"0_r\00" + +define i64 @ENTRYPOINT__main() #0 { +entry: + call void @__quantum__rt__initialize(ptr null) + br label %body +body: + call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + br label %measurements +measurements: + call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + br label %output +output: + call void @__quantum__rt__result_record_output(ptr inttoptr (i64 0 to ptr), ptr @0) + ret i64 0 +} + +declare void @__quantum__rt__initialize(ptr) +declare void @__quantum__qis__h__body(%Qubit*) +declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1 +declare void @__quantum__rt__tuple_record_output(i64, ptr) +declare void @__quantum__rt__result_record_output(ptr, ptr) + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="1" "required_num_results"="1" } +attributes #1 = { "irreversible" } + +!llvm.module.flags = !{!0, !1, !2, !3} +!0 = !{i32 1, !"qir_major_version", i32 1} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} +"#, + ) + .expect("base_v1_module IR should parse") +} + +// -- Test helper: build a minimal valid adaptive v2 module -- +fn adaptive_v2_1_module() -> Module { + parse_module( + r#"@0 = internal constant [4 x i8] c"0_r\00" + +define i64 @ENTRYPOINT__main() #0 { +entry: + call void @__quantum__rt__initialize(ptr null) + call void @__quantum__qis__h__body(ptr inttoptr (i64 0 to ptr)) + ret i64 0 +} + +declare void @__quantum__rt__initialize(ptr) +declare i1 @__quantum__rt__read_result(ptr) +declare void @__quantum__qis__h__body(ptr) +declare void @__quantum__rt__result_record_output(ptr, ptr) + +attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="0" } + +!llvm.module.flags = !{!0, !1, !2, !3} +!0 = !{i32 1, !"qir_major_version", i32 2} +!1 = !{i32 7, !"qir_minor_version", i32 0} +!2 = !{i32 1, !"dynamic_qubit_management", i1 false} +!3 = !{i32 1, !"dynamic_result_management", i1 false} +"#, + ) + .expect("adaptive_v2_1_module IR should parse") +} + +fn has_error bool>(result: &QirProfileValidation, pred: F) -> bool { + result.errors.iter().any(pred) +} + +fn set_float_computations(module: &mut Module, widths: &[&str]) { + let values = widths + .iter() + .map(|width| MetadataValue::String((*width).to_string())) + .collect(); + + set_flag_value( + module, + "float_computations", + 5, + MetadataValue::SubList(values), + ); +} + +fn set_flag_value(module: &mut Module, key: &str, behavior: i64, value: MetadataValue) { + if let Some(node) = module.metadata_nodes.iter_mut().find(|node| { + node.values + .iter() + .any(|entry| matches!(entry, MetadataValue::String(text) if text == key)) + }) { + node.values[0] = MetadataValue::Int(Type::Integer(32), behavior); + node.values[2] = value; + return; + } + + let next_id = module + .metadata_nodes + .iter() + .map(|node| node.id) + .max() + .unwrap_or(0) + + 1; + + module.metadata_nodes.push(MetadataNode { + id: next_id, + values: vec![ + MetadataValue::Int(Type::Integer(32), behavior), + MetadataValue::String(key.to_string()), + value, + ], + }); + + if let Some(module_flags) = module + .named_metadata + .iter_mut() + .find(|metadata| metadata.name == "llvm.module.flags") + { + module_flags.node_refs.push(next_id); + } else { + module.named_metadata.push(NamedMetadata { + name: "llvm.module.flags".to_string(), + node_refs: vec![next_id], + }); + } +} + +fn set_bool_flag(module: &mut Module, key: &str, value: bool) { + set_flag_value( + module, + key, + 1, + MetadataValue::Int(Type::Integer(1), i64::from(value)), + ); +} + +fn set_backwards_branching_flag(module: &mut Module, value: u8) { + set_flag_value( + module, + "backwards_branching", + 7, + MetadataValue::Int(Type::Integer(8), i64::from(value)), + ); +} + +fn push_declaration(module: &mut Module, name: &str, return_type: Type, params: Vec) { + module.functions.push(Function { + name: name.to_string(), + return_type, + params: params + .into_iter() + .map(|ty| Param { ty, name: None }) + .collect(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); +} + +fn assert_adaptive_single_float_width_is_allowed(width_name: &str, ty: &Type) { + let mut m = adaptive_v2_1_module(); + set_float_computations(&mut m, &[width_name]); + + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: ty.clone(), + lhs: Operand::float_const(ty.clone(), 1.0), + rhs: Operand::float_const(ty.clone(), 2.0), + result: format!("{width_name}_sum"), + }, + ); + + let result = validate_qir_profile(&m); + assert_eq!( + result.detected.capabilities.float_computations, + vec![width_name.to_string()] + ); + assert!( + result.errors.is_empty(), + "expected {width_name}-only float surface to validate, got: {:#?}", + result.errors + ); +} + +// ---- Profile detection tests ---- + +#[test] +fn detect_base_v1_profile() { + let m = base_v1_module(); + let result = validate_qir_profile(&m); + assert_eq!(result.detected.profile, QirProfile::BaseV1); +} + +#[test] +fn detect_adaptive_v2_profile() { + let m = adaptive_v2_1_module(); + let result = validate_qir_profile(&m); + assert_eq!(result.detected.profile, QirProfile::AdaptiveV2); +} + +// ---- Base profile valid module ---- + +#[test] +fn base_v1_valid_module_no_violations() { + let m = base_v1_module(); + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected no errors, got: {:#?}", + result.errors + ); +} + +// ---- Adaptive valid module ---- + +#[test] +fn adaptive_v2_1_valid_module_no_violations() { + let m = adaptive_v2_1_module(); + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected no errors, got: {:#?}", + result.errors + ); +} + +// ---- MS-01: Missing struct types ---- + +#[test] +fn base_v1_missing_qubit_struct() { + let mut m = base_v1_module(); + m.struct_types.retain(|s| s.name != "Qubit"); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::MissingOpaqueType { .. } + ))); +} + +// ---- MS-03: Multiple entry points ---- + +#[test] +fn multiple_entry_points_violation() { + let mut m = base_v1_module(); + // Add a second entry point. + m.functions.push(Function { + name: "ENTRYPOINT__other".into(), + return_type: Type::Integer(64), + params: Vec::new(), + is_declaration: false, + attribute_group_refs: vec![0], + basic_blocks: vec![BasicBlock { + name: "entry".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }], + }); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::EntryPointCount { .. } + ))); +} + +// ---- MF-01: Missing module flags ---- + +#[test] +fn missing_module_flags_violation() { + let mut m = base_v1_module(); + m.named_metadata.clear(); + m.metadata_nodes.clear(); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingModuleFlag { flag_name } if flag_name == "qir_major_version") + )); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingModuleFlag { flag_name } if flag_name == "qir_minor_version") + )); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingModuleFlag { flag_name } if flag_name == "dynamic_qubit_management") + )); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingModuleFlag { flag_name } if flag_name == "dynamic_result_management") + )); +} + +// ---- EP-01: Entry point with parameters ---- + +#[test] +fn entry_point_with_params_violation() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].params.push(Param { + ty: Type::Integer(32), + name: None, + }); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::EntryPointParams { .. } + ))); +} + +// ---- EP-02: Wrong return type ---- + +#[test] +fn entry_point_wrong_return_type() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].return_type = Type::Void; + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::EntryPointReturnType { .. } + ))); +} + +// ---- EP-03: Wrong block count for base ---- + +#[test] +fn base_v1_wrong_block_count() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Remove blocks to have only 2. + m.functions[ep_idx].basic_blocks.truncate(2); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::BaseBlockCount { .. } + ))); +} + +// ---- CI-01: BinOp in base profile ---- + +#[test] +fn base_v1_binop_not_allowed() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Insert an Add instruction in the body block. + m.functions[ep_idx].basic_blocks[1].instructions.insert( + 0, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "sum".into(), + }, + ); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::InstructionNotAllowed { .. } + ))); +} + +// ---- CI-04: Conditional branch in base profile ---- + +#[test] +fn base_v1_conditional_branch_violation() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Replace a jump with conditional branch. + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "body".into(), + false_dest: "body".into(), + }; + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::BaseConditionalBranch { .. } + ))); +} + +// ---- AP-CI-02: Int instructions without capability ---- + +#[test] +fn adaptive_int_binop_without_capability() { + let mut m = adaptive_v2_1_module(); + // No int_computations flag in metadata → should fail. + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Add, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 1), + rhs: Operand::IntConst(Type::Integer(64), 2), + result: "sum".into(), + }, + ); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingCapability { capability, .. } if capability == "int_computations") + )); +} + +// ---- AP-CI-03: Float instructions without capability ---- + +#[test] +fn adaptive_float_binop_without_capability() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::float_const(Type::Double, 1.0), + rhs: Operand::float_const(Type::Double, 2.0), + result: "fsum".into(), + }, + ); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingCapability { capability, .. } if capability == "float_computations") + )); +} + +#[test] +fn adaptive_float_signature_without_capability_triggers_cr_02() { + let mut m = adaptive_v2_1_module(); + m.functions.push(Function { + name: "use_double".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Double, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::CapabilityNotDeclared { flag_name, .. } if flag_name == "float_computations") + )); +} + +#[test] +fn adaptive_undeclared_float_width_triggers_allow_list_violation() { + let mut m = adaptive_v2_1_module(); + set_float_computations(&mut m, &["half"]); + + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::float_const(Type::Double, 1.0), + rhs: Operand::float_const(Type::Double, 2.0), + result: "fsum".into(), + }, + ); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::FloatWidthNotDeclared { width_name } if width_name == "double" + ))); +} + +#[test] +fn adaptive_float_capability_without_operation_triggers_contract_violation() { + let mut m = adaptive_v2_1_module(); + set_float_computations(&mut m, &["double"]); + m.functions.push(Function { + name: "use_double".to_string(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Double, + name: None, + }], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::FloatCapabilityWithoutOperation + ))); +} + +#[test] +fn adaptive_over_declared_float_widths_are_allowed() { + let mut m = adaptive_v2_1_module(); + set_float_computations(&mut m, &["half", "double"]); + + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::float_const(Type::Double, 1.0), + rhs: Operand::float_const(Type::Double, 2.0), + result: "fsum".into(), + }, + ); + + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected no errors, got: {:#?}", + result.errors + ); +} + +#[test] +fn adaptive_half_only_float_width_is_allowed() { + assert_adaptive_single_float_width_is_allowed("half", &Type::Half); +} + +#[test] +fn adaptive_float_only_float_width_is_allowed() { + assert_adaptive_single_float_width_is_allowed("float", &Type::Float); +} + +#[test] +fn adaptive_double_only_float_width_is_allowed() { + assert_adaptive_single_float_width_is_allowed("double", &Type::Double); +} + +// ---- AP-CI-04: Switch without capability ---- + +#[test] +fn adaptive_switch_without_capability() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Replace ret with switch + extra blocks. + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + default_dest: "exit".into(), + cases: Vec::new(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "exit".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingCapability { capability, .. } if capability == "multiple_target_branching") + )); +} + +// ---- RT-01: Missing rt::initialize ---- + +#[test] +fn missing_rt_initialize() { + let mut m = base_v1_module(); + m.functions + .retain(|f| f.name != "__quantum__rt__initialize"); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingDeclaration { function_name, .. } if function_name == "__quantum__rt__initialize") + )); +} + +// ---- AP-MC-02: Missing read_result for adaptive ---- + +#[test] +fn adaptive_missing_read_result() { + let mut m = adaptive_v2_1_module(); + m.functions + .retain(|f| f.name != "__quantum__rt__read_result"); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingDeclaration { function_name, .. } if function_name == "__quantum__rt__read_result") + )); +} + +// ---- QIS-01: QIS non-void return in base ---- + +#[test] +fn base_v1_qis_non_void_return() { + let mut m = base_v1_module(); + // Change h gate to return i1. + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__qis__h__body") + { + f.return_type = Type::Integer(1); + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::QisNonVoidReturn { .. } + ))); +} + +// ---- DT-04: Base profile with dynamic_qubit_management = true ---- + +#[test] +fn base_v1_dynamic_qubit_management_true() { + let mut m = base_v1_module(); + // Change dynamic_qubit_management flag to 1. + if let Some(node) = m.metadata_nodes.iter_mut().find(|n| { + n.values + .iter() + .any(|v| matches!(v, MetadataValue::String(s) if s == "dynamic_qubit_management")) + }) { + node.values[2] = MetadataValue::Int(Type::Integer(1), 1); + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::BaseDynamicMgmtEnabled { .. } + ))); +} + +// ---- CF-01: Non-linear flow in base profile ---- + +#[test] +fn base_v1_non_linear_flow() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Make block 0 jump to block 2 (skipping block 1). + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Jump { + dest: "measurements".into(), + }; + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::NonLinearFlow { .. } + ))); +} + +// ---- CF-03: Cycle detection ---- + +#[test] +fn adaptive_cycle_without_backwards_branching() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Create a cycle: entry → loop → entry (back-edge). + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Jump { + dest: "loop".into(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "loop".into(), + instructions: vec![Instruction::Jump { + dest: "entry".into(), + }], + }); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::UnauthorizedCycle { .. } + ))); +} + +#[test] +fn adaptive_cycle_with_backwards_branching_is_allowed() { + let mut m = adaptive_v2_1_module(); + set_backwards_branching_flag(&mut m, 1); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Jump { + dest: "loop".into(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "loop".into(), + instructions: vec![Instruction::Jump { + dest: "entry".into(), + }], + }); + + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected backwards_branching-enabled cycle to validate, got: {:#?}", + result.errors + ); +} + +// ---- CF-04: Multiple ret without capability ---- + +#[test] +fn adaptive_multiple_ret_without_capability() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "then".into(), + false_dest: "else".into(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "then".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }); + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "else".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 1, + )))], + }); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::UnauthorizedMultipleReturns { .. } + ))); +} + +#[test] +fn adaptive_multiple_ret_with_capability_is_allowed() { + let mut m = adaptive_v2_1_module(); + set_bool_flag(&mut m, "multiple_return_points", true); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "then".into(), + false_dest: "else".into(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "then".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }); + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "else".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 1, + )))], + }); + + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected multiple_return_points-enabled control flow to validate, got: {:#?}", + result.errors + ); +} + +#[test] +fn unsupported_profile_version_pair_is_reported_explicitly() { + let mut m = base_v1_module(); + set_flag_value( + &mut m, + "qir_major_version", + 1, + MetadataValue::Int(Type::Integer(32), 2), + ); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::UnsupportedProfileMetadata { + profile_name, + major_version, + } if profile_name == "base_profile" && *major_version == 2 + ))); +} + +#[test] +fn dangling_module_flag_reference_is_reported_without_hiding_later_flags() { + let mut m = adaptive_v2_1_module(); + m.named_metadata[0].node_refs.insert(0, 999); + + let result = validate_qir_profile(&m); + assert_eq!(result.detected.profile, QirProfile::AdaptiveV2); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::DanglingModuleFlagReference { node_ref } if *node_ref == 999 + ))); + assert!( + !has_error(&result, |e| matches!( + e, + QirProfileError::MissingModuleFlag { flag_name } if flag_name == "qir_major_version" + )), + "dangling refs should not hide later valid qir_major_version flags" + ); +} + +#[test] +fn malformed_float_capability_flag_is_reported_instead_of_missing_capability() { + let mut m = adaptive_v2_1_module(); + set_flag_value( + &mut m, + "float_computations", + 5, + MetadataValue::SubList(vec![MetadataValue::Int(Type::Integer(32), 1)]), + ); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::BinOp { + op: BinOpKind::Fadd, + ty: Type::Double, + lhs: Operand::float_const(Type::Double, 1.0), + rhs: Operand::float_const(Type::Double, 2.0), + result: "fsum".into(), + }, + ); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::MalformedModuleFlag { flag_name, .. } if flag_name == "float_computations" + ))); + assert!(!has_error(&result, |e| matches!( + e, + QirProfileError::MissingCapability { capability, .. } if capability == "float_computations" + ))); + assert!(!has_error(&result, |e| matches!( + e, + QirProfileError::CapabilityNotDeclared { flag_name, .. } if flag_name == "float_computations" + ))); +} + +#[test] +fn malformed_multiple_return_points_flag_is_reported_instead_of_defaulting_to_missing() { + let mut m = adaptive_v2_1_module(); + set_flag_value( + &mut m, + "multiple_return_points", + 1, + MetadataValue::String("true".to_string()), + ); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Br { + cond_ty: Type::Integer(1), + cond: Operand::IntConst(Type::Integer(1), 1), + true_dest: "then".into(), + false_dest: "else".into(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "then".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }); + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "else".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 1, + )))], + }); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::MalformedModuleFlag { flag_name, .. } if flag_name == "multiple_return_points" + ))); + assert!(!has_error(&result, |e| matches!( + e, + QirProfileError::UnauthorizedMultipleReturns { .. } + ))); + assert!(!has_error(&result, |e| matches!( + e, + QirProfileError::CapabilityNotDeclared { flag_name, .. } if flag_name == "multiple_return_points" + ))); +} + +// ---- CR-01: Int instructions without capability flag ---- + +#[test] +fn adaptive_consistency_int_without_flag() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::ICmp { + pred: crate::model::IntPredicate::Eq, + ty: Type::Integer(64), + lhs: Operand::IntConst(Type::Integer(64), 0), + rhs: Operand::IntConst(Type::Integer(64), 1), + result: "cmp".into(), + }, + ); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::CapabilityNotDeclared { flag_name, .. } if flag_name == "int_computations") + )); +} + +// ---- CR-05: Switch without multiple_target_branching ---- + +#[test] +fn adaptive_consistency_switch_without_flag() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let last = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions[last] = Instruction::Switch { + ty: Type::Integer(64), + value: Operand::IntConst(Type::Integer(64), 0), + default_dest: "exit".into(), + cases: Vec::new(), + }; + m.functions[ep_idx].basic_blocks.push(BasicBlock { + name: "exit".into(), + instructions: vec![Instruction::Ret(Some(Operand::IntConst( + Type::Integer(64), + 0, + )))], + }); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::CapabilityNotDeclared { flag_name, .. } if flag_name == "multiple_target_branching") + )); +} + +// ---- AT-02: Missing qir_profiles attribute ---- + +#[test] +fn missing_qir_profiles_attribute() { + let mut m = base_v1_module(); + // Remove qir_profiles from attribute group. + m.attribute_groups[0] + .attributes + .retain(|a| !matches!(a, Attribute::KeyValue(k, _) if k == "qir_profiles")); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingEntryPointAttr { attr_name } if attr_name == "qir_profiles") + )); +} + +// ---- AT-03: Missing required_num_qubits ---- + +#[test] +fn missing_required_num_qubits() { + let mut m = base_v1_module(); + m.attribute_groups[0] + .attributes + .retain(|a| !matches!(a, Attribute::KeyValue(k, _) if k == "required_num_qubits")); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingEntryPointAttr { attr_name } if attr_name == "required_num_qubits") + )); +} + +#[test] +fn dynamic_management_flags_allow_missing_required_counts() { + let mut m = adaptive_v2_1_module(); + set_bool_flag(&mut m, "dynamic_qubit_management", true); + set_bool_flag(&mut m, "dynamic_result_management", true); + m.attribute_groups[0].attributes.retain(|attr| { + !matches!(attr, Attribute::KeyValue(key, _) if key == "required_num_qubits" || key == "required_num_results") + }); + + push_declaration(&mut m, qir::rt::QUBIT_ALLOCATE, Type::Ptr, Vec::new()); + push_declaration(&mut m, qir::rt::QUBIT_RELEASE, Type::Void, vec![Type::Ptr]); + push_declaration(&mut m, qir::rt::RESULT_ALLOCATE, Type::Ptr, Vec::new()); + push_declaration(&mut m, qir::rt::RESULT_RELEASE, Type::Void, vec![Type::Ptr]); + + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected dynamic-management entry point counts to be optional, got: {:#?}", + result.errors + ); +} + +// ---- RT-03: Wrong output recording signature ---- + +#[test] +fn wrong_tuple_record_output_signature() { + let mut m = base_v1_module(); + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__rt__tuple_record_output") + { + f.params.clear(); // Wrong: no params. + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::WrongSignature { .. } + ))); +} + +// ---- RT-03: result_record_output wrong return type ---- + +#[test] +fn result_record_output_wrong_return_type() { + let mut m = base_v1_module(); + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__rt__result_record_output") + { + f.return_type = Type::Integer(64); // Wrong: should be void. + } + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::WrongSignature { function_name, .. } if function_name == "__quantum__rt__result_record_output") + )); +} + +// ---- RT-03: result_record_output wrong param count ---- + +#[test] +fn result_record_output_wrong_param_count() { + let mut m = base_v1_module(); + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__rt__result_record_output") + { + f.params = vec![Param { + ty: Type::Ptr, + name: None, + }]; // Wrong: 1 param instead of 2. + } + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::WrongSignature { function_name, .. } if function_name == "__quantum__rt__result_record_output") + )); +} + +#[test] +fn result_array_record_output_wrong_signature() { + let mut m = adaptive_v2_1_module(); + push_declaration( + &mut m, + qir::rt::RESULT_ARRAY_RECORD_OUTPUT, + Type::Void, + vec![Type::Ptr, Type::Ptr], + ); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::WrongSignature { function_name, .. } + if function_name == qir::rt::RESULT_ARRAY_RECORD_OUTPUT + ))); +} + +// ---- RT-03: array_record_output wrong return type ---- + +#[test] +fn array_record_output_wrong_return_type() { + let mut m = base_v1_module(); + // Add an array_record_output with wrong return type. + m.functions.push(Function { + name: "__quantum__rt__array_record_output".into(), + return_type: Type::Integer(64), // Wrong: should be void. + params: vec![ + Param { + ty: Type::Integer(64), + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::WrongSignature { function_name, .. } if function_name == "__quantum__rt__array_record_output") + )); +} + +// ---- RT-03: array_record_output wrong param count ---- + +#[test] +fn array_record_output_wrong_param_count() { + let mut m = base_v1_module(); + // Add an array_record_output with wrong param count. + m.functions.push(Function { + name: "__quantum__rt__array_record_output".into(), + return_type: Type::Void, + params: vec![Param { + ty: Type::Integer(64), + name: None, + }], // Wrong: 1 param instead of 2. + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::WrongSignature { function_name, .. } if function_name == "__quantum__rt__array_record_output") + )); +} + +// ---- RT-02: initialize wrong signature ---- + +#[test] +fn initialize_wrong_sig_triggers_rt_02() { + let mut m = base_v1_module(); + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__rt__initialize") + { + f.return_type = Type::Integer(32); // Wrong: should be void. + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::InitializeWrongSignature { .. } + ))); +} + +// ---- RT-04: bool_record_output wrong signature ---- + +#[test] +fn bool_record_output_wrong_sig_triggers_rt_04() { + let mut m = adaptive_v2_1_module(); + // Add a bool_record_output with incorrect signature. + m.functions.push(Function { + name: "__quantum__rt__bool_record_output".into(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::Integer(64), // Wrong: should be i1. + name: None, + }, + Param { + ty: Type::Ptr, + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::WrongSignature { function_name, .. } if function_name == "__quantum__rt__bool_record_output") + )); +} + +// ---- RT-07: qubit_allocate missing when dynamic_qubit_management ---- + +#[test] +fn qubit_allocate_missing_when_dynamic_mgmt_triggers_rt_07() { + let mut m = adaptive_v2_1_module(); + // Enable dynamic_qubit_management. + if let Some(node) = m.metadata_nodes.iter_mut().find(|n| { + n.values + .iter() + .any(|v| matches!(v, MetadataValue::String(s) if s == "dynamic_qubit_management")) + }) { + node.values[2] = MetadataValue::Int(Type::Integer(1), 1); + } + // Do not add qubit_allocate declaration. + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingDeclaration { function_name, .. } if function_name == "__quantum__rt__qubit_allocate") + )); +} + +// ---- RT-09: result_allocate missing when dynamic_result_management ---- + +#[test] +fn result_allocate_missing_when_dynamic_mgmt_triggers_rt_09() { + let mut m = adaptive_v2_1_module(); + // Enable dynamic_result_management. + if let Some(node) = m.metadata_nodes.iter_mut().find(|n| { + n.values + .iter() + .any(|v| matches!(v, MetadataValue::String(s) if s == "dynamic_result_management")) + }) { + node.values[2] = MetadataValue::Int(Type::Integer(1), 1); + } + // Do not add result_allocate declaration. + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingDeclaration { function_name, .. } if function_name == "__quantum__rt__result_allocate") + )); +} + +// ---- AR-02: alloca in adaptive without arrays flag ---- + +#[test] +fn arrays_instructions_without_flag_trigger_ar_02() { + let mut m = adaptive_v2_1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + // Insert an alloca instruction. + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::Alloca { + ty: Type::Integer(64), + result: "alloc".into(), + }, + ); + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::ArraysNotEnabled { .. } + ))); +} + +#[test] +fn arrays_instructions_with_arrays_flag_are_allowed() { + let mut m = adaptive_v2_1_module(); + set_bool_flag(&mut m, "arrays", true); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + m.functions[ep_idx].basic_blocks[0].instructions.insert( + 1, + Instruction::Alloca { + ty: Type::Integer(64), + result: "alloc".into(), + }, + ); + + let result = validate_qir_profile(&m); + assert!( + result.errors.is_empty(), + "expected arrays-enabled alloca to validate, got: {:#?}", + result.errors + ); +} + +#[test] +fn result_array_record_output_requires_declaration_when_used() { + let mut m = adaptive_v2_1_module(); + set_bool_flag(&mut m, "arrays", true); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let insert_at = m.functions[ep_idx].basic_blocks[0].instructions.len() - 1; + m.functions[ep_idx].basic_blocks[0].instructions.insert( + insert_at, + Instruction::Call { + return_ty: None, + callee: qir::rt::RESULT_ARRAY_RECORD_OUTPUT.to_string(), + args: vec![ + (Type::Integer(64), Operand::IntConst(Type::Integer(64), 1)), + (Type::Ptr, Operand::IntToPtr(0, Type::Ptr)), + (Type::Ptr, Operand::GlobalRef("0".into())), + ], + result: None, + attr_refs: Vec::new(), + }, + ); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::MissingDeclaration { function_name, .. } + if function_name == qir::rt::RESULT_ARRAY_RECORD_OUTPUT + ))); +} + +#[test] +fn result_record_output_requires_string_label_operand() { + let mut m = base_v1_module(); + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + let output_block = m.functions[ep_idx] + .basic_blocks + .iter_mut() + .find(|block| block.name == "output") + .expect("output block"); + + let Instruction::Call { args, .. } = &mut output_block.instructions[0] else { + panic!("expected result_record_output call"); + }; + args[1] = (Type::Ptr, Operand::NullPtr); + + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::InvalidOutputLabelOperand { function_name, .. } + if function_name == qir::rt::RESULT_RECORD_OUTPUT + ))); +} + +// ---- QIS-02: measurement missing irreversible ---- + +#[test] +fn measurement_missing_irreversible_triggers_qis_02() { + let mut m = base_v1_module(); + // Rename measurement function to contain "measure" so it triggers QIS-02 check, + // and remove its irreversible attribute reference. + if let Some(f) = m + .functions + .iter_mut() + .find(|f| f.name == "__quantum__qis__m__body") + { + f.name = "__quantum__qis__measure__body".to_string(); + f.attribute_group_refs.clear(); + } + // Also update the call instruction in the entry point to use the new name. + let ep_idx = qir::find_entry_point(&m).expect("entry point"); + for bb in &mut m.functions[ep_idx].basic_blocks { + for instr in &mut bb.instructions { + if let Instruction::Call { callee, .. } = instr + && callee == "__quantum__qis__m__body" + { + *callee = "__quantum__qis__measure__body".to_string(); + } + } + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::MissingIrreversible { .. } + ))); +} + +// ---- AT-06: required_num_qubits non-integer ---- + +#[test] +fn required_num_qubits_non_integer_triggers_at_06() { + let mut m = base_v1_module(); + // Replace required_num_qubits value with a non-integer string. + for ag in &mut m.attribute_groups { + for attr in &mut ag.attributes { + if let Attribute::KeyValue(k, v) = attr + && k == "required_num_qubits" + { + *v = "not_a_number".to_string(); + } + } + } + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingEntryPointAttr { attr_name } if attr_name.contains("required_num_qubits")) + )); +} + +// ---- AT-05: missing output_labeling_schema ---- + +#[test] +fn missing_output_labeling_schema_triggers_at_05() { + let mut m = base_v1_module(); + // Remove output_labeling_schema from attribute group. + m.attribute_groups[0] + .attributes + .retain(|a| !matches!(a, Attribute::StringAttr(s) if s == "output_labeling_schema")); + let result = validate_qir_profile(&m); + assert!(has_error( + &result, + |e| matches!(e, QirProfileError::MissingEntryPointAttr { attr_name } if attr_name == "output_labeling_schema") + )); +} + +// ---- DT-05: base profile with dynamic_result_management = true ---- + +#[test] +fn base_dynamic_result_mgmt_enabled_triggers_dt_05() { + let mut m = base_v1_module(); + // Change dynamic_result_management flag to 1. + if let Some(node) = m.metadata_nodes.iter_mut().find(|n| { + n.values + .iter() + .any(|v| matches!(v, MetadataValue::String(s) if s == "dynamic_result_management")) + }) { + node.values[2] = MetadataValue::Int(Type::Integer(1), 1); + } + let result = validate_qir_profile(&m); + assert!(has_error(&result, |e| matches!( + e, + QirProfileError::BaseDynamicMgmtEnabled { .. } + ))); +} diff --git a/source/fuzz/Cargo.toml b/source/fuzz/Cargo.toml index 8240cddc8a..c8b7d9c61d 100644 --- a/source/fuzz/Cargo.toml +++ b/source/fuzz/Cargo.toml @@ -15,6 +15,7 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = { workspace = true, optional = true } qsc = { path = "../compiler/qsc" } +qsc_llvm = { path = "../compiler/qsc_llvm" } [target.'cfg(not(any(target_family = "wasm")))'.dependencies] allocator = { path = "../allocator" } @@ -36,3 +37,9 @@ name = "qasm" path = "fuzz_targets/qasm.rs" test = false doc = false + +[[bin]] +name = "qir" +path = "fuzz_targets/qir.rs" +test = false +doc = false diff --git a/source/fuzz/README.md b/source/fuzz/README.md index 9e7bef7657..3ddc0d7a6e 100644 --- a/source/fuzz/README.md +++ b/source/fuzz/README.md @@ -1,420 +1,73 @@ -# Fuzzing +## Overview -Based on [Fuzzing with cargo-fuzz](https://rust-fuzz.github.io/book/cargo-fuzz.html). +Use two separate workflows from the repository root: -For running locally you need the following steps. - -(**On Windows use [WSL](https://learn.microsoft.com/windows/wsl/).** Tested in WSL Ubuntu 22.04) +- `qir` is the real libFuzzer target. Run it through `cargo +nightly fuzz` and pass `--fuzz-dir source/fuzz`. +- `qir_matrix` is a deterministic replay binary. Run it through `cargo run -p fuzz --bin qir_matrix`; do not invoke it through `cargo fuzz`. ## Prerequisites +For libFuzzer runs: + ```bash rustup install nightly -rustup default nightly - cargo install cargo-fuzz ``` -## Running - -**NOTE:** All the commands below are executed in WSL in the directory that contains this "fuzz" directory -in a clean local copy of the repo (the whole repo is not built). - -```bash -cargo fuzz list # Optional. See the available fuzzing targets. -# compile # This fuzzing target fuzzes the `compile()` function. -cargo fuzz run compile --features do_fuzz -- -seed_inputs=@fuzz/seed_inputs/compile/list.txt - # Build and run the fuzzing target "compile". -# The build takes a few minutes. You may get an impression that the build takes place twice, -# that is expected. -# The run/fuzzing can last indefinitely without bumping into a panic. Stop with . - -cargo fuzz run compile --features do_fuzz -- -help=1 # Optional. See the available run settings. - -# Optional. Run fuzzing for 10 runs at most, 5 seconds at most, generate ASCII-only fuzzing sequences. -cargo fuzz run compile --features do_fuzz -- -seed_inputs=@fuzz/seed_inputs/compile/list.txt -runs=10 -max_total_time=5 -only_ascii=1 -``` - -## Purifying the Bugs Found with Fuzzing - -
The commands below were executed in a branch based on the following commit in "main" (click this line). - -```log -commit e51a8b6f145be23fc2358b2cf0bab6707a7a46a0 (origin/main, origin/HEAD, main) -Author: Bill Ticehurst -Date: Wed Apr 19 10:42:03 2023 -0700 - - Fix mapping of spans for non-ASCII code (#182) - - This builds on the branch for the PR at - https://github.com/microsoft/qdk/pull/180 (which fixes the code - sharing issue with non-ASCII chars), not not strictly dependent. - - The excessive comments on the `mapUtf8UnitsToUtf16Units` function should - outline why this is needed and what it fixes. -``` - -
- -The fuzzing run `cargo fuzz run compile`, if hits a panic, reports the panic message - -```log -thread '' panicked at 'local variable should have inferred type', \ - .../qsharp/source/compiler/qsc_frontend/src/typeck/rules.rs:326:30 -``` - -Among the last few lines the log lists the following commands of interest: - -```log -Reproduce with: - cargo fuzz run compile fuzz/artifacts/compile/crash-22fc59256083904ead44f3ce8f5f04a251d7cc23 -Minimize test case with: - cargo fuzz tmin compile fuzz/artifacts/compile/crash-22fc59256083904ead44f3ce8f5f04a251d7cc23 -``` - -The first thing you may typically do is to look at the input that caused the panic: -`cat fuzz/artifacts/compile/crash-22fc59256083904ead44f3ce8f5f04a251d7cc23`. -The input may be longer than sufficient to cause the panic. So the next thing you may want to do -is to shorten the input (see the "Minimize test case with" above), but it is recommended to -add `-r 10000` after `tmin` (which results in a longer run but much shorter input, in any case the run takes within one minute): -`cargo fuzz tmin -r 10000 compile fuzz/artifacts/compile/crash-22fc59256083904ead44f3ce8f5f04a251d7cc23`. -This command makes a number of runs with shorter input to figure out a shorter sequence that causes the panic. -The log fragments of interest are in the end: - -```log -Minimized artifact: - fuzz/artifacts/compile/minimized-from-b665e6267c297608e85c5948481cd353107a07fa -``` - -```log -Reproduce with: - cargo fuzz run compile fuzz/artifacts/compile/minimized-from-b665e6267c297608e85c5948481cd353107a07fa -``` - -**NOTE:** This command of automated input shortening can end up in a _different panic_. -That panic can be a new bug found or a previously known bug. - -Make sure that you are still on track, reproduce the panic of interest with the shortened input (see "Reproduce with" command above): +For deterministic replay on macOS: ```bash -cargo fuzz run compile --features do_fuzz fuzz/artifacts/compile/minimized-from-b665e6267c297608e85c5948481cd353107a07fa +brew install llvm@14 llvm@15 llvm@16 llvm@21 ``` -Right below the panic message the log gives you a stack trace hint: -`note: run with 'RUST_BACKTRACE=1' environment variable to display a backtrace`. - -You can enable the stack trace display, when reproducing the panic, like this: +## List Targets ```bash -RUST_BACKTRACE=1 cargo fuzz run compile --features do_fuzz fuzz/artifacts/compile/minimized-from-b665e6267c297608e85c5948481cd353107a07fa -``` - -(you repeat the repro command but you set the environment variable `RUST_BACKTRACE` for that run only). - -**NOTE:** See the stack trace shown not in the end of the repro log, but immediately after the panic message. - -
Example (click this line). - -```log -thread 'unnamed' panicked at 'local variable should have inferred type', /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_frontend/src/typeck/rules.rs:326:30 -stack backtrace: - 0: rust_begin_unwind - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/panicking.rs:577:5 - 1: core::panicking::panic_fmt - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/panicking.rs:67:14 - 2: core::panicking::panic_display - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/panicking.rs:150:5 - 3: core::panicking::panic_str - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/panicking.rs:134:5 - 4: core::option::expect_failed - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/option.rs:2025:5 - 5: core::option::Option{T}::expect - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/option.rs:913:21 - 6: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:326:30 - 7: qsc_frontend::typeck::rules::Context::infer_binop - at ./src/typeck/rules.rs:445:32 - 8: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:217:56 - 9: qsc_frontend::typeck::rules::Context::infer_binop - at ./src/typeck/rules.rs:444:32 - 10: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:217:56 - 11: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 12: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 13: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 14: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 15: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 16: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 17: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 18: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 19: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 20: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 21: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 22: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 23: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 24: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 25: qsc_frontend::typeck::rules::Context::infer_update - at ./src/typeck/rules.rs:509:38 - 26: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:368:27 - 27: qsc_frontend::typeck::rules::Context::infer_expr - at ./src/typeck/rules.rs:351:26 - 28: qsc_frontend::typeck::rules::Context::infer_stmt - at ./src/typeck/rules.rs:172:27 - 29: qsc_frontend::typeck::rules::Context::infer_block - at ./src/typeck/rules.rs:143:35 - 30: qsc_frontend::typeck::rules::Context::infer_spec - at ./src/typeck/rules.rs:106:21 - 31: qsc_frontend::typeck::rules::spec - at ./src/typeck/rules.rs:610:5 - 32: qsc_frontend::typeck::check::Checker::check_spec - at ./src/typeck/check.rs:105:22 - 33: {qsc_frontend::typeck::check::Checker as qsc_ast::visit::Visitor}::visit_callable_decl - at ./src/typeck/check.rs:131:48 - 34: qsc_ast::visit::walk_item - at /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_ast/src/visit.rs:94:37 - 35: qsc_ast::visit::Visitor::visit_item - at /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_ast/src/visit.rs:20:9 - 36: qsc_ast::visit::walk_namespace::{{closure}} - at /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_ast/src/visit.rs:86:41 - 37: {core::slice::iter::Iter{T} as core::iter::traits::iterator::Iterator}::for_each - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/core/src/slice/iter/macros.rs:201:21 - 38: qsc_ast::visit::walk_namespace - at /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_ast/src/visit.rs:86:5 - 39: qsc_ast::visit::Visitor::visit_namespace - at /mnt/c/ed/dev/QSharpCompiler/qsharp-runtime/qsharp/source/compiler/qsc_ast/src/visit.rs:16:9 - 40: {qsc_frontend::typeck::check::Checker as qsc_ast::visit::Visitor}::visit_package - at ./src/typeck/check.rs:118:13 - 41: qsc_frontend::compile::typeck_all - at ./src/compile.rs:318:5 - 42: qsc_frontend::compile::compile - at ./src/compile.rs:175:28 - 43: compile::_::__libfuzzer_sys_run - at ./fuzz/fuzz_targets/compile.rs:10:17 - 44: rust_fuzzer_test_input - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/src/lib.rs:224:17 - 45: libfuzzer_sys::test_input_wrap::{{closure}} - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/src/lib.rs:61:9 - 46: std::panicking::try::do_call - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/panicking.rs:485:40 - 47: __rust_try - 48: std::panicking::try - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/panicking.rs:449:19 - 49: std::panic::catch_unwind - at /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/panic.rs:140:14 - 50: LLVMFuzzerTestOneInput - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/src/lib.rs:59:22 - 51: _ZN6fuzzer6Fuzzer15ExecuteCallbackEPKhm - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/libfuzzer/FuzzerLoop.cpp:612:13 - 52: _ZN6fuzzer10RunOneTestEPNS_6FuzzerEPKcm - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/libfuzzer/FuzzerDriver.cpp:324:6 - 53: _ZN6fuzzer12FuzzerDriverEPiPPPcPFiPKhmE - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/libfuzzer/FuzzerDriver.cpp:860:9 - 54: main - at /home/rokuzmin/.cargo/registry/src/index.crates.io-6f17d22bba15001f/libfuzzer-sys-0.4.6/libfuzzer/FuzzerMain.cpp:20:10 - 55: __libc_start_main - at /build/glibc-SzIz7B/glibc-2.31/csu/../csu/libc-start.c:308:16 - 56: _start -note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace. -==16693== ERROR: libFuzzer: deadly signal +cargo +nightly fuzz list --fuzz-dir source/fuzz ``` -
- -After the steps above you typically get all the necessary information to file (or to work on) a bug - the short enough input and the panic stack trace. - -If the input is still too long then you may want to shorten it manually (e.g. remove the Q# code comments from the Q# input). - -If you believe that the input is still longer than sufficient to reproduce the panic, e.g. the panic complains about a local variable in the Q# input, -and in the Q# input you have a dozen of functions with a few dozens of nested scopes with local variables, then you will likely want to break in the debugger -upon panic and see the particular local variable that caused the panic. - -To achieve that, you need to rebuild the fuzzing binary with the debugging information (`--dev`): -`cargo fuzz build --dev compile`. -The resulting binary "compile" should be in the "debug", not "release", directory - -```bash -ls fuzz/target/x86_64-unknown-linux-gnu/debug/ -# ... compile ... -``` +## Run LibFuzzer -In your WSL session go to the root directory ("qsharp") of this repo and launch VSCode +Run the real QIR fuzz target against the repo-root corpus: ```bash -code . +cargo +nightly fuzz run --fuzz-dir source/fuzz qir --features do_fuzz -- -runs=200 -max_total_time=30 ``` -(Assuming your VSCode has CodeLLDB extension installed) - -- Click the "Run and Debug" view on the left (or press ``). -- Click the "create a launch.json file" link. Select debugger "LLDB". -- Feel free to reply "No" to the question - "Cargo.toml has been detected in the workspace. - Would you like to generate launch configurations for its targets?". - -
Change the contents of "launch.json" to look like this (click this line). - -```json -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "lldb", - "request": "launch", - "name": "Debug", - "program": "${workspaceFolder}/fuzz/target/x86_64-unknown-linux-gnu/debug/compile", - "args": [ - "fuzz/artifacts/compile/minimized-from-b665e6267c297608e85c5948481cd353107a07fa" - ], - "cwd": "${workspaceFolder}" - } - ] -} -``` - -
- -- Press `` to run the debugging session. The debugger will stop upon panic. -- Look at the call stack, click the stack frames of interest and inspect the local variables and - parameters in those frames to figure out the exact input fragment that caused the panic. - -Then you can manually minimize the input around the fragment of interest. - -## Adding More Fuzzing Targets +Inspect available libFuzzer options: ```bash -cargo fuzz add -cargo fuzz list # Optional. See the available fuzzing targets. -# Edit the "fuzz/fuzz_targets/.rs". -cargo fuzz build # Optional. Build the fuzzing targets. -cargo fuzz run # Build and run. -# See "Running" section for fine-tuning the runs. +cargo +nightly fuzz run --fuzz-dir source/fuzz qir --features do_fuzz -- -help=1 ``` -## Adding More Seed Inputs for Fuzzing - -Add more files with input sequences to the -"fuzz/seed_inputs/\/" directory and add their paths to the list in -"fuzz/seed_inputs/\/list.txt". - -Details +On macOS, if the default AddressSanitizer-backed smoke stalls before the harness starts, use this local diagnostic variant to confirm the corpus and harness path independently: ```bash -cargo fuzz run compile --features do_fuzz -- -help=1 2>&1 | grep seed_inputs -# seed_inputs 0 A comma-separated list of input files to use as an additional seed corpus. -# Alternatively, an "@" followed by the name of a file containing the comma-separated list. +cargo +nightly fuzz run --fuzz-dir source/fuzz --sanitizer none qir --features do_fuzz -- -runs=1 -max_total_time=15 ``` -See more in [LibFuzzer Corpus](https://llvm.org/docs/LibFuzzer.html#corpus). - -## Code Coverage During Fuzzing - -Based on [Code Coverage](https://rust-fuzz.github.io/book/cargo-fuzz/coverage.html#code-coverage). - -Tested in WSL Ubuntu 22.04. +This no-sanitizer command is a local proof aid only. The default `cargo +nightly fuzz run --fuzz-dir source/fuzz qir --features do_fuzz -- -max_total_time=15` command remains the automation gate, and the repository has not promoted the no-sanitizer variant into CI. -### Code Coverage Prerequisites +## Run Deterministic Replay -Note: The command `sudo apt install clang` installed `clang-10` and created the executables `clang` and `clang++` available in the `PATH`. -The installation of other versions, like `sudo apt install clang-14` was installing the executables `clang-14` and `clang++-14`, -but the executables `clang` and `clang++` were still of version 10. - -For the subsequent steps to succeed the executables `llvm-profdata` and `llvm-cov` need to be in the `PATH`. +Replay the checked seed corpus across the fast external LLVM matrix: ```bash -which llvm-profdata # See if the `llvm-profdata` is available. -# llvm-profdata-10 # Not available, but version 10 is installed. -which llvm-profdata-10 # See the path of version 10. -# /usr/bin/llvm-profdata-10 # The path of version 10. -pushd /usr/bin # Temporarily enter the dir where `llvm-profdata-10` is located. -sudo ln -s llvm-profdata-10 llvm-profdata # Create symlink `llvm-profdata` -> `llvm-profdata-10`. -sudo ln -s llvm-cov-10 llvm-cov # Create symlink `llvm-cov` -> `llvm-cov-10`. -popd # Get back to the original dir. - -llvm-cov --version # Optional. See the version. -#LLVM (http://llvm.org/): -# LLVM version 10.0.0 -# Optimized build. -# Default target: x86_64-pc-linux-gnu -# Host CPU: skylake +cargo run -p fuzz --bin qir_matrix -- --toolchains 14,15,16,21 ``` -The executables `llvm-profdata` and `llvm-cov` also need to be in the nightly toolchain. +Write replay artifacts to a known directory: ```bash -rustup default # Make sure that the nightly toolchain is the default. -# nightly-x86_64-unknown-linux-gnu (default) - -# See if the executables `llvm-profdata` and `llvm-cov` are installed in the nightly toolchain: -ls /home/$USER/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/x86_64-unknown-linux-gnu/bin/ -# gcc-ld llc # Nothing starting with `llvm-`. - -# Not installed, install: -rustup component add --toolchain nightly llvm-tools-preview - -# Make sure they are installed: -ls /home/$USER/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/x86_64-unknown-linux-gnu/bin/ -# gcc-ld llc llvm-ar llvm-as llvm-cov llvm-dis llvm-nm llvm-objcopy llvm-objdump llvm-profdata llvm-readobj llvm-size llvm-strip opt rust-lld - -# Optional. See the version: -/home/$USER/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/x86_64-unknown-linux-gnu/bin/llvm-cov --version -#LLVM (http://llvm.org/): -# LLVM version 16.0.2-rust-1.70.0-nightly -# Optimized build. +cargo run -p fuzz --bin qir_matrix -- --toolchains 14,15,16,21 --output-dir /tmp/qir-matrix ``` -Install the Rust demangler. -`cargo install rustfilt` - -### Running the Code Coverage Tool - -```bash -# In "qsc_frontend" directory: - -# Make sure that fuzzing still works OK: -cargo fuzz list # Optional. See the fuzzing targets. -#compile -cargo fuzz run compile --features do_fuzz -- -seed_inputs=@fuzz/seed_inputs/compile/list.txt -max_total_time=1 - # Run the fuzzing for at least 1 second. - # It is assumed that earlier you were running `cargo fuzz run compile` for a long time to gather - # the execution statistics. +The replay harness reads `.seed` files from `source/fuzz/corpus/qir`, exports checked text artifacts, and runs `llvm-as`, `opt -passes=verify`, `llvm-dis`, and a second `llvm-as` per lane. -cargo fuzz coverage compile # Gather the code coverage info. - # The run takes a few minutes. -# Later you will likely need the following data: -# One of the first log lines shows the absolute path to the executable the code coverage is gathered for: -# .../target/x86_64-unknown-linux-gnu/coverage/x86_64-unknown-linux-gnu/release/compile -# The last log line shows the absolute path to the file containing the code coverage info: -# .../fuzz/coverage/compile/coverage.profdata +## Notes -# Generate the HTML-report showing the code coverage for the fuzzing executable: -/home/$USER/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/x86_64-unknown-linux-gnu/bin/llvm-cov \ - show -Xdemangler=rustfilt -show-line-counts-or-regions -show-instantiations --ignore-filename-regex="/home/$USER/.cargo/.*" \ - -format=html \ - -instr-profile=fuzz/coverage/compile/coverage.profdata \ - target/x86_64-unknown-linux-gnu/coverage/x86_64-unknown-linux-gnu/release/compile \ - > index.html -# The unrelated error and warning that were observed (did not affect the result): -#error: /rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/sys/common/thread_local/fast_local.rs: No such file or directory -#warning: The file '/rustc/88fb1b922b047981fc0cfc62aa1418b4361ae72e/library/std/src/sys/common/thread_local/fast_local.rs' isn't covered. - -# Open the "index.html" in the web-browser to see the code coverage report (not necessarily in WSL). -``` +- Run all commands from the repository root. +- `qir_matrix` intentionally keeps deterministic replay separate from libFuzzer. +- The current deterministic seed bank includes BaseV1, AdaptiveV1, AdaptiveV2, and BareRoundtrip replay coverage. +- Failures from the real `qir` target are written under `source/fuzz/artifacts/qir/`. +- See [corpus/README.md](corpus/README.md) for corpus layout and seed naming. diff --git a/source/fuzz/fuzz_targets/qir.rs b/source/fuzz/fuzz_targets/qir.rs new file mode 100644 index 0000000000..2aa2198da6 --- /dev/null +++ b/source/fuzz/fuzz_targets/qir.rs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![no_main] + +allocator::assign_global!(); + +#[cfg(feature = "do_fuzz")] +use libfuzzer_sys::fuzz_target; + +use fuzz::qir_seed_bank::{ + compile, compile_base_v1, compile_mutated_adaptive_v1, compile_mutated_adaptive_v2, +}; +use qsc_llvm::fuzz::mutation::compile_raw_parser_lanes; + +#[cfg(feature = "do_fuzz")] +fuzz_target!(|data: &[u8]| { + compile_raw_parser_lanes(data); + compile(data); + compile_base_v1(data); + compile_mutated_adaptive_v1(data); + compile_mutated_adaptive_v2(data); +}); + +#[cfg(not(feature = "do_fuzz"))] +#[unsafe(no_mangle)] +pub extern "C" fn main() { + compile_raw_parser_lanes(&[]); + compile(&[]); + compile_base_v1(&[]); + compile_mutated_adaptive_v1(&[]); + compile_mutated_adaptive_v2(&[]); +} diff --git a/source/fuzz/src/bin/qir_matrix.rs b/source/fuzz/src/bin/qir_matrix.rs new file mode 100644 index 0000000000..58e005a843 --- /dev/null +++ b/source/fuzz/src/bin/qir_matrix.rs @@ -0,0 +1,476 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::{Command, ExitCode, Output}; +use std::time::{SystemTime, UNIX_EPOCH}; + +allocator::assign_global!(); + +use fuzz::qir_seed_bank::{ + QirSeedInput, default_qir_corpus_dir, generate_checked_seed_artifact, load_seed_inputs, +}; +use qsc_llvm::{GeneratedArtifact, write_module_to_string}; + +const FAST_TOOLCHAINS: [u8; 4] = [14, 15, 16, 21]; +const HOMEBREW_OPT_PREFIXES: [&str; 2] = ["/opt/homebrew/opt", "/usr/local/opt"]; +const MIN_LLVM_VERSION: u8 = 14; +const MAX_LLVM_VERSION: u8 = 21; + +#[derive(Debug)] +struct Options { + toolchains: Vec, + corpus_dir: PathBuf, + output_dir: PathBuf, +} + +#[derive(Debug)] +struct ReplaySummary { + exported_artifacts: usize, + replay_steps: usize, + output_dir: PathBuf, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct LlvmToolchain { + version: u8, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PointerMode { + Typed, + Opaque, +} + +impl PointerMode { + fn from_artifact(artifact: &GeneratedArtifact) -> Self { + if artifact.effective_config.allow_typed_pointers { + Self::Typed + } else { + Self::Opaque + } + } + + const fn extra_args(self, version: u8) -> &'static [&'static str] { + match (version, self) { + (14, Self::Opaque) => &["-opaque-pointers"], + _ => &[], + } + } +} + +impl LlvmToolchain { + fn tool_path(self, tool: &str) -> Result { + let mut candidates = Vec::new(); + + if let Some(prefix) = env::var_os("HOMEBREW_PREFIX") { + let path = PathBuf::from(prefix) + .join("opt") + .join(format!("llvm@{}", self.version)) + .join("bin") + .join(tool); + if !candidates.contains(&path) { + candidates.push(path); + } + } + + for prefix in HOMEBREW_OPT_PREFIXES { + let path = PathBuf::from(prefix) + .join(format!("llvm@{}", self.version)) + .join("bin") + .join(tool); + if !candidates.contains(&path) { + candidates.push(path); + } + } + + if let Some(path) = candidates.iter().find(|path| path.exists()) { + Ok(path.clone()) + } else { + Err(format!( + "unable to find llvm@{} {} under any known Homebrew prefix: {}", + self.version, + tool, + candidates + .iter() + .map(|path| path.display().to_string()) + .collect::>() + .join(", ") + )) + } + } + + fn command(self, tool: &str, pointer_mode: PointerMode) -> Result { + let mut command = Command::new(self.tool_path(tool)?); + for arg in pointer_mode.extra_args(self.version) { + command.arg(arg); + } + Ok(command) + } + + fn ensure_available(self) -> Result<(), String> { + for tool in ["llvm-as", "llvm-dis", "opt"] { + let output = Command::new(self.tool_path(tool)?) + .arg("--version") + .output() + .map_err(|error| { + format!( + "failed to run llvm@{} {} --version: {error}", + self.version, tool + ) + })?; + if !output.status.success() { + return Err(format!( + "llvm@{} {} --version failed: {}", + self.version, + tool, + summarize_output(&output) + )); + } + } + + Ok(()) + } + + fn verify_bitcode(self, pointer_mode: PointerMode, bitcode_path: &Path) -> Result<(), String> { + let output = self + .command("opt", pointer_mode)? + .arg("-passes=verify") + .arg(bitcode_path) + .arg("-disable-output") + .output() + .map_err(|error| { + format!( + "failed to run llvm@{} opt on {}: {error}", + self.version, + bitcode_path.display() + ) + })?; + + if output.status.success() { + Ok(()) + } else { + Err(format!( + "llvm@{} opt verify failed for {}: {}", + self.version, + bitcode_path.display(), + summarize_output(&output) + )) + } + } + + fn disassemble_bitcode( + self, + pointer_mode: PointerMode, + input_path: &Path, + output_path: &Path, + ) -> Result<(), String> { + let output = self + .command("llvm-dis", pointer_mode)? + .arg(input_path) + .arg("-o") + .arg(output_path) + .output() + .map_err(|error| { + format!( + "failed to run llvm@{} llvm-dis on {}: {error}", + self.version, + input_path.display() + ) + })?; + + if output.status.success() { + Ok(()) + } else { + Err(format!( + "llvm@{} llvm-dis failed for {}: {}", + self.version, + input_path.display(), + summarize_output(&output) + )) + } + } + + fn assemble_text( + self, + pointer_mode: PointerMode, + input_path: &Path, + output_path: &Path, + ) -> Result<(), String> { + let output = self + .command("llvm-as", pointer_mode)? + .arg(input_path) + .arg("-o") + .arg(output_path) + .output() + .map_err(|error| { + format!( + "failed to run llvm@{} llvm-as on {}: {error}", + self.version, + input_path.display() + ) + })?; + + if output.status.success() { + Ok(()) + } else { + Err(format!( + "llvm@{} llvm-as failed for {}: {}", + self.version, + input_path.display(), + summarize_output(&output) + )) + } + } +} + +fn main() -> ExitCode { + match run() { + Ok(Some(summary)) => { + println!( + "exported {} valid artifacts across {} replay steps into {}", + summary.exported_artifacts, + summary.replay_steps, + summary.output_dir.display() + ); + ExitCode::SUCCESS + } + Ok(None) => ExitCode::SUCCESS, + Err(error) => { + eprintln!("qir_matrix failed: {error}"); + ExitCode::FAILURE + } + } +} + +fn run() -> Result, String> { + let args = env::args().skip(1).collect::>(); + if args + .iter() + .any(|arg| matches!(arg.as_str(), "--help" | "-h")) + { + println!("{}", usage()); + return Ok(None); + } + + let options = Options::parse(&args)?; + for toolchain in &options.toolchains { + toolchain.ensure_available()?; + } + + fs::create_dir_all(&options.output_dir) + .map_err(|error| format!("create {}: {error}", options.output_dir.display()))?; + + let seeds = load_seed_inputs(&options.corpus_dir)?; + run_replay_matrix(&options, &seeds).map(Some) +} + +impl Options { + fn parse(args: &[String]) -> Result { + let mut toolchains = parse_toolchains( + &FAST_TOOLCHAINS + .iter() + .map(u8::to_string) + .collect::>() + .join(","), + )?; + let mut corpus_dir = default_qir_corpus_dir(); + let mut output_dir = default_output_dir(); + + let mut index = 0; + while index < args.len() { + match args[index].as_str() { + "--toolchains" => { + index += 1; + let value = args + .get(index) + .ok_or_else(|| "missing value for --toolchains".to_string())?; + toolchains = parse_toolchains(value)?; + } + "--corpus-dir" => { + index += 1; + let value = args + .get(index) + .ok_or_else(|| "missing value for --corpus-dir".to_string())?; + corpus_dir = PathBuf::from(value); + } + "--output-dir" => { + index += 1; + let value = args + .get(index) + .ok_or_else(|| "missing value for --output-dir".to_string())?; + output_dir = PathBuf::from(value); + } + other => { + return Err(format!("unrecognized argument {other:?}\n\n{}", usage())); + } + } + index += 1; + } + + Ok(Self { + toolchains, + corpus_dir, + output_dir, + }) + } +} + +fn run_replay_matrix(options: &Options, seeds: &[QirSeedInput]) -> Result { + let mut exported_artifacts = 0; + let mut replay_steps = 0; + + for seed in seeds { + let artifact = + generate_checked_seed_artifact(seed.profile, &seed.bytes).map_err(|error| { + format!( + "checked artifact generation failed for {} ({}): {error}", + seed.name, + seed.path.display() + ) + })?; + let pointer_mode = PointerMode::from_artifact(&artifact); + let text = artifact + .text + .clone() + .unwrap_or_else(|| write_module_to_string(&artifact.module)); + let artifact_path = options.output_dir.join(format!("{}.ll", seed.name)); + fs::write(&artifact_path, text) + .map_err(|error| format!("write {}: {error}", artifact_path.display()))?; + exported_artifacts += 1; + + for toolchain in &options.toolchains { + replay_artifact( + *toolchain, + pointer_mode, + &seed.name, + &artifact_path, + &options.output_dir, + ) + .map_err(|error| format!("seed {}: {error}", seed.name))?; + replay_steps += 1; + } + } + + if exported_artifacts == 0 { + return Err(format!( + "no checked-valid seed artifacts were found under {}", + options.corpus_dir.display() + )); + } + + Ok(ReplaySummary { + exported_artifacts, + replay_steps, + output_dir: options.output_dir.clone(), + }) +} + +fn replay_artifact( + toolchain: LlvmToolchain, + pointer_mode: PointerMode, + seed_name: &str, + artifact_path: &Path, + output_dir: &Path, +) -> Result<(), String> { + let assembled_bitcode_path = + output_dir.join(format!("{seed_name}.llvm{}.bc", toolchain.version)); + let disassembly_path = output_dir.join(format!("{seed_name}.llvm{}.ll", toolchain.version)); + let roundtrip_bitcode_path = + output_dir.join(format!("{seed_name}.llvm{}.rt.bc", toolchain.version)); + + toolchain.assemble_text(pointer_mode, artifact_path, &assembled_bitcode_path)?; + toolchain.verify_bitcode(pointer_mode, &assembled_bitcode_path)?; + toolchain.disassemble_bitcode(pointer_mode, &assembled_bitcode_path, &disassembly_path)?; + toolchain.assemble_text(pointer_mode, &disassembly_path, &roundtrip_bitcode_path)?; + toolchain.verify_bitcode(pointer_mode, &roundtrip_bitcode_path) +} + +fn parse_toolchains(csv: &str) -> Result, String> { + let mut toolchains = Vec::new(); + for raw_version in csv.split(',') { + let version_text = raw_version.trim(); + if version_text.is_empty() { + continue; + } + + let version = version_text.parse::().map_err(|error| { + format!("failed to parse LLVM toolchain version {version_text:?}: {error}") + })?; + if !(MIN_LLVM_VERSION..=MAX_LLVM_VERSION).contains(&version) { + return Err(format!( + "LLVM toolchain version {version} is out of range; expected {MIN_LLVM_VERSION} through {MAX_LLVM_VERSION}" + )); + } + if !toolchains.contains(&LlvmToolchain { version }) { + toolchains.push(LlvmToolchain { version }); + } + } + + if toolchains.is_empty() { + return Err("no LLVM toolchains requested".to_string()); + } + + Ok(toolchains) +} + +fn default_output_dir() -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + env::temp_dir().join(format!("qir-matrix-{}-{nanos}", std::process::id())) +} + +fn summarize_output(output: &Output) -> String { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + if !stderr.is_empty() { + return stderr; + } + + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if stdout.is_empty() { + "command produced no output".to_string() + } else { + stdout + } +} + +fn usage() -> String { + format!( + "Usage: cargo run -p fuzz --bin qir_matrix -- [--toolchains 14,15,16,21] [--corpus-dir PATH] [--output-dir PATH]\n\nDefault corpus directory: {}\nFast matrix default: 14,15,16,21\nFull matrix example: --toolchains 14,15,16,17,18,19,20,21", + default_qir_corpus_dir().display() + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_toolchains_accepts_fast_and_full_matrix() { + let fast = parse_toolchains("14,15,16,21").expect("fast matrix should parse"); + assert_eq!( + fast.iter() + .map(|toolchain| toolchain.version) + .collect::>(), + vec![14, 15, 16, 21] + ); + + let full = parse_toolchains("14,15,16,17,18,19,20,21").expect("full matrix should parse"); + assert_eq!( + full.iter() + .map(|toolchain| toolchain.version) + .collect::>(), + vec![14, 15, 16, 17, 18, 19, 20, 21] + ); + } + + #[test] + fn parse_toolchains_rejects_versions_outside_supported_range() { + assert!(parse_toolchains("13").is_err()); + assert!(parse_toolchains("22").is_err()); + } +} diff --git a/source/fuzz/src/lib.rs b/source/fuzz/src/lib.rs new file mode 100644 index 0000000000..6c03de6dbb --- /dev/null +++ b/source/fuzz/src/lib.rs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +pub mod qir_seed_bank; diff --git a/source/fuzz/src/qir_seed_bank.rs b/source/fuzz/src/qir_seed_bank.rs new file mode 100644 index 0000000000..bc84d43db2 --- /dev/null +++ b/source/fuzz/src/qir_seed_bank.rs @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fs; +use std::path::{Path, PathBuf}; + +use qsc_llvm::fuzz::mutation::{ + MutationKind, SeedMutator, dispatch_mutation_family, mutation_selector, + validate_mutated_module, validate_seed_artifact, +}; +use qsc_llvm::fuzz::qir_mutations::mutate_adaptive_v1_typed_pointer_seed; +use qsc_llvm::{ + GeneratedArtifact, Module, QirProfilePreset, QirSmithConfig, QirSmithError, + generate_checked_from_bytes, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QirSeedInput { + pub name: String, + pub profile: QirProfilePreset, + pub bytes: Vec, + pub path: PathBuf, +} + +#[must_use] +pub fn default_qir_corpus_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("corpus/qir") +} + +pub fn compile_seed_profile(profile: QirProfilePreset, profile_name: &str, data: &[u8]) { + let artifact = generate_seed_artifact(profile, profile_name, data); + validate_seed_artifact(&artifact); +} + +pub fn compile(data: &[u8]) { + compile_seed_profile(QirProfilePreset::AdaptiveV2, "AdaptiveV2", data); +} + +pub fn compile_base_v1(data: &[u8]) { + compile_seed_profile(QirProfilePreset::BaseV1, "BaseV1", data); +} + +fn compile_mutated_profile( + profile: QirProfilePreset, + profile_name: &str, + data: &[u8], + mutator: SeedMutator, +) { + let artifact = generate_seed_artifact(profile, profile_name, data); + validate_seed_artifact(&artifact); + + let mutated = mutator(&artifact.module, data); + validate_mutated_module(&mutated); +} + +fn mutate_adaptive_v2_seed(seed: &Module, data: &[u8]) -> Module { + let mut mutated = seed.clone(); + dispatch_mutation_family( + &mut mutated, + MutationKind::from_data(data), + mutation_selector(data, 1), + ); + mutated +} + +fn mutate_typed_pointer_seed(seed: &Module, data: &[u8]) -> Module { + mutate_adaptive_v1_typed_pointer_seed(seed, mutation_selector(data, 0)) +} + +pub fn compile_mutated_adaptive_v1(data: &[u8]) { + compile_mutated_profile( + QirProfilePreset::AdaptiveV1, + "AdaptiveV1", + data, + mutate_typed_pointer_seed, + ); +} + +pub fn compile_mutated_adaptive_v2(data: &[u8]) { + compile_mutated_profile( + QirProfilePreset::AdaptiveV2, + "AdaptiveV2", + data, + mutate_adaptive_v2_seed, + ); +} + +fn generate_seed_artifact( + profile: QirProfilePreset, + profile_name: &str, + data: &[u8], +) -> GeneratedArtifact { + generate_checked_seed_artifact(profile, data) + .unwrap_or_else(|err| panic!("qir_smith {profile_name} checked generation failed: {err}")) +} + +pub fn load_seed_inputs(corpus_dir: &Path) -> Result, String> { + let mut entries = fs::read_dir(corpus_dir) + .map_err(|error| format!("read {}: {error}", corpus_dir.display()))? + .collect::, _>>() + .map_err(|error| format!("read {}: {error}", corpus_dir.display()))?; + entries.sort_by_key(std::fs::DirEntry::file_name); + + let mut seeds = Vec::new(); + for entry in entries { + if !entry + .file_type() + .map_err(|error| format!("inspect {}: {error}", entry.path().display()))? + .is_file() + { + continue; + } + + let path = entry.path(); + let Some(file_name) = path.file_name().and_then(|value| value.to_str()) else { + continue; + }; + if path.extension().and_then(|value| value.to_str()) != Some("seed") { + continue; + } + + let Some(profile) = profile_from_seed_file_name(file_name) else { + return Err(format!( + "seed file {} must start with base-v1-, adaptive-v1-, adaptive-v2-, or bare-roundtrip-", + path.display() + )); + }; + + let bytes = fs::read(&path).map_err(|error| format!("read {}: {error}", path.display()))?; + if bytes.is_empty() { + return Err(format!("seed file {} is empty", path.display())); + } + + let name = file_name.trim_end_matches(".seed").to_string(); + seeds.push(QirSeedInput { + name, + profile, + bytes, + path, + }); + } + + if seeds.is_empty() { + return Err(format!( + "no .seed inputs found under {}", + corpus_dir.display() + )); + } + + Ok(seeds) +} + +pub fn generate_checked_seed_artifact( + profile: QirProfilePreset, + bytes: &[u8], +) -> Result { + let config = QirSmithConfig::for_profile(profile); + generate_checked_from_bytes(&config, bytes) +} + +fn profile_from_seed_file_name(file_name: &str) -> Option { + if file_name.starts_with("base-v1-") { + Some(QirProfilePreset::BaseV1) + } else if file_name.starts_with("adaptive-v1-") { + Some(QirProfilePreset::AdaptiveV1) + } else if file_name.starts_with("adaptive-v2-") { + Some(QirProfilePreset::AdaptiveV2) + } else if file_name.starts_with("bare-roundtrip-") { + Some(QirProfilePreset::BareRoundtrip) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use qsc_llvm::validate_qir_profile; + + const ADAPTIVE_V1_TYPED_FIXTURE: &[u8] = include_bytes!("../corpus/qir/adaptive-v1-typed.seed"); + + #[test] + fn profile_prefixes_map_to_expected_qir_profiles() { + assert_eq!( + profile_from_seed_file_name("base-v1-smoke.seed"), + Some(QirProfilePreset::BaseV1) + ); + assert_eq!( + profile_from_seed_file_name("adaptive-v1-smoke.seed"), + Some(QirProfilePreset::AdaptiveV1) + ); + assert_eq!( + profile_from_seed_file_name("adaptive-v2-smoke.seed"), + Some(QirProfilePreset::AdaptiveV2) + ); + assert_eq!( + profile_from_seed_file_name("bare-roundtrip-smoke.seed"), + Some(QirProfilePreset::BareRoundtrip) + ); + assert_eq!(profile_from_seed_file_name("unexpected.seed"), None); + } + + #[test] + fn checked_seed_artifact_emits_text_for_typed_profiles() { + let seed_bytes: Vec = (0_u8..=127).collect(); + let artifact = generate_checked_seed_artifact(QirProfilePreset::AdaptiveV1, &seed_bytes) + .expect("typed replay artifact generation should succeed"); + + assert!(artifact.text.is_some()); + assert!(artifact.bitcode.is_none()); + } + + #[test] + fn checked_seed_artifact_replays_base_v1_fixture() { + let seed_bytes = + b"base-v1|entry|0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ|0001"; + let artifact = generate_checked_seed_artifact(QirProfilePreset::BaseV1, seed_bytes) + .expect("BaseV1 replay artifact generation should succeed"); + + assert!(artifact.text.is_some()); + assert!(artifact.bitcode.is_none()); + assert!( + validate_qir_profile(&artifact.module).errors.is_empty(), + "BaseV1 replay artifact should satisfy the QIR profile" + ); + } + + #[test] + fn adaptive_v2_checked_generation_accepts_adaptive_v1_typed_fixture() { + let artifact = + generate_checked_seed_artifact(QirProfilePreset::AdaptiveV2, ADAPTIVE_V1_TYPED_FIXTURE) + .expect("AdaptiveV2 generation should accept the adaptive-v1 typed corpus seed"); + + assert!(artifact.text.is_some()); + assert!(artifact.bitcode.is_some()); + assert!( + validate_qir_profile(&artifact.module).errors.is_empty(), + "AdaptiveV2 artifact should satisfy the QIR profile for the typed fixture" + ); + } +} diff --git a/source/pip/Cargo.toml b/source/pip/Cargo.toml index 70c4f0888e..06b3d49465 100644 --- a/source/pip/Cargo.toml +++ b/source/pip/Cargo.toml @@ -16,6 +16,8 @@ num-bigint = { workspace = true } num-complex = { workspace = true } num-traits = { workspace = true } qsc = { path = "../compiler/qsc" } +qsc_codegen = { path = "../compiler/qsc_codegen" } +qsc_llvm = { path = "../compiler/qsc_llvm" } qdk_simulators = { path = "../simulators" } resource_estimator = { path = "../resource_estimator" } miette = { workspace = true, features = ["fancy"] } diff --git a/source/pip/qsharp/_adaptive_pass.py b/source/pip/qsharp/_adaptive_pass.py deleted file mode 100644 index 81c7c4b6f6..0000000000 --- a/source/pip/qsharp/_adaptive_pass.py +++ /dev/null @@ -1,986 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""AdaptiveProfilePass: walks Adaptive Profile QIR and emits the intermediate -format consumed by Rust. - -Unlike ``AggregateGatesPass`` (which subclasses ``pyqir.QirModuleVisitor`` and -only dispatches CALL instructions), this pass iterates basic blocks and -instructions directly so it can handle *all* LLVM IR opcodes required by the -Adaptive Profile specification. -""" - -from __future__ import annotations -from dataclasses import dataclass, astuple -import pyqir -import struct -from typing import Any, Dict, List, Optional, Tuple, TypeAlias, cast -from ._adaptive_bytecode import * - -# --------------------------------------------------------------------------- -# Gate name → OpID mapping (must match shader_types.rs OpID enum) -# --------------------------------------------------------------------------- - -GATE_MAP: Dict[str, int] = { - "reset": 1, - "x": 2, - "y": 3, - "z": 4, - "h": 5, - "s": 6, - "s__adj": 7, - "t": 8, - "t__adj": 9, - "sx": 10, - "sx__adj": 11, - "rx": 12, - "ry": 13, - "rz": 14, - "cnot": 15, - "cx": 15, - "cz": 16, - "cy": 29, - "rxx": 17, - "ryy": 18, - "rzz": 19, - "ccx": 20, - "m": 21, - "mz": 21, - "mresetz": 22, - "swap": 24, -} - -# Gates that take a result ID as a second argument -MEASURE_GATES = {"m", "mz", "mresetz"} - -# Gates that reset a qubit (single qubit argument, no result) -RESET_GATES = {"reset"} - -# Rotation gates that take an angle parameter as first argument -ROTATION_GATES = {"rx", "ry", "rz", "rxx", "ryy", "rzz"} - -# --------------------------------------------------------------------------- -# ICmp / FCmp predicate mappings -# --------------------------------------------------------------------------- - -ICMP_MAP = { - pyqir.IntPredicate.EQ: ICMP_EQ, - pyqir.IntPredicate.NE: ICMP_NE, - pyqir.IntPredicate.SLT: ICMP_SLT, - pyqir.IntPredicate.SLE: ICMP_SLE, - pyqir.IntPredicate.SGT: ICMP_SGT, - pyqir.IntPredicate.SGE: ICMP_SGE, - pyqir.IntPredicate.ULT: ICMP_ULT, - pyqir.IntPredicate.ULE: ICMP_ULE, - pyqir.IntPredicate.UGT: ICMP_UGT, - pyqir.IntPredicate.UGE: ICMP_UGE, -} - -FCMP_MAP = { - pyqir.FloatPredicate.FALSE: FCMP_FALSE, - pyqir.FloatPredicate.OEQ: FCMP_OEQ, - pyqir.FloatPredicate.OGT: FCMP_OGT, - pyqir.FloatPredicate.OGE: FCMP_OGE, - pyqir.FloatPredicate.OLT: FCMP_OLT, - pyqir.FloatPredicate.OLE: FCMP_OLE, - pyqir.FloatPredicate.ONE: FCMP_ONE, - pyqir.FloatPredicate.ORD: FCMP_ORD, - pyqir.FloatPredicate.UNO: FCMP_UNO, - pyqir.FloatPredicate.UEQ: FCMP_UEQ, - pyqir.FloatPredicate.UGT: FCMP_UGT, - pyqir.FloatPredicate.UGE: FCMP_UGE, - pyqir.FloatPredicate.ULT: FCMP_ULT, - pyqir.FloatPredicate.ULE: FCMP_ULE, - pyqir.FloatPredicate.UNE: FCMP_UNE, - pyqir.FloatPredicate.TRUE: FCMP_TRUE, -} - - -@dataclass -class AdaptiveProgram: - num_qubits: int - num_results: int - num_registers: int - entry_block: int - blocks: List[Block] - instructions: List[Instruction] - quantum_ops: List[QuantumOp] - functions: List[Function] - phi_entries: List[PhiNodeEntry] - switch_cases: List[SwitchCase] - call_args: List[CallArg] - labels: List[Label] - register_types: List[RegisterType] - - def as_dict(self): - """ - Transforms the program to a dictionary, and each of - the helper dataclasses to a tuple. This format is intended - to be used in the FFI between Python and Rust. - """ - return { - "num_qubits": self.num_qubits, - "num_results": self.num_results, - "num_registers": self.num_registers, - "entry_block": self.entry_block, - "blocks": [astuple(x) for x in self.blocks], - "instructions": [astuple(x) for x in self.instructions], - "quantum_ops": [astuple(x) for x in self.quantum_ops], - "functions": [astuple(x) for x in self.functions], - "phi_entries": [astuple(x) for x in self.phi_entries], - "switch_cases": [astuple(x) for x in self.switch_cases], - "call_args": self.call_args, - "labels": self.labels, - "register_types": self.register_types, - } - - -@dataclass -class Block: - block_id: int - instr_offset: int - instr_count: int - - -@dataclass -class Instruction: - opcode: int - dst: int - src0: int - src1: int - aux0: int - aux1: int - aux2: int - aux3: int - - -@dataclass -class QuantumOp: - op_id: int - q1: int - q2: int - q3: int - angle: float - - -@dataclass -class Function: - func_entry_block: int - num_params: int - param_base: int - - -@dataclass -class PhiNodeEntry: - block_id: int - val_reg: int - - -@dataclass -class SwitchCase: - case_val: int - target_block: int - - -# OpID for correlated noise (must match shader_types.rs OpID::CorrelatedNoise) -CORRELATED_NOISE_OP_ID = 131 - -CallArg: TypeAlias = int -Label: TypeAlias = str -RegisterType: TypeAlias = int - - -@dataclass -class IntOperand: - val: int = 0 - - def __post_init__(self): - # Mask to u32 range so negative Python ints become their - # two's-complement u32 representation (e.g. -7 → 0xFFFFFFF9). - self.val = self.val & 0xFFFFFFFF - - -class FloatOperand: - def __init__(self, val: float = 0.0) -> None: - self.val: int = encode_float_as_bits(val) - - -@dataclass -class Reg: - val: int # index in the registers table - - -def is_immediate(arg) -> bool: - return isinstance(arg, (IntOperand, FloatOperand)) - - -def prepare_immediate_flags( - *, dst=None, src0=None, src1=None, aux0=None, aux1=None, aux2=None, aux3=None -): - flags = 0 - if is_immediate(dst): - flags |= FLAG_DST_IMM - if is_immediate(src0): - flags |= FLAG_SRC0_IMM - if is_immediate(src1): - flags |= FLAG_SRC1_IMM - if is_immediate(aux0): - flags |= FLAG_AUX0_IMM - if is_immediate(aux1): - flags |= FLAG_AUX1_IMM - if is_immediate(aux2): - flags |= FLAG_AUX2_IMM - if is_immediate(aux3): - flags |= FLAG_AUX3_IMM - return flags - - -def unwrap_operands( - dst, src0, src1, aux0, aux1, aux2, aux3 -) -> Tuple[int, int, int, int, int, int, int]: - if not isinstance(dst, int): - dst = dst.val - if not isinstance(src0, int): - src0 = src0.val - if not isinstance(src1, int): - src1 = src1.val - if not isinstance(aux0, int): - aux0 = aux0.val - if not isinstance(aux1, int): - aux1 = aux1.val - if not isinstance(aux2, int): - aux2 = aux2.val - if not isinstance(aux3, int): - aux3 = aux3.val - return (dst, src0, src1, aux0, aux1, aux2, aux3) - - -def encode_float_as_bits(val: float) -> int: - return struct.unpack(" AdaptiveProgram: - """Process module and return the AdaptiveProgram. - - Args: - mod: The QIR module to process. - noise: Optional NoiseConfig. When provided, noise intrinsic calls - are resolved to correlated noise ops using the intrinsics table. - noise_intrinsics: Optional dict mapping noise intrinsic callee names - to noise table IDs. Takes precedence over ``noise`` if both are - given. - """ - if mod.get_flag("arrays"): - raise ValueError("QIR arrays are not currently supported.") - - if noise_intrinsics is not None: - self._noise_intrinsics = noise_intrinsics - elif noise is not None: - # Build {name: table_id} mapping from the NoiseConfig intrinsics - intrinsics = noise.intrinsics - self._noise_intrinsics = {} - for callee_name in mod.functions: - name = callee_name.name - if name in intrinsics: - self._noise_intrinsics[name] = intrinsics.get_intrinsic_id(name) - - errors = mod.verify() - if errors is not None: - raise ValueError(f"Module verification failed: {errors}") - - # Pass 1: Assign block IDs and function IDs for all defined functions - for func in mod.functions: - if len(func.basic_blocks) > 0: - self._assign_function(func) - - # Pass 2: Walk instructions and emit encoding - for func in mod.functions: - if len(func.basic_blocks) > 0: - self._walk_function(func) - - entry_func = next(filter(pyqir.is_entry_point, mod.functions)) - num_qubits = pyqir.required_num_qubits(entry_func) - num_results = pyqir.required_num_results(entry_func) - assert isinstance(num_qubits, int) - assert isinstance(num_results, int) - - return AdaptiveProgram( - num_qubits=num_qubits, - num_results=num_results, - num_registers=self._next_reg, - entry_block=self._block_to_id[entry_func.basic_blocks[0]], - blocks=self.blocks, - instructions=self.instructions, - quantum_ops=self.quantum_ops, - functions=self.functions, - phi_entries=self.phi_entries, - switch_cases=self.switch_cases, - call_args=self.call_args, - labels=self.labels, - register_types=self.register_types, - ) - - # ------------------------------------------------------------------ - # Register allocation - # ------------------------------------------------------------------ - - def _alloc_reg(self, value: Any, type_tag: int) -> Reg: - """Allocate a new register for `value` and record its type. - - If `value` was already pre-allocated (e.g. as a forward reference from - a phi node), return the existing register instead of allocating a new - one. - """ - if value is not None and value in self._value_to_reg: - return self._value_to_reg[value] - reg = Reg(self._next_reg) - self._next_reg += 1 - if value is not None: - self._value_to_reg[value] = reg - self.register_types.append(type_tag) - return reg - - # ------------------------------------------------------------------ - # Instruction emission - # ------------------------------------------------------------------ - - def _emit( - self, - opcode: int, - *, - dst: int | IntOperand | FloatOperand | Reg = 0, - src0: int | IntOperand | FloatOperand | Reg = 0, - src1: int | IntOperand | FloatOperand | Reg = 0, - aux0: int | IntOperand | FloatOperand | Reg = 0, - aux1: int | IntOperand | FloatOperand | Reg = 0, - aux2: int | IntOperand | FloatOperand | Reg = 0, - aux3: int | IntOperand | FloatOperand | Reg = 0, - ) -> None: - imm_flags = prepare_immediate_flags( - dst=dst, src0=src0, src1=src1, aux0=aux0, aux1=aux1, aux2=aux2, aux3=aux3 - ) - (dst, src0, src1, aux0, aux1, aux2, aux3) = unwrap_operands( - dst, src0, src1, aux0, aux1, aux2, aux3 - ) - ins = Instruction(opcode | imm_flags, dst, src0, src1, aux0, aux1, aux2, aux3) - self.instructions.append(ins) - - def _emit_quantum_op( - self, - op_id: int, - q1: int = 0, - q2: int = 0, - q3: int = 0, - angle: float = 0.0, - ) -> int: - idx = self._next_qop - self._next_qop += 1 - qop = QuantumOp(op_id, q1, q2, q3, angle) - self.quantum_ops.append(qop) - return idx - - # ------------------------------------------------------------------ - # Operand resolution - # ------------------------------------------------------------------ - - def _resolve_operand(self, value: pyqir.Value) -> IntOperand | FloatOperand | Reg: - """Resolve a pyqir Value to a register index. - - If `value` is an already-assigned SSA register, return its index. - If `value` is an integer constant, allocate a register and emit - ``OP_CONST`` to materialise it. - """ - if value in self._value_to_reg: - return self._value_to_reg[value] - - if isinstance(value, pyqir.IntConstant): - val = value.value - return IntOperand(val) - - if isinstance(value, pyqir.FloatConstant): - val = value.value - return FloatOperand(val) - - # Forward reference (e.g. phi incoming from a later block). - # Pre-allocate a register; the defining instruction will reuse it - # via _alloc_reg's dedup check. - if isinstance(value, pyqir.Instruction): - return self._alloc_reg(value, self._type_tag(value.type)) - - # Constant expressions (e.g. inttoptr (i64 N to %Qubit*)). - if isinstance(value, pyqir.Constant): - # Try extracting as a qubit/result pointer constant. - qid = pyqir.qubit_id(value) - if qid is not None: - return IntOperand(qid) - rid = pyqir.result_id(value) - if rid is not None: - return IntOperand(rid) - # Null pointer - if value.is_null: - reg = self._alloc_reg(value, REG_TYPE_PTR) - self._emit(OP_CONST | FLAG_SRC0_IMM, dst=reg.val, src0=0) - return reg - - raise ValueError(f"Cannot resolve operand: {type(value).__name__}") - - def _type_tag(self, ty: Any) -> int: - """Map a pyqir Type to a register type tag.""" - if isinstance(ty, pyqir.IntType): - w = ty.width - if w == 1: - return REG_TYPE_BOOL - if w <= 32: - return REG_TYPE_I32 - return REG_TYPE_I64 - if isinstance(ty, pyqir.PointerType): - return REG_TYPE_PTR - if ty.is_double: - return REG_TYPE_F64 - # Remaining floating-point types (e.g. float/f32) - return REG_TYPE_F32 - - # ------------------------------------------------------------------ - # Binary / unary helpers - # ------------------------------------------------------------------ - - def _emit_binary(self, opcode: int, instr: Any) -> None: - """Emit a binary arithmetic/bitwise instruction.""" - dst = self._alloc_reg(instr, self._type_tag(instr.type)) - src0 = self._resolve_operand(instr.operands[0]) - src1 = self._resolve_operand(instr.operands[1]) - self._emit(opcode, dst=dst, src0=src0, src1=src1) - - def _emit_unary(self, opcode: int, instr: Any) -> None: - """Emit a unary conversion instruction.""" - dst = self._alloc_reg(instr, self._type_tag(instr.type)) - src0 = self._resolve_operand(instr.operands[0]) - self._emit(opcode, dst=dst, src0=src0) - - def _emit_sext(self, instr: Any) -> None: - """Emit OP_SEXT with source bit width in aux0.""" - dst = self._alloc_reg(instr, self._type_tag(instr.type)) - src0 = self._resolve_operand(instr.operands[0]) - src_type = instr.operands[0].type - src_bits = src_type.width if isinstance(src_type, pyqir.IntType) else 32 - self._emit(OP_SEXT, dst=dst, src0=src0, aux0=src_bits) - - # ------------------------------------------------------------------ - # Function assignment (Pass 1) - # ------------------------------------------------------------------ - - def _assign_function(self, func: pyqir.Function) -> None: - """Assign block IDs and function IDs for a function.""" - if not pyqir.is_entry_point(func) and func.name not in self._func_to_id: - func_id = len(self._func_to_id) - self._func_to_id[func.name] = func_id - for block in func.basic_blocks: - self._block_to_id[block] = self._next_block - self._next_block += 1 - - # ------------------------------------------------------------------ - # Function walking (Pass 2) - # ------------------------------------------------------------------ - - def _walk_function(self, func: pyqir.Function) -> None: - """Walk all blocks and instructions in a function, emitting bytecode.""" - self._current_func_is_entry = pyqir.is_entry_point(func) - - # For non-entry functions, register parameters as registers - if not self._current_func_is_entry: - param_base = self._next_reg - for param in func.params: - self._alloc_reg( - param, REG_TYPE_PTR - ) # params are pointers (%Qubit*, %Result*) - # Record function entry in the function table - if func.name in self._func_to_id: - func_entry_block = self._block_to_id[func.basic_blocks[0]] - f = Function(func_entry_block, len(func.params), param_base) - self.functions.append(f) - - for block in func.basic_blocks: - block_id = self._block_to_id[block] - instr_offset = len(self.instructions) - for instr in block.instructions: - self._on_instruction(instr) - # NOTE: block.terminator is already included in block.instructions - # in pyqir, so we do NOT separately process it. - instr_count = len(self.instructions) - instr_offset - blk = Block(block_id, instr_offset, instr_count) - self.blocks.append(blk) - - # ------------------------------------------------------------------ - # Instruction dispatch - # ------------------------------------------------------------------ - - def _on_instruction(self, instr: pyqir.Instruction) -> None: - """Dispatch a single instruction by opcode.""" - match instr.opcode: - case pyqir.Opcode.CALL: - self._emit_call(cast(pyqir.Call, instr)) - case pyqir.Opcode.PHI: - self._emit_phi(cast(pyqir.Phi, instr)) - case pyqir.Opcode.ICMP: - self._emit_icmp(cast(pyqir.ICmp, instr)) - case pyqir.Opcode.FCMP: - self._emit_fcmp(cast(pyqir.FCmp, instr)) - case pyqir.Opcode.SWITCH: - self._emit_switch(cast(pyqir.Switch, instr)) - case pyqir.Opcode.BR: - self._emit_branch(instr) - case pyqir.Opcode.RET: - self._emit_ret(instr) - case pyqir.Opcode.SELECT: - self._emit_select(instr) - case pyqir.Opcode.ADD: - self._emit_binary(OP_ADD, instr) - case pyqir.Opcode.SUB: - self._emit_binary(OP_SUB, instr) - case pyqir.Opcode.MUL: - self._emit_binary(OP_MUL, instr) - case pyqir.Opcode.UDIV: - self._emit_binary(OP_UDIV, instr) - case pyqir.Opcode.SDIV: - self._emit_binary(OP_SDIV, instr) - case pyqir.Opcode.UREM: - self._emit_binary(OP_UREM, instr) - case pyqir.Opcode.SREM: - self._emit_binary(OP_SREM, instr) - case pyqir.Opcode.AND: - self._emit_binary(OP_AND, instr) - case pyqir.Opcode.OR: - self._emit_binary(OP_OR, instr) - case pyqir.Opcode.XOR: - self._emit_binary(OP_XOR, instr) - case pyqir.Opcode.SHL: - self._emit_binary(OP_SHL, instr) - case pyqir.Opcode.LSHR: - self._emit_binary(OP_LSHR, instr) - case pyqir.Opcode.ASHR: - self._emit_binary(OP_ASHR, instr) - case pyqir.Opcode.ZEXT: - self._emit_unary(OP_ZEXT, instr) - case pyqir.Opcode.SEXT: - self._emit_sext(instr) - case pyqir.Opcode.TRUNC: - self._emit_unary(OP_TRUNC, instr) - case pyqir.Opcode.FADD: - self._emit_binary(OP_FADD | FLAG_FLOAT, instr) - case pyqir.Opcode.FSUB: - self._emit_binary(OP_FSUB | FLAG_FLOAT, instr) - case pyqir.Opcode.FMUL: - self._emit_binary(OP_FMUL | FLAG_FLOAT, instr) - case pyqir.Opcode.FDIV: - self._emit_binary(OP_FDIV | FLAG_FLOAT, instr) - case pyqir.Opcode.FP_EXT: - self._emit_unary(OP_FPEXT | FLAG_FLOAT, instr) - case pyqir.Opcode.FP_TRUNC: - self._emit_unary(OP_FPTRUNC | FLAG_FLOAT, instr) - case pyqir.Opcode.FP_TO_SI: - self._emit_unary(OP_FPTOSI, instr) - case pyqir.Opcode.SI_TO_FP: - self._emit_unary(OP_SITOFP | FLAG_FLOAT, instr) - case pyqir.Opcode.INT_TO_PTR: - self._emit_inttoptr(instr) - case _: - raise ValueError(f"Unsupported instruction: {instr.opcode}") - - # ------------------------------------------------------------------ - # Call dispatch - # ------------------------------------------------------------------ - - def _emit_call(self, call: pyqir.Call) -> None: - """Dispatch a CALL instruction based on callee name.""" - callee = call.callee.name - - match callee: - case "__quantum__qis__read_result__body" | "__quantum__rt__read_result": - dst = self._alloc_reg(call, REG_TYPE_BOOL) - result_reg = self._resolve_result_operand(call.args[0]) - self._emit(OP_READ_RESULT, dst=dst, src0=result_reg) - case _ if callee.startswith("__quantum__qis__"): - self._emit_quantum_call(call) - case "__quantum__rt__result_record_output": - result_reg = self._resolve_result_operand(call.args[0]) - label_str = self._extract_label(call.args[1]) - label_idx = len(self.labels) - self.labels.append(label_str) - self._emit(OP_RECORD_OUTPUT, src0=result_reg, aux0=label_idx) - case "__quantum__rt__array_record_output": - # Record structure output — pass through as-is for output formatting - count = ( - call.args[0].value - if isinstance(call.args[0], pyqir.IntConstant) - else 0 - ) - label_str = self._extract_label(call.args[1]) - label_idx = len(self.labels) - self.labels.append(label_str) - self._emit( - OP_RECORD_OUTPUT, src0=count, aux0=label_idx, aux1=1 - ) # aux1=1 -> array - case "__quantum__rt__tuple_record_output": - count = ( - call.args[0].value - if isinstance(call.args[0], pyqir.IntConstant) - else 0 - ) - label_str = self._extract_label(call.args[1]) - label_idx = len(self.labels) - self.labels.append(label_str) - self._emit( - OP_RECORD_OUTPUT, src0=count, aux0=label_idx, aux1=2 - ) # aux1=2 -> tuple - case "__quantum__rt__bool_record_output": - # Bool record output - pass through - src = self._resolve_operand(call.args[0]) - label_str = self._extract_label(call.args[1]) - label_idx = len(self.labels) - self.labels.append(label_str) - self._emit( - OP_RECORD_OUTPUT, src0=src, aux0=label_idx, aux1=3 - ) # aux1=3 -> bool - case "__quantum__rt__int_record_output": - src = self._resolve_operand(call.args[0]) - label_str = self._extract_label(call.args[1]) - label_idx = len(self.labels) - self.labels.append(label_str) - self._emit( - OP_RECORD_OUTPUT, src0=src, aux0=label_idx, aux1=4 - ) # aux1=4 -> int - case ( - "__quantum__rt__initialize" - | "__quantum__rt__begin_parallel" - | "__quantum__rt__end_parallel" - | "__quantum__qis__barrier__body" - | "__quantum__rt__read_loss" - ): - pass # No-op - case _ if callee in self._func_to_id: - self._emit_ir_function_call(call) - case _ if "qdk_noise" in call.callee.attributes.func: - # Check if this is a noise intrinsic (custom gate with qdk_noise attribute) - self._emit_noise_intrinsic_call(call) - case _: - raise ValueError(f"Unsupported call: {callee}") - - # ------------------------------------------------------------------ - # Quantum call dispatch - # ------------------------------------------------------------------ - - def _resolve_qubit_operands( - self, args: List[pyqir.Value] - ) -> Tuple[IntOperand | Reg, IntOperand | Reg, IntOperand | Reg]: - qs: List[IntOperand | Reg] = [IntOperand(), IntOperand(), IntOperand()] - for i, arg in enumerate(args): - qs[i] = self._resolve_qubit_operand(arg) - return (qs[0], qs[1], qs[2]) - - def _resolve_qubit_operand(self, arg: pyqir.Value) -> IntOperand | Reg: - a = self._resolve_operand(arg) - assert isinstance(a, (IntOperand, Reg)) - return a - - def _resolve_result_operand(self, arg: pyqir.Value) -> IntOperand | Reg: - a = self._resolve_operand(arg) - assert isinstance(a, (IntOperand, Reg)) - return a - - def _resolve_angle_operand(self, arg: pyqir.Value) -> FloatOperand | Reg: - a = self._resolve_operand(arg) - assert isinstance(a, (FloatOperand, Reg)) - return a - - def _emit_quantum_call(self, call: pyqir.Call) -> None: - """Emit a quantum gate, measure, or reset from a ``__quantum__qis__*`` call.""" - callee_name = call.callee.name - gate_name = callee_name.replace("__quantum__qis__", "").replace("__body", "") - op_id = GATE_MAP[gate_name] - if gate_name in MEASURE_GATES: - q = self._resolve_qubit_operand(call.args[0]) - r = self._resolve_result_operand(call.args[1]) - qop_idx = self._emit_quantum_op(op_id, q.val, r.val) - self._emit( - OP_MEASURE, - aux0=qop_idx, - aux1=q, - aux2=r, - ) - return - if gate_name in RESET_GATES: - q = self._resolve_qubit_operand(call.args[0]) - qop_idx = self._emit_quantum_op(op_id, q.val) - self._emit( - OP_RESET, - aux0=qop_idx, - aux1=q, - ) - return - if gate_name in ROTATION_GATES: - qubit_arg_offset = 1 - angle = self._resolve_angle_operand(call.args[0]) - else: - qubit_arg_offset = 0 - angle = FloatOperand() - qubit_arg_offset = 1 if gate_name in ROTATION_GATES else 0 - q1, q2, q3 = self._resolve_qubit_operands(call.args[qubit_arg_offset:]) - qop_idx = self._emit_quantum_op(op_id, q1.val, q2.val, q3.val, angle.val) - self._emit( - OP_QUANTUM_GATE, - aux0=qop_idx, - aux1=q1, - aux2=q2, - aux3=q3, - ) - - def _emit_noise_intrinsic_call(self, call: pyqir.Call) -> None: - """Emit a noise intrinsic call. - - When a noise config is provided and the callee is a known intrinsic, - store qubit register indices in ``call_args`` (following the same - pattern as ``_emit_ir_function_call``), then emit a single - ``OP_QUANTUM_GATE`` whose ``aux1`` = qubit count and ``aux2`` = - offset into ``call_args``. The shader reads qubit IDs from - ``call_arg_table`` at runtime, supporting arbitrarily many qubits. - - When no noise config is provided, emit an identity gate (no-op). - """ - callee_name = call.callee.name - if self._noise_intrinsics is not None and callee_name in self._noise_intrinsics: - table_id = self._noise_intrinsics[callee_name] - qubit_count = len(call.args) - # Store qubit register indices in call_args, materializing - # immediates into registers (same pattern as _emit_ir_function_call). - arg_offset = len(self.call_args) - for arg in call.args: - operand = self._resolve_qubit_operand(arg) - if isinstance(operand, Reg): - self.call_args.append(operand.val) - else: - reg = self._alloc_reg(None, REG_TYPE_PTR) - self._emit(OP_MOV | FLAG_SRC0_IMM, dst=reg, src0=operand.val) - self.call_args.append(reg.val) - # QuantumOp stores table_id in q1 and qubit_count in q2. - qop_idx = self._emit_quantum_op( - CORRELATED_NOISE_OP_ID, table_id, qubit_count - ) - self._emit( - OP_QUANTUM_GATE, - aux0=qop_idx, - aux1=IntOperand(qubit_count), - aux2=IntOperand(arg_offset), - ) - elif self._noise_intrinsics is not None: - raise ValueError(f"Missing noise intrinsic: {callee_name}") - else: - # No noise config — no-op - pass - - # ------------------------------------------------------------------ - # Control flow emitters - # ------------------------------------------------------------------ - - def _emit_branch(self, instr: pyqir.Instruction) -> None: - """Emit jump or conditional branch.""" - operands = instr.operands - if len(operands) == 1: - # Unconditional: br label %target - target = self._block_to_id[operands[0]] - self._emit(OP_JUMP, dst=target) - else: - # Conditional: br i1 %cond, label %true, label %false - # pyqir operands: [condition, FALSE_block, TRUE_block] - cond_reg = self._resolve_operand(operands[0]) - false_block = self._block_to_id[operands[1]] - true_block = self._block_to_id[operands[2]] - self._emit(OP_BRANCH, src0=cond_reg, aux0=true_block, aux1=false_block) - - def _emit_phi(self, phi_instr: pyqir.Phi) -> None: - """Emit a PHI node with side table entries.""" - dst_reg = self._alloc_reg(phi_instr, self._type_tag(phi_instr.type)) - phi_offset = len(self.phi_entries) - for value, block in phi_instr.incoming: - operand = self._resolve_operand(value) - if isinstance(operand, Reg): - val_reg = operand.val - else: - # Immediate values must be materialized into a register - # because the GPU phi_table stores register indices. - reg = self._alloc_reg(None, self._type_tag(phi_instr.type)) - self._emit(OP_MOV | FLAG_SRC0_IMM, dst=reg, src0=operand.val) - val_reg = reg.val - block_id = self._block_to_id[block] - phi_entry = PhiNodeEntry(block_id, val_reg) - self.phi_entries.append(phi_entry) - count = len(phi_instr.incoming) - self._emit(OP_PHI, dst=dst_reg, aux0=phi_offset, aux1=count) - - def _emit_select(self, instr: pyqir.Instruction) -> None: - """Emit a SELECT instruction.""" - dst = self._alloc_reg(instr, self._type_tag(instr.type)) - cond = self._resolve_operand(instr.operands[0]) - true_val = self._resolve_operand(instr.operands[1]) - false_val = self._resolve_operand(instr.operands[2]) - self._emit(OP_SELECT, dst=dst, src0=cond, aux0=true_val, aux1=false_val) - - def _emit_switch(self, switch_instr: pyqir.Switch) -> None: - """Emit a SWITCH instruction with case table entries. - - NOTE: We use ``operands`` instead of the ``.cond`` / ``.cases`` - helpers because pyqir's ``Switch.cond`` returns a stale ``Function`` - reference when ``mod.functions`` has already been iterated (two-pass - compilation). ``operands`` is not affected by this behavior. - """ - # operands layout: [cond, default_block, case_val0, case_block0, ...] - ops = switch_instr.operands - cond_reg = self._resolve_operand(ops[0]) - default_block = self._block_to_id[ops[1]] - case_offset = len(self.switch_cases) - num_case_pairs = (len(ops) - 2) // 2 - for i in range(num_case_pairs): - case_val = ops[2 + 2 * i] - case_block = ops[2 + 2 * i + 1] - target_block = self._block_to_id[case_block] - switch_case = SwitchCase(case_val.value, target_block) - self.switch_cases.append(switch_case) - case_count = num_case_pairs - self._emit( - OP_SWITCH, - src0=cond_reg, - aux0=default_block, - aux1=case_offset, - aux2=case_count, - ) - - def _emit_ret(self, instr: Any) -> None: - """Emit RET or CALL_RETURN.""" - if not self._current_func_is_entry: - # Return from IR-defined function - if len(instr.operands) > 0: - ret_reg = self._resolve_operand(instr.operands[0]) - self._emit(OP_CALL_RETURN, src0=ret_reg) - else: - self._emit(OP_CALL_RETURN) - else: - # Return from entry point - if len(instr.operands) > 0: - ret_reg = self._resolve_operand(instr.operands[0]) - self._emit(OP_RET, dst=ret_reg) - else: - # Void return — use immediate 0 as exit code. - self._emit(OP_RET, dst=IntOperand(0)) - - # ------------------------------------------------------------------ - # Comparison emitters - # ------------------------------------------------------------------ - - def _emit_icmp(self, instr: Any) -> None: - """Emit an integer comparison.""" - cond_code = ICMP_MAP.get(instr.predicate, 0) - dst = self._alloc_reg(instr, REG_TYPE_BOOL) - src0 = self._resolve_operand(instr.operands[0]) - src1 = self._resolve_operand(instr.operands[1]) - self._emit(OP_ICMP | (cond_code << 8), dst=dst, src0=src0, src1=src1) - - def _emit_fcmp(self, instr: Any) -> None: - """Emit a float comparison.""" - cond_code = FCMP_MAP.get(instr.predicate, 0) - dst = self._alloc_reg(instr, REG_TYPE_BOOL) - src0 = self._resolve_operand(instr.operands[0]) - src1 = self._resolve_operand(instr.operands[1]) - self._emit( - OP_FCMP | (cond_code << 8) | FLAG_FLOAT, - dst=dst, - src0=src0, - src1=src1, - ) - - # ------------------------------------------------------------------ - # inttoptr handling - # ------------------------------------------------------------------ - - def _emit_inttoptr(self, instr: Any) -> None: - """Handle ``inttoptr`` — just propagate the source register. - - ``inttoptr i64 %v to %Qubit*`` is a no-op cast; the integer value - is the qubit/result ID. We use OP_MOV to alias the value. - """ - src_operand = instr.operands[0] - src_reg = self._resolve_operand(src_operand) - # Register the inttoptr result as pointing to the same register - dst = self._alloc_reg(instr, REG_TYPE_PTR) - self._emit(OP_MOV, dst=dst, src0=src_reg) - - # ------------------------------------------------------------------ - # IR-defined function call/return - # ------------------------------------------------------------------ - - def _emit_ir_function_call(self, call: Any) -> None: - """Emit OP_CALL for an IR-defined function.""" - func_name = call.callee.name - func_id = self._func_to_id[func_name] - arg_offset = len(self.call_args) - for arg in call.args: - operand = self._resolve_operand(arg) - if isinstance(operand, Reg): - self.call_args.append(operand.val) - else: - # Immediate values must be materialized into a register - # because the GPU call_arg_table stores register indices. - reg = self._alloc_reg(None, REG_TYPE_PTR) - self._emit(OP_MOV | FLAG_SRC0_IMM, dst=reg, src0=operand.val) - self.call_args.append(reg.val) - # Allocate return register if function has non-void return type - if call.type.is_void: - return_reg = VOID_RETURN # no return - else: - return_reg = self._alloc_reg(call, REG_TYPE_I32) - self._emit( - OP_CALL, - dst=return_reg, - aux0=func_id, - aux1=len(call.args), - aux2=arg_offset, - ) - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _extract_label(self, value: Any) -> str: - """Extract a label string from a call argument.""" - bs = pyqir.extract_byte_string(value) - if bs is not None: - return bs.decode("utf-8") - return "" diff --git a/source/pip/qsharp/_device/_atom/_decomp.py b/source/pip/qsharp/_device/_atom/_decomp.py index 56451113de..ed30473fef 100644 --- a/source/pip/qsharp/_device/_atom/_decomp.py +++ b/source/pip/qsharp/_device/_atom/_decomp.py @@ -1,20 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import ( - FloatConstant, - const, - Function, - FunctionType, - Type, - qubit_type, - result_type, - result, - Context, - Linkage, - QirModuleVisitor, - required_num_results, -) +try: + from pyqir import ( + FloatConstant, + const, + Function, + FunctionType, + Type, + qubit_type, + result_type, + result, + Context, + Linkage, + QirModuleVisitor, + required_num_results, + ) +except ImportError: + pass # PyQIR required only for neutral atom features from math import pi from ._utils import TOLERANCE diff --git a/source/pip/qsharp/_device/_atom/_optimize.py b/source/pip/qsharp/_device/_atom/_optimize.py index a6f6058fc6..d95ff5d267 100644 --- a/source/pip/qsharp/_device/_atom/_optimize.py +++ b/source/pip/qsharp/_device/_atom/_optimize.py @@ -1,19 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import ( - Type, - Function, - FunctionType, - FloatConstant, - Linkage, - const, - qubit_type, - qubit_id, - result_type, - is_entry_point, - QirModuleVisitor, -) +try: + from pyqir import ( + Type, + Function, + FunctionType, + FloatConstant, + Linkage, + const, + qubit_type, + qubit_id, + result_type, + is_entry_point, + QirModuleVisitor, + ) +except ImportError: + pass # PyQIR required only for neutral atom features from math import pi from ._utils import TOLERANCE diff --git a/source/pip/qsharp/_device/_atom/_reorder.py b/source/pip/qsharp/_device/_atom/_reorder.py index 3efed6a4f0..6626f95cf9 100644 --- a/source/pip/qsharp/_device/_atom/_reorder.py +++ b/source/pip/qsharp/_device/_atom/_reorder.py @@ -3,12 +3,16 @@ from ._utils import as_qis_gate, get_used_values, uses_any_value from .._device import Device -from pyqir import ( - Call, - Instruction, - Function, - QirModuleVisitor, -) + +try: + from pyqir import ( + Call, + Instruction, + Function, + QirModuleVisitor, + ) +except ImportError: + pass # PyQIR required only for neutral atom features def is_output_recording(instr: Instruction): diff --git a/source/pip/qsharp/_device/_atom/_scheduler.py b/source/pip/qsharp/_device/_atom/_scheduler.py index 77282010fb..5a6d80285f 100644 --- a/source/pip/qsharp/_device/_atom/_scheduler.py +++ b/source/pip/qsharp/_device/_atom/_scheduler.py @@ -2,19 +2,23 @@ # Licensed under the MIT License. from ._utils import as_qis_gate, get_used_values, uses_any_value -from pyqir import ( - Call, - Instruction, - Function, - QirModuleVisitor, - FunctionType, - Type, - Linkage, - qubit_type, - qubit_id, - IntType, - Value, -) + +try: + from pyqir import ( + Call, + Instruction, + Function, + QirModuleVisitor, + FunctionType, + Type, + Linkage, + qubit_type, + qubit_id, + IntType, + Value, + ) +except ImportError: + pass # PyQIR required only for neutral atom features from .._device import Device, Zone, ZoneType from collections import defaultdict from dataclasses import dataclass diff --git a/source/pip/qsharp/_device/_atom/_trace.py b/source/pip/qsharp/_device/_atom/_trace.py index 7bae2649d1..e34def1df9 100644 --- a/source/pip/qsharp/_device/_atom/_trace.py +++ b/source/pip/qsharp/_device/_atom/_trace.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import QirModuleVisitor, qubit_id, required_num_qubits +try: + from pyqir import QirModuleVisitor, qubit_id, required_num_qubits +except ImportError: + pass # PyQIR required only for neutral atom features from .._device import Device diff --git a/source/pip/qsharp/_device/_atom/_utils.py b/source/pip/qsharp/_device/_atom/_utils.py index 683f5cd5f4..b7e4417c11 100644 --- a/source/pip/qsharp/_device/_atom/_utils.py +++ b/source/pip/qsharp/_device/_atom/_utils.py @@ -1,16 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import ( - Instruction, - Call, - Constant, - Value, - qubit_id, - is_qubit_type, - result_id, - is_result_type, -) +try: + from pyqir import ( + Instruction, + Call, + Constant, + Value, + qubit_id, + is_qubit_type, + result_id, + is_result_type, + ) +except ImportError: + pass # PyQIR required only for neutral atom features from typing import Dict TOLERANCE: float = 1.1920929e-7 # Machine epsilon for 32-bit IEEE FP numbers. diff --git a/source/pip/qsharp/_device/_atom/_validate.py b/source/pip/qsharp/_device/_atom/_validate.py index 0ebab719f8..59c953800d 100644 --- a/source/pip/qsharp/_device/_atom/_validate.py +++ b/source/pip/qsharp/_device/_atom/_validate.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import QirModuleVisitor, is_entry_point, Opcode +try: + from pyqir import QirModuleVisitor, is_entry_point, Opcode +except ImportError: + pass # PyQIR required only for neutral atom features class ValidateAllowedIntrinsics(QirModuleVisitor): diff --git a/source/pip/qsharp/_simulation.py b/source/pip/qsharp/_simulation.py index 084b7ca625..19084acc97 100644 --- a/source/pip/qsharp/_simulation.py +++ b/source/pip/qsharp/_simulation.py @@ -1,10 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pathlib import Path +import re import random from typing import Callable, Literal, List, Optional, Tuple, TypeAlias, Union -import pyqir from ._native import ( QirInstructionId, QirInstruction, @@ -15,474 +14,125 @@ NoiseConfig, GpuContext, try_create_gpu_adapter, -) -from pyqir import ( - Function, - FunctionType, - Type, - qubit_type, - Linkage, + get_qir_profile, + parse_base_profile_qir, + compile_adaptive_program, ) from ._qsharp import QirInputData, Result from typing import TYPE_CHECKING -from ._adaptive_pass import AdaptiveProfilePass, OP_RECORD_OUTPUT if TYPE_CHECKING: # This is in the pyi file only from ._native import GpuShotResults -class AggregateGatesPass(pyqir.QirModuleVisitor): - def __init__(self): - super().__init__() - self.gates: List[QirInstruction | Tuple] = [] - self.required_num_qubits = None - self.required_num_results = None - - def _get_value_as_string(self, value: pyqir.Value) -> str: - value = pyqir.extract_byte_string(value) - if value is None: - return "" - value = value.decode("utf-8") - return value - - def run(self, mod: pyqir.Module) -> Tuple[List[QirInstruction | Tuple], int, int]: - errors = mod.verify() - if errors is not None: - raise ValueError(f"Module verification failed: {errors}") - - # verify that the module is base profile - func = next(filter(pyqir.is_entry_point, mod.functions)) - self.required_num_qubits = pyqir.required_num_qubits(func) - self.required_num_results = pyqir.required_num_results(func) - - super().run(mod) - return (self.gates, self.required_num_qubits, self.required_num_results) - - def _on_block(self, block): - if ( - block.terminator - and block.terminator.opcode == pyqir.Opcode.BR - and len(block.terminator.operands) > 1 - ): - raise ValueError( - "simulation of programs with branching control flow is not supported" - ) - super()._on_block(block) - - def _on_call_instr(self, call: pyqir.Call) -> None: - callee_name = call.callee.name - if callee_name == "__quantum__qis__ccx__body": - self.gates.append( - ( - QirInstructionId.CCX, - pyqir.qubit_id(call.args[0]), - pyqir.qubit_id(call.args[1]), - pyqir.qubit_id(call.args[2]), - ) - ) - elif callee_name == "__quantum__qis__cx__body": - self.gates.append( - ( - QirInstructionId.CX, - pyqir.qubit_id(call.args[0]), - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__cy__body": - self.gates.append( - ( - QirInstructionId.CY, - pyqir.qubit_id(call.args[0]), - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__cz__body": - self.gates.append( - ( - QirInstructionId.CZ, - pyqir.qubit_id(call.args[0]), - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__swap__body": - self.gates.append( - ( - QirInstructionId.SWAP, - pyqir.qubit_id(call.args[0]), - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__rx__body": - self.gates.append( - ( - QirInstructionId.RX, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__rxx__body": - self.gates.append( - ( - QirInstructionId.RXX, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - pyqir.qubit_id(call.args[2]), - ) - ) - elif callee_name == "__quantum__qis__ry__body": - self.gates.append( - ( - QirInstructionId.RY, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__ryy__body": - self.gates.append( - ( - QirInstructionId.RYY, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - pyqir.qubit_id(call.args[2]), - ) - ) - elif callee_name == "__quantum__qis__rz__body": - self.gates.append( - ( - QirInstructionId.RZ, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__rzz__body": - self.gates.append( - ( - QirInstructionId.RZZ, - call.args[0].value, - pyqir.qubit_id(call.args[1]), - pyqir.qubit_id(call.args[2]), - ) - ) - elif callee_name == "__quantum__qis__h__body": - self.gates.append((QirInstructionId.H, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__s__body": - self.gates.append((QirInstructionId.S, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__s__adj": - self.gates.append((QirInstructionId.SAdj, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__sx__body": - self.gates.append((QirInstructionId.SX, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__t__body": - self.gates.append((QirInstructionId.T, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__t__adj": - self.gates.append((QirInstructionId.TAdj, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__x__body": - self.gates.append((QirInstructionId.X, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__y__body": - self.gates.append((QirInstructionId.Y, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__z__body": - self.gates.append((QirInstructionId.Z, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__m__body": - self.gates.append( - ( - QirInstructionId.M, - pyqir.qubit_id(call.args[0]), - pyqir.result_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__mz__body": - self.gates.append( - ( - QirInstructionId.MZ, - pyqir.qubit_id(call.args[0]), - pyqir.result_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__mresetz__body": - self.gates.append( - ( - QirInstructionId.MResetZ, - pyqir.qubit_id(call.args[0]), - pyqir.result_id(call.args[1]), - ) - ) - elif callee_name == "__quantum__qis__reset__body": - self.gates.append((QirInstructionId.RESET, pyqir.qubit_id(call.args[0]))) - elif callee_name == "__quantum__qis__move__body": - self.gates.append( - ( - QirInstructionId.Move, - pyqir.qubit_id(call.args[0]), - ) - ) - elif callee_name == "__quantum__rt__result_record_output": - tag = self._get_value_as_string(call.args[1]) - self.gates.append( - ( - QirInstructionId.ResultRecordOutput, - str(pyqir.result_id(call.args[0])), - tag, - ) - ) - elif callee_name == "__quantum__rt__tuple_record_output": - tag = self._get_value_as_string(call.args[1]) - self.gates.append( - (QirInstructionId.TupleRecordOutput, str(call.args[0].value), tag) - ) - elif callee_name == "__quantum__rt__array_record_output": - tag = self._get_value_as_string(call.args[1]) - self.gates.append( - (QirInstructionId.ArrayRecordOutput, str(call.args[0].value), tag) - ) - elif ( - callee_name == "__quantum__rt__initialize" - or callee_name == "__quantum__rt__begin_parallel" - or callee_name == "__quantum__rt__end_parallel" - or callee_name == "__quantum__qis__barrier__body" - # We only hit this during noiseless simulations - or "qdk_noise" in call.callee.attributes.func - ): - pass - else: - raise ValueError(f"Unsupported call instruction: {callee_name}") - - -class CorrelatedNoisePass(AggregateGatesPass): - """ - This pass replaces the QIR intrinsics that are in the provided NoiseConfig - by correlated noise instructions that the simulator understands. - """ - - def __init__(self, noise_config: NoiseConfig): - super().__init__() - self.noise_intrinsics_table = noise_config.intrinsics - - def _on_call_instr(self, call: pyqir.Call) -> None: - callee_name = call.callee.name - if callee_name in self.noise_intrinsics_table: - self.gates.append( - ( - QirInstructionId.CorrelatedNoise, - self.noise_intrinsics_table.get_intrinsic_id(callee_name), - [pyqir.qubit_id(arg) for arg in call.args], - ) - ) - elif "qdk_noise" in call.callee.attributes.func: - # If we are running a noisy simulation, we treat - # missing noise intrinsics as an error. - raise ValueError(f"Missing noise intrinsic: {callee_name}") - else: - super()._on_call_instr(call) - - -class GpuCorrelatedNoisePass(AggregateGatesPass): - """ - A special case of the CorrelatedNoisePass that uses data loaded - directly from rust instead of a NoiseConfig object to detect the - correlated noise intrinsics. - """ - - def __init__(self, noise_table: List[Tuple[int, str, int]]): - super().__init__() - self.noise_table = dict() - for table_id, name, _count in noise_table: - self.noise_table[name] = table_id - - def _on_call_instr(self, call: pyqir.Call) -> None: - callee_name = call.callee.name - if callee_name in self.noise_table: - self.gates.append( - ( - QirInstructionId.CorrelatedNoise, - int(self.noise_table[callee_name]), # Noise table ID - [pyqir.qubit_id(qubit) for qubit in call.args], # qubit args - ) - ) - elif "qdk_noise" in call.callee.attributes.func: - # If we are running a noisy simulation, we treat - # missing noise intrinsics as an error. - raise ValueError(f"Missing noise intrinsic: {callee_name}") - else: - super()._on_call_instr(call) - - -class OutputRecordingPass(pyqir.QirModuleVisitor): - _output_str = "" - _closers = [] - _counters = [] - - def process_output(self, bitstring: str): - return eval( - self._output_str, - { - "o": [ - Result.Zero if x == "0" else Result.One if x == "1" else Result.Loss - for x in bitstring - ] - }, +def _normalize_input(input: Union[QirInputData, str, bytes]) -> str: + """Normalize QIR input to text IR string.""" + if isinstance(input, QirInputData): + return str(input) + elif isinstance(input, str): + return input + else: + raise ValueError( + "Bitcode input is not supported without PyQIR. " "Provide text IR instead." ) - def _on_function(self, function): - if pyqir.is_entry_point(function): - super()._on_function(function) - while len(self._closers) > 0: - self._output_str += self._closers.pop() - self._counters.pop() - - def _on_rt_result_record_output(self, call, result, target): - self._output_str += f"o[{pyqir.result_id(result)}]" - while len(self._counters) > 0: - self._output_str += "," - self._counters[-1] -= 1 - if self._counters[-1] == 0: - self._output_str += self._closers[-1] - self._closers.pop() - self._counters.pop() - else: - break - - def _on_rt_array_record_output(self, call, value, target): - self._output_str += "[" - self._closers.append("]") - # if len(self._counters) > 0: - # self._counters[-1] -= 1 - self._counters.append(value.value) - - def _on_rt_tuple_record_output(self, call, value, target): - self._output_str += "(" - self._closers.append(")") - # if len(self._counters) > 0: - # self._counters[-1] -= 1 - self._counters.append(value.value) - - -class DecomposeCcxPass(pyqir.QirModuleVisitor): - - h_func: Function - t_func: Function - tadj_func: Function - cz_func: Function - def __init__(self): - super().__init__() - - def _on_module(self, module): - void = Type.void(module.context) - qubit_ty = qubit_type(module.context) - - # Find or create all the needed functions. - for func in module.functions: - match func.name: - case "__quantum__qis__h__body": - self.h_func = func - case "__quantum__qis__t__body": - self.t_func = func - case "__quantum__qis__t__adj": - self.tadj_func = func - case "__quantum__qis__cz__body": - self.cz_func = func - if not hasattr(self, "h_func"): - self.h_func = Function( - FunctionType(void, [qubit_ty]), - Linkage.EXTERNAL, - "__quantum__qis__h__body", - module, - ) - if not hasattr(self, "t_func"): - self.t_func = Function( - FunctionType(void, [qubit_ty]), - Linkage.EXTERNAL, - "__quantum__qis__t__body", - module, - ) - if not hasattr(self, "tadj_func"): - self.tadj_func = Function( - FunctionType(void, [qubit_ty]), - Linkage.EXTERNAL, - "__quantum__qis__t__adj", - module, - ) - if not hasattr(self, "cz_func"): - self.cz_func = Function( - FunctionType(void, [qubit_ty, qubit_ty]), - Linkage.EXTERNAL, - "__quantum__qis__cz__body", - module, - ) - super()._on_module(module) - - def _on_qis_ccx(self, call, ctrl1, ctrl2, target): - self.builder.insert_before(call) - self.builder.call(self.h_func, [target]) - self.builder.call(self.tadj_func, [ctrl1]) - self.builder.call(self.tadj_func, [ctrl2]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.cz_func, [target, ctrl1]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.t_func, [ctrl1]) - self.builder.call(self.h_func, [target]) - self.builder.call(self.cz_func, [ctrl2, target]) - self.builder.call(self.h_func, [target]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.cz_func, [ctrl2, ctrl1]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.t_func, [target]) - self.builder.call(self.tadj_func, [ctrl1]) - self.builder.call(self.h_func, [target]) - self.builder.call(self.cz_func, [ctrl2, target]) - self.builder.call(self.h_func, [target]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.cz_func, [target, ctrl1]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.tadj_func, [target]) - self.builder.call(self.t_func, [ctrl1]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.cz_func, [ctrl2, ctrl1]) - self.builder.call(self.h_func, [ctrl1]) - self.builder.call(self.h_func, [target]) - call.erase() - - -Simulator: TypeAlias = Callable[ - [List[QirInstruction], int, int, int, NoiseConfig, int], str -] - - -def preprocess_simulation_input( - input: Union[QirInputData, str, bytes], +def _prepare_params( shots: Optional[int] = 1, noise: Optional[NoiseConfig] = None, seed: Optional[int] = None, -) -> tuple[pyqir.Module, int, Optional[NoiseConfig], int]: +) -> tuple[int, Optional[NoiseConfig], int]: if shots is None: shots = 1 - # If no seed specified, generate a random u32 to use if seed is None: seed = random.randint(0, 2**32 - 1) if isinstance(noise, tuple): raise ValueError( "Specifying Pauli noise via a tuple is not supported. Use a NoiseConfig instead." ) + return (shots, noise, seed) + + +def _process_output(output_fmt: str, bitstring: str): + """Evaluate the output format string against a bitstring of results.""" + return eval( + output_fmt, + { + "o": [ + Result.Zero if x == "0" else Result.One if x == "1" else Result.Loss + for x in bitstring + ] + }, + ) - context = pyqir.Context() - if isinstance(input, QirInputData): - mod = pyqir.Module.from_ir(context, str(input)) - elif isinstance(input, str): - mod = pyqir.Module.from_ir(context, input) - else: - mod = pyqir.Module.from_bitcode(context, input) - return (mod, shots, noise, seed) +def _build_noise_dict(ir: str, noise_config: NoiseConfig) -> dict: + """Build noise intrinsics dict by scanning IR for function declarations + and checking them against the noise config intrinsics table.""" + noise_dict = {} + intrinsics = noise_config.intrinsics + for match in re.finditer(r"(?:declare|define)\s+\S+\s+@([^\s(]+)", ir): + name = match.group(1) + if name in intrinsics: + noise_dict[name] = intrinsics.get_intrinsic_id(name) + return noise_dict + + +def _decompose_ccx_in_gates(gates: list) -> list: + """Replace CCX gates with the decomposed gate sequence using H, T, TAdj, + and CZ gates (needed for GPU simulator compatibility).""" + result = [] + for gate in gates: + if ( + isinstance(gate, tuple) + and len(gate) >= 1 + and gate[0] == QirInstructionId.CCX + ): + _, ctrl1, ctrl2, target = gate + result.extend( + [ + (QirInstructionId.H, target), + (QirInstructionId.TAdj, ctrl1), + (QirInstructionId.TAdj, ctrl2), + (QirInstructionId.H, ctrl1), + (QirInstructionId.CZ, target, ctrl1), + (QirInstructionId.H, ctrl1), + (QirInstructionId.T, ctrl1), + (QirInstructionId.H, target), + (QirInstructionId.CZ, ctrl2, target), + (QirInstructionId.H, target), + (QirInstructionId.H, ctrl1), + (QirInstructionId.CZ, ctrl2, ctrl1), + (QirInstructionId.H, ctrl1), + (QirInstructionId.T, target), + (QirInstructionId.TAdj, ctrl1), + (QirInstructionId.H, target), + (QirInstructionId.CZ, ctrl2, target), + (QirInstructionId.H, target), + (QirInstructionId.H, ctrl1), + (QirInstructionId.CZ, target, ctrl1), + (QirInstructionId.H, ctrl1), + (QirInstructionId.TAdj, target), + (QirInstructionId.T, ctrl1), + (QirInstructionId.H, ctrl1), + (QirInstructionId.CZ, ctrl2, ctrl1), + (QirInstructionId.H, ctrl1), + (QirInstructionId.H, target), + ] + ) + else: + result.append(gate) + return result + + +Simulator: TypeAlias = Callable[ + [List[QirInstruction], int, int, int, NoiseConfig, int], str +] -def is_adaptive(mod: pyqir.Module) -> bool: - """Check if the QIR module uses the Adaptive Profile.""" - entry = next(filter(pyqir.is_entry_point, mod.functions), None) - if entry is None: - return False - func_attrs = entry.attributes.func - if "qir_profiles" not in func_attrs: - return False - return func_attrs["qir_profiles"].string_value == "adaptive_profile" +def is_adaptive(ir: str) -> bool: + """Check if the QIR uses the Adaptive Profile.""" + return get_qir_profile(ir) == "adaptive_profile" def run_qir_clifford( @@ -491,17 +141,17 @@ def run_qir_clifford( noise: Optional[NoiseConfig] = None, seed: Optional[int] = None, ) -> List: - (mod, shots, noise, seed) = preprocess_simulation_input(input, shots, noise, seed) - if noise is None: - (gates, num_qubits, num_results) = AggregateGatesPass().run(mod) - else: - (gates, num_qubits, num_results) = CorrelatedNoisePass(noise).run(mod) - recorder = OutputRecordingPass() - recorder.run(mod) + ir = _normalize_input(input) + (shots, noise, seed) = _prepare_params(shots, noise, seed) + + noise_intrinsics = _build_noise_dict(ir, noise) if noise else None + (gates, num_qubits, num_results, output_fmt) = parse_base_profile_qir( + ir, noise_intrinsics + ) return list( map( - recorder.process_output, + lambda bs: _process_output(output_fmt, bs), run_clifford(gates, num_qubits, num_results, shots, noise, seed), ) ) @@ -513,17 +163,17 @@ def run_qir_cpu( noise: Optional[NoiseConfig] = None, seed: Optional[int] = None, ) -> List: - (mod, shots, noise, seed) = preprocess_simulation_input(input, shots, noise, seed) - if noise is None: - (gates, num_qubits, num_results) = AggregateGatesPass().run(mod) - else: - (gates, num_qubits, num_results) = CorrelatedNoisePass(noise).run(mod) - recorder = OutputRecordingPass() - recorder.run(mod) + ir = _normalize_input(input) + (shots, noise, seed) = _prepare_params(shots, noise, seed) + + noise_intrinsics = _build_noise_dict(ir, noise) if noise else None + (gates, num_qubits, num_results, output_fmt) = parse_base_profile_qir( + ir, noise_intrinsics + ) return list( map( - recorder.process_output, + lambda bs: _process_output(output_fmt, bs), run_cpu_full_state(gates, num_qubits, num_results, shots, noise, seed), ) ) @@ -547,35 +197,35 @@ def run_qir_gpu( noise: Optional[NoiseConfig] = None, seed: Optional[int] = None, ) -> List: - (mod, shots, noise, seed) = preprocess_simulation_input(input, shots, noise, seed) - # Ccx is not support in the GPU simulator, decompose it - DecomposeCcxPass().run(mod) - if is_adaptive(mod): - program = AdaptiveProfilePass().run(mod, noise) - results = run_adaptive_parallel_shots(program.as_dict(), shots, noise, seed) + ir = _normalize_input(input) + (shots, noise, seed) = _prepare_params(shots, noise, seed) + + if is_adaptive(ir): + noise_intrinsics = _build_noise_dict(ir, noise) if noise else None + program = compile_adaptive_program(ir, noise_intrinsics) + results = run_adaptive_parallel_shots(program, shots, noise, seed) # Extract recorded output result indices from the bytecode. - # OP_RECORD_OUTPUT with aux1=0 is result_record_output where + # OP_RECORD_OUTPUT (0x14) with aux1=0 is result_record_output where # src0 is the result index in the results buffer. recorded_result_indices = [] - for ins in program.instructions: - if (ins.opcode & 0xFF) == OP_RECORD_OUTPUT and ins.aux1 == 0: - recorded_result_indices.append(ins.src0) + for ins in program["instructions"]: + if (ins[0] & 0xFF) == 0x14 and ins[5] == 0: + recorded_result_indices.append(ins[2]) # Filter shot_results to only include recorded output indices filtered = [] for s in results: filtered.append([str_to_result(s[i]) for i in recorded_result_indices]) return filtered else: - if noise is None: - (gates, num_qubits, num_results) = AggregateGatesPass().run(mod) - else: - (gates, num_qubits, num_results) = CorrelatedNoisePass(noise).run(mod) - recorder = OutputRecordingPass() - recorder.run(mod) + noise_intrinsics = _build_noise_dict(ir, noise) if noise else None + (gates, num_qubits, num_results, output_fmt) = parse_base_profile_qir( + ir, noise_intrinsics + ) + gates = _decompose_ccx_in_gates(gates) return list( map( - recorder.process_output, + lambda bs: _process_output(output_fmt, bs), run_parallel_shots(gates, shots, num_qubits, num_results, noise, seed), ) ) @@ -585,18 +235,13 @@ def prepare_qir_with_correlated_noise( input: Union[QirInputData, str, bytes], noise_tables: List[Tuple[int, str, int]], ) -> Tuple[List[QirInstruction], int, int]: - # Turn the input into a QIR module - (mod, _, _, _) = preprocess_simulation_input(input, None, None, None) - - # Ccx is not support in the GPU simulator, decompose it - DecomposeCcxPass().run(mod) - - # Extract the gates including correlated noise instructions - (gates, required_num_qubits, required_num_results) = GpuCorrelatedNoisePass( - noise_tables - ).run(mod) - - return (gates, required_num_qubits, required_num_results) + ir = _normalize_input(input) + noise_dict = {name: table_id for table_id, name, _count in noise_tables} + (gates, num_qubits, num_results, _) = parse_base_profile_qir( + ir, noise_dict if noise_dict else None + ) + gates = _decompose_ccx_in_gates(gates) + return (gates, num_qubits, num_results) class GpuSimulator: @@ -638,28 +283,28 @@ def set_program(self, input: Union[QirInputData, str, bytes]): multiple programs sequentially by calling this method multiple times before calling `run_shots` without needing to create a new simulator instance or reloading noise tables. """ - # Parse the QIR module to detect profile - (mod, _, _, _) = preprocess_simulation_input(input, None, None, None) - if is_adaptive(mod): + ir = _normalize_input(input) + if is_adaptive(ir): self._is_adaptive = True + # Build noise_intrinsics dict from loaded noise tables (if any) noise_intrinsics = None if self.tables is not None: noise_intrinsics = {name: table_id for table_id, name, _ in self.tables} - program = AdaptiveProfilePass().run(mod, noise_intrinsics=noise_intrinsics) - self.gpu_context.set_adaptive_program(program.as_dict()) + program = compile_adaptive_program(ir, noise_intrinsics) + self.gpu_context.set_adaptive_program(program) # Extract recorded output result indices from the bytecode. - # OP_RECORD_OUTPUT with aux1=0 is result_record_output where + # OP_RECORD_OUTPUT (0x14) with aux1=0 is result_record_output where # src0 is the result index in the results buffer. self._recorded_result_indices = [] - for instr in program.instructions: - if instr.opcode & 0xFF == OP_RECORD_OUTPUT and instr.aux1 == 0: - self._recorded_result_indices.append(instr.src0) + for instr in program["instructions"]: + if instr[0] & 0xFF == 0x14 and instr[5] == 0: + self._recorded_result_indices.append(instr[2]) else: (self.gates, self.required_num_qubits, self.required_num_results) = ( prepare_qir_with_correlated_noise( - input, self.tables if not self.tables is None else [] + input, self.tables if self.tables is not None else [] ) ) self.gpu_context.set_program( diff --git a/source/pip/src/interpreter.rs b/source/pip/src/interpreter.rs index d9272eaa2f..a66788cfca 100644 --- a/source/pip/src/interpreter.rs +++ b/source/pip/src/interpreter.rs @@ -24,10 +24,22 @@ use crate::{ noisy_simulator::register_noisy_simulator_submodule, qir_simulation::{ IdleNoiseParams, NoiseConfig, NoiseTable, QirInstruction, QirInstructionId, + adaptive_pass::compile_adaptive_program, + atom_decomp::{ + atom_decompose_multi_qubit_to_cz, atom_decompose_rz_to_clifford, + atom_decompose_single_qubit_to_rz_sx, atom_decompose_single_rotation_to_rz, + atom_replace_reset_with_mresetz, + }, + atom_optimize::{atom_optimize_single_qubit_gates, atom_prune_unused_functions}, + atom_reorder::atom_reorder, + atom_scheduler::atom_schedule, + atom_trace::trace_atom_program, + atom_validate::{validate_allowed_intrinsics, validate_no_conditional_branches}, cpu_simulators::{run_clifford, run_cpu_full_state}, gpu_full_state::{ GpuContext, run_adaptive_parallel_shots, run_parallel_shots, try_create_gpu_adapter, }, + native_qir_parser::{get_qir_profile, parse_base_profile_qir}, unbind_noise_config, }, }; @@ -146,6 +158,23 @@ fn _native<'a>(py: Python<'a>, m: &Bound<'a, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(circuit_qasm_program, m)?)?; m.add_function(wrap_pyfunction!(compile_qasm_program_to_qir, m)?)?; m.add_function(wrap_pyfunction!(compile_qasm_to_qsharp, m)?)?; + m.add_function(wrap_pyfunction!(get_qir_profile, m)?)?; + m.add_function(wrap_pyfunction!(parse_base_profile_qir, m)?)?; + m.add_function(wrap_pyfunction!(compile_adaptive_program, m)?)?; + // Neutral atom passes + m.add_function(wrap_pyfunction!(validate_allowed_intrinsics, m)?)?; + m.add_function(wrap_pyfunction!(validate_no_conditional_branches, m)?)?; + m.add_function(wrap_pyfunction!(trace_atom_program, m)?)?; + // Neutral atom mutation passes + m.add_function(wrap_pyfunction!(atom_decompose_multi_qubit_to_cz, m)?)?; + m.add_function(wrap_pyfunction!(atom_decompose_single_rotation_to_rz, m)?)?; + m.add_function(wrap_pyfunction!(atom_decompose_single_qubit_to_rz_sx, m)?)?; + m.add_function(wrap_pyfunction!(atom_decompose_rz_to_clifford, m)?)?; + m.add_function(wrap_pyfunction!(atom_replace_reset_with_mresetz, m)?)?; + m.add_function(wrap_pyfunction!(atom_optimize_single_qubit_gates, m)?)?; + m.add_function(wrap_pyfunction!(atom_prune_unused_functions, m)?)?; + m.add_function(wrap_pyfunction!(atom_reorder, m)?)?; + m.add_function(wrap_pyfunction!(atom_schedule, m)?)?; Ok(()) } diff --git a/source/pip/src/qir_simulation.rs b/source/pip/src/qir_simulation.rs index 6e41c2da97..39c871e3cb 100644 --- a/source/pip/src/qir_simulation.rs +++ b/source/pip/src/qir_simulation.rs @@ -1,9 +1,18 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +pub(crate) mod adaptive_pass; +pub(crate) mod atom_decomp; +pub(crate) mod atom_optimize; +pub(crate) mod atom_reorder; +pub(crate) mod atom_scheduler; +pub(crate) mod atom_trace; +pub(crate) mod atom_utils; +pub(crate) mod atom_validate; mod correlated_noise; pub(crate) mod cpu_simulators; pub(crate) mod gpu_full_state; +pub(crate) mod native_qir_parser; use crate::qir_simulation::correlated_noise::parse_noise_table; diff --git a/source/pip/src/qir_simulation/adaptive_pass.rs b/source/pip/src/qir_simulation/adaptive_pass.rs new file mode 100644 index 0000000000..5827d37916 --- /dev/null +++ b/source/pip/src/qir_simulation/adaptive_pass.rs @@ -0,0 +1,1434 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Adaptive Profile Pass — walks adaptive-profile QIR and emits bytecode +//! consumed by the GPU/CPU parallel shot simulators. +//! +//! This is the Rust equivalent of `_adaptive_pass.py`. It takes a parsed +//! `Module`, traverses functions → basic blocks → instructions, and produces +//! the same dict-like structure that `AdaptiveProgram.as_dict()` returns. + +use pyo3::{ + Bound, IntoPyObjectExt, PyResult, Python, + exceptions::PyValueError, + pyfunction, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyTuple}, +}; +use qsc_llvm::{ + model::Type, + model::{ + Attribute, BinOpKind, CastKind, Constant, FloatPredicate, Function, Instruction, + IntPredicate, Module, Operand, + }, + parse_module, + qir::{find_entry_point, get_function_attribute, qis, rt}, +}; +use rustc_hash::FxHashMap; + +// ── Bytecode opcodes — must match `_adaptive_bytecode.py` and GPU shader ──── + +const FLAG_DST_IMM: u32 = 1 << 18; +const FLAG_SRC0_IMM: u32 = 1 << 16; +const FLAG_SRC1_IMM: u32 = 1 << 17; +const FLAG_AUX0_IMM: u32 = 1 << 19; +const FLAG_AUX1_IMM: u32 = 1 << 20; +const FLAG_AUX2_IMM: u32 = 1 << 21; +#[allow(dead_code)] +const FLAG_AUX3_IMM: u32 = 1 << 22; +const FLAG_FLOAT: u32 = 1 << 23; + +const OP_RET: u32 = 0x02; +const OP_JUMP: u32 = 0x04; +const OP_BRANCH: u32 = 0x05; +const OP_SWITCH: u32 = 0x06; +const OP_CALL: u32 = 0x07; +const OP_CALL_RETURN: u32 = 0x08; + +const OP_QUANTUM_GATE: u32 = 0x10; +const OP_MEASURE: u32 = 0x11; +const OP_RESET: u32 = 0x12; +const OP_READ_RESULT: u32 = 0x13; +const OP_RECORD_OUTPUT: u32 = 0x14; + +const OP_ADD: u32 = 0x20; +const OP_SUB: u32 = 0x21; +const OP_MUL: u32 = 0x22; +const OP_UDIV: u32 = 0x23; +const OP_SDIV: u32 = 0x24; +const OP_UREM: u32 = 0x25; +const OP_SREM: u32 = 0x26; + +const OP_AND: u32 = 0x28; +const OP_OR: u32 = 0x29; +const OP_XOR: u32 = 0x2A; +const OP_SHL: u32 = 0x2B; +const OP_LSHR: u32 = 0x2C; +const OP_ASHR: u32 = 0x2D; + +const OP_ICMP: u32 = 0x30; +const OP_FCMP: u32 = 0x31; + +const OP_FADD: u32 = 0x38; +const OP_FSUB: u32 = 0x39; +const OP_FMUL: u32 = 0x3A; +const OP_FDIV: u32 = 0x3B; + +const OP_ZEXT: u32 = 0x40; +const OP_SEXT: u32 = 0x41; +const OP_TRUNC: u32 = 0x42; +const OP_FPEXT: u32 = 0x43; +const OP_FPTRUNC: u32 = 0x44; +#[allow(dead_code)] +const OP_INTTOPTR: u32 = 0x45; +const OP_FPTOSI: u32 = 0x46; +const OP_SITOFP: u32 = 0x47; + +const OP_PHI: u32 = 0x50; +const OP_SELECT: u32 = 0x51; +const OP_MOV: u32 = 0x52; +const OP_CONST: u32 = 0x53; + +// ICmp condition codes +const ICMP_EQ: u32 = 0; +const ICMP_NE: u32 = 1; +const ICMP_SLT: u32 = 2; +const ICMP_SLE: u32 = 3; +const ICMP_SGT: u32 = 4; +const ICMP_SGE: u32 = 5; +const ICMP_ULT: u32 = 6; +const ICMP_ULE: u32 = 7; +const ICMP_UGT: u32 = 8; +const ICMP_UGE: u32 = 9; + +// FCmp condition codes +const FCMP_OEQ: u32 = 1; +const FCMP_OGT: u32 = 2; +const FCMP_OGE: u32 = 3; +const FCMP_OLT: u32 = 4; +const FCMP_OLE: u32 = 5; +const FCMP_ONE: u32 = 6; +const FCMP_ORD: u32 = 7; +const FCMP_UNO: u32 = 8; +const FCMP_UEQ: u32 = 9; +const FCMP_UGT: u32 = 10; +const FCMP_UGE: u32 = 11; +const FCMP_ULT: u32 = 12; +const FCMP_ULE: u32 = 13; +const FCMP_UNE: u32 = 14; + +// Register type tags +const REG_TYPE_BOOL: u32 = 0; +const REG_TYPE_I32: u32 = 1; +const REG_TYPE_I64: u32 = 2; +const REG_TYPE_F32: u32 = 3; +const REG_TYPE_F64: u32 = 4; +const REG_TYPE_PTR: u32 = 5; + +const VOID_RETURN: u32 = 0xFFFF_FFFF; + +// Correlated noise op ID (must match shader_types.rs) +const CORRELATED_NOISE_OP_ID: u32 = 131; + +// ── Gate mapping ──────────────────────────────────────────────────────────── + +fn gate_name_to_op_id(name: &str) -> Option { + match name { + "reset" => Some(1), + "x" => Some(2), + "y" => Some(3), + "z" => Some(4), + "h" => Some(5), + "s" => Some(6), + "s__adj" => Some(7), + "t" => Some(8), + "t__adj" => Some(9), + "sx" => Some(10), + "sx__adj" => Some(11), + "rx" => Some(12), + "ry" => Some(13), + "rz" => Some(14), + "cnot" | "cx" => Some(15), + "cz" => Some(16), + "cy" => Some(29), + "rxx" => Some(17), + "ryy" => Some(18), + "rzz" => Some(19), + "ccx" => Some(20), + "m" | "mz" => Some(21), + "mresetz" => Some(22), + "swap" => Some(24), + _ => None, + } +} + +fn is_measure_gate(name: &str) -> bool { + matches!(name, "m" | "mz" | "mresetz") +} + +fn is_reset_gate(name: &str) -> bool { + name == "reset" +} + +fn is_rotation_gate(name: &str) -> bool { + matches!(name, "rx" | "ry" | "rz" | "rxx" | "ryy" | "rzz") +} + +// ── Operand wrapper ───────────────────────────────────────────────────────── + +/// An operand that either refers to a register or carries an immediate value. +#[derive(Clone, Copy)] +enum OpVal { + Reg(u32), + IntImm(u32), + FloatImm(u32), +} + +impl OpVal { + fn raw(self) -> u32 { + match self { + OpVal::Reg(v) | OpVal::IntImm(v) | OpVal::FloatImm(v) => v, + } + } + + fn is_imm(self) -> bool { + matches!(self, OpVal::IntImm(_) | OpVal::FloatImm(_)) + } +} + +fn encode_float_as_bits(val: f64) -> u32 { + (val as f32).to_bits() +} + +fn i64_to_u32_masked(val: i64) -> u32 { + (val as u32) & 0xFFFF_FFFF +} + +// ── Block / instruction / quantum op tuples ───────────────────────────────── + +#[derive(Clone, Copy)] +struct BcBlock { + block_id: u32, + instr_offset: u32, + instr_count: u32, +} + +#[derive(Clone, Copy)] +struct BcInstr { + opcode: u32, + dst: u32, + src0: u32, + src1: u32, + aux0: u32, + aux1: u32, + aux2: u32, + aux3: u32, +} + +#[derive(Clone, Copy)] +struct BcQuantumOp { + op_id: u32, + q1: u32, + q2: u32, + q3: u32, + angle: f64, +} + +#[derive(Clone, Copy)] +struct BcFunction { + entry_block: u32, + num_params: u32, + param_base: u32, +} + +#[derive(Clone, Copy)] +struct BcPhiEntry { + block_id: u32, + val_reg: u32, +} + +#[derive(Clone, Copy)] +struct BcSwitchCase { + case_val: u32, + target_block: u32, +} + +// ── Pass state ────────────────────────────────────────────────────────────── + +struct AdaptivePass<'m> { + module: &'m Module, + + // Output tables + blocks: Vec, + instructions: Vec, + quantum_ops: Vec, + functions: Vec, + phi_entries: Vec, + switch_cases: Vec, + call_args: Vec, + labels: Vec, + register_types: Vec, + + // Internal + next_reg: u32, + next_block: u32, + next_qop: u32, + /// SSA name (%name) → register ID + value_to_reg: FxHashMap, + /// Basic block name → block ID (function-qualified: "func::block") + block_to_id: FxHashMap, + /// Function name → function table index + func_to_id: FxHashMap, + current_func_is_entry: bool, + current_func_name: String, + noise_intrinsics: Option>, +} + +impl<'m> AdaptivePass<'m> { + fn new(module: &'m Module, noise_intrinsics: Option>) -> Self { + Self { + module, + blocks: Vec::new(), + instructions: Vec::new(), + quantum_ops: Vec::new(), + functions: Vec::new(), + phi_entries: Vec::new(), + switch_cases: Vec::new(), + call_args: Vec::new(), + labels: Vec::new(), + register_types: Vec::new(), + next_reg: 0, + next_block: 0, + next_qop: 0, + value_to_reg: FxHashMap::default(), + block_to_id: FxHashMap::default(), + func_to_id: FxHashMap::default(), + current_func_is_entry: true, + current_func_name: String::new(), + noise_intrinsics, + } + } + + fn run(&mut self, entry_idx: usize) -> PyResult<()> { + // Check for arrays module flag + if self.module.get_flag("arrays").is_some() { + return Err(PyValueError::new_err( + "QIR arrays are not currently supported.", + )); + } + + // Pass 1: assign block IDs and function IDs + for func in &self.module.functions { + if !func.basic_blocks.is_empty() { + self.assign_function(func, entry_idx); + } + } + + // Pass 2: walk instructions and emit bytecode + for (idx, func) in self.module.functions.iter().enumerate() { + if !func.basic_blocks.is_empty() { + self.walk_function(func, idx == entry_idx)?; + } + } + + Ok(()) + } + + // ── Register allocation ───────────────────────────────────────────── + + fn alloc_reg(&mut self, name: Option<&str>, type_tag: u32) -> u32 { + if let Some(n) = name { + if let Some(&existing) = self.value_to_reg.get(n) { + return existing; + } + } + let reg = self.next_reg; + self.next_reg += 1; + if let Some(n) = name { + self.value_to_reg.insert(n.to_string(), reg); + } + self.register_types.push(type_tag); + reg + } + + /// Build a qualified name for a block: "func_name::block_name" + fn qualified_block_name(func_name: &str, block_name: &str) -> String { + let mut s = String::with_capacity(func_name.len() + 2 + block_name.len()); + s.push_str(func_name); + s.push_str("::"); + s.push_str(block_name); + s + } + + // ── Instruction emission ──────────────────────────────────────────── + + fn emit( + &mut self, + opcode: u32, + dst: OpVal, + src0: OpVal, + src1: OpVal, + aux0: OpVal, + aux1: OpVal, + aux2: OpVal, + aux3: OpVal, + ) { + let mut flags: u32 = 0; + if dst.is_imm() { + flags |= FLAG_DST_IMM; + } + if src0.is_imm() { + flags |= FLAG_SRC0_IMM; + } + if src1.is_imm() { + flags |= FLAG_SRC1_IMM; + } + if aux0.is_imm() { + flags |= FLAG_AUX0_IMM; + } + if aux1.is_imm() { + flags |= FLAG_AUX1_IMM; + } + if aux2.is_imm() { + flags |= FLAG_AUX2_IMM; + } + if aux3.is_imm() { + // Note: FLAG_AUX3_IMM is not commonly used but is tracked + } + self.instructions.push(BcInstr { + opcode: opcode | flags, + dst: dst.raw(), + src0: src0.raw(), + src1: src1.raw(), + aux0: aux0.raw(), + aux1: aux1.raw(), + aux2: aux2.raw(), + aux3: aux3.raw(), + }); + } + + /// Helper for common case: emit with minimal operands + fn emit_simple(&mut self, opcode: u32, dst: OpVal, src0: OpVal, src1: OpVal) { + let z = OpVal::Reg(0); + self.emit(opcode, dst, src0, src1, z, z, z, z); + } + + fn emit_quantum_op(&mut self, op_id: u32, q1: u32, q2: u32, q3: u32, angle: f64) -> u32 { + let idx = self.next_qop; + self.next_qop += 1; + self.quantum_ops.push(BcQuantumOp { + op_id, + q1, + q2, + q3, + angle, + }); + idx + } + + // ── Operand resolution ────────────────────────────────────────────── + + fn resolve_operand(&mut self, operand: &Operand) -> PyResult { + match operand { + Operand::LocalRef(name) | Operand::TypedLocalRef(name, _) => { + if let Some(®) = self.value_to_reg.get(name.as_str()) { + Ok(OpVal::Reg(reg)) + } else { + // Forward reference — pre-allocate a register + let reg = self.alloc_reg(Some(name), REG_TYPE_I64); + Ok(OpVal::Reg(reg)) + } + } + Operand::IntConst(_, val) => Ok(OpVal::IntImm(i64_to_u32_masked(*val))), + Operand::FloatConst(_, val) => Ok(OpVal::FloatImm(encode_float_as_bits(*val))), + Operand::IntToPtr(val, _) => Ok(OpVal::IntImm(i64_to_u32_masked(*val))), + Operand::NullPtr => { + // Null pointer — materialize as register with value 0 + let reg = self.alloc_reg(None, REG_TYPE_PTR); + self.emit( + OP_CONST | FLAG_SRC0_IMM, + OpVal::Reg(reg), + OpVal::IntImm(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + ); + Ok(OpVal::Reg(reg)) + } + Operand::GlobalRef(_) => { + // Global reference — look up string initializer for label extraction + // Return as immediate 0 (globals are typically used for labels, not values) + Ok(OpVal::IntImm(0)) + } + Operand::GetElementPtr { .. } => Err(PyValueError::new_err( + "GEP operands not supported in adaptive pass", + )), + } + } + + fn type_tag(ty: &Type) -> u32 { + match ty { + Type::Integer(1) => REG_TYPE_BOOL, + Type::Integer(w) if *w <= 32 => REG_TYPE_I32, + Type::Integer(_) => REG_TYPE_I64, + Type::Ptr | Type::NamedPtr(_) | Type::TypedPtr(_) => REG_TYPE_PTR, + Type::Double => REG_TYPE_F64, + _ => REG_TYPE_F32, + } + } + + // ── Function assignment (Pass 1) ──────────────────────────────────── + + fn assign_function(&mut self, func: &Function, entry_idx: usize) { + let is_entry = self + .module + .functions + .iter() + .position(|f| std::ptr::eq(f, func)) + .is_some_and(|idx| idx == entry_idx); + + if !is_entry && !self.func_to_id.contains_key(&func.name) { + let func_id = self.func_to_id.len() as u32; + self.func_to_id.insert(func.name.clone(), func_id); + } + + for block in &func.basic_blocks { + let qname = Self::qualified_block_name(&func.name, &block.name); + let id = self.next_block; + self.next_block += 1; + self.block_to_id.insert(qname, id); + } + } + + // ── Function walking (Pass 2) ─────────────────────────────────────── + + fn walk_function(&mut self, func: &Function, is_entry: bool) -> PyResult<()> { + self.current_func_is_entry = is_entry; + self.current_func_name = func.name.clone(); + + // Clear per-function register name mapping so names from previous + // functions don't leak (e.g. %q0 in one function != %q0 in another). + self.value_to_reg.clear(); + + // For non-entry functions, register parameters as registers + if !is_entry { + let param_base = self.next_reg; + for param in &func.params { + let name = param.name.as_deref(); + self.alloc_reg(name, REG_TYPE_PTR); + } + if let Some(&_func_id) = self.func_to_id.get(&func.name) { + let func_entry_block = self.block_to_id + [&Self::qualified_block_name(&func.name, &func.basic_blocks[0].name)]; + self.functions.push(BcFunction { + entry_block: func_entry_block, + num_params: func.params.len() as u32, + param_base, + }); + } + } + + for block in &func.basic_blocks { + let qname = Self::qualified_block_name(&func.name, &block.name); + let block_id = self.block_to_id[&qname]; + let instr_offset = self.instructions.len() as u32; + for instr in &block.instructions { + self.on_instruction(instr)?; + } + let instr_count = self.instructions.len() as u32 - instr_offset; + self.blocks.push(BcBlock { + block_id, + instr_offset, + instr_count, + }); + } + + Ok(()) + } + + // ── Instruction dispatch ──────────────────────────────────────────── + + fn on_instruction(&mut self, instr: &Instruction) -> PyResult<()> { + match instr { + Instruction::Call { + callee, + args, + result, + return_ty, + .. + } => self.emit_call(callee, args, result.as_deref(), return_ty.as_ref()), + Instruction::Phi { + ty, + incoming, + result, + } => self.emit_phi(ty, incoming, result), + Instruction::ICmp { + pred, + ty: _, + lhs, + rhs, + result, + } => self.emit_icmp(pred, lhs, rhs, result), + Instruction::FCmp { + pred, + ty: _, + lhs, + rhs, + result, + } => self.emit_fcmp(pred, lhs, rhs, result), + Instruction::Switch { + ty: _, + value, + default_dest, + cases, + } => self.emit_switch(value, default_dest, cases), + Instruction::Br { + cond, + true_dest, + false_dest, + .. + } => self.emit_cond_branch(cond, true_dest, false_dest), + Instruction::Jump { dest } => self.emit_jump(dest), + Instruction::Ret(operand) => self.emit_ret(operand.as_ref()), + Instruction::Select { + cond, + true_val, + false_val, + ty, + result, + } => self.emit_select(cond, true_val, false_val, ty, result), + Instruction::BinOp { + op, + ty, + lhs, + rhs, + result, + } => self.emit_binop(op, ty, lhs, rhs, result), + Instruction::Cast { + op, + from_ty, + to_ty, + value, + result, + } => self.emit_cast(op, from_ty, to_ty, value, result), + Instruction::Alloca { .. } + | Instruction::Load { .. } + | Instruction::Store { .. } + | Instruction::GetElementPtr { .. } => { + // Memory instructions not expected in adaptive profile QIR + Err(PyValueError::new_err(format!( + "Unsupported memory instruction in adaptive pass: {instr:?}" + ))) + } + Instruction::Unreachable => Ok(()), + } + } + + // ── BinOp dispatch ────────────────────────────────────────────────── + + fn emit_binop( + &mut self, + op: &BinOpKind, + ty: &Type, + lhs: &Operand, + rhs: &Operand, + result: &str, + ) -> PyResult<()> { + let opcode = match op { + BinOpKind::Add => OP_ADD, + BinOpKind::Sub => OP_SUB, + BinOpKind::Mul => OP_MUL, + BinOpKind::Udiv => OP_UDIV, + BinOpKind::Sdiv => OP_SDIV, + BinOpKind::Urem => OP_UREM, + BinOpKind::Srem => OP_SREM, + BinOpKind::And => OP_AND, + BinOpKind::Or => OP_OR, + BinOpKind::Xor => OP_XOR, + BinOpKind::Shl => OP_SHL, + BinOpKind::Lshr => OP_LSHR, + BinOpKind::Ashr => OP_ASHR, + BinOpKind::Fadd => OP_FADD | FLAG_FLOAT, + BinOpKind::Fsub => OP_FSUB | FLAG_FLOAT, + BinOpKind::Fmul => OP_FMUL | FLAG_FLOAT, + BinOpKind::Fdiv => OP_FDIV | FLAG_FLOAT, + }; + let dst_reg = self.alloc_reg(Some(result), Self::type_tag(ty)); + let s0 = self.resolve_operand(lhs)?; + let s1 = self.resolve_operand(rhs)?; + self.emit_simple(opcode, OpVal::Reg(dst_reg), s0, s1); + Ok(()) + } + + // ── Cast dispatch ─────────────────────────────────────────────────── + + fn emit_cast( + &mut self, + op: &CastKind, + from_ty: &Type, + to_ty: &Type, + value: &Operand, + result: &str, + ) -> PyResult<()> { + match op { + CastKind::Zext => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_ZEXT, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::Sext => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + let src_bits = match from_ty { + Type::Integer(w) => *w, + _ => 32, + }; + let z = OpVal::Reg(0); + self.emit( + OP_SEXT, + OpVal::Reg(dst), + src, + z, + OpVal::IntImm(src_bits), + z, + z, + z, + ); + } + CastKind::Trunc => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_TRUNC, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::FpExt => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_FPEXT | FLAG_FLOAT, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::FpTrunc => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_FPTRUNC | FLAG_FLOAT, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::Fptosi => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_FPTOSI, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::Sitofp => { + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_SITOFP | FLAG_FLOAT, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::IntToPtr => { + // inttoptr is essentially a no-op cast; alias via MOV + let dst = self.alloc_reg(Some(result), REG_TYPE_PTR); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_MOV, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + CastKind::PtrToInt | CastKind::Bitcast => { + // Pass-through casts + let dst = self.alloc_reg(Some(result), Self::type_tag(to_ty)); + let src = self.resolve_operand(value)?; + self.emit_simple(OP_MOV, OpVal::Reg(dst), src, OpVal::Reg(0)); + } + } + Ok(()) + } + + // ── Call dispatch ─────────────────────────────────────────────────── + + fn emit_call( + &mut self, + callee: &str, + args: &[(Type, Operand)], + result: Option<&str>, + return_ty: Option<&Type>, + ) -> PyResult<()> { + match callee { + qis::READ_RESULT | rt::READ_RESULT => { + let dst = self.alloc_reg(result, REG_TYPE_BOOL); + let result_reg = self.resolve_operand(&args[0].1)?; + let z = OpVal::Reg(0); + self.emit(OP_READ_RESULT, OpVal::Reg(dst), result_reg, z, z, z, z, z); + } + name if name.starts_with("__quantum__qis__") => { + self.emit_quantum_call(name, args, result)?; + } + rt::RESULT_RECORD_OUTPUT => { + let result_reg = self.resolve_operand(&args[0].1)?; + let label_str = self.extract_label(&args[1].1); + let label_idx = self.labels.len() as u32; + self.labels.push(label_str); + let z = OpVal::Reg(0); + self.emit( + OP_RECORD_OUTPUT, + z, + result_reg, + z, + OpVal::IntImm(label_idx), + z, + z, + z, + ); + } + rt::ARRAY_RECORD_OUTPUT => { + let count = match args[0].1 { + Operand::IntConst(_, v) => u32::try_from(v).expect("Array length out of range"), + _ => 0, + }; + let label_str = self.extract_label(&args[1].1); + let label_idx = self.labels.len() as u32; + self.labels.push(label_str); + let z = OpVal::Reg(0); + self.emit( + OP_RECORD_OUTPUT, + z, + OpVal::IntImm(count), + z, + OpVal::IntImm(label_idx), + OpVal::IntImm(1), // aux1=1 → array + z, + z, + ); + } + rt::TUPLE_RECORD_OUTPUT => { + let count = match args[0].1 { + Operand::IntConst(_, v) => u32::try_from(v).expect("Tuple length out of range"), + _ => 0, + }; + let label_str = self.extract_label(&args[1].1); + let label_idx = self.labels.len() as u32; + self.labels.push(label_str); + let z = OpVal::Reg(0); + self.emit( + OP_RECORD_OUTPUT, + z, + OpVal::IntImm(count), + z, + OpVal::IntImm(label_idx), + OpVal::IntImm(2), // aux1=2 → tuple + z, + z, + ); + } + rt::BOOL_RECORD_OUTPUT => { + let src = self.resolve_operand(&args[0].1)?; + let label_str = self.extract_label(&args[1].1); + let label_idx = self.labels.len() as u32; + self.labels.push(label_str); + let z = OpVal::Reg(0); + self.emit( + OP_RECORD_OUTPUT, + z, + src, + z, + OpVal::IntImm(label_idx), + OpVal::IntImm(3), // aux1=3 → bool + z, + z, + ); + } + rt::INT_RECORD_OUTPUT => { + let src = self.resolve_operand(&args[0].1)?; + let label_str = self.extract_label(&args[1].1); + let label_idx = self.labels.len() as u32; + self.labels.push(label_str); + let z = OpVal::Reg(0); + self.emit( + OP_RECORD_OUTPUT, + z, + src, + z, + OpVal::IntImm(label_idx), + OpVal::IntImm(4), // aux1=4 → int + z, + z, + ); + } + rt::INITIALIZE + | rt::BEGIN_PARALLEL + | rt::END_PARALLEL + | qis::BARRIER + | rt::READ_LOSS => { + // No-op + } + name if self.func_to_id.contains_key(name) => { + self.emit_ir_function_call(name, args, result, return_ty)?; + } + name if self.is_noise_intrinsic(name) => { + self.emit_noise_intrinsic_call(name, args)?; + } + _ => { + return Err(PyValueError::new_err(format!("Unsupported call: {callee}"))); + } + } + Ok(()) + } + + // ── Quantum call ──────────────────────────────────────────────────── + + fn emit_quantum_call( + &mut self, + callee_name: &str, + args: &[(Type, Operand)], + _result: Option<&str>, + ) -> PyResult<()> { + let gate_name = callee_name + .replace("__quantum__qis__", "") + .replace("__body", ""); + let op_id = gate_name_to_op_id(&gate_name) + .ok_or_else(|| PyValueError::new_err(format!("Unknown quantum gate: {gate_name}")))?; + + if is_measure_gate(&gate_name) { + let q = self.resolve_operand(&args[0].1)?; + let r = self.resolve_operand(&args[1].1)?; + let qop_idx = self.emit_quantum_op(op_id, q.raw(), r.raw(), 0, 0.0); + let z = OpVal::Reg(0); + self.emit(OP_MEASURE, z, z, z, OpVal::IntImm(qop_idx), q, r, z); + return Ok(()); + } + if is_reset_gate(&gate_name) { + let q = self.resolve_operand(&args[0].1)?; + let qop_idx = self.emit_quantum_op(op_id, q.raw(), 0, 0, 0.0); + let z = OpVal::Reg(0); + self.emit(OP_RESET, z, z, z, OpVal::IntImm(qop_idx), q, z, z); + return Ok(()); + } + + let (qubit_offset, angle) = if is_rotation_gate(&gate_name) { + let a = self.resolve_operand(&args[0].1)?; + (1, a) + } else { + (0, OpVal::FloatImm(0)) + }; + + let qubit_args = &args[qubit_offset..]; + let mut qs = [OpVal::IntImm(0); 3]; + for (i, (_, operand)) in qubit_args.iter().enumerate().take(3) { + qs[i] = self.resolve_operand(operand)?; + } + + let qop_idx = self.emit_quantum_op( + op_id, + qs[0].raw(), + qs[1].raw(), + qs[2].raw(), + if angle.is_imm() { + // Decode float bits back to f64 for the quantum op table + f32::from_bits(angle.raw()).into() + } else { + 0.0 // Register-based angle — will be handled at runtime + }, + ); + let z = OpVal::Reg(0); + self.emit( + OP_QUANTUM_GATE, + z, + z, + z, + OpVal::IntImm(qop_idx), + qs[0], + qs[1], + qs[2], + ); + Ok(()) + } + + // ── Noise intrinsic call ──────────────────────────────────────────── + + fn is_noise_intrinsic(&self, name: &str) -> bool { + // Check if the callee has qdk_noise attribute + self.module.functions.iter().any(|f| { + f.name == name + && f.attribute_group_refs.iter().any(|&group_ref| { + self.module + .attribute_groups + .iter() + .find(|ag| ag.id == group_ref) + .is_some_and(|ag| { + ag.attributes.iter().any(|attr| { + matches!(attr, Attribute::StringAttr(s) if s.contains("qdk_noise")) + }) + }) + }) + }) + } + + fn emit_noise_intrinsic_call( + &mut self, + callee_name: &str, + args: &[(Type, Operand)], + ) -> PyResult<()> { + if let Some(noise_map) = &self.noise_intrinsics { + if let Some(&table_id) = noise_map.get(callee_name) { + let qubit_count = args.len() as u32; + let arg_offset = self.call_args.len() as u32; + for (_, operand) in args { + let op = self.resolve_operand(operand)?; + if let OpVal::Reg(r) = op { + self.call_args.push(r); + } else { + let reg = self.alloc_reg(None, REG_TYPE_PTR); + self.emit( + OP_MOV | FLAG_SRC0_IMM, + OpVal::Reg(reg), + OpVal::IntImm(op.raw()), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + ); + self.call_args.push(reg); + } + } + let qop_idx = + self.emit_quantum_op(CORRELATED_NOISE_OP_ID, table_id, qubit_count, 0, 0.0); + let z = OpVal::Reg(0); + self.emit( + OP_QUANTUM_GATE, + z, + z, + z, + OpVal::IntImm(qop_idx), + OpVal::IntImm(qubit_count), + OpVal::IntImm(arg_offset), + z, + ); + } else { + return Err(PyValueError::new_err(format!( + "Missing noise intrinsic: {callee_name}" + ))); + } + } + // No noise config → no-op + Ok(()) + } + + // ── Control flow ──────────────────────────────────────────────────── + + fn emit_jump(&mut self, dest: &str) -> PyResult<()> { + let qname = Self::qualified_block_name(&self.current_func_name, dest); + let target = self + .block_to_id + .get(&qname) + .copied() + .ok_or_else(|| PyValueError::new_err(format!("Unknown block: {dest}")))?; + let z = OpVal::Reg(0); + self.emit(OP_JUMP, OpVal::IntImm(target), z, z, z, z, z, z); + Ok(()) + } + + fn emit_cond_branch( + &mut self, + cond: &Operand, + true_dest: &str, + false_dest: &str, + ) -> PyResult<()> { + let cond_reg = self.resolve_operand(cond)?; + let true_block = + self.block_to_id[&Self::qualified_block_name(&self.current_func_name, true_dest)]; + let false_block = + self.block_to_id[&Self::qualified_block_name(&self.current_func_name, false_dest)]; + let z = OpVal::Reg(0); + self.emit( + OP_BRANCH, + z, + cond_reg, + z, + OpVal::IntImm(true_block), + OpVal::IntImm(false_block), + z, + z, + ); + Ok(()) + } + + fn emit_phi( + &mut self, + ty: &Type, + incoming: &[(Operand, String)], + result: &str, + ) -> PyResult<()> { + let dst_reg = self.alloc_reg(Some(result), Self::type_tag(ty)); + let phi_offset = self.phi_entries.len() as u32; + for (value, block_name) in incoming { + let operand = self.resolve_operand(value)?; + let val_reg = match operand { + OpVal::Reg(r) => r, + _ => { + // Immediate → materialize into register + let reg = self.alloc_reg(None, Self::type_tag(ty)); + self.emit( + OP_MOV | FLAG_SRC0_IMM, + OpVal::Reg(reg), + OpVal::IntImm(operand.raw()), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + ); + reg + } + }; + let qname = Self::qualified_block_name(&self.current_func_name, block_name); + let block_id = self.block_to_id.get(&qname).copied().ok_or_else(|| { + PyValueError::new_err(format!("Unknown phi source block: {block_name}")) + })?; + self.phi_entries.push(BcPhiEntry { block_id, val_reg }); + } + let count = incoming.len() as u32; + let z = OpVal::Reg(0); + self.emit( + OP_PHI, + OpVal::Reg(dst_reg), + z, + z, + OpVal::IntImm(phi_offset), + OpVal::IntImm(count), + z, + z, + ); + Ok(()) + } + + fn emit_select( + &mut self, + cond: &Operand, + true_val: &Operand, + false_val: &Operand, + ty: &Type, + result: &str, + ) -> PyResult<()> { + let dst = self.alloc_reg(Some(result), Self::type_tag(ty)); + let c = self.resolve_operand(cond)?; + let t = self.resolve_operand(true_val)?; + let f = self.resolve_operand(false_val)?; + let z = OpVal::Reg(0); + self.emit(OP_SELECT, OpVal::Reg(dst), c, z, t, f, z, z); + Ok(()) + } + + fn emit_switch( + &mut self, + value: &Operand, + default_dest: &str, + cases: &[(i64, String)], + ) -> PyResult<()> { + let cond_reg = self.resolve_operand(value)?; + let default_block = + self.block_to_id[&Self::qualified_block_name(&self.current_func_name, default_dest)]; + let case_offset = self.switch_cases.len() as u32; + for (case_val, block_name) in cases { + let qname = Self::qualified_block_name(&self.current_func_name, block_name); + let target_block = self.block_to_id[&qname]; + self.switch_cases.push(BcSwitchCase { + case_val: i64_to_u32_masked(*case_val), + target_block, + }); + } + let case_count = cases.len() as u32; + let z = OpVal::Reg(0); + self.emit( + OP_SWITCH, + z, + cond_reg, + z, + OpVal::IntImm(default_block), + OpVal::IntImm(case_offset), + OpVal::IntImm(case_count), + z, + ); + Ok(()) + } + + fn emit_ret(&mut self, operand: Option<&Operand>) -> PyResult<()> { + if !self.current_func_is_entry { + // Return from IR-defined function + if let Some(op) = operand { + let ret_reg = self.resolve_operand(op)?; + let z = OpVal::Reg(0); + self.emit(OP_CALL_RETURN, z, ret_reg, z, z, z, z, z); + } else { + let z = OpVal::Reg(0); + self.emit(OP_CALL_RETURN, z, z, z, z, z, z, z); + } + } else if let Some(op) = operand { + let ret_reg = self.resolve_operand(op)?; + let z = OpVal::Reg(0); + self.emit(OP_RET, ret_reg, z, z, z, z, z, z); + } else { + let z = OpVal::Reg(0); + self.emit(OP_RET, OpVal::IntImm(0), z, z, z, z, z, z); + } + Ok(()) + } + + // ── Comparison ────────────────────────────────────────────────────── + + fn emit_icmp( + &mut self, + pred: &IntPredicate, + lhs: &Operand, + rhs: &Operand, + result: &str, + ) -> PyResult<()> { + let cond_code = match pred { + IntPredicate::Eq => ICMP_EQ, + IntPredicate::Ne => ICMP_NE, + IntPredicate::Slt => ICMP_SLT, + IntPredicate::Sle => ICMP_SLE, + IntPredicate::Sgt => ICMP_SGT, + IntPredicate::Sge => ICMP_SGE, + IntPredicate::Ult => ICMP_ULT, + IntPredicate::Ule => ICMP_ULE, + IntPredicate::Ugt => ICMP_UGT, + IntPredicate::Uge => ICMP_UGE, + }; + let dst = self.alloc_reg(Some(result), REG_TYPE_BOOL); + let s0 = self.resolve_operand(lhs)?; + let s1 = self.resolve_operand(rhs)?; + self.emit_simple(OP_ICMP | (cond_code << 8), OpVal::Reg(dst), s0, s1); + Ok(()) + } + + fn emit_fcmp( + &mut self, + pred: &FloatPredicate, + lhs: &Operand, + rhs: &Operand, + result: &str, + ) -> PyResult<()> { + let cond_code = match pred { + FloatPredicate::Oeq => FCMP_OEQ, + FloatPredicate::Ogt => FCMP_OGT, + FloatPredicate::Oge => FCMP_OGE, + FloatPredicate::Olt => FCMP_OLT, + FloatPredicate::Ole => FCMP_OLE, + FloatPredicate::One => FCMP_ONE, + FloatPredicate::Ord => FCMP_ORD, + FloatPredicate::Uno => FCMP_UNO, + FloatPredicate::Ueq => FCMP_UEQ, + FloatPredicate::Ugt => FCMP_UGT, + FloatPredicate::Uge => FCMP_UGE, + FloatPredicate::Ult => FCMP_ULT, + FloatPredicate::Ule => FCMP_ULE, + FloatPredicate::Une => FCMP_UNE, + }; + let dst = self.alloc_reg(Some(result), REG_TYPE_BOOL); + let s0 = self.resolve_operand(lhs)?; + let s1 = self.resolve_operand(rhs)?; + self.emit_simple( + OP_FCMP | (cond_code << 8) | FLAG_FLOAT, + OpVal::Reg(dst), + s0, + s1, + ); + Ok(()) + } + + // ── IR-defined function call/return ───────────────────────────────── + + fn emit_ir_function_call( + &mut self, + func_name: &str, + args: &[(Type, Operand)], + result: Option<&str>, + return_ty: Option<&Type>, + ) -> PyResult<()> { + let func_id = self.func_to_id[func_name]; + let arg_offset = self.call_args.len() as u32; + for (_, operand) in args { + let op = self.resolve_operand(operand)?; + if let OpVal::Reg(r) = op { + self.call_args.push(r); + } else { + let reg = self.alloc_reg(None, REG_TYPE_PTR); + self.emit( + OP_MOV | FLAG_SRC0_IMM, + OpVal::Reg(reg), + OpVal::IntImm(op.raw()), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + OpVal::Reg(0), + ); + self.call_args.push(reg); + } + } + let is_void = return_ty.is_none() || matches!(return_ty, Some(Type::Void)); + let return_reg = if is_void { + VOID_RETURN + } else { + self.alloc_reg(result, REG_TYPE_I32) + }; + let z = OpVal::Reg(0); + self.emit( + OP_CALL, + OpVal::IntImm(return_reg), + z, + z, + OpVal::IntImm(func_id), + OpVal::IntImm(args.len() as u32), + OpVal::IntImm(arg_offset), + z, + ); + Ok(()) + } + + // ── Helpers ───────────────────────────────────────────────────────── + + fn extract_label(&self, operand: &Operand) -> String { + match operand { + Operand::GlobalRef(name) => { + // Look up global's string initializer + for global in &self.module.globals { + if global.name == *name { + if let Some(Constant::CString(s)) = &global.initializer { + return s.clone(); + } + } + } + String::new() + } + _ => String::new(), + } + } +} + +// ── Python-facing function ────────────────────────────────────────────────── + +/// Compile adaptive-profile QIR text IR into the bytecode dict consumed by +/// `run_adaptive_parallel_shots`. +/// +/// Returns a Python dict with the same keys as `AdaptiveProgram.as_dict()`. +#[pyfunction] +#[pyo3(signature = (ir, noise_intrinsics=None))] +pub fn compile_adaptive_program<'py>( + py: Python<'py>, + ir: &str, + noise_intrinsics: Option<&Bound<'py, PyDict>>, +) -> PyResult> { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + + let entry_idx = find_entry_point(&module) + .ok_or_else(|| PyValueError::new_err("no entry point function found in IR"))?; + + let num_qubits: u32 = get_function_attribute(&module, entry_idx, "required_num_qubits") + .ok_or_else(|| PyValueError::new_err("missing required_num_qubits attribute"))? + .parse() + .map_err(|e| PyValueError::new_err(format!("invalid required_num_qubits: {e}")))?; + + let num_results: u32 = get_function_attribute(&module, entry_idx, "required_num_results") + .ok_or_else(|| PyValueError::new_err("missing required_num_results attribute"))? + .parse() + .map_err(|e| PyValueError::new_err(format!("invalid required_num_results: {e}")))?; + + // Build noise intrinsics lookup from Python dict + let noise_map: Option> = noise_intrinsics.map(|dict| { + let mut map = FxHashMap::default(); + for (key, value) in dict.iter() { + if let (Ok(k), Ok(v)) = (key.extract::(), value.extract::()) { + map.insert(k, v); + } + } + map + }); + + let mut pass = AdaptivePass::new(&module, noise_map); + pass.run(entry_idx)?; + + let entry_func = &module.functions[entry_idx]; + let entry_block_name = + AdaptivePass::qualified_block_name(&entry_func.name, &entry_func.basic_blocks[0].name); + let entry_block = pass.block_to_id[&entry_block_name]; + + // Build the Python dict + let dict = PyDict::new(py); + + dict.set_item("num_qubits", num_qubits)?; + dict.set_item("num_results", num_results)?; + dict.set_item("num_registers", pass.next_reg)?; + dict.set_item("entry_block", entry_block)?; + + // blocks: list of (block_id, instr_offset, instr_count) + let blocks = PyList::empty(py); + for b in &pass.blocks { + let t = PyTuple::new(py, [b.block_id, b.instr_offset, b.instr_count])?; + blocks.append(t)?; + } + dict.set_item("blocks", blocks)?; + + // instructions: list of (opcode, dst, src0, src1, aux0, aux1, aux2, aux3) + let instrs = PyList::empty(py); + for i in &pass.instructions { + let t = PyTuple::new( + py, + [ + i.opcode, i.dst, i.src0, i.src1, i.aux0, i.aux1, i.aux2, i.aux3, + ], + )?; + instrs.append(t)?; + } + dict.set_item("instructions", instrs)?; + + // quantum_ops: list of (op_id, q1, q2, q3, angle) + let qops = PyList::empty(py); + for q in &pass.quantum_ops { + let t = PyTuple::new( + py, + &[ + q.op_id.into_py_any(py)?, + q.q1.into_py_any(py)?, + q.q2.into_py_any(py)?, + q.q3.into_py_any(py)?, + q.angle.into_py_any(py)?, + ], + )?; + qops.append(t)?; + } + dict.set_item("quantum_ops", qops)?; + + // functions: list of (entry_block, num_params, param_base) + let funcs = PyList::empty(py); + for f in &pass.functions { + let t = PyTuple::new(py, [f.entry_block, f.num_params, f.param_base])?; + funcs.append(t)?; + } + dict.set_item("functions", funcs)?; + + // phi_entries: list of (block_id, val_reg) + let phis = PyList::empty(py); + for p in &pass.phi_entries { + let t = PyTuple::new(py, [p.block_id, p.val_reg])?; + phis.append(t)?; + } + dict.set_item("phi_entries", phis)?; + + // switch_cases: list of (case_val, target_block) + let cases = PyList::empty(py); + for s in &pass.switch_cases { + let t = PyTuple::new(py, [s.case_val, s.target_block])?; + cases.append(t)?; + } + dict.set_item("switch_cases", cases)?; + + // call_args: list of u32 + let cargs = PyList::new(py, &pass.call_args)?; + dict.set_item("call_args", cargs)?; + + // labels: list of str + let lbls = PyList::new(py, &pass.labels)?; + dict.set_item("labels", lbls)?; + + // register_types: list of u32 + let rtypes = PyList::new(py, &pass.register_types)?; + dict.set_item("register_types", rtypes)?; + + Ok(dict) +} diff --git a/source/pip/src/qir_simulation/atom_decomp.rs b/source/pip/src/qir_simulation/atom_decomp.rs new file mode 100644 index 0000000000..0b5d4870b3 --- /dev/null +++ b/source/pip/src/qir_simulation/atom_decomp.rs @@ -0,0 +1,553 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom gate decomposition pass. +//! +//! Rust equivalent of `_decomp.py`. Decomposes multi-qubit gates to CZ +//! primitives, single rotations to Rz, single-qubit gates to Rz+SX, +//! Rz Clifford angles to named gates, and Reset to MResetZ. + +use pyo3::{PyResult, exceptions::PyValueError, pyfunction}; +use qsc_llvm::{ + model::Type, + model::{Function, Instruction, Module, Operand, Param}, + parse_module, + qir::{self, double_op, qis, qubit_op, result_op, rt, void_call}, + write_module_to_string, +}; +use std::f64::consts::PI; + +use super::atom_utils::{TOLERANCE, extract_float, extract_id}; + +/// Shorthand: `call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 q to %Qubit*))`. +fn h(q: u32) -> Instruction { + void_call(qis::H, vec![qubit_op(q)]) +} + +fn s(q: u32) -> Instruction { + void_call(qis::S, vec![qubit_op(q)]) +} + +fn s_adj(q: u32) -> Instruction { + void_call(qis::S_ADJ, vec![qubit_op(q)]) +} + +fn t(q: u32) -> Instruction { + void_call(qis::T, vec![qubit_op(q)]) +} + +fn t_adj(q: u32) -> Instruction { + void_call(qis::T_ADJ, vec![qubit_op(q)]) +} + +fn rz(angle: f64, q: u32) -> Instruction { + void_call(qis::RZ, vec![double_op(angle), qubit_op(q)]) +} + +fn rz_passthrough(angle_op: (Type, Operand), q: u32) -> Instruction { + void_call(qis::RZ, vec![angle_op, qubit_op(q)]) +} + +fn cz(q1: u32, q2: u32) -> Instruction { + void_call(qis::CZ, vec![qubit_op(q1), qubit_op(q2)]) +} + +fn sx(q: u32) -> Instruction { + void_call(qis::SX, vec![qubit_op(q)]) +} + +fn z(q: u32) -> Instruction { + void_call(qis::Z, vec![qubit_op(q)]) +} + +fn mresetz(q: u32, r: u32) -> Instruction { + void_call(qis::MRESETZ, vec![qubit_op(q), result_op(r)]) +} + +/// Ensure a declaration for given function name exists in the module. +/// If missing, add a void(...) declaration with the appropriate signature. +pub(crate) fn ensure_declaration(module: &mut Module, name: &str) { + if module.functions.iter().any(|f| f.name == name) { + return; + } + let (ret, params) = match name { + qis::H | qis::S | qis::S_ADJ | qis::T | qis::T_ADJ | qis::SX | qis::Z => ( + Type::Void, + vec![Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string())], + ), + qis::RZ => ( + Type::Void, + vec![ + Type::Double, + Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + ], + ), + qis::CZ => ( + Type::Void, + vec![ + Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + ], + ), + qis::MRESETZ => ( + Type::Void, + vec![ + Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + Type::NamedPtr(qir::RESULT_TYPE_NAME.to_string()), + ], + ), + rt::BEGIN_PARALLEL | rt::END_PARALLEL => (Type::Void, Vec::new()), + qis::MOVE => ( + Type::Void, + vec![ + Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + Type::Integer(64), + Type::Integer(64), + ], + ), + _ => (Type::Void, Vec::new()), + }; + module.functions.push(Function { + name: name.to_string(), + return_type: ret, + params: params + .into_iter() + .map(|ty| Param { ty, name: None }) + .collect(), + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); +} + +fn decompose_multi_qubit_to_cz(module: &mut Module) { + let needed = [ + qis::H, + qis::S, + qis::S_ADJ, + qis::T, + qis::T_ADJ, + qis::RZ, + qis::CZ, + ]; + for n in &needed { + ensure_declaration(module, n); + } + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs = Vec::with_capacity(old_instrs.len()); + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + match callee.as_str() { + name if name == qis::CCX => { + // CCX(ctrl1, ctrl2, target) + if let (Some(c1), Some(c2), Some(tgt)) = ( + extract_id(&args[0].1), + extract_id(&args[1].1), + extract_id(&args[2].1), + ) { + new_instrs.extend([ + h(tgt), + t_adj(c1), + t_adj(c2), + h(c1), + cz(tgt, c1), + h(c1), + t(c1), + h(tgt), + cz(c2, tgt), + h(tgt), + h(c1), + cz(c2, c1), + h(c1), + t(tgt), + t_adj(c1), + h(tgt), + cz(c2, tgt), + h(tgt), + h(c1), + cz(tgt, c1), + h(c1), + t_adj(tgt), + t(c1), + h(c1), + cz(c2, c1), + h(c1), + h(tgt), + ]); + continue; + } + } + name if name == qis::CX => { + if let (Some(ctrl), Some(tgt)) = + (extract_id(&args[0].1), extract_id(&args[1].1)) + { + new_instrs.extend([h(tgt), cz(ctrl, tgt), h(tgt)]); + continue; + } + } + name if name == qis::CY => { + if let (Some(ctrl), Some(tgt)) = + (extract_id(&args[0].1), extract_id(&args[1].1)) + { + new_instrs.extend([ + s_adj(tgt), + h(tgt), + cz(ctrl, tgt), + h(tgt), + s(tgt), + ]); + continue; + } + } + name if name == qis::RXX => { + // rxx(angle, q1, q2) + if let (Some(q1), Some(q2)) = + (extract_id(&args[1].1), extract_id(&args[2].1)) + { + let angle_op = args[0].clone(); + new_instrs.extend([ + h(q2), + cz(q2, q1), + h(q1), + rz_passthrough(angle_op, q1), + h(q1), + cz(q2, q1), + h(q2), + ]); + continue; + } + } + name if name == qis::RYY => { + if let (Some(q1), Some(q2)) = + (extract_id(&args[1].1), extract_id(&args[2].1)) + { + let angle_op = args[0].clone(); + new_instrs.extend([ + s_adj(q1), + s_adj(q2), + h(q2), + cz(q2, q1), + h(q1), + rz_passthrough(angle_op, q1), + h(q1), + cz(q2, q1), + h(q2), + s(q2), + s(q1), + ]); + continue; + } + } + name if name == qis::RZZ => { + if let (Some(q1), Some(q2)) = + (extract_id(&args[1].1), extract_id(&args[2].1)) + { + let angle_op = args[0].clone(); + new_instrs.extend([ + h(q1), + cz(q2, q1), + h(q1), + rz_passthrough(angle_op, q1), + h(q1), + cz(q2, q1), + h(q1), + ]); + continue; + } + } + name if name == qis::SWAP => { + if let (Some(q1), Some(q2)) = + (extract_id(&args[0].1), extract_id(&args[1].1)) + { + new_instrs.extend([ + h(q2), + cz(q1, q2), + h(q2), + h(q1), + cz(q2, q1), + h(q1), + h(q2), + cz(q1, q2), + h(q2), + ]); + continue; + } + } + _ => {} + } + } + // Instruction not decomposed — keep as-is. + new_instrs.push(instr); + } + bb.instructions = new_instrs; + } + } +} + +fn decompose_single_rotation_to_rz(module: &mut Module) { + let needed = [qis::H, qis::S, qis::S_ADJ, qis::RZ]; + for n in &needed { + ensure_declaration(module, n); + } + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs = Vec::with_capacity(old_instrs.len()); + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + match callee.as_str() { + name if name == qis::RX => { + // rx(angle, target) + if let Some(tgt) = extract_id(&args[1].1) { + let angle_op = args[0].clone(); + new_instrs.extend([h(tgt), rz_passthrough(angle_op, tgt), h(tgt)]); + continue; + } + } + name if name == qis::RY => { + if let Some(tgt) = extract_id(&args[1].1) { + let angle_op = args[0].clone(); + new_instrs.extend([ + s_adj(tgt), + h(tgt), + rz_passthrough(angle_op, tgt), + h(tgt), + s(tgt), + ]); + continue; + } + } + _ => {} + } + } + new_instrs.push(instr); + } + bb.instructions = new_instrs; + } + } +} + +fn decompose_single_qubit_to_rz_sx(module: &mut Module) { + let needed = [qis::SX, qis::RZ]; + for n in &needed { + ensure_declaration(module, n); + } + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs = Vec::with_capacity(old_instrs.len()); + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + if let Some(first_arg) = args.first() { + if let Some(tgt) = extract_id(&first_arg.1) { + match callee.as_str() { + name if name == qis::H => { + new_instrs.extend([ + rz(PI / 2.0, tgt), + sx(tgt), + rz(PI / 2.0, tgt), + ]); + continue; + } + name if name == qis::S => { + new_instrs.push(rz(PI / 2.0, tgt)); + continue; + } + name if name == qis::S_ADJ => { + new_instrs.push(rz(-PI / 2.0, tgt)); + continue; + } + name if name == qis::T => { + new_instrs.push(rz(PI / 4.0, tgt)); + continue; + } + name if name == qis::T_ADJ => { + new_instrs.push(rz(-PI / 4.0, tgt)); + continue; + } + name if name == qis::X => { + new_instrs.extend([sx(tgt), sx(tgt)]); + continue; + } + name if name == qis::Y => { + new_instrs.extend([sx(tgt), sx(tgt), rz(PI, tgt)]); + continue; + } + name if name == qis::Z => { + new_instrs.push(rz(PI, tgt)); + continue; + } + _ => {} + } + } + } + } + new_instrs.push(instr); + } + bb.instructions = new_instrs; + } + } +} + +fn decompose_rz_angles_to_clifford(module: &mut Module) { + let needed = [qis::S, qis::S_ADJ, qis::Z]; + for n in &needed { + ensure_declaration(module, n); + } + + let three_pi_over_2 = 3.0 * PI / 2.0; + let pi_over_2 = PI / 2.0; + let two_pi = 2.0 * PI; + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs = Vec::with_capacity(old_instrs.len()); + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + if callee == qis::RZ { + if let (Some(angle), Some(tgt)) = + (extract_float(&args[0].1), extract_id(&args[1].1)) + { + if (angle - three_pi_over_2).abs() < TOLERANCE + || (angle + pi_over_2).abs() < TOLERANCE + { + new_instrs.push(s_adj(tgt)); + } else if (angle - PI).abs() < TOLERANCE + || (angle + PI).abs() < TOLERANCE + { + new_instrs.push(z(tgt)); + } else if (angle - pi_over_2).abs() < TOLERANCE + || (angle + three_pi_over_2).abs() < TOLERANCE + { + new_instrs.push(s(tgt)); + } else if angle.abs() < TOLERANCE + || (angle - two_pi).abs() < TOLERANCE + || (angle + two_pi).abs() < TOLERANCE + { + // Identity — drop. + } else { + // Non-Clifford angle — keep instruction as is. + new_instrs.push(instr); + } + continue; + } + } + } + new_instrs.push(instr); + } + bb.instructions = new_instrs; + } + } +} + +fn replace_reset_with_mresetz(module: &mut Module) { + ensure_declaration(module, qis::MRESETZ); + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + // Find the maximum result id used in this function so we can allocate new ones. + let mut max_result_id: u32 = 0; + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Instruction::Call { args, .. } = &instr { + for (ty, op) in args { + if matches!(ty, Type::NamedPtr(n) if n == qir::RESULT_TYPE_NAME) { + if let Some(id) = extract_id(op) { + if id >= max_result_id { + max_result_id = id + 1; + } + } + } + } + } + } + } + // Also check entry_point attribute for required_num_results. + let mut next_result_id = max_result_id; + + for bb in &mut func.basic_blocks { + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs = Vec::with_capacity(old_instrs.len()); + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + if callee == qis::RESET { + if let Some(q) = extract_id(&args[0].1) { + new_instrs.push(mresetz(q, next_result_id)); + next_result_id += 1; + continue; + } + } + } + new_instrs.push(instr); + } + bb.instructions = new_instrs; + } + } +} + +/// Decompose multi-qubit gates to CZ, single rotations to Rz, and +/// single-qubit gates to Rz+SX. Also replace Reset with MResetZ. +/// +/// This chains the four decomposition sub-passes that `_decomp.py` exposes. +/// The caller (`__init__.py`) invokes individual decompositions via the +/// Python pipeline; this function exposes them as a single native function +/// or individually. +#[pyfunction] +pub fn atom_decompose_multi_qubit_to_cz(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + decompose_multi_qubit_to_cz(&mut module); + Ok(write_module_to_string(&module)) +} + +#[pyfunction] +pub fn atom_decompose_single_rotation_to_rz(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + decompose_single_rotation_to_rz(&mut module); + Ok(write_module_to_string(&module)) +} + +#[pyfunction] +pub fn atom_decompose_single_qubit_to_rz_sx(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + decompose_single_qubit_to_rz_sx(&mut module); + Ok(write_module_to_string(&module)) +} + +#[pyfunction] +pub fn atom_decompose_rz_to_clifford(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + decompose_rz_angles_to_clifford(&mut module); + Ok(write_module_to_string(&module)) +} + +#[pyfunction] +pub fn atom_replace_reset_with_mresetz(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + replace_reset_with_mresetz(&mut module); + Ok(write_module_to_string(&module)) +} diff --git a/source/pip/src/qir_simulation/atom_optimize.rs b/source/pip/src/qir_simulation/atom_optimize.rs new file mode 100644 index 0000000000..be3ee2207c --- /dev/null +++ b/source/pip/src/qir_simulation/atom_optimize.rs @@ -0,0 +1,456 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom single-qubit gate optimization pass. +//! +//! Rust equivalent of `_optimize.py`. Combines adjacent Rz gates, +//! removes identity rotations, cancels adjoint gate pairs, replaces +//! h-s-h → sx, and converts trailing m+reset → mresetz. + +use pyo3::{PyResult, exceptions::PyValueError, pyfunction}; +use qsc_llvm::{ + model::Type, + model::{Instruction, Module, Operand}, + parse_module, + qir::{qis, qubit_op, rt, void_call}, + write_module_to_string, +}; +use rustc_hash::FxHashMap; +use std::f64::consts::PI; + +use super::atom_utils::{TOLERANCE, extract_float, extract_id}; + +/// Classify a call instruction as a named gate on a single qubit, returning +/// `(gate_name, qubit_id, is_rotation)`. +fn classify_single_qubit_gate( + callee: &str, + args: &[(Type, Operand)], +) -> Option<(String, u32, bool)> { + if !callee.starts_with("__quantum__qis__") { + return None; + } + let parts: Vec<&str> = callee.split("__").collect(); + if parts.len() < 5 { + return None; + } + let gate_name = parts[3]; + let suffix = if parts.len() > 4 { parts[4] } else { "" }; + let full_name = if suffix == "adj" { + format!("{gate_name}_adj") + } else { + gate_name.to_string() + }; + // Single-qubit gates: the qubit is the last arg. + let qubit_arg = args.last()?; + if !matches!(&qubit_arg.0, Type::NamedPtr(n) if n == "Qubit") { + return None; + } + let q = extract_id(&qubit_arg.1)?; + let is_rotation = matches!(gate_name, "rx" | "ry" | "rz"); + Some((full_name, q, is_rotation)) +} + +/// Return the "adjoint name" for named gates that cancel with themselves/adjoints. +fn adjoint_of(name: &str) -> &str { + match name { + "h" => "h", + "s" => "s_adj", + "s_adj" => "s", + "t" => "t_adj", + "t_adj" => "t", + "x" => "x", + "y" => "y", + "z" => "z", + _ => "", + } +} + +/// A tracked pending operation on a single qubit. +#[derive(Clone)] +struct PendingOp { + instr: Instruction, + gate: String, +} + +fn optimize_single_qubit_gates(module: &mut Module) { + // Ensure SX and MResetZ declarations exist. + super::atom_decomp::ensure_declaration(module, qis::SX); + super::atom_decomp::ensure_declaration(module, qis::MRESETZ); + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + let mut last_meas: FxHashMap = + FxHashMap::default(); + + for bb in &mut func.basic_blocks { + let mut qubit_ops: FxHashMap> = FxHashMap::default(); + let mut used_qubits: rustc_hash::FxHashSet = rustc_hash::FxHashSet::default(); + let mut local_last_meas: FxHashMap< + u32, + (Instruction, (Type, Operand), (Type, Operand)), + > = FxHashMap::default(); + + let old_instrs = std::mem::take(&mut bb.instructions); + let mut new_instrs: Vec = Vec::with_capacity(old_instrs.len()); + + for instr in old_instrs { + if let Instruction::Call { callee, args, .. } = &instr { + // Special cases. + match callee.as_str() { + s if s == qis::SX || s == qis::MOVE => { + // Drop tracked ops for involved qubits. + if let Some(q) = extract_id(&args[0].1) { + qubit_ops.remove(&q); + local_last_meas.remove(&q); + used_qubits.insert(q); + } + new_instrs.push(instr); + continue; + } + s if s == qis::BARRIER => { + qubit_ops.clear(); + local_last_meas.clear(); + new_instrs.push(instr); + continue; + } + _ => {} + } + + // Measurement: m, mz, mresetz. + if callee == qis::M || callee == qis::MZ || callee == qis::MRESETZ { + if let Some(q) = extract_id(&args[0].1) { + qubit_ops.remove(&q); + used_qubits.insert(q); + local_last_meas + .insert(q, (instr.clone(), args[0].clone(), args[1].clone())); + } + new_instrs.push(instr); + continue; + } + + // Reset. + if callee == qis::RESET + && let Some(q) = extract_id(&args[0].1) + { + if let Some((meas_instr, target_op, result_op_val)) = + local_last_meas.remove(&q) + { + // Replace the last measurement with mresetz. + // First, find and replace the measurement instruction in new_instrs. + if let Some(pos) = new_instrs.iter().position(|i| *i == meas_instr) { + new_instrs[pos] = void_call( + qis::MRESETZ, + vec![target_op.clone(), result_op_val.clone()], + ); + let new_mresetz = new_instrs[pos].clone(); + local_last_meas.insert(q, (new_mresetz, target_op, result_op_val)); + } + // Drop the reset. + continue; + } else if !used_qubits.contains(&q) { + // Qubit was never used; drop the reset. + continue; + } else if qubit_ops + .get(&q) + .is_some_and(|ops| ops.last().is_some_and(|op| op.gate == "reset")) + { + // Last op was also a reset; drop duplicate. + continue; + } else { + qubit_ops.remove(&q); + used_qubits.insert(q); + let ops = qubit_ops.entry(q).or_default(); + ops.push(PendingOp { + instr: instr.clone(), + gate: "reset".to_string(), + }); + new_instrs.push(instr); + continue; + } + } + + // Two-qubit gates: drop tracked ops for both qubits. + if callee.starts_with("__quantum__qis__") + && args.len() >= 2 + && matches!(&args[0].0, Type::NamedPtr(n) if n == "Qubit") + && matches!(&args[1].0, Type::NamedPtr(n) if n == "Qubit") + { + for a in args { + if let Some(q) = extract_id(&a.1) { + qubit_ops.remove(&q); + local_last_meas.remove(&q); + used_qubits.insert(q); + } + } + new_instrs.push(instr); + continue; + } + + // Try to classify as a single qubit gate. + if let Some((gate, q, is_rotation)) = classify_single_qubit_gate(callee, args) { + if is_rotation { + // Rotation folding. + let angle = extract_float(&args[0].1); + if let Some(angle_val) = angle { + if let Some(ops) = qubit_ops.get_mut(&q) { + if let Some(last) = ops.last() { + if last.gate == gate { + // Same rotation type — try to fold. + if let Instruction::Call { + args: prev_args, .. + } = &last.instr + { + if let Some(prev_angle) = + extract_float(&prev_args[0].1) + { + let mut new_angle = angle_val + prev_angle; + let sign = + if new_angle < 0.0 { -1.0 } else { 1.0 }; + let mut abs_angle = new_angle.abs(); + while abs_angle > 2.0 * PI { + abs_angle -= 2.0 * PI; + } + new_angle = sign * abs_angle; + + // Remove the previous instruction from output. + let prev_instr = + ops.pop().expect("just checked").instr; + if let Some(pos) = new_instrs + .iter() + .rposition(|i| *i == prev_instr) + { + new_instrs.remove(pos); + } + + if new_angle.abs() > TOLERANCE + && (new_angle.abs() - 2.0 * PI).abs() + > TOLERANCE + { + // Insert folded rotation. + let folded = void_call( + callee, + vec![ + ( + Type::Double, + Operand::float_const( + Type::Double, + new_angle, + ), + ), + qubit_op(q), + ], + ); + ops.push(PendingOp { + instr: folded.clone(), + gate: gate.clone(), + }); + used_qubits.insert(q); + local_last_meas.remove(&q); + new_instrs.push(folded); + } else if ops.is_empty() { + qubit_ops.remove(&q); + } + continue; + } + } + } + } + } + // Can't fold — just add. + let ops = qubit_ops.entry(q).or_default(); + ops.push(PendingOp { + instr: instr.clone(), + gate, + }); + used_qubits.insert(q); + local_last_meas.remove(&q); + new_instrs.push(instr); + continue; + } + // Non-constant angle — keep. + let ops = qubit_ops.entry(q).or_default(); + ops.push(PendingOp { + instr: instr.clone(), + gate, + }); + used_qubits.insert(q); + local_last_meas.remove(&q); + new_instrs.push(instr); + continue; + } + + // Non-rotation single qubit gate: check for cancellation / h-s-h → sx. + let adj = adjoint_of(&gate); + if let Some(ops) = qubit_ops.get_mut(&q) { + if let Some(last) = ops.last() { + if last.gate == adj { + // Cancel pair. + let prev_instr = ops.pop().expect("just checked").instr; + if let Some(pos) = + new_instrs.iter().rposition(|i| *i == prev_instr) + { + new_instrs.remove(pos); + } + if ops.is_empty() { + qubit_ops.remove(&q); + } + continue; + } + // h-s-h → sx pattern. + if ops.len() >= 2 + && gate == "h" + && last.gate == "s" + && ops[ops.len() - 2].gate == "h" + { + let s_instr = ops.pop().expect("just checked").instr; + let h_instr = ops.pop().expect("just checked").instr; + // Remove the s and first h from output. + if let Some(pos) = + new_instrs.iter().rposition(|i| *i == s_instr) + { + new_instrs.remove(pos); + } + if let Some(pos) = + new_instrs.iter().rposition(|i| *i == h_instr) + { + new_instrs.remove(pos); + } + // Insert sx instead of the second h. + let sx_instr = void_call(qis::SX, vec![qubit_op(q)]); + new_instrs.push(sx_instr); + // Don't track further — drop ops for this qubit. + if ops.is_empty() { + qubit_ops.remove(&q); + } + continue; + } + } + // No cancellation — append. + ops.push(PendingOp { + instr: instr.clone(), + gate, + }); + used_qubits.insert(q); + local_last_meas.remove(&q); + new_instrs.push(instr); + continue; + } + // First operation on this qubit. + qubit_ops.insert( + q, + vec![PendingOp { + instr: instr.clone(), + gate, + }], + ); + used_qubits.insert(q); + local_last_meas.remove(&q); + new_instrs.push(instr); + continue; + } + } + // Non-call instruction — keep. + new_instrs.push(instr); + } + bb.instructions = new_instrs; + // Propagate local_last_meas to function-level last_meas. + for (k, v) in local_last_meas { + last_meas.insert(k, v); + } + } + + // Post-function: convert trailing measurements to mresetz. + for bb in &mut func.basic_blocks { + for (_, (meas_instr, target, res)) in &last_meas { + if let Some(pos) = bb.instructions.iter().position(|i| i == meas_instr) { + bb.instructions[pos] = + void_call(qis::MRESETZ, vec![target.clone(), res.clone()]); + } + } + // Remove trailing resets. + for (_q, ops_vec) in std::iter::empty::<(u32, Vec)>() { + // This is handled inline above. + let _ = ops_vec; + } + } + } +} + +fn prune_unused_functions(module: &mut Module) { + // Collect names of functions called from entry points. + let mut called: rustc_hash::FxHashSet = rustc_hash::FxHashSet::default(); + + // Also track entry points. + let mut entry_points: rustc_hash::FxHashSet = rustc_hash::FxHashSet::default(); + for func in &module.functions { + if !func.is_declaration && !func.basic_blocks.is_empty() { + // Check if this function has an entry_point attribute. + // For simplicity, treat all non-declaration functions as potential entry points. + entry_points.insert(func.name.clone()); + } + } + + // Collect all function calls. + for func in &module.functions { + if func.is_declaration { + continue; + } + for bb in &func.basic_blocks { + for instr in &bb.instructions { + if let Instruction::Call { callee, .. } = instr { + called.insert(callee.clone()); + // Also remove __quantum__rt__initialize and __quantum__qis__barrier__body calls. + } + } + } + } + + // Remove instructions that call init/barrier. + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + bb.instructions.retain(|instr| { + if let Instruction::Call { callee, .. } = instr { + callee != rt::INITIALIZE && callee != qis::BARRIER + } else { + true + } + }); + } + } + + // Prune non-entry functions that are never called. + module.functions.retain(|f| { + if f.is_declaration { + // Keep declarations that are called. + called.contains(&f.name) + } else { + // Keep entry points always. + true + } + }); +} + +/// Optimize single-qubit gate sequences: cancel adjoints, fold rotations, +/// replace h-s-h with sx, convert m+reset to mresetz. +#[pyfunction] +pub fn atom_optimize_single_qubit_gates(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + optimize_single_qubit_gates(&mut module); + Ok(write_module_to_string(&module)) +} + +/// Remove unused function declarations and calls to +/// `__quantum__rt__initialize` and `__quantum__qis__barrier__body`. +#[pyfunction] +pub fn atom_prune_unused_functions(ir: &str) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + prune_unused_functions(&mut module); + Ok(write_module_to_string(&module)) +} diff --git a/source/pip/src/qir_simulation/atom_reorder.rs b/source/pip/src/qir_simulation/atom_reorder.rs new file mode 100644 index 0000000000..e21fac2b51 --- /dev/null +++ b/source/pip/src/qir_simulation/atom_reorder.rs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom instruction reordering pass. +//! +//! Rust equivalent of `_reorder.py`. Reorders instructions within each +//! basic block to find contiguous sequences of the same gate on different +//! qubits, enabling better scheduling during execution. + +use pyo3::{PyResult, exceptions::PyValueError, pyfunction}; +use qsc_llvm::{ + model::Type, + model::{Instruction, Module}, + parse_module, + qir::{operand_key, qis, rt}, + write_module_to_string, +}; +use rustc_hash::FxHashSet; + +use super::atom_utils::extract_id; + +/// A stand-in for value identity. In the Python code every `Value` object +/// has identity; here we use the operand's debug representation as a key +/// (sufficient for inttoptr-based qubit/result pointers). +type ValKey = String; + +/// Return (value_keys, result_keys) used by an instruction — mirrors `get_used_values`. +fn get_used_values(instr: &Instruction) -> (Vec, Vec) { + let mut vals = Vec::new(); + let mut meas = Vec::new(); + + if let Instruction::Call { callee, args, .. } = instr { + match callee.as_str() { + s if s == qis::MRESETZ || s == qis::M || s == qis::MZ => { + // First arg is qubit value, rest are result values. + if let Some(first) = args.first() { + vals.push(operand_key(&first.1)); + } + for a in args.iter().skip(1) { + meas.push(operand_key(&a.1)); + } + } + s if s == qis::READ_RESULT || s == rt::READ_RESULT || s == rt::READ_ATOM_RESULT => { + for a in args { + meas.push(operand_key(&a.1)); + } + } + _ => { + for a in args { + vals.push(operand_key(&a.1)); + } + } + } + } + // Also include the instruction itself as a produced value (for result-producing calls). + // In the Python code this is `vals.append(instr)` — we approximate with a unique repr. + vals.push(format!("{instr:?}")); + (vals, meas) +} + +fn uses_any_value(used: &[ValKey], existing: &FxHashSet) -> bool { + used.iter().any(|v| existing.contains(v)) +} + +fn is_output_recording(instr: &Instruction) -> bool { + if let Instruction::Call { callee, .. } = instr { + callee.ends_with("_record_output") + } else { + false + } +} + +/// Compute a sort key for an instruction based on its first qubit's home ordering. +/// `ordering_fn` maps qubit_id → device ordering index (passed from Python via device config). +fn instr_sort_key(instr: &Instruction, ordering: &[u32]) -> u32 { + if let Instruction::Call { callee, args, .. } = instr { + if callee.starts_with("__quantum__qis__") { + // Find the first qubit argument and use its ordering. + for (ty, op) in args { + if matches!(ty, Type::NamedPtr(n) if n == "Qubit") { + if let Some(id) = extract_id(op) { + return ordering.get(id as usize).copied().unwrap_or(0); + } + } + } + } + } + 0 +} + +fn reorder_block_instructions(instrs: Vec, ordering: &[u32]) -> Vec { + // Separate instructions into steps, preserving dependencies. + let mut steps: Vec> = Vec::new(); + let mut vals_per_step: Vec> = Vec::new(); + let mut results_per_step: Vec> = Vec::new(); + let mut outputs: Vec = Vec::new(); + let mut terminator: Option = None; + + let mut to_process: Vec = Vec::with_capacity(instrs.len()); + for instr in instrs { + // Check if this is a terminator. + if matches!( + instr, + Instruction::Ret(_) + | Instruction::Jump { .. } + | Instruction::Br { .. } + | Instruction::Switch { .. } + | Instruction::Unreachable + ) { + terminator = Some(instr); + continue; + } + if is_output_recording(&instr) { + outputs.push(instr); + continue; + } + to_process.push(instr); + } + + for instr in to_process { + let (used_vals, used_results) = get_used_values(&instr); + + // Find the last step this instruction depends on. + let mut last_dep = steps.len() as i64 - 1; + while last_dep >= 0 { + let idx = last_dep as usize; + if uses_any_value(&used_vals, &vals_per_step[idx]) + || uses_any_value(&used_results, &results_per_step[idx]) + { + break; + } + last_dep -= 1; + } + + // For Call instructions, push forward past steps with different callees + // to group same-gate operations together. + if let Instruction::Call { callee, .. } = &instr { + while (last_dep as usize) < steps.len().saturating_sub(1) { + let next_idx = (last_dep + 1) as usize; + if let Some(first) = steps[next_idx].first() { + if let Instruction::Call { + callee: other_callee, + .. + } = first + { + if callee != other_callee { + last_dep += 1; + continue; + } + } + } + break; + } + } + + let target_step = (last_dep + 1) as usize; + if target_step >= steps.len() { + steps.push(vec![instr]); + vals_per_step.push(used_vals.into_iter().collect()); + results_per_step.push(used_results.into_iter().collect()); + } else { + steps[target_step].push(instr); + vals_per_step[target_step].extend(used_vals); + results_per_step[target_step].extend(used_results); + } + } + + // Flatten steps, sorting within each step by qubit ordering. + let mut result = Vec::new(); + for step in &mut steps { + step.sort_by_key(|i| instr_sort_key(i, ordering)); + result.extend(step.drain(..)); + } + result.extend(outputs); + if let Some(term) = terminator { + result.push(term); + } + result +} + +fn reorder_module(module: &mut Module, ordering: &[u32]) { + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let instrs = std::mem::take(&mut bb.instructions); + bb.instructions = reorder_block_instructions(instrs, ordering); + } + } +} + +/// Reorder instructions within each basic block to group contiguous +/// sequences of the same gate on different qubits. +/// +/// `ordering` is a list of u32 values mapping qubit ID → device ordering +/// index, passed from the Python `Device.get_ordering()` method. +#[pyfunction] +pub fn atom_reorder(ir: &str, ordering: Vec) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + reorder_module(&mut module, &ordering); + Ok(write_module_to_string(&module)) +} diff --git a/source/pip/src/qir_simulation/atom_scheduler.rs b/source/pip/src/qir_simulation/atom_scheduler.rs new file mode 100644 index 0000000000..64c752990c --- /dev/null +++ b/source/pip/src/qir_simulation/atom_scheduler.rs @@ -0,0 +1,656 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom hardware scheduling pass. +//! +//! Rust equivalent of `_scheduler.py`. Groups quantum gates into parallel +//! layers constrained by hardware capabilities: +//! - Inserts `__quantum__rt__begin_parallel` / `end_parallel` markers. +//! - Inserts `__quantum__qis__move__body` instructions for atom rearrangement. +//! +//! The scheduling algorithm operates on a single basic block at a time, +//! batching CZ operations into interaction-zone rows and measurements into +//! measurement-zone slots, with appropriate qubit movement. + +use pyo3::{PyResult, exceptions::PyValueError, pyfunction}; +use qsc_llvm::{ + model::Type, + model::{Function, Instruction, Module, Param}, + parse_module, + qir::{self, i64_op, operand_key, qis, qubit_op, rt, void_call}, + write_module_to_string, +}; +use rustc_hash::FxHashSet; +use std::collections::BTreeMap; + +use super::atom_utils::as_qis_gate; + +const MOVE_GROUPS_PER_PARALLEL_SECTION: usize = 1; + +fn begin_parallel() -> Instruction { + void_call(rt::BEGIN_PARALLEL, vec![]) +} + +fn end_parallel() -> Instruction { + void_call(rt::END_PARALLEL, vec![]) +} + +fn move_instr(qubit_id: u32, row: i64, col: i64) -> Instruction { + void_call( + qis::MOVE, + vec![qubit_op(qubit_id), i64_op(row), i64_op(col)], + ) +} + +#[derive(Clone)] +struct ZoneInfo { + row_count: usize, + offset: usize, // offset in cells (= zone_row_offset * column_count) + zone_type: ZoneKind, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum ZoneKind { + Register, + Interaction, + Measurement, +} + +struct DeviceConfig { + column_count: usize, + zones: Vec, + home_locs: Vec<(i64, i64)>, +} + +impl DeviceConfig { + fn interaction_zone(&self) -> &ZoneInfo { + self.zones + .iter() + .find(|z| z.zone_type == ZoneKind::Interaction) + .expect("device must have an interaction zone") + } + + fn measurement_zone(&self) -> &ZoneInfo { + self.zones + .iter() + .find(|z| z.zone_type == ZoneKind::Measurement) + .expect("device must have a measurement zone") + } + + fn get_home_loc(&self, q: u32) -> (i64, i64) { + self.home_locs.get(q as usize).copied().unwrap_or((0, 0)) + } + + fn get_ordering(&self, q: u32) -> u32 { + let (row, col) = self.get_home_loc(q); + #[allow(clippy::cast_sign_loss)] + let val = row as u32 * self.column_count as u32 + col as u32; + val + } +} + +type Location = (i64, i64); + +struct MoveOp { + qubit_id: u32, + src_loc: Location, + dst_loc: Location, +} + +/// Schedule moves for a set of qubits into a target zone. +/// Returns groups of moves that can be executed in parallel. +/// +/// This is a simplified version of the Python `MoveScheduler` that focuses +/// on assigning destinations and grouping, without the sophisticated +/// scale-factor-based move-group optimization (which can be added later +/// for optimal parallelism). +fn schedule_moves( + device: &DeviceConfig, + zone: &ZoneInfo, + qubits_to_move: &[QubitToMove], +) -> Vec> { + let zone_row_offset = zone.offset / device.column_count; + let mut available: BTreeMap = BTreeMap::new(); + for row in zone_row_offset..(zone_row_offset + zone.row_count) { + for col in 0..device.column_count { + available.insert((row as i64, col as i64), ()); + } + } + + let mut all_moves: Vec = Vec::new(); + + for qtm in qubits_to_move { + match qtm { + QubitToMove::Single(q) => { + let src = device.get_home_loc(*q); + // Prefer straight-up/down move (same column). + let mut dst = None; + for row in zone_row_offset..(zone_row_offset + zone.row_count) { + let loc = (row as i64, src.1); + if available.contains_key(&loc) { + dst = Some(loc); + break; + } + } + if dst.is_none() { + // Fallback: any available. + dst = available.keys().next().copied(); + } + if let Some(d) = dst { + available.remove(&d); + all_moves.push(MoveOp { + qubit_id: *q, + src_loc: src, + dst_loc: d, + }); + } + } + QubitToMove::Pair(q1, q2) => { + let src1 = device.get_home_loc(*q1); + // CZ pair: place on adjacent even/odd columns in same row. + let mut dst1 = None; + let mut dst2 = None; + let src_col = if src1.1 % 2 == 0 { src1.1 } else { src1.1 - 1 }; + for row in zone_row_offset..(zone_row_offset + zone.row_count) { + let loc1 = (row as i64, src_col); + let loc2 = (row as i64, src_col + 1); + if available.contains_key(&loc1) && available.contains_key(&loc2) { + dst1 = Some(loc1); + dst2 = Some(loc2); + break; + } + } + // Fallback: find any adjacent pair. + if dst1.is_none() { + for row in zone_row_offset..(zone_row_offset + zone.row_count) { + for col in (0..device.column_count).step_by(2) { + let loc1 = (row as i64, col as i64); + let loc2 = (row as i64, col as i64 + 1); + if available.contains_key(&loc1) && available.contains_key(&loc2) { + dst1 = Some(loc1); + dst2 = Some(loc2); + break; + } + } + if dst1.is_some() { + break; + } + } + } + if let (Some(d1), Some(d2)) = (dst1, dst2) { + available.remove(&d1); + available.remove(&d2); + all_moves.push(MoveOp { + qubit_id: *q1, + src_loc: device.get_home_loc(*q1), + dst_loc: d1, + }); + all_moves.push(MoveOp { + qubit_id: *q2, + src_loc: device.get_home_loc(*q2), + dst_loc: d2, + }); + } + } + } + } + + // Group moves: for simplicity, put all moves in one group sorted by qubit ID. + // The Python code uses a sophisticated MoveGroupPool; this simplified version + // just returns a single group, which is correct but less optimal for parallelism. + if all_moves.is_empty() { + Vec::new() + } else { + all_moves.sort_by_key(|m| m.qubit_id); + vec![all_moves] + } +} + +#[derive(Clone)] +enum QubitToMove { + Single(u32), + Pair(u32, u32), +} + +struct SchedulerState<'a> { + device: &'a DeviceConfig, + num_qubits: usize, + single_qubit_ops: Vec>, // per-qubit queued ops + curr_cz_ops: Vec, + measurements: Vec<(Instruction, String)>, // (instr, gate_name) + pending_qubits_to_move: Vec, + pending_moves: Vec>, + vals_used_in_cz: FxHashSet, + vals_used_in_meas: FxHashSet, + output: Vec, +} + +fn get_call_used_values(instr: &Instruction) -> (Vec, Vec) { + let mut vals = Vec::new(); + let mut meas = Vec::new(); + if let Instruction::Call { callee, args, .. } = instr { + match callee.as_str() { + s if s == qis::MRESETZ || s == qis::M || s == qis::MZ => { + if let Some(first) = args.first() { + vals.push(operand_key(&first.1)); + } + for a in args.iter().skip(1) { + meas.push(operand_key(&a.1)); + } + } + _ => { + for a in args { + vals.push(operand_key(&a.1)); + } + } + } + } + (vals, meas) +} + +impl<'a> SchedulerState<'a> { + fn new(device: &'a DeviceConfig) -> Self { + let n = device.home_locs.len(); + Self { + device, + num_qubits: n, + single_qubit_ops: vec![Vec::new(); n], + curr_cz_ops: Vec::new(), + measurements: Vec::new(), + pending_qubits_to_move: Vec::new(), + pending_moves: Vec::new(), + vals_used_in_cz: FxHashSet::default(), + vals_used_in_meas: FxHashSet::default(), + output: Vec::new(), + } + } + + fn any_pending_sq(&self) -> bool { + self.single_qubit_ops.iter().any(|ops| !ops.is_empty()) + } + + fn any_pending_cz(&self) -> bool { + !self.curr_cz_ops.is_empty() + } + + fn any_pending_meas(&self) -> bool { + !self.measurements.is_empty() + } + + fn any_pending(&self) -> bool { + self.any_pending_cz() || self.any_pending_sq() || self.any_pending_meas() + } + + fn insert_moves(&mut self) { + let mut group_id = 0usize; + for group in &self.pending_moves { + if group_id == 0 { + self.output.push(begin_parallel()); + } + for m in group { + self.output + .push(move_instr(m.qubit_id, m.dst_loc.0, m.dst_loc.1)); + } + group_id += 1; + if group_id >= MOVE_GROUPS_PER_PARALLEL_SECTION { + group_id = 0; + self.output.push(end_parallel()); + } + } + if group_id != 0 { + self.output.push(end_parallel()); + } + } + + fn insert_moves_back(&mut self) { + let mut group_id = 0usize; + for group in &self.pending_moves { + if group_id == 0 { + self.output.push(begin_parallel()); + } + for m in group { + self.output + .push(move_instr(m.qubit_id, m.src_loc.0, m.src_loc.1)); + } + group_id += 1; + if group_id >= MOVE_GROUPS_PER_PARALLEL_SECTION { + group_id = 0; + self.output.push(end_parallel()); + } + } + if group_id != 0 { + self.output.push(end_parallel()); + } + self.pending_moves.clear(); + } + + fn target_qubits_by_row(&self, zone: &ZoneInfo) -> Vec> { + let zone_row_offset = zone.offset / self.device.column_count; + let mut by_row: Vec> = vec![Vec::new(); zone.row_count]; + for group in &self.pending_moves { + for m in group { + let row_idx = (m.dst_loc.0 as usize).saturating_sub(zone_row_offset); + if row_idx < zone.row_count { + by_row[row_idx].push(m.qubit_id); + } + } + } + for row in &mut by_row { + row.sort_unstable(); + } + by_row + } + + fn flush_single_qubit_ops(&mut self, target_qubits: &[u32]) { + // Gather ops to flush. + let mut ops_to_flush: Vec> = Vec::new(); + for &q in target_qubits { + let idx = q as usize; + if idx < self.num_qubits { + let mut ops = std::mem::take(&mut self.single_qubit_ops[idx]); + ops.reverse(); + ops_to_flush.push(ops); + } else { + ops_to_flush.push(Vec::new()); + } + } + + while ops_to_flush.iter().any(|ops| !ops.is_empty()) { + // Collect rz ops. + let mut rz_ops = Vec::new(); + for q_ops in &mut ops_to_flush { + if let Some(last) = q_ops.last() { + if last.1 == "rz" { + rz_ops.push(q_ops.pop().expect("just checked").0); + } + } + } + if !rz_ops.is_empty() { + self.output.push(begin_parallel()); + self.output.extend(rz_ops); + self.output.push(end_parallel()); + } + + // Collect sx ops. + let mut sx_ops = Vec::new(); + for q_ops in &mut ops_to_flush { + if let Some(last) = q_ops.last() { + if last.1 == "sx" { + sx_ops.push(q_ops.pop().expect("just checked").0); + } + } + } + if !sx_ops.is_empty() { + self.output.push(begin_parallel()); + self.output.extend(sx_ops); + self.output.push(end_parallel()); + } + } + } + + fn schedule_pending_moves(&mut self, zone: &ZoneInfo) { + let moves = schedule_moves(self.device, zone, &self.pending_qubits_to_move); + self.pending_moves.extend(moves); + self.pending_qubits_to_move.clear(); + } + + fn flush_pending(&mut self) { + let iz = self.device.interaction_zone().clone(); + let mz = self.device.measurement_zone().clone(); + + if self.any_pending_cz() { + self.schedule_pending_moves(&iz); + self.insert_moves(); + let qubits_by_row = self.target_qubits_by_row(&iz); + for row_qubits in &qubits_by_row { + self.flush_single_qubit_ops(row_qubits); + } + self.output.push(begin_parallel()); + let cz_ops = std::mem::take(&mut self.curr_cz_ops); + self.output.extend(cz_ops); + self.output.push(end_parallel()); + self.insert_moves_back(); + self.vals_used_in_cz.clear(); + } else if self.any_pending_meas() { + self.schedule_pending_moves(&mz); + self.insert_moves(); + self.output.push(begin_parallel()); + let meas = std::mem::take(&mut self.measurements); + for (instr, _) in meas { + self.output.push(instr); + } + self.output.push(end_parallel()); + self.vals_used_in_meas.clear(); + self.insert_moves_back(); + } else { + // Single-qubit ops only: move to IZ, execute, move back. + while self.any_pending_sq() { + let mut target_qubits_by_row: Vec> = vec![Vec::new(); iz.row_count]; + let mut curr_row = 0usize; + for q in 0..self.num_qubits { + if !self.single_qubit_ops[q].is_empty() { + target_qubits_by_row[curr_row].push(q as u32); + if target_qubits_by_row[curr_row].len() >= self.device.column_count { + curr_row += 1; + if curr_row >= iz.row_count { + break; + } + } + } + } + for target_qs in &target_qubits_by_row { + for &q in target_qs { + let idx = q as usize; + if idx < self.num_qubits && !self.single_qubit_ops[idx].is_empty() { + // Determine qubit operand from the first instruction. + let gate_name = &self.single_qubit_ops[idx][0].1; + let _qubit_arg_idx = if gate_name == "rz" { 1 } else { 0 }; + self.pending_qubits_to_move.push(QubitToMove::Single(q)); + } + } + } + self.schedule_pending_moves(&iz); + self.insert_moves(); + let qubits_by_row = self.target_qubits_by_row(&iz); + for row_qubits in &qubits_by_row { + self.flush_single_qubit_ops(row_qubits); + } + self.insert_moves_back(); + } + } + } + + fn schedule_block(&mut self, instrs: Vec) { + let iz = self.device.interaction_zone().clone(); + let mz = self.device.measurement_zone().clone(); + let max_iz_pairs = (self.device.column_count / 2) * iz.row_count; + let max_measurements = self.device.column_count * mz.row_count; + + self.single_qubit_ops = vec![Vec::new(); self.num_qubits]; + self.curr_cz_ops.clear(); + self.measurements.clear(); + self.pending_qubits_to_move.clear(); + self.vals_used_in_cz.clear(); + self.vals_used_in_meas.clear(); + + for instr in instrs { + if let Instruction::Call { callee, args, .. } = &instr { + if let Some(gate) = as_qis_gate(callee, args) { + // Single-qubit gate (no result args). + if gate.qubit_args.len() == 1 && gate.result_args.is_empty() { + let q = gate.qubit_args[0]; + + // Check if qubit is involved in pending moves. + let involved_in_moves = + self.pending_qubits_to_move.iter().any(|qtm| match qtm { + QubitToMove::Single(id) => *id == q, + QubitToMove::Pair(id1, id2) => *id1 == q || *id2 == q, + }); + if involved_in_moves { + self.flush_pending(); + } + + if (q as usize) < self.num_qubits { + self.single_qubit_ops[q as usize].push((instr, gate.gate)); + } else { + self.output.push(instr); + } + continue; + } + + // Two-qubit gate (CZ after decomposition). + if gate.qubit_args.len() == 2 { + let (vals, _) = get_call_used_values(&instr); + let val_set: FxHashSet<_> = vals.into_iter().collect(); + if self.any_pending_meas() + || val_set.iter().any(|v| self.vals_used_in_cz.contains(v)) + || self.curr_cz_ops.len() >= max_iz_pairs + { + self.flush_pending(); + } + self.curr_cz_ops.push(instr.clone()); + self.vals_used_in_cz.extend(val_set); + + let q0 = gate.qubit_args[0]; + let q1 = gate.qubit_args[1]; + let home0 = self.device.get_home_loc(q0); + let home1 = self.device.get_home_loc(q1); + if home0.1 > home1.1 { + self.pending_qubits_to_move.push(QubitToMove::Pair(q1, q0)); + } else { + self.pending_qubits_to_move.push(QubitToMove::Pair(q0, q1)); + } + continue; + } + + // Measurement. + if !gate.result_args.is_empty() { + let (vals, _) = get_call_used_values(&instr); + let val_set: FxHashSet<_> = vals.into_iter().collect(); + if !self.measurements.is_empty() + && (self.measurements.len() >= max_measurements + || val_set.iter().any(|v| self.vals_used_in_meas.contains(v))) + { + self.flush_pending(); + } + + // Flush pending single-qubit ops for qubit being measured. + let q = gate.qubit_args[0]; + if (q as usize) < self.num_qubits + && !self.single_qubit_ops[q as usize].is_empty() + { + let temp_meas = std::mem::take(&mut self.measurements); + let temp_moves = std::mem::take(&mut self.pending_qubits_to_move); + self.flush_pending(); + self.measurements = temp_meas; + self.pending_qubits_to_move = temp_moves; + } + + self.measurements.push((instr.clone(), gate.gate)); + self.vals_used_in_meas.extend(val_set); + self.pending_qubits_to_move + .push(QubitToMove::Single(gate.qubit_args[0])); + continue; + } + } + } + + // Non-gate instruction: flush everything, then emit. + while self.any_pending() { + self.flush_pending(); + } + self.output.push(instr); + } + } +} + +fn schedule_module(module: &mut Module, device: &DeviceConfig) { + // Ensure declarations for parallel markers and move function. + super::atom_decomp::ensure_declaration(module, rt::BEGIN_PARALLEL); + super::atom_decomp::ensure_declaration(module, rt::END_PARALLEL); + // Add move function declaration with correct signature. + if !module.functions.iter().any(|f| f.name == qis::MOVE) { + module.functions.push(Function { + name: qis::MOVE.to_string(), + return_type: Type::Void, + params: vec![ + Param { + ty: Type::NamedPtr(qir::QUBIT_TYPE_NAME.to_string()), + name: None, + }, + Param { + ty: Type::Integer(64), + name: None, + }, + Param { + ty: Type::Integer(64), + name: None, + }, + ], + is_declaration: true, + attribute_group_refs: Vec::new(), + basic_blocks: Vec::new(), + }); + } + + for func in &mut module.functions { + if func.is_declaration { + continue; + } + for bb in &mut func.basic_blocks { + let instrs = std::mem::take(&mut bb.instructions); + let mut state = SchedulerState::new(device); + state.schedule_block(instrs); + bb.instructions = state.output; + } + } +} + +/// Schedule quantum operations for neutral-atom hardware. +/// +/// Parameters: +/// - `ir`: LLVM IR text. +/// - `column_count`: device column count. +/// - `zone_row_counts`: list of row counts for each zone. +/// - `zone_types`: list of zone types (0=register, 1=interaction, 2=measurement). +/// - `home_locs`: flat list of (row, col) tuples for each qubit's home location. +#[pyfunction] +pub fn atom_schedule( + ir: &str, + column_count: usize, + zone_row_counts: Vec, + zone_types: Vec, + home_locs: Vec<(i64, i64)>, +) -> PyResult { + let mut module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("parse error: {e}")))?; + + // Build zones. + let mut zones = Vec::new(); + let mut offset = 0usize; + for (rc, zt) in zone_row_counts.iter().zip(zone_types.iter()) { + let kind = match zt { + 0 => ZoneKind::Register, + 1 => ZoneKind::Interaction, + 2 => ZoneKind::Measurement, + _ => { + return Err(PyValueError::new_err(format!("unknown zone type: {zt}"))); + } + }; + zones.push(ZoneInfo { + row_count: *rc, + offset: offset * column_count, + zone_type: kind, + }); + offset += rc; + } + + let device = DeviceConfig { + column_count, + zones, + home_locs, + }; + + schedule_module(&mut module, &device); + Ok(write_module_to_string(&module)) +} diff --git a/source/pip/src/qir_simulation/atom_trace.rs b/source/pip/src/qir_simulation/atom_trace.rs new file mode 100644 index 0000000000..f18ba99d1f --- /dev/null +++ b/source/pip/src/qir_simulation/atom_trace.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom QIR trace pass (read-only). +//! +//! Rust equivalent of `_trace.py`. Walks the entry-point function and +//! builds a trace structure describing the sequence of parallel/serial +//! steps and operations (move, sx, rz, cz, mz). + +use pyo3::{ + Bound, IntoPyObject, IntoPyObjectExt, PyResult, Python, + exceptions::PyValueError, + pyfunction, + types::{PyDict, PyDictMethods, PyList, PyListMethods}, +}; +use qsc_llvm::{ + model::{Instruction, Operand}, + parse_module, + qir::{find_entry_point, get_function_attribute, qis, rt}, +}; +use rustc_hash::FxHashMap; + +use super::atom_utils::extract_id; + +/// Build an execution trace for a neutral-atom QIR program. +/// +/// Returns a Python dict with keys: +/// - `"qubits"`: list of `(row, col)` tuples (home locations, truncated to +/// `required_num_qubits` if present) +/// - `"steps"`: list of step dicts, each with `"id"` (int) and `"ops"` (list of str) +/// +/// `home_locs` is the full device home-location list passed from Python. +#[pyfunction] +pub fn trace_atom_program<'py>( + py: Python<'py>, + ir: &str, + home_locs: Vec<(i64, i64)>, +) -> PyResult> { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + + let entry_idx = find_entry_point(&module) + .ok_or_else(|| PyValueError::new_err("no entry point function found in IR"))?; + + // Optionally truncate home_locs to required_num_qubits + let num_qubits: Option = + get_function_attribute(&module, entry_idx, "required_num_qubits") + .and_then(|s| s.parse().ok()); + + let used_locs = match num_qubits { + Some(n) if n < home_locs.len() => &home_locs[..n], + _ => &home_locs, + }; + + // Build the Python "qubits" list + let py_qubits = PyList::empty(py); + for &(row, col) in used_locs { + let tup = (row, col).into_pyobject(py)?; + py_qubits.append(tup)?; + } + + // Walk the entry function + let mut steps: Vec = Vec::new(); + let mut in_parallel = false; + let mut q_cols: FxHashMap = FxHashMap::default(); + + let entry_func = &module.functions[entry_idx]; + for block in &entry_func.basic_blocks { + for instr in &block.instructions { + if let Instruction::Call { callee, args, .. } = instr { + match callee.as_str() { + s if s == rt::BEGIN_PARALLEL => { + steps.push(Step::new(steps.len())); + in_parallel = true; + } + s if s == rt::END_PARALLEL => { + in_parallel = false; + } + s if s == qis::MOVE => { + if !in_parallel { + steps.push(Step::new(steps.len())); + } + if let (Some(q), Some(row_val), Some(col_val)) = ( + args.first().and_then(|(_, op)| extract_id(op)), + args.get(1).and_then(|(_, op)| extract_int_val(op)), + args.get(2).and_then(|(_, op)| extract_int_val(op)), + ) { + q_cols.insert(q, col_val); + if let Some(step) = steps.last_mut() { + step.ops.push(format!("move({row_val}, {col_val}) {q}")); + } + } + } + s if s == qis::SX => { + if !in_parallel { + steps.push(Step::new(steps.len())); + } + if let Some(q) = args.first().and_then(|(_, op)| extract_id(op)) { + if let Some(step) = steps.last_mut() { + step.ops.push(format!("sx {q}")); + } + } + } + s if s == qis::RZ => { + if !in_parallel { + steps.push(Step::new(steps.len())); + } + if let (Some(angle), Some(q)) = ( + args.first().and_then(|(_, op)| extract_float_val(op)), + args.get(1).and_then(|(_, op)| extract_id(op)), + ) { + if let Some(step) = steps.last_mut() { + step.ops.push(format!("rz({angle}) {q}")); + } + } + } + s if s == qis::CZ => { + if !in_parallel { + steps.push(Step::new(steps.len())); + } + if let (Some(mut q1), Some(mut q2)) = ( + args.first().and_then(|(_, op)| extract_id(op)), + args.get(1).and_then(|(_, op)| extract_id(op)), + ) { + // Sort by column so lower-column qubit comes first + let c1 = q_cols.get(&q1).copied().unwrap_or(-1); + let c2 = q_cols.get(&q2).copied().unwrap_or(-1); + if c1 > c2 { + std::mem::swap(&mut q1, &mut q2); + } + if let Some(step) = steps.last_mut() { + step.ops.push(format!("cz {q1}, {q2}")); + } + } + } + s if s == qis::MRESETZ => { + if !in_parallel { + steps.push(Step::new(steps.len())); + } + if let Some(q) = args.first().and_then(|(_, op)| extract_id(op)) { + if let Some(step) = steps.last_mut() { + step.ops.push(format!("mz {q}")); + } + } + } + _ => {} + } + } + } + } + + // Build final Python dict + let py_steps = PyList::empty(py); + for step in &steps { + let d = PyDict::new(py); + d.set_item("id", step.id)?; + let ops = PyList::empty(py); + for op in &step.ops { + ops.append(op.into_py_any(py)?)?; + } + d.set_item("ops", ops)?; + py_steps.append(d)?; + } + + let result = PyDict::new(py); + result.set_item("qubits", py_qubits)?; + result.set_item("steps", py_steps)?; + Ok(result) +} + +struct Step { + id: usize, + ops: Vec, +} + +impl Step { + fn new(id: usize) -> Self { + Self { + id, + ops: Vec::new(), + } + } +} + +fn extract_int_val(operand: &Operand) -> Option { + match operand { + Operand::IntConst(_, v) => Some(*v), + _ => None, + } +} + +fn extract_float_val(operand: &Operand) -> Option { + match operand { + Operand::FloatConst(_, v) => Some(*v), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::{extract_float_val, extract_int_val}; + + /// Verify that extract helpers work correctly. + #[test] + fn extract_int_and_float() { + use qsc_llvm::{model::Operand, model::Type}; + + assert_eq!( + extract_int_val(&Operand::IntConst(Type::Integer(64), 42)), + Some(42) + ); + assert_eq!(extract_int_val(&Operand::NullPtr), None); + + let f = extract_float_val(&Operand::float_const(Type::Double, std::f64::consts::PI)); + assert!((f.expect("should be Some") - std::f64::consts::PI).abs() < 1e-10); + assert_eq!(extract_float_val(&Operand::NullPtr), None); + } +} diff --git a/source/pip/src/qir_simulation/atom_utils.rs b/source/pip/src/qir_simulation/atom_utils.rs new file mode 100644 index 0000000000..1cceb6dd2b --- /dev/null +++ b/source/pip/src/qir_simulation/atom_utils.rs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Utility helpers for inspecting neutral-atom QIR instructions. +//! +//! Rust equivalent of `_utils.py`. These are used internally by +//! `atom_validate` and `atom_trace`; they are *not* exposed to Python. + +use qsc_llvm::{model::Operand, model::Type, qir}; + +pub(crate) const TOLERANCE: f64 = 1.192_092_9e-7; // Machine epsilon for f32 + +/// Description of a `__quantum__qis__*` gate call. +pub(crate) struct QisGate { + pub gate: String, + pub qubit_args: Vec, + pub result_args: Vec, + pub other_args: Vec, +} + +/// Catch-all for non-qubit, non-result arguments (angles, ints, …). +#[derive(Clone)] +pub(crate) enum OtherArg { + Float(f64), + Int(i64), +} + +/// If the callee name matches `__quantum__qis____`, return +/// a [`QisGate`] with qubit/result IDs extracted from `args`. +pub(crate) fn as_qis_gate(callee: &str, args: &[(Type, Operand)]) -> Option { + if !callee.starts_with("__quantum__qis__") { + return None; + } + let parts: Vec<&str> = callee.split("__").collect(); + // parts: ["", "", "quantum", "", "qis", "", "", "", "", ...] + if parts.len() < 5 { + return None; + } + let gate_name = parts[3]; + let suffix = if parts.len() > 4 { parts[4] } else { "" }; + let gate = if suffix == "adj" { + format!("{gate_name}_adj") + } else { + gate_name.to_string() + }; + + let mut qubit_args = Vec::new(); + let mut result_args = Vec::new(); + let mut other_args = Vec::new(); + + for (ty, operand) in args { + if is_qubit_type(ty) { + if let Some(id) = extract_id(operand) { + qubit_args.push(id); + } + } else if is_result_type(ty) { + if let Some(id) = extract_id(operand) { + result_args.push(id); + } + } else { + match operand { + Operand::FloatConst(_, v) => other_args.push(OtherArg::Float(*v)), + Operand::IntConst(_, v) => other_args.push(OtherArg::Int(*v)), + _ => {} + } + } + } + + Some(QisGate { + gate, + qubit_args, + result_args, + other_args, + }) +} + +/// Check whether a type is `%Qubit*` (pointer to opaque `Qubit` struct). +fn is_qubit_type(ty: &Type) -> bool { + matches!(ty, Type::NamedPtr(n) if n == qir::QUBIT_TYPE_NAME) +} + +/// Check whether a type is `%Result*` (pointer to opaque `Result` struct). +fn is_result_type(ty: &Type) -> bool { + matches!(ty, Type::NamedPtr(n) if n == qir::RESULT_TYPE_NAME) +} + +/// Extract an integer ID from an `inttoptr` or `null` operand. +/// `NullPtr` is treated as ID 0 (pyqir normalizes `inttoptr(i64 0)` to `null`). +pub(crate) fn extract_id(operand: &Operand) -> Option { + match operand { + Operand::IntToPtr(val, _) => u32::try_from(*val).ok(), + Operand::NullPtr => Some(0), + _ => None, + } +} + +/// Extract a float constant from an operand. +pub(crate) fn extract_float(operand: &Operand) -> Option { + match operand { + Operand::FloatConst(_, val) => Some(*val), + _ => None, + } +} + +/// Check if a callee is a measurement gate. +pub(crate) fn is_measurement(callee: &str) -> bool { + matches!(callee, s if s == qir::qis::MRESETZ || s == qir::qis::M || s == qir::qis::MZ) +} + +/// Check if a callee is a quantum instruction (starts with `__quantum__qis__`). +pub(crate) fn is_qubit_instruction(callee: &str) -> bool { + callee.starts_with("__quantum__qis__") +} diff --git a/source/pip/src/qir_simulation/atom_validate.rs b/source/pip/src/qir_simulation/atom_validate.rs new file mode 100644 index 0000000000..961292955b --- /dev/null +++ b/source/pip/src/qir_simulation/atom_validate.rs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Neutral-atom QIR validation passes (read-only). +//! +//! Rust equivalent of `_validate.py`. Two checks: +//! +//! 1. **`validate_allowed_intrinsics`** – every non-entry-point function +//! must be in the allow-list. +//! 2. **`validate_no_conditional_branches`** – the entry point must not +//! contain any conditional `br` instructions. + +use pyo3::{PyResult, exceptions::PyValueError, pyfunction}; +use qsc_llvm::{ + model::{Instruction, Module}, + parse_module, + qir::{find_entry_point, qis, rt}, +}; + +/// Allowed function names that are not the entry point. +const ALLOWED_INTRINSICS: &[&str] = &[ + rt::BEGIN_PARALLEL, + rt::END_PARALLEL, + qis::READ_RESULT, + rt::READ_RESULT, + qis::MOVE, + qis::CZ, + qis::SX, + qis::RZ, + qis::MRESETZ, +]; + +/// Check all defined functions are allowed intrinsics. +/// Returns `Err(function_name)` on the first disallowed function. +fn check_allowed_intrinsics(module: &Module) -> Result<(), String> { + let entry_idx = find_entry_point(module); + + for (idx, func) in module.functions.iter().enumerate() { + if func.is_declaration { + continue; + } + if Some(idx) == entry_idx { + continue; + } + let name = &func.name; + if name.ends_with("_record_output") { + continue; + } + if !ALLOWED_INTRINSICS.contains(&name.as_str()) { + return Err(name.clone()); + } + } + Ok(()) +} + +/// Check that the entry-point has no conditional branches. +/// Returns `Ok(())` or `Err(message)`. +fn check_no_conditional_branches(module: &Module) -> Result<(), &'static str> { + let entry_idx = find_entry_point(module).ok_or("no entry point function found in IR")?; + let entry_func = &module.functions[entry_idx]; + + for block in &entry_func.basic_blocks { + for instr in &block.instructions { + if matches!(instr, Instruction::Br { .. }) { + return Err("programs with branching control flow are not supported"); + } + } + } + Ok(()) +} + +/// Validate that the module only contains allowed intrinsics. +/// +/// Raises `ValueError` if a function is found that is not the entry point, +/// not an output-recording intrinsic, and not in the allow-list. +#[pyfunction] +pub fn validate_allowed_intrinsics(ir: &str) -> PyResult<()> { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + check_allowed_intrinsics(&module) + .map_err(|name| PyValueError::new_err(format!("{name} is not a supported intrinsic"))) +} + +/// Validate that the entry-point function contains only unconditional branches. +/// +/// Raises `ValueError` if any basic block has a conditional `br`. +#[pyfunction] +pub fn validate_no_conditional_branches(ir: &str) -> PyResult<()> { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + check_no_conditional_branches(&module).map_err(|msg| PyValueError::new_err(msg)) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Minimal base-profile IR with no branching and only allowed intrinsics. + const VALID_IR: &str = r#" +; ModuleID = 'test' +source_filename = "test" + +%Qubit = type opaque +%Result = type opaque + +define void @main() #0 { +entry: + call void @__quantum__qis__sx__body(%Qubit* inttoptr (i64 0 to %Qubit*)) + call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Qubit* inttoptr (i64 1 to %Qubit*)) + call void @__quantum__qis__mresetz__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*)) + ret void +} + +declare void @__quantum__qis__sx__body(%Qubit*) +declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*) +declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) + +attributes #0 = { "entry_point" "required_num_qubits"="2" "required_num_results"="1" } +"#; + + #[test] + fn valid_ir_passes_both_checks() { + let module = parse_module(VALID_IR).expect("parse"); + check_allowed_intrinsics(&module).expect("should pass"); + check_no_conditional_branches(&module).expect("should pass"); + } + + #[test] + fn disallowed_intrinsic_is_rejected() { + let ir = r#" +; ModuleID = 'test' +source_filename = "test" + +%Qubit = type opaque +%Result = type opaque + +define void @main() #0 { +entry: + ret void +} + +define void @__quantum__qis__h__body(%Qubit* %q) { +entry: + ret void +} + +declare void @__quantum__qis__sx__body(%Qubit*) + +attributes #0 = { "entry_point" "required_num_qubits"="1" "required_num_results"="0" } +"#; + let err = check_allowed_intrinsics(&parse_module(ir).expect("parse")).unwrap_err(); + assert!( + err.contains("h__body"), + "error should mention the disallowed function: {err}" + ); + } + + #[test] + fn conditional_branch_is_rejected() { + let ir = r#" +; ModuleID = 'test' +source_filename = "test" + +%Qubit = type opaque +%Result = type opaque + +define void @main() #0 { +entry: + br i1 true, label %then, label %else + +then: + ret void + +else: + ret void +} + +attributes #0 = { "entry_point" "required_num_qubits"="1" "required_num_results"="0" } +"#; + let err = check_no_conditional_branches(&parse_module(ir).expect("parse")).unwrap_err(); + assert!( + err.contains("branching control flow"), + "unexpected error: {err}" + ); + } +} diff --git a/source/pip/src/qir_simulation/native_qir_parser.rs b/source/pip/src/qir_simulation/native_qir_parser.rs new file mode 100644 index 0000000000..bd7b4940b2 --- /dev/null +++ b/source/pip/src/qir_simulation/native_qir_parser.rs @@ -0,0 +1,566 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fmt::Write; + +use pyo3::{ + Bound, IntoPyObject, PyResult, Python, + exceptions::PyValueError, + pyfunction, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyTuple}, +}; +use qsc_llvm::{ + model::Type, + model::{Attribute, Constant, Instruction, Module, Operand}, + parse_module, + qir::{find_entry_point, get_function_attribute, qis, rt}, +}; +use rustc_hash::FxHashMap; + +use super::QirInstructionId; + +/// Check if a function (by name) has `qdk_noise` in its attribute groups. +fn function_has_qdk_noise(module: &Module, func_name: &str) -> bool { + module + .functions + .iter() + .find(|f| f.name == func_name) + .is_some_and(|func| { + func.attribute_group_refs.iter().any(|&group_ref| { + module + .attribute_groups + .iter() + .find(|ag| ag.id == group_ref) + .is_some_and(|ag| { + ag.attributes + .iter() + .any(|attr| matches!(attr, Attribute::StringAttr(s) if s.contains("qdk_noise"))) + }) + }) + }) +} + +/// Extract qubit/result ID from an `Operand`. +/// In QIR, qubit and result references use `inttoptr` patterns. +/// PyQIR also normalizes `inttoptr (i64 0 to %T*)` to `null`, so we +/// handle `NullPtr` as ID 0. +fn extract_id(operand: &Operand) -> PyResult { + match operand { + Operand::IntToPtr(val, _) => Ok(u32::try_from(*val).map_err(|_| { + PyValueError::new_err(format!("qubit/result ID {val} out of range for u32")) + })?), + Operand::NullPtr => Ok(0), + other => Err(PyValueError::new_err(format!( + "expected inttoptr operand for qubit/result ID, got {other:?}" + ))), + } +} + +/// Extract a float value from an operand (for rotation gate angles). +fn extract_float(operand: &Operand) -> PyResult { + match operand { + Operand::FloatConst(_, val) => Ok(*val), + other => Err(PyValueError::new_err(format!( + "expected float constant for rotation angle, got {other:?}" + ))), + } +} + +/// Extract an integer value from an operand (for array/tuple record output count). +fn extract_int(operand: &Operand) -> PyResult { + match operand { + Operand::IntConst(_, val) => Ok(*val), + other => Err(PyValueError::new_err(format!( + "expected integer constant, got {other:?}" + ))), + } +} + +/// Look up a global variable's string initializer by name. +fn lookup_global_string<'a>(module: &'a Module, name: &str) -> &'a str { + for global in &module.globals { + if global.name == name + && let Some(Constant::CString(s)) = &global.initializer + { + return s.as_str(); + } + } + "" +} + +/// Extract a tag string from an operand, which is typically a `GlobalRef` +/// pointing to a global with a `CString` initializer. +fn extract_tag(module: &Module, operand: &Operand) -> String { + match operand { + Operand::GlobalRef(name) => lookup_global_string(module, name).to_string(), + _ => String::new(), + } +} + +/// Detect the QIR profile from text IR. +/// +/// Parses the IR, finds the entry point function, and reads +/// the "qir_profiles" attribute value. +/// +/// Returns `base_profile`, `adaptive_profile`, or `unknown`. +#[pyfunction] +pub fn get_qir_profile(ir: &str) -> PyResult { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + + let entry_idx = find_entry_point(&module) + .ok_or_else(|| PyValueError::new_err("no entry point function found in IR"))?; + + let profile = get_function_attribute(&module, entry_idx, "qir_profiles").unwrap_or("unknown"); + + Ok(profile.to_string()) +} + +/// Parse Base Profile QIR and extract gate sequence, qubit/result counts, +/// and output format string. +/// +/// Returns a tuple: `(gates_list, num_qubits, num_results, output_format_str)` +/// +/// The `gates_list` contains Python tuples matching the format produced by +/// the `AggregateGatesPass` and `CorrelatedNoisePass` in the Python code. +#[pyfunction] +#[pyo3(signature = (ir, noise_intrinsics=None))] +pub fn parse_base_profile_qir<'py>( + py: Python<'py>, + ir: &str, + noise_intrinsics: Option<&Bound<'py, PyDict>>, +) -> PyResult> { + let module = + parse_module(ir).map_err(|e| PyValueError::new_err(format!("failed to parse IR: {e}")))?; + + let entry_idx = find_entry_point(&module) + .ok_or_else(|| PyValueError::new_err("no entry point function found in IR"))?; + + let num_qubits = get_function_attribute(&module, entry_idx, "required_num_qubits") + .ok_or_else(|| PyValueError::new_err("missing required_num_qubits attribute"))? + .parse::() + .map_err(|e| PyValueError::new_err(format!("invalid required_num_qubits: {e}")))?; + + let num_results = get_function_attribute(&module, entry_idx, "required_num_results") + .ok_or_else(|| PyValueError::new_err("missing required_num_results attribute"))? + .parse::() + .map_err(|e| PyValueError::new_err(format!("invalid required_num_results: {e}")))?; + + // Build noise intrinsics lookup: gate_name -> table_id + let noise_map: Option> = + noise_intrinsics.map(|dict: &Bound<'_, PyDict>| { + let mut map = FxHashMap::default(); + for (key, value) in dict.iter() { + if let (Ok(k), Ok(v)) = (key.extract::(), value.extract::()) { + map.insert(k, v); + } + } + map + }); + + let entry_func = &module.functions[entry_idx]; + let gates = PyList::empty(py); + let mut output_str = String::new(); + let mut closers: Vec<&str> = Vec::new(); + let mut counters: Vec = Vec::new(); + + for block in &entry_func.basic_blocks { + // Check for branching control flow + if let Some(last_instr) = block.instructions.last() { + if matches!(last_instr, Instruction::Br { .. }) { + return Err(PyValueError::new_err( + "simulation of programs with branching control flow is not supported", + )); + } + } + + for instr in &block.instructions { + if let Instruction::Call { callee, args, .. } = instr { + process_call_instruction( + py, + &module, + callee, + args, + noise_map.as_ref(), + &gates, + &mut output_str, + &mut closers, + &mut counters, + )?; + } + } + } + + // Close any remaining output format closers + while let Some(closer) = closers.pop() { + output_str.push_str(closer); + counters.pop(); + } + + let result = PyTuple::new( + py, + &[ + gates.into_any(), + num_qubits.into_pyobject(py)?.into_any(), + num_results.into_pyobject(py)?.into_any(), + output_str.into_pyobject(py)?.into_any(), + ], + )?; + + Ok(result) +} + +/// Process a single call instruction, appending to the gate list and/or +/// updating the output format string. +#[allow(clippy::too_many_arguments)] +fn process_call_instruction<'py>( + py: Python<'py>, + module: &Module, + callee: &str, + args: &[(Type, Operand)], + noise_map: Option<&FxHashMap>, + gates: &Bound<'py, PyList>, + output_str: &mut String, + closers: &mut Vec<&str>, + counters: &mut Vec, +) -> PyResult<()> { + // Check noise intrinsics first + if let Some(map) = noise_map { + if let Some(&table_id) = map.get(callee) { + let qubit_ids = PyList::empty(py); + for (_, operand) in args { + qubit_ids.append(extract_id(operand)?)?; + } + let gate_tuple = PyTuple::new( + py, + &[ + QirInstructionId::CorrelatedNoise + .into_pyobject(py)? + .into_any(), + table_id.into_pyobject(py)?.into_any(), + qubit_ids.into_any(), + ], + )?; + gates.append(gate_tuple)?; + return Ok(()); + } + // If running noisy sim and callee is a noise intrinsic but not in the table, error + if function_has_qdk_noise(module, callee) { + return Err(PyValueError::new_err(format!( + "Missing noise intrinsic: {callee}" + ))); + } + } + + if let Some(gate_tuple) = map_quantum_gate(py, callee, args)? { + gates.append(gate_tuple)?; + } else { + process_output_or_runtime_call( + py, module, callee, args, noise_map, gates, output_str, closers, counters, + )?; + } + Ok(()) +} + +/// Map a quantum gate callee name to its Python tuple representation. +/// Returns `None` if the callee is not a recognized quantum gate. +fn map_quantum_gate<'py>( + py: Python<'py>, + callee: &str, + args: &[(Type, Operand)], +) -> PyResult>> { + let tuple = match callee { + qis::CCX => PyTuple::new( + py, + &[ + QirInstructionId::CCX.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[2].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::CX => PyTuple::new( + py, + &[ + QirInstructionId::CX.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::CY => PyTuple::new( + py, + &[ + QirInstructionId::CY.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::CZ => PyTuple::new( + py, + &[ + QirInstructionId::CZ.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::SWAP => PyTuple::new( + py, + &[ + QirInstructionId::SWAP.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RX => PyTuple::new( + py, + &[ + QirInstructionId::RX.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RXX => PyTuple::new( + py, + &[ + QirInstructionId::RXX.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[2].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RY => PyTuple::new( + py, + &[ + QirInstructionId::RY.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RYY => PyTuple::new( + py, + &[ + QirInstructionId::RYY.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[2].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RZ => PyTuple::new( + py, + &[ + QirInstructionId::RZ.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RZZ => PyTuple::new( + py, + &[ + QirInstructionId::RZZ.into_pyobject(py)?.into_any(), + extract_float(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[2].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::H => PyTuple::new( + py, + &[ + QirInstructionId::H.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::S => PyTuple::new( + py, + &[ + QirInstructionId::S.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::S_ADJ => PyTuple::new( + py, + &[ + QirInstructionId::SAdj.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::SX => PyTuple::new( + py, + &[ + QirInstructionId::SX.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::T => PyTuple::new( + py, + &[ + QirInstructionId::T.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::T_ADJ => PyTuple::new( + py, + &[ + QirInstructionId::TAdj.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::X => PyTuple::new( + py, + &[ + QirInstructionId::X.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::Y => PyTuple::new( + py, + &[ + QirInstructionId::Y.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::Z => PyTuple::new( + py, + &[ + QirInstructionId::Z.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::M | qis::MZ => { + let id = if callee == qis::M { + QirInstructionId::M + } else { + QirInstructionId::MZ + }; + PyTuple::new( + py, + &[ + id.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )? + } + qis::MRESETZ => PyTuple::new( + py, + &[ + QirInstructionId::MResetZ.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + extract_id(&args[1].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::RESET => PyTuple::new( + py, + &[ + QirInstructionId::RESET.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + qis::MOVE => PyTuple::new( + py, + &[ + QirInstructionId::Move.into_pyobject(py)?.into_any(), + extract_id(&args[0].1)?.into_pyobject(py)?.into_any(), + ], + )?, + _ => return Ok(None), + }; + Ok(Some(tuple)) +} + +/// Process output recording and runtime calls that are not quantum gates. +#[allow(clippy::too_many_arguments)] +fn process_output_or_runtime_call<'py>( + py: Python<'py>, + module: &Module, + callee: &str, + args: &[(Type, Operand)], + noise_map: Option<&FxHashMap>, + gates: &Bound<'py, PyList>, + output_str: &mut String, + closers: &mut Vec<&str>, + counters: &mut Vec, +) -> PyResult<()> { + match callee { + rt::RESULT_RECORD_OUTPUT => { + let result_id = extract_id(&args[0].1)?; + let tag = extract_tag(module, &args[1].1); + let gate_tuple = PyTuple::new( + py, + &[ + QirInstructionId::ResultRecordOutput + .into_pyobject(py)? + .into_any(), + result_id.to_string().into_pyobject(py)?.into_any(), + tag.into_pyobject(py)?.into_any(), + ], + )?; + gates.append(gate_tuple)?; + + write!(output_str, "o[{result_id}]").expect("write to string should succeed"); + while !counters.is_empty() { + output_str.push(','); + let last = counters.last_mut().expect("counters should not be empty"); + *last -= 1; + if *last == 0 { + output_str.push_str(closers.pop().expect("closers should match counters")); + counters.pop(); + } else { + break; + } + } + } + rt::TUPLE_RECORD_OUTPUT => { + let count = extract_int(&args[0].1)?; + let tag = extract_tag(module, &args[1].1); + let gate_tuple = PyTuple::new( + py, + &[ + QirInstructionId::TupleRecordOutput + .into_pyobject(py)? + .into_any(), + count.to_string().into_pyobject(py)?.into_any(), + tag.into_pyobject(py)?.into_any(), + ], + )?; + gates.append(gate_tuple)?; + + // Output recording logic + output_str.push('('); + closers.push(")"); + counters.push(count); + } + rt::ARRAY_RECORD_OUTPUT => { + let count = extract_int(&args[0].1)?; + let tag = extract_tag(module, &args[1].1); + let gate_tuple = PyTuple::new( + py, + &[ + QirInstructionId::ArrayRecordOutput + .into_pyobject(py)? + .into_any(), + count.to_string().into_pyobject(py)?.into_any(), + tag.into_pyobject(py)?.into_any(), + ], + )?; + gates.append(gate_tuple)?; + + // Output recording logic + output_str.push('['); + closers.push("]"); + counters.push(count); + } + rt::INITIALIZE | rt::BEGIN_PARALLEL | rt::END_PARALLEL | qis::BARRIER => { + // Skip runtime/barrier calls + } + _ => { + // For noiseless simulation, skip noise intrinsics silently + if noise_map.is_none() && function_has_qdk_noise(module, callee) { + return Ok(()); + } + return Err(PyValueError::new_err(format!( + "Unsupported call instruction: {callee}" + ))); + } + } + Ok(()) +} diff --git a/source/pip/tests-integration/devices/validation/__init__.py b/source/pip/tests-integration/devices/validation/__init__.py index 17793e75fc..d4637606c2 100644 --- a/source/pip/tests-integration/devices/validation/__init__.py +++ b/source/pip/tests-integration/devices/validation/__init__.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pyqir import QirModuleVisitor, is_entry_point, qubit_id, required_num_qubits +try: + from pyqir import QirModuleVisitor, is_entry_point, qubit_id, required_num_qubits +except ImportError: + pass # PyQIR required only for neutral atom validation tests class ValidateBeginEndParallel(QirModuleVisitor): diff --git a/source/pip/tests-integration/test_requirements.txt b/source/pip/tests-integration/test_requirements.txt index e616b40b17..1277c70b9e 100644 --- a/source/pip/tests-integration/test_requirements.txt +++ b/source/pip/tests-integration/test_requirements.txt @@ -1,6 +1,5 @@ pytest==8.2.2 qirrunner==0.9.0 -pyqir<0.12.0 qiskit-aer==0.17.2 qiskit_qasm3_import==0.6.0 expecttest==0.3.0 diff --git a/source/pip/tests/test_adaptive_pass.py b/source/pip/tests/test_adaptive_pass.py index 8e1d0b89fa..6ffabc6322 100644 --- a/source/pip/tests/test_adaptive_pass.py +++ b/source/pip/tests/test_adaptive_pass.py @@ -1,30 +1,65 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Unit tests for AdaptiveProfilePass. +"""Unit tests for the Rust-native adaptive profile compiler. -Tests verify the Python QIR-to-bytecode translation pass by feeding -LLVM IR strings through the pass and checking the resulting -instruction dict encoding. +Tests verify the QIR-to-bytecode compilation by feeding +LLVM IR strings through compile_adaptive_program and checking +the resulting instruction dict encoding. """ -from dataclasses import astuple, asdict -import pyqir +from collections import namedtuple import pytest -from qsharp._adaptive_pass import AdaptiveProfilePass, AdaptiveProgram +from qsharp._native import compile_adaptive_program from qsharp._adaptive_bytecode import * +# --------------------------------------------------------------------------- +# Lightweight wrappers for named field access on the dict returned by +# compile_adaptive_program (which uses tuples for sub-structures). +# --------------------------------------------------------------------------- + +Block = namedtuple("Block", ["block_id", "instr_offset", "instr_count"]) +Instruction = namedtuple( + "Instruction", ["opcode", "dst", "src0", "src1", "aux0", "aux1", "aux2", "aux3"] +) +QuantumOp = namedtuple("QuantumOp", ["op_id", "q1", "q2", "q3", "angle"]) +Function = namedtuple("Function", ["func_entry_block", "num_params", "param_base"]) +PhiNodeEntry = namedtuple("PhiNodeEntry", ["block_id", "val_reg"]) +SwitchCase = namedtuple("SwitchCase", ["case_val", "target_block"]) + + +class AdaptiveProgram: + """Thin wrapper around the dict returned by compile_adaptive_program.""" + + def __init__(self, d: dict): + self._dict = d + self.num_qubits = d["num_qubits"] + self.num_results = d["num_results"] + self.num_registers = d["num_registers"] + self.entry_block = d["entry_block"] + self.blocks = [Block(*t) for t in d["blocks"]] + self.instructions = [Instruction(*t) for t in d["instructions"]] + self.quantum_ops = [QuantumOp(*t) for t in d["quantum_ops"]] + self.functions = [Function(*t) for t in d["functions"]] + self.phi_entries = [PhiNodeEntry(*t) for t in d["phi_entries"]] + self.switch_cases = [SwitchCase(*t) for t in d["switch_cases"]] + self.call_args = list(d["call_args"]) + self.labels = list(d["labels"]) + self.register_types = list(d["register_types"]) + + def as_dict(self): + return self._dict + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _run_pass(ir: str, name: str = "test.ll") -> AdaptiveProgram: - """Parse an LLVM IR string and run through AdaptiveProfilePass.""" - mod = pyqir.Module.from_ir(pyqir.Context(), ir, name) - return AdaptiveProfilePass().run(mod) +def _run_pass(ir: str) -> AdaptiveProgram: + """Compile an LLVM IR string via the native Rust adaptive pass.""" + return AdaptiveProgram(compile_adaptive_program(ir)) def _primary(opcode_word: int) -> int: @@ -734,8 +769,8 @@ def test_bell_loop_funcs_structure(): def test_bell_loop_funcs_function_entry(): """The make_bell function table entry has correct param count and entry block.""" r = _run_pass(BELL_LOOP_FUNCS_QIR) - func = r.functions[0] # (entry_block, num_params, param_base, reserved) - entry_block, num_params, param_base = astuple(func) + func = r.functions[0] # (entry_block, num_params, param_base) + entry_block, num_params, param_base = func assert num_params == 2, "make_bell takes 2 params (%Qubit*, %Qubit*)" # entry_block should be a valid block ID valid_block_ids = {b.block_id for b in r.blocks} @@ -800,7 +835,7 @@ def test_bell_loop_funcs_param_registers(): """make_bell params get allocated registers with PTR type tag.""" r = _run_pass(BELL_LOOP_FUNCS_QIR) func = r.functions[0] - _, num_params, param_base = astuple(func) + _, num_params, param_base = func for i in range(num_params): reg = param_base + i assert ( @@ -918,18 +953,14 @@ def test_instruction_tuple_length(): """All instructions are 8-tuples.""" r = _run_pass(LINEAR_QIR) for i, inst in enumerate(r.instructions): - assert ( - len(astuple(inst)) == 8 - ), f"Instruction {i} has {len(astuple(inst))} fields, expected 8" + assert len(inst) == 8, f"Instruction {i} has {len(inst)} fields, expected 8" def test_quantum_op_tuple_length(): """All quantum ops are 5-tuples.""" r = _run_pass(LINEAR_QIR) for i, qop in enumerate(r.quantum_ops): - assert ( - len(astuple(qop)) == 5 - ), f"Quantum op {i} has {len(astuple(qop))} fields, expected 5" + assert len(qop) == 5, f"Quantum op {i} has {len(qop)} fields, expected 5" ADAPTIVE_RIFLA_QIR = r""" diff --git a/source/pip/tests/test_clifford_simulator.py b/source/pip/tests/test_clifford_simulator.py index 2c47fc0c8e..af6eff745f 100644 --- a/source/pip/tests/test_clifford_simulator.py +++ b/source/pip/tests/test_clifford_simulator.py @@ -2,13 +2,14 @@ # Licensed under the MIT License. from pathlib import Path -import pyqir import qsharp from qsharp._simulation import run_qir_clifford, NoiseConfig +from qsharp._native import ( + validate_no_conditional_branches, + atom_decompose_rz_to_clifford, +) from qsharp._device._atom import NeutralAtomDevice -from qsharp._device._atom._decomp import DecomposeRzAnglesToCliffordGates -from qsharp._device._atom._validate import ValidateNoConditionalBranches from qsharp import TargetProfile, Result current_file_path = Path(__file__) @@ -20,10 +21,10 @@ def transform_to_clifford(input) -> str: native_qir = NeutralAtomDevice().compile(input) - module = pyqir.Module.from_ir(pyqir.Context(), str(native_qir)) - ValidateNoConditionalBranches().run(module) - DecomposeRzAnglesToCliffordGates().run(module) - return str(module) + ir = str(native_qir) + validate_no_conditional_branches(ir) + ir = atom_decompose_rz_to_clifford(ir) + return ir def read_file(file_name: str) -> str: diff --git a/source/qdk_package/pyproject.toml b/source/qdk_package/pyproject.toml index 958c4ad49f..69a24cae95 100644 --- a/source/qdk_package/pyproject.toml +++ b/source/qdk_package/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" authors = [ { name = "Microsoft" } ] license = { file = "LICENSE.txt" } requires-python = ">=3.10" -dependencies = ["qsharp==0.0.0", "pyqir<0.12"] +dependencies = ["qsharp==0.0.0"] [project.optional-dependencies] jupyter = ["qsharp-widgets==0.0.0", "qsharp-jupyterlab==0.0.0"]