Skip to content

Commit

Permalink
Remove validation methods from primitive base classes (backport #11052)…
Browse files Browse the repository at this point in the history
… (#11532)

* Remove validation methods from primitive base classes (#11052)

* Remove validation methods from primitive base classes

This deprecates the argument validation methods from primitive base classes and moves them to separate helper functions. These methods unnecessarily bloat the base classes, and are odd to have when the BasePrimitive doesn't even define a run method to validate. There is no reason primitive implementations need to use the same validation as these base classes either. A follow up will be to remove the validation from the base `run` methods and have subclasses implement their own validation.

* Apply suggestions from code review

* Update qiskit/primitives/base/base_estimator.py

---------

Co-authored-by: Ikko Hamamura <[email protected]>
(cherry picked from commit 05d958b)

* Update qiskit/primitives/base/base_estimator.py

* Add missing import

---------

Co-authored-by: Christopher J. Wood <[email protected]>
Co-authored-by: Matthew Treinish <[email protected]>
Co-authored-by: Jake Lishman <[email protected]>
4 people authored Feb 1, 2024
1 parent 3d3edd0 commit 6177feb
Showing 6 changed files with 273 additions and 145 deletions.
39 changes: 10 additions & 29 deletions qiskit/primitives/base/base_estimator.py
Original file line number Diff line number Diff line change
@@ -91,9 +91,10 @@
from qiskit.providers import JobV1 as Job
from qiskit.quantum_info.operators import SparsePauliOp
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_func

from ..utils import init_observable
from .base_primitive import BasePrimitive
from . import validation

if typing.TYPE_CHECKING:
from qiskit.opflow import PauliSumOp
@@ -175,18 +176,11 @@ def run(
TypeError: Invalid argument type given.
ValueError: Invalid argument values given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
observables = self._validate_observables(observables)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
# Validation
circuits, observables, parameter_values = validation._validate_estimator_args(
circuits, observables, parameter_values
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
self._cross_validate_circuits_observables(circuits, observables)

# Options
run_opts = copy(self.options)
run_opts.update_options(**run_options)
@@ -206,34 +200,21 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.")
raise NotImplementedError("The subclass of BaseEstimator must implement `_run` method.")

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_observables(
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
) -> tuple[SparsePauliOp, ...]:
if isinstance(observables, str) or not isinstance(observables, Sequence):
observables = (observables,)
if len(observables) == 0:
raise ValueError("No observables were provided.")
return tuple(init_observable(obs) for obs in observables)
return validation._validate_observables(observables)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_observables(
circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...]
) -> None:
if len(circuits) != len(observables):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of observables ({len(observables)})."
)
for i, (circuit, observable) in enumerate(zip(circuits, observables)):
if circuit.num_qubits != observable.num_qubits:
raise ValueError(
f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does "
f"not match the number of qubits of the {i}-th observable "
f"({observable.num_qubits})."
)
return validation._cross_validate_circuits_observables(circuits, observables)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
79 changes: 11 additions & 68 deletions qiskit/primitives/base/base_primitive.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,11 @@
from abc import ABC
from collections.abc import Sequence

import numpy as np

from qiskit.circuit import QuantumCircuit
from qiskit.providers import Options
from qiskit.utils.deprecation import deprecate_func

from . import validation


class BasePrimitive(ABC):
@@ -49,83 +50,25 @@ def set_options(self, **fields):
self._run_options.update_options(**fields)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
if isinstance(circuits, QuantumCircuit):
circuits = (circuits,)
elif not isinstance(circuits, Sequence) or not all(
isinstance(cir, QuantumCircuit) for cir in circuits
):
raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].")
elif not isinstance(circuits, tuple):
circuits = tuple(circuits)
if len(circuits) == 0:
raise ValueError("No circuits were provided.")
return circuits
return validation._validate_circuits(circuits)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_parameter_values(
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None,
default: Sequence[Sequence[float]] | Sequence[float] | None = None,
) -> tuple[tuple[float, ...], ...]:
# Allow optional (if default)
if parameter_values is None:
if default is None:
raise ValueError("No default `parameter_values`, optional input disallowed.")
parameter_values = default

# Support numpy ndarray
if isinstance(parameter_values, np.ndarray):
parameter_values = parameter_values.tolist()
elif isinstance(parameter_values, Sequence):
parameter_values = tuple(
vector.tolist() if isinstance(vector, np.ndarray) else vector
for vector in parameter_values
)

# Allow single value
if _isreal(parameter_values):
parameter_values = ((parameter_values,),)
elif isinstance(parameter_values, Sequence) and not any(
isinstance(vector, Sequence) for vector in parameter_values
):
parameter_values = (parameter_values,)

# Validation
if (
not isinstance(parameter_values, Sequence)
or not all(isinstance(vector, Sequence) for vector in parameter_values)
or not all(all(_isreal(value) for value in vector) for vector in parameter_values)
):
raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].")

return tuple(tuple(float(value) for value in vector) for vector in parameter_values)
return validation._validate_parameter_values(parameter_values, default=default)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_parameter_values(
circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...]
) -> None:
if len(circuits) != len(parameter_values):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of parameter value sets ({len(parameter_values)})."
)
for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)):
if len(vector) != circuit.num_parameters:
raise ValueError(
f"The number of values ({len(vector)}) does not match "
f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit."
)


