Skip to content

feat: Add poisson primitives for :count types #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
com.cognitect/transit-clj {:mvn/version "1.0.324"}
com.github.haifengl/smile-core {:mvn/version "3.0.1"}
io.github.inferenceql/inferenceql.gpm.sppl {:git/sha "52f8316e094b3644709dccde8f0a935f9b55f187"}
io.github.inferenceql/inferenceql.inference {:git/sha "40e77dedf680b7936ce988b66186a86f5c4db6a5"}
io.github.inferenceql/inferenceql.inference {:git/sha "78b72a4f0356b705311fc003ee94bc48e5c1c142"}
io.github.inferenceql/inferenceql.query {:git/sha "933a1d1b620bb227ddcf6f649187c8759b46ac27"}
lambdaisland/regal {:mvn/version "0.0.143"}
medley/medley {:mvn/version "1.4.0"}
Expand Down
67 changes: 1 addition & 66 deletions dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,12 @@ stages:
--kernel alpha
--kernel view_alphas
--kernel column_hypers
--kernel rows
--kernel columns
--output data/cgpm/complete/{/}
--data data/numericalized.csv
--params params.yaml
--seed $((${seed} + {#} - 1))
--minutes ${cgpm.minutes}'
#--iterations ${cgpm.iterations}
# --iterations ${cgpm.iterations}
params:
- parallel.flags
- seed
Expand Down Expand Up @@ -259,69 +257,6 @@ stages:
outs:
- data/dep-prob.svg

save-linear-stats:
cmd: >
python scripts/linear_stats.py
--data data/ignored.csv
--schema data/schema.edn
--output data/linear-stats.json
deps:
- data/ignored.csv
- data/schema.edn
- scripts/linear_stats.py
outs:
- data/linear-stats.json

linear-stats-vl:
cmd: >
clojure -X inferenceql.structure-learning.heatmap/vega-lite
:stats-path '"data/linear-stats.json"'
:sort-path '"data/dep-prob.json"'
:domain '[1.0 0.0]'
:default 0.0
:name '"statistics"'
:field '"p-value"'
:scheme '"oranges"'
> data/linear-stats.vl.json
deps:
- data/linear-stats.json
- data/dep-prob.json
- src/clojure/inferenceql/structure_learning/heatmap.clj
outs:
- data/linear-stats.vl.json

linear-stats-vg:
cmd: >
pnpm vl2vg
< data/linear-stats.vl.json
> data/linear-stats.vg.json
deps:
- data/linear-stats.vl.json
outs:
- data/linear-stats.vg.json

linear-stats-svg:
cmd: >
pnpm vg2svg
< data/linear-stats.vg.json
> data/linear-stats.svg
deps:
- data/linear-stats.vg.json
outs:
- data/linear-stats.svg
compare-dep-prob-with-linear:
desc: "Compares results from dependency probability with standard statistical tests"
cmd: >
python scripts/compare_deps.py
--deps data/dep-prob.json
--linear data/linear-stats.json
>> data/qc-statistical-tests.txt
deps:
- data/dep-prob.json
- data/linear-stats.json
outs:
- data/qc-statistical-tests.txt

ast-export:
desc: "Exports ASTs of the parametric model programs resulting from truncating CGPM-CrossCat models."
cmd:
Expand Down
2 changes: 1 addition & 1 deletion scripts/ast_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def export_primitive(output, cctype, hypers, suffstats, distargs, categorical_ma
a = hypers["a"]
b = hypers["b"]
N = suffstats["N"]
x_sum = suffstats["x_sum"]
sum_x = suffstats["sum_x"]
# Compute the distribution.
# The implementation of Poisson.logpdf in CGPM is rather suspicious:
# https://github.com/probcomp/cgpm/issues/251
Expand Down
4 changes: 3 additions & 1 deletion scripts/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def impute_missing_features(train_dataset, test_dataset, schema):
replacements[c] = train_dataset[c].median()
elif schema[c] == "nominal":
replacements[c] = train_dataset[c].mode()[0]
elif schema[c] == "count":
replacements[c] = int(train_dataset[c].median())
else:
raise ValueError(error_message_stat_type(schema[c]))
train_dataset = train_dataset.fillna(replacements)
Expand Down Expand Up @@ -62,7 +64,7 @@ def recode_categoricals(train_dataset, test_dataset, schema, target):
)
# Add new cols to df
col_names = [
c for c in train_dataset.columns if (schema[c] == "numerical") and (c != target)
c for c in train_dataset.columns if (schema[c] not in ["nominal", "ignored"] ) and (c != target)
]
for i in range(X_transformed.shape[1]):
col_name = f"c_{i}"
Expand Down
2 changes: 2 additions & 0 deletions src/clojure/inferenceql/structure_learning/schema.clj
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"Returns the Loom schema for an InferenceQL schema."
[schema]
(let [replacements {:nominal "dd" ; discrete dirichlet
:count "gp" ; gamma-poisson
:numerical "nich"}] ; normal inverse chi squared
(into {}
(comp (remove (comp #{:ignore} val))
Expand All @@ -78,6 +79,7 @@
"Returns the CGPM schema for an InferenceQL schema."
[schema]
(let [replacements {:nominal "categorical" ; discrete dirichlet
:count "poisson"
:numerical "normal"}]
(into {}
(comp (remove (comp #{:ignore} val))
Expand Down
14 changes: 8 additions & 6 deletions src/clojure/inferenceql/structure_learning/xcat.clj
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
(defn ^:private view-name
"Returns a cluster name for view index n."
[n]
(str "view_" n))
(keyword (str "view_" n)))

(defn ^:private cluster-name
"Returns a cluster name for cluster index n."
[n]
(str "cluster_" n))
(keyword (str "cluster_" n)))

(defn ^:private map-invert
"Returns m with its vals as keys and its keys grouped into a vector as vals."
Expand All @@ -35,8 +35,9 @@

(defn ^:private data
[data-cells schema num-rows]
(let [headers (map name (first data-cells))
(let [headers (map keyword (first data-cells))
column->f (comp {:numerical am.csv/parse-number
:count am.csv/parse-number
:nominal am.csv/parse-str}
schema
name)
Expand All @@ -58,15 +59,16 @@

(defn ^:private col-names
[numericalized cgpm-model]
(mapv name (get cgpm-model :col_names (first numericalized))))
(mapv keyword (get cgpm-model :col_names (first numericalized))))

(defn ^:private spec
[numericalized schema cgpm-model]
(let [columns (col-names numericalized cgpm-model)
views (views columns cgpm-model)
types (->> schema
(medley/map-keys name)
(medley/map-keys keyword)
(medley/map-vals {:nominal :categorical
:count :poisson
:numerical :gaussian}))]
{:views views
:types types}))
Expand All @@ -93,7 +95,7 @@
(defn options
[mapping-table]
(->> mapping-table
(medley/map-keys name)
(medley/map-keys keyword)
(medley/map-vals #(->> % (sort-by val) (map key) (into [])))))

(defn xcat-model
Expand Down
37 changes: 16 additions & 21 deletions test/clojure/inferenceql/structure_learning/schema_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
(is (= [0 1 2]
(schema/column :x [{:x 0 :y "a"}
{:x 1 :y "b"}
{:x 2 :y "c"}])))
(is (= [0 1 2]
(schema/column "x" [{"x" 0 "y" "a"}
{"x" 1 "y" "b"}
{"x" 2 "y" "c"}]))))
{:x 2 :y "c"}]))))

(deftest guess-stattype
(are [stattype coll] (= stattype (schema/guess-stattype :ignore coll))
Expand All @@ -36,24 +32,23 @@
:ignore
[{:id 0 :x 0.0 :y "a"}
{:id 1 :x 1.0 :y "b"}
{:id 2 :x 2.0 :y "c"}])))
(is (= {"id" :ignore
"x" :numerical
"y" :nominal}
(schema/guess
:ignore
[{"id" 0 "x" 0.0 "y" "a"}
{"id" 1 "x" 1.0 "y" "b"}
{"id" 2 "x" 2.0 "y" "c"}]))))
{:id 2 :x 2.0 :y "c"}]))))

(deftest loom
(is (= {:x "nich"
:y "dd"}
:y "dd"
:z "gp"}
(schema/loom {:id :ignore
:x :numerical
:y :nominal})))
(is (= {"x" "nich"
"y" "dd"}
(schema/loom {"id" :ignore
"x" :numerical
"y" :nominal}))))
:y :nominal
:z :count}))))
(deftest cgpm
(is (= {"x" "normal"
"y" "categorical"
"z" "poisson"
}
(schema/cgpm {:id :ignore
:x :numerical
:y :nominal
:z :count
}))))