From bea9741473d03aefffcdea7dc6a4885a60ba5a80 Mon Sep 17 00:00:00 2001 From: Jesper Cockx Date: Mon, 6 Oct 2025 17:56:06 +0200 Subject: [PATCH 1/4] Allow inlining of any function, do check dynamically --- src/Agda2Hs/Compile.hs | 4 +++- src/Agda2Hs/Compile/Function.hs | 28 ++++------------------------ src/Agda2Hs/Compile/Term.hs | 10 +++++++--- src/Agda2Hs/Compile/Types.hs | 2 ++ src/Agda2Hs/Compile/Utils.hs | 24 +++++++++++++++++++++++- test/Fail/Inline.agda | 5 +++++ test/Fail/Inline2.agda | 5 +++++ test/golden/Inline.err | 5 +++-- test/golden/Inline2.err | 5 +++-- 9 files changed, 55 insertions(+), 33 deletions(-) diff --git a/src/Agda2Hs/Compile.hs b/src/Agda2Hs/Compile.hs index 55815176..f76cc499 100644 --- a/src/Agda2Hs/Compile.hs +++ b/src/Agda2Hs/Compile.hs @@ -11,6 +11,7 @@ import Data.IORef import Data.List ( isPrefixOf, group, sort ) import qualified Data.Map as M +import qualified Data.Set as S import Agda.Compiler.Backend import Agda.Compiler.Common ( curIF ) @@ -41,7 +42,8 @@ globalSetup :: Options -> TCM GlobalEnv globalSetup opts = do opts <- checkConfig opts ctMap <- liftIO $ newIORef M.empty - return $ GlobalEnv opts ctMap + ilMap <- liftIO $ newIORef S.empty + return $ GlobalEnv opts ctMap ilMap initCompileEnv :: GlobalEnv -> TopLevelModuleName -> SpecialRules -> CompileEnv initCompileEnv genv tlm rewrites = CompileEnv diff --git a/src/Agda2Hs/Compile/Function.hs b/src/Agda2Hs/Compile/Function.hs index 8478b3a8..a5a425b9 100644 --- a/src/Agda2Hs/Compile/Function.hs +++ b/src/Agda2Hs/Compile/Function.hs @@ -394,31 +394,11 @@ checkTransparentPragma def = compileFun False def >>= \case "A transparent function must have exactly one non-erased argument and return it unchanged." --- | Ensure a definition can be defined as inline. +-- | Mark a definition as one that should be inlined. checkInlinePragma :: Definition -> C () -checkInlinePragma def@Defn{defName = f} = do - let Function{funClauses = cs} = theDef def - case filter (isJust . clauseBody) cs of - [c] -> - unlessM (allowedPats (namedClausePats c)) $ agda2hsErrorM $ - "Cannot make function" <+> prettyTCM (defName def) <+> "inlinable." <+> - "Inline functions can only use variable patterns or transparent record constructor patterns." - _ -> - agda2hsErrorM $ - "Cannot make function" <+> prettyTCM f <+> "inlinable." <+> - "An inline function must have exactly one clause." - - where allowedPat :: DeBruijnPattern -> C Bool - allowedPat VarP{} = pure True - -- only allow matching on (unboxed) record constructors - allowedPat (ConP ch ci cargs) = - isUnboxConstructor (conName ch) >>= \case - Just _ -> allowedPats cargs - Nothing -> pure False - allowedPat _ = pure False - - allowedPats :: NAPs -> C Bool - allowedPats pats = allM (allowedPat . dget . dget) pats +checkInlinePragma def@(Defn { defName = q , theDef = df }) = do + let qs = fromMaybe [] $ getMutual_ df + addInlineSymbols $ q : qs checkCompileToFunctionPragma :: Definition -> String -> C () checkCompileToFunctionPragma def s = noCheckNames $ do diff --git a/src/Agda2Hs/Compile/Term.hs b/src/Agda2Hs/Compile/Term.hs index dea0a24f..1b93bb8e 100644 --- a/src/Agda2Hs/Compile/Term.hs +++ b/src/Agda2Hs/Compile/Term.hs @@ -171,8 +171,8 @@ compileDef f ty args | Just sem <- isSpecialDef f = do sem ty args compileDef f ty args = - ifM (isTransparentFunction f) (compileErasedApp ty args) $ - ifM (isInlinedFunction f) (compileInlineFunctionApp f ty args) $ do + ifM (isTransparentFunction f) (compileErasedApp ty args) $ do + reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function:" <+> prettyTCM f let defMod = qnameModule f @@ -458,6 +458,9 @@ compileTerm ty v = do v <- instantiate v + toInline <- getInlineSymbols + v <- locallyReduceDefs (OnlyReduceDefs toInline) $ reduce v + let bad s t = agda2hsErrorM $ vcat [ text "cannot compile" <+> text (s ++ ":") , nest 2 $ prettyTCM t @@ -465,7 +468,8 @@ compileTerm ty v = do reduceProjectionLike v >>= \case - Def f es -> do + v@(Def f es) -> do + whenM (isInlinedFunction f) $ bad "inlined function" v ty <- defType <$> getConstInfo f compileSpined (compileDef f ty) (Def f) ty es diff --git a/src/Agda2Hs/Compile/Types.hs b/src/Agda2Hs/Compile/Types.hs index eed48235..bb87924f 100644 --- a/src/Agda2Hs/Compile/Types.hs +++ b/src/Agda2Hs/Compile/Types.hs @@ -29,6 +29,8 @@ data GlobalEnv = GlobalEnv { globalOptions :: Options , compileToMap :: IORef (Map QName QName) -- ^ names with a compile-to pragma + , inlineSymbols :: IORef (Set QName) + -- ^ names of functions that should be inlined } type ModuleEnv = TopLevelModuleName diff --git a/src/Agda2Hs/Compile/Utils.hs b/src/Agda2Hs/Compile/Utils.hs index a1ec2164..b228d0be 100644 --- a/src/Agda2Hs/Compile/Utils.hs +++ b/src/Agda2Hs/Compile/Utils.hs @@ -12,6 +12,8 @@ import Data.List ( isPrefixOf, stripPrefix ) import Data.Maybe ( isJust ) import qualified Data.Map as M import Data.String ( IsString(..) ) +import Data.Set ( Set ) +import qualified Data.Set as S import GHC.Stack (HasCallStack) @@ -281,8 +283,28 @@ isTupleProjection q = isTransparentFunction :: QName -> C Bool isTransparentFunction q = (== TransparentPragma) <$> getPragma q +getInlineSymbols :: C (Set QName) +getInlineSymbols = do + ilSetRef <- asks $ inlineSymbols . globalEnv + liftIO $ readIORef ilSetRef + +debugInlineSymbols :: C () +debugInlineSymbols = do + ilSetRef <- asks $ inlineSymbols . globalEnv + ilSet <- liftIO $ readIORef ilSetRef + reportSDoc "agda2hs.compile.inline" 50 $ text $ + show $ map prettyShow $ S.toList ilSet + isInlinedFunction :: QName -> C Bool -isInlinedFunction q = (== InlinePragma) <$> getPragma q +isInlinedFunction q = S.member q <$> getInlineSymbols + +addInlineSymbols :: [QName] -> C () +addInlineSymbols qs = do + reportSDoc "agda2hs.compile.inline" 15 $ + "Adding inline rules for" <+> pretty qs + ilSetRef <- asks $ inlineSymbols . globalEnv + liftIO $ modifyIORef ilSetRef $ \s -> foldr S.insert s qs + debugCompileToMap :: C () debugCompileToMap = do diff --git a/test/Fail/Inline.agda b/test/Fail/Inline.agda index acc0b041..fe848239 100644 --- a/test/Fail/Inline.agda +++ b/test/Fail/Inline.agda @@ -6,3 +6,8 @@ tail' : List a → List a tail' (x ∷ xs) = xs tail' [] = [] {-# COMPILE AGDA2HS tail' inline #-} + +test : List a → List a +test = tail' + +{-# COMPILE AGDA2HS test #-} diff --git a/test/Fail/Inline2.agda b/test/Fail/Inline2.agda index 336f13f0..cb224337 100644 --- a/test/Fail/Inline2.agda +++ b/test/Fail/Inline2.agda @@ -5,3 +5,8 @@ open import Haskell.Prelude tail' : (xs : List a) → @0 {{ NonEmpty xs }} → List a tail' (x ∷ xs) = xs {-# COMPILE AGDA2HS tail' inline #-} + +test : (xs : List a) → @0 {{ NonEmpty xs }} → List a +test = tail' + +{-# COMPILE AGDA2HS test #-} diff --git a/test/golden/Inline.err b/test/golden/Inline.err index 88267052..aaf4ce67 100644 --- a/test/golden/Inline.err +++ b/test/golden/Inline.err @@ -1,3 +1,4 @@ -test/Fail/Inline.agda:5.1-6: error: [CustomBackendError] +test/Fail/Inline.agda:10.1-5: error: [CustomBackendError] agda2hs: - Cannot make function tail' inlinable. An inline function must have exactly one clause. + cannot compile inlined function: + tail' diff --git a/test/golden/Inline2.err b/test/golden/Inline2.err index 69a73d65..047e1a12 100644 --- a/test/golden/Inline2.err +++ b/test/golden/Inline2.err @@ -1,3 +1,4 @@ -test/Fail/Inline2.agda:5.1-6: error: [CustomBackendError] +test/Fail/Inline2.agda:9.1-5: error: [CustomBackendError] agda2hs: - Cannot make function tail' inlinable. Inline functions can only use variable patterns or transparent record constructor patterns. + cannot compile inlined function: + tail' From 086bb62f28479971c23cf7839d226c48bbbf9230 Mon Sep 17 00:00:00 2001 From: Jesper Cockx Date: Mon, 6 Oct 2025 17:57:49 +0200 Subject: [PATCH 2/4] Add test RuntimeCast.agda relying on advanced inlining --- test/AllTests.agda | 2 + test/RuntimeCast.agda | 85 ++++++++++++++++++++++++++++++++++++++ test/golden/AllTests.hs | 1 + test/golden/RuntimeCast.hs | 15 +++++++ 4 files changed, 103 insertions(+) create mode 100644 test/RuntimeCast.agda create mode 100644 test/golden/RuntimeCast.hs diff --git a/test/AllTests.agda b/test/AllTests.agda index 6bbc1ae8..753b4067 100644 --- a/test/AllTests.agda +++ b/test/AllTests.agda @@ -101,6 +101,7 @@ import Issue409 import Issue346 import Issue408 import CompileTo +import RuntimeCast {-# FOREIGN AGDA2HS import Issue14 @@ -199,4 +200,5 @@ import Issue409 import Issue346 import Issue408 import CompileTo +import RuntimeCast #-} diff --git a/test/RuntimeCast.agda b/test/RuntimeCast.agda new file mode 100644 index 00000000..2cac8300 --- /dev/null +++ b/test/RuntimeCast.agda @@ -0,0 +1,85 @@ +{-# OPTIONS --erasure #-} + +open import Haskell.Prelude +open import Haskell.Control.Exception +open import Haskell.Extra.Dec +open import Haskell.Extra.Refinement +open import Haskell.Law.Ord + +variable + A A' B B' C C' : Set + P P' Q Q' : A → Set + +it : {{A}} → A +it {{x}} = x + +data _~_ : (A : Set) (B : Set) → Set₁ +cast : A ~ B → A → B +cast' : A ~ B → B → A + +data _~_ where + refl : A ~ A + + assert-pre-left : ∀ {A : Set} {B : @0 A → Set} + → {{Dec A}} + → ({{@0 x : A}} → B x ~ B') + → ({{@0 x : A}} → B x) ~ B' + + assert-pre-right : ∀ {A : Set} {B' : @0 A → Set} + → {{Dec A}} + → ({{@0 x : A}} → B ~ B' x) + → B ~ ({{@0 x : A}} → B' x) + + assert-post-left : ∀ {A : Set} {@0 B : A → Set} + → {{∀ {x} → Dec (B x)}} + → A ~ A' + → ∃ A B ~ A' + + assert-post-right : ∀ {A : Set} {@0 B' : A' → Set} + → {{∀ {x'} → Dec (B' x')}} + → A ~ A' + → A ~ ∃ A' B' + + cong-pi : {B : @0 A → Set} {B' : @0 A' → Set} + → (eA : A ~ A') → (eB : (x : A) (x' : A') → B x ~ B' x') + → ((x : A) → B x) ~ ((x : A') → B' x) + +cast refl x = x +cast (assert-pre-left {A = A} eB) x = assert A (cast eB x) +cast (assert-pre-right eB) x = cast eB x +cast (assert-post-left eA) (x ⟨ _ ⟩) = cast eA x +cast (assert-post-right {B' = B'} eA) x = assert (B' x') (x' ⟨⟩) + where x' = cast eA x +cast (cong-pi {A = A} eA eB) f x' = cast (eB x x') (f x) + where x = cast' eA x' + +cast' refl x' = x' +cast' (assert-pre-left eB) x' = cast' eB x' +cast' (assert-pre-right {A = A} eB) x' = assert A (cast' eB x') +cast' (assert-post-left {B = B} eA) x' = assert (B x) (x ⟨⟩) + where x = cast' eA x' +cast' (assert-post-right eA) (x' ⟨ _ ⟩) = cast' eA x' +cast' (cong-pi eA eB) f x = cast' (eB x x') (f x') + where x' = cast eA x + +{-# COMPILE AGDA2HS cast inline #-} +{-# COMPILE AGDA2HS cast' inline #-} + +module Sqrt where + + postulate + mySqrt : (x : Double) → @0 {{IsTrue (x >= 0)}} → Double + + {-# COMPILE AGDA2HS mySqrt #-} + + eqr : ((x : Double) → @0 {{IsTrue (x >= 0)}} → Double) ~ + ((x : Double) → ∃ Double (λ v → IsTrue (v >= 0) × IsTrue ((abs (x - v * v) <= 0.01)))) + eqr = cong-pi refl (λ x x' → assert-pre-left (assert-post-right refl)) + + {-# COMPILE AGDA2HS eqr inline #-} + + checkedSqrt : (x : Double) → ∃ Double (λ y → IsTrue (y >= 0) × IsTrue (abs (x - y * y) <= 0.01)) + checkedSqrt = cast eqr mySqrt + + {-# COMPILE AGDA2HS checkedSqrt #-} + diff --git a/test/golden/AllTests.hs b/test/golden/AllTests.hs index 8968a53a..e4eb808b 100644 --- a/test/golden/AllTests.hs +++ b/test/golden/AllTests.hs @@ -96,4 +96,5 @@ import Issue409 import Issue346 import Issue408 import CompileTo +import RuntimeCast diff --git a/test/golden/RuntimeCast.hs b/test/golden/RuntimeCast.hs new file mode 100644 index 00000000..9cfb0203 --- /dev/null +++ b/test/golden/RuntimeCast.hs @@ -0,0 +1,15 @@ +module RuntimeCast where + +import Control.Exception (assert) + +mySqrt :: Double -> Double +mySqrt = error "postulate: Double -> Double" + +checkedSqrt :: Double -> Double +checkedSqrt + = \ x' -> + assert (x' >= 0) + (assert + (mySqrt x' >= 0 && abs (x' - mySqrt x' * mySqrt x') <= 1.0e-2) + (mySqrt x')) + From 7b20b3c8ce38aac6e848720dada37ffe0164a221 Mon Sep 17 00:00:00 2001 From: Jesper Cockx Date: Fri, 10 Oct 2025 15:46:41 +0200 Subject: [PATCH 3/4] [ #431 ] Eta-expand definition of rTail in EraseType.agda test case --- test/EraseType.agda | 2 +- test/golden/EraseType.hs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/EraseType.agda b/test/EraseType.agda index a49b98c2..ffaf732e 100644 --- a/test/EraseType.agda +++ b/test/EraseType.agda @@ -29,6 +29,6 @@ testCong = singCong (1 +_) testSingleton {-# COMPILE AGDA2HS testCong #-} rTail : ∀ {@0 x xs} → Singleton {a = List Int} (x ∷ xs) → Singleton xs -rTail = singTail +rTail ys = singTail ys {-# COMPILE AGDA2HS rTail #-} diff --git a/test/golden/EraseType.hs b/test/golden/EraseType.hs index fb05e37f..85e52ed7 100644 --- a/test/golden/EraseType.hs +++ b/test/golden/EraseType.hs @@ -16,5 +16,5 @@ testCong :: Int testCong = 1 + testSingleton rTail :: [Int] -> [Int] -rTail = \ ys -> tail ys +rTail ys = tail ys From b1ce1ef42cb6344bd7eef404ad3e437a5db0919f Mon Sep 17 00:00:00 2001 From: Jesper Cockx Date: Fri, 10 Oct 2025 16:24:17 +0200 Subject: [PATCH 4/4] [ #431 ] Comment out partial application examples in Inlining.agda test case --- test/Inlining.agda | 14 +++++++------- test/golden/Inlining.hs | 6 ------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/test/Inlining.agda b/test/Inlining.agda index 602343e9..9181a09c 100644 --- a/test/Inlining.agda +++ b/test/Inlining.agda @@ -34,10 +34,10 @@ test2 x y = mapWrap2 _+_ x y {-# COMPILE AGDA2HS test2 #-} -- partial application of inline function -test3 : Wrap Int → Wrap Int → Wrap Int -test3 x = mapWrap2 _+_ x -{-# COMPILE AGDA2HS test3 #-} - -test4 : Wrap Int → Wrap Int → Wrap Int -test4 = mapWrap2 _+_ -{-# COMPILE AGDA2HS test4 #-} +-- test3 : Wrap Int → Wrap Int → Wrap Int +-- test3 x = mapWrap2 _+_ x +-- {-# COMPILE AGDA2HS test3 #-} +-- +-- test4 : Wrap Int → Wrap Int → Wrap Int +-- test4 = mapWrap2 _+_ +-- {-# COMPILE AGDA2HS test4 #-} diff --git a/test/golden/Inlining.hs b/test/golden/Inlining.hs index 6df0af7f..40ef1c08 100644 --- a/test/golden/Inlining.hs +++ b/test/golden/Inlining.hs @@ -9,9 +9,3 @@ test1 x = 1 + x test2 :: Int -> Int -> Int test2 x y = x + y -test3 :: Int -> Int -> Int -test3 x = \ y -> x + y - -test4 :: Int -> Int -> Int -test4 = \ x y -> x + y -