diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs index 2a2300329f..99bba447a9 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs @@ -4,9 +4,18 @@ module U.Codebase.Sqlite.Branch.Format HashBranchFormat, BranchLocalIds, BranchLocalIds' (..), + branchLocalIdsText_, + branchLocalIdsDefn_, + branchLocalIdsPatch_, + branchLocalIdsChildren_, HashBranchLocalIds, SyncBranchFormat, SyncBranchFormat' (..), + syncBranchFormatTexts_, + syncBranchFormatDefns_, + syncBranchFormatPatches_, + syncBranchFormatChildren_, + syncBranchFormatParents_, LocalBranchBytes (..), localToDbBranch, localToDbDiff, @@ -16,6 +25,7 @@ module U.Codebase.Sqlite.Branch.Format ) where +import Control.Lens import Data.Vector (Vector) import Data.Vector qualified as Vector import U.Codebase.HashTags @@ -103,6 +113,18 @@ data BranchLocalIds' t d p c = LocalIds } deriving (Show, Eq) +branchLocalIdsText_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t' d p c) t t' +branchLocalIdsText_ f (LocalIds t d p c) = LocalIds <$> traverse f t <*> pure d <*> pure p <*> pure c + +branchLocalIdsDefn_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d' p c) d d' +branchLocalIdsDefn_ f (LocalIds t d p c) = LocalIds <$> pure t <*> traverse f d <*> pure p <*> pure c + +branchLocalIdsPatch_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d p' c) p p' +branchLocalIdsPatch_ f (LocalIds t d p c) = LocalIds <$> pure t <*> pure d <*> traverse f p <*> pure c + +branchLocalIdsChildren_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d p c') c c' +branchLocalIdsChildren_ f (LocalIds t d p c) = LocalIds <$> pure t <*> pure d <*> pure p <*> traverse f c + -- | Bytes encoding a LocalBranch newtype LocalBranchBytes = LocalBranchBytes ByteString deriving (Show, Eq, Ord) @@ -112,6 +134,31 @@ data SyncBranchFormat' parent text defn patch child | SyncDiff parent (BranchLocalIds' text defn patch child) LocalBranchBytes deriving (Eq, Show) +syncBranchFormatTexts_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text' defn patch child) text text' +syncBranchFormatTexts_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsText_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsText_ %%~ f) <*> pure bytes + +syncBranchFormatDefns_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn' patch child) defn defn' +syncBranchFormatDefns_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsDefn_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsDefn_ %%~ f) <*> pure bytes + +syncBranchFormatPatches_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn patch' child) patch patch' +syncBranchFormatPatches_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsPatch_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsPatch_ %%~ f) <*> pure bytes + +syncBranchFormatChildren_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn patch child') child child' +syncBranchFormatChildren_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsChildren_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsChildren_ %%~ f) <*> pure bytes + +syncBranchFormatParents_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent' text defn patch child) parent parent' +syncBranchFormatParents_ f = \case + SyncFull li bytes -> pure $ SyncFull li bytes + SyncDiff parent li bytes -> SyncDiff <$> f parent <*> pure li <*> pure bytes + type SyncBranchFormat = SyncBranchFormat' BranchObjectId TextId ObjectId PatchObjectId (BranchObjectId, CausalHashId) localToBranch :: (Ord t, Ord d) => BranchLocalIds' t d p c -> LocalBranch -> (Branch.Full.Branch' t d p c) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs index 87f532bf25..b8b35f1555 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs @@ -3,9 +3,12 @@ module U.Codebase.Sqlite.Causal GDbCausal (..), SyncCausalFormat, SyncCausalFormat' (..), + syncCausalFormatCausalHash_, + syncCausalFormatValueHash_, ) where +import Control.Lens import Data.Vector (Vector) import U.Codebase.Sqlite.DbId (BranchHashId, CausalHashId) import Unison.Prelude @@ -24,4 +27,10 @@ data SyncCausalFormat' causalHash valueHash = SyncCausalFormat } deriving stock (Eq, Show) +syncCausalFormatCausalHash_ :: Traversal (SyncCausalFormat' causalHash valueHash) (SyncCausalFormat' causalHash' valueHash) causalHash causalHash' +syncCausalFormatCausalHash_ f (SyncCausalFormat v p) = SyncCausalFormat v <$> traverse f p + +syncCausalFormatValueHash_ :: Lens (SyncCausalFormat' causalHash valueHash) (SyncCausalFormat' causalHash valueHash') valueHash valueHash' +syncCausalFormatValueHash_ f (SyncCausalFormat v p) = (\v' -> SyncCausalFormat v' p) <$> f v + type SyncCausalFormat = SyncCausalFormat' CausalHashId BranchHashId diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs index 5752d2dd87..6e492f5183 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs @@ -2,11 +2,13 @@ module U.Codebase.Sqlite.Decl.Format where +import Control.Lens import Data.Vector (Vector) import U.Codebase.Decl (DeclR) import U.Codebase.Reference (Reference') import U.Codebase.Sqlite.DbId (ObjectId, TextId) import U.Codebase.Sqlite.LocalIds (LocalDefnId, LocalIds', LocalTextId) +import U.Codebase.Sqlite.LocalIds qualified as LocalIds import U.Codebase.Sqlite.Symbol (Symbol) import U.Codebase.Type qualified as Type import U.Core.ABT qualified as ABT @@ -38,10 +40,24 @@ data SyncDeclFormat' t d = SyncDecl (SyncLocallyIndexedComponent' t d) deriving stock (Eq, Show) +syncDeclFormatTexts_ :: Traversal (SyncDeclFormat' t d) (SyncDeclFormat' t' d) t t' +syncDeclFormatTexts_ f (SyncDecl c) = SyncDecl <$> syncLocallyIndexedComponentTexts_ f c + +syncDeclFormatDefns_ :: Traversal (SyncDeclFormat' t d) (SyncDeclFormat' t d') d d' +syncDeclFormatDefns_ f (SyncDecl c) = SyncDecl <$> syncLocallyIndexedComponentDefns_ f c + newtype SyncLocallyIndexedComponent' t d = SyncLocallyIndexedComponent (Vector (LocalIds' t d, ByteString)) deriving stock (Eq, Show) +syncLocallyIndexedComponentTexts_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t' d) t t' +syncLocallyIndexedComponentTexts_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.t_ %%~ f) + +syncLocallyIndexedComponentDefns_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t d') d d' +syncLocallyIndexedComponentDefns_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.h_ %%~ f) + -- [OldDecl] ==map==> [NewDecl] ==number==> [(NewDecl, Int)] ==sort==> [(NewDecl, Int)] ==> permutation is map snd of that -- type List a = Nil | Cons (List a) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs index 92cbb58828..1d034627bc 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs @@ -1,5 +1,6 @@ module U.Codebase.Sqlite.Entity where +import Control.Lens import U.Codebase.Sqlite.Branch.Format qualified as Namespace import U.Codebase.Sqlite.Causal qualified as Causal import U.Codebase.Sqlite.DbId (BranchHashId, BranchObjectId, CausalHashId, HashId, ObjectId, PatchObjectId, TextId) @@ -33,3 +34,77 @@ entityType = \case N _ -> NamespaceType P _ -> PatchType C _ -> CausalType + +texts_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text' hash defn patch branchh branch causal) text text' +texts_ f = \case + TC tcf -> TC <$> Term.syncTermFormatTexts_ f tcf + DC dcf -> DC <$> Decl.syncDeclFormatTexts_ f dcf + N ncf -> N <$> Namespace.syncBranchFormatTexts_ f ncf + P pcf -> P <$> Patch.syncPatchFormatTexts_ f pcf + C ccf -> pure (C ccf) + +hashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash' defn patch branchh branch causal) hash hash' +hashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> pure (N ncf) + P pcf -> P <$> Patch.syncPatchFormatHashes_ f pcf + C ccf -> pure (C ccf) + +defns_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn' patch branchh branch causal) defn defn' +defns_ f = \case + TC tcf -> TC <$> Term.syncTermFormatDefns_ f tcf + DC dcf -> DC <$> Decl.syncDeclFormatDefns_ f dcf + N ncf -> N <$> Namespace.syncBranchFormatDefns_ f ncf + P pcf -> P <$> Patch.syncPatchFormatDefns_ f pcf + C ccf -> pure (C ccf) + +patches_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch' branchh branch causal) patch patch' +patches_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> N <$> Namespace.syncBranchFormatPatches_ f ncf + P pcf -> P <$> Patch.syncPatchFormatParents_ f pcf + C ccf -> pure (C ccf) + +branchHashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh' branch causal) branchh branchh' +branchHashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> pure (N ncf) + P pcf -> pure (P pcf) + C ccf -> C <$> Causal.syncCausalFormatValueHash_ f ccf + +branches_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh branch' causal) branch branch' +branches_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> + ( case ncf of + Namespace.SyncFull li bytes -> Namespace.SyncFull <$> (li & Namespace.branchLocalIdsChildren_ . _1 %%~ f) <*> pure bytes + Namespace.SyncDiff parent li bytes -> + Namespace.SyncDiff + <$> (f parent) + <*> (li & Namespace.branchLocalIdsChildren_ . _1 %%~ f) + <*> pure bytes + ) + <&> N + P pcf -> pure (P pcf) + C ccf -> pure (C ccf) + +causalHashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh branch causal') causal causal' +causalHashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> + ( case ncf of + Namespace.SyncFull li bytes -> Namespace.SyncFull <$> (li & Namespace.branchLocalIdsChildren_ . _2 %%~ f) <*> pure bytes + Namespace.SyncDiff parent li bytes -> + Namespace.SyncDiff + <$> (pure parent) + <*> (li & Namespace.branchLocalIdsChildren_ . _2 %%~ f) + <*> pure bytes + ) + <&> N + P pcf -> pure (P pcf) + C ccf -> C <$> Causal.syncCausalFormatCausalHash_ f ccf diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs index 452df27904..34ad6d855f 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs @@ -2,9 +2,16 @@ module U.Codebase.Sqlite.Patch.Format ( PatchFormat (..), PatchLocalIds, PatchLocalIds' (..), + patchLocalIdsTexts_, + patchLocalIdsHashes_, + patchLocalIdsDefns_, HashPatchLocalIds, SyncPatchFormat, SyncPatchFormat' (..), + syncPatchFormatParents_, + syncPatchFormatTexts_, + syncPatchFormatHashes_, + syncPatchFormatDefns_, applyPatchDiffs, localPatchToPatch, localPatchToPatch', @@ -13,6 +20,7 @@ module U.Codebase.Sqlite.Patch.Format ) where +import Control.Lens import Data.Map.Strict qualified as Map import Data.Set qualified as Set import Data.Vector (Vector) @@ -42,6 +50,15 @@ data PatchLocalIds' t h d = LocalIds } deriving stock (Eq, Show) +patchLocalIdsTexts_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t' h d) t t' +patchLocalIdsTexts_ f (LocalIds t h d) = LocalIds <$> traverse f t <*> pure h <*> pure d + +patchLocalIdsHashes_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t h' d) h h' +patchLocalIdsHashes_ f (LocalIds t h d) = LocalIds <$> pure t <*> traverse f h <*> pure d + +patchLocalIdsDefns_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t h d') d d' +patchLocalIdsDefns_ f (LocalIds t h d) = LocalIds <$> pure t <*> pure h <*> traverse f d + type SyncPatchFormat = SyncPatchFormat' PatchObjectId TextId HashId ObjectId data SyncPatchFormat' parent text hash defn @@ -50,6 +67,26 @@ data SyncPatchFormat' parent text hash defn SyncDiff parent (PatchLocalIds' text hash defn) ByteString deriving stock (Eq, Show) +syncPatchFormatParents_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p' text hash defn) p p' +syncPatchFormatParents_ f = \case + (SyncDiff p li b) -> SyncDiff <$> f p <*> pure li <*> pure b + (SyncFull li b) -> SyncFull <$> pure li <*> pure b + +syncPatchFormatTexts_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text' hash defn) text text' +syncPatchFormatTexts_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsTexts_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsTexts_ %%~ f) <*> pure b + +syncPatchFormatHashes_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text hash' defn) hash hash' +syncPatchFormatHashes_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsHashes_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsHashes_ %%~ f) <*> pure b + +syncPatchFormatDefns_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text hash defn') defn defn' +syncPatchFormatDefns_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsDefns_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsDefns_ %%~ f) <*> pure b + -- | Apply a list of patch diffs to a patch, left to right. applyPatchDiffs :: Patch -> [PatchDiff] -> Patch applyPatchDiffs = diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index 30e56300fd..6b96067540 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -218,13 +218,16 @@ module U.Codebase.Sqlite.Queries EntityLocation (..), entityExists, entityLocation, + entityLocationSyncV3, expectEntity, syncToTempEntity, insertTempEntity, + insertTempEntitySyncV3, saveTempEntityInMain, expectTempEntity, deleteTempEntity, clearTempEntityTables, + streamTempEntitiesSyncV3, -- * elaborate hashes elaborateHashes, @@ -254,6 +257,7 @@ module U.Codebase.Sqlite.Queries addUpdateBranchTable, addDerivedDependentsByDependencyIndex, addUpgradeBranchTable, + addSyncV3TempTables, -- ** schema version currentSchemaVersion, @@ -300,6 +304,7 @@ import Data.Aeson qualified as Aeson import Data.Aeson.Text qualified as Aeson import Data.Bitraversable (bitraverse) import Data.ByteString.Lazy (LazyByteString) +import Data.ByteString.Lazy.Char8 qualified as BL import Data.Bytes.Put (runPutS) import Data.Foldable qualified as Foldable import Data.List qualified as List @@ -413,7 +418,7 @@ type TextPathSegments = [Text] -- * main squeeze currentSchemaVersion :: SchemaVersion -currentSchemaVersion = 22 +currentSchemaVersion = 23 runCreateSql :: Transaction () runCreateSql = @@ -499,6 +504,10 @@ addUpgradeBranchTable :: Transaction () addUpgradeBranchTable = executeStatements $(embedProjectStringFile "sql/019-add-upgrade-branch-table.sql") +addSyncV3TempTables :: Transaction () +addSyncV3TempTables = + executeStatements $(embedProjectStringFile "sql/020-add-sync-v3-temp-tables.sql") + schemaVersion :: Transaction SchemaVersion schemaVersion = queryOneCol @@ -2232,6 +2241,16 @@ entityLocation hash = True -> Just EntityInTempStorage False -> Nothing +entityLocationSyncV3 :: Hash32 -> Transaction (Maybe EntityLocation) +entityLocationSyncV3 hash = + entityExists hash >>= \case + True -> pure (Just EntityInMainStorage) + False -> do + let theSql = [sql| SELECT EXISTS (SELECT 1 FROM syncv3_temp_entity WHERE entity_hash = :hash) |] + queryOneCol theSql <&> \case + True -> Just EntityInTempStorage + False -> Nothing + -- | Does this entity already exist in the database, i.e. in the `object` or `causal` table? entityExists :: Hash32 -> Transaction Bool entityExists hash = do @@ -2285,6 +2304,15 @@ insertTempEntity entityHash entity missingDependencies = do entityType = Entity.entityType entity +insertTempEntitySyncV3 :: Hash32 -> Text -> Hash32 -> Int64 -> BL.ByteString -> Transaction () +insertTempEntitySyncV3 rootCausal entityKind entityHash entityDepth entityBlob = do + execute + [sql| + INSERT INTO syncv3_temp_entity (root_causal, entity_hash, entity_kind, entity_data, entity_depth) + VALUES (:rootCausal, :entityHash, :entityKind, :entityBlob, :entityDepth) + ON CONFLICT DO NOTHING + |] + -- | Delete a row from the `temp_entity` table, if it exists. deleteTempEntity :: Hash32 -> Transaction () deleteTempEntity hash = @@ -4005,3 +4033,14 @@ saveSquashResult bhId chId = ) ON CONFLICT DO NOTHING |] + +streamTempEntitiesSyncV3 :: Hash32 -> (Transaction (Maybe (Hash32, BL.ByteString)) -> Transaction a) -> Transaction a +streamTempEntitiesSyncV3 rootCausalHash action = do + Sqlite.queryStreamRow @(Hash32, BL.ByteString) + [sql| + SELECT entity_hash, entity_data + FROM syncv3_temp_entity + WHERE root_causal = :rootCausalHash + ORDER BY entity_depth ASC + |] + action diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs index f06fc70ec3..8e5a722dcc 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs @@ -2,6 +2,7 @@ module U.Codebase.Sqlite.Term.Format where +import Control.Lens import Data.ByteString (ByteString) import Data.Text (Text) import Data.Vector (Vector) @@ -9,6 +10,7 @@ import U.Codebase.Reference (Reference') import U.Codebase.Referent (Referent') import U.Codebase.Sqlite.DbId (ObjectId, TextId) import U.Codebase.Sqlite.LocalIds (LocalDefnId, LocalIds', LocalTextId, WatchLocalIds) +import U.Codebase.Sqlite.LocalIds qualified as LocalIds import U.Codebase.Sqlite.Reference qualified as Sqlite import U.Codebase.Sqlite.Symbol (Symbol) import U.Codebase.Term qualified as Term @@ -51,6 +53,14 @@ newtype SyncLocallyIndexedComponent' t d = SyncLocallyIndexedComponent (Vector (LocalIds' t d, ByteString)) deriving stock (Eq, Show) +syncLocallyIndexedComponentTexts_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t' d) t t' +syncLocallyIndexedComponentTexts_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.t_ %%~ f) + +syncLocallyIndexedComponentDefns :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t d') d d' +syncLocallyIndexedComponentDefns f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.h_ %%~ f) + {- message = "hello, world" -> ABT { ... { Term.F.Text "hello, world" } } -> hashes to (#abc, 0) program = printLine message -> ABT { ... { Term.F.App (ReferenceBuiltin ##io.PrintLine) (Reference #abc 0) } } -> hashes to (#def, 0) @@ -130,6 +140,14 @@ type SyncTermFormat = SyncTermFormat' TextId ObjectId data SyncTermFormat' t d = SyncTerm (SyncLocallyIndexedComponent' t d) deriving stock (Eq, Show) +syncTermFormatTexts_ :: Traversal (SyncTermFormat' t d) (SyncTermFormat' t' d) t t' +syncTermFormatTexts_ f (SyncTerm slic) = + SyncTerm <$> (slic & syncLocallyIndexedComponentTexts_ %%~ f) + +syncTermFormatDefns_ :: Traversal (SyncTermFormat' t d) (SyncTermFormat' t d') d d' +syncTermFormatDefns_ f (SyncTerm slic) = + SyncTerm <$> (slic & syncLocallyIndexedComponentDefns %%~ f) + data WatchResultFormat = WatchResult WatchLocalIds Term diff --git a/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql new file mode 100644 index 0000000000..5bd3eba986 --- /dev/null +++ b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql @@ -0,0 +1,14 @@ +-- Add a new table for storing entities which are currently being synced + +CREATE TABLE syncv3_temp_entity ( + root_causal INTEGER NOT NULL, + entity_hash TEXT NOT NULL, + entity_kind TEXT NOT NULL, + entity_data BLOB NOT NULL, + entity_depth INTEGER NOT NULL, + PRIMARY KEY (root_causal, entity_hash) +) WITHOUT ROWID; + +-- We _could_ add an index on (root_causal, entity_depth), since that's how we'll +-- be querying this table, but we only run the query exactly once per sync, so it's +-- probably faster to sort on query rather than maintaining the index. diff --git a/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal b/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal index 0e5780e7f9..923a04bfb1 100644 --- a/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal +++ b/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -29,6 +29,7 @@ extra-source-files: sql/017-add-update-branch-table.sql sql/018-add-derived-dependents-by-dependency-index.sql sql/019-add-upgrade-branch-table.sql + sql/020-add-sync-v3-temp-tables.sql sql/create.sql source-repository head diff --git a/lib/unison-hash/package.yaml b/lib/unison-hash/package.yaml index 523de61905..92be9dba76 100644 --- a/lib/unison-hash/package.yaml +++ b/lib/unison-hash/package.yaml @@ -10,6 +10,7 @@ dependencies: - deepseq - unison-prelude - unison-util-base32hex + - hashable library: source-dirs: src diff --git a/lib/unison-hash/src/Unison/Hash32.hs b/lib/unison-hash/src/Unison/Hash32.hs index 97e7c201ed..4b74d89631 100644 --- a/lib/unison-hash/src/Unison/Hash32.hs +++ b/lib/unison-hash/src/Unison/Hash32.hs @@ -18,6 +18,7 @@ module Unison.Hash32 ) where +import Data.Hashable (Hashable) import U.Util.Base32Hex (Base32Hex (..)) import Unison.Hash (Hash) import Unison.Hash qualified as Hash @@ -30,7 +31,7 @@ import Unison.Prelude -- * @unison-util-base32hex-orphans-aeson@ -- * @unison-util-base32hex-orphans-sqlite@ newtype Hash32 = UnsafeFromBase32Hex Base32Hex - deriving (Eq, Ord, Show) via (Text) + deriving (Eq, Ord, Show, Hashable) via (Text) instance From Hash32 Text where from = toText diff --git a/lib/unison-hash/unison-hash.cabal b/lib/unison-hash/unison-hash.cabal index a1a32f0b1b..d68ba201ba 100644 --- a/lib/unison-hash/unison-hash.cabal +++ b/lib/unison-hash/unison-hash.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -54,6 +54,7 @@ library base , bytestring , deepseq + , hashable , unison-prelude , unison-util-base32hex default-language: Haskell2010 diff --git a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs index 5b13751a16..63b528166e 100644 --- a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs +++ b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs @@ -89,7 +89,8 @@ migrations regionVar getDeclType termBuffer declBuffer rootCodebasePath = sqlMigration 19 Q.addMergeBranchTables, sqlMigration 20 Q.addUpdateBranchTable, sqlMigration 21 Q.addDerivedDependentsByDependencyIndex, - sqlMigration 22 Q.addUpgradeBranchTable + sqlMigration 22 Q.addUpgradeBranchTable, + sqlMigration 23 Q.addSyncV3TempTables ] where runT :: Sqlite.Transaction () -> Sqlite.Connection -> IO () diff --git a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs index e95091a164..b31ffaa2c2 100644 --- a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs +++ b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs @@ -85,6 +85,7 @@ createSchema = do Q.addUpdateBranchTable Q.addDerivedDependentsByDependencyIndex Q.addUpgradeBranchTable + Q.addSyncV3TempTables (_, emptyCausalHashId) <- emptyCausalHash (_, ProjectBranchRow {projectId, branchId}) <- insertProjectAndBranch scratchProjectName scratchBranchName emptyCausalHashId diff --git a/unison-cli/package.yaml b/unison-cli/package.yaml index 4418105d7b..844f92d83d 100644 --- a/unison-cli/package.yaml +++ b/unison-cli/package.yaml @@ -53,6 +53,7 @@ library: - megaparsec - memory - mtl + - network - network-simple - network-uri - nonempty-containers @@ -104,6 +105,8 @@ library: - vector - wai - warp + - websockets + - wuss - witch - witherable diff --git a/unison-cli/src/Unison/Cli/DownloadUtils.hs b/unison-cli/src/Unison/Cli/DownloadUtils.hs index 936d5f00fb..6861803dc1 100644 --- a/unison-cli/src/Unison/Cli/DownloadUtils.hs +++ b/unison-cli/src/Unison/Cli/DownloadUtils.hs @@ -27,20 +27,24 @@ import Unison.Codebase.Editor.RemoteRepo qualified as RemoteRepo import Unison.Codebase.Path qualified as Path import Unison.Codebase.ProjectPath (ProjectBranch (..)) import Unison.Core.Project (ProjectAndBranch (..)) +import Unison.Debug qualified as Debug import Unison.NameSegment.Internal qualified as NameSegment import Unison.Prelude import Unison.Share.API.Hash qualified as Share import Unison.Share.Codeserver qualified as Codeserver +import Unison.Share.Codeserver qualified as Share import Unison.Share.Sync qualified as Share import Unison.Share.Sync.Types qualified as Share import Unison.Share.SyncV2 qualified as SyncV2 +import Unison.Share.SyncV3 qualified as SyncV3 import Unison.Share.Types (codeserverBaseURL) import Unison.Sync.Common qualified as Sync.Common import Unison.Sync.Types qualified as Share import Unison.SyncV2.Types qualified as SyncV2 +import Unison.SyncV3.Types qualified as SyncV3 import UnliftIO.Environment qualified as UnliftIO -data SyncVersion = SyncV1 | SyncV2 +data SyncVersion = SyncV1 | SyncV2 | SyncV3 deriving (Eq, Show) -- | The version of the sync protocol to use. @@ -49,7 +53,8 @@ syncVersion = unsafePerformIO do UnliftIO.lookupEnv "UNISON_SYNC_VERSION" <&> \case Just "1" -> SyncV1 - _ -> SyncV2 + Just "2" -> SyncV2 + _ -> SyncV3 -- | Download a project/branch from Share. downloadProjectBranchFromShare :: @@ -70,6 +75,7 @@ downloadProjectBranchFromShare useSquashed branch isPull = (Share.NoSquashedHead, _) -> pure branch.branchHead let causalHash32 = Share.hashJWTHash causalHashJwt exists <- Cli.runTransaction (Queries.causalExistsByHash32 causalHash32) + Debug.debugM Debug.Temp "Downloading using Sync " syncVersion when (not exists) do case syncVersion of SyncV1 -> do @@ -95,6 +101,20 @@ downloadProjectBranchFromShare useSquashed branch isPull = Share.SyncError pullErr -> Output.ShareErrorPullV2 pullErr Share.TransportError err -> Output.ShareErrorTransport err + SyncV3 -> do + Debug.debugLogM Debug.Temp "Using SyncV3 protocol" + let branchRef = SyncV3.BranchRef (into @Text (ProjectAndBranch branch.projectName remoteProjectBranchName)) + let shouldValidate = Codeserver.isCustomCodeserver Codeserver.defaultCodeserver + when isPull $ do + pb <- Cli.getCurrentProjectBranch + currentCausalHash <- Cli.runTransaction $ Ops.expectProjectBranchHead pb.projectId pb.branchId + Cli.respond $ Output.SyncingFromTo currentCausalHash (Sync.Common.hash32ToCausalHash causalHash32) + result <- SyncV3.syncFromCodeserver shouldValidate Share.defaultCodeserver branchRef causalHashJwt + void result & onLeft \err0 -> do + done case err0 of + Share.SyncError _pullErr -> + error "TODO: define SyncV3 pull error and handle it here" + Share.TransportError err -> Output.ShareErrorTransport err pure (Sync.Common.hash32ToCausalHash (Share.hashJWTHash causalHashJwt)) -- | Download loose code from Share. diff --git a/unison-cli/src/Unison/Cli/Monad.hs b/unison-cli/src/Unison/Cli/Monad.hs index 60a9382a1c..66621d4d11 100644 --- a/unison-cli/src/Unison/Cli/Monad.hs +++ b/unison-cli/src/Unison/Cli/Monad.hs @@ -71,6 +71,7 @@ import U.Codebase.Sqlite.DbId (ProjectBranchId, ProjectId) import U.Codebase.Sqlite.Queries qualified as Q import Unison.Auth.CredentialManager (CredentialManager) import Unison.Auth.HTTPClient (AuthenticatedHttpClient) +import Unison.Auth.Tokens (TokenProvider) import Unison.Codebase (Codebase) import Unison.Codebase qualified as Codebase import Unison.Codebase.Editor.Input (Input) @@ -158,6 +159,9 @@ type SourceName = Text -- Get the environment with 'ask'. data Env = Env { authHTTPClient :: AuthenticatedHttpClient, + -- | How to get auth tokens for a given codeserver. + -- Using AuthenticatedHttpClient takes care of this, but websocket connection need to provide auth headers manually. + tokenProvider :: TokenProvider, codebase :: Codebase IO Symbol Ann, credentialManager :: CredentialManager, -- | Generate a unique name. diff --git a/unison-cli/src/Unison/Codebase/Transcript/Runner.hs b/unison-cli/src/Unison/Codebase/Transcript/Runner.hs index b71b895d44..3b87f87d0e 100644 --- a/unison-cli/src/Unison/Codebase/Transcript/Runner.hs +++ b/unison-cli/src/Unison/Codebase/Transcript/Runner.hs @@ -96,7 +96,9 @@ withRunner :: m r withRunner isTest verbosity ucmVersion action = do credMan <- AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient credMan + let tokenProvider :: AuthN.TokenProvider + tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion -- If we're in a transcript test, configure the environment to use a non-existent fzf binary -- so that errors are consistent. @@ -130,6 +132,7 @@ withRunner isTest verbosity ucmVersion action = do ucmVersion baseUrlText authenticatedHTTPClient + tokenProvider credMan stanzas where @@ -138,11 +141,6 @@ withRunner isTest verbosity ucmVersion action = do RTI.withRuntime False RTI.Persistent ucmVersion \runtime -> RTI.withRuntime True RTI.Persistent ucmVersion \sbRuntime -> action runtime sbRuntime - initTranscriptAuthenticatedHTTPClient :: AuthN.CredentialManager -> m AuthN.AuthenticatedHttpClient - initTranscriptAuthenticatedHTTPClient credMan = liftIO $ do - let tokenProvider :: AuthN.TokenProvider - tokenProvider = AuthN.newTokenProvider credMan - AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion isGeneratedBlock :: ProcessedBlock -> Bool isGeneratedBlock = generated . getCommonInfoTags @@ -157,10 +155,11 @@ run :: UCMVersion -> Text -> AuthN.AuthenticatedHttpClient -> + AuthN.TokenProvider -> AuthN.CredentialManager -> Transcript -> IO (Either Error Transcript) -run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticatedHTTPClient credMan transcript = UnliftIO.try do +run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticatedHTTPClient tokenProvider credMan transcript = UnliftIO.try do let behaviors = extractBehaviors $ settings transcript let stanzas' = stanzas transcript httpManager <- HTTP.newManager HTTP.defaultManagerSettings @@ -518,6 +517,7 @@ run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticated let env = Cli.Env { authHTTPClient = authenticatedHTTPClient, + tokenProvider, codebase, credentialManager = credMan, generateUniqueName = do diff --git a/unison-cli/src/Unison/CommandLine/Main.hs b/unison-cli/src/Unison/CommandLine/Main.hs index 845abeee05..36c2801f98 100644 --- a/unison-cli/src/Unison/CommandLine/Main.hs +++ b/unison-cli/src/Unison/CommandLine/Main.hs @@ -26,6 +26,7 @@ import U.Codebase.Sqlite.Queries qualified as Queries import Unison.Auth.CredentialManager qualified as AuthN import Unison.Auth.HTTPClient (AuthenticatedHttpClient) import Unison.Auth.HTTPClient qualified as AuthN +import Unison.Auth.Tokens (TokenProvider) import Unison.Cli.Monad qualified as Cli import Unison.Cli.Pretty qualified as P import Unison.Cli.ProjectUtils qualified as ProjectUtils @@ -146,11 +147,12 @@ main :: Maybe Server.BaseUrl -> UCMVersion -> AuthN.AuthenticatedHttpClient -> + TokenProvider -> AuthN.CredentialManager -> (PP.ProjectPathIds -> IO ()) -> ShouldWatchFiles -> IO () -main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl ucmVersion authHTTPClient credentialManager lspCheckForChanges shouldWatchFiles = do +main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl ucmVersion authHTTPClient tokenProvider credentialManager lspCheckForChanges shouldWatchFiles = do -- we don't like FSNotify's debouncing (it seems to drop later events) -- so we will be doing our own instead let config = FSNotify.defaultConfig @@ -288,6 +290,7 @@ main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl uc { authHTTPClient, codebase, credentialManager, + tokenProvider, loadSource = loadSourceFile, lspCheckForChanges, writeSource, diff --git a/unison-cli/src/Unison/MCP/Cli.hs b/unison-cli/src/Unison/MCP/Cli.hs index 403aa7ab09..889aea5180 100644 --- a/unison-cli/src/Unison/MCP/Cli.hs +++ b/unison-cli/src/Unison/MCP/Cli.hs @@ -102,6 +102,7 @@ cliToMCP projCtx cli = do let cliEnv = Cli.Env { authHTTPClient = authenticatedHTTPClient, + tokenProvider, codebase, credentialManager = credMan, generateUniqueName = do diff --git a/unison-cli/src/Unison/Main.hs b/unison-cli/src/Unison/Main.hs index ab03c6c38d..88f94a3c26 100644 --- a/unison-cli/src/Unison/Main.hs +++ b/unison-cli/src/Unison/Main.hs @@ -54,6 +54,7 @@ import Text.Megaparsec qualified as MP import U.Codebase.Sqlite.Queries qualified as Queries import Unison.Auth.CredentialManager qualified as AuthN import Unison.Auth.HTTPClient qualified as AuthN +import Unison.Auth.Tokens (TokenProvider) import Unison.Auth.Tokens qualified as AuthN import Unison.Cli.ProjectUtils qualified as ProjectUtils import Unison.Codebase (Codebase, CodebasePath) @@ -185,7 +186,8 @@ main version = do let serverUrl = Nothing let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion startProjectPath <- Codebase.runTransaction theCodebase Codebase.expectCurrentProjectPath launch version @@ -195,6 +197,7 @@ main version = do theCodebase [Left fileEvent, Right $ Input.ExecuteI NoProf mainName args, Right Input.QuitI] authenticatedHTTPClient + tokenProvider credMan serverUrl (PP.toIds startProjectPath) @@ -213,7 +216,8 @@ main version = do let serverUrl = Nothing let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion startProjectPath <- Codebase.runTransaction theCodebase Codebase.expectCurrentProjectPath launch version @@ -223,6 +227,7 @@ main version = do theCodebase [Left fileEvent, Right $ Input.ExecuteI NoProf mainName args, Right Input.QuitI] authenticatedHTTPClient + tokenProvider credMan serverUrl (PP.toIds startProjectPath) @@ -330,7 +335,8 @@ main version = do let isTest = False let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion mcpServerConfig <- MCP.initServer theCodebase runtime sbRuntime (Just currentDir) ucmVersion authenticatedHTTPClient Server.startServer isTest @@ -374,6 +380,7 @@ main version = do theCodebase [] authenticatedHTTPClient + tokenProvider credMan mayBaseUrl (PP.toIds startingProjectPath) @@ -596,6 +603,7 @@ launch :: Codebase.Codebase IO Symbol Ann -> [Either Input.Event Input.Input] -> AuthN.AuthenticatedHttpClient -> + TokenProvider -> AuthN.CredentialManager -> Maybe Server.BaseUrl -> PP.ProjectPathIds -> @@ -603,7 +611,7 @@ launch :: (PP.ProjectPathIds -> IO ()) -> CommandLine.ShouldWatchFiles -> IO () -launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient credMan serverBaseUrl startingPath initResult lspCheckForChanges shouldWatchFiles = do +launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient tokenProvider credMan serverBaseUrl startingPath initResult lspCheckForChanges shouldWatchFiles = do showWelcomeHint <- Codebase.runTransaction codebase Queries.doProjectsExist let isNewCodebase = case initResult of CreatedCodebase -> NewlyCreatedCodebase @@ -621,6 +629,7 @@ launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient cre serverBaseUrl ucmVersion authenticatedHTTPClient + tokenProvider credMan lspCheckForChanges shouldWatchFiles diff --git a/unison-cli/src/Unison/Share/Codeserver.hs b/unison-cli/src/Unison/Share/Codeserver.hs index ea7aee4b73..a569094da2 100644 --- a/unison-cli/src/Unison/Share/Codeserver.hs +++ b/unison-cli/src/Unison/Share/Codeserver.hs @@ -3,6 +3,8 @@ module Unison.Share.Codeserver defaultCodeserver, resolveCodeserver, CodeserverURI (..), + Scheme (..), + CodeserverId (..), ) where diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs new file mode 100644 index 0000000000..80386908ec --- /dev/null +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -0,0 +1,230 @@ +module Unison.Share.SyncV3 + ( syncFromCodeserver, + ) +where + +import Network.Socket (withSocketsDo) +import Control.Arrow ((&&&)) +import Control.Monad.Reader +import Data.Set qualified as Set +import Data.Text.Encoding as Text +import GHC.Natural +import Ki qualified +import Network.WebSockets qualified as WS +import U.Codebase.HashTags +import U.Codebase.Sqlite.DbId +import U.Codebase.Sqlite.Queries qualified as Q +import U.Codebase.Sqlite.V2.HashHandle (v2HashHandle) +import Unison.Cli.Monad +import Unison.Cli.Monad qualified as Cli +import Unison.Codebase (Codebase) +import Unison.Codebase qualified as Codebase +import Unison.Debug qualified as Debug +import Unison.Hash32 (Hash32) +import Unison.Prelude +import Unison.Server.Orphans () +import Unison.Share.API.Hash qualified as Share +import Unison.Share.Codeserver qualified as Codeserver +import Unison.Share.Sync.Types qualified as Sync +import Unison.Share.Types +import Unison.Sync.Common qualified as Sync +import Unison.SyncV3.Types +import Unison.SyncV3.Types as SyncV3 +import Unison.SyncV3.Utils (tempEntityDependencies) +import Unison.Util.Servant.CBOR qualified as CBOR +import Unison.Util.Websockets (Queues (..), withQueues) +import UnliftIO.STM +import Wuss qualified + +-- Websocket send/receive buffer sizes +inputBuffer :: Natural +inputBuffer = 1000 + +outputBuffer :: Natural +outputBuffer = 1000 + +transactionBatchSize :: Natural +transactionBatchSize = 1000 + +syncV3ClientVersion :: Int32 +syncV3ClientVersion = 1 + +syncFromCodeserver :: + Bool -> + -- | The Unison Share URL. + Codeserver.CodeserverURI -> + -- | The branch to download from. + BranchRef -> + -- | The hash to download. + Share.HashJWT -> + Cli (Either (Sync.SyncError SyncV3.SyncError) (CausalHash, CausalHashId)) +syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do + Cli.Env {codebase, tokenProvider} <- ask + let host = Codeserver.codeserverRegName codeserver + let syncV3Path = "/ucm/v3/sync/download" + let rootCausalHash = Share.hashJWTHash hashJwt + -- Enable compression + let connectionOptions = WS.defaultConnectionOptions {WS.connectionCompressionOptions = WS.PermessageDeflateCompression WS.defaultPermessageDeflate} + headers <- + (liftIO (tokenProvider (codeserverIdFromCodeserverURI codeserver))) <&> \case + Left {} -> [] + Right token -> [("Authorization", "Bearer " <> Text.encodeUtf8 token)] + let runner = case Codeserver.codeserverScheme codeserver of + Codeserver.Https -> + let tlsPort = 443 + port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) codeserver + in Wuss.runSecureClientWith host port + Codeserver.Http -> + let tlsPort = 443 :: Int + port = maybe tlsPort id $ (Codeserver.codeserverPort) codeserver + in WS.runClientWith host port + Debug.debugLogM Debug.Temp "Obtaining Connection" + liftIO $ withSocketsDo $ (runner syncV3Path connectionOptions headers) \conn -> do + Debug.debugLogM Debug.Temp "Obtained Connection" + withQueues inputBuffer outputBuffer conn $ \queues@Queues {send} -> do + Debug.debugLogM Debug.Temp "Obtained Queues" + let initMsg = + InitMsg + { initMsgClientVersion = syncV3ClientVersion, + initMsgBranchRef = branchRef, + initMsgRootCausal = hashJwt, + initMsgRequestedDepth = Nothing + } + Debug.debugLogM Debug.Temp "Sending init message" + atomically $ send $ Msg $ ReceiverInitStream initMsg + Debug.debugLogM Debug.Temp "Init message sent" + pendingRequestsVar <- newTVarIO (Set.singleton (CausalEntity, rootCausalHash)) + yetToRequestVar <- newTVarIO Set.empty + toIngestQueue <- newTBQueueIO transactionBatchSize + let initState = + SyncState + { pendingRequestsVar, + yetToRequestVar, + toIngestQueue, + rootCausalHash + } + + liftIO (doSync codebase initState queues) >>= \case + -- TODO: proper error handling + Left err -> error $ show err + Right () -> pure () + Debug.debugLogM Debug.Temp "!Done sync, flushing temp entities" + causalId <- liftIO $ flushTemp codebase (Share.hashJWTHash hashJwt) + pure $ Right (Sync.hash32ToCausalHash rootCausalHash, causalId) + +data SyncState = SyncState + { pendingRequestsVar :: TVar (Set (EntityKind, Hash32)), + yetToRequestVar :: TVar (Set (EntityKind, Hash32)), + toIngestQueue :: TBQueue (Entity Hash32 Text), + rootCausalHash :: Hash32 + } + +-- | Given a stream that's already been initialized, receive entities and issue requests as needed. +doSync :: Codebase IO v a -> SyncState -> Queues (MsgOrError SyncError (FromReceiverMessage Share.HashJWT Hash32)) (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) -> IO (Either SyncError ()) +doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, rootCausalHash} (Queues {send, receive, shutdown, connectionClosed}) = Ki.scoped \scope -> do + errorVar <- newEmptyTMVarIO + let onErr err = do + atomically $ putTMVar errorVar err + shutdown + _ <- Ki.fork scope (receiverWorker onErr) + _ <- Ki.fork scope (requesterWorker onErr) + _ <- Ki.fork scope (ingestionWorker onErr) + let finished = do + pending <- readTVar pendingRequestsVar + yetToReq <- readTVar yetToRequestVar + guard $ Set.null pending && Set.null yetToReq + + Debug.debugLogM Debug.Temp "Awaiting completion" + result <- + atomically $ + (Right <$> finished) + <|> (Right <$> Ki.awaitAll scope) + <|> (Left . Left <$> readTMVar errorVar) + <|> (Left . Right <$> connectionClosed) + + Debug.debugM Debug.Temp "End result" result + case result of + Left (Left syncErr) -> pure $ Left syncErr + Left (Right mayConnErr) -> case mayConnErr of + Nothing -> pure $ Right () + Just connErr -> pure $ Left $ ConnectionError (tShow connErr) + Right () -> pure $ Right () + where + receiverWorker :: (SyncError -> IO ()) -> IO () + receiverWorker onErr = do + Debug.debugLogM Debug.Temp "Receiver waiting for message" + atomically receive >>= \case + Msg (EmitterEntityMsg entity) -> do + atomically $ do + writeTBQueue toIngestQueue entity + receiverWorker onErr + Err err -> onErr err + requesterWorker :: (SyncError -> IO ()) -> IO () + requesterWorker _onErr = forever do + Debug.debugLogM Debug.Temp "Requester waiting to send requests" + atomically $ do + requests <- readTVar yetToRequestVar + guard $ not (Set.null requests) + writeTVar yetToRequestVar Set.empty + modifyTVar' pendingRequestsVar (Set.union requests) + send $ Msg $ ReceiverEntityRequest $ EntityRequestMsg (Set.toList requests) + + ingestionWorker :: (SyncError -> IO ()) -> IO () + ingestionWorker _onErr = forever do + Debug.debugLogM Debug.Temp "Ingestion waiting for entities" + newEntities <- atomically $ do + flushTBQueue toIngestQueue + Codebase.runTransaction codebase $ do + -- TODO: do hash validation based on shouldValidate + for_ newEntities $ \(Entity {entityKind, entityHash, entityDepth, entityData = CBOR.CBORBytes entityBytes}) -> do + Q.insertTempEntitySyncV3 rootCausalHash (tShow entityKind) entityHash (unEntityDepth entityDepth) entityBytes + + tempEntities <- case for newEntities (CBOR.deserialiseOrFailCBORBytes . entityData) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntities -> pure tempEntities + let allDeps = foldMap tempEntityDependencies tempEntities + -- TODO: double-check whether it's okay to have this as a separate atomic block. + alreadyRequestedEntities <- atomically $ do + pending <- readTVar pendingRequestsVar + reqs <- readTVar yetToRequestVar + pure $ Set.union pending reqs + let unrequestedDeps = Set.difference allDeps alreadyRequestedEntities + missingDeps <- + (Set.toList unrequestedDeps) & filterA \(_depKind, depHash) -> do + Codebase.runTransaction codebase (Q.entityLocationSyncV3 depHash) <&> \case + Nothing -> True + _ -> False + let newlyInserted = + newEntities + <&> (entityKind &&& entityHash) + & Set.fromList + -- Request any deps we're missing which also haven't already been requested + atomically $ do + pending <- readTVar pendingRequestsVar + let missingDepsSet = Set.fromList missingDeps + let unRequestedDeps = Set.difference missingDepsSet pending + modifyTVar' yetToRequestVar (Set.union unRequestedDeps) + modifyTVar' pendingRequestsVar (\pending -> Set.difference pending newlyInserted) + +flushTemp :: Codebase IO v a -> Hash32 -> IO CausalHashId +flushTemp codebase rootCausalHash = do + Codebase.runTransaction codebase $ do + Q.streamTempEntitiesSyncV3 rootCausalHash \next -> + do + let loop = do + next >>= \case + Nothing -> pure () + Just (hash, tempEntityBytes) -> + do + Debug.debugLogM Debug.Temp $ "Flushing temp entity: " <> show hash + tempEntity <- case CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes tempEntityBytes) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntity -> pure tempEntity + Debug.debugLogM Debug.Temp $ "Saving in main" <> show hash + void $ Q.saveTempEntityInMain v2HashHandle hash tempEntity + loop + loop + Debug.debugLogM Debug.Temp "Flushed temp entities, getting causal hash id" + Q.expectCausalHashIdByCausalHash (Sync.hash32ToCausalHash rootCausalHash) diff --git a/unison-cli/unison-cli.cabal b/unison-cli/unison-cli.cabal index 75daa4c176..dbea97e245 100644 --- a/unison-cli/unison-cli.cabal +++ b/unison-cli/unison-cli.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -163,6 +163,7 @@ library Unison.Share.Sync Unison.Share.Sync.Types Unison.Share.SyncV2 + Unison.Share.SyncV3 Unison.Util.HTTP Unison.Version hs-source-dirs: @@ -247,6 +248,7 @@ library , megaparsec , memory , mtl + , network , network-simple , network-uri , nonempty-containers @@ -298,8 +300,10 @@ library , vector , wai , warp + , websockets , witch , witherable + , wuss default-language: Haskell2010 if !os(windows) build-depends: diff --git a/unison-share-api/package.yaml b/unison-share-api/package.yaml index 879fce940c..4bba34c9af 100644 --- a/unison-share-api/package.yaml +++ b/unison-share-api/package.yaml @@ -26,6 +26,7 @@ library: - hs-mcp - http-media - http-types + - ki-unlifted - lens - lucid - memory @@ -65,6 +66,7 @@ library: - wai - wai-cors - warp + - websockets - yaml tests: diff --git a/unison-share-api/src/Unison/SyncCommon/Types.hs b/unison-share-api/src/Unison/SyncCommon/Types.hs new file mode 100644 index 0000000000..e9900cb855 --- /dev/null +++ b/unison-share-api/src/Unison/SyncCommon/Types.hs @@ -0,0 +1,18 @@ +-- Types common to multiple versions of Sync +module Unison.SyncCommon.Types + ( BranchRef (..), + ) +where + +import Codec.Serialise (Serialise (..)) +import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Text (Text) +import Unison.Core.Project (ProjectAndBranch (..), ProjectBranchName, ProjectName) +import Unison.Prelude (From (..)) +import Unison.Server.Orphans () + +newtype BranchRef = BranchRef {unBranchRef :: Text} + deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text + +instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where + from pab = BranchRef $ from pab diff --git a/unison-share-api/src/Unison/SyncV2/Types.hs b/unison-share-api/src/Unison/SyncV2/Types.hs index 82f2e95f63..1bfe3e7a6f 100644 --- a/unison-share-api/src/Unison/SyncV2/Types.hs +++ b/unison-share-api/src/Unison/SyncV2/Types.hs @@ -38,20 +38,13 @@ import Data.Text qualified as Text import Data.Word (Word16, Word64) import U.Codebase.HashTags (CausalHash) import U.Codebase.Sqlite.TempEntity (TempEntity) -import Unison.Core.Project (ProjectAndBranch (..), ProjectBranchName, ProjectName) import Unison.Hash32 (Hash32) -import Unison.Prelude (From (..)) import Unison.Server.Orphans () import Unison.Share.API.Hash (HashJWT) import Unison.Sync.Types qualified as SyncV1 +import Unison.SyncCommon.Types import Unison.Util.Servant.CBOR -newtype BranchRef = BranchRef {unBranchRef :: Text} - deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text - -instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where - from pab = BranchRef $ from pab - data GetCausalHashErrorTag = GetCausalHashNoReadPermissionTag | GetCausalHashUserNotFoundTag diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs new file mode 100644 index 0000000000..8ada257633 --- /dev/null +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -0,0 +1,411 @@ +module Unison.SyncV3.Types + ( InitMsg (..), + EntityRequestMsg (..), + FromReceiverMessage (..), + FromEmitterMessage (..), + MsgOrError (..), + SyncError (..), + Entity (..), + EntityKind (..), + EntityDepth (..), + HashTag (..), + BranchRef (..), + ) +where + +import Codec.Serialise (Serialise) +import Codec.Serialise qualified as CBOR +import Control.Lens hiding ((.=)) +import Data.Aeson +import Data.Aeson qualified as Aeson +import Data.ByteString qualified as BS +import Data.ByteString.Lazy.Char8 qualified as BL +import Data.Int (Int32, Int64) +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Text (Text) +import Network.WebSockets (WebSocketsData) +import Network.WebSockets qualified as WS +import U.Codebase.Sqlite.Orphans () +import U.Codebase.Sqlite.TempEntity +import Unison.Hash32 (Hash32) +import Unison.Prelude (tShow) +import Unison.Server.Orphans () +import Unison.Sqlite qualified as Sqlite +import Unison.SyncCommon.Types +import Unison.Util.Servant.CBOR qualified as CBOR + +data InitMsg authedHash = InitMsg + { initMsgClientVersion :: Int32, + initMsgBranchRef :: BranchRef, + initMsgRootCausal :: authedHash, + initMsgRequestedDepth :: Maybe Int64 + } + deriving (Show, Eq) + +instance (ToJSON authedHash) => ToJSON (InitMsg authedHash) where + toJSON (InitMsg {initMsgClientVersion, initMsgBranchRef, initMsgRootCausal, initMsgRequestedDepth}) = + object + [ "clientVersion" .= initMsgClientVersion, + "branchRef" .= initMsgBranchRef, + "rootCausal" .= initMsgRootCausal, + "requestedDepth" .= initMsgRequestedDepth + ] + +instance (FromJSON authedHash) => FromJSON (InitMsg authedHash) where + parseJSON = withObject "InitMsg" $ \o -> + InitMsg + <$> o .: "clientVersion" + <*> o .: "branchRef" + <*> o .: "rootCausal" + <*> o .:? "requestedDepth" + +data EntityRequestMsg hash = EntityRequestMsg + { hashes :: [(EntityKind, hash)] + } + deriving (Show, Eq) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let msg = EntityRequestMsg {hashes = [(CausalEntity, "hash1"), (NamespaceEntity, "hash2")]} +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True +instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where + encode (EntityRequestMsg {hashes}) = + CBOR.encode hashes + + decode = do + hashes <- CBOR.decode @[(EntityKind, sh)] + pure $ EntityRequestMsg {hashes} + +data FromReceiverMessageTag + = ReceiverInitStreamTag + | ReceiverEntityRequestTag + deriving (Show, Eq) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise ReceiverInitStreamTag) == ReceiverInitStreamTag +-- True +-- >>> CBOR.deserialise (CBOR.serialise ReceiverEntityRequestTag) == ReceiverEntityRequestTag +-- True +instance CBOR.Serialise FromReceiverMessageTag where + encode = \case + ReceiverInitStreamTag -> CBOR.encode (0 :: Int) + ReceiverEntityRequestTag -> CBOR.encode (1 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure ReceiverInitStreamTag + 1 -> pure ReceiverEntityRequestTag + _ -> fail $ "Unknown FromReceiverMessageTag: " <> show tag + +-- A message sent from the downloader to the emitter. +data FromReceiverMessage ah hash + = -- Initialize the stream + ReceiverInitStream (InitMsg ah) + | -- Request more entities by hash. + ReceiverEntityRequest (EntityRequestMsg hash) + deriving (Show, Eq) + +instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where + encode msg = do + -- This is dumb, but there's currently no reasonable way to encode a heterogenous Map + -- using Haskell's CBOR library :| + -- + -- See https://github.com/well-typed/cborg/issues/369 + CBOR.encode @BS.ByteString $ BL.toStrict $ Aeson.encode msg + + decode = do + bs <- CBOR.decode @BS.ByteString + case Aeson.eitherDecode $ BL.fromStrict bs of + Left err -> fail $ "Error decoding InitMsg from JSON: " <> err + Right msg -> pure msg + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let msg = InitMsg {initMsgClientVersion = 1, initMsgBranchRef = BranchRef "main", initMsgRootCausal = "hash123", initMsgRequestedDepth = Just 10} +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True +-- >>> let initMsg :: FromReceiverMessage Text Text = ReceiverInitStream msg +-- >>> CBOR.deserialise (CBOR.serialise initMsg) == initMsg +-- True +-- >>> let entityReq :: FromReceiverMessage Text Text = ReceiverEntityRequest (EntityRequestMsg {hashes = [(CausalEntity, "h1")]}) +-- >>> CBOR.deserialise (CBOR.serialise entityReq) == entityReq +-- True +instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromReceiverMessage ah h) where + encode = \case + ReceiverInitStream initMsg -> + CBOR.encode ReceiverInitStreamTag + <> CBOR.encode initMsg + ReceiverEntityRequest msg -> + CBOR.encode ReceiverEntityRequestTag + <> CBOR.encode msg + decode = do + tag <- CBOR.decode @FromReceiverMessageTag + case tag of + ReceiverInitStreamTag -> ReceiverInitStream <$> CBOR.decode @(InitMsg ah) + ReceiverEntityRequestTag -> ReceiverEntityRequest <$> CBOR.decode @(EntityRequestMsg h) + +data SyncError + = InitializationError Text + | UnexpectedMessage BL.ByteString + | EncodingFailure Text + | -- The caller asked for a Hash they shouldn't have access to. + ForbiddenEntityRequest (Set (EntityKind, Hash32)) + | ConnectionError Text + | ProjectNotFound BranchRef + | UserNotFound BranchRef + | NoReadPermission BranchRef + | HashJWTVerificationError Text + | InvalidBranchRef Text BranchRef + deriving (Show, Eq) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import qualified Data.Set as Set +-- >>> CBOR.deserialise (CBOR.serialise (InitializationError "test")) == InitializationError "test" +-- True +-- >>> CBOR.deserialise (CBOR.serialise (EncodingFailure "fail")) == EncodingFailure "fail" +-- True +-- >>> let forbidden = ForbiddenEntityRequest (Set.fromList [(CausalEntity, undefined)]) +-- >>> CBOR.deserialise (CBOR.serialise (ConnectionError "err")) == ConnectionError "err" +-- True +instance CBOR.Serialise SyncError where + encode = \case + InitializationError msg -> + CBOR.encode (0 :: Int) <> CBOR.encode msg + UnexpectedMessage msg -> + CBOR.encode (1 :: Int) <> CBOR.encode (BL.toStrict msg) + EncodingFailure msg -> + CBOR.encode (2 :: Int) <> CBOR.encode msg + ForbiddenEntityRequest hashes -> + CBOR.encode (3 :: Int) <> CBOR.encode hashes + ConnectionError err -> + CBOR.encode (4 :: Int) <> CBOR.encode err + ProjectNotFound branchRef -> + CBOR.encode (5 :: Int) <> CBOR.encode branchRef + UserNotFound branchRef -> + CBOR.encode (6 :: Int) <> CBOR.encode branchRef + NoReadPermission branchRef -> + CBOR.encode (7 :: Int) <> CBOR.encode branchRef + HashJWTVerificationError err -> + CBOR.encode (8 :: Int) <> CBOR.encode err + InvalidBranchRef err branchRef -> + CBOR.encode (9 :: Int) <> CBOR.encode err <> CBOR.encode branchRef + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> InitializationError <$> CBOR.decode + 1 -> do + bs <- CBOR.decode @BS.ByteString + pure $ UnexpectedMessage (BL.fromStrict bs) + 2 -> EncodingFailure <$> CBOR.decode + 3 -> ForbiddenEntityRequest . Set.fromList <$> CBOR.decode + 4 -> do + err <- CBOR.decode @Text + pure $ ConnectionError err + 5 -> ProjectNotFound <$> CBOR.decode + 6 -> UserNotFound <$> CBOR.decode + 7 -> NoReadPermission <$> CBOR.decode + 8 -> HashJWTVerificationError <$> CBOR.decode + 9 -> do + err <- CBOR.decode @Text + branchRef <- CBOR.decode @BranchRef + pure $ InvalidBranchRef err branchRef + _ -> fail $ "Unknown SyncError tag: " <> show tag + +-- A message sent from the emitter to the downloader. +data FromEmitterMessage hash text + = EmitterEntityMsg (Entity hash text) + deriving (Show, Eq) + +data EntityKind + = CausalEntity + | NamespaceEntity + | DefnComponentEntity + | PatchEntity + deriving stock (Show, Eq, Ord) + +instance Sqlite.ToField EntityKind where + toField = + Sqlite.toField . \case + CausalEntity -> (0 :: Int) + NamespaceEntity -> 1 + DefnComponentEntity -> 2 + PatchEntity -> 3 + +instance Sqlite.FromField EntityKind where + fromField field = do + tag <- Sqlite.fromField field + case tag of + (0 :: Int) -> pure CausalEntity + 1 -> pure NamespaceEntity + 2 -> pure DefnComponentEntity + 3 -> pure PatchEntity + _ -> fail $ "Unknown EntityKind tag: " <> show tag + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise CausalEntity) == CausalEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise NamespaceEntity) == NamespaceEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise DefnComponentEntity) == DefnComponentEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise PatchEntity) == PatchEntity +-- True +instance CBOR.Serialise EntityKind where + encode = \case + CausalEntity -> CBOR.encode (0 :: Int) + NamespaceEntity -> CBOR.encode (1 :: Int) + DefnComponentEntity -> CBOR.encode (2 :: Int) + PatchEntity -> CBOR.encode (3 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure CausalEntity + 1 -> pure NamespaceEntity + 2 -> pure DefnComponentEntity + 3 -> pure PatchEntity + _ -> fail $ "Unknown EntityKind tag: " <> show tag + +-- | The number of _levels_ of dependencies an entity has, +-- this has no real semantic meaning on its own, but provides the +-- property that out of a given set of synced entities, if you process +-- them in order of increasing EntityDepth, you will always have +-- processed an entity's dependencies before you see the entity itself. +newtype EntityDepth = EntityDepth {unEntityDepth :: Int64} + deriving (Show, Eq, Ord) + deriving newtype (CBOR.Serialise) + +data Entity hash text = Entity + { entityHash :: hash, + entityKind :: EntityKind, + entityDepth :: EntityDepth, + entityData :: CBOR.CBORBytes TempEntity + } + deriving (Show, Eq) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import U.Codebase.Sqlite.TempEntity (TempEntity(..)) +-- >>> let ent :: Entity Text Text = Entity {entityHash = "hash", entityKind = CausalEntity, entityDepth = EntityDepth 5, entityData = CBOR.CBORBytes "abc"} +-- >>> CBOR.deserialise (CBOR.serialise ent) == ent +-- True +instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Entity smallHash text) where + encode (Entity {entityHash, entityKind, entityDepth, entityData}) = + CBOR.encode entityHash + <> CBOR.encode entityKind + <> CBOR.encode entityDepth + <> CBOR.encode entityData + + decode = do + entityHash <- CBOR.decode @smallHash + entityKind <- CBOR.decode @EntityKind + entityDepth <- CBOR.decode @EntityDepth + entityData <- CBOR.decode @(CBOR.CBORBytes TempEntity) + + pure $ Entity {entityHash, entityKind, entityData, entityDepth} + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import U.Codebase.Sqlite.TempEntity (TempEntity(..)) +-- >>> let ent :: Entity Text Text = Entity {entityHash = "hash", entityKind = CausalEntity, entityDepth = EntityDepth 5, entityData = CBOR.CBORBytes "abc"} +-- >>> let msg = EmitterEntityMsg ent +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True +instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitterMessage hash text) where + encode = \case + EmitterEntityMsg msg -> CBOR.encode EmitterEntityTag <> CBOR.encode msg + + decode = do + tag <- CBOR.decode @FromEmitterMessageTag + case tag of + EmitterEntityTag -> EmitterEntityMsg <$> CBOR.decode + +data FromEmitterMessageTag + = EmitterEntityTag + deriving (Show, Eq) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise EmitterEntityTag) == EmitterEntityTag +-- True +instance CBOR.Serialise FromEmitterMessageTag where + encode = \case + EmitterEntityTag -> CBOR.encode (0 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure EmitterEntityTag + _ -> fail $ "Unknown FromEmitterMessageTag: " <> show tag + +data MsgOrError err a + = Msg a + | Err err + deriving (Show, Eq, Ord) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise (Msg "test" :: MsgOrError Text Text)) == Msg "test" +-- True +-- >>> CBOR.deserialise (CBOR.serialise (Err "error" :: MsgOrError Text Text)) == Err "error" +-- True +instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError err a) where + encode = \case + Msg a -> CBOR.encode (0 :: Int) <> CBOR.encode a + Err e -> CBOR.encode (1 :: Int) <> CBOR.encode e + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> Msg <$> CBOR.decode + 1 -> Err <$> CBOR.decode + _ -> fail $ "Unknown MsgOrError tag: " <> show tag + +-- | Roundtrip test: +-- >>> import qualified Network.WebSockets as WS +-- >>> let msgVal = Msg "test" :: MsgOrError SyncError Text +-- >>> WS.fromLazyByteString (WS.toLazyByteString msgVal) == msgVal +-- True +-- >>> let errVal = Err (InitializationError "init error") :: MsgOrError SyncError Text +-- >>> WS.fromLazyByteString (WS.toLazyByteString errVal) == errVal +-- True +-- >>> let dataMsg = WS.Binary (WS.toLazyByteString msgVal) +-- >>> WS.fromDataMessage dataMsg == msgVal +-- True +instance (Serialise msg) => WebSocketsData (MsgOrError SyncError msg) where + fromLazyByteString bytes = + CBOR.deserialiseOrFail bytes + & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id + + toLazyByteString = CBOR.serialise + + fromDataMessage dm = do + case dm of + WS.Text bytes _ -> WS.fromLazyByteString bytes + WS.Binary bytes -> WS.fromLazyByteString bytes + +-- Application level compression of Hash references. +-- We can send a mapping of Hash <-> HashTag at the start of the stream, +-- and then use the smaller HashTag in all subsequent messages. +data HashTag = HashTag (EntityKind, Int64) + deriving (Show, Eq, Ord) + +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let tag = HashTag (CausalEntity, 42) +-- >>> CBOR.deserialise (CBOR.serialise tag) == tag +-- True +instance CBOR.Serialise HashTag where + encode (HashTag (kind, idx)) = + CBOR.encode (kind, idx) + + decode = do + (kind, idx) <- CBOR.decode @(EntityKind, Int64) + pure $ HashTag (kind, idx) diff --git a/unison-share-api/src/Unison/SyncV3/Utils.hs b/unison-share-api/src/Unison/SyncV3/Utils.hs new file mode 100644 index 0000000000..513fba3ef7 --- /dev/null +++ b/unison-share-api/src/Unison/SyncV3/Utils.hs @@ -0,0 +1,30 @@ +module Unison.SyncV3.Utils (tempEntityDependencies, entityDependencies) where + +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Set.Lens qualified as Lens +import U.Codebase.Sqlite.Entity qualified as Entity +import U.Codebase.Sqlite.TempEntity +import Unison.Hash32 (Hash32) +import Unison.SyncV3.Types +import Unison.Util.Servant.CBOR qualified as CBOR + +tempEntityDependencies :: TempEntity -> Set (EntityKind, Hash32) +tempEntityDependencies entity = do + let componentDeps = Lens.setOf Entity.defns_ entity + patchDeps = Lens.setOf Entity.patches_ entity + branchHashes = Lens.setOf Entity.branchHashes_ entity <> Lens.setOf Entity.branches_ entity + causalHashes = Lens.setOf Entity.causalHashes_ entity + in Set.unions + [ Set.map (DefnComponentEntity,) componentDeps, + Set.map (PatchEntity,) patchDeps, + Set.map (NamespaceEntity,) branchHashes, + Set.map (CausalEntity,) causalHashes + ] + +entityDependencies :: Entity hash text -> Set (EntityKind, Hash32) +entityDependencies Entity {entityData} = do + case (CBOR.deserialiseOrFailCBORBytes $ entityData) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntity -> tempEntityDependencies tempEntity diff --git a/unison-share-api/src/Unison/Util/Websockets.hs b/unison-share-api/src/Unison/Util/Websockets.hs new file mode 100644 index 0000000000..721034ec3c --- /dev/null +++ b/unison-share-api/src/Unison/Util/Websockets.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE KindSignatures #-} + +module Unison.Util.Websockets + ( withQueues, + Queues (..), + ) +where + +import Control.Applicative +import Control.Lens (Profunctor (..)) +import Control.Monad +import Data.Text (Text) +import GHC.Natural +import Ki.Unlifted qualified as Ki +import Network.WebSockets +import UnliftIO + +-- | Allows interfacing with a websocket as a pair of bounded queues. +data Queues i o = Queues + { -- Receive from the client + receive :: STM o, + -- Send to the client + send :: i -> STM (), + shutdown :: IO (), + -- This succeeds with a 'Just' value if the connection was closed due to an exception, + -- 'Nothing' if it was closed normally, or retries if the connection is still open. + connectionClosed :: STM (Maybe ConnectionException) + } + +instance Profunctor Queues where + dimap f g (Queues {receive, send, shutdown, connectionClosed}) = + Queues + { receive = g <$> receive, + send = send . f, + shutdown, + connectionClosed + } + +withQueues :: forall i o m a. (MonadUnliftIO m, WebSocketsData i, WebSocketsData o) => Natural -> Natural -> Connection -> (Queues i o -> m a) -> m a +withQueues inputBuffer outputBuffer conn action = Ki.scoped $ \scope -> do + receiveQ <- liftIO $ newTBQueueIO inputBuffer + sendQ <- liftIO $ newTBQueueIO outputBuffer + connectionClosedMVar <- liftIO $ newEmptyTMVarIO + let receive = do readTBQueue receiveQ + let send msg = writeTBQueue sendQ msg + + let triggerClose :: forall n. (MonadIO n) => (Maybe ConnectionException) -> n () + triggerClose mayErr = do + newlyClosed <- atomically $ do + tryPutTMVar connectionClosedMVar mayErr + when newlyClosed $ do + -- If we closed due to a connection error, we don't need to send a close. + -- If we're shutting down normally, we send a close message. + case mayErr of + Nothing -> liftIO $ sendClose conn ("Server is shutting down" :: Text) + _ -> pure () + + let queues = Queues {receive, send, shutdown = (triggerClose Nothing), connectionClosed = readTMVar connectionClosedMVar} + _ <- Ki.fork scope $ recvWorker triggerClose receiveQ + _ <- Ki.fork scope $ sendWorker triggerClose sendQ + r <- action queues + -- Ensure the connection is closed when done. + liftIO $ triggerClose Nothing + pure r + where + recvWorker :: (Maybe ConnectionException -> m ()) -> TBQueue o -> m () + recvWorker triggerClose q = UnliftIO.handle (handler triggerClose) $ do + msg <- liftIO $ receiveData conn + atomically $ writeTBQueue q msg + recvWorker triggerClose q + + handler :: (Maybe ConnectionException -> m ()) -> ConnectionException -> m () + handler triggerClose = \case + CloseRequest {} -> do + -- The client requested a close, we can just close normally. + triggerClose Nothing + -- Other cases are exceptional + err -> triggerClose (Just err) + + sendWorker :: (Maybe ConnectionException -> m ()) -> TBQueue i -> m () + sendWorker triggerClose q = UnliftIO.handle (handler triggerClose) $ do + outMsgs <- atomically $ some $ readTBQueue q + liftIO $ sendBinaryDatas conn outMsgs + sendWorker triggerClose q diff --git a/unison-share-api/unison-share-api.cabal b/unison-share-api/unison-share-api.cabal index dd8a117685..e995cadf76 100644 --- a/unison-share-api/unison-share-api.cabal +++ b/unison-share-api/unison-share-api.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -49,10 +49,14 @@ library Unison.Sync.Common Unison.Sync.EntityValidation Unison.Sync.Types + Unison.SyncCommon.Types Unison.SyncV2.API Unison.SyncV2.Types + Unison.SyncV3.Types + Unison.SyncV3.Utils Unison.Util.Find Unison.Util.Servant.CBOR + Unison.Util.Websockets hs-source-dirs: src default-extensions: @@ -105,6 +109,7 @@ library , hs-mcp , http-media , http-types + , ki-unlifted , lens , lucid , memory @@ -144,6 +149,7 @@ library , wai , wai-cors , warp + , websockets , yaml default-language: Haskell2010