Skip to content

Commit d3aae1b

Browse files
authored
Add ToPy instances for Py and functions returning Py (#25)
1 parent 569a753 commit d3aae1b

File tree

2 files changed

+88
-44
lines changed

2 files changed

+88
-44
lines changed

src/Python/Inline/Literal.hs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,6 @@ instance (ToPy b) => ToPy (IO b) where
577577
--
578578
[CU.exp| PyObject* { inline_py_callback_METH_NOARGS($(PyCFunction f_ptr)) } |]
579579

580-
581580
-- | Only accepts positional parameters
582581
instance (FromPy a, Show a, ToPy b) => ToPy (a -> IO b) where
583582
basicToPy f = Py $ do
@@ -600,6 +599,40 @@ instance (FromPy a1, FromPy a2, ToPy b) => ToPy (a1 -> a2 -> IO b) where
600599
--
601600
[CU.exp| PyObject* { inline_py_callback_METH_FASTCALL($(PyCFunctionFast f_ptr)) } |]
602601

602+
603+
-- | Converted to 0-ary function
604+
instance (ToPy b) => ToPy (Py b) where
605+
basicToPy f = Py $ do
606+
--
607+
f_ptr <- wrapCFunction $ \_ _ -> pyCallback $ do
608+
progPy $ basicToPy =<< f
609+
--
610+
[CU.exp| PyObject* { inline_py_callback_METH_NOARGS($(PyCFunction f_ptr)) } |]
611+
612+
-- | Only accepts positional parameters
613+
instance (FromPy a, Show a, ToPy b) => ToPy (a -> Py b) where
614+
basicToPy f = Py $ do
615+
--
616+
f_ptr <- wrapCFunction $ \_ p_a -> pyCallback $ do
617+
a <- loadArg p_a 0 1
618+
progPy $ basicToPy =<< f a
619+
--
620+
[CU.exp| PyObject* { inline_py_callback_METH_O($(PyCFunction f_ptr)) } |]
621+
622+
-- | Only accepts positional parameters
623+
instance (FromPy a1, FromPy a2, ToPy b) => ToPy (a1 -> a2 -> Py b) where
624+
basicToPy f = Py $ do
625+
--
626+
f_ptr <- wrapFastcall $ \_ p_arr n -> pyCallback $ do
627+
when (n /= 2) $ abortM $ raiseBadNArgs 2 n
628+
a1 <- loadArgFastcall p_arr 0 n
629+
a2 <- loadArgFastcall p_arr 1 n
630+
progPy $ basicToPy =<< f a1 a2
631+
--
632+
[CU.exp| PyObject* { inline_py_callback_METH_FASTCALL($(PyCFunctionFast f_ptr)) } |]
633+
634+
635+
603636
----------------------------------------------------------------
604637
-- Helpers
605638
----------------------------------------------------------------

test/TST/Callbacks.hs

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -54,48 +54,59 @@ tests = testGroup "Callbacks"
5454
except TypeError as e:
5555
pass
5656
|]
57-
, testCase "Haskell exception in callback(arity=1)" $ runPy $ do
58-
let foo :: Int -> IO Int
59-
foo y = pure $ 10 `div` y
60-
throwsPy [py_| foo_hs(0) |]
61-
, testCase "Haskell exception in callback(arity=2)" $ runPy $ do
62-
let foo :: Int -> Int -> IO Int
63-
foo x y = pure $ x `div` y
64-
throwsPy [py_| foo_hs(1, 0) |]
6557
----------------------------------------
66-
, testCase "Call python in callback (arity=1)" $ runPy $ do
67-
let foo :: Int -> IO Int
68-
foo x = do Just x' <- runPy $ fromPy =<< [pye| 100 // x_hs |]
69-
pure x'
70-
[py_|
71-
assert foo_hs(5) == 20
72-
|]
73-
, testCase "Call python in callback (arity=2" $ runPy $ do
74-
let foo :: Int -> Int -> IO Int
75-
foo x y = do Just x' <- runPy $ fromPy =<< [pye| x_hs // y_hs |]
76-
pure x'
77-
[py_|
78-
assert foo_hs(100,5) == 20
79-
|]
58+
, testCase "Py function(arity 0)" $ do
59+
let fun = [pye| 0 |]
60+
runPy [py_| assert fun_hs() == 0 |]
61+
, testCase "Py function(arity=1)" $ runPy $ do
62+
let double (n::Int) = [pye| n_hs * 2 |]
63+
[py_| assert double_hs(3) == 6 |]
64+
, testCase "Py function(arity=2)" $ runPy $ do
65+
let foo (x::Int) (y::Int) = [pye| x_hs * y_hs |]
66+
[py_| assert foo_hs(3, 100) == 300 |]
8067
----------------------------------------
81-
, testCase "No leaks (arity=1)" $ runPy $ do
82-
let foo :: Int -> IO Int
83-
foo y = pure $ 10 * y
84-
[py_|
85-
import sys
86-
x = 123456
87-
old_refcount = sys.getrefcount(x)
88-
foo_hs(x)
89-
assert old_refcount == sys.getrefcount(x)
90-
|]
91-
, testCase "No leaks (arity=2)" $ runPy $ do
92-
let foo :: Int -> Int -> IO Int
93-
foo x y = pure $ x * y
94-
[py_|
95-
import sys
96-
x = 123456
97-
old_refcount = sys.getrefcount(x)
98-
foo_hs(1,x)
99-
assert old_refcount == sys.getrefcount(x)
100-
|]
101-
]
68+
, testCase "Haskell exception in callback(arity=1)" $ runPy $ do
69+
let foo :: Int -> IO Int
70+
foo y = pure $ 10 `div` y
71+
throwsPy [py_| foo_hs(0) |]
72+
, testCase "Haskell exception in callback(arity=2)" $ runPy $ do
73+
let foo :: Int -> Int -> IO Int
74+
foo x y = pure $ x `div` y
75+
throwsPy [py_| foo_hs(1, 0) |]
76+
----------------------------------------
77+
, testCase "Call python in callback (arity=1)" $ runPy $ do
78+
let foo :: Int -> IO Int
79+
foo x = do Just x' <- runPy $ fromPy =<< [pye| 100 // x_hs |]
80+
pure x'
81+
[py_|
82+
assert foo_hs(5) == 20
83+
|]
84+
, testCase "Call python in callback (arity=2" $ runPy $ do
85+
let foo :: Int -> Int -> IO Int
86+
foo x y = do Just x' <- runPy $ fromPy =<< [pye| x_hs // y_hs |]
87+
pure x'
88+
[py_|
89+
assert foo_hs(100,5) == 20
90+
|]
91+
----------------------------------------
92+
, testCase "No leaks (arity=1)" $ runPy $ do
93+
let foo :: Int -> IO Int
94+
foo y = pure $ 10 * y
95+
[py_|
96+
import sys
97+
x = 123456
98+
old_refcount = sys.getrefcount(x)
99+
foo_hs(x)
100+
assert old_refcount == sys.getrefcount(x)
101+
|]
102+
, testCase "No leaks (arity=2)" $ runPy $ do
103+
let foo :: Int -> Int -> IO Int
104+
foo x y = pure $ x * y
105+
[py_|
106+
import sys
107+
x = 123456
108+
old_refcount = sys.getrefcount(x)
109+
foo_hs(1,x)
110+
assert old_refcount == sys.getrefcount(x)
111+
|]
112+
]

0 commit comments

Comments
 (0)