def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is int."""
int_types = (int, np.integer)
return isinstance(obj, int_types) and not isinstance(obj, bool)


def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is a real number: int or float except ``±Inf`` and ``NaN``."""
float_types = (float, np.floating)
return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf")
return validation._cross_validate_circuits_parameter_values(
circuits, parameter_values=parameter_values
)
44 changes: 8 additions & 36 deletions qiskit/primitives/base/base_sampler.py
Original file line number Diff line number Diff line change
@@ -80,11 +80,13 @@
from copy import copy
from typing import Generic, TypeVar

from qiskit.circuit import ControlFlowOp, Measure, QuantumCircuit
from qiskit.utils.deprecation import deprecate_func
from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parametertable import ParameterView
from qiskit.providers import JobV1 as Job

from .base_primitive import BasePrimitive
from . import validation

T = TypeVar("T", bound=Job)

@@ -130,15 +132,8 @@ def run(
Raises:
ValueError: Invalid arguments are given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
# Validation
circuits, parameter_values = validation._validate_sampler_args(circuits, parameter_values)

# Options
run_opts = copy(self.options)
@@ -157,27 +152,15 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.")
raise NotImplementedError("The subclass of BaseSampler must implement `_run` method.")

@classmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
cls,
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
circuits = super()._validate_circuits(circuits)
for i, circuit in enumerate(circuits):
if circuit.num_clbits == 0:
raise ValueError(
f"The {i}-th circuit does not have any classical bit. "
"Sampler requires classical bits, plus measurements "
"on the desired qubits."
)
if not _has_measure(circuit):
raise ValueError(
f"The {i}-th circuit does not have Measure instruction. "
"Without measurements, the circuit cannot be sampled from."
)
return circuits
return validation._validate_circuits(circuits, requires_measure=True)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
@@ -196,14 +179,3 @@ def parameters(self) -> tuple[ParameterView, ...]:
List of the parameters in each quantum circuit.
"""
return tuple(self._parameters)


def _has_measure(circuit: QuantumCircuit) -> bool:
for instruction in reversed(circuit):
if isinstance(instruction.operation, Measure):
return True
elif isinstance(instruction.operation, ControlFlowOp):
for block in instruction.operation.blocks:
if _has_measure(block):
return True
return False
231 changes: 231 additions & 0 deletions qiskit/primitives/base/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Primitive validation methods.
Note that these are not intended to be part of the public API of base primitives
but are here for backwards compatibility with deprecated functions.
"""

from __future__ import annotations

from collections.abc import Sequence
import typing
import numpy as np

from qiskit.circuit import QuantumCircuit, ControlFlowOp, Measure
from qiskit.quantum_info.operators import SparsePauliOp
from qiskit.quantum_info.operators.base_operator import BaseOperator

from ..utils import init_observable

if typing.TYPE_CHECKING:
from qiskit.opflow import PauliSumOp


def _validate_estimator_args(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None,
) -> tuple[tuple[QuantumCircuit], tuple[BaseOperator], tuple[tuple[float]]]:
"""Validate run arguments for a reference Estimator.
Args:
circuits: one or more circuit objects.
observables: one or more observable objects.
parameter_values: concrete parameters to be bound.
Returns:
The formatted arguments ``(circuits, observables, parameter_values)``.
Raises:
TypeError: If input arguments are invalid types.
ValueError: if input arguments are invalid values.
"""
# Singular validation
circuits = _validate_circuits(circuits)
observables = _validate_observables(observables)
parameter_values = _validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
)

