Skip to content

Commit 2a766d2

Browse files
authored
Merge pull request #6 from mwien/fix-flower-construction
Fix flower construction bug
2 parents fbebdaf + 62120e0 commit 2a766d2

File tree

4 files changed

+105
-20
lines changed

4 files changed

+105
-20
lines changed

cliquepicking_python/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "cliquepicking"
3-
version = "0.2.4"
3+
version = "0.2.5"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

cliquepicking_rs/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "cliquepicking_rs"
3-
version = "0.1.0"
3+
version = "0.2.5"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

cliquepicking_rs/src/clique_tree.rs

+8-13
Original file line numberDiff line numberDiff line change
@@ -121,33 +121,28 @@ impl CliqueTree {
121121
if flowers[edge_id].is_empty() {
122122
let mut flower = Vec::new();
123123
flower.push(t);
124-
let mut add_ids = Vec::new();
125-
add_ids.push(edge_id);
126124
visited[s] = true;
127125
visited[t] = true;
128126
let mut q = VecDeque::new();
129127
q.push_back(t);
130128
while !q.is_empty() {
131129
let u = q.pop_front().unwrap();
132130
for &v in self.tree.neighbors(u) {
133-
if !visited[v] && st_sep.is_subset(&self.cliques[v]) {
134-
if separators[self.get_edge_id(u, v)] == *st_sep {
135-
add_ids.push(self.get_edge_id(u, v));
136-
} else {
137-
flower.push(v);
138-
visited[v] = true;
139-
q.push_back(v);
140-
}
131+
if !visited[v]
132+
&& st_sep.is_subset(&self.cliques[v])
133+
&& separators[self.get_edge_id(u, v)] != *st_sep
134+
{
135+
flower.push(v);
136+
visited[v] = true;
137+
q.push_back(v);
141138
}
142139
}
143140
}
144141
visited[s] = false;
145142
for &f in &flower {
146143
visited[f] = false;
147144
}
148-
for &id in &add_ids {
149-
flowers[id] = IndexSet::from(flower.clone());
150-
}
145+
flowers[edge_id] = IndexSet::from(flower.clone());
151146
}
152147
}
153148
}

cliquepicking_rs/src/sample.rs

+95-5
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,10 @@ pub fn sample_cpdag_orders(g: &PartiallyDirectedGraph, k: usize) -> Vec<Vec<usiz
530530
mod tests {
531531
use std::collections::{HashMap, HashSet};
532532

533-
use crate::graph::Graph;
533+
use crate::{graph::Graph, partially_directed_graph::PartiallyDirectedGraph};
534534

535-
#[test]
536-
fn sample_amos_basic_check() {
537-
let g = Graph::from_edge_list(
535+
fn get_paper_graph() -> Graph {
536+
Graph::from_edge_list(
538537
vec![
539538
(0, 1),
540539
(0, 2),
@@ -549,7 +548,74 @@ mod tests {
549548
(4, 5),
550549
],
551550
6,
552-
);
551+
)
552+
}
553+
554+
fn get_issue4_graph() -> PartiallyDirectedGraph {
555+
PartiallyDirectedGraph::from_edge_list(
556+
vec![
557+
(0, 1),
558+
(1, 0),
559+
(0, 2),
560+
(2, 0),
561+
(1, 2),
562+
(2, 1),
563+
(1, 3),
564+
(3, 1),
565+
(1, 4),
566+
(4, 1),
567+
],
568+
5,
569+
)
570+
}
571+
572+
fn get_issue5_graph() -> PartiallyDirectedGraph {
573+
PartiallyDirectedGraph::from_edge_list(
574+
vec![
575+
(9, 10),
576+
(9, 13),
577+
(9, 7),
578+
(10, 9),
579+
(10, 11),
580+
(10, 12),
581+
(13, 9),
582+
(4, 5),
583+
(4, 12),
584+
(5, 4),
585+
(0, 1),
586+
(0, 3),
587+
(1, 0),
588+
(1, 19),
589+
(6, 7),
590+
(6, 14),
591+
(6, 19),
592+
(7, 6),
593+
(7, 9),
594+
(7, 8),
595+
(14, 6),
596+
(14, 15),
597+
(8, 7),
598+
(8, 19),
599+
(16, 15),
600+
(16, 18),
601+
(16, 17),
602+
(15, 16),
603+
(15, 14),
604+
(18, 16),
605+
(18, 19),
606+
(11, 10),
607+
(11, 19),
608+
(3, 17),
609+
(3, 19),
610+
(2, 3),
611+
],
612+
5,
613+
)
614+
}
615+
616+
#[test]
617+
fn sample_amos_basic_check() {
618+
let g = get_paper_graph();
553619
let sample_size = 10_000;
554620
let amos = super::sample_amos(&g, sample_size);
555621
assert_eq!(amos.len(), sample_size);
@@ -567,4 +633,28 @@ mod tests {
567633
}
568634
assert_eq!(dags.len(), 54);
569635
}
636+
637+
#[test]
638+
fn sample_cpdag_basic_check() {
639+
let g = get_issue4_graph();
640+
let sample_size = 10_000;
641+
let dags = super::sample_cpdag(&g, sample_size);
642+
assert_eq!(dags.len(), sample_size);
643+
let mut count_dags = HashMap::new();
644+
for a in dags.iter() {
645+
count_dags.entry(a).and_modify(|cnt| *cnt += 1).or_insert(1);
646+
}
647+
assert_eq!(count_dags.len(), 10);
648+
let g = get_issue5_graph();
649+
let sample_size = 10_000;
650+
let dags = super::sample_cpdag(&g, sample_size);
651+
assert_eq!(dags.len(), sample_size);
652+
let mut count_dags = HashMap::new();
653+
for a in dags.iter() {
654+
count_dags.entry(a).and_modify(|cnt| *cnt += 1).or_insert(1);
655+
}
656+
assert_eq!(count_dags.len(), 44);
657+
}
658+
659+
// TODO: test orders as well
570660
}

0 commit comments

Comments
 (0)