diff --git a/README.md b/README.md
index ae359df..0efd383 100644
--- a/README.md
+++ b/README.md
@@ -3,12 +3,10 @@
Install |
-Quickstart |
+Quickstart (Colab) |
+Tutorials |
+Docs |
Testimonials |
-Demos |
-Usage |
-Gotchas |
-Tutorials
# Computations that save, query and version themselves
@@ -38,20 +36,12 @@ two tools:
pip install git+https://github.com/amakelov/mandala
```
-# Quickstart
-
-[Run in Colab](https://colab.research.google.com/github/amakelov/mandala/blob/master/mandala/_next/tutorials/hello.ipynb)
-
-# Documentation
-TODO: link
-
# Tutorials
- see the ["Hello world!"
- tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/00_hello.ipynb)
+ tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/hello.ipynb)
for a 2-minute introduction to the library's main features
-- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_random_forest.ipynb)
-for a more realistic example of a machine learning project managed by Mandala.
-- TODO: dependency tracking
+- see [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/ml.ipynb)
+for a more realistic example of managing a machine learning project.
# FAQs
diff --git a/docs/docs/04_versions.md b/docs/docs/04_versions.md
index 810f0a6..bdc65c7 100644
--- a/docs/docs/04_versions.md
+++ b/docs/docs/04_versions.md
@@ -1,5 +1,22 @@
# Changing `@op`s and managing versions
+It should be easy to change your code and have the storage respond in a correct
+way (e.g., recompute a call **only** when the logic behind it has changed).
+`mandala` provides the following mechanisms to do that:
+- **automatic per-call dependency tracking**: every `@op` call records the
+functions it called along the way. This allows the `storage` to automatically
+know, given some inputs, whether a past call for these inputs can be reused
+given the current state of the code. This is a very fine-grained notion of
+reuse.
+- **marking changes as breaking vs non-breaking**: when a change to an `@op` or
+its dependencies is detected, you can choose to mark it as breaking the calls that depend on it
+or not. However, **breaking changes are generally more fool-proof**; see [caveats of non-breaking changes](#caveats-of-marking-changes-as-non-breaking).
+- **content-based versioning**: the current state of the codebase uniquely
+determines the version each `@op` is in. There are no arbitrary names attached
+to versions. The versions for each `@op` can be inspected in a `git`-like data
+structure.
+
+## Enabling and configuring versioning
Passing a value to the `deps_path` parameter of the `Storage` class enables
dependency tracking and versioning. This means that any time a memoized function
*actually executes* (instead of reusing a past call's results), it keeps track
@@ -7,10 +24,20 @@ of the functions and global variables it accesses along the way.
Usually, the functions we want to track are limited to user-defined ones (you
typically don't want to track changes in installed libraries!):
-- Setting `deps_path` to `"__main__"` will only look for dependencies `f` defined in the current interactive session or process (as determined by `f.__module__`).
-- Setting it to a folder will only look for dependencies defined in this folder.
-### Caveat: The `@track` decorator
+- Setting `deps_path` to `"__main__"` will only look for dependencies `f`
+defined in the current interactive session or process (as determined by
+`f.__module__`).
+- Setting it to a folder will only look for dependencies defined in this folder.
+
+
+```python
+from mandala.imports import Storage, op, track
+
+storage = Storage(deps_path='__main__')
+```
+
+## The `@track` decorator
The most efficient and reliable implementation of dependency tracking currently
requires you to explicitly put `@track` on non-memoized functions and classes
you want to track. This limitation may be lifted in the future, but at the cost
@@ -19,26 +46,17 @@ current local scope that originate in given paths).
The alternative (experimental) decorator implementation is based on
`sys.settrace`. Limitations are described in this [blog
-post](https://amakelov.github.io/blog/deps/#syssettrace).
+post](https://amakelov.github.io/blog/deps/#syssettrace))
-### What is a version of an `@op`?
-A **version** for an `@op` is (to a first approximation) a collection of
-- hashes of the source code of functions and methods;
-- hashes of values of global variables
+## Examining the captured versions
+Let's run a small ML pipeline, where we optionally apply scaling to the data,
+introducing a non-`@op` dependency for some of the calls:
-accessed when a call to this `@op` was executed. Even if you don't change
-anything in the code, a single function can have multiple versions if it invokes
-different dependencies for different calls.
-### Versioning in action
-For example, consider this code:
```python
-import numpy as np
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
-from mandala.imports import Storage, op, track
-from typing import Tuple, Any
N_CLASS = 10
@@ -55,7 +73,7 @@ def load_data():
def train_model(X, y, scale=False):
if scale:
X = scale_data(X)
- return LogisticRegression().fit(X, y)
+ return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y)
@op
def eval_model(model, X, y, scale=False):
@@ -63,18 +81,94 @@ def eval_model(model, X, y, scale=False):
X = scale_data(X)
return model.score(X, y)
-storage = Storage(deps_path='__main__')
-
with storage:
X, y = load_data()
for scale in [False, True]:
model = train_model(X, y, scale=scale)
acc = eval_model(model, X, y, scale=scale)
```
-When you run it, `train_model` and `eval_model` will each have two versions -
-one that depends on `scale_data` and one that doesn't. You can confirm this by
-calling `storage.versions(train_model)`. Now suppose we make some changes
-and re-run:
+
+Now `train_model` and `eval_model` will each have two versions - one that
+depends on `scale_data` and one that doesn't. You can confirm this by calling
+e.g. `storage.versions(train_model)`:
+
+
+```python
+storage.versions(train_model)
+```
+
+
+
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function train_model from module __main__ │ +│ ### content_version_id=db93a1e9c60fb37868575845a7afe47d │ +│ ### semantic_version_id=2acaa8919ddd4b5d8846f1f2d15bc971 │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @op │ +│ def train_model(X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function train_model from module __main__ │ +│ ### content_version_id=2717c55fbbbb60442535a8dea0c81f67 │ +│ ### semantic_version_id=4674057d19bbf217687dd9dabe01df36 │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @track # to track a non-memoized function as a dependency │ +│ def scale_data(X): │ +│ return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │ +│ │ +│ @op │ +│ def train_model(X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ++ + + +## Making changes and sorting them into breaking and non-breaking +Now suppose we make some changes and re-run: + +- we change the value of the global variable `N_CLASS`; +- we change the code of `scale_data` in a semantically meaningful (i.e., +breaking) way +- we change the code of `eval_model` in a "cosmetic" way that can be considered +non-breaking. + +When entering the `storage` block, the storage will detect the changes in +the tracked components, and for each change will present you with the functions +affected: + +- `N_CLASS` is a dependency for `load_data`; +- `scale_data` is a dependency for the calls to `train_model` and `eval_model` + which had `scale=True`; +- `eval_model` is a dependency for itself. + + +```python +### simulate user input non-interactively +from unittest.mock import patch + +def mock_input(prompts): + it = iter(prompts) + def mock_input_func(*args): + return next(it) + return mock_input_func +``` + + ```python N_CLASS = 5 @@ -88,40 +182,209 @@ def eval_model(model, X, y, scale=False): X = scale_data(X) return round(model.score(X, y), 2) -with storage: - X, y = load_data() - for scale in [False, True]: - model = train_model(X, y, scale=scale) - acc = eval_model(model, X, y, scale=scale) +answers = ['y', 'n', 'y'] + +with patch('builtins.input', mock_input(answers)): + with storage: + X, y = load_data() + for scale in [False, True]: + model = train_model(X, y, scale=scale) + acc = eval_model(model, X, y, scale=scale) ``` -When entering the `storage` block, the storage will detect the changes in -the tracked components, and for each change will present you with the functions -affected: -- `N_CLASS` is a dependency for `load_data`; -- `scale_data` is a dependency for the calls to `train_model` and `eval_model` - which had `scale=True`; -- `eval_model` is a dependency for itself. -### Semantic vs content changes and versions -For each change to the content of some dependency (the source code of a function -or the value of a global variable), you can choose whether this content change -is also a **semantic** change. A semantic change will cause all calls that -have accessed this dependency to not appear memoized **with respect to the new -state of the code**. The content versions of a single dependency are organized -in a `git`-like DAG (currently, tree) that can be inspected using -`storage.sources(f)` for functions. + CHANGE DETECTED in N_CLASS from module __main__ + Dependent components: + Version of "load_data" from module "__main__" (content: 426c4c8c56d7c0d6374095c7d4a4974f, semantic: 7d0732b9bfb31e5e2211e0122651a624) + + + +
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮ +│ 1 -10 │ +│ 2 +5 │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ++ + + + Does this change require recomputation of dependent calls? + WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them. + Answer: [y]es/[n]o/[a]bort + You answered: "y" + CHANGE DETECTED in eval_model from module __main__ + Dependent components: + Version of "eval_model" from module "__main__" (content: 955b2a683de8dacf624047c0e020140a, semantic: c847d6dc3f23c176e6c8bf9e7006576a) + Version of "eval_model" from module "__main__" (content: 5bdcd6ffc4888990d8922aa85795198d, semantic: 4e1d702e9797ebba156831294de46425) + + + +
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮ +│ 1 if scale: │ +│ 2 X = scale_data(X) │ +│ 3 - return model.score(X, y) │ +│ 4 + return round(model.score(X, y), 2) │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ++ + + + Does this change require recomputation of dependent calls? + WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them. + Answer: [y]es/[n]o/[a]bort + You answered: "n" + CHANGE DETECTED in scale_data from module __main__ + Dependent components: + Version of "train_model" from module "__main__" (content: 2717c55fbbbb60442535a8dea0c81f67, semantic: 4674057d19bbf217687dd9dabe01df36) + Version of "eval_model" from module "__main__" (content: 5bdcd6ffc4888990d8922aa85795198d, semantic: 4e1d702e9797ebba156831294de46425) + + + +
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮ +│ 1 -@track # to track a non-memoized function as a dependency │ +│ 2 +@track │ +│ 3 def scale_data(X): │ +│ 4 - return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │ +│ 5 + return StandardScaler(with_mean=True, with_std=True).fit_transform(X) │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ++ + + + Does this change require recomputation of dependent calls? + WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them. + Answer: [y]es/[n]o/[a]bort + You answered: "y" + + +When a change is detected, the UI: + +- shows the diffs in each function, +- gives you a list of which `@op`s' versions are affected by each change +- lets you choose if the change is breaking or non-breaking + +We can check what happened by constructing a computation frame: + + +```python +cf = storage.cf(eval_model).expand_all() +cf.draw(verbose=True) +``` + + + + + + + +We see that `load_data` has two versions in use, whereas `train_model` and +`eval_model` both have three. Which ones? Again, call `versions` to find out. +For example, with `eval_model`, we have 4 different content versions, that +overall span 3 semantically different versions: + + +```python +storage.versions(eval_model) +``` + + +
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function eval_model from module __main__ │ +│ ### content_version_id=955b2a683de8dacf624047c0e020140a │ +│ ### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @op │ +│ def eval_model(model, X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return model.score(X, y) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function eval_model from module __main__ │ +│ ### content_version_id=5bdcd6ffc4888990d8922aa85795198d │ +│ ### semantic_version_id=4e1d702e9797ebba156831294de46425 │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @op │ +│ def eval_model(model, X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return model.score(X, y) │ +│ │ +│ @track # to track a non-memoized function as a dependency │ +│ def scale_data(X): │ +│ return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function eval_model from module __main__ │ +│ ### content_version_id=b50e3e2529b811e226d2bb39a572a5e4 │ +│ ### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @op │ +│ def eval_model(model, X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return model.score(X, y) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ### Dependencies for version of function eval_model from module __main__ │ +│ ### content_version_id=136129b20d9a3a3795e88ba8cf89b115 │ +│ ### semantic_version_id=f2573de2a6c25b390fc86d665ea85687 │ +│ │ +│ ################################################################################ │ +│ ### IN MODULE "__main__" │ +│ ################################################################################ │ +│ │ +│ @op │ +│ def eval_model(model, X, y, scale=False): │ +│ if scale: │ +│ X = scale_data(X) │ +│ return model.score(X, y) │ +│ │ +│ @track │ +│ def scale_data(X): │ +│ return StandardScaler(with_mean=True, with_std=True).fit_transform(X) │ +│ │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ++ + + +### So what really is a version of an `@op`? +A **version** for an `@op` is a collection of + +- (hashes of) the source code of functions and methods +- (hashes of) values of global variables + +at the time when a call to this `@op` was executed. Even if you don't change +anything in the code, a single function can have multiple versions if it invokes +different dependencies for different calls. ### Going back in time Since the versioning system is content-based, simply restoring an old state of the code makes the storage automatically recognize which "world" it's in, and which calls are memoized in this world. -### A warning about non-semantic changes -The main motivation for allowing non-semantic changes is to maintain clarity in -the storage when doing routine code improvements (refactoring, comments, -logging). **However**, non-semantic changes should be applied with care. Apart from -being prone to errors (you wrongly conclude that a change has no effect on -semantics when it does), they can also introduce **invisible dependencies**: -suppose you factor a function out of some dependency and mark the change -non-semantic. Then the newly extracted function may in reality be a dependency -of the existing calls, but this goes unnoticed by the system. +### Caveats of marking changes as non-breaking +The main motivation for allowing non-breaking changes is to maintain the storage +when doing routine code improvements (refactoring, comments, logging). + +**However**, non-semantic changes should be applied with care. Apart from being +prone to errors (you wrongly conclude that a change has no effect on semantics +when it does), they can also introduce **invisible dependencies**: suppose you +factor a function out of some dependency and mark the change non-semantic. Then +the newly extracted function may in reality be a dependency of the existing +calls, but this goes unnoticed by the system. Consequently, changes in this +dependency may go unnoticed by the versioning algorithm. diff --git a/docs/docs/04_versions_files/04_versions_12_0.svg b/docs/docs/04_versions_files/04_versions_12_0.svg new file mode 100644 index 0000000..423fe2a --- /dev/null +++ b/docs/docs/04_versions_files/04_versions_12_0.svg @@ -0,0 +1,160 @@ + + + + + diff --git a/docs/docs/stylesheets/extra.css b/docs/docs/stylesheets/extra.css index d9ee1d1..135969c 100644 --- a/docs/docs/stylesheets/extra.css +++ b/docs/docs/stylesheets/extra.css @@ -10,3 +10,7 @@ --md-code-bg-color: #eee8d5; --md-code-fg-color: #657b83; } + +a { + text-decoration: underline; +} \ No newline at end of file diff --git a/mandala/deps/shallow_versions.py b/mandala/deps/shallow_versions.py index 98282f8..8324734 100644 --- a/mandala/deps/shallow_versions.py +++ b/mandala/deps/shallow_versions.py @@ -235,11 +235,27 @@ def commit(self, content: T, is_semantic_change: Optional[bool] = None) -> str: self.content_adapter.get_presentable_content(content), self.get_presentable_content(commit=self.head), ) - print( - _get_colorized_diff( - current=presentable_diff[1], new=presentable_diff[0] + if Config.has_rich: + colorized_diff = _get_colorized_diff( + current=presentable_diff[1], new=presentable_diff[0], + colorize=False, + ) + panel = Panel( + Syntax( + colorized_diff, + lexer="diff", + line_numbers=True, + theme="solarized-light", + ), + title="Diff", ) - ) + rich.print(panel) + else: + colorized_diff = _get_colorized_diff( + current=presentable_diff[1], new=presentable_diff[0], + colorize=True, + ) + print(colorized_diff) answer = ask_user( question="Does this change require recomputation of dependent calls?\nWARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\nAnswer: [y]es/[n]o/[a]bort", valid_options=["y", "n", "a"], diff --git a/mandala/deps/versioner.py b/mandala/deps/versioner.py index 3552080..8ee9eef 100644 --- a/mandala/deps/versioner.py +++ b/mandala/deps/versioner.py @@ -258,7 +258,7 @@ def sync_codebase(self, code_state: CodeState): ) print(f"CHANGE DETECTED in {component[1]} from module {component[0]}") print(f"Dependent components:\n{dependent_versions_presentation}") - print(f"===DIFF===:") + # print(f"===DIFF===:") dag.sync(content=content) # update the DAGs if all commits succeeded self.component_dags = dags diff --git a/mandala/docs/01_storage_and_ops.ipynb b/mandala/docs/01_storage_and_ops.ipynb index de22dc8..4a9270e 100644 --- a/mandala/docs/01_storage_and_ops.ipynb +++ b/mandala/docs/01_storage_and_ops.ipynb @@ -90,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-07-02T21:40:11.743630Z", @@ -104,13 +104,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "AtomRef(17, hid='43b...', cid='89c...')\n" + "AtomRef(42, hid='168...', cid='d92...')\n" ] } ], "source": [ "with storage: # all `@op` calls inside this block use `storage`\n", - " s = sum_args(1, 2, 3, 4, c=6,)\n", + " s = sum_args(6, 7, 8, 9, c=11,)\n", " print(s)" ] }, diff --git a/mandala/docs/04_versions.ipynb b/mandala/docs/04_versions.ipynb index da77288..2d731d3 100644 --- a/mandala/docs/04_versions.ipynb +++ b/mandala/docs/04_versions.ipynb @@ -5,7 +5,29 @@ "metadata": {}, "source": [ "# Changing `@op`s and managing versions\n", + "It should be easy to change your code and have the storage respond in a correct\n", + "way (e.g., recompute a call **only** when the logic behind it has changed).\n", + "`mandala` provides the following mechanisms to do that:\n", "\n", + "- **automatic per-call dependency tracking**: every `@op` call records the\n", + "functions it called along the way. This allows the `storage` to automatically\n", + "know, given some inputs, whether a past call for these inputs can be reused \n", + "given the current state of the code. This is a very fine-grained notion of\n", + "reuse.\n", + "- **marking changes as breaking vs non-breaking**: when a change to an `@op` or\n", + "its dependencies is detected, you can choose to mark it as breaking the calls that depend on it\n", + "or not. However, **breaking changes are generally more fool-proof**; see [caveats of non-breaking changes](#caveats-of-marking-changes-as-non-breaking).\n", + "- **content-based versioning**: the current state of the codebase uniquely\n", + "determines the version each `@op` is in. There are no arbitrary names attached\n", + "to versions. The versions for each `@op` can be inspected in a `git`-like data\n", + "structure." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enabling and configuring versioning\n", "Passing a value to the `deps_path` parameter of the `Storage` class enables\n", "dependency tracking and versioning. This means that any time a memoized function\n", "*actually executes* (instead of reusing a past call's results), it keeps track\n", @@ -13,10 +35,36 @@ "\n", "Usually, the functions we want to track are limited to user-defined ones (you\n", "typically don't want to track changes in installed libraries!):\n", - "- Setting `deps_path` to `\"__main__\"` will only look for dependencies `f` defined in the current interactive session or process (as determined by `f.__module__`).\n", - "- Setting it to a folder will only look for dependencies defined in this folder. \n", "\n", - "### Caveat: The `@track` decorator\n", + "- Setting `deps_path` to `\"__main__\"` will only look for dependencies `f`\n", + "defined in the current interactive session or process (as determined by\n", + "`f.__module__`).\n", + "- Setting it to a folder will only look for dependencies defined in this folder." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:32.003071Z", + "iopub.status.busy": "2024-07-02T23:37:32.002508Z", + "iopub.status.idle": "2024-07-02T23:37:34.082928Z", + "shell.execute_reply": "2024-07-02T23:37:34.082372Z" + } + }, + "outputs": [], + "source": [ + "from mandala.imports import Storage, op, track\n", + "\n", + "storage = Storage(deps_path='__main__')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `@track` decorator\n", "The most efficient and reliable implementation of dependency tracking currently\n", "requires you to explicitly put `@track` on non-memoized functions and classes\n", "you want to track. This limitation may be lifted in the future, but at the cost\n", @@ -25,26 +73,34 @@ "\n", "The alternative (experimental) decorator implementation is based on\n", "`sys.settrace`. Limitations are described in this [blog\n", - "post](https://amakelov.github.io/blog/deps/#syssettrace).\n", - "\n", - "### What is a version of an `@op`?\n", - "A **version** for an `@op` is (to a first approximation) a collection of\n", - "- hashes of the source code of functions and methods;\n", - "- hashes of values of global variables\n", - "\n", - "accessed when a call to this `@op` was executed. Even if you don't change\n", - "anything in the code, a single function can have multiple versions if it invokes\n", - "different dependencies for different calls. \n", - "\n", - "### Versioning in action\n", - "For example, consider this code:\n", - "```python\n", - "import numpy as np\n", + "post](https://amakelov.github.io/blog/deps/#syssettrace))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examining the captured versions\n", + "Let's run a small ML pipeline, where we optionally apply scaling to the data,\n", + "introducing a non-`@op` dependency for some of the calls:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:34.086211Z", + "iopub.status.busy": "2024-07-02T23:37:34.085833Z", + "iopub.status.idle": "2024-07-02T23:37:35.163955Z", + "shell.execute_reply": "2024-07-02T23:37:35.163223Z" + } + }, + "outputs": [], + "source": [ "from sklearn.datasets import load_digits\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.preprocessing import StandardScaler\n", - "from mandala.imports import Storage, op, track\n", - "from typing import Tuple, Any\n", "\n", "N_CLASS = 10\n", "\n", @@ -61,7 +117,7 @@ "def train_model(X, y, scale=False):\n", " if scale:\n", " X = scale_data(X)\n", - " return LogisticRegression().fit(X, y)\n", + " return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y)\n", "\n", "@op\n", "def eval_model(model, X, y, scale=False):\n", @@ -69,19 +125,296 @@ " X = scale_data(X)\n", " return model.score(X, y)\n", "\n", - "storage = Storage(deps_path='__main__')\n", - "\n", "with storage:\n", " X, y = load_data()\n", " for scale in [False, True]:\n", " model = train_model(X, y, scale=scale)\n", - " acc = eval_model(model, X, y, scale=scale)\n", - "```\n", - "When you run it, `train_model` and `eval_model` will each have two versions -\n", - "one that depends on `scale_data` and one that doesn't. You can confirm this by\n", - "calling `storage.versions(train_model)`. Now suppose we make some changes\n", - "and re-run:\n", - "```python\n", + " acc = eval_model(model, X, y, scale=scale)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now `train_model` and `eval_model` will each have two versions - one that\n", + "depends on `scale_data` and one that doesn't. You can confirm this by calling\n", + "e.g. `storage.versions(train_model)`:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:35.167116Z", + "iopub.status.busy": "2024-07-02T23:37:35.166713Z", + "iopub.status.idle": "2024-07-02T23:37:35.241357Z", + "shell.execute_reply": "2024-07-02T23:37:35.240816Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function train_model from module __main__ │\n", + "│ ### content_version_id=db93a1e9c60fb37868575845a7afe47d │\n", + "│ ### semantic_version_id=2acaa8919ddd4b5d8846f1f2d15bc971 │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @op │\n", + "│ def train_model(X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function train_model from module __main__ │\n", + "│ ### content_version_id=2717c55fbbbb60442535a8dea0c81f67 │\n", + "│ ### semantic_version_id=4674057d19bbf217687dd9dabe01df36 │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @track # to track a non-memoized function as a dependency │\n", + "│ def scale_data(X): │\n", + "│ return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │\n", + "│ │\n", + "│ @op │\n", + "│ def train_model(X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return LogisticRegression(max_iter=1000, solver='liblinear').fit(X, y) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function train_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=db93a1e9c60fb37868575845a7afe47d\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=2acaa8919ddd4b5d8846f1f2d15bc971\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mLogisticRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmax_iter\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1000\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227msolver\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mliblinear\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function train_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=2717c55fbbbb60442535a8dea0c81f67\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=4674057d19bbf217687dd9dabe01df36\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@track\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[3;38;2;147;161;161;48;2;253;246;227m# to track a non-memoized function as a dependency\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mStandardScaler\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_mean\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mTrue\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_std\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit_transform\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mLogisticRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmax_iter\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1000\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227msolver\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mliblinear\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "storage.versions(train_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Making changes and sorting them into breaking and non-breaking\n", + "Now suppose we make some changes and re-run:\n", + "\n", + "- we change the value of the global variable `N_CLASS`;\n", + "- we change the code of `scale_data` in a semantically meaningful (i.e.,\n", + "breaking) way\n", + "- we change the code of `eval_model` in a \"cosmetic\" way that can be considered\n", + "non-breaking.\n", + "\n", + "When entering the `storage` block, the storage will detect the changes in\n", + "the tracked components, and for each change will present you with the functions\n", + "affected:\n", + "\n", + "- `N_CLASS` is a dependency for `load_data`;\n", + "- `scale_data` is a dependency for the calls to `train_model` and `eval_model`\n", + " which had `scale=True`;\n", + "- `eval_model` is a dependency for itself." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:35.287994Z", + "iopub.status.busy": "2024-07-02T23:37:35.287434Z", + "iopub.status.idle": "2024-07-02T23:37:35.315147Z", + "shell.execute_reply": "2024-07-02T23:37:35.314507Z" + } + }, + "outputs": [], + "source": [ + "### simulate user input non-interactively\n", + "from unittest.mock import patch\n", + "\n", + "def mock_input(prompts):\n", + " it = iter(prompts)\n", + " def mock_input_func(*args):\n", + " return next(it)\n", + " return mock_input_func" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:35.318043Z", + "iopub.status.busy": "2024-07-02T23:37:35.317548Z", + "iopub.status.idle": "2024-07-02T23:37:35.518390Z", + "shell.execute_reply": "2024-07-02T23:37:35.517687Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CHANGE DETECTED in N_CLASS from module __main__\n", + "Dependent components:\n", + " Version of \"load_data\" from module \"__main__\" (content: 426c4c8c56d7c0d6374095c7d4a4974f, semantic: 7d0732b9bfb31e5e2211e0122651a624)\n" + ] + }, + { + "data": { + "text/html": [ + "
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ 1 -10 │\n", + "│ 2 +5 │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m1 \u001b[0m\u001b[38;2;220;50;47;48;2;253;246;227m-10\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m2 \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227m+5\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Does this change require recomputation of dependent calls?\n", + "WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\n", + "Answer: [y]es/[n]o/[a]bort \n", + "You answered: \"y\"\n", + "CHANGE DETECTED in eval_model from module __main__\n", + "Dependent components:\n", + " Version of \"eval_model\" from module \"__main__\" (content: 955b2a683de8dacf624047c0e020140a, semantic: c847d6dc3f23c176e6c8bf9e7006576a)\n", + " Version of \"eval_model\" from module \"__main__\" (content: 5bdcd6ffc4888990d8922aa85795198d, semantic: 4e1d702e9797ebba156831294de46425)\n" + ] + }, + { + "data": { + "text/html": [ + "
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ 1 if scale: │\n", + "│ 2 X = scale_data(X) │\n", + "│ 3 - return model.score(X, y) │\n", + "│ 4 + return round(model.score(X, y), 2) │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m1 \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m if scale:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m2 \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m X = scale_data(X)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m3 \u001b[0m\u001b[38;2;220;50;47;48;2;253;246;227m- return model.score(X, y)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m4 \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227m+ return round(model.score(X, y), 2)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Does this change require recomputation of dependent calls?\n", + "WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\n", + "Answer: [y]es/[n]o/[a]bort \n", + "You answered: \"n\"\n", + "CHANGE DETECTED in scale_data from module __main__\n", + "Dependent components:\n", + " Version of \"train_model\" from module \"__main__\" (content: 2717c55fbbbb60442535a8dea0c81f67, semantic: 4674057d19bbf217687dd9dabe01df36)\n", + " Version of \"eval_model\" from module \"__main__\" (content: 5bdcd6ffc4888990d8922aa85795198d, semantic: 4e1d702e9797ebba156831294de46425)\n" + ] + }, + { + "data": { + "text/html": [ + "
╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ 1 -@track # to track a non-memoized function as a dependency │\n", + "│ 2 +@track │\n", + "│ 3 def scale_data(X): │\n", + "│ 4 - return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │\n", + "│ 5 + return StandardScaler(with_mean=True, with_std=True).fit_transform(X) │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "╭───────────────────────────────────────────────────── Diff ──────────────────────────────────────────────────────╮\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m1 \u001b[0m\u001b[38;2;220;50;47;48;2;253;246;227m-@track # to track a non-memoized function as a dependency\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m2 \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227m+@track\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m3 \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mdef scale_data(X):\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m4 \u001b[0m\u001b[38;2;220;50;47;48;2;253;246;227m- return StandardScaler(with_mean=True, with_std=False).fit_transform(X)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[1;38;2;116;135;140;48;2;253;246;227m \u001b[0m\u001b[38;2;207;209;198;48;2;253;246;227m5 \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227m+ return StandardScaler(with_mean=True, with_std=True).fit_transform(X)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Does this change require recomputation of dependent calls?\n", + "WARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\n", + "Answer: [y]es/[n]o/[a]bort \n", + "You answered: \"y\"\n" + ] + } + ], + "source": [ "N_CLASS = 5\n", "\n", "@track\n", @@ -94,43 +427,432 @@ " X = scale_data(X)\n", " return round(model.score(X, y), 2)\n", "\n", - "with storage:\n", - " X, y = load_data()\n", - " for scale in [False, True]:\n", - " model = train_model(X, y, scale=scale)\n", - " acc = eval_model(model, X, y, scale=scale)\n", - "```\n", - "When entering the `storage` block, the storage will detect the changes in\n", - "the tracked components, and for each change will present you with the functions\n", - "affected:\n", - "- `N_CLASS` is a dependency for `load_data`;\n", - "- `scale_data` is a dependency for the calls to `train_model` and `eval_model`\n", - " which had `scale=True`;\n", - "- `eval_model` is a dependency for itself.\n", + "answers = ['y', 'n', 'y']\n", + "\n", + "with patch('builtins.input', mock_input(answers)):\n", + " with storage:\n", + " X, y = load_data()\n", + " for scale in [False, True]:\n", + " model = train_model(X, y, scale=scale)\n", + " acc = eval_model(model, X, y, scale=scale)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When a change is detected, the UI:\n", + "\n", + "- shows the diffs in each function,\n", + "- gives you a list of which `@op`s' versions are affected by each change\n", + "- lets you choose if the change is breaking or non-breaking\n", + "\n", + "We can check what happened by constructing a computation frame:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2024-07-02T23:37:35.521514Z", + "iopub.status.busy": "2024-07-02T23:37:35.521159Z", + "iopub.status.idle": "2024-07-02T23:37:35.907219Z", + "shell.execute_reply": "2024-07-02T23:37:35.906644Z" + } + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function eval_model from module __main__ │\n", + "│ ### content_version_id=955b2a683de8dacf624047c0e020140a │\n", + "│ ### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @op │\n", + "│ def eval_model(model, X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return model.score(X, y) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function eval_model from module __main__ │\n", + "│ ### content_version_id=5bdcd6ffc4888990d8922aa85795198d │\n", + "│ ### semantic_version_id=4e1d702e9797ebba156831294de46425 │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @op │\n", + "│ def eval_model(model, X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return model.score(X, y) │\n", + "│ │\n", + "│ @track # to track a non-memoized function as a dependency │\n", + "│ def scale_data(X): │\n", + "│ return StandardScaler(with_mean=True, with_std=False).fit_transform(X) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function eval_model from module __main__ │\n", + "│ ### content_version_id=b50e3e2529b811e226d2bb39a572a5e4 │\n", + "│ ### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @op │\n", + "│ def eval_model(model, X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return model.score(X, y) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ ### Dependencies for version of function eval_model from module __main__ │\n", + "│ ### content_version_id=136129b20d9a3a3795e88ba8cf89b115 │\n", + "│ ### semantic_version_id=f2573de2a6c25b390fc86d665ea85687 │\n", + "│ │\n", + "│ ################################################################################ │\n", + "│ ### IN MODULE \"__main__\" │\n", + "│ ################################################################################ │\n", + "│ │\n", + "│ @op │\n", + "│ def eval_model(model, X, y, scale=False): │\n", + "│ if scale: │\n", + "│ X = scale_data(X) │\n", + "│ return model.score(X, y) │\n", + "│ │\n", + "│ @track │\n", + "│ def scale_data(X): │\n", + "│ return StandardScaler(with_mean=True, with_std=True).fit_transform(X) │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function eval_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=955b2a683de8dacf624047c0e020140a\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227meval_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function eval_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=5bdcd6ffc4888990d8922aa85795198d\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=4e1d702e9797ebba156831294de46425\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227meval_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@track\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[3;38;2;147;161;161;48;2;253;246;227m# to track a non-memoized function as a dependency\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mStandardScaler\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_mean\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mTrue\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_std\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit_transform\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function eval_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=b50e3e2529b811e226d2bb39a572a5e4\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=c847d6dc3f23c176e6c8bf9e7006576a\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227meval_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function eval_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=136129b20d9a3a3795e88ba8cf89b115\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=f2573de2a6c25b390fc86d665ea85687\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227meval_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mFalse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;38;139;210;48;2;253;246;227m@track\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mscale_data\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mStandardScaler\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_mean\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mTrue\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mwith_std\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mTrue\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit_transform\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", + "│ \u001b[48;2;253;246;227m \u001b[0m │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "storage.versions(eval_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional notes\n", + "\n", + "### So what really is a version of an `@op`?\n", + "A **version** for an `@op` is a collection of\n", + "\n", + "- (hashes of) the source code of functions and methods\n", + "- (hashes of) values of global variables\n", "\n", - "### Semantic vs content changes and versions\n", - "For each change to the content of some dependency (the source code of a function\n", - "or the value of a global variable), you can choose whether this content change\n", - "is also a **semantic** change. A semantic change will cause all calls that\n", - "have accessed this dependency to not appear memoized **with respect to the new\n", - "state of the code**. The content versions of a single dependency are organized\n", - "in a `git`-like DAG (currently, tree) that can be inspected using\n", - "`storage.sources(f)` for functions. \n", + "at the time when a call to this `@op` was executed. Even if you don't change\n", + "anything in the code, a single function can have multiple versions if it invokes\n", + "different dependencies for different calls. \n", "\n", "### Going back in time\n", "Since the versioning system is content-based, simply restoring an old state of\n", "the code makes the storage automatically recognize which \"world\" it's in, and\n", "which calls are memoized in this world.\n", "\n", - "### A warning about non-semantic changes\n", - "The main motivation for allowing non-semantic changes is to maintain clarity in\n", - "the storage when doing routine code improvements (refactoring, comments,\n", - "logging). **However**, non-semantic changes should be applied with care. Apart from\n", - "being prone to errors (you wrongly conclude that a change has no effect on\n", - "semantics when it does), they can also introduce **invisible dependencies**:\n", - "suppose you factor a function out of some dependency and mark the change\n", - "non-semantic. Then the newly extracted function may in reality be a dependency\n", - "of the existing calls, but this goes unnoticed by the system." + "### Caveats of marking changes as non-breaking\n", + "The main motivation for allowing non-breaking changes is to maintain the storage\n", + "when doing routine code improvements (refactoring, comments, logging).\n", + "\n", + "**However**, non-semantic changes should be applied with care. Apart from being\n", + "prone to errors (you wrongly conclude that a change has no effect on semantics\n", + "when it does), they can also introduce **invisible dependencies**: suppose you\n", + "factor a function out of some dependency and mark the change non-semantic. Then\n", + "the newly extracted function may in reality be a dependency of the existing\n", + "calls, but this goes unnoticed by the system. Consequently, changes in this \n", + "dependency may go unnoticed by the versioning algorithm." ] } ], diff --git a/mandala/viz.py b/mandala/viz.py index 1d7667b..95bed42 100644 --- a/mandala/viz.py +++ b/mandala/viz.py @@ -81,7 +81,8 @@ def _get_diff(current: str, new: str) -> str: def _get_colorized_diff( - current: str, new: str, style: str = "multiline", context_lines: int = 2 + current: str, new: str, style: str = "multiline", context_lines: int = 2, + colorize: bool = True ) -> str: """ Return a line-by-line colorized diff of the changes between `current` and @@ -101,11 +102,15 @@ def _get_colorized_diff( if line.startswith("-"): if style == "inline": line = line[1:] - lines.append(_colorize(line, "red")) + if colorize: + line = _colorize(line, "red") + lines.append(line) elif line.startswith("+"): if style == "inline": line = line[1:] - lines.append(_colorize(line, "green")) + if colorize: + line = _colorize(line, "green") + lines.append(line) else: lines.append(line) if style == "multiline": diff --git a/requirements.txt b/requirements.txt index 278a607..f74c524 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,4 @@ # for binder numpy >= 1.18 pandas >= 1.0 -joblib >= 1.0 -pypika >= 0.48 -pyarrow >= 8.0.0 -cityhash >= 0.2.2 -rich -scikit-learn \ No newline at end of file +joblib >= 1.0 \ No newline at end of file diff --git a/setup.py b/setup.py index f372bf0..1ac47c9 100644 --- a/setup.py +++ b/setup.py @@ -5,18 +5,10 @@ "numpy >= 1.18", "pandas >= 1.0", "joblib >= 1.0", - "pypika >= 0.48", - "pyarrow >= 8.0.0", ] extras_require = { "base": [], - "performance": [ - "cityhash >= 0.2.2", # for faster content hashing - ], - "integrations": [ - "dask[complete]", - ], "ui": [ "rich", ], @@ -24,16 +16,10 @@ "pytest >= 6.0.0", "hypothesis >= 6.0.0", "ipython", - "mongomock", - "duckdb >= 0.6", ], "demos": [ - "torch", "scikit-learn", ], - "server": [ - "pymongo", - ], } @@ -41,20 +27,13 @@ packages = [ "mandala", - "mandala.core", "mandala.deps", - "mandala.deps.tracers", - "mandala.queries", - "mandala.storages", - "mandala.storages.rel_impls", - "mandala.storages.remote_impls", - "mandala.ui", "mandala.tests", ] setup( name="mandala", - version="0.1.0", + version="v0.2.0-alpha", description="", url="https://github.com/amakelov/mandala", license="Apache 2.0", diff --git a/tutorials/00_hello.ipynb b/tutorials/00_hello.ipynb deleted file mode 100644 index 22f2ab2..0000000 --- a/tutorials/00_hello.ipynb +++ /dev/null @@ -1,202 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Hello world\n", - "In this notebook, you'll run a minimal interesting example of `mandala`. It's a\n", - "great way to quickly get a feel for the library and play with it yourself!\n", - "\n", - "If you want a more in-depth introduction with a real ML project, check out the\n", - "[the next tutorial](01_logistic.ipynb).\n", - "\n", - "## Create the storage and computational primitives\n", - "A `Storage` instance is where the results of all computations you run in a\n", - "project are stored. Importantly, **the only way to put data into a `Storage` is\n", - "to call a function**: you decorate your functions with the `@op` decorator, and\n", - "then any time you call them, the inputs and outputs for this call are stored in\n", - "the `Storage`.\n", - "\n", - "Go ahead and create a storage and two `@op`-decorated functions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from mandala.imports import *\n", - "\n", - "# create a storage for results\n", - "storage = Storage()\n", - "\n", - "@op # memoization decorator\n", - "def inc(x) -> int:\n", - " print('Hi from inc!')\n", - " return x + 1 \n", - "\n", - "@op\n", - "def add(x: int, y: int) -> int:\n", - " print('Hi from add!')\n", - " return x + y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### A note on function inputs/outputs\n", - "Currently, **`mandala` only supports functions with a fixed number of inputs and\n", - "outputs**. To help make this explicit for outputs, you must specify\n", - "the number of outputs in the return type annotation. For example, `def f() ->\n", - "int` means that `f` returns a single integer, and `def f() -> Tuple[int, int]` means\n", - "that `f` returns two integers. Functions that return nothing can leave the\n", - "return type annotation empty." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Your first `mandala`-tracked computation\n", - "The main way you use `mandala` is through \"workflows\", i.e. compositions of\n", - "`@op`-decorated functions. Running a workflow for the first time inside a\n", - "`storage.run()` block will execute the workflow and store the results in the \n", - "`storage`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run():\n", - " x = inc(20)\n", - " y = add(21, x)\n", - " print(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Running this workflow a **second** time will not re-execute it, but instead\n", - "retrieve the results from the `storage` at each function call along the way:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run():\n", - " x = inc(20)\n", - " y = add(21, x)\n", - " print(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Adding more logic to the workflow will not re-execute the parts that have\n", - "already been executed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run():\n", - " for a in [10, 20, 30]:\n", - " x = inc(a)\n", - " y = add(21, x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The workflow just executed can also be used as a jumping-off point for issuing\n", - "queries. For example, `storage.similar(...)` can be used to query for a table of\n", - "values that were computed in an analogous way to given variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "storage.similar(y, context=True) # use `context=True` to also get the values of dependencies" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `storage.similar` method prints out the query extracted from the\n", - "computation. For more control (or if you dislike how implicit the interface\n", - "above is), you can directly copy-paste this code into a `storage.query()` block:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.query():\n", - " a = Q() # input to computation; can match anything\n", - " a0 = Q() # input to computation; can match anything\n", - " x = inc(x=a)\n", - " y = add(x=a0, y=x)\n", - "storage.df(a, a0, x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Those are the main patterns you need to know to start playing around with\n", - "`mandala`! We invite you to go back and modify the code above by creating new\n", - "computational primitives and workflows, and see how `mandala` handles it.\n", - "\n", - "TODO: talk about versioning!" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.8", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "30c0510467e0bc33a523a84a8acb20ce0730b8eb0ee254a4b0039140f094f217" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tutorials/01_logistic.ipynb b/tutorials/01_logistic.ipynb deleted file mode 100644 index e095f40..0000000 --- a/tutorials/01_logistic.ipynb +++ /dev/null @@ -1,1048 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Logistic regression in pytorch\n", - "## What's in this tutorial?\n", - "This notebook will walk you through the basic uses of `mandala` for storing\n", - "and tracking ML experiment results. It uses logistic regression on a synthetic\n", - "dataset as a \"minimally interesting\" example of a data management use case. By\n", - "following this ML mini-project, you will learn how to\n", - "- break up an experiment into Python functions whose calls can be\n", - "tracked and queried by `mandala`;\n", - "- use `mandala`'s memoization to avoid re-running expensive computations and to\n", - "naturally interact with and grow your project (by adjusting the parameters and/or\n", - "adding new code);\n", - "- repurpose the (pure Python) code of your experiments into a *query interface*\n", - "to their results \"for free\";\n", - "- modify, or create new versions of, your experimental primitives, and have them\n", - " seamlessly interact with the results of previous runs.\n", - "\n", - "Ultimatley, the features of `mandala` work together to enable you to evolve\n", - "complex ML projects by writing only the plain-Python code that you'd write in a\n", - "temporary in-memory interactive session, yet get the benefits of a\n", - "database-backed experiment tracking system. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Import libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Tuple\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "\n", - "# recommended way to import mandala functionality\n", - "from mandala.imports import *\n", - "\n", - "# for reproducibility\n", - "np.random.seed(0)\n", - "torch.random.set_rng_state(torch.manual_seed(0).get_state())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define experiment primitives\n", - "You'll break the project into two main functions: to generate the synthetic\n", - "dataset, and to train the model. Below is fairly standard `pytorch` code for\n", - "these. Note the use of `@op` to mark the functions as tracked by `mandala` -\n", - "more on that shortly:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_DIMENSION = 10\n", - "\n", - "# main `mandala` decorator; like @functools.lru_cache, but with extra functionality.\n", - "# Currently, you must specify the exact number of inputs (i.e., no *args or **kwargs),\n", - "# and the number of outputs (using a type annotation with a `Tuple` if there are\n", - "# multiple outputs).\n", - "@op\n", - "def generate_dataset() -> Tuple[TensorDataset, TensorDataset]:\n", - " \"\"\"\n", - " Generate a simple synthetic dataset for logistic regression, perform a\n", - " 80/20 train/test split, and return the results as `TensorDataset`s.\n", - " \"\"\"\n", - " n_samples = 1000\n", - " x = np.random.randn(n_samples, DATA_DIMENSION)\n", - " y = x[:, 0] > 0\n", - " x, y = torch.from_numpy(x).float(), torch.from_numpy(y).long()\n", - " train_size = int(0.8 * n_samples)\n", - " train_dataset = TensorDataset(x[:train_size], y[:train_size])\n", - " test_dataset = TensorDataset(x[train_size:], y[train_size:])\n", - " return train_dataset, test_dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class LogisticRegression(torch.nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.linear = torch.nn.Linear(DATA_DIMENSION, 2)\n", - "\n", - " def forward(self, feature):\n", - " output = self.linear(feature)\n", - " return output\n", - "\n", - "\n", - "@op\n", - "def train_model(\n", - " train_dataset: TensorDataset,\n", - " test_dataset: TensorDataset,\n", - " learning_rate: float = 0.001,\n", - " batch_size: int = 100,\n", - " num_epochs: int = 3,\n", - ") -> Tuple[LogisticRegression, float]:\n", - " \"\"\"\n", - " Train a logistic model on the given training dataset with the given\n", - " hyperparameters.\n", - "\n", - " Prints out the train loss and test accuracy at the end of\n", - " each epoch. Returns the trained model and the final test accuracy.\n", - " \"\"\"\n", - " train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", - " test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " model = LogisticRegression().to(device)\n", - " loss = torch.nn.CrossEntropyLoss().to(device)\n", - " optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n", - " for epoch in range(num_epochs):\n", - " # train\n", - " model.train()\n", - " for xs, ys in train_loader:\n", - " xs = xs.to(device)\n", - " ys = ys.to(device)\n", - " optimizer.zero_grad()\n", - " output = model(xs)\n", - " loss_value = loss(output, ys)\n", - " loss_value.backward()\n", - " optimizer.step()\n", - " # test\n", - " model.eval()\n", - " accurate, total = 0, 0\n", - " for xs, ys in test_loader:\n", - " xs = xs.to(device)\n", - " ys = ys.to(device)\n", - " output = model(xs)\n", - " _, predicted = torch.max(output.data, 1)\n", - " total += ys.size(0)\n", - " accurate += (predicted == ys).sum()\n", - " acc = 100 * accurate / total\n", - " print(\n", - " f\"Epoch: {epoch}, Training loss: {round(loss_value.item(), 2)}. Test accuracy: {round(acc.item(), 2)}\"\n", - " )\n", - " return model, round(float(acc.item()), 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## \"Hello world\", or: run the pipeline and store the results\n", - "Now that you have defined the functions that make up your pipeline, you can\n", - "run it with the default parameters to see how well the model performs!\n", - "\n", - "The `@op` decorator on the functions above tells `mandala` to track the calls to\n", - "these functions and store their results - but this only happens when you call\n", - "these functions *in the context of a given `Storage` object*. So go ahead and\n", - "create a storage for the project: " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "storage = Storage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This storage will hold the results of all the experiments you run in this\n", - "notebook. Now, run the pipeline and save its results by wrapping the code you'd\n", - "normally write in a `storage.run()` context manager:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Final accuracy: ValueRef(41.5, uid=0b6...)\n" - ] - } - ], - "source": [ - "with storage.run():\n", - " train_dataset, test_dataset = generate_dataset()\n", - " model, acc = train_model(train_dataset, test_dataset)\n", - " print(f\"Final accuracy: {acc}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### What just happened?\n", - "A lot happened behind the scenes in these few lines of code! Let's break it\n", - "down:\n", - "- Inside the `storage.run()` block, each time an `@op`-decorated function is\n", - "called **for the first time** on a set of inputs, `mandala` stores the inputs\n", - "and outputs of this call in the storage. \n", - "- Values shared between calls are stored only once. So\n", - " `train_dataset` will appear in storage as both the output to the call to\n", - " `generate_dataset`, and the input to the call to `train_model`.\n", - "- The `acc` object (like all objects returned by `@op`-decorated functions) is a\n", - "*value reference*, which is a value wrapped with storage-related metadata. \n", - "\n", - "So, what happens when you call `@op`-decorated functions *a second time* on the\n", - "same inputs? Find out by running the cell below:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run():\n", - " train_dataset, test_dataset = generate_dataset()\n", - " model, acc = train_model(train_dataset, test_dataset)\n", - " print(f\"Final accuracy: {acc}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Note that this time the intermediate training results did not get printed!**.\n", - "This is because `mandala` recognized that the inputs to the functions were the\n", - "same as before, and so it didn't need to re-run the calls." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### So what?\n", - "This was the simplest non-trivial use case of `mandala`! However, at this point\n", - "it is just a glorified `pickle`-based memoization system. Its real power comes\n", - "from the way in which `mandala`'s memoization *composes* with the rest of the\n", - "Python language, which allow you to manage complex experiments with the minimal\n", - "amount of plain-Python code, as we'll see next." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Grow the project with new parameters\n", - "Running the pipeline once is nice, but where `mandala` really shines is in\n", - "enabling you to grow a computational project in various ways with the minimal\n", - "necessary code changes, and have the storage interfaces \"just work\". \n", - "\n", - "Let's begin exploring this by investigating the effect of changing the learning\n", - "rate of the model. So far, you have been using the default learning rate of\n", - "`0.001`. Let's try a few other values, but also see how they compare with the\n", - "default value. Thanks to memoization, this is easy to do without re-doing\n", - "expensive work: we can use a list of values for the `learning_rate` parameter\n", - "that includes the default, and compare:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "===end of run=== learning_rate: 0.001, acc: 41.5\n", - "Epoch: 0, Training loss: 0.81. Test accuracy: 61.5\n", - "Epoch: 1, Training loss: 0.73. Test accuracy: 61.5\n", - "Epoch: 2, Training loss: 0.71. Test accuracy: 61.5\n", - "===end of run=== learning_rate: 0.01, acc: 61.5\n", - "Epoch: 0, Training loss: 0.59. Test accuracy: 68.0\n", - "Epoch: 1, Training loss: 0.59. Test accuracy: 70.5\n", - "Epoch: 2, Training loss: 0.53. Test accuracy: 73.5\n", - "===end of run=== learning_rate: 0.1, acc: 73.5\n" - ] - } - ], - "source": [ - "with storage.run():\n", - " train_dataset, test_dataset = generate_dataset()\n", - " for learning_rate in [0.001, 0.01, 0.1]:\n", - " model, acc = train_model(train_dataset, test_dataset, learning_rate)\n", - " print(\n", - " # `unwrap()` is used to get the value wrapped by a `ValueRef`\n", - " f\"===end of run=== learning_rate: {learning_rate}, acc: {round(unwrap(acc), 2)}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - " | num_epochs | \n", - "test_dataset | \n", - "batch_size | \n", - "learning_rate | \n", - "train_dataset | \n", - "output_0 | \n", - "output_1 | \n", - "
---|---|---|---|---|---|---|---|
0 | \n", - "ValueRef(in_memory=False, uid=566...) | \n", - "ValueRef(in_memory=False, uid=ff8...) | \n", - "ValueRef(in_memory=False, uid=9ce...) | \n", - "ValueRef(in_memory=False, uid=afd...) | \n", - "ValueRef(in_memory=False, uid=239...) | \n", - "ValueRef(in_memory=False, uid=1a2...) | \n", - "ValueRef(in_memory=False, uid=0b6...) | \n", - "
1 | \n", - "ValueRef(in_memory=False, uid=566...) | \n", - "ValueRef(in_memory=False, uid=ff8...) | \n", - "ValueRef(in_memory=False, uid=9ce...) | \n", - "ValueRef(in_memory=False, uid=0e1...) | \n", - "ValueRef(in_memory=False, uid=239...) | \n", - "ValueRef(in_memory=False, uid=ca2...) | \n", - "ValueRef(in_memory=False, uid=251...) | \n", - "
2 | \n", - "ValueRef(in_memory=False, uid=566...) | \n", - "ValueRef(in_memory=False, uid=ff8...) | \n", - "ValueRef(in_memory=False, uid=9ce...) | \n", - "ValueRef(in_memory=False, uid=38f...) | \n", - "ValueRef(in_memory=False, uid=239...) | \n", - "ValueRef(in_memory=False, uid=242...) | \n", - "ValueRef(in_memory=False, uid=1fa...) | \n", - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ ### Dependencies for version of function train_model from module __main__ │\n", - "│ ### content_version_id=3186cce4d44f99c0e346bbc4668f613b │\n", - "│ ### semantic_version_id=e27ccd0e76ba6886b390d33c7ade95a3 │\n", - "│ │\n", - "│ ################################################################################ │\n", - "│ ### IN MODULE \"__main__\" │\n", - "│ ################################################################################ │\n", - "│ C = 0.1 │\n", - "│ X = array([[-0.6693561 , -1.49577819, -0.87076638, ..., -1.26733697, [...] │\n", - "│ y = array([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, [...] │\n", - "│ │\n", - "│ ### in class MyRegression: │\n", - "│ def __init__(self): │\n", - "│ self.lr = LogisticRegression(C=C) │\n", - "│ │\n", - "│ def fit(self, X, y): │\n", - "│ self.lr.fit(X, y) │\n", - "│ │\n", - "│ def score(self, X, y): │\n", - "│ return self.lr.score(X, y) │\n", - "│ │\n", - "│ @op │\n", - "│ def train_model(model_class: str = 'lr') -> float: │\n", - "│ print(f'training model {model_class}') │\n", - "│ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) │\n", - "│ if model_class == 'lr': │\n", - "│ model = MyRegression() │\n", - "│ model.fit(X_train, y_train) │\n", - "│ acc = model.score(X_test, y_test) │\n", - "│ else: │\n", - "│ acc = train_rf(X_train, y_train, X_test, y_test) │\n", - "│ return acc │\n", - "│ │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "\n" - ], - "text/plain": [ - "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function train_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=3186cce4d44f99c0e346bbc4668f613b\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=e27ccd0e76ba6886b390d33c7ade95a3\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.1\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.6693561\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.49577819\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.87076638\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.26733697\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### in class MyRegression:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227m__init__\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mLogisticRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mstr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mtraining model \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_test_split\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtest_size\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.2\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrandom_state\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m42\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m==\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mMyRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227melse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "with storage.run():\n", - " acc = train_model()\n", - "\n", - "storage.versions(train_model) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The dependencies of the call we memoized include\n", - " - the **function's own source code**, in this case the code of `train_model`;\n", - " - the **global variables** that it accesses, in this case `X` and `y`;\n", - " - **recursively**, the **dependencies of any functions/methods that it calls**.\n", - " For example, `train_model` depends on `MyRegression.__init__`, \n", - " among other functions/methods - and `MyRegression.__init__` itself depends\n", - " on `C`. \n", - "\n", - "A change to any of these *may* make a memoized call invalid. Note that whether\n", - "or not a result is stale depends on the function's dependencies **for that\n", - "particular call**, and on whether the changes to the dependencies change the\n", - "meaning of the computation. \n", - "\n", - "Indeed, if we now train a random forest classifier instead, we will see a new\n", - "version with different dependencies show up:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "training model rf\n" - ] - }, - { - "data": { - "text/html": [ - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ ### Dependencies for version of function train_model from module __main__ │\n", - "│ ### content_version_id=3186cce4d44f99c0e346bbc4668f613b │\n", - "│ ### semantic_version_id=e27ccd0e76ba6886b390d33c7ade95a3 │\n", - "│ │\n", - "│ ################################################################################ │\n", - "│ ### IN MODULE \"__main__\" │\n", - "│ ################################################################################ │\n", - "│ C = 0.1 │\n", - "│ X = array([[-0.6693561 , -1.49577819, -0.87076638, ..., -1.26733697, [...] │\n", - "│ y = array([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, [...] │\n", - "│ │\n", - "│ ### in class MyRegression: │\n", - "│ def __init__(self): │\n", - "│ self.lr = LogisticRegression(C=C) │\n", - "│ │\n", - "│ def fit(self, X, y): │\n", - "│ self.lr.fit(X, y) │\n", - "│ │\n", - "│ def score(self, X, y): │\n", - "│ return self.lr.score(X, y) │\n", - "│ │\n", - "│ @op │\n", - "│ def train_model(model_class: str = 'lr') -> float: │\n", - "│ print(f'training model {model_class}') │\n", - "│ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) │\n", - "│ if model_class == 'lr': │\n", - "│ model = MyRegression() │\n", - "│ model.fit(X_train, y_train) │\n", - "│ acc = model.score(X_test, y_test) │\n", - "│ else: │\n", - "│ acc = train_rf(X_train, y_train, X_test, y_test) │\n", - "│ return acc │\n", - "│ │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ ### Dependencies for version of function train_model from module __main__ │\n", - "│ ### content_version_id=574da61ae9e58f448972ba39ef30c15e │\n", - "│ ### semantic_version_id=3c430ce7f58c07cac3490e3eafcdc5bf │\n", - "│ │\n", - "│ ################################################################################ │\n", - "│ ### IN MODULE \"__main__\" │\n", - "│ ################################################################################ │\n", - "│ N_ESTIMATORS = 100 │\n", - "│ X = array([[-0.6693561 , -1.49577819, -0.87076638, ..., -1.26733697, [...] │\n", - "│ y = array([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, [...] │\n", - "│ │\n", - "│ @op │\n", - "│ def train_model(model_class: str = 'lr') -> float: │\n", - "│ print(f'training model {model_class}') │\n", - "│ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) │\n", - "│ if model_class == 'lr': │\n", - "│ model = MyRegression() │\n", - "│ model.fit(X_train, y_train) │\n", - "│ acc = model.score(X_test, y_test) │\n", - "│ else: │\n", - "│ acc = train_rf(X_train, y_train, X_test, y_test) │\n", - "│ return acc │\n", - "│ │\n", - "│ def train_rf(X_train, y_train, X_test, y_test) -> float: │\n", - "│ rf = RandomForestClassifier(n_estimators=N_ESTIMATORS) │\n", - "│ rf.fit(X_train, y_train) │\n", - "│ score = rf.score(X_test, y_test) │\n", - "│ return score │\n", - "│ │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "\n" - ], - "text/plain": [ - "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function train_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=3186cce4d44f99c0e346bbc4668f613b\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=e27ccd0e76ba6886b390d33c7ade95a3\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.1\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.6693561\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.49577819\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.87076638\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.26733697\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### in class MyRegression:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227m__init__\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mLogisticRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mC\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mself\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mlr\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mstr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mtraining model \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_test_split\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtest_size\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.2\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrandom_state\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m42\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m==\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mMyRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227melse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### Dependencies for version of function train_model from module __main__\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### content_version_id=574da61ae9e58f448972ba39ef30c15e\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### semantic_version_id=3c430ce7f58c07cac3490e3eafcdc5bf\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m### IN MODULE \"__main__\"\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[3;38;2;147;161;161;48;2;253;246;227m################################################################################\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mN_ESTIMATORS\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m100\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.6693561\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.49577819\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.87076638\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1.26733697\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227marray\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m1\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m[\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m]\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mstr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mtraining model \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_test_split\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtest_size\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.2\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrandom_state\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m42\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m==\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mMyRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227melse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mRandomForestClassifier\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mn_estimators\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mN_ESTIMATORS\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrf\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrf\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[48;2;253;246;227m \u001b[0m │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "with storage.run():\n", - " acc = train_model(model_class='rf')\n", - "\n", - "storage.versions(train_model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Detecting changes\n", - "Now to the fun part: automatically reacting to changes in dependencies. Let's\n", - "make three different kinds of changes to the dependencies:\n", - "- change the constant `C` used in the logistic regression;\n", - "- change the source of the method `__init__` of the `MyRegression` class.\n", - "- change the source of `train_model` itself;\n", - "\n", - "As you'll see when you run the below, you'll be presented with a diff of the\n", - "dependencies, organized by module and the kind of dependency (function or global\n", - "variable), and the memoized functions affected by each change:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CHANGE DETECTED in train_model from module __main__\n", - "Dependent components:\n", - " Version of \"train_model\" from module \"__main__\" (content: 3186cce4d44f99c0e346bbc4668f613b, semantic: e27ccd0e76ba6886b390d33c7ade95a3)\n", - " Version of \"train_model\" from module \"__main__\" (content: 574da61ae9e58f448972ba39ef30c15e, semantic: 3c430ce7f58c07cac3490e3eafcdc5bf)\n", - "===DIFF===:\n", - " else:\n", - " acc = train_rf(X_train, y_train, X_test, y_test)\n", - "\u001b[32m+ print(f'accuracy: {acc}') # CHANGE: print the accuracy\u001b[0m\n", - " return acc\n", - "Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort \n", - "You answered: \"n\"\n", - "CHANGE DETECTED in MyRegression.__init__ from module __main__\n", - "Dependent components:\n", - " Version of \"train_model\" from module \"__main__\" (content: 3186cce4d44f99c0e346bbc4668f613b, semantic: e27ccd0e76ba6886b390d33c7ade95a3)\n", - "===DIFF===:\n", - " def __init__(self):\n", - "\u001b[31m- self.lr = LogisticRegression(C=C)\u001b[0m\n", - "\u001b[32m+ self.lr = LogisticRegression(C=C, class_weight='balanced') # changed from 'l2' (default)\u001b[0m\n", - "Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort \n", - "You answered: \"y\"\n", - "CHANGE DETECTED in C from module __main__\n", - "Dependent components:\n", - " Version of \"train_model\" from module \"__main__\" (content: 3186cce4d44f99c0e346bbc4668f613b, semantic: e27ccd0e76ba6886b390d33c7ade95a3)\n", - "===DIFF===:\n", - "\u001b[31m-0.1\u001b[0m\n", - "\u001b[32m+1.0\u001b[0m\n", - "Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort \n", - "You answered: \"y\"\n", - "training model lr\n", - "accuracy: 0.855\n" - ] - } - ], - "source": [ - "N_SAMPLES = 1000 \n", - "X, y = make_classification(n_samples=N_SAMPLES, random_state=42)\n", - "\n", - "C = 1.0 # changed from 0.1\n", - "class MyRegression: # a thin wrapper around sklearn's LogisticRegression\n", - "\n", - " def __init__(self):\n", - " self.lr = LogisticRegression(C=C, class_weight='balanced') # changed from 'l2' (default)\n", - " \n", - " def fit(self, X, y):\n", - " self.lr.fit(X, y)\n", - " \n", - " def score(self, X, y):\n", - " return self.lr.score(X, y)\n", - "\n", - "N_ESTIMATORS = 100\n", - "def train_rf(X_train, y_train, X_test, y_test) -> float:\n", - " rf = RandomForestClassifier(n_estimators=N_ESTIMATORS)\n", - " rf.fit(X_train, y_train)\n", - " score = rf.score(X_test, y_test)\n", - " return score\n", - "\n", - "@op\n", - "def train_model(model_class: str = 'lr') -> float:\n", - " print(f'training model {model_class}')\n", - " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", - " if model_class == 'lr':\n", - " model = MyRegression()\n", - " model.fit(X_train, y_train)\n", - " acc = model.score(X_test, y_test)\n", - " else:\n", - " acc = train_rf(X_train, y_train, X_test, y_test)\n", - " print(f'accuracy: {acc}') # CHANGE: print the accuracy\n", - " return acc\n", - "\n", - "with storage.run():\n", - " rf_acc = train_model(model_class='rf')\n", - " lr_acc = train_model(model_class='lr')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's unpack what happened. In the above interactive dialog, we marked the\n", - "changes as follows:\n", - "- the added printing of accuracy in `train_model` was marked as not requiring\n", - " recomputation; it will only affect future calls.\n", - "- the changes to `C` and `MyRegression.__init__` were marked as requiring\n", - " recomputation, as they substantially change what gets computed!\n", - "As a result, the call that trained a logistic regression model was re-computed -\n", - "but the call using a random forest classifier was reused. This is the power of\n", - "fine-grained dependency tracking! \n", - "\n", - "Now that you have multiple versions of stuff, you can also get a look at how a\n", - "particular component evolved during a project. For example, `storage.sources()`\n", - "gives you a tree-like revision history:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Revision history for the source code of function train_model from module __main__ (\"===HEAD===\" is the current version):\n" - ] - }, - { - "data": { - "text/html": [ - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ @op │\n", - "│ def train_model(model_class: str = 'lr') -> float: │\n", - "│ print(f'training model {model_class}') │\n", - "│ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) │\n", - "│ if model_class == 'lr': │\n", - "│ model = MyRegression() │\n", - "│ model.fit(X_train, y_train) │\n", - "│ acc = model.score(X_test, y_test) │\n", - "│ else: │\n", - "│ acc = train_rf(X_train, y_train, X_test, y_test) │\n", - "│ return acc │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "└── ╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - " │ ### ===HEAD=== │\n", - " │ @op │\n", - " │ def train_model(model_class: str = 'lr') -> float: │\n", - " │ print(f'training model {model_class}') │\n", - " │ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) │\n", - " │ if model_class == 'lr': │\n", - " │ model = MyRegression() │\n", - " │ model.fit(X_train, y_train) │\n", - " │ acc = model.score(X_test, y_test) │\n", - " │ else: │\n", - " │ acc = train_rf(X_train, y_train, X_test, y_test) │\n", - " │ print(f'accuracy: {acc}') # CHANGE: print the accuracy │\n", - " │ return acc │\n", - " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "\n" - ], - "text/plain": [ - "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mstr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mtraining model \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_test_split\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtest_size\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.2\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrandom_state\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m42\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m==\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mMyRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227melse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "│ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "└── ╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - " │ \u001b[3;38;2;147;161;161;48;2;253;246;227m### ===HEAD===\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;38;139;210;48;2;253;246;227m@op\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;133;153;0;48;2;253;246;227mdef\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mtrain_model\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mstr\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m-\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m>\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mfloat\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mtraining model \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_test_split\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtest_size\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m0.2\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mrandom_state\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m42\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mif\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel_class\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m==\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mlr\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mMyRegression\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mfit\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mmodel\u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m.\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mscore\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227melse\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m:\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;147;161;161;48;2;253;246;227m=\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mtrain_rf\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_train\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227mX_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m,\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227my_test\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;38;139;210;48;2;253;246;227mprint\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m(\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227mf\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227maccuracy: \u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m{\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m}\u001b[0m\u001b[38;2;42;161;152;48;2;253;246;227m'\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m)\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[3;38;2;147;161;161;48;2;253;246;227m# CHANGE: print the accuracy\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " │ \u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;133;153;0;48;2;253;246;227mreturn\u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227m \u001b[0m\u001b[38;2;101;123;131;48;2;253;246;227macc\u001b[0m\u001b[48;2;253;246;227m \u001b[0m │\n", - " ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "storage.sources(train_model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Conclusion\n", - "This was a very brief exposure to the versioning machinery. In particular, we\n", - "didn't cover some more of the things possible with this design:\n", - "- revisit old versions and make new branches off of them;\n", - "- audit all the source code that went into a particular call; \n", - "- incorporate versioning information in declarative queries" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.6", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "e39bb3b1f45b78879464f3858f3ac405da62799496d9b7e0a39caf0b676c9a45" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tutorials/03_advanced.ipynb b/tutorials/03_advanced.ipynb deleted file mode 100644 index 5b492bd..0000000 --- a/tutorials/03_advanced.ipynb +++ /dev/null @@ -1,380 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", - "except:\n", - " IN_COLAB = False\n", - "\n", - "if IN_COLAB:\n", - " # run this cell ONLY if you are running this in Google Colab\n", - " !pip install git+https://github.com/amakelov/mandala\n", - " !pip install scikit-learn" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from mandala.imports import *\n", - "from typing import List, Tuple\n", - "from sklearn.tree import DecisionTreeClassifier\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn.datasets import make_classification, load_digits\n", - "from pathlib import Path\n", - "import numpy as np\n", - "from numpy import ndarray\n", - "Config.enable_ref_magics = True\n", - "Config.warnings = False\n", - "\n", - "OUTPUT_ROOT = Path().absolute() / \"03_advanced.db\"\n", - "OUTPUT_ROOT.unlink(missing_ok=True)\n", - "\n", - "# storage = Storage() # use this for an in-memory storage without dependency tracking\n", - "# storage = Storage(db_path=OUTPUT_ROOT) # use this for a persistent storage without dependency tracking\n", - "storage = Storage(db_path=OUTPUT_ROOT, deps_path='__main__')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@op\n", - "def generate_data() -> Tuple[ndarray, ndarray]:\n", - " return load_digits(n_class=2, return_X_y=True)\n", - "\n", - "@op\n", - "def train_and_eval_tree(X, y, seed,\n", - " max_depth=1) -> Tuple[DecisionTreeClassifier, float]:\n", - " tree = DecisionTreeClassifier(random_state=seed, \n", - " max_depth=max_depth,\n", - " max_features=1).fit(X, y)\n", - " return tree, round(accuracy_score(y_true=y, y_pred=tree.predict(X)), 2)\n", - " \n", - "@op\n", - "def eval_forest(trees:List[DecisionTreeClassifier], X, y) -> float:\n", - " majority_vote = np.array([tree.predict(X) for tree in trees]).mean(axis=0) >= 0.5\n", - " return round(accuracy_score(y_true=y, y_pred=majority_vote), 2)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): # memoization context manager\n", - " X, y = generate_data()\n", - " trees = []\n", - " for seed in range(10): # can't grow trees without seeds\n", - " tree, acc = train_and_eval_tree(X, y, seed=seed)\n", - " trees.append(tree)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15):\n", - " trees = []\n", - " for seed in range(n_trees): \n", - " tree, acc = train_and_eval_tree(X, y, seed=seed)\n", - " trees.append(tree)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15):\n", - " trees = []\n", - " for seed in range(n_trees): \n", - " tree, acc = train_and_eval_tree(X, y, seed=seed)\n", - " if acc > 0.8:\n", - " trees.append(tree)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@superop\n", - "def train_forest(X, y, n_trees) -> List[DecisionTreeClassifier]:\n", - " trees = []\n", - " for i in range(n_trees):\n", - " tree, acc = train_and_eval_tree(X, y, seed=i) \n", - " if acc > 0.8:\n", - " trees.append(tree)\n", - " return trees" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15):\n", - " trees = train_forest(X, y, n_trees)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15, 20):\n", - " trees = train_forest(X, y, n_trees)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, ):\n", - " trees = train_forest(X, y, n_trees)\n", - " forest_acc = eval_forest(trees, X, y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "storage.similar(forest_acc, context=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (10, 15, 20,):\n", - " trees = train_forest(X, y, n_trees)\n", - " forest_acc = eval_forest(trees[:n_trees//2], X, y)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "storage.draw_graph(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "storage.draw_graph(forest_acc, project=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "storage.print_graph(forest_acc, project=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.query():\n", - " idx = Q() # index into list\n", - " X, y = generate_data()\n", - " n_trees = Q() # input to computation; can match anything\n", - " trees = train_forest(X=X, y=y, n_trees=n_trees)\n", - " a0 = trees[idx] # a0 will match any element of a match for trees at index matching idx1\n", - " a1 = ListQ(elts=[a0], idxs=[idx]) # a1 will match any list containing a match for a0 at index idx0\n", - " forest_acc = eval_forest(trees=a1, X=X, y=y)\n", - "storage.df(n_trees, forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@superop\n", - "def train_forest(X, y, n_trees, threshold = 0.8) -> List[DecisionTreeClassifier]:\n", - " trees = []\n", - " for i in range(n_trees):\n", - " tree, acc = train_and_eval_tree(X, y, seed=i) \n", - " if acc > threshold:\n", - " trees.append(tree)\n", - " return trees" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15, 20):\n", - " trees = train_forest(X, y, n_trees)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (5, 10, 15, 20):\n", - " trees = train_forest(X, y, n_trees, threshold=0.5)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# notice we changed `max_features` to 2\n", - "@op\n", - "def train_and_eval_tree(X, y, seed,\n", - " max_depth=1) -> Tuple[DecisionTreeClassifier, float]:\n", - " tree = DecisionTreeClassifier(random_state=seed, \n", - " max_depth=max_depth,\n", - " max_features=2).fit(X, y)\n", - " return tree, round(accuracy_score(y_true=y, y_pred=tree.predict(X)), 2)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (10, 15, 20):\n", - " trees = train_forest(X, y, n_trees, threshold=0.5)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# now we change it back to 1 - the old memoized calls are used!\n", - "@op\n", - "def train_and_eval_tree(X, y, seed,\n", - " max_depth=1) -> Tuple[DecisionTreeClassifier, float]:\n", - " tree = DecisionTreeClassifier(random_state=seed, \n", - " max_depth=max_depth,\n", - " max_features=1).fit(X, y)\n", - " return tree, round(accuracy_score(y_true=y, y_pred=tree.predict(X)), 2)\n", - "\n", - "\n", - "with storage.run(): \n", - " X, y = generate_data()\n", - " for n_trees in (10, 15, 20):\n", - " trees = train_forest(X, y, n_trees, threshold=0.5)\n", - " forest_acc = eval_forest(trees, X, y)\n", - " print(forest_acc)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# look at the versions of a single dependency\n", - "storage.sources(train_and_eval_tree)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.6", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "vscode": { - "interpreter": { - "hash": "e39bb3b1f45b78879464f3858f3ac405da62799496d9b7e0a39caf0b676c9a45" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/mandala/tutorials/gotchas.ipynb b/tutorials/gotchas.ipynb similarity index 97% rename from mandala/tutorials/gotchas.ipynb rename to tutorials/gotchas.ipynb index c920671..598346f 100644 --- a/mandala/tutorials/gotchas.ipynb +++ b/tutorials/gotchas.ipynb @@ -33,7 +33,7 @@ "import random\n", "import numpy as np\n", "\n", - "from mandala._next.utils import get_content_hash, serialize, deserialize\n", + "from mandala.utils import get_content_hash, serialize, deserialize\n", "\n", "X, y = load_digits(n_class=10, return_X_y=True)\n", "\n", diff --git a/mandala/tutorials/hello.ipynb b/tutorials/hello.ipynb similarity index 99% rename from mandala/tutorials/hello.ipynb rename to tutorials/hello.ipynb index b061a8c..51da9ae 100644 --- a/mandala/tutorials/hello.ipynb +++ b/tutorials/hello.ipynb @@ -49,7 +49,7 @@ } ], "source": [ - "from mandala._next.imports import *\n", + "from mandala.imports import *\n", "import time\n", "\n", "storage = Storage( # stores all `@op` calls\n", diff --git a/mandala/tutorials/ml.ipynb b/tutorials/ml.ipynb similarity index 99% rename from mandala/tutorials/ml.ipynb rename to tutorials/ml.ipynb index 7f58895..eda329e 100644 --- a/mandala/tutorials/ml.ipynb +++ b/tutorials/ml.ipynb @@ -43,7 +43,7 @@ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "# recommended way to import mandala functionality\n", - "from mandala._next.imports import *\n", + "from mandala.imports import *\n", "\n", "np.random.seed(0)" ] diff --git a/tutorials/readme_examples.ipynb b/tutorials/readme_examples.ipynb deleted file mode 100644 index 2d11117..0000000 --- a/tutorials/readme_examples.ipynb +++ /dev/null @@ -1,167 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hi from increment!\n", - "hi from increment!\n", - "ValueRef(24, uid=f8b...)\n", - "24\n", - "hi from increment!\n", - "hi from increment!\n", - "hi from increment!\n", - "hi from increment!\n", - "hi from increment!\n", - "hi from average!\n", - "hi from average!\n", - "hi from average!\n", - "Pattern-matching to the following computational graph (all constraints apply):\n", - " idx0 = Q() # index into list\n", - " a0 = Q() # input to computation; can match anything\n", - " a1 = increment(x=a0)\n", - " a2 = ListQ(elts=[a1], idxs=[idx0]) # a2 will match any list containing a match for a1 at index idx0\n", - " result = average(nums=a2)\n", - " result = storage.df(idx0, a0, a1, a2, result)\n", - " idx0 a0 a1 a2 result\n", - "8 0 0 1 [1, 2, 3] 2.0\n", - "6 1 1 2 [1, 2, 3] 2.0\n", - "2 2 2 3 [1, 2, 3] 2.0\n", - "7 0 1 2 [2, 3, 4] 3.0\n", - "3 1 2 3 [2, 3, 4] 3.0\n", - "5 2 3 4 [2, 3, 4] 3.0\n", - "1 0 2 3 [3, 4, 5] 4.0\n", - "4 1 3 4 [3, 4, 5] 4.0\n", - "0 2 4 5 [3, 4, 5] 4.0\n", - "CHANGE DETECTED in increment from module __main__\n", - "Dependent components:\n", - " Version of \"increment\" from module \"__main__\" (content: a677f90d62d1ed62eab1b55b5197ac6a, semantic: 7795dcadf8a827e5975234a12eb9c3d5)\n", - "===DIFF===:\n", - "\u001b[31m-@op # memoization (and more) decorator\u001b[0m\n", - "\u001b[31m-def increment(x: int) -> int: # always indicate number of outputs in return type\u001b[0m\n", - "\u001b[31m- print('hi from increment!')\u001b[0m\n", - "\u001b[31m- return x + 1\u001b[0m\n", - "\u001b[32m+@op\u001b[0m\n", - "\u001b[32m+def increment(x: int) -> int:\u001b[0m\n", - "\u001b[32m+ print('hi from new increment!')\u001b[0m\n", - "\u001b[32m+ return x + 2\u001b[0m\n", - "Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort \n", - "You answered: \"y\"\n", - "hi from new increment!\n", - "hi from new increment!\n", - "hi from new increment!\n", - "hi from new increment!\n", - "hi from new increment!\n", - "hi from average!\n", - "Pattern-matching to the following computational graph (all constraints apply):\n", - " idx0 = Q() # index into list\n", - " a0 = Q() # input to computation; can match anything\n", - " a1 = increment(x=a0)\n", - " a2 = ListQ(elts=[a1], idxs=[idx0]) # a2 will match any list containing a match for a1 at index idx0\n", - " result = average(nums=a2)\n", - " result = storage.df(idx0, a0, a1, a2, result)\n", - " idx0 a0 a1 a2 result\n", - "8 0 0 2 [2, 3, 4] 3.0\n", - "6 1 1 3 [2, 3, 4] 3.0\n", - "7 2 2 4 [2, 3, 4] 3.0\n", - "3 0 1 3 [3, 4, 5] 4.0\n", - "5 1 2 4 [3, 4, 5] 4.0\n", - "4 2 3 5 [3, 4, 5] 4.0\n", - "1 0 2 4 [4, 5, 6] 5.0\n", - "0 1 3 5 [4, 5, 6] 5.0\n", - "2 2 4 6 [4, 5, 6] 5.0\n" - ] - } - ], - "source": [ - "from mandala.imports import *\n", - "\n", - "# the storage saves calls and tracks dependencies, versions, etc.\n", - "storage = Storage( \n", - " deps_path='__main__' # track dependencies in current session\n", - " ) \n", - "\n", - "@op # memoization (and more) decorator\n", - "def increment(x: int) -> int: # always indicate number of outputs in return type\n", - " print('hi from increment!')\n", - " return x + 1\n", - "\n", - "increment(23) # function acts normally\n", - "\n", - "with storage.run(): # context manager that triggers `mandala`\n", - " y = increment(23) # now it's memoized w.r.t. this version of `increment`\n", - "\n", - "print(y) # result wrapped with metadata. \n", - "print(unwrap(y)) # `unwrap` gets the raw value\n", - "\n", - "with storage.run():\n", - " y = increment(23) # loads result from `storage`; doesn't execute `increment`\n", - "\n", - "@op # type-annotate data structures to store elts separately\n", - "def average(nums: list) -> float: \n", - " print('hi from average!')\n", - " return sum(nums) / len(nums)\n", - "\n", - "# memoized functions are designed to be composed!\n", - "with storage.run(): \n", - " # sliding averages of `increment`'s results over 3 elts\n", - " nums = [increment(i) for i in range(5)]\n", - " for i in range(3):\n", - " result = average(nums[i:i+3])\n", - "\n", - "# get a table of all values similar to `result` in `storage`,\n", - "# i.e., computed as average([increment(something), ...])\n", - "# read the message this prints out!\n", - "print(storage.similar(result, context=True))\n", - "\n", - "# change implementation of `increment` and re-run\n", - "# you'll be asked if the change requires recomputing dependencies (say yes)\n", - "@op\n", - "def increment(x: int) -> int:\n", - " print('hi from new increment!')\n", - " return x + 2\n", - "\n", - "with storage.run(): \n", - " nums = [increment(i) for i in range(5)]\n", - " for i in range(3):\n", - " # only one call to `average` is executed!\n", - " result = average(nums[i:i+3])\n", - "\n", - "# query is ran against the *new* version of `increment`\n", - "print(storage.similar(result, context=True))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.6", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "e39bb3b1f45b78879464f3858f3ac405da62799496d9b7e0a39caf0b676c9a45" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}