diff --git a/cmaes/_cma.py b/cmaes/_cma.py index 717168e..6e7b6c4 100644 --- a/cmaes/_cma.py +++ b/cmaes/_cma.py @@ -378,9 +378,13 @@ def tell(self, solutions: list[tuple[np.ndarray, float]]) -> None: # (eq.47) rank_one = np.outer(self._pc, self._pc) + rank_mu = np.sum( - np.array([w * np.outer(y, y) for w, y in zip(w_io, y_k)]), axis=0 + w_io.reshape(-1, 1, 1) * np.einsum("...i,...j->...ij", y_k, y_k), axis=0 ) + # The above line is equivalent to: + # rank_mu = np.sum(np.array([w * np.outer(y, y) for w, y in zip(w_io, y_k)]), axis=0) + self._C = ( ( 1