From 17396a530f9178903e007f364148cd79242b7a43 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Mon, 10 Feb 2025 05:35:17 -0800 Subject: [PATCH] feat: add hint for missing fields (#14521) * feat: add hint for missing fields * set threshold to include 0.5 * fix failed test * add diagnose * fix clippy * fix bugs fix bugs * add test * fix test * fix clippy --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/column.rs | 19 ++---- datafusion/common/src/dfschema.rs | 2 +- datafusion/common/src/error.rs | 37 ++++++++++- datafusion/common/src/utils/mod.rs | 21 ++++++ datafusion/sql/src/planner.rs | 13 +++- datafusion/sql/tests/cases/diagnostic.rs | 65 ++++++++++++++++--- datafusion/sqllogictest/test_files/errors.slt | 17 ++++- .../sqllogictest/test_files/identifiers.slt | 8 +-- .../sqllogictest/test_files/join.slt.part | 2 +- 9 files changed, 153 insertions(+), 31 deletions(-) diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 05e2dff0bd43..2c3a84229d30 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,7 +17,7 @@ //! Column -use crate::error::_schema_err; +use crate::error::{_schema_err, add_possible_columns_to_diag}; use crate::utils::{parse_identifiers_normalized, quote_identifier}; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow_schema::{Field, FieldRef}; @@ -273,18 +273,11 @@ impl Column { // user which columns are candidates, or which table // they come from. For now, let's list the table names // only. - for qualified_field in qualified_fields { - let (Some(table), _) = qualified_field else { - continue; - }; - diagnostic.add_note( - format!( - "possible reference to '{}' in table '{}'", - &self.name, table - ), - None, - ); - } + add_possible_columns_to_diag( + &mut diagnostic, + &Column::new_unqualified(&self.name), + &columns, + ); err.with_diagnostic(diagnostic) }); } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 2ac629432ce9..f6ab5acd975f 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1069,7 +1069,7 @@ mod tests { Column names are case sensitive. \ You can use double quotes to refer to the \"\"t1.c0\"\" column \ or set the datafusion.sql_parser.enable_ident_normalization configuration. \ - Valid fields are t1.c0, t1.c1."; + Did you mean 't1.c0'?."; assert_eq!(err.strip_backtrace(), expected); Ok(()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 013b1d5a2cab..237a04a26ffc 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -27,6 +27,7 @@ use std::io; use std::result; use std::sync::Arc; +use crate::utils::datafusion_strsim::normalized_levenshtein; use crate::utils::quote_identifier; use crate::{Column, DFSchema, Diagnostic, TableReference}; #[cfg(feature = "avro")] @@ -190,6 +191,11 @@ impl Display for SchemaError { .iter() .map(|column| column.flat_name().to_lowercase()) .collect::>(); + + let valid_fields_names = valid_fields + .iter() + .map(|column| column.flat_name()) + .collect::>(); if lower_valid_fields.contains(&field.flat_name().to_lowercase()) { write!( f, @@ -198,7 +204,15 @@ impl Display for SchemaError { field.quoted_flat_name() )?; } - if !valid_fields.is_empty() { + let field_name = field.name(); + if let Some(matched) = valid_fields_names + .iter() + .filter(|str| normalized_levenshtein(str, field_name) >= 0.5) + .collect::>() + .first() + { + write!(f, ". Did you mean '{matched}'?")?; + } else if !valid_fields.is_empty() { write!( f, ". Valid fields are {}", @@ -827,6 +841,27 @@ pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionE }) } +pub fn add_possible_columns_to_diag( + diagnostic: &mut Diagnostic, + field: &Column, + valid_fields: &[Column], +) { + let field_names: Vec = valid_fields + .iter() + .filter_map(|f| { + if normalized_levenshtein(f.name(), field.name()) >= 0.5 { + Some(f.flat_name()) + } else { + None + } + }) + .collect(); + + for name in field_names { + diagnostic.add_note(format!("possible column {}", name), None); + } +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index f2377cc5410a..d66cf7562278 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -735,6 +735,27 @@ pub mod datafusion_strsim { pub fn levenshtein(a: &str, b: &str) -> usize { generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) } + + /// Calculates the normalized Levenshtein distance between two strings. + /// The normalized distance is a value between 0.0 and 1.0, where 1.0 indicates + /// that the strings are identical and 0.0 indicates no similarity. + /// + /// ``` + /// use datafusion_common::utils::datafusion_strsim::normalized_levenshtein; + /// + /// assert!((normalized_levenshtein("kitten", "sitting") - 0.57142).abs() < 0.00001); + /// + /// assert!(normalized_levenshtein("", "second").abs() < 0.00001); + /// + /// assert!((normalized_levenshtein("kitten", "sitten") - 0.833).abs() < 0.001); + /// ``` + pub fn normalized_levenshtein(a: &str, b: &str) -> f64 { + if a.is_empty() && b.is_empty() { + return 1.0; + } + 1.0 - (levenshtein(a, b) as f64) + / (a.chars().count().max(b.chars().count()) as f64) + } } /// Merges collections `first` and `second`, removes duplicates and sorts the diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 85d428cae84f..b376d0c5d82e 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; +use datafusion_common::error::add_possible_columns_to_diag; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, SchemaError, @@ -368,10 +369,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } .map_err(|err: DataFusionError| match &err { DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, + SchemaError::FieldNotFound { + field, + valid_fields, + }, _, ) => { - let diagnostic = if let Some(relation) = &col.relation { + let mut diagnostic = if let Some(relation) = &col.relation { Diagnostic::new_error( format!( "column '{}' not found in '{}'", @@ -385,6 +389,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { col.spans().first(), ) }; + add_possible_columns_to_diag( + &mut diagnostic, + field, + valid_fields, + ); err.with_diagnostic(diagnostic) } _ => err, diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 55d3a953a728..9dae2d0c3e93 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -35,6 +35,7 @@ fn do_query(sql: &'static str) -> Diagnostic { collect_spans: true, ..ParserOptions::default() }; + let state = MockSessionState::default(); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); @@ -200,14 +201,8 @@ fn test_ambiguous_reference() -> Result<()> { let diag = do_query(query); assert_eq!(diag.message, "column 'first_name' is ambiguous"); assert_eq!(diag.span, Some(spans["a"])); - assert_eq!( - diag.notes[0].message, - "possible reference to 'first_name' in table 'a'" - ); - assert_eq!( - diag.notes[1].message, - "possible reference to 'first_name' in table 'b'" - ); + assert_eq!(diag.notes[0].message, "possible column a.first_name"); + assert_eq!(diag.notes[1].message, "possible column b.first_name"); Ok(()) } @@ -225,3 +220,57 @@ fn test_incompatible_types_binary_arithmetic() -> Result<()> { assert_eq!(diag.notes[1].span, Some(spans["right"])); Ok(()) } + +#[test] +fn test_field_not_found_suggestion() -> Result<()> { + let query = "SELECT /*whole*/first_na/*whole*/ FROM person"; + let spans = get_spans(query); + let diag = do_query(query); + assert_eq!(diag.message, "column 'first_na' not found"); + assert_eq!(diag.span, Some(spans["whole"])); + assert_eq!(diag.notes.len(), 1); + + let mut suggested_fields: Vec = diag + .notes + .iter() + .filter_map(|note| { + if note.message.starts_with("possible column") { + Some(note.message.replace("possible column ", "")) + } else { + None + } + }) + .collect(); + suggested_fields.sort(); + assert_eq!(suggested_fields[0], "person.first_name"); + Ok(()) +} + +#[test] +fn test_ambiguous_column_suggestion() -> Result<()> { + let query = "SELECT /*whole*/id/*whole*/ FROM test_decimal, person"; + let spans = get_spans(query); + let diag = do_query(query); + + assert_eq!(diag.message, "column 'id' is ambiguous"); + assert_eq!(diag.span, Some(spans["whole"])); + + assert_eq!(diag.notes.len(), 2); + + let mut suggested_fields: Vec = diag + .notes + .iter() + .filter_map(|note| { + if note.message.starts_with("possible column") { + Some(note.message.replace("possible column ", "")) + } else { + None + } + }) + .collect(); + + suggested_fields.sort(); + assert_eq!(suggested_fields, vec!["person.id", "test_decimal.id"]); + + Ok(()) +} diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 5a94ba9c0583..a35a4d6f28dc 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -169,4 +169,19 @@ statement ok INSERT INTO tab0 VALUES(83,0,38); query error DataFusion error: Arrow error: Divide by zero error -SELECT DISTINCT - 84 FROM tab0 AS cor0 WHERE NOT + 96 / + col1 <= NULL GROUP BY col1, col0; \ No newline at end of file +SELECT DISTINCT - 84 FROM tab0 AS cor0 WHERE NOT + 96 / + col1 <= NULL GROUP BY col1, col0; + +statement ok +create table a(timestamp int, birthday int, ts int, tokens int, amp int, staamp int); + +query error DataFusion error: Schema error: No field named timetamp\. Did you mean 'a\.timestamp'\?\. +select timetamp from a; + +query error DataFusion error: Schema error: No field named dadsada\. Valid fields are a\.timestamp, a\.birthday, a\.ts, a\.tokens, a\.amp, a\.staamp\. +select dadsada from a; + +query error DataFusion error: Schema error: No field named ammp\. Did you mean 'a\.amp'\?\. +select ammp from a; + +statement ok +drop table a; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/identifiers.slt b/datafusion/sqllogictest/test_files/identifiers.slt index 755d617e7a2a..e5eec3bf7f2c 100644 --- a/datafusion/sqllogictest/test_files/identifiers.slt +++ b/datafusion/sqllogictest/test_files/identifiers.slt @@ -90,16 +90,16 @@ drop table case_insensitive_test statement ok CREATE TABLE test("Column1" string) AS VALUES ('content1'); -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT COLumn1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT Column1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT column1 from test -statement error DataFusion error: Schema error: No field named column1\. Valid fields are test\."Column1"\. +statement error DataFusion error: Schema error: No field named column1\. Did you mean 'test\.Column1'\?\. SELECT "column1" from test statement ok diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index c88f419a9cb2..c7cfb6b9b026 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -94,7 +94,7 @@ statement ok set datafusion.execution.batch_size = 4096; # left semi with wrong where clause -query error DataFusion error: Schema error: No field named t2\.t2_id\. Valid fields are t1\.t1_id, t1\.t1_name, t1\.t1_int\. +query error DataFusion error: Schema error: No field named t2\.t2_id\. Did you mean 't1\.t1_id'\?\. SELECT t1.t1_id, t1.t1_name, t1.t1_int FROM t1 LEFT SEMI JOIN t2 ON t1.t1_id = t2.t2_id