Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jflournoy committed Oct 13, 2017
0 parents commit 06d2031
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 0 deletions.
199 changes: 199 additions & 0 deletions app.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#
# This is a Shiny web application. You can run the application by clicking
# the 'Run App' button above.
#
# Find out more about building applications with Shiny here:
#
# http://shiny.rstudio.com/
#

library(shiny)
library(tidyverse)
library(wesanderson)

# Define UI for application that draws a histogram
ui <- fluidPage(

# Application title
titlePanel("Associative Learning Simulation"),

withMathJax(),

# Sidebar with a slider input for number of bins
sidebarLayout(
sidebarPanel(
h3('Set parameter levels'),
sliderInput("xi",
HTML("\\(\\xi\\) <span style=\"font-weight: normal; font-style: italic;\">(noise)</span>"),
min = -3,
max = 3,
value = 0,
step=.05),
sliderInput("ep",
HTML("\\(\\epsilon\\) <span style=\"font-weight: normal; font-style: italic;\">(learning rate)</span>"),
min = -3,
max = 3,
value = 0,
step=.05),
# sliderInput("b",
# "Bias",
# min = -2,
# max = 2,
# value = 0,
# step=.05),
sliderInput("rho",
HTML("\\(\\rho\\) <span style=\"font-weight: normal; font-style: italic;\">(inverse temperature)</span>"),
min = -3,
max = 3,
value = 0,
step=.05),
h3('Parameter optimality across 40 random runs'),
p('When noise parameter is set very low'),
img(src = 'optimality_plot.png', style = 'max-width: 100%')
),

# Show a plot of the generated distribution
mainPanel(
plotOutput("trialsPlot"),
helpText("RW Equation: \n
$$p(a_t|s_t) = \\text{logit}^{-1}\\Big(Q_{t-1}(a_{t},s_{t}) + \\epsilon\\big(\\rho r_{t} - Q_{t-1}(a_{t},s_{t})\\big)\\Big)\\cdot (1-\\xi) + \\frac{\\xi}{2}$$"),
uiOutput('rw_eq')
# ,
# tableOutput('runTable')
)
)
)

# Define server logic required to draw a histogram
server <- function(input, output) {

generateTrials <- function() {
p_right <- data.frame(expand.grid(cue=1:2, reward = c(1,5)), pcorrect_if_pressed_r=c(rep(.2,1), rep(.8,1)))

cue1Indxs <- sample(c(1,3), size = 60, replace = T)
cue2Indxs <- sample(c(2,4), size = 60, replace = T)

manyTrialIndxs <- c(cue1Indxs,cue2Indxs)

manyTrialIndxsShuffled <- manyTrialIndxs[sample(1:length(manyTrialIndxs),
size = length(manyTrialIndxs),
replace = F)]

Trials <- p_right[manyTrialIndxsShuffled,]
Trials$crct_if_right <- rbinom(dim(Trials)[1], size = 1, prob = Trials$pcorrect_if_pressed_r)
Trials$outcome_r <- Trials$crct_if_right*Trials$reward
Trials$outcome_l <- (1-Trials$crct_if_right)*Trials$reward
return(Trials)
}

inv_logit <- function(x) exp(x)/(1+exp(x))
Phi_approx <- function(x) pnorm(x)

rw_strategy <- function(trialdf, mu_p){
xi <- Phi_approx( mu_p[1])# + sigma[1] * xi_pr[i] )
ep <- Phi_approx( mu_p[2])# + sigma[2] * ep_pr[i] )
b <- mu_p[3]# + sigma[3] * b_pr; # vectorization
rho <- exp( mu_p[4])# + sigma[4] * rho_pr );

K <- length(unique(trialdf$cue))
Tsubj <- dim(trialdf)[1]
wv_g <- c(rep(0, K)) # action wegith for go
wv_ng <- c(rep(0, K)) # action wegith for nogo
qv_g <- c(rep(0, K)) # Q value for go
qv_ng <- c(rep(0, K)) # Q value for nogo
pGo <- c(rep(0, K)) # prob of go (press)

trialdf$pressed_r <- NA
trialdf$Qgo <- NA
trialdf$Qnogo <- NA
trialdf$Wgo <- NA
trialdf$Wnogo <- NA
trialdf$pGo <- NA
trialdf$outcome <- NA

for (t in 1:Tsubj) {
wv_g[ trialdf$cue[t] ] <- qv_g[ trialdf$cue[t] ] + b
wv_ng[ trialdf$cue[t] ] <- qv_ng[ trialdf$cue[t] ] # qv_ng is always equal to wv_ng (regardless of action)
pGo[ trialdf$cue[t] ] = inv_logit( wv_g[ trialdf$cue[t] ] - wv_ng[ trialdf$cue[t] ] )
pGo[ trialdf$cue[t] ] = pGo[ trialdf$cue[t] ] * (1 - xi) + xi/2; # noise

trialdf$pressed_r[t] <- rbinom(n = 1, size = 1, prob = , pGo[ trialdf$cue[t] ]);

trialdf$Qgo[t] <- qv_g[ trialdf$cue[t] ];
trialdf$Qnogo[t] <- qv_ng[ trialdf$cue[t] ];
trialdf$Wgo[t] <- wv_g[ trialdf$cue[t] ];
trialdf$Wnogo[t] <- wv_ng[ trialdf$cue[t] ];
trialdf$pGo[t] <- pGo[ trialdf$cue[t] ];

# update action values
if(trialdf$pressed_r[t] == 1){
qv_g[ trialdf$cue[t] ] <- qv_g[ trialdf$cue[t] ] + ep * (rho * trialdf$outcome_r[t] - qv_g[ trialdf$cue[t] ]);
trialdf$outcome[t] <- trialdf$outcome_r[t]
} else {
qv_ng[ trialdf$cue[t] ] <- qv_ng[ trialdf$cue[t] ] + ep * (rho * trialdf$outcome_l[t] - qv_ng[ trialdf$cue[t] ]);
trialdf$outcome[t] <- trialdf$outcome_l[t]
}
} # end of t loop
return(trialdf)
}

plot_RW_run <- function(trials, mu_p){
single_run <- rw_strategy(trialdf = trials,
mu_p = mu_p)

aplot <- single_run %>%
mutate(cue = factor(cue)) %>%
group_by(cue) %>%
mutate(t = 1:n(), last_outcome = as.numeric( ifelse(lag(pressed_r) == 1 & lag(outcome) == 5, 1,
ifelse(lag(pressed_r) == 1 & lag(outcome) == 1, .95,
ifelse(lag(pressed_r) == 1 & lag(outcome) == 0, .1,
ifelse(lag(pressed_r) == 0 & lag(outcome) == 5, 0,
ifelse(lag(pressed_r) == 0 & lag(outcome) == 1, .05,
.9)))))),
last_press = lag(pressed_r)) %>%
ggplot(aes(x = t, y = pGo)) +
geom_line(alpha = .1) +
geom_line(stat = 'smooth', method = 'gam', formula = y ~ s(x, k = 15, bs = "cr"), alpha = .5) +
geom_segment(aes(xend = t, yend = last_outcome), alpha = .1, color = 'black') +
geom_point(aes(y = last_outcome, shape = factor(last_press))) +
scale_shape_manual(values = c(25,24), name = 'Last press was...', breaks = c(1,0), labels = c('Right', 'left')) +
scale_y_continuous(breaks = c(0,.5, 1), labels = c('left', '', 'right'))+
geom_point() +
facet_wrap(~cue, nrow = 2)+
theme(panel.background = element_blank(),
# strip.text = element_blank(),
strip.background = element_rect(fill = '#eeeeee'))+
labs(y = "More likely to press...", x = 'Trial number')
return(list(plot = aplot, runData = single_run))
}

someTrials <- generateTrials()

simulatedTrials <- reactive({
plot_RW_run(trials = someTrials,
mu_p = c(xi = input$xi, ep = input$ep, b = 0, rho = input$rho))
})

output$trialsPlot <- renderPlot({
simulatedTrials()$plot
})

output$rw_eq <- renderUI({
xi <- Phi_approx( input$xi )
ep <- Phi_approx( input$ep )
b <- input$b
rho <- exp( input$rho )

withMathJax(sprintf("RW Equation with transformed values: \n
$$p(a_t|s_t) = \\text{logit}^{-1}\\Big(Q_{t-1}(a_{t},s_{t}) + %.02f\\big(%.02f r_{t} - Q_{t-1}(a_{t},s_{t})\\big)\\Big)\\cdot (1-%.02f) + \\frac{%.02f}{2}$$",
ep, rho, xi, xi))
})

output$runTable <- renderTable({
simulatedTrials()$runData
})
}

# Run the application
shinyApp(ui = ui, server = server)

10 changes: 10 additions & 0 deletions rsconnect/shinyapps.io/jflournoy/rw_model.dcf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: rw_model
title: rw_model
account: jflournoy
server: shinyapps.io
appId: 188785
bundleId: 909163
url: https://jflournoy.shinyapps.io/rw_model/
when: 1500486751.98579
asMultiple: FALSE
asStatic: FALSE
Binary file added www/optimality_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 06d2031

Please sign in to comment.