diff --git a/src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs b/src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs index 8482d63a307..c6300a97b1e 100644 --- a/src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs @@ -24,7 +24,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { var visitedExpression = base.VisitBinary(binaryExpression); - return TryOptimizeConditionalEquality(visitedExpression) ?? visitedExpression; + return TryOptimizeConditionalEquality(visitedExpression) + ?? TryOptimizeQueryableNullCheck(visitedExpression) + ?? visitedExpression; } /// @@ -115,6 +117,34 @@ protected override Expression VisitConditional(ConditionalExpression conditional return null; } + private static Expression? TryOptimizeQueryableNullCheck(Expression expression) + { + // Optimize IQueryable/DbSet null checks: + // * IQueryable != null => true + // * IQueryable == null => false + if (expression is BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + } binaryExpression) + { + var isLeftNull = IsNullConstant(binaryExpression.Left); + var isRightNull = IsNullConstant(binaryExpression.Right); + + if (isLeftNull != isRightNull) + { + var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left; + + if (nonNullExpression.Type.IsAssignableTo(typeof(IQueryable))) + { + var result = binaryExpression.NodeType == ExpressionType.NotEqual; + return Expression.Constant(result); + } + } + } + + return null; + } + private sealed class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor { private readonly ISet _nullSafeAccesses = new HashSet(ExpressionEqualityComparer.Instance); diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs index f91ad1dee97..2a93b5d8fba 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs @@ -1699,6 +1699,22 @@ public override async Task Where_Queryable_AsEnumerable_Contains_negated(bool as AssertSql(); } + public override async Task Where_Queryable_not_null_check_with_Contains(bool async) + { + // Cosmos client evaluation. Issue #17246. + await AssertTranslationFailed(() => base.Where_Queryable_not_null_check_with_Contains(async)); + + AssertSql(); + } + + public override async Task Where_Queryable_null_check_with_Contains(bool async) + { + // Cosmos client evaluation. Issue #17246. + await AssertTranslationFailed(() => base.Where_Queryable_null_check_with_Contains(async)); + + AssertSql(); + } + public override Task Where_list_object_contains_over_value_type(bool async) => Fixture.NoSyncTest( async, async a => diff --git a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs index e4beae2c27e..6abf5422980 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs @@ -1438,6 +1438,27 @@ public virtual Task Where_Queryable_ToArray_Length_member(bool async) assertOrder: true, elementAsserter: (e, a) => AssertCollection(e, a)); + [ConditionalTheory, MemberData(nameof(IsAsyncData))] + public virtual Task Where_Queryable_not_null_check_with_Contains(bool async) + => AssertQuery( + async, + ss => + { + var ids = ss.Set().Select(c => c.CustomerID); + return ss.Set().Where(c => ids != null && ids.Contains(c.CustomerID)); + }); + + [ConditionalTheory, MemberData(nameof(IsAsyncData))] + public virtual Task Where_Queryable_null_check_with_Contains(bool async) + => AssertQuery( + async, + ss => + { + var ids = ss.Set().Select(c => c.CustomerID); + return ss.Set().Where(c => ids == null || !ids.Contains(c.CustomerID)); + }, + assertEmpty: true); + [ConditionalTheory, MemberData(nameof(IsAsyncData))] public virtual Task Where_collection_navigation_ToList_Count(bool async) => AssertQuery( diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs index 63a234e7e7c..2cc60852b77 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs @@ -1890,6 +1890,36 @@ ORDER BY [c].[CustomerID] """); } + public override async Task Where_Queryable_not_null_check_with_Contains(bool async) + { + await base.Where_Queryable_not_null_check_with_Contains(async); + + AssertSql( + """ +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] IN ( + SELECT [c0].[CustomerID] + FROM [Customers] AS [c0] +) +"""); + } + + public override async Task Where_Queryable_null_check_with_Contains(bool async) + { + await base.Where_Queryable_null_check_with_Contains(async); + + AssertSql( + """ +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] NOT IN ( + SELECT [c0].[CustomerID] + FROM [Customers] AS [c0] +) +"""); + } + public override async Task Where_collection_navigation_ToList_Count(bool async) { await base.Where_collection_navigation_ToList_Count(async);