# Cross-validation
_cross_validate_circuits_parameter_values(circuits, parameter_values)
_cross_validate_circuits_observables(circuits, observables)

return circuits, observables, parameter_values


def _validate_sampler_args(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None,
) -> tuple[tuple[QuantumCircuit], tuple[BaseOperator], tuple[tuple[float]]]:
"""Validate run arguments for a reference Sampler.
Args:
circuits: one or more circuit objects.
parameter_values: concrete parameters to be bound.
Returns:
The formatted arguments ``(circuits, parameter_values)``.
Raises:
TypeError: If input arguments are invalid types.
ValueError: if input arguments are invalid values.
"""
# Singular validation
circuits = _validate_circuits(circuits, requires_measure=True)
parameter_values = _validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
)

# Cross-validation
_cross_validate_circuits_parameter_values(circuits, parameter_values)

return circuits, parameter_values


def _validate_circuits(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
requires_measure: bool = False,
) -> tuple[QuantumCircuit, ...]:
if isinstance(circuits, QuantumCircuit):
circuits = (circuits,)
elif not isinstance(circuits, Sequence) or not all(
isinstance(cir, QuantumCircuit) for cir in circuits
):
raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].")
elif not isinstance(circuits, tuple):
circuits = tuple(circuits)
if len(circuits) == 0:
raise ValueError("No circuits were provided.")

if requires_measure:
for i, circuit in enumerate(circuits):
if circuit.num_clbits == 0:
raise ValueError(
f"The {i}-th circuit does not have any classical bit. "
"Sampler requires classical bits, plus measurements "
"on the desired qubits."
)
if not _has_measure(circuit):
raise ValueError(
f"The {i}-th circuit does not have Measure instruction. "
"Without measurements, the circuit cannot be sampled from."
)
return circuits


def _validate_parameter_values(
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None,
default: Sequence[Sequence[float]] | Sequence[float] | None = None,
) -> tuple[tuple[float, ...], ...]:
# Allow optional (if default)
if parameter_values is None:
if default is None:
raise ValueError("No default `parameter_values`, optional input disallowed.")
parameter_values = default

# Support numpy ndarray
if isinstance(parameter_values, np.ndarray):
parameter_values = parameter_values.tolist()
elif isinstance(parameter_values, Sequence):
parameter_values = tuple(
vector.tolist() if isinstance(vector, np.ndarray) else vector
for vector in parameter_values
)

# Allow single value
if _isreal(parameter_values):
parameter_values = ((parameter_values,),)
elif isinstance(parameter_values, Sequence) and not any(
isinstance(vector, Sequence) for vector in parameter_values
):
parameter_values = (parameter_values,)

# Validation
if (
not isinstance(parameter_values, Sequence)
or not all(isinstance(vector, Sequence) for vector in parameter_values)
or not all(all(_isreal(value) for value in vector) for vector in parameter_values)
):
raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].")

return tuple(tuple(float(value) for value in vector) for vector in parameter_values)


def _validate_observables(
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
) -> tuple[SparsePauliOp, ...]:
if isinstance(observables, str) or not isinstance(observables, Sequence):
observables = (observables,)
if len(observables) == 0:
raise ValueError("No observables were provided.")
return tuple(init_observable(obs) for obs in observables)


def _cross_validate_circuits_parameter_values(
circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...]
) -> None:
if len(circuits) != len(parameter_values):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of parameter value sets ({len(parameter_values)})."
)
for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)):
if len(vector) != circuit.num_parameters:
raise ValueError(
f"The number of values ({len(vector)}) does not match "
f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit."
)


def _cross_validate_circuits_observables(
circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...]
) -> None:
if len(circuits) != len(observables):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of observables ({len(observables)})."
)
for i, (circuit, observable) in enumerate(zip(circuits, observables)):
if circuit.num_qubits != observable.num_qubits:
raise ValueError(
f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does "
f"not match the number of qubits of the {i}-th observable "
f"({observable.num_qubits})."
)


def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is int."""
int_types = (int, np.integer)
return isinstance(obj, int_types) and not isinstance(obj, bool)


def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is a real number: int or float except ``±Inf`` and ``NaN``."""
float_types = (float, np.floating)
return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf")


