From 01a03e9baffbfbfbf8611f0c61d02ea85613521a Mon Sep 17 00:00:00 2001 From: ta264 Date: Wed, 22 Jan 2020 12:31:31 +0000 Subject: [PATCH] Fixed: WhereBuilder exception when string variable null --- .../Datastore/WhereBuilderFixture.cs | 17 ++ src/NzbDrone.Core/Datastore/WhereBuilder.cs | 160 +++++++++++------- src/NzbDrone.Core/Movies/MovieRepository.cs | 2 +- 3 files changed, 119 insertions(+), 60 deletions(-) diff --git a/src/NzbDrone.Core.Test/Datastore/WhereBuilderFixture.cs b/src/NzbDrone.Core.Test/Datastore/WhereBuilderFixture.cs index d90111c56..f6882efb8 100644 --- a/src/NzbDrone.Core.Test/Datastore/WhereBuilderFixture.cs +++ b/src/NzbDrone.Core.Test/Datastore/WhereBuilderFixture.cs @@ -66,6 +66,23 @@ public void where_allows_abstract_condition_if_not_requiresConcreteCondition() _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")"); } + [Test] + public void where_string_is_null() + { + _subject = Where(x => x.ImdbId == null); + + _subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)"); + } + + [Test] + public void where_string_is_null_value() + { + string imdb = null; + _subject = Where(x => x.ImdbId == imdb); + + _subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)"); + } + [Test] public void where_column_contains_string() { diff --git a/src/NzbDrone.Core/Datastore/WhereBuilder.cs b/src/NzbDrone.Core/Datastore/WhereBuilder.cs index f4c2fdd21..24b3b17e3 100644 --- a/src/NzbDrone.Core/Datastore/WhereBuilder.cs +++ b/src/NzbDrone.Core/Datastore/WhereBuilder.cs @@ -11,15 +11,13 @@ namespace NzbDrone.Core.Datastore { public class WhereBuilder : ExpressionVisitor { - private const DbType EnumerableMultiParameter = (DbType)(-1); - - private readonly string _paramNamePrefix; - private int _paramCount = 0; - private bool _requireConcreteValue = false; - private bool _gotConcreteValue = false; protected StringBuilder _sb; - public DynamicParameters Parameters { get; private set; } + private const DbType EnumerableMultiParameter = (DbType)(-1); + private readonly string _paramNamePrefix; + private readonly bool _requireConcreteValue = false; + private int _paramCount = 0; + private bool _gotConcreteValue = false; public WhereBuilder(Expression filter, bool requireConcreteValue) { @@ -35,6 +33,8 @@ public WhereBuilder(Expression filter, bool requireConcreteValue) } } + public DynamicParameters Parameters { get; private set; } + private string AddParameter(object value, DbType? dbType = null) { _gotConcreteValue = true; @@ -61,7 +61,7 @@ protected override Expression VisitBinary(BinaryExpression expression) protected override Expression VisitMethodCall(MethodCallExpression expression) { - string method = expression.Method.Name; + var method = expression.Method.Name; switch (expression.Method.Name) { @@ -78,7 +78,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) break; default: - string msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); + var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); throw new NotImplementedException(msg); } @@ -87,7 +87,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) protected override Expression VisitMemberAccess(MemberExpression expression) { - string tableName = TableMapping.Mapper.TableNameMapping(expression.Expression.Type); + var tableName = expression != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null; if (tableName != null) { @@ -95,27 +95,26 @@ protected override Expression VisitMemberAccess(MemberExpression expression) } else { - object value = GetRightValue(expression); + var value = GetRightValue(expression); - // string is IEnumerable but we don't want to pick up that case - var type = value.GetType(); - var typeInfo = type.GetTypeInfo(); - bool isEnumerable = - type != typeof(string) && ( - typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) || - (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))); - - string paramName; - if (isEnumerable) + if (value != null) { - paramName = AddParameter(value, EnumerableMultiParameter); + // string is IEnumerable but we don't want to pick up that case + var type = value.GetType(); + var typeInfo = type.GetTypeInfo(); + var isEnumerable = + type != typeof(string) && ( + typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) || + (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))); + + var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value); + _sb.Append(paramName); } else { - paramName = AddParameter(value); + _gotConcreteValue = true; + _sb.Append("NULL"); } - - _sb.Append(paramName); } return expression; @@ -125,51 +124,78 @@ protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value != null) { - string paramName = AddParameter(expression.Value); + var paramName = AddParameter(expression.Value); _sb.Append(paramName); } else { + _gotConcreteValue = true; _sb.Append("NULL"); } return expression; } - private object GetRightValue(Expression rightExpression) + private bool TryGetConstantValue(Expression expression, out object result) { - object rightValue = null; - - var right = rightExpression as ConstantExpression; - - // Value is not directly passed in as a constant - if (right == null) + if (expression is ConstantExpression constExp) { - var rightMemberExp = rightExpression as MemberExpression; - var parentMemberExpression = rightMemberExp.Expression as MemberExpression; + result = constExp.Value; + return true; + } + result = null; + return false; + } + + private bool TryGetPropertyValue(MemberExpression expression, out object result) + { + if (expression.Expression is MemberExpression nested) + { // Value is passed in as a property on a parent entity - if (parentMemberExpression != null) - { - var memberInfo = (rightMemberExp.Expression as MemberExpression).Member; - var container = ((rightMemberExp.Expression as MemberExpression).Expression as ConstantExpression).Value; - var entity = GetFieldValue(container, memberInfo); - rightValue = GetFieldValue(entity, rightMemberExp.Member); - } - else - { - // Value is passed in as a variable - var parent = (rightMemberExp.Expression as ConstantExpression).Value; - rightValue = GetFieldValue(parent, rightMemberExp.Member); - } - } - else - { - // Value is passed in directly as a constant - rightValue = right.Value; + var container = (nested.Expression as ConstantExpression).Value; + var entity = GetFieldValue(container, nested.Member); + result = GetFieldValue(entity, expression.Member); + return true; } - return rightValue; + result = null; + return false; + } + + private bool TryGetVariableValue(MemberExpression expression, out object result) + { + // Value is passed in as a variable + if (expression.Expression is ConstantExpression nested) + { + result = GetFieldValue(nested.Value, expression.Member); + return true; + } + + result = null; + return false; + } + + private object GetRightValue(Expression expression) + { + if (TryGetConstantValue(expression, out var constValue)) + { + return constValue; + } + + var memberExp = expression as MemberExpression; + + if (TryGetPropertyValue(memberExp, out var propValue)) + { + return propValue; + } + + if (TryGetVariableValue(memberExp, out var variableValue)) + { + return variableValue; + } + + return null; } private object GetFieldValue(object entity, MemberInfo member) @@ -187,13 +213,29 @@ private object GetFieldValue(object entity, MemberInfo member) throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); } + private bool IsNullVariable(Expression expression) + { + if (expression.NodeType == ExpressionType.Constant && + TryGetConstantValue(expression, out var constResult) && + constResult == null) + { + return true; + } + + if (expression.NodeType == ExpressionType.MemberAccess && + expression is MemberExpression member && + TryGetVariableValue(member, out var variableResult) && + variableResult == null) + { + return true; + } + + return false; + } + private string Decode(BinaryExpression expression) { - bool isRightSideNullConstant = expression.Right.NodeType == - ExpressionType.Constant && - ((ConstantExpression)expression.Right).Value == null; - - if (isRightSideNullConstant) + if (IsNullVariable(expression.Right)) { switch (expression.NodeType) { diff --git a/src/NzbDrone.Core/Movies/MovieRepository.cs b/src/NzbDrone.Core/Movies/MovieRepository.cs index 59c8fa307..dedcbf4dd 100644 --- a/src/NzbDrone.Core/Movies/MovieRepository.cs +++ b/src/NzbDrone.Core/Movies/MovieRepository.cs @@ -120,7 +120,7 @@ public List FindByTitleInexact(string cleanTitle) public Movie FindByImdbId(string imdbid) { var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid); - return Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault(); + return imdbIdWithPrefix == null ? null : Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault(); } public Movie FindByTmdbId(int tmdbid)