1
0
mirror of https://github.com/Radarr/Radarr.git synced 2024-09-11 03:52:33 +02:00

Fixed: WhereBuilder exception when string variable null

This commit is contained in:
ta264 2020-01-22 12:31:31 +00:00
parent df101258c5
commit 01a03e9baf
3 changed files with 119 additions and 60 deletions

View File

@ -66,6 +66,23 @@ public void where_allows_abstract_condition_if_not_requiresConcreteCondition()
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")");
}
[Test]
public void where_string_is_null()
{
_subject = Where(x => x.ImdbId == null);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test]
public void where_string_is_null_value()
{
string imdb = null;
_subject = Where(x => x.ImdbId == imdb);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test]
public void where_column_contains_string()
{

View File

@ -11,15 +11,13 @@ namespace NzbDrone.Core.Datastore
{
public class WhereBuilder : ExpressionVisitor
{
private const DbType EnumerableMultiParameter = (DbType)(-1);
private readonly string _paramNamePrefix;
private int _paramCount = 0;
private bool _requireConcreteValue = false;
private bool _gotConcreteValue = false;
protected StringBuilder _sb;
public DynamicParameters Parameters { get; private set; }
private const DbType EnumerableMultiParameter = (DbType)(-1);
private readonly string _paramNamePrefix;
private readonly bool _requireConcreteValue = false;
private int _paramCount = 0;
private bool _gotConcreteValue = false;
public WhereBuilder(Expression filter, bool requireConcreteValue)
{
@ -35,6 +33,8 @@ public WhereBuilder(Expression filter, bool requireConcreteValue)
}
}
public DynamicParameters Parameters { get; private set; }
private string AddParameter(object value, DbType? dbType = null)
{
_gotConcreteValue = true;
@ -61,7 +61,7 @@ protected override Expression VisitBinary(BinaryExpression expression)
protected override Expression VisitMethodCall(MethodCallExpression expression)
{
string method = expression.Method.Name;
var method = expression.Method.Name;
switch (expression.Method.Name)
{
@ -78,7 +78,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
break;
default:
string msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method);
var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method);
throw new NotImplementedException(msg);
}
@ -87,7 +87,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
protected override Expression VisitMemberAccess(MemberExpression expression)
{
string tableName = TableMapping.Mapper.TableNameMapping(expression.Expression.Type);
var tableName = expression != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null;
if (tableName != null)
{
@ -95,27 +95,26 @@ protected override Expression VisitMemberAccess(MemberExpression expression)
}
else
{
object value = GetRightValue(expression);
var value = GetRightValue(expression);
// string is IEnumerable<Char> but we don't want to pick up that case
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
bool isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
string paramName;
if (isEnumerable)
if (value != null)
{
paramName = AddParameter(value, EnumerableMultiParameter);
// string is IEnumerable<Char> but we don't want to pick up that case
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
var isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value);
_sb.Append(paramName);
}
else
{
paramName = AddParameter(value);
_gotConcreteValue = true;
_sb.Append("NULL");
}
_sb.Append(paramName);
}
return expression;
@ -125,51 +124,78 @@ protected override Expression VisitConstant(ConstantExpression expression)
{
if (expression.Value != null)
{
string paramName = AddParameter(expression.Value);
var paramName = AddParameter(expression.Value);
_sb.Append(paramName);
}
else
{
_gotConcreteValue = true;
_sb.Append("NULL");
}
return expression;
}
private object GetRightValue(Expression rightExpression)
private bool TryGetConstantValue(Expression expression, out object result)
{
object rightValue = null;
var right = rightExpression as ConstantExpression;
// Value is not directly passed in as a constant
if (right == null)
if (expression is ConstantExpression constExp)
{
var rightMemberExp = rightExpression as MemberExpression;
var parentMemberExpression = rightMemberExp.Expression as MemberExpression;
result = constExp.Value;
return true;
}
result = null;
return false;
}
private bool TryGetPropertyValue(MemberExpression expression, out object result)
{
if (expression.Expression is MemberExpression nested)
{
// Value is passed in as a property on a parent entity
if (parentMemberExpression != null)
{
var memberInfo = (rightMemberExp.Expression as MemberExpression).Member;
var container = ((rightMemberExp.Expression as MemberExpression).Expression as ConstantExpression).Value;
var entity = GetFieldValue(container, memberInfo);
rightValue = GetFieldValue(entity, rightMemberExp.Member);
}
else
{
// Value is passed in as a variable
var parent = (rightMemberExp.Expression as ConstantExpression).Value;
rightValue = GetFieldValue(parent, rightMemberExp.Member);
}
}
else
{
// Value is passed in directly as a constant
rightValue = right.Value;
var container = (nested.Expression as ConstantExpression).Value;
var entity = GetFieldValue(container, nested.Member);
result = GetFieldValue(entity, expression.Member);
return true;
}
return rightValue;
result = null;
return false;
}
private bool TryGetVariableValue(MemberExpression expression, out object result)
{
// Value is passed in as a variable
if (expression.Expression is ConstantExpression nested)
{
result = GetFieldValue(nested.Value, expression.Member);
return true;
}
result = null;
return false;
}
private object GetRightValue(Expression expression)
{
if (TryGetConstantValue(expression, out var constValue))
{
return constValue;
}
var memberExp = expression as MemberExpression;
if (TryGetPropertyValue(memberExp, out var propValue))
{
return propValue;
}
if (TryGetVariableValue(memberExp, out var variableValue))
{
return variableValue;
}
return null;
}
private object GetFieldValue(object entity, MemberInfo member)
@ -187,13 +213,29 @@ private object GetFieldValue(object entity, MemberInfo member)
throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name));
}
private bool IsNullVariable(Expression expression)
{
if (expression.NodeType == ExpressionType.Constant &&
TryGetConstantValue(expression, out var constResult) &&
constResult == null)
{
return true;
}
if (expression.NodeType == ExpressionType.MemberAccess &&
expression is MemberExpression member &&
TryGetVariableValue(member, out var variableResult) &&
variableResult == null)
{
return true;
}
return false;
}
private string Decode(BinaryExpression expression)
{
bool isRightSideNullConstant = expression.Right.NodeType ==
ExpressionType.Constant &&
((ConstantExpression)expression.Right).Value == null;
if (isRightSideNullConstant)
if (IsNullVariable(expression.Right))
{
switch (expression.NodeType)
{

View File

@ -120,7 +120,7 @@ public List<Movie> FindByTitleInexact(string cleanTitle)
public Movie FindByImdbId(string imdbid)
{
var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid);
return Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault();
return imdbIdWithPrefix == null ? null : Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault();
}
public Movie FindByTmdbId(int tmdbid)