diff --git a/qiskit/primitives/base/base_estimator.py b/qiskit/primitives/base/base_estimator.py index a980e08b8dad..935813a33c43 100644 --- a/qiskit/primitives/base/base_estimator.py +++ b/qiskit/primitives/base/base_estimator.py @@ -91,9 +91,10 @@ from qiskit.providers import JobV1 as Job from qiskit.quantum_info.operators import SparsePauliOp from qiskit.quantum_info.operators.base_operator import BaseOperator +from qiskit.utils.deprecation import deprecate_func -from ..utils import init_observable from .base_primitive import BasePrimitive +from . import validation if typing.TYPE_CHECKING: from qiskit.opflow import PauliSumOp @@ -175,18 +176,11 @@ def run( TypeError: Invalid argument type given. ValueError: Invalid argument values given. """ - # Singular validation - circuits = self._validate_circuits(circuits) - observables = self._validate_observables(observables) - parameter_values = self._validate_parameter_values( - parameter_values, - default=[()] * len(circuits), + # Validation + circuits, observables, parameter_values = validation._validate_estimator_args( + circuits, observables, parameter_values ) - # Cross-validation - self._cross_validate_circuits_parameter_values(circuits, parameter_values) - self._cross_validate_circuits_observables(circuits, observables) - # Options run_opts = copy(self.options) run_opts.update_options(**run_options) @@ -206,34 +200,21 @@ def _run( parameter_values: tuple[tuple[float, ...], ...], **run_options, ) -> T: - raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.") + raise NotImplementedError("The subclass of BaseEstimator must implement `_run` method.") @staticmethod + @deprecate_func(since="0.46.0") def _validate_observables( observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str, ) -> tuple[SparsePauliOp, ...]: - if isinstance(observables, str) or not isinstance(observables, Sequence): - observables = (observables,) - if len(observables) == 0: - raise ValueError("No observables were provided.") - return tuple(init_observable(obs) for obs in observables) + return validation._validate_observables(observables) @staticmethod + @deprecate_func(since="0.46.0") def _cross_validate_circuits_observables( circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...] ) -> None: - if len(circuits) != len(observables): - raise ValueError( - f"The number of circuits ({len(circuits)}) does not match " - f"the number of observables ({len(observables)})." - ) - for i, (circuit, observable) in enumerate(zip(circuits, observables)): - if circuit.num_qubits != observable.num_qubits: - raise ValueError( - f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does " - f"not match the number of qubits of the {i}-th observable " - f"({observable.num_qubits})." - ) + return validation._cross_validate_circuits_observables(circuits, observables) @property def circuits(self) -> tuple[QuantumCircuit, ...]: diff --git a/qiskit/primitives/base/base_primitive.py b/qiskit/primitives/base/base_primitive.py index a2f8dacdb828..c161ca8094fa 100644 --- a/qiskit/primitives/base/base_primitive.py +++ b/qiskit/primitives/base/base_primitive.py @@ -17,10 +17,11 @@ from abc import ABC from collections.abc import Sequence -import numpy as np - from qiskit.circuit import QuantumCircuit from qiskit.providers import Options +from qiskit.utils.deprecation import deprecate_func + +from . import validation class BasePrimitive(ABC): @@ -49,83 +50,25 @@ def set_options(self, **fields): self._run_options.update_options(**fields) @staticmethod + @deprecate_func(since="0.46.0") def _validate_circuits( circuits: Sequence[QuantumCircuit] | QuantumCircuit, ) -> tuple[QuantumCircuit, ...]: - if isinstance(circuits, QuantumCircuit): - circuits = (circuits,) - elif not isinstance(circuits, Sequence) or not all( - isinstance(cir, QuantumCircuit) for cir in circuits - ): - raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].") - elif not isinstance(circuits, tuple): - circuits = tuple(circuits) - if len(circuits) == 0: - raise ValueError("No circuits were provided.") - return circuits + return validation._validate_circuits(circuits) @staticmethod + @deprecate_func(since="0.46.0") def _validate_parameter_values( parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None, default: Sequence[Sequence[float]] | Sequence[float] | None = None, ) -> tuple[tuple[float, ...], ...]: - # Allow optional (if default) - if parameter_values is None: - if default is None: - raise ValueError("No default `parameter_values`, optional input disallowed.") - parameter_values = default - - # Support numpy ndarray - if isinstance(parameter_values, np.ndarray): - parameter_values = parameter_values.tolist() - elif isinstance(parameter_values, Sequence): - parameter_values = tuple( - vector.tolist() if isinstance(vector, np.ndarray) else vector - for vector in parameter_values - ) - - # Allow single value - if _isreal(parameter_values): - parameter_values = ((parameter_values,),) - elif isinstance(parameter_values, Sequence) and not any( - isinstance(vector, Sequence) for vector in parameter_values - ): - parameter_values = (parameter_values,) - - # Validation - if ( - not isinstance(parameter_values, Sequence) - or not all(isinstance(vector, Sequence) for vector in parameter_values) - or not all(all(_isreal(value) for value in vector) for vector in parameter_values) - ): - raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].") - - return tuple(tuple(float(value) for value in vector) for vector in parameter_values) + return validation._validate_parameter_values(parameter_values, default=default) @staticmethod + @deprecate_func(since="0.46.0") def _cross_validate_circuits_parameter_values( circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...] ) -> None: - if len(circuits) != len(parameter_values): - raise ValueError( - f"The number of circuits ({len(circuits)}) does not match " - f"the number of parameter value sets ({len(parameter_values)})." - ) - for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)): - if len(vector) != circuit.num_parameters: - raise ValueError( - f"The number of values ({len(vector)}) does not match " - f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit." - ) - - -def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool: - """Check if object is int.""" - int_types = (int, np.integer) - return isinstance(obj, int_types) and not isinstance(obj, bool) - - -def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool: - """Check if object is a real number: int or float except ``±Inf`` and ``NaN``.""" - float_types = (float, np.floating) - return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf") + return validation._cross_validate_circuits_parameter_values( + circuits, parameter_values=parameter_values + ) diff --git a/qiskit/primitives/base/base_sampler.py b/qiskit/primitives/base/base_sampler.py index fb6cf557ca59..e3f1ae7ba08f 100644 --- a/qiskit/primitives/base/base_sampler.py +++ b/qiskit/primitives/base/base_sampler.py @@ -80,11 +80,13 @@ from copy import copy from typing import Generic, TypeVar -from qiskit.circuit import ControlFlowOp, Measure, QuantumCircuit +from qiskit.utils.deprecation import deprecate_func +from qiskit.circuit import QuantumCircuit from qiskit.circuit.parametertable import ParameterView from qiskit.providers import JobV1 as Job from .base_primitive import BasePrimitive +from . import validation T = TypeVar("T", bound=Job) @@ -130,15 +132,8 @@ def run( Raises: ValueError: Invalid arguments are given. """ - # Singular validation - circuits = self._validate_circuits(circuits) - parameter_values = self._validate_parameter_values( - parameter_values, - default=[()] * len(circuits), - ) - - # Cross-validation - self._cross_validate_circuits_parameter_values(circuits, parameter_values) + # Validation + circuits, parameter_values = validation._validate_sampler_args(circuits, parameter_values) # Options run_opts = copy(self.options) @@ -157,27 +152,15 @@ def _run( parameter_values: tuple[tuple[float, ...], ...], **run_options, ) -> T: - raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.") + raise NotImplementedError("The subclass of BaseSampler must implement `_run` method.") @classmethod + @deprecate_func(since="0.46.0") def _validate_circuits( cls, circuits: Sequence[QuantumCircuit] | QuantumCircuit, ) -> tuple[QuantumCircuit, ...]: - circuits = super()._validate_circuits(circuits) - for i, circuit in enumerate(circuits): - if circuit.num_clbits == 0: - raise ValueError( - f"The {i}-th circuit does not have any classical bit. " - "Sampler requires classical bits, plus measurements " - "on the desired qubits." - ) - if not _has_measure(circuit): - raise ValueError( - f"The {i}-th circuit does not have Measure instruction. " - "Without measurements, the circuit cannot be sampled from." - ) - return circuits + return validation._validate_circuits(circuits, requires_measure=True) @property def circuits(self) -> tuple[QuantumCircuit, ...]: @@ -196,14 +179,3 @@ def parameters(self) -> tuple[ParameterView, ...]: List of the parameters in each quantum circuit. """ return tuple(self._parameters) - - -def _has_measure(circuit: QuantumCircuit) -> bool: - for instruction in reversed(circuit): - if isinstance(instruction.operation, Measure): - return True - elif isinstance(instruction.operation, ControlFlowOp): - for block in instruction.operation.blocks: - if _has_measure(block): - return True - return False diff --git a/qiskit/primitives/base/validation.py b/qiskit/primitives/base/validation.py new file mode 100644 index 000000000000..7e1bccf2f89c --- /dev/null +++ b/qiskit/primitives/base/validation.py @@ -0,0 +1,231 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Primitive validation methods. + +Note that these are not intended to be part of the public API of base primitives +but are here for backwards compatibility with deprecated functions. +""" + +from __future__ import annotations + +from collections.abc import Sequence +import typing +import numpy as np + +from qiskit.circuit import QuantumCircuit, ControlFlowOp, Measure +from qiskit.quantum_info.operators import SparsePauliOp +from qiskit.quantum_info.operators.base_operator import BaseOperator + +from ..utils import init_observable + +if typing.TYPE_CHECKING: + from qiskit.opflow import PauliSumOp + + +def _validate_estimator_args( + circuits: Sequence[QuantumCircuit] | QuantumCircuit, + observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str, + parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None, +) -> tuple[tuple[QuantumCircuit], tuple[BaseOperator], tuple[tuple[float]]]: + """Validate run arguments for a reference Estimator. + + Args: + circuits: one or more circuit objects. + observables: one or more observable objects. + parameter_values: concrete parameters to be bound. + + Returns: + The formatted arguments ``(circuits, observables, parameter_values)``. + + Raises: + TypeError: If input arguments are invalid types. + ValueError: if input arguments are invalid values. + """ + # Singular validation + circuits = _validate_circuits(circuits) + observables = _validate_observables(observables) + parameter_values = _validate_parameter_values( + parameter_values, + default=[()] * len(circuits), + ) + + # Cross-validation + _cross_validate_circuits_parameter_values(circuits, parameter_values) + _cross_validate_circuits_observables(circuits, observables) + + return circuits, observables, parameter_values + + +def _validate_sampler_args( + circuits: Sequence[QuantumCircuit] | QuantumCircuit, + parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None = None, +) -> tuple[tuple[QuantumCircuit], tuple[BaseOperator], tuple[tuple[float]]]: + """Validate run arguments for a reference Sampler. + + Args: + circuits: one or more circuit objects. + parameter_values: concrete parameters to be bound. + + Returns: + The formatted arguments ``(circuits, parameter_values)``. + + Raises: + TypeError: If input arguments are invalid types. + ValueError: if input arguments are invalid values. + """ + # Singular validation + circuits = _validate_circuits(circuits, requires_measure=True) + parameter_values = _validate_parameter_values( + parameter_values, + default=[()] * len(circuits), + ) + + # Cross-validation + _cross_validate_circuits_parameter_values(circuits, parameter_values) + + return circuits, parameter_values + + +def _validate_circuits( + circuits: Sequence[QuantumCircuit] | QuantumCircuit, + requires_measure: bool = False, +) -> tuple[QuantumCircuit, ...]: + if isinstance(circuits, QuantumCircuit): + circuits = (circuits,) + elif not isinstance(circuits, Sequence) or not all( + isinstance(cir, QuantumCircuit) for cir in circuits + ): + raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].") + elif not isinstance(circuits, tuple): + circuits = tuple(circuits) + if len(circuits) == 0: + raise ValueError("No circuits were provided.") + + if requires_measure: + for i, circuit in enumerate(circuits): + if circuit.num_clbits == 0: + raise ValueError( + f"The {i}-th circuit does not have any classical bit. " + "Sampler requires classical bits, plus measurements " + "on the desired qubits." + ) + if not _has_measure(circuit): + raise ValueError( + f"The {i}-th circuit does not have Measure instruction. " + "Without measurements, the circuit cannot be sampled from." + ) + return circuits + + +def _validate_parameter_values( + parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None, + default: Sequence[Sequence[float]] | Sequence[float] | None = None, +) -> tuple[tuple[float, ...], ...]: + # Allow optional (if default) + if parameter_values is None: + if default is None: + raise ValueError("No default `parameter_values`, optional input disallowed.") + parameter_values = default + + # Support numpy ndarray + if isinstance(parameter_values, np.ndarray): + parameter_values = parameter_values.tolist() + elif isinstance(parameter_values, Sequence): + parameter_values = tuple( + vector.tolist() if isinstance(vector, np.ndarray) else vector + for vector in parameter_values + ) + + # Allow single value + if _isreal(parameter_values): + parameter_values = ((parameter_values,),) + elif isinstance(parameter_values, Sequence) and not any( + isinstance(vector, Sequence) for vector in parameter_values + ): + parameter_values = (parameter_values,) + + # Validation + if ( + not isinstance(parameter_values, Sequence) + or not all(isinstance(vector, Sequence) for vector in parameter_values) + or not all(all(_isreal(value) for value in vector) for vector in parameter_values) + ): + raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].") + + return tuple(tuple(float(value) for value in vector) for vector in parameter_values) + + +def _validate_observables( + observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str, +) -> tuple[SparsePauliOp, ...]: + if isinstance(observables, str) or not isinstance(observables, Sequence): + observables = (observables,) + if len(observables) == 0: + raise ValueError("No observables were provided.") + return tuple(init_observable(obs) for obs in observables) + + +def _cross_validate_circuits_parameter_values( + circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...] +) -> None: + if len(circuits) != len(parameter_values): + raise ValueError( + f"The number of circuits ({len(circuits)}) does not match " + f"the number of parameter value sets ({len(parameter_values)})." + ) + for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)): + if len(vector) != circuit.num_parameters: + raise ValueError( + f"The number of values ({len(vector)}) does not match " + f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit." + ) + + +def _cross_validate_circuits_observables( + circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...] +) -> None: + if len(circuits) != len(observables): + raise ValueError( + f"The number of circuits ({len(circuits)}) does not match " + f"the number of observables ({len(observables)})." + ) + for i, (circuit, observable) in enumerate(zip(circuits, observables)): + if circuit.num_qubits != observable.num_qubits: + raise ValueError( + f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does " + f"not match the number of qubits of the {i}-th observable " + f"({observable.num_qubits})." + ) + + +def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool: + """Check if object is int.""" + int_types = (int, np.integer) + return isinstance(obj, int_types) and not isinstance(obj, bool) + + +def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool: + """Check if object is a real number: int or float except ``±Inf`` and ``NaN``.""" + float_types = (float, np.floating) + return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf") + + +def _has_measure(circuit: QuantumCircuit) -> bool: + for instruction in reversed(circuit): + if isinstance(instruction.operation, Measure): + return True + elif isinstance(instruction.operation, ControlFlowOp): + for block in instruction.operation.blocks: + if _has_measure(block): + return True + return False diff --git a/test/python/primitives/test_estimator.py b/test/python/primitives/test_estimator.py index 1fc6e84d8dfb..f438135230bc 100644 --- a/test/python/primitives/test_estimator.py +++ b/test/python/primitives/test_estimator.py @@ -21,7 +21,8 @@ from qiskit.circuit.library import RealAmplitudes from qiskit.exceptions import QiskitError from qiskit.opflow import PauliSumOp -from qiskit.primitives import BaseEstimator, Estimator, EstimatorResult +from qiskit.primitives import Estimator, EstimatorResult +from qiskit.primitives.base import validation, BaseEstimator from qiskit.primitives.utils import _observable_key from qiskit.providers import JobV1 from qiskit.quantum_info import Operator, Pauli, PauliList, SparsePauliOp @@ -388,7 +389,7 @@ class TestObservableValidation(QiskitTestCase): @unpack def test_validate_observables(self, obsevables, expected): """Test obsevables standardization.""" - self.assertEqual(BaseEstimator._validate_observables(obsevables), expected) + self.assertEqual(validation._validate_observables(obsevables), expected) @data( (PauliList("IXYZ"), (SparsePauliOp("IXYZ"),)), @@ -407,13 +408,13 @@ def test_validate_observables_deprecated(self, obsevables, expected): def test_qiskit_error(self, observables): """Test qiskit error if invalid input.""" with self.assertRaises(QiskitError): - BaseEstimator._validate_observables(observables) + validation._validate_observables(observables) @data((), []) def test_value_error(self, observables): """Test value error if no obsevables are provided.""" with self.assertRaises(ValueError): - BaseEstimator._validate_observables(observables) + validation._validate_observables(observables) if __name__ == "__main__": diff --git a/test/python/primitives/test_primitive.py b/test/python/primitives/test_primitive.py index cc60a17abc7e..5ae93df82a51 100644 --- a/test/python/primitives/test_primitive.py +++ b/test/python/primitives/test_primitive.py @@ -19,7 +19,7 @@ from qiskit import QuantumCircuit, pulse, transpile from qiskit.circuit.random import random_circuit -from qiskit.primitives.base.base_primitive import BasePrimitive +from qiskit.primitives.base import validation from qiskit.primitives.utils import _circuit_key from qiskit.providers.fake_provider import FakeAlmaden from qiskit.test import QiskitTestCase @@ -39,19 +39,19 @@ class TestCircuitValidation(QiskitTestCase): @unpack def test_validate_circuits(self, circuits, expected): """Test circuits standardization.""" - self.assertEqual(BasePrimitive._validate_circuits(circuits), expected) + self.assertEqual(validation._validate_circuits(circuits), expected) @data(None, "ERROR", True, 0, 1.0, 1j, [0.0]) def test_type_error(self, circuits): """Test type error if invalid input.""" with self.assertRaises(TypeError): - BasePrimitive._validate_circuits(circuits) + validation._validate_circuits(circuits) @data((), [], "") def test_value_error(self, circuits): """Test value error if no circuits are provided.""" with self.assertRaises(ValueError): - BasePrimitive._validate_circuits(circuits) + validation._validate_circuits(circuits) @ddt @@ -87,9 +87,9 @@ class TestParameterValuesValidation(QiskitTestCase): def test_validate_parameter_values(self, _parameter_values, expected): """Test parameter_values standardization.""" for parameter_values in [_parameter_values, array(_parameter_values)]: # Numpy - self.assertEqual(BasePrimitive._validate_parameter_values(parameter_values), expected) + self.assertEqual(validation._validate_parameter_values(parameter_values), expected) self.assertEqual( - BasePrimitive._validate_parameter_values(None, default=parameter_values), expected + validation._validate_parameter_values(None, default=parameter_values), expected ) @data( @@ -108,12 +108,12 @@ def test_validate_parameter_values(self, _parameter_values, expected): def test_type_error(self, parameter_values): """Test type error if invalid input.""" with self.assertRaises(TypeError): - BasePrimitive._validate_parameter_values(parameter_values) + validation._validate_parameter_values(parameter_values) def test_value_error(self): """Test value error if no parameter_values or default are provided.""" with self.assertRaises(ValueError): - BasePrimitive._validate_parameter_values(None) + validation._validate_parameter_values(None) class TestCircuitKey(QiskitTestCase):