diff --git a/pygsti/forwardsims/torchfwdsim.py b/pygsti/forwardsims/torchfwdsim.py index 1285e51de..5079b26e1 100644 --- a/pygsti/forwardsims/torchfwdsim.py +++ b/pygsti/forwardsims/torchfwdsim.py @@ -31,6 +31,7 @@ try: import torch + from torch.profiler import profile, record_function, ProfilerActivity TORCH_ENABLED = True except ImportError: TORCH_ENABLED = False @@ -89,6 +90,7 @@ def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArra # framed in terms of the "layout._element_indicies" dict. eind = layout._element_indices assert isinstance(eind, dict) + assert len(eind) > 0 items = iter(eind.items()) k_prev, v_prev = next(items) assert k_prev == 0 @@ -159,6 +161,17 @@ def get_torch_bases(self, free_params: Tuple[torch.Tensor]) -> Dict[Label, torch fp in free_params. This can be done by calling fp._requires_grad(True) before calling this function. """ + # The closest analog to this function in tgst is the first couple lines in + # tgst.gst.MachineModel.circuit_outcome_probs(...). + # Those lines just assign values a-la new_machine.params[i][:] = fp[:]. + # + # The variables new_machine.params[i] are just references to Tensors + # that are attached to tgst.abstractions objects (Gate, Measurement, State). + # + # Calling abstr.rep_array for a given abstraction performs a computation on + # its attached Tensor, and that computation is roughly analogous to + # torchable.torch_base(...). + # assert len(free_params) == len(self.param_metadata) # ^ A sanity check that we're being called with the correct number of arguments. torch_bases = dict() @@ -202,8 +215,23 @@ def circuit_probs_from_free_params(self, *free_params: Tuple[torch.Tensor], enab if enable_backward: for fp in free_params: fp._requires_grad(True) - torch_bases = self.get_torch_bases(free_params) - probs = self.circuit_probs_from_torch_bases(torch_bases) + + torch_bases = dict() + for i, val in enumerate(free_params): + label, type_handle, stateless_data = self.param_metadata[i] + param_t = type_handle.torch_base(stateless_data, val) + torch_bases[label] = param_t + + probs = [] + for c in self.circuits: + superket = torch_bases[c.prep_label] + superops = [torch_bases[ol] for ol in c.op_labels] + povm_mat = torch_bases[c.povm_label] + for superop in superops: + superket = superop @ superket + circuit_probs = povm_mat @ superket + probs.append(circuit_probs) + probs = torch.concat(probs) return probs @@ -248,8 +276,10 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: if slm.default_to_reverse_ad: # Then slm.circuit_probs_from_free_params will automatically construct the # torch_base dict to support reverse-mode AD. + # print('USING REVERSE-MODE AD') J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums) else: + # print('USING FORWARD-MODE AD') # Then slm.circuit_probs_from_free_params will automatically skip the extra # steps needed for torch_base to support reverse-mode AD. J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums) @@ -258,7 +288,14 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: # have a need to override the default in the future then we'd need to override # the ForwardSimulator function(s) that call self._bulk_fill_dprobs(...). + # import time + # print('Calling J_func at current free_params') + # tic = time.time() + # with profile(activities=[ProfilerActivity.CPU], profile_memory=True) as prof: J_val = J_func(*free_params) + # toc = time.time() + # print() + # print(f'Done! --> {toc - tic} seconds elapsed') J_val = torch.column_stack(J_val) array_to_fill[:] = J_val.cpu().detach().numpy() return diff --git a/pygsti/modelmembers/operations/fulltpop.py b/pygsti/modelmembers/operations/fulltpop.py index 16866b893..f77307ad7 100644 --- a/pygsti/modelmembers/operations/fulltpop.py +++ b/pygsti/modelmembers/operations/fulltpop.py @@ -164,16 +164,18 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() # because _rep.base == _ptr (same memory) self.dirty = dirty_value - def stateless_data(self) -> Tuple[int]: - return (self.dim,) - - @staticmethod - def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor: - dim = sd[0] + def stateless_data(self) -> Tuple[int, _torch.Tensor]: + dim = self.dim t_const = _torch.zeros(size=(1, dim), dtype=_torch.double) t_const[0,0] = 1.0 - t_param_mat = t_param.reshape((dim - 1, dim)) + return (dim, t_const) + + @staticmethod + def torch_base(sd: Tuple[int, _torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor: + dim, t_const = sd + t_param_mat = t_param.view(dim - 1, dim) t = _torch.row_stack((t_const, t_param_mat)) + # TODO: cache the row of all zeros? return t diff --git a/pygsti/modelmembers/povms/conjugatedeffect.py b/pygsti/modelmembers/povms/conjugatedeffect.py index 5af305a44..3b0b5ddec 100644 --- a/pygsti/modelmembers/povms/conjugatedeffect.py +++ b/pygsti/modelmembers/povms/conjugatedeffect.py @@ -80,6 +80,12 @@ def __setitem__(self, key, val): ret = self.columnvec.__setitem__(key, val) self._ptr_has_changed() return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__.update(d) def __getattr__(self, attr): #use __dict__ so no chance for recursive __getattr__ diff --git a/pygsti/modelmembers/povms/tppovm.py b/pygsti/modelmembers/povms/tppovm.py index 80753385f..1183f5e3e 100644 --- a/pygsti/modelmembers/povms/tppovm.py +++ b/pygsti/modelmembers/povms/tppovm.py @@ -102,29 +102,28 @@ def to_vector(self): vec = _np.concatenate(effect_vecs) return vec - def stateless_data(self) -> Tuple[int, _np.ndarray]: + def stateless_data(self) -> Tuple[int, _torch.Tensor, int]: num_effects = len(self) complement_effect = self[self.complement_label] identity = complement_effect.identity.to_vector() - return (num_effects, identity) - - @staticmethod - def torch_base(sd: Tuple[int, _np.ndarray], t_param: _torch.Tensor) -> _torch.Tensor: - num_effects, identity = sd + identity = identity.reshape((1, -1)) # make into a row vector + t_identity = _torch.from_numpy(identity) + dim = identity.size - - first_basis_vec = _np.zeros(dim) - first_basis_vec[0] = dim ** 0.25 + first_basis_vec = _np.zeros((1,dim)) + first_basis_vec[0,0] = dim ** 0.25 TOL = 1e-15 * _np.sqrt(dim) if _np.linalg.norm(first_basis_vec - identity) > TOL: # Don't error out. The documentation for the class # clearly indicates that the meaning of "identity" # can be nonstandard. warnings.warn('Unexpected normalization!') + return (num_effects, t_identity, dim) - identity = identity.reshape((1, -1)) # make into a row vector - t_identity = _torch.from_numpy(identity) - t_param_mat = t_param.reshape((num_effects - 1, dim)) + @staticmethod + def torch_base(sd: Tuple[int, _torch.Tensor, int], t_param: _torch.Tensor) -> _torch.Tensor: + num_effects, t_identity, dim = sd + t_param_mat = t_param.view(num_effects - 1, dim) t_func = t_identity - t_param_mat.sum(axis=0, keepdim=True) t = _torch.row_stack((t_param_mat, t_func)) return t diff --git a/pygsti/modelmembers/states/densestate.py b/pygsti/modelmembers/states/densestate.py index 3c7df543f..2d9b17fc0 100644 --- a/pygsti/modelmembers/states/densestate.py +++ b/pygsti/modelmembers/states/densestate.py @@ -100,6 +100,12 @@ def __setitem__(self, key, val): ret = self.columnvec.__setitem__(key, val) self._ptr_has_changed() return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__.update(d) def __getattr__(self, attr): #use __dict__ so no chance for recursive __getattr__ diff --git a/pygsti/modelmembers/states/tpstate.py b/pygsti/modelmembers/states/tpstate.py index 659d6da24..c74c49d78 100644 --- a/pygsti/modelmembers/states/tpstate.py +++ b/pygsti/modelmembers/states/tpstate.py @@ -166,13 +166,14 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() self.dirty = dirty_value - def stateless_data(self) -> Tuple[int]: - return (self.dim,) + def stateless_data(self) -> Tuple[_torch.Tensor]: + dim = self.dim + t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double) + return (t_const,) @staticmethod - def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor: - dim = sd[0] - t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double) + def torch_base(sd: Tuple[_torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor: + t_const = sd[0] t = _torch.concat((t_const, t_param)) return t