diff --git a/source/pip/qsharp/__init__.py b/source/pip/qsharp/__init__.py index a6161f518c..b66097bc2d 100644 --- a/source/pip/qsharp/__init__.py +++ b/source/pip/qsharp/__init__.py @@ -21,8 +21,15 @@ BitFlipNoise, PhaseFlipNoise, CircuitGenerationMethod, + QdkContext, + new_context, + get_context, + context_of, ) +# Backward-compatible alias +QSharpContext = QdkContext + telemetry_events.on_import() from ._native import Result, Pauli, QSharpError, TargetProfile, estimate_custom @@ -62,4 +69,8 @@ "BitFlipNoise", "PhaseFlipNoise", "CircuitGenerationMethod", + "QdkContext", + "new_context", + "get_context", + "context_of", ] diff --git a/source/pip/qsharp/_ipython.py b/source/pip/qsharp/_ipython.py index c010befe72..67107ad02e 100644 --- a/source/pip/qsharp/_ipython.py +++ b/source/pip/qsharp/_ipython.py @@ -12,7 +12,7 @@ from IPython.display import display, Javascript, clear_output from IPython.core.magic import register_cell_magic from ._native import QSharpError -from ._qsharp import get_interpreter, qsharp_value_to_python_value +from ._qsharp import _get_default_ctx, qsharp_value_to_python_value from . import telemetry_events import pathlib @@ -37,7 +37,7 @@ def callback(output): try: results = qsharp_value_to_python_value( - get_interpreter().interpret(cell, callback) + _get_default_ctx()._interpreter.interpret(cell, callback) ) durationMs = (monotonic() - start_time) * 1000 diff --git a/source/pip/qsharp/_qsharp.py b/source/pip/qsharp/_qsharp.py index 57b122f279..f3309a9b1f 100644 --- a/source/pip/qsharp/_qsharp.py +++ b/source/pip/qsharp/_qsharp.py @@ -115,8 +115,7 @@ def python_args_to_interpreter_args(args): return lower_python_obj(args) -_interpreter: Union["Interpreter", None] = None -_config: Union["Config", None] = None +_default_ctx: Union["QdkContext", None] = None # Check if we are running in a Jupyter notebook to use the IPython display function _in_jupyter = False @@ -195,78 +194,21 @@ def get_target_profile(self) -> str: return self._config.get("targetProfile", "unspecified") -class PauliNoise(Tuple[float, float, float]): - """ - The Pauli noise to use in simulation represented - as probabilities of Pauli-X, Pauli-Y, and Pauli-Z errors - """ - - def __new__(cls, x: float, y: float, z: float): - if x < 0 or y < 0 or z < 0: - raise ValueError("Pauli noise probabilities must be non-negative.") - if x + y + z > 1: - raise ValueError("The sum of Pauli noise probabilities must be at most 1.") - return super().__new__(cls, (x, y, z)) - - -class DepolarizingNoise(PauliNoise): - """ - The depolarizing noise to use in simulation. - """ - - def __new__(cls, p: float): - return super().__new__(cls, p / 3, p / 3, p / 3) - - -class BitFlipNoise(PauliNoise): - """ - The bit flip noise to use in simulation. - """ - - def __new__(cls, p: float): - return super().__new__(cls, p, 0, 0) - - -class PhaseFlipNoise(PauliNoise): +def _create_interpreter( + target_profile: TargetProfile, + language_features: Optional[List[str]], + project_root: Optional[str], + target_name: Optional[str], + trace_circuit: Optional[bool], + make_callable_fn, + make_class_fn, +) -> Tuple["Interpreter", "Config"]: """ - The phase flip noise to use in simulation. - """ - - def __new__(cls, p: float): - return super().__new__(cls, 0, 0, p) - - -def init( - *, - target_profile: TargetProfile = TargetProfile.Unrestricted, - target_name: Optional[str] = None, - project_root: Optional[str] = None, - language_features: Optional[List[str]] = None, - trace_circuit: Optional[bool] = None, -) -> Config: - """ - Initializes the Q# interpreter. - - :param target_profile: Setting the target profile allows the Q# - interpreter to generate programs that are compatible - with a specific target. See :py:class: `qsharp.TargetProfile`. - - :param target_name: An optional name of the target machine to use for inferring the compatible - target_profile setting. - - :param project_root: An optional path to a root directory with a Q# project to include. - It must contain a qsharp.json project manifest. - - :param trace_circuit: Enables tracing of circuit during execution. - Passing `True` is required for the `dump_circuit` function to return a circuit. - The `circuit` function is *NOT* affected by this parameter will always generate a circuit. + Shared helper that creates an Interpreter and Config from the given parameters. """ from ._fs import read_file, list_directory, exists, join, resolve from ._http import fetch_github - global _interpreter - global _config - if isinstance(target_name, str): target = target_name.split(".")[0].lower() if target == "ionq" or target == "rigetti": @@ -280,7 +222,6 @@ def init( manifest_contents = None if project_root is not None: - # Normalize the project path (i.e. fix file separators and remove unnecessary '.' and '..') project_root = resolve(".", project_root) qsharp_json = join(project_root, "qsharp.json") if not exists(qsharp_json): @@ -295,10 +236,30 @@ def init( f"Error reading {qsharp_json}. qsharp.json should exist at the project root and be a valid JSON file." ) from e - # Loop through the environment module and remove any dynamically added attributes that represent - # Q# callables or structs. This is necessary to avoid conflicts with the new interpreter instance. + interpreter = Interpreter( + target_profile, + language_features, + project_root, + read_file, + list_directory, + resolve, + fetch_github, + make_callable_fn, + make_class_fn, + trace_circuit, + ) + + config = Config(target_profile, language_features, manifest_contents, project_root) + return interpreter, config + + +def _clear_code_module(code_module: types.ModuleType, module_prefix: str): + """ + Removes dynamically added Q# callables, structs, and namespace modules from + a code module and sys.modules. + """ keys_to_remove = [] - for key, val in code.__dict__.items(): + for key, val in code_module.__dict__.items(): if ( hasattr(val, "__global_callable") or hasattr(val, "__qsharp_class") @@ -306,59 +267,55 @@ def init( ): keys_to_remove.append(key) for key in keys_to_remove: - code.__delattr__(key) + code_module.__delattr__(key) - # Also remove any namespace modules dynamically added to the system. keys_to_remove = [] for key in sys.modules: - if key.startswith("qsharp.code."): + if key.startswith(module_prefix + "."): keys_to_remove.append(key) for key in keys_to_remove: sys.modules.__delitem__(key) - _interpreter = Interpreter( - target_profile, - language_features, - project_root, - read_file, - list_directory, - resolve, - fetch_github, - _make_callable, - _make_class, - trace_circuit, - ) - _config = Config(target_profile, language_features, manifest_contents, project_root) - # Return the configuration information to provide a hint to the - # language service through the cell output. - return _config +class PauliNoise(Tuple[float, float, float]): + """ + The Pauli noise to use in simulation represented + as probabilities of Pauli-X, Pauli-Y, and Pauli-Z errors + """ + def __new__(cls, x: float, y: float, z: float): + if x < 0 or y < 0 or z < 0: + raise ValueError("Pauli noise probabilities must be non-negative.") + if x + y + z > 1: + raise ValueError("The sum of Pauli noise probabilities must be at most 1.") + return super().__new__(cls, (x, y, z)) -def get_interpreter() -> Interpreter: - """ - Returns the Q# interpreter. - :returns: The Q# interpreter. +class DepolarizingNoise(PauliNoise): + """ + The depolarizing noise to use in simulation. """ - global _interpreter - if _interpreter is None: - init() - assert _interpreter is not None, "Failed to initialize the Q# interpreter." - return _interpreter + def __new__(cls, p: float): + return super().__new__(cls, p / 3, p / 3, p / 3) -def get_config() -> Config: + +class BitFlipNoise(PauliNoise): + """ + The bit flip noise to use in simulation. """ - Returns the Q# interpreter configuration. - :returns: The Q# interpreter configuration. + def __new__(cls, p: float): + return super().__new__(cls, p, 0, 0) + + +class PhaseFlipNoise(PauliNoise): + """ + The phase flip noise to use in simulation. """ - global _config - if _config is None: - init() - assert _config is not None, "Failed to initialize the Q# interpreter." - return _config + + def __new__(cls, p: float): + return super().__new__(cls, 0, 0, p) class StateDump: @@ -449,132 +406,28 @@ class ShotResult(TypedDict): dumps: List[StateDump] -def eval( - source: str, - *, - save_events: bool = False, -) -> Any: - """ - Evaluates Q# source code. - - Output is printed to console. - - :param source: The Q# source code to evaluate. - :param save_events: If true, all output will be saved and returned. If false, they will be printed. - :returns value: The value returned by the last statement in the source code or the saved output if `save_events` is true. - :raises QSharpError: If there is an error evaluating the source code. - """ - ipython_helper() - - results: ShotResult = { - "events": [], - "result": None, - "messages": [], - "matrices": [], - "dumps": [], - } - - def on_save_events(output: Output) -> None: - # Append the output to the last shot's output list - if output.is_matrix(): - results["events"].append(output) - results["matrices"].append(output) - elif output.is_state_dump(): - dump_data = cast(StateDumpData, output.state_dump()) - state_dump = StateDump(dump_data) - results["events"].append(state_dump) - results["dumps"].append(state_dump) - elif output.is_message(): - stringified = str(output) - results["events"].append(stringified) - results["messages"].append(stringified) - - def callback(output: Output) -> None: - if _in_jupyter: - try: - display(output) - return - except: - # If IPython is not available, fall back to printing the output - pass - print(output, flush=True) - - telemetry_events.on_eval() - start_time = monotonic() - - output = get_interpreter().interpret( - source, on_save_events if save_events else callback - ) - results["result"] = qsharp_value_to_python_value(output) - - durationMs = (monotonic() - start_time) * 1000 - telemetry_events.on_eval_end(durationMs) - - if save_events: - return results - else: - return results["result"] - - -# Helper function that knows how to create a function that invokes a callable. This will be -# used by the underlying native code to create functions for callables on the fly that know -# how to get the currently initialized global interpreter instance. -def _make_callable(callable: GlobalCallable, namespace: List[str], callable_name: str): - module = code - # Create a name that will be used to collect the hierarchy of namespace identifiers if they exist and use that - # to register created modules with the system. - accumulated_namespace = "qsharp.code" - accumulated_namespace += "." - for name in namespace: - accumulated_namespace += name - # Use the existing entry, which should already be a module. - if hasattr(module, name): - module = module.__getattribute__(name) - if sys.modules.get(accumulated_namespace) is None: - # This is an existing entry that is not yet registered in sys.modules, so add it. - # This can happen if a callable with the same name as this namespace is already - # defined. - sys.modules[accumulated_namespace] = module - else: - # This namespace entry doesn't exist as a module yet, so create it, add it to the environment, and - # add it to sys.modules so it supports import properly. - new_module = types.ModuleType(accumulated_namespace) - module.__setattr__(name, new_module) - sys.modules[accumulated_namespace] = new_module - module = new_module - accumulated_namespace += "." - - def _callable(*args): - ipython_helper() - - def callback(output: Output) -> None: - if _in_jupyter: - try: - display(output) - return - except: - # If IPython is not available, fall back to printing the output - pass - print(output, flush=True) - - args = python_args_to_interpreter_args(args) +# Class that wraps generated QIR, which can be used by +# azure-quantum as input data. +# +# This class must implement the QirRepresentable protocol +# that is defined by the azure-quantum package. +# See: https://github.com/microsoft/qdk-python/blob/fcd63c04aa871e49206703bbaa792329ffed13c4/azure-quantum/azure/quantum/target/target.py#L21 +class QirInputData: + # The name of this variable is defined + # by the protocol and must remain unchanged. + _name: str - output = get_interpreter().invoke(callable, args, callback) - return qsharp_value_to_python_value(output) + def __init__(self, name: str, ll_str: str): + self._name = name + self._ll_str = ll_str - # Each callable is annotated so that we know it is auto-generated and can be removed on a re-init of the interpreter. - _callable.__global_callable = callable + # The name of this method is defined + # by the protocol and must remain unchanged. + def _repr_qir_(self, **kwargs) -> bytes: + return self._ll_str.encode("utf-8") - # Add the callable to the module. - if module.__dict__.get(callable_name) is None: - module.__setattr__(callable_name, _callable) - else: - # Preserve any existing attributes on the attribute with the matching name, - # since this could be a collision with an existing namespace/module. - for key, val in module.__dict__.get(callable_name).__dict__.items(): - if key != "__global_callable": - _callable.__dict__[key] = val - module.__setattr__(callable_name, _callable) + def __str__(self) -> str: + return self._ll_str def qsharp_value_to_python_value(obj): @@ -655,39 +508,695 @@ def make_class_rec(qsharp_type: TypeIR) -> type: ) -def _make_class(qsharp_type: TypeIR, namespace: List[str], class_name: str): +def _check_same_context(ctx: "QdkContext", callable_fn: Callable) -> None: + """Raise if a callable belongs to a different context than *ctx*.""" + getter = getattr(callable_fn, "_qdk_get_context", None) + if getter is not None: + origin = getter() + if origin is not ctx: + raise QSharpError( + "This callable belongs to a different QdkContext. " + "Use qsharp.context_of(callable) to get the correct context, " + "or operate on the callable within the context that created it." + ) + + +class QdkContext: """ - Helper function to create a python class given a description of it. This will be - used by the underlying native code to create classes on the fly corresponding to - the currently initialized interpreter instance. + An isolated Q# interpreter context. Created via ``qsharp.new_context(...)``. + + Each context has its own interpreter, configuration, and code namespace. + Instance methods mirror the module-level functions (``eval``, ``run``, + ``compile``, etc.) but operate on this context's interpreter. """ - module = code - # Create a name that will be used to collect the hierarchy of namespace identifiers if they exist and use that - # to register created modules with the system. - accumulated_namespace = "qsharp.code" - accumulated_namespace += "." - for name in namespace: - accumulated_namespace += name - # Use the existing entry, which should already be a module. - if hasattr(module, name): - module = module.__getattribute__(name) + _interpreter: Interpreter + _config: Config + code: types.ModuleType + _code_prefix: str + _disposed: bool + + def __init__( + self, + *, + target_profile: TargetProfile = TargetProfile.Unrestricted, + target_name: Optional[str] = None, + project_root: Optional[str] = None, + language_features: Optional[List[str]] = None, + trace_circuit: Optional[bool] = None, + _code_module: Optional[types.ModuleType] = None, + _code_prefix: Optional[str] = None, + ): + self._disposed = False + + if _code_module is not None: + self.code = _code_module + self._code_prefix = _code_prefix or "qsharp.code" + else: + self._code_prefix = f"qsharp._ctx_{id(self)}" + self.code = types.ModuleType(self._code_prefix) + + self._interpreter, self._config = _create_interpreter( + target_profile=target_profile, + language_features=language_features, + project_root=project_root, + target_name=target_name, + trace_circuit=trace_circuit, + make_callable_fn=self._make_callable, + make_class_fn=self._make_class, + ) + + def _make_callable( + self, callable: GlobalCallable, namespace: List[str], callable_name: str + ): + """Registers a Q# callable in this context's code module.""" + module = self.code + accumulated_namespace = self._code_prefix + "." + for name in namespace: + accumulated_namespace += name + if hasattr(module, name): + module = module.__getattribute__(name) + if sys.modules.get(accumulated_namespace) is None: + sys.modules[accumulated_namespace] = module + else: + new_module = types.ModuleType(accumulated_namespace) + module.__setattr__(name, new_module) + sys.modules[accumulated_namespace] = new_module + module = new_module + accumulated_namespace += "." + + def _callable_fn(*args): + if self._disposed: + raise QSharpError( + "This callable belongs to a disposed Q# context. " + "Re-evaluate the callable in a current context." + ) + ipython_helper() + + def callback(output: Output) -> None: + if _in_jupyter: + try: + display(output) + return + except: + pass + print(output, flush=True) + + args = python_args_to_interpreter_args(args) + output = self._interpreter.invoke(callable, args, callback) + return qsharp_value_to_python_value(output) + + _callable_fn._qdk_get_interpreter = lambda: self._interpreter + _callable_fn._qdk_get_context = lambda: self + setattr(_callable_fn, "__global_callable", callable) + + if module.__dict__.get(callable_name) is None: + module.__setattr__(callable_name, _callable_fn) + else: + for key, val in module.__dict__.get(callable_name).__dict__.items(): + if key != "__global_callable": + _callable_fn.__dict__[key] = val + module.__setattr__(callable_name, _callable_fn) + + def _make_class(self, qsharp_type: TypeIR, namespace: List[str], class_name: str): + """Registers a Q# type as a Python dataclass in this context's code module.""" + module = self.code + accumulated_namespace = self._code_prefix + "." + for name in namespace: + accumulated_namespace += name + if hasattr(module, name): + module = module.__getattribute__(name) + else: + new_module = types.ModuleType(accumulated_namespace) + module.__setattr__(name, new_module) + sys.modules[accumulated_namespace] = new_module + module = new_module + accumulated_namespace += "." + + QSharpClass = make_class_rec(qsharp_type) + QSharpClass.__qsharp_class = True + module.__setattr__(class_name, QSharpClass) + + @property + def config(self) -> Config: + """The interpreter configuration (read-only).""" + return self._config + + def __repr__(self) -> str: + return repr(self._config) + + def _repr_mimebundle_( + self, include: Union[Any, None] = None, exclude: Union[Any, None] = None + ) -> Dict[str, Dict[str, Any]]: + return self._config._repr_mimebundle_(include, exclude) + + def eval( + self, + source: str, + *, + save_events: bool = False, + ) -> Any: + """ + Evaluates Q# source code in this context. + """ + ipython_helper() + + results: ShotResult = { + "events": [], + "result": None, + "messages": [], + "matrices": [], + "dumps": [], + } + + def on_save_events(output: Output) -> None: + if output.is_matrix(): + results["events"].append(output) + results["matrices"].append(output) + elif output.is_state_dump(): + dump_data = cast(StateDumpData, output.state_dump()) + state_dump = StateDump(dump_data) + results["events"].append(state_dump) + results["dumps"].append(state_dump) + elif output.is_message(): + stringified = str(output) + results["events"].append(stringified) + results["messages"].append(stringified) + + def callback(output: Output) -> None: + if _in_jupyter: + try: + display(output) + return + except: + pass + print(output, flush=True) + + telemetry_events.on_eval() + start_time = monotonic() + + output = self._interpreter.interpret( + source, on_save_events if save_events else callback + ) + results["result"] = qsharp_value_to_python_value(output) + + durationMs = (monotonic() - start_time) * 1000 + telemetry_events.on_eval_end(durationMs) + + if save_events: + return results + else: + return results["result"] + + def run( + self, + entry_expr: Union[str, Callable, GlobalCallable, Closure], + shots: int, + *args, + on_result: Optional[Callable[[ShotResult], None]] = None, + save_events: bool = False, + noise: Optional[ + Union[ + Tuple[float, float, float], + PauliNoise, + BitFlipNoise, + PhaseFlipNoise, + DepolarizingNoise, + ] + ] = None, + qubit_loss: Optional[float] = None, + ) -> List[Any]: + """ + Runs the given Q# expression for the given number of shots in this context. + """ + ipython_helper() + + if shots < 1: + raise ValueError("The number of shots must be greater than 0.") + + telemetry_events.on_run( + shots, + noise=(noise is not None and noise != (0.0, 0.0, 0.0)), + qubit_loss=(qubit_loss is not None and qubit_loss > 0.0), + ) + start_time = monotonic() + + results: List[ShotResult] = [] + + def print_output(output: Output) -> None: + if _in_jupyter: + try: + display(output) + return + except: + pass + print(output, flush=True) + + def on_save_events(output: Output) -> None: + results[-1]["events"].append(output) + if output.is_matrix(): + results[-1]["matrices"].append(output) + elif output.is_state_dump(): + dump_data = cast(StateDumpData, output.state_dump()) + results[-1]["dumps"].append(StateDump(dump_data)) + elif output.is_message(): + results[-1]["messages"].append(str(output)) + + callable = None + run_entry_expr = None + if isinstance(entry_expr, Callable) and hasattr( + entry_expr, "__global_callable" + ): + _check_same_context(self, entry_expr) + args = python_args_to_interpreter_args(args) + callable = getattr(entry_expr, "__global_callable") + elif isinstance(entry_expr, (GlobalCallable, Closure)): + args = python_args_to_interpreter_args(args) + callable = entry_expr + else: + assert isinstance(entry_expr, str) + run_entry_expr = entry_expr + + for shot in range(shots): + results.append( + { + "result": None, + "events": [], + "messages": [], + "matrices": [], + "dumps": [], + } + ) + run_results = self._interpreter.run( + run_entry_expr, + on_save_events if save_events else print_output, + noise, + qubit_loss, + callable, + args, + ) + run_results = qsharp_value_to_python_value(run_results) + results[-1]["result"] = run_results + if on_result: + on_result(results[-1]) + run_entry_expr = None + + durationMs = (monotonic() - start_time) * 1000 + telemetry_events.on_run_end(durationMs, shots) + + if save_events: + return results + else: + return [shot["result"] for shot in results] + + def compile( + self, + entry_expr: Union[str, Callable, GlobalCallable, Closure], + *args, + ) -> QirInputData: + """ + Compiles the Q# source code into a program that can be submitted to a target. + """ + ipython_helper() + start = monotonic() + target_profile = self._config.get_target_profile() + telemetry_events.on_compile(target_profile) + if isinstance(entry_expr, Callable) and hasattr( + entry_expr, "__global_callable" + ): + _check_same_context(self, entry_expr) + args = python_args_to_interpreter_args(args) + ll_str = self._interpreter.qir( + callable=getattr(entry_expr, "__global_callable"), args=args + ) + elif isinstance(entry_expr, (GlobalCallable, Closure)): + args = python_args_to_interpreter_args(args) + ll_str = self._interpreter.qir(callable=entry_expr, args=args) + else: + assert isinstance(entry_expr, str) + ll_str = self._interpreter.qir(entry_expr=entry_expr) + res = QirInputData("main", ll_str) + durationMs = (monotonic() - start) * 1000 + telemetry_events.on_compile_end(durationMs, target_profile) + return res + + def circuit( + self, + entry_expr: Optional[Union[str, Callable, GlobalCallable, Closure]] = None, + *args, + operation: Optional[str] = None, + generation_method: Optional[CircuitGenerationMethod] = None, + max_operations: Optional[int] = None, + source_locations: bool = False, + group_by_scope: bool = True, + prune_classical_qubits: bool = False, + ) -> Circuit: + """ + Synthesizes a circuit for a Q# program in this context. + """ + ipython_helper() + start = monotonic() + telemetry_events.on_circuit() + config = CircuitConfig( + max_operations=max_operations, + generation_method=generation_method, + source_locations=source_locations, + group_by_scope=group_by_scope, + prune_classical_qubits=prune_classical_qubits, + ) + + if isinstance(entry_expr, Callable) and hasattr( + entry_expr, "__global_callable" + ): + _check_same_context(self, entry_expr) + args = python_args_to_interpreter_args(args) + res = self._interpreter.circuit( + config=config, + callable=getattr(entry_expr, "__global_callable"), + args=args, + ) + elif isinstance(entry_expr, (GlobalCallable, Closure)): + args = python_args_to_interpreter_args(args) + res = self._interpreter.circuit( + config=config, callable=entry_expr, args=args + ) + else: + assert entry_expr is None or isinstance(entry_expr, str) + res = self._interpreter.circuit(config, entry_expr, operation=operation) + + durationMs = (monotonic() - start) * 1000 + telemetry_events.on_circuit_end(durationMs) + + return res + + def estimate( + self, + entry_expr: Union[str, Callable, GlobalCallable, Closure], + params: Optional[Union[Dict[str, Any], List, EstimatorParams]] = None, + *args, + ) -> EstimatorResult: + """ + Estimates resources for Q# source code in this context. + """ + ipython_helper() + + def _coerce_estimator_params( + params: Optional[ + Union[Dict[str, Any], List[Dict[str, Any]], EstimatorParams] + ] = None, + ) -> List[Dict[str, Any]]: + if params is None: + return [{}] + elif isinstance(params, EstimatorParams): + if params.has_items: + return cast(List[Dict[str, Any]], params.as_dict()["items"]) + else: + return [params.as_dict()] + elif isinstance(params, dict): + return [params] + return params + + params = _coerce_estimator_params(params) + param_str = json.dumps(params) + telemetry_events.on_estimate() + start = monotonic() + if isinstance(entry_expr, Callable) and hasattr( + entry_expr, "__global_callable" + ): + _check_same_context(self, entry_expr) + args = python_args_to_interpreter_args(args) + res_str = self._interpreter.estimate( + param_str, callable=getattr(entry_expr, "__global_callable"), args=args + ) + elif isinstance(entry_expr, (GlobalCallable, Closure)): + args = python_args_to_interpreter_args(args) + res_str = self._interpreter.estimate( + param_str, callable=entry_expr, args=args + ) + else: + assert isinstance(entry_expr, str) + res_str = self._interpreter.estimate(param_str, entry_expr=entry_expr) + res = json.loads(res_str) + + try: + qubits = res[0]["logicalCounts"]["numQubits"] + except (KeyError, IndexError): + qubits = "unknown" + + durationMs = (monotonic() - start) * 1000 + telemetry_events.on_estimate_end(durationMs, qubits) + return EstimatorResult(res) + + def logical_counts( + self, + entry_expr: Union[str, Callable, GlobalCallable, Closure], + *args, + ) -> LogicalCounts: + """ + Extracts logical resource counts from Q# source code in this context. + """ + ipython_helper() + + if isinstance(entry_expr, Callable) and hasattr( + entry_expr, "__global_callable" + ): + _check_same_context(self, entry_expr) + args = python_args_to_interpreter_args(args) + res_dict = self._interpreter.logical_counts( + callable=getattr(entry_expr, "__global_callable"), args=args + ) + elif isinstance(entry_expr, (GlobalCallable, Closure)): + args = python_args_to_interpreter_args(args) + res_dict = self._interpreter.logical_counts(callable=entry_expr, args=args) else: - # This namespace entry doesn't exist as a module yet, so create it, add it to the environment, and - # add it to sys.modules so it supports import properly. - new_module = types.ModuleType(accumulated_namespace) - module.__setattr__(name, new_module) - sys.modules[accumulated_namespace] = new_module - module = new_module - accumulated_namespace += "." + assert isinstance(entry_expr, str) + res_dict = self._interpreter.logical_counts(entry_expr=entry_expr) + return LogicalCounts(res_dict) + + def set_quantum_seed(self, seed: Optional[int]) -> None: + """ + Sets the seed for the random number generator used for quantum measurements. + """ + self._interpreter.set_quantum_seed(seed) + + def set_classical_seed(self, seed: Optional[int]) -> None: + """ + Sets the seed for the random number generator used for standard + library classical random number operations. + """ + self._interpreter.set_classical_seed(seed) + + def dump_machine(self) -> StateDump: + """ + Returns the sparse state vector of the simulator as a StateDump object. + """ + ipython_helper() + return StateDump(self._interpreter.dump_machine()) + + def dump_circuit(self) -> Circuit: + """ + Dumps a circuit showing the current state of the simulator. + """ + ipython_helper() + return self._interpreter.dump_circuit() + + def import_openqasm( + self, + source: str, + **kwargs: Any, + ) -> Any: + """ + Imports OpenQASM source code into this context's interpreter. + + Args: + source (str): An OpenQASM program or fragment. + **kwargs: Additional keyword arguments (name, search_path, + output_semantics, program_type). + + Returns: + value: The value returned by the last statement in the source code. + """ + from .openqasm._ipython import display_or_print + from ._fs import read_file, list_directory, resolve + from ._http import fetch_github + + ipython_helper() + + telemetry_events.on_import_qasm() + start_time = monotonic() + + kwargs = {k: v for k, v in kwargs.items() if k is not None and v is not None} + if "search_path" not in kwargs: + kwargs["search_path"] = "." + + res = self._interpreter.import_qasm( + source, + display_or_print, + read_file, + list_directory, + resolve, + fetch_github, + **kwargs, + ) + + durationMs = (monotonic() - start_time) * 1000 + telemetry_events.on_import_qasm_end(durationMs) - QSharpClass = make_class_rec(qsharp_type) + return res - # Each class is annotated so that we know it is auto-generated and can be removed on a re-init of the interpreter. - QSharpClass.__qsharp_class = True - # Add the class to the module. - module.__setattr__(class_name, QSharpClass) +def new_context( + *, + target_profile: TargetProfile = TargetProfile.Unrestricted, + target_name: Optional[str] = None, + project_root: Optional[str] = None, + language_features: Optional[List[str]] = None, + trace_circuit: Optional[bool] = None, +) -> QdkContext: + """ + Creates an isolated Q# interpreter context. + + :param target_profile: The target profile for the interpreter. + :param target_name: An optional target machine name. + :param project_root: An optional path to a Q# project root. + :param language_features: Optional language features to enable. + :param trace_circuit: Enables tracing of circuit during execution. + :returns: A new ``QdkContext``. + """ + return QdkContext( + target_profile=target_profile, + target_name=target_name, + project_root=project_root, + language_features=language_features, + trace_circuit=trace_circuit, + ) + + +def init( + *, + target_profile: TargetProfile = TargetProfile.Unrestricted, + target_name: Optional[str] = None, + project_root: Optional[str] = None, + language_features: Optional[List[str]] = None, + trace_circuit: Optional[bool] = None, +) -> QdkContext: + """ + Initializes the Q# interpreter. + + :param target_profile: Setting the target profile allows the Q# + interpreter to generate programs that are compatible + with a specific target. See :py:class: `qsharp.TargetProfile`. + + :param target_name: An optional name of the target machine to use for inferring the compatible + target_profile setting. + + :param project_root: An optional path to a root directory with a Q# project to include. + It must contain a qsharp.json project manifest. + + :param trace_circuit: Enables tracing of circuit during execution. + Passing `True` is required for the `dump_circuit` function to return a circuit. + The `circuit` function is *NOT* affected by this parameter will always generate a circuit. + + :returns: The ``QdkContext`` that is now the global default. + """ + global _default_ctx + + # Dispose the old context so its callables fail gracefully. + if _default_ctx is not None: + _default_ctx._disposed = True + + # Clean up the global code namespace before creating a new context. + _clear_code_module(code, "qsharp.code") + + _default_ctx = QdkContext( + target_profile=target_profile, + target_name=target_name, + project_root=project_root, + language_features=language_features, + trace_circuit=trace_circuit, + _code_module=code, + _code_prefix="qsharp.code", + ) + # Return the context, which supports __repr__ and _repr_mimebundle_ + # for language service hints through notebook cell output. + return _default_ctx + + +def _get_default_ctx() -> QdkContext: + """ + Returns the global default context, lazily initializing if needed. + """ + global _default_ctx + if _default_ctx is None: + init() + assert _default_ctx is not None, "Failed to initialize the Q# interpreter." + return _default_ctx + + +def get_context() -> QdkContext: + """ + Returns the current global context without reinitializing. + + If no context exists yet, one is created lazily (equivalent to calling + ``init()`` with default parameters). + + :returns: The global default ``QdkContext``. + """ + return _get_default_ctx() + + +def context_of(obj: Callable) -> QdkContext: + """ + Returns the ``QdkContext`` that created a QDK callable. + + :param obj: A callable obtained from a ``QdkContext``'s ``code`` namespace + (e.g. ``ctx.code.MyOp`` or ``qsharp.code.MyOp``). + :returns: The ``QdkContext`` that compiled the callable. + :raises TypeError: If the object is not a QDK callable. + """ + getter = getattr(obj, "_qdk_get_context", None) + if getter is None: + raise TypeError( + "Expected a QDK callable (from ctx.code.* or qsharp.code.*), " + f"got {type(obj).__name__}" + ) + return getter() + + +def get_interpreter() -> Interpreter: + """ + Returns the Q# interpreter. + + :returns: The Q# interpreter. + """ + return _get_default_ctx()._interpreter + + +def get_config() -> Config: + """ + Returns the Q# interpreter configuration. + + :returns: The Q# interpreter configuration. + """ + return _get_default_ctx()._config + + +def eval( + source: str, + *, + save_events: bool = False, +) -> Any: + """ + Evaluates Q# source code. + + Output is printed to console. + + :param source: The Q# source code to evaluate. + :param save_events: If true, all output will be saved and returned. If false, they will be printed. + :returns value: The value returned by the last statement in the source code or the saved output if `save_events` is true. + :raises QSharpError: If there is an error evaluating the source code. + """ + return _get_default_ctx().eval(source, save_events=save_events) def run( @@ -726,105 +1235,15 @@ def run( :raises QSharpError: If there is an error interpreting the input. :raises ValueError: If the number of shots is less than 1. """ - ipython_helper() - - if shots < 1: - raise ValueError("The number of shots must be greater than 0.") - - telemetry_events.on_run( + return _get_default_ctx().run( + entry_expr, shots, - noise=(noise is not None and noise != (0.0, 0.0, 0.0)), - qubit_loss=(qubit_loss is not None and qubit_loss > 0.0), + *args, + on_result=on_result, + save_events=save_events, + noise=noise, + qubit_loss=qubit_loss, ) - start_time = monotonic() - - results: List[ShotResult] = [] - - def print_output(output: Output) -> None: - if _in_jupyter: - try: - display(output) - return - except: - # If IPython is not available, fall back to printing the output - pass - print(output, flush=True) - - def on_save_events(output: Output) -> None: - # Append the output to the last shot's output list - results[-1]["events"].append(output) - if output.is_matrix(): - results[-1]["matrices"].append(output) - elif output.is_state_dump(): - dump_data = cast(StateDumpData, output.state_dump()) - results[-1]["dumps"].append(StateDump(dump_data)) - elif output.is_message(): - results[-1]["messages"].append(str(output)) - - callable = None - run_entry_expr = None - if isinstance(entry_expr, Callable) and hasattr(entry_expr, "__global_callable"): - args = python_args_to_interpreter_args(args) - callable = entry_expr.__global_callable - elif isinstance(entry_expr, (GlobalCallable, Closure)): - args = python_args_to_interpreter_args(args) - callable = entry_expr - else: - assert isinstance(entry_expr, str) - run_entry_expr = entry_expr - - for shot in range(shots): - results.append( - {"result": None, "events": [], "messages": [], "matrices": [], "dumps": []} - ) - run_results = get_interpreter().run( - run_entry_expr, - on_save_events if save_events else print_output, - noise, - qubit_loss, - callable, - args, - ) - run_results = qsharp_value_to_python_value(run_results) - results[-1]["result"] = run_results - if on_result: - on_result(results[-1]) - # For every shot after the first, treat the entry expression as None to trigger - # a rerun of the last executed expression without paying the cost for any additional - # compilation. - run_entry_expr = None - - durationMs = (monotonic() - start_time) * 1000 - telemetry_events.on_run_end(durationMs, shots) - - if save_events: - return results - else: - return [shot["result"] for shot in results] - - -# Class that wraps generated QIR, which can be used by -# azure-quantum as input data. -# -# This class must implement the QirRepresentable protocol -# that is defined by the azure-quantum package. -# See: https://github.com/microsoft/qdk-python/blob/fcd63c04aa871e49206703bbaa792329ffed13c4/azure-quantum/azure/quantum/target/target.py#L21 -class QirInputData: - # The name of this variable is defined - # by the protocol and must remain unchanged. - _name: str - - def __init__(self, name: str, ll_str: str): - self._name = name - self._ll_str = ll_str - - # The name of this method is defined - # by the protocol and must remain unchanged. - def _repr_qir_(self, **kwargs) -> bytes: - return self._ll_str.encode("utf-8") - - def __str__(self) -> str: - return self._ll_str def compile( @@ -849,24 +1268,7 @@ def compile( with open('myfile.ll', 'w') as file: file.write(str(program)) """ - ipython_helper() - start = monotonic() - interpreter = get_interpreter() - target_profile = get_config().get_target_profile() - telemetry_events.on_compile(target_profile) - if isinstance(entry_expr, Callable) and hasattr(entry_expr, "__global_callable"): - args = python_args_to_interpreter_args(args) - ll_str = interpreter.qir(callable=entry_expr.__global_callable, args=args) - elif isinstance(entry_expr, (GlobalCallable, Closure)): - args = python_args_to_interpreter_args(args) - ll_str = interpreter.qir(callable=entry_expr, args=args) - else: - assert isinstance(entry_expr, str) - ll_str = interpreter.qir(entry_expr=entry_expr) - res = QirInputData("main", ll_str) - durationMs = (monotonic() - start) * 1000 - telemetry_events.on_compile_end(durationMs, target_profile) - return res + return _get_default_ctx().compile(entry_expr, *args) def circuit( @@ -894,34 +1296,17 @@ def circuit( :raises QSharpError: If there is an error synthesizing the circuit. """ - ipython_helper() - start = monotonic() - telemetry_events.on_circuit() - config = CircuitConfig( - max_operations=max_operations, + return _get_default_ctx().circuit( + entry_expr, + *args, + operation=operation, generation_method=generation_method, + max_operations=max_operations, source_locations=source_locations, group_by_scope=group_by_scope, prune_classical_qubits=prune_classical_qubits, ) - if isinstance(entry_expr, Callable) and hasattr(entry_expr, "__global_callable"): - args = python_args_to_interpreter_args(args) - res = get_interpreter().circuit( - config=config, callable=entry_expr.__global_callable, args=args - ) - elif isinstance(entry_expr, (GlobalCallable, Closure)): - args = python_args_to_interpreter_args(args) - res = get_interpreter().circuit(config=config, callable=entry_expr, args=args) - else: - assert entry_expr is None or isinstance(entry_expr, str) - res = get_interpreter().circuit(config, entry_expr, operation=operation) - - durationMs = (monotonic() - start) * 1000 - telemetry_events.on_circuit_end(durationMs) - - return res - def estimate( entry_expr: Union[str, Callable, GlobalCallable, Closure], @@ -938,50 +1323,7 @@ def estimate( :returns `EstimatorResult`: The estimated resources. """ - - ipython_helper() - - def _coerce_estimator_params( - params: Optional[ - Union[Dict[str, Any], List[Dict[str, Any]], EstimatorParams] - ] = None, - ) -> List[Dict[str, Any]]: - if params is None: - return [{}] - elif isinstance(params, EstimatorParams): - if params.has_items: - return cast(List[Dict[str, Any]], params.as_dict()["items"]) - else: - return [params.as_dict()] - elif isinstance(params, dict): - return [params] - return params - - params = _coerce_estimator_params(params) - param_str = json.dumps(params) - telemetry_events.on_estimate() - start = monotonic() - if isinstance(entry_expr, Callable) and hasattr(entry_expr, "__global_callable"): - args = python_args_to_interpreter_args(args) - res_str = get_interpreter().estimate( - param_str, callable=entry_expr.__global_callable, args=args - ) - elif isinstance(entry_expr, (GlobalCallable, Closure)): - args = python_args_to_interpreter_args(args) - res_str = get_interpreter().estimate(param_str, callable=entry_expr, args=args) - else: - assert isinstance(entry_expr, str) - res_str = get_interpreter().estimate(param_str, entry_expr=entry_expr) - res = json.loads(res_str) - - try: - qubits = res[0]["logicalCounts"]["numQubits"] - except (KeyError, IndexError): - qubits = "unknown" - - durationMs = (monotonic() - start) * 1000 - telemetry_events.on_estimate_end(durationMs, qubits) - return EstimatorResult(res) + return _get_default_ctx().estimate(entry_expr, params, *args) def logical_counts( @@ -997,21 +1339,7 @@ def logical_counts( :returns `LogicalCounts`: Program resources in terms of logical gate counts. """ - - ipython_helper() - - if isinstance(entry_expr, Callable) and hasattr(entry_expr, "__global_callable"): - args = python_args_to_interpreter_args(args) - res_dict = get_interpreter().logical_counts( - callable=entry_expr.__global_callable, args=args - ) - elif isinstance(entry_expr, (GlobalCallable, Closure)): - args = python_args_to_interpreter_args(args) - res_dict = get_interpreter().logical_counts(callable=entry_expr, args=args) - else: - assert isinstance(entry_expr, str) - res_dict = get_interpreter().logical_counts(entry_expr=entry_expr) - return LogicalCounts(res_dict) + return _get_default_ctx().logical_counts(entry_expr, *args) def set_quantum_seed(seed: Optional[int]) -> None: @@ -1022,7 +1350,7 @@ def set_quantum_seed(seed: Optional[int]) -> None: :param seed: The seed to use for the quantum random number generator. If None, the seed will be generated from entropy. """ - get_interpreter().set_quantum_seed(seed) + _get_default_ctx().set_quantum_seed(seed) def set_classical_seed(seed: Optional[int]) -> None: @@ -1034,7 +1362,7 @@ def set_classical_seed(seed: Optional[int]) -> None: :param seed: The seed to use for the classical random number generator. If None, the seed will be generated from entropy. """ - get_interpreter().set_classical_seed(seed) + _get_default_ctx().set_classical_seed(seed) def dump_machine() -> StateDump: @@ -1043,8 +1371,7 @@ def dump_machine() -> StateDump: :returns: The state of the simulator. """ - ipython_helper() - return StateDump(get_interpreter().dump_machine()) + return _get_default_ctx().dump_machine() def dump_circuit() -> Circuit: @@ -1056,5 +1383,4 @@ def dump_circuit() -> Circuit: Requires the interpreter to be initialized with `trace_circuit=True`. """ - ipython_helper() - return get_interpreter().dump_circuit() + return _get_default_ctx().dump_circuit() diff --git a/source/pip/qsharp/openqasm/_circuit.py b/source/pip/qsharp/openqasm/_circuit.py index bcd77707bf..c8e8a2ce4b 100644 --- a/source/pip/qsharp/openqasm/_circuit.py +++ b/source/pip/qsharp/openqasm/_circuit.py @@ -60,9 +60,8 @@ def circuit( if isinstance(source, Callable) and hasattr(source, "__global_callable"): args = python_args_to_interpreter_args(args) - res = get_interpreter().circuit( - config, callable=source.__global_callable, args=args - ) + interp = getattr(source, "_qdk_get_interpreter", get_interpreter)() + res = interp.circuit(config, callable=source.__global_callable, args=args) else: # remove any entries from kwargs with a None key or None value kwargs = {k: v for k, v in kwargs.items() if k is not None and v is not None} diff --git a/source/pip/qsharp/openqasm/_compile.py b/source/pip/qsharp/openqasm/_compile.py index 850b9621dc..8b61cae955 100644 --- a/source/pip/qsharp/openqasm/_compile.py +++ b/source/pip/qsharp/openqasm/_compile.py @@ -70,7 +70,8 @@ def compile( if isinstance(source, Callable) and hasattr(source, "__global_callable"): args = python_args_to_interpreter_args(args) - ll_str = get_interpreter().qir( + interp = getattr(source, "_qdk_get_interpreter", get_interpreter)() + ll_str = interp.qir( entry_expr=None, callable=source.__global_callable, args=args ) elif isinstance(source, str): @@ -91,7 +92,9 @@ def compile( **kwargs, ) else: - raise ValueError("source must be a string or a callable with __global_callable attribute") + raise ValueError( + "source must be a string or a callable with __global_callable attribute" + ) res = QirInputData("main", ll_str) durationMs = (monotonic() - start) * 1000 diff --git a/source/pip/qsharp/openqasm/_estimate.py b/source/pip/qsharp/openqasm/_estimate.py index 1555e58549..35e6515ba6 100644 --- a/source/pip/qsharp/openqasm/_estimate.py +++ b/source/pip/qsharp/openqasm/_estimate.py @@ -71,7 +71,8 @@ def _coerce_estimator_params( start = monotonic() if isinstance(source, Callable) and hasattr(source, "__global_callable"): args = python_args_to_interpreter_args(args) - res_str = get_interpreter().estimate( + interp = getattr(source, "_qdk_get_interpreter", get_interpreter)() + res_str = interp.estimate( param_str, entry_expr=None, callable=source.__global_callable, args=args ) elif isinstance(source, str): diff --git a/source/pip/qsharp/openqasm/_run.py b/source/pip/qsharp/openqasm/_run.py index 25a351f616..e21199dfc1 100644 --- a/source/pip/qsharp/openqasm/_run.py +++ b/source/pip/qsharp/openqasm/_run.py @@ -106,6 +106,7 @@ def on_save_events(output: Output) -> None: source_str = source if callable: + interp = getattr(source, "_qdk_get_interpreter", get_interpreter)() for _ in range(shots): results.append( { @@ -116,7 +117,7 @@ def on_save_events(output: Output) -> None: "messages": [], } ) - run_results = get_interpreter().run( + run_results = interp.run( source_str, on_save_events if save_events else display_or_print, noise, diff --git a/source/pip/qsharp/utils/_utils.py b/source/pip/qsharp/utils/_utils.py index 6268c1801a..60dd4ea3d9 100644 --- a/source/pip/qsharp/utils/_utils.py +++ b/source/pip/qsharp/utils/_utils.py @@ -2,11 +2,16 @@ # Licensed under the MIT License. from .._qsharp import run -from typing import List +from typing import List, Optional, TYPE_CHECKING import math +if TYPE_CHECKING: + from .._qsharp import QdkContext -def dump_operation(operation: str, num_qubits: int) -> List[List[complex]]: + +def dump_operation( + operation: str, num_qubits: int, *, ctx: Optional["QdkContext"] = None +) -> List[List[complex]]: """ Returns a square matrix of complex numbers representing the operation performed. @@ -27,7 +32,7 @@ def dump_operation(operation: str, num_qubits: int) -> List[List[complex]]: Microsoft.Quantum.Diagnostics.DumpMachine(); ResetAll(targets + extra); }}""" - result = run(code, shots=1, save_events=True)[0] + result = (ctx.run if ctx else run)(code, shots=1, save_events=True)[0] state = result["events"][-1].state_dump().get_dict() num_entries = pow(2, num_qubits) factor = math.sqrt(num_entries) diff --git a/source/pip/tests/test_qsharp.py b/source/pip/tests/test_qsharp.py index 9ffd92a592..a149ca778a 100644 --- a/source/pip/tests/test_qsharp.py +++ b/source/pip/tests/test_qsharp.py @@ -508,16 +508,16 @@ def test_compile_qir_str_from_qsharp_callable_with_multiple_args_passed_as_tuple def test_init_from_provider_name() -> None: - config = qsharp.init(target_name="ionq.simulator") - assert config._config["targetProfile"] == "base" - config = qsharp.init(target_name="rigetti.sim.qvm") - assert config._config["targetProfile"] == "base" - config = qsharp.init(target_name="quantinuum.sim") - assert config._config["targetProfile"] == "adaptive_ri" - config = qsharp.init(target_name="Quantinuum") - assert config._config["targetProfile"] == "adaptive_ri" - config = qsharp.init(target_name="IonQ") - assert config._config["targetProfile"] == "base" + ctx = qsharp.init(target_name="ionq.simulator") + assert ctx.config._config["targetProfile"] == "base" + ctx = qsharp.init(target_name="rigetti.sim.qvm") + assert ctx.config._config["targetProfile"] == "base" + ctx = qsharp.init(target_name="quantinuum.sim") + assert ctx.config._config["targetProfile"] == "adaptive_ri" + ctx = qsharp.init(target_name="Quantinuum") + assert ctx.config._config["targetProfile"] == "adaptive_ri" + ctx = qsharp.init(target_name="IonQ") + assert ctx.config._config["targetProfile"] == "base" def test_run_with_result(capsys) -> None: @@ -1201,3 +1201,181 @@ def test_swap_label_circuit_from_callable() -> None: q_1 ────────────── """ ) + + +# --- QdkContext tests --- + + +def test_context_eval() -> None: + ctx = qsharp.new_context() + result = ctx.eval("1 + 2") + assert result == 3 + + +def test_context_isolation() -> None: + ctx1 = qsharp.new_context() + ctx2 = qsharp.new_context() + ctx1.eval("function Foo() : Int { 42 }") + result1 = ctx1.eval("Foo()") + assert result1 == 42 + # ctx2 should not have Foo defined + with pytest.raises(Exception): + ctx2.eval("Foo()") + + +def test_context_run() -> None: + ctx = qsharp.new_context() + ctx.eval('operation Foo() : Result { Message("hi"); Zero }') + results = ctx.run("Foo()", 3) + assert results == [qsharp.Result.Zero, qsharp.Result.Zero, qsharp.Result.Zero] + + +def test_module_level_backward_compat() -> None: + qsharp.init() + result = qsharp.eval("1 + 1") + assert result == 2 + + +def test_init_returns_context() -> None: + ctx = qsharp.init() + assert isinstance(ctx, qsharp.QdkContext) + # The context should be usable directly + result = ctx.eval("3 + 4") + assert result == 7 + # Module-level eval should use the same context + result2 = qsharp.eval("3 + 4") + assert result2 == 7 + + +def test_context_callable_has_interpreter_ref() -> None: + """Callables created via eval carry a _qdk_get_interpreter attribute.""" + ctx = qsharp.new_context() + ctx.eval("function Add(a : Int, b : Int) : Int { a + b }") + add_fn = ctx.code.Add + assert hasattr(add_fn, "_qdk_get_interpreter") + assert add_fn._qdk_get_interpreter() is ctx._interpreter + + +def test_context_import_openqasm() -> None: + """import_openqasm loads an OpenQASM program into the context.""" + ctx = qsharp.new_context() + ctx.import_openqasm( + """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit q; + h q; + """ + ) + + +def test_context_dump_operation() -> None: + """dump_operation works with an explicit context.""" + from qsharp.utils import dump_operation + + ctx = qsharp.new_context() + matrix = dump_operation("qs => X(qs[0])", 1, ctx=ctx) + # X gate matrix should swap |0> and |1> + assert len(matrix) == 2 + assert len(matrix[0]) == 2 + + +def test_backward_compat_alias() -> None: + """QSharpContext alias still works for backward compatibility.""" + assert qsharp.QSharpContext is qsharp.QdkContext + + +def test_get_context_returns_global() -> None: + """get_context() returns the global default context.""" + ctx1 = qsharp.init() + ctx2 = qsharp.get_context() + assert ctx2 is ctx1 + + +def test_context_of_returns_origin() -> None: + """context_of() returns the context that compiled the callable.""" + ctx = qsharp.new_context() + ctx.eval("function Hello() : Int { 1 }") + fn = ctx.code.Hello + assert qsharp.context_of(fn) is ctx + + +def test_context_of_global_callable() -> None: + """context_of() works for callables in the global context.""" + ctx = qsharp.init() + qsharp.eval("function Hi() : Int { 2 }") + fn = qsharp.code.Hi + assert qsharp.context_of(fn) is ctx + + +def test_context_of_rejects_non_callable() -> None: + """context_of() raises TypeError for non-QDK objects.""" + with pytest.raises(TypeError, match="Expected a QDK callable"): + qsharp.context_of(lambda: None) + + +def test_cross_context_run_raises() -> None: + """Passing a callable from one context to another's run() raises.""" + ctx_a = qsharp.new_context() + ctx_b = qsharp.new_context() + ctx_a.eval("operation Foo() : Result { use q = Qubit(); M(q) }") + foo = ctx_a.code.Foo + with pytest.raises(Exception, match="different QdkContext"): + ctx_b.run(foo, 1) + + +def test_cross_context_compile_raises() -> None: + """Passing a callable from one context to another's compile() raises.""" + ctx_a = qsharp.new_context(target_profile=qsharp.TargetProfile.Base) + ctx_b = qsharp.new_context(target_profile=qsharp.TargetProfile.Base) + ctx_a.eval("operation Bar() : Result { use q = Qubit(); M(q) }") + bar = ctx_a.code.Bar + with pytest.raises(Exception, match="different QdkContext"): + ctx_b.compile(bar) + + +def test_cross_context_circuit_raises() -> None: + """Passing a callable from one context to another's circuit() raises.""" + ctx_a = qsharp.new_context() + ctx_b = qsharp.new_context() + ctx_a.eval("operation Baz() : Unit { use q = Qubit(); H(q); }") + baz = ctx_a.code.Baz + with pytest.raises(Exception, match="different QdkContext"): + ctx_b.circuit(baz) + + +def test_cross_context_estimate_raises() -> None: + """Passing a callable from one context to another's estimate() raises.""" + ctx_a = qsharp.new_context() + ctx_b = qsharp.new_context() + ctx_a.eval("operation Qux() : Unit { use q = Qubit(); H(q); }") + qux = ctx_a.code.Qux + with pytest.raises(Exception, match="different QdkContext"): + ctx_b.estimate(qux) + + +def test_cross_context_logical_counts_raises() -> None: + """Passing a callable from one context to another's logical_counts() raises.""" + ctx_a = qsharp.new_context() + ctx_b = qsharp.new_context() + ctx_a.eval("operation Corge() : Unit { use q = Qubit(); H(q); }") + corge = ctx_a.code.Corge + with pytest.raises(Exception, match="different QdkContext"): + ctx_b.logical_counts(corge) + + +def test_stale_callable_after_reinit() -> None: + """Callables from a prior init() become invalid after re-initialization.""" + qsharp.init() + qsharp.eval("function Stale() : Int { 99 }") + old_fn = qsharp.code.Stale + # Reinitialize — old callable should now be stale + qsharp.init() + with pytest.raises(Exception, match="disposed"): + old_fn() + + +def test_context_config_property() -> None: + """QdkContext exposes a .config property with the target profile.""" + ctx = qsharp.new_context(target_profile=qsharp.TargetProfile.Base) + assert ctx.config.get_target_profile() == "base" diff --git a/source/qdk_package/src/qdk/__init__.py b/source/qdk_package/src/qdk/__init__.py index f0f6daf17b..3e7c332a2d 100644 --- a/source/qdk_package/src/qdk/__init__.py +++ b/source/qdk_package/src/qdk/__init__.py @@ -34,6 +34,10 @@ DepolarizingNoise, BitFlipNoise, PhaseFlipNoise, + QdkContext, + new_context, + get_context, + context_of, ) # utilities lifted from qsharp @@ -51,4 +55,8 @@ "DepolarizingNoise", "BitFlipNoise", "PhaseFlipNoise", + "QdkContext", + "new_context", + "get_context", + "context_of", ] diff --git a/source/qdk_package/tests/mocks.py b/source/qdk_package/tests/mocks.py index f4cb7f1938..445f9420c2 100644 --- a/source/qdk_package/tests/mocks.py +++ b/source/qdk_package/tests/mocks.py @@ -49,6 +49,10 @@ class _T: # placeholder types stub.DepolarizingNoise = _T stub.BitFlipNoise = _T stub.PhaseFlipNoise = _T + stub.QdkContext = _T + stub.new_context = _not_impl + stub.get_context = _not_impl + stub.context_of = _not_impl stub.__all__ = [ "run", "estimate", @@ -65,6 +69,10 @@ class _T: # placeholder types "DepolarizingNoise", "BitFlipNoise", "PhaseFlipNoise", + "QdkContext", + "new_context", + "get_context", + "context_of", "estimator", "openqasm", "utils",