def _has_measure(circuit: QuantumCircuit) -> bool:
for instruction in reversed(circuit):
if isinstance(instruction.operation, Measure):
return True
elif isinstance(instruction.operation, ControlFlowOp):
for block in instruction.operation.blocks:
if _has_measure(block):
return True
return False
9 changes: 5 additions & 4 deletions test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,8 @@
from qiskit.circuit.library import RealAmplitudes
from qiskit.exceptions import QiskitError
from qiskit.opflow import PauliSumOp
from qiskit.primitives import BaseEstimator, Estimator, EstimatorResult
from qiskit.primitives import Estimator, EstimatorResult
from qiskit.primitives.base import validation, BaseEstimator
from qiskit.primitives.utils import _observable_key
from qiskit.providers import JobV1
from qiskit.quantum_info import Operator, Pauli, PauliList, SparsePauliOp
@@ -388,7 +389,7 @@ class TestObservableValidation(QiskitTestCase):
@unpack
def test_validate_observables(self, obsevables, expected):
"""Test obsevables standardization."""
self.assertEqual(BaseEstimator._validate_observables(obsevables), expected)
self.assertEqual(validation._validate_observables(obsevables), expected)

@data(
(PauliList("IXYZ"), (SparsePauliOp("IXYZ"),)),
@@ -407,13 +408,13 @@ def test_validate_observables_deprecated(self, obsevables, expected):
def test_qiskit_error(self, observables):
"""Test qiskit error if invalid input."""
with self.assertRaises(QiskitError):
BaseEstimator._validate_observables(observables)
validation._validate_observables(observables)

@data((), [])
def test_value_error(self, observables):
"""Test value error if no obsevables are provided."""
with self.assertRaises(ValueError):
BaseEstimator._validate_observables(observables)
validation._validate_observables(observables)


if __name__ == "__main__":
16 changes: 8 additions & 8 deletions test/python/primitives/test_primitive.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@

from qiskit import QuantumCircuit, pulse, transpile
from qiskit.circuit.random import random_circuit
from qiskit.primitives.base.base_primitive import BasePrimitive
from qiskit.primitives.base import validation
from qiskit.primitives.utils import _circuit_key
from qiskit.providers.fake_provider import FakeAlmaden
from qiskit.test import QiskitTestCase
@@ -39,19 +39,19 @@ class TestCircuitValidation(QiskitTestCase):
@unpack
def test_validate_circuits(self, circuits, expected):
"""Test circuits standardization."""
self.assertEqual(BasePrimitive._validate_circuits(circuits), expected)
self.assertEqual(validation._validate_circuits(circuits), expected)

@data(None, "ERROR", True, 0, 1.0, 1j, [0.0])
def test_type_error(self, circuits):
"""Test type error if invalid input."""
with self.assertRaises(TypeError):
BasePrimitive._validate_circuits(circuits)
validation._validate_circuits(circuits)

@data((), [], "")
def test_value_error(self, circuits):
"""Test value error if no circuits are provided."""
with self.assertRaises(ValueError):
BasePrimitive._validate_circuits(circuits)
validation._validate_circuits(circuits)


@ddt
@@ -87,9 +87,9 @@ class TestParameterValuesValidation(QiskitTestCase):
def test_validate_parameter_values(self, _parameter_values, expected):
"""Test parameter_values standardization."""
for parameter_values in [_parameter_values, array(_parameter_values)]: # Numpy
self.assertEqual(BasePrimitive._validate_parameter_values(parameter_values), expected)
self.assertEqual(validation._validate_parameter_values(parameter_values), expected)
self.assertEqual(
BasePrimitive._validate_parameter_values(None, default=parameter_values), expected
validation._validate_parameter_values(None, default=parameter_values), expected
)

@data(
@@ -108,12 +108,12 @@ def test_validate_parameter_values(self, _parameter_values, expected):
def test_type_error(self, parameter_values):
"""Test type error if invalid input."""
with self.assertRaises(TypeError):
BasePrimitive._validate_parameter_values(parameter_values)
validation._validate_parameter_values(parameter_values)

def test_value_error(self):
"""Test value error if no parameter_values or default are provided."""
with self.assertRaises(ValueError):
BasePrimitive._validate_parameter_values(None)
validation._validate_parameter_values(None)


class TestCircuitKey(QiskitTestCase):

0 comments on commit 6177feb

Please sign in to comment.