diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 38996dd..d854956 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -190,6 +190,7 @@ async fn build_prompt( file_url: &str, language_id: LanguageId, snippet_retriever: Arc>, + target_workspace: &str, ) -> Result { let t = Instant::now(); if fim.enabled { @@ -259,6 +260,7 @@ async fn build_prompt( Compare::Neq, file_url.into(), )), + target_workspace ) .await?; let context_header = build_context_header(language_id, snippets); @@ -307,6 +309,7 @@ async fn build_prompt( Compare::Neq, file_url.into(), )), + target_workspace ) .await?; let context_header = build_context_header(language_id, snippets); @@ -508,6 +511,25 @@ fn build_url(backend: Backend, model: &str) -> String { } } +fn file_uri_to_workspace(workspace_folders: Option<&Vec>, file_uri: &str) -> String { + // let folders = self.workspace_folders.read().await; + match workspace_folders { + Some(folders) => { + let parent_workspace = folders + .clone() + .into_iter() + .filter(|folder| file_uri.contains(folder.uri.path())) + .collect::>(); + if parent_workspace.is_empty() { + folders[0].name.clone() + } else { + parent_workspace[0].name.clone() + } + } + None => "".to_string(), + } +} + impl LlmService { async fn get_completions( &self, @@ -520,6 +542,9 @@ impl LlmService { let document_map = self.document_map.read().await; let file_url = params.text_document_position.text_document.uri.as_str(); + let target_workspace = file_uri_to_workspace( + self.workspace_folders.read().await.as_ref(), + file_url); let document = match document_map.get(file_url) { Some(doc) => doc, @@ -578,6 +603,7 @@ impl LlmService { &file_url.replace("file://", ""), document.language_id, self.snippet_retriever.clone(), + &target_workspace, ).await?; let http_client = if params.tls_skip_verify_insecure { @@ -700,12 +726,15 @@ impl LanguageServer for LlmService { .await; } let mut guard = snippet_retriever.write().await; + let workspace_path = &workspace_folders[0].uri.path().to_string(); + let workspace_root = file_uri_to_workspace(Some(workspace_folders), &workspace_path); tokio::select! { res = guard.build_workspace_snippets( client.clone(), config, token, - workspace_folders[0].uri.path(), + &workspace_root, + &workspace_path, ) => { if let Err(err) = res { error!("failed building workspace snippets: {err}"); @@ -778,6 +807,9 @@ impl LanguageServer for LlmService { match doc.change(range, &change.text).await { Ok((start, old_end, new_end)) => { let start = Position::new(start as u32, 0); + let workspace_folders = self.workspace_folders.read().await; + let target_workspace = file_uri_to_workspace(workspace_folders.as_ref(), path); + if let Err(err) = self .snippet_retriever .write() @@ -785,6 +817,7 @@ impl LanguageServer for LlmService { .remove( path.to_owned(), Range::new(start, Position::new(old_end as u32, 0)), + &target_workspace, ) .await { @@ -797,6 +830,7 @@ impl LanguageServer for LlmService { .update_document( path.to_owned(), Range::new(start, Position::new(new_end as u32, 0)), + &target_workspace, ) .await { @@ -897,6 +931,7 @@ async fn main() { .danger_accept_invalid_certs(true) .build() .expect("failed to build reqwest unsafe client"); + debug!("Reading {:?}", cache_dir); let config = Arc::new( load_config( @@ -952,3 +987,81 @@ async fn main() { Server::new(stdin, stdout, socket).serve(service).await; } } + +#[cfg(test)] +mod tests { + use super::*; + + async fn service_setup() -> LspService { + let cache_dir = PathBuf::from(r"idontexist"); + let config = Arc::new(LlmLsConfig { + ..Default::default() + }); + let snippet_retriever = Arc::new(RwLock::new( + SnippetRetriever::new(cache_dir.join("embeddings"), config.model.clone(), 20, 10) + .await + .unwrap(), + )); + let (service, _) = LspService::build(|client| LlmService { + cache_dir, + client, + config, + document_map: Arc::new(RwLock::new(HashMap::new())), + http_client: reqwest::Client::new(), + unsafe_http_client: reqwest::Client::new(), + workspace_folders: Arc::new(RwLock::new(None)), + tokenizer_map: Arc::new(RwLock::new(HashMap::new())), + unauthenticated_warn_at: Arc::new(RwLock::new( + Instant::now() + .checked_sub(MAX_WARNING_REPEAT) + .expect("instant to be in bounds"), + )), + snippet_retriever, + supports_progress_bar: Arc::new(RwLock::new(false)), + cancel_snippet_build_tx: Arc::new(RwLock::new(None)), + indexation_handle: Arc::new(RwLock::new(None)), + }) + .finish(); + service + } + + #[tokio::test] + async fn test_file_uri_to_workspace() { + // let (service, socket) = LspService::new(|client| LlmService { client }); + let service = service_setup().await; + { + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), + "/home/test"); + + assert_eq!(inn, ""); + } + { + *service.inner().workspace_folders.write().await = vec![ + WorkspaceFolder { + name: "other_repo".to_string(), + uri: Url::from_directory_path("/home/other_test").unwrap(), + }, + WorkspaceFolder { + name: "test_repo".to_string(), + uri: Url::from_directory_path("/home/test").unwrap(), + }, + ] + .into(); + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), + "/home/test/src/lib/main.py"); + assert_eq!(inn, "test_repo"); + } + { + *service.inner().workspace_folders.write().await = vec![WorkspaceFolder { + name: "other_repo".to_string(), + uri: Url::from_directory_path("/home/other_test").unwrap(), + }] + .into(); + let folders = service.inner().workspace_folders.read().await; + let inn = file_uri_to_workspace(folders.as_ref(), "/home/test/src/lib/main.py"); + assert_eq!(inn, "other_repo"); + } + } +} diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 6bea3f8..2a5f93d 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -241,7 +241,6 @@ impl TryFrom<&SimilarityResult> for Snippet { pub(crate) struct SnippetRetriever { cache_path: PathBuf, - collection_name: String, db: Option, model: Arc, model_config: ModelConfig, @@ -260,13 +259,12 @@ impl SnippetRetriever { window_size: usize, window_step: usize, ) -> Result { - let collection_name = "code-slices".to_owned(); let (model, tokenizer) = build_model_and_tokenizer(model_config.id.clone(), model_config.revision.clone()) .await?; Ok(Self { cache_path, - collection_name, + // collection_name, db: None, model: Arc::new(model), model_config, @@ -276,12 +274,20 @@ impl SnippetRetriever { }) } - pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result { + fn workspace_name_to_snippet_collections(&self, workspace_root: &str) -> String { + format!("{}--{}", "workspace", workspace_root).to_string() + } + + pub(crate) async fn initialise_database( + &mut self, + db_name: &str, + workspace_root: &str, + ) -> Result { let uri = self.cache_path.join(db_name); let mut db = Db::open(uri).await.expect("failed to open database"); match db .create_collection( - self.collection_name.clone(), + self.workspace_name_to_snippet_collections(workspace_root), self.model_config.embeddings_size, Distance::Cosine, ) @@ -303,24 +309,24 @@ impl SnippetRetriever { config: Arc, token: NumberOrString, workspace_root: &str, + workspace_path: &str, ) -> Result<()> { - debug!("building workspace snippets"); + debug!("building snippets for workspace {workspace_root}"); let start = Instant::now(); - let workspace_root = PathBuf::from(workspace_root); if self.db.is_none() { - self.initialise_database(&format!( - "{}--{}", - workspace_root - .file_name() - .ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))? - .to_str() - .ok_or(Error::NonUnicode)?, - self.model_config.id.replace('/', "--"), - )) + self.initialise_database( + &format!( + "{}--{}", + "code-slices".to_owned(), + self.model_config.id.replace('/', "--"), + ), + workspace_root, + ) .await?; } + let workspace_path: PathBuf = PathBuf::from(workspace_path); let mut files = Vec::new(); - let mut gitignore = Gitignore::parse(&workspace_root).ok(); + let mut gitignore = Gitignore::parse(&workspace_path).ok(); for pattern in config.ignored_paths.iter() { if let Some(gitignore) = gitignore.as_mut() { if let Err(err) = gitignore.add_rule(pattern.clone()) { @@ -341,7 +347,7 @@ impl SnippetRetriever { }) .await; let mut stack = VecDeque::new(); - stack.push_back(workspace_root.clone()); + stack.push_back(workspace_path.clone()); while let Some(src) = stack.pop_back() { let mut entries = tokio::fs::read_dir(&src).await?; while let Some(entry) = entries.next_entry().await? { @@ -367,7 +373,7 @@ impl SnippetRetriever { } for (i, file) in files.iter().enumerate() { let file_url = file.to_str().expect("file path should be utf8").to_owned(); - self.add_document(file_url).await?; + self.add_document(file_url, workspace_root).await?; client .send_notification::(ProgressParams { token: token.clone(), @@ -376,7 +382,7 @@ impl SnippetRetriever { message: Some(format!( "{i}/{} ({})", files.len(), - file.strip_prefix(workspace_root.as_path())? + file.strip_prefix(workspace_path.as_path())? .to_str() .expect("expect file name to be valid unicode") )), @@ -393,16 +399,23 @@ impl SnippetRetriever { Ok(()) } - pub(crate) async fn add_document(&self, file_url: String) -> Result<()> { - self.build_and_add_snippets(file_url, 0, None).await?; + pub(crate) async fn add_document(&self, file_url: String, workspace_root: &str) -> Result<()> { + self.build_and_add_snippets(file_url, 0, None, workspace_root) + .await?; Ok(()) } - pub(crate) async fn update_document(&mut self, file_url: String, range: Range) -> Result<()> { + pub(crate) async fn update_document( + &mut self, + file_url: String, + range: Range, + workspace_root: &str, + ) -> Result<()> { self.build_and_add_snippets( file_url, range.start.line as usize, Some(range.end.line as usize), + workspace_root, ) .await?; Ok(()) @@ -459,12 +472,14 @@ impl SnippetRetriever { &self, query: &[f32], filter: Option, + workspace_root: &str, ) -> Result> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection(&self.collection_name).await?; + let target_collection_name = self.workspace_name_to_snippet_collections(workspace_root); + let col = db.get_collection(&target_collection_name).await?; let result = col .read() .await @@ -485,12 +500,19 @@ impl SnippetRetriever { Ok(()) } - pub(crate) async fn remove(&self, file_url: String, range: Range) -> Result<()> { + pub(crate) async fn remove( + &self, + file_url: String, + range: Range, + target_workspace: &str, + ) -> Result<()> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection(&self.collection_name).await?; + let col = db + .get_collection(&self.workspace_name_to_snippet_collections(&target_workspace)) + .await?; col.write().await.remove(Some( Collection::filter() .comparison( @@ -542,12 +564,14 @@ impl SnippetRetriever { file_url: String, start: usize, end: Option, + workspace_root: &str, ) -> Result<()> { let db = match self.db.as_ref() { Some(db) => db.clone(), None => return Err(Error::UninitialisedDatabase), }; - let col = db.get_collection("code-slices").await?; + let collection_name = self.workspace_name_to_snippet_collections(workspace_root); + let col = db.get_collection(&collection_name).await?; let file = tokio::fs::read_to_string(&file_url).await?; let lines = file.split('\n').collect::>(); let end = end.unwrap_or(lines.len()).min(lines.len()); diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index dd20995..a6dfdb5 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -267,6 +267,18 @@ impl TryInto for &Value { } } +impl TryInto for &Value { + type Error = Error; + + fn try_into(self) -> Result { + if let Value::Number(n) = self { + Ok(n.clone() as usize) + } else { + Err(Error::ValueNotNumber(self.to_owned())) + } + } +} + impl From for Value { fn from(value: usize) -> Self { Self::Number(value as f32)