Skip to content

Commit

Permalink
[DAGCircuit Oxidation] Refactor bit management in CircuitData (#12372)
Browse files Browse the repository at this point in the history
* Checkpoint before rebase.

* Refactor bit management in CircuitData.

* Revert changes not for this PR.

* Add doc comment for InstructionPacker, rename Interner::lookup.

* Add CircuitData::map_* and refactor.

* Fix merge issue.

* CircuitInstruction::new returns Self.

* Use unbind.

* Make bit types pub.

* Fix merge.
  • Loading branch information
kevinhartman authored Jun 6, 2024
1 parent 1798c3f commit d1c8404
Show file tree
Hide file tree
Showing 7 changed files with 513 additions and 317 deletions.
192 changes: 192 additions & 0 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use crate::BitType;
use hashbrown::HashMap;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyList;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};

/// Private wrapper for Python-side Bit instances that implements
/// [Hash] and [Eq], allowing them to be used in Rust hash-based
/// sets and maps.
///
/// Python's `hash()` is called on the wrapped Bit instance during
/// construction and returned from Rust's [Hash] trait impl.
/// The impl of [PartialEq] first compares the native Py pointers
/// to determine equality. If these are not equal, only then does
/// it call `repr()` on both sides, which has a significant
/// performance advantage.
#[derive(Clone, Debug)]
struct BitAsKey {
/// Python's `hash()` of the wrapped instance.
hash: isize,
/// The wrapped instance.
bit: PyObject,
}

impl BitAsKey {
pub fn new(bit: &Bound<PyAny>) -> Self {
BitAsKey {
// This really shouldn't fail, but if it does,
// we'll just use 0.
hash: bit.hash().unwrap_or(0),
bit: bit.clone().unbind(),
}
}
}

impl Hash for BitAsKey {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_isize(self.hash);
}
}

impl PartialEq for BitAsKey {
fn eq(&self, other: &Self) -> bool {
self.bit.is(&other.bit)
|| Python::with_gil(|py| {
self.bit
.bind(py)
.repr()
.unwrap()
.eq(other.bit.bind(py).repr().unwrap())
.unwrap()
})
}
}

impl Eq for BitAsKey {}

#[derive(Clone, Debug)]
pub(crate) struct BitData<T> {
/// The public field name (i.e. `qubits` or `clbits`).
description: String,
/// Registered Python bits.
bits: Vec<PyObject>,
/// Maps Python bits to native type.
indices: HashMap<BitAsKey, T>,
/// The bits registered, cached as a PyList.
cached: Py<PyList>,
}

pub(crate) struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);

impl<T> BitData<T>
where
T: From<BitType> + Copy,
BitType: From<T>,
{
pub fn new(py: Python<'_>, description: String) -> Self {
BitData {
description,
bits: Vec::new(),
indices: HashMap::new(),
cached: PyList::empty_bound(py).unbind(),
}
}

/// Gets the number of bits.
pub fn len(&self) -> usize {
self.bits.len()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<PyObject> {
&self.bits
}

/// Gets a reference to the cached Python list, maintained by
/// this instance.
#[inline]
pub fn cached(&self) -> &Py<PyList> {
&self.cached
}

/// Finds the native bit index of the given Python bit.
#[inline]
pub fn find(&self, bit: &Bound<PyAny>) -> Option<T> {
self.indices.get(&BitAsKey::new(bit)).copied()
}

/// Map the provided Python bits to their native indices.
/// An error is returned if any bit is not registered.
pub fn map_bits<'py>(
&self,
bits: impl IntoIterator<Item = Bound<'py, PyAny>>,
) -> Result<impl Iterator<Item = T>, BitNotFoundError<'py>> {
let v: Result<Vec<_>, _> = bits
.into_iter()
.map(|b| {
self.indices
.get(&BitAsKey::new(&b))
.copied()
.ok_or_else(|| BitNotFoundError(b))
})
.collect();
v.map(|x| x.into_iter())
}

/// Map the provided native indices to the corresponding Python
/// bit instances.
/// Panics if any of the indices are out of range.
pub fn map_indices(&self, bits: &[T]) -> impl Iterator<Item = &Py<PyAny>> + ExactSizeIterator {
let v: Vec<_> = bits.iter().map(|i| self.get(*i).unwrap()).collect();
v.into_iter()
}

/// Gets the Python bit corresponding to the given native
/// bit index.
#[inline]
pub fn get(&self, index: T) -> Option<&PyObject> {
self.bits.get(<BitType as From<T>>::from(index) as usize)
}

/// Adds a new Python bit.
pub fn add(&mut self, py: Python, bit: &Bound<PyAny>, strict: bool) -> PyResult<()> {
if self.bits.len() != self.cached.bind(bit.py()).len() {
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
));
}
let idx: BitType = self.bits.len().try_into().map_err(|_| {
PyRuntimeError::new_err(format!(
"The number of {} in the circuit has exceeded the maximum capacity",
self.description
))
})?;
if self
.indices
.try_insert(BitAsKey::new(bit), idx.into())
.is_ok()
{
self.bits.push(bit.into_py(py));
self.cached.bind(py).append(bit)?;
} else if strict {
return Err(PyValueError::new_err(format!(
"Existing bit {:?} cannot be re-added in strict mode.",
bit
)));
}
Ok(())
}

/// Called during Python garbage collection, only!.
/// Note: INVALIDATES THIS INSTANCE.
pub fn dispose(&mut self) {
self.indices.clear();
self.bits.clear();
}
}
Loading

0 comments on commit d1c8404

Please sign in to comment.