Skip to content

Commit 17507e9

Browse files
author
mhtess
committed
uncertain parse model (replacing uncertain has a threshold model)
1 parent 17f4ddf commit 17507e9

File tree

1 file changed

+172
-41
lines changed

1 file changed

+172
-41
lines changed

models/model-understanding.Rmd

+172-41
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,42 @@ library(tidyverse)
1212
library(knitr)
1313
theme_set(theme_few())
1414
```
15+
```{r rsaBins}
16+
rsaBinsCoarse <- '
17+
var lowerBins = [
18+
0,
19+
0.01,
20+
0.1,
21+
0.2,
22+
0.3,
23+
0.4,
24+
0.5,
25+
0.6,
26+
0.7,
27+
0.8,
28+
0.9,
29+
0.99
30+
];
1531
16-
```{r utils}
17-
utils <- '
18-
var round = function(x){
19-
return Math.round(x*100)/100
20-
}
21-
22-
var isNegation = function(utt){
23-
return (utt.split("_")[0] == "not")
24-
};
25-
26-
var hasNegModifier = function(utt){
27-
return (utt.split("_")[0] == "not")
28-
};
29-
var hasNegMorph = function(utt){
30-
return (utt.indexOf("un") > -1)
31-
};
32-
var roundTo3 = function(x){
33-
return Math.round(x * 1000) / 1000
34-
}
32+
var upperBins = [
33+
0.01,
34+
0.1,
35+
0.2,
36+
0.3,
37+
0.4,
38+
0.5,
39+
0.6,
40+
0.7,
41+
0.8,
42+
0.9,
43+
0.99,
44+
1
45+
];
46+
'
47+
```
3548

49+
```{r rsaBinsFine}
50+
rsaBinsFine <- '
3651
var lowerBins = [
3752
0,
3853
0.01,
@@ -82,6 +97,28 @@ var upperBins = [
8297
0.99,
8398
1
8499
];
100+
'
101+
```
102+
103+
```{r utils}
104+
utils <- '
105+
var round = function(x){
106+
return Math.round(x*100)/100
107+
}
108+
109+
var isNegation = function(utt){
110+
return (utt.split("_")[0] == "not")
111+
};
112+
113+
var hasNegModifier = function(utt){
114+
return (utt.split("_")[0] == "not")
115+
};
116+
var hasNegMorph = function(utt){
117+
return (utt.indexOf("un") > -1)
118+
};
119+
var roundTo3 = function(x){
120+
return Math.round(x * 1000) / 1000
121+
}
85122
86123
var midBins = map2(function(b1,b2){
87124
return roundTo3((b2 - b1)/2 + b1)
@@ -121,7 +158,6 @@ var DiscreteBeta = cache(function(a, b){
121158
'
122159
```
123160

124-
125161
```{r meaningFn}
126162
meaningFn <- '
127163
var meaning = function(words, state, thresholds){
@@ -525,19 +561,22 @@ rs.listener.wp.tidy %>%
525561

526562
# Uncertain "has threshold" RSA
527563

564+
01/16/18: This is being refashioned to be analagous to the "uncertain parsing" model (yet to be implemented)
565+
528566
```{r rsa-uncertainHasThresholds}
529567
uncertainHasThresholdsRSA <- '
530568
var utterances = [
531569
"happy",
532570
"not_unhappy",
533571
"not_happy",
534-
"unhappy",
572+
"unhappy"
573+
// "silence"
535574
// "neither_nor"
536575
];
537576
538577
var cost_yes = 0;
539-
var cost_not = 3;
540-
var cost_un = 3;
578+
var cost_not = 2;
579+
var cost_un = 2;
541580
542581
var uttCosts = map(function(u) {
543582
var notCost = hasNegModifier(u) ? cost_not : 0
@@ -553,22 +592,50 @@ var utterancePrior = Infer({model: function(){
553592
var speakerOptimality = 1;
554593
var speakerOptimality2 = 1;
555594
556-
var has_an_unhappy_threshold_prior = 0.2;
595+
var meaning = function(words, state, thresholds, parsing){
596+
words == "happy" ? state > thresholds.happy :
597+
words == "not_happy" ? parsing.compositional_not ?
598+
!(state > thresholds.happy) :
599+
(state < thresholds.not_happy) :
600+
words == "unhappy" ? parsing.compositional_un ?
601+
!(state > thresholds.happy) :
602+
(state < thresholds.unhappy) :
603+
// words == "not_unhappy" ? parsing.compositional_not ?
604+
// parsing.compositional_un ? (state > thresholds.happy) :
605+
// !(state < thresholds.unhappy) :
606+
// (state > thresholds.not_unhappy) :
607+
words == "not_unhappy" ? parsing.compositional_un ?
608+
(state > thresholds.happy) : !(state < thresholds.unhappy) :
609+
words == "sad" ? state < thresholds.sad :
610+
words == "not_sad" ? !(state < thresholds.sad) :
611+
words == "neither_nor" ? (
612+
!(state > thresholds.happy) &&
613+
!(state < thresholds.unhappy)
614+
) :
615+
true
616+
};
557617
558-
var listener0 = cache(function(utterance, thresholds) {
618+
var compositional_un_prior = 0.5;
619+
var compositional_not_prior = 0.5;
620+
// var un_not_lexical_prior = not_lexical_prior*un_lexical_prior;
621+
622+
var listener0 = cache(function(utterance, thresholds, parsing) {
559623
Infer({model: function(){
560624
var state = sample(DiscreteBeta(1, 1));
625+
// display(JSON.stringify(thresholds))
561626
// var state = sample(DiscreteGaussian(0, 0.5));
562-
var m = meaning(utterance, state, thresholds);
627+
var m = meaning(utterance, state, thresholds, parsing);
628+
// display("l0 " + state + " " + m + " " + JSON.stringify(parsing))
563629
condition(m);
564630
return state;
565631
}})
566632
}, 10000);
567633
568-
var speaker1 = cache(function(state, thresholds) {
634+
var speaker1 = cache(function(state, thresholds, parsing) {
569635
Infer({model: function(){
570636
var utterance = sample(utterancePrior);
571-
var L0 = listener0(utterance, thresholds);
637+
// display(utterance)
638+
var L0 = listener0(utterance, thresholds, parsing);
572639
factor(speakerOptimality*L0.score(state));
573640
return utterance;
574641
}})
@@ -577,23 +644,31 @@ var speaker1 = cache(function(state, thresholds) {
577644
var listener1 = cache(function(utterance) {
578645
Infer({model: function(){
579646
580-
var happy_threshold = uniformDraw(thetaBins)
581-
var has_an_unhappy_threshold = flip(has_an_unhappy_threshold_prior)
582-
var unhappy_threshold = has_an_unhappy_threshold ?
583-
uniformDraw(thetaBins) :
584-
happy_threshold
647+
var happy_threshold = uniformDraw(thetaBins);
648+
var compositional_un = flip(compositional_un_prior)
649+
var compositional_not = flip(compositional_not_prior)
650+
651+
var unhappy_threshold = compositional_un ? "happy_threshold" : uniformDraw(thetaBins)
652+
var not_happy_threshold = compositional_not ? "happy_threshold" : uniformDraw(thetaBins);
653+
var not_unhappy_threshold = -99;
654+
// compositional_not ? compositional_un ? "happy_threshold" :
655+
// "unhappy_threshold" : uniformDraw(thetaBins)
585656
586657
var thresholds = {
587658
happy: happy_threshold,
588-
unhappy: unhappy_threshold
659+
unhappy: unhappy_threshold,
660+
not_happy: not_happy_threshold,
661+
not_unhappy: not_unhappy_threshold
589662
}
590663
664+
var parsing = {compositional_un, compositional_not}
665+
591666
var state = sample(DiscreteBeta(1, 1));
592667
// var state = sample(DiscreteGaussian(0, 0.5));
593668
594-
var S1 = speaker1(state, thresholds)
669+
var S1 = speaker1(state, thresholds, parsing)
595670
observe(S1, utterance)
596-
return state
671+
return extend(parsing, {state})
597672
}})
598673
}, 10000);
599674
'
@@ -602,14 +677,19 @@ var listener1 = cache(function(utterance) {
602677
```{r wpplCalls-uncertainHasThresholds}
603678
uncertainHasThresholdListenerCall <- '
604679
_.fromPairs(map(function(u){
680+
display(u)
605681
var post = listener1(u)
606-
return [u, post]
682+
display(u + " __ Comp(un) = " + expectation(post, function(x){return x.compositional_un}))
683+
display(u + " __ Comp(not) = " + expectation(post, function(x){return x.compositional_not}))
684+
return [u, marginalize(post, "state")]
607685
}, utterances))
608686
'
687+
#uncertainHasThresholdListenerCall<- 'listener1("unhappy")'
609688
```
610689

611690
```{r runUncertainHasThresholdListener}
612-
rs.listener.wp.2 <- webppl(paste(utils, meaningFn, uncertainHasThresholdsRSA, uncertainHasThresholdListenerCall, sep = '\n'))
691+
rs.listener.wp.2 <- webppl(paste(rsaBinsCoarse,
692+
utils, uncertainHasThresholdsRSA, uncertainHasThresholdListenerCall, sep = '\n'))
613693
614694
rs.listener.wp.tidy.2 <- bind_rows(
615695
data.frame(rs.listener.wp.2$happy) %>%
@@ -633,16 +713,67 @@ rs.listener.wp.tidy.samples.2 <- get_samples(
633713
634714
ggplot(rs.listener.wp.tidy.samples.2,
635715
aes( x = support,fill = utterance, color = utterance))+
636-
geom_density(alpha = 0.4, size = 1.3)+
716+
# geom_density(alpha = 0.4, size = 1.3)+
637717
scale_fill_solarized()+
718+
geom_histogram(alpha = 0.4, size = 1.3)+
638719
scale_color_solarized()+
639720
xlab("Degree of happiness")+
721+
facet_wrap(~utterance)+
640722
ylab("Posterior probability density")+
641723
scale_x_continuous(breaks =c(0, 1))+
642724
scale_y_continuous(breaks = c(0, 2))
643725
644726
#ggsave("figs/L1_posteriors_wCost3_alpha1.png", width = 6, height = 4)
645727
```
728+
```{r}
729+
rs.listener.wp.tidy.2 %>%
730+
group_by(utterance) %>%
731+
summarize(interpretation = sum(probs * support)) %>%
732+
mutate(utterance = factor(utterance,
733+
levels = c("unhappy",
734+
"not_happy",
735+
"not_unhappy",
736+
"happy"))) %>%
737+
ggplot(., aes( x = utterance, y=interpretation,
738+
fill = utterance, color = utterance))+
739+
geom_col(position = position_dodge(0.8),
740+
width = 0.8,
741+
alpha =0.8, color = 'black')+
742+
#coord_flip()+
743+
geom_hline(yintercept = 0.5, lty = 3)+
744+
scale_fill_solarized()+
745+
guides(fill = F)+
746+
scale_y_continuous(limits = c(0, 1), breaks = c(0, 0.5, 1))+
747+
xlab("")+
748+
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1))
749+
750+
#ggsave("figs/L1_means_wCost3_alpha1.png", width = 4, height = 3.5)
751+
```
752+
```{r}
753+
rs.listener.wp.tidy.2 %>%
754+
group_by(utterance) %>%
755+
summarize(interpretation = sum(probs * support)) %>%
756+
mutate(utterance = factor(utterance,
757+
levels = c("unhappy",
758+
"not_happy",
759+
"not_unhappy",
760+
"happy"))) %>%
761+
kable(.)
762+
```
763+
764+
765+
#### parameters
766+
767+
- lower the lexical "un-" probability
768+
- the more "unhappy" and "not happy" get squished together, but also "not unhappy" and "happy"
769+
- including "un" cost
770+
- bring "unhappy" and "not happy" closer together than "happy" and "not unhappy"
771+
- with speaker opt = 1
772+
- "happy" looks kind of weak?
773+
- with higher speaker optimality:
774+
- "not unhappy" > "happy" (because super costly)
775+
776+
646777

647778
# Uncertain alternatives RSA
648779

@@ -1008,7 +1139,7 @@ var listener1 = cache(function(utterance) {
10081139
'
10091140
```
10101141

1011-
```{r wpplCalls-uncertainAlternatives}
1142+
```{r wpplCalls-uncertainParser}
10121143
uncertainAlternativesListenerCall <- '
10131144
_.fromPairs(map(function(u){
10141145
var post = listener1(u)
@@ -1017,7 +1148,7 @@ _.fromPairs(map(function(u){
10171148
'
10181149
```
10191150

1020-
```{r runUncertainAlternativesListener}
1151+
```{r runUncertainParseListener}
10211152
rs.listener.wp.2 <- webppl(paste(utils, uncertainAlternativesRSA, uncertainAlternativesListenerCall, sep = '\n'),
10221153
data = c(1,5,5), data_var = "alternativesPriorProbs")
10231154
```

0 commit comments

Comments
 (0)