Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions examples/cky.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

' # CKY Algorithm for enumerating binary trees.


' The CKY algorithm is one of the most celebrated algorithms in
' Natural Language Processing. First described as early as 1961
' it gives a dynamic programming formulation for counting binary
' over a fixed length sequence.

' https://en.wikipedia.org/wiki/CYK_algorithm

' Historically the algorithm is typically descibed in terms of finding the '
'All possible parse trees of a sentence under a Chomsky Normal Form
' grammar' This allows for determining possible attachment of clauses in
' sentences such as "John hit the ball".

' ![](https://upload.wikimedia.org/wikipedia/en/4/4b/ParseTree.jpg)

' Let us develop some notation for this example. Each of green symbols is
' in the set of labels. For simplicity we can have 10 of them.

Label = Fin 10

' In this parse, the label VP covers the span over 'hit the ball'

Len = Fin 5
sentence = ["John", "hit", "the", "ball", "eos"]

I = 1@Len
J = 4@Len

' We introduce a slice range function to allow us to pull out a slice of the sentence.

-- Slice a range from a table. Used for viewing spans.
def sliceRange (i:a) ?-> (j:a) ?-> (xs : a=>b) : (i..<j) => b = slice xs (ordinal i) (i..<j)
res : (I..<J) => String = sliceRange sentence
toList res


' ## Index Helpers


' In order to make our implementation of the CKY algorithm simpler we introduce
some basic index manipulation functions.

-- Changes type without changing position
def rebase (i: a) ?-> (j: a) ?-> (x:(i<..)) : (j<..) =
((ordinal x) - ((ordinal j) - (ordinal i)))@(j<..)

K = 2@Len
rebase (0@(K<..)) : (I<..)


-- Cast based on ordinal value
def cast (d:a) : m = (ordinal d)@_
-- Shift over from a starting point
def shift (j:a) ?-> (x: a) : (j<..) = cast x

shift (1@_) : (I<..)

-- Index arithmetic
def start : a = 0 @ a
def end : a = (size a - 1) @ a
instance Add (Fin a)
add = \a b. ((ordinal a) + (ordinal b))@_
sub = \a b. ((ordinal a) - (ordinal b))@_
zero = start



' ## Chart Manipulation

' The modern incarnation of CKY abstracts the inference algorithm away from the
underlying grammar. The core focus of the algorithm is to enumerate all binary
trees.

' To do this we start with a dynamic programming chart.

def Chart (a:Type) (b:Type) : Type = i:a => (i<..) => b
def Params (a:Type) (b:Type) (labels:Type) : Type = labels => i:a => (i<..) => b
def chart (ref:Ref h (Chart a b )) (i: a) (j: (i<..)) : Ref h b =
d = %indexRef ref i
d!j

def cky [Add pos, Add semi, Mul semi] (weights' : Params pos semi labels) : (semi & Chart pos semi) =
-- Initialize the chart to all zeros
c_init : Chart pos semi = for i. for j. zero
(first, last) = (start, end)
-- Sum out the labels
weights = for i j. sum for k. weights'.k.i.j
out = runState c_init $ \ c.
C = chart c

-- Enumerate over all spans d
-- Each of these needs to be done in order
for_ d.
boundary = last - d
v = case ordinal d == 0 of
-- Size 1 spans are mapped are initialized as 1
True -> one
False ->
-- Main loop. No writes
c' = get c
for i' : (..<boundary).
-- Calculate span (i, j)
i = %inject i'
j':(i<..) = shift d
j = %inject j'
w = weights.i.j'
-- Sum over k in i<..<j
sum for k' : (i<..<j).
k = %inject k'
c'.i.(cast k') * c'.k.(rebase j') * w

-- Fill (i, j) in the chart
for i' : (..<boundary).
i = %inject i'
j':(i<..) = cast d
C i j' := v.i'
()
get $ C first end
out


' Run an example

length = 10
WordPos = Fin 10
Labels = Fin 10

key = (newKey 0)
param : Params WordPos Float Labels = for i j k. rand (ixkey2 key i (j, k))
(v, table) = cky param

import plot

' Todo : Want this to work
-- :t (grad \p. log $ fst $ cky p) param



:t table

-- :html matshow for i:WordPos. for j:WordPos. case (ordinal i) < (ordinal j) of
-- True -> table.i.(shift (j - i- 1@_))
-- False -> 0.0