Skip to content

Commit d1b87a8

Browse files
committed
feat: Allow to update dims and coords in stan model
1 parent ae7e9f1 commit d1b87a8

File tree

4 files changed

+60
-36
lines changed

4 files changed

+60
-36
lines changed

Cargo.lock

Lines changed: 30 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
[package]
22
name = "nutpie"
3-
version = "0.5.1"
4-
authors = ["Adrian Seyboldt <[email protected]>"]
3+
version = "0.6.0-beta.1"
4+
authors = [
5+
"Adrian Seyboldt <[email protected]>",
6+
"PyMC Developers <[email protected]>"
7+
]
58
edition = "2021"
69
license = "MIT"
710
repository = "https://github.com/pymc-devs/nutpie"

nutpie/compile_stan.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import pathlib
33
import tempfile
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, replace
55
from typing import Any, Dict, Optional
66

77
import numpy as np
@@ -58,6 +58,22 @@ def with_data(self, *, seed=None, **updates):
5858
model=model,
5959
)
6060

61+
def with_coords(self, **coords):
62+
if self.coords is None:
63+
coords_new = {}
64+
else:
65+
coords_new = self.coords.copy()
66+
coords_new.update(coords)
67+
return replace(self, coords=coords_new)
68+
69+
def with_dims(self, **dims):
70+
if self.dims is None:
71+
dims_new = {}
72+
else:
73+
dims_new = self.dims.copy()
74+
dims_new.update(dims)
75+
return replace(self, dims=dims_new)
76+
6177
def _make_model(self, init_mean):
6278
if self.model is None:
6379
return self.with_data().model

pyproject.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[build-system]
2-
requires = ["maturin>=0.14,<0.15"]
2+
requires = ["maturin>=1.1,<2.0"]
33
build-backend = "maturin"
44

55
[tool.maturin]
6-
bindings = "pyo3"
6+
features = ["pyo3/extension-module"]
77

88
[project]
99
name = "nutpie"
@@ -12,7 +12,11 @@ authors = [{name = "PyMC Developers", email = "[email protected]"}]
1212
readme = "README.md"
1313
requires-python = ">=3.9"
1414
license = { text = "MIT" }
15-
version = "0.6.0"
15+
classifiers = [
16+
"Programming Language :: Rust",
17+
"Programming Language :: Python :: Implementation :: CPython",
18+
"Programming Language :: Python :: Implementation :: PyPy",
19+
]
1620

1721
dependencies = [
1822
"pyarrow >= 12.0.0",
@@ -30,3 +34,4 @@ all = [
3034
"pymc >= 5.5.0",
3135
"numba >= 0.57.1",
3236
]
37+

0 commit comments

Comments
 (0)