1
0
mirror of https://github.com/Radarr/Radarr.git synced 2024-09-17 15:02:34 +02:00

Fixed: WhereBuilder exception when string variable null

This commit is contained in:
ta264 2020-01-22 12:31:31 +00:00
parent df101258c5
commit 01a03e9baf
3 changed files with 119 additions and 60 deletions

View File

@ -66,6 +66,23 @@ public void where_allows_abstract_condition_if_not_requiresConcreteCondition()
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")"); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")");
} }
[Test]
public void where_string_is_null()
{
_subject = Where(x => x.ImdbId == null);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test]
public void where_string_is_null_value()
{
string imdb = null;
_subject = Where(x => x.ImdbId == imdb);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)");
}
[Test] [Test]
public void where_column_contains_string() public void where_column_contains_string()
{ {

View File

@ -11,15 +11,13 @@ namespace NzbDrone.Core.Datastore
{ {
public class WhereBuilder : ExpressionVisitor public class WhereBuilder : ExpressionVisitor
{ {
private const DbType EnumerableMultiParameter = (DbType)(-1);
private readonly string _paramNamePrefix;
private int _paramCount = 0;
private bool _requireConcreteValue = false;
private bool _gotConcreteValue = false;
protected StringBuilder _sb; protected StringBuilder _sb;
public DynamicParameters Parameters { get; private set; } private const DbType EnumerableMultiParameter = (DbType)(-1);
private readonly string _paramNamePrefix;
private readonly bool _requireConcreteValue = false;
private int _paramCount = 0;
private bool _gotConcreteValue = false;
public WhereBuilder(Expression filter, bool requireConcreteValue) public WhereBuilder(Expression filter, bool requireConcreteValue)
{ {
@ -35,6 +33,8 @@ public WhereBuilder(Expression filter, bool requireConcreteValue)
} }
} }
public DynamicParameters Parameters { get; private set; }
private string AddParameter(object value, DbType? dbType = null) private string AddParameter(object value, DbType? dbType = null)
{ {
_gotConcreteValue = true; _gotConcreteValue = true;
@ -61,7 +61,7 @@ protected override Expression VisitBinary(BinaryExpression expression)
protected override Expression VisitMethodCall(MethodCallExpression expression) protected override Expression VisitMethodCall(MethodCallExpression expression)
{ {
string method = expression.Method.Name; var method = expression.Method.Name;
switch (expression.Method.Name) switch (expression.Method.Name)
{ {
@ -78,7 +78,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
break; break;
default: default:
string msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method); var msg = string.Format("'{0}' expressions are not yet implemented in the where clause expression tree parser.", method);
throw new NotImplementedException(msg); throw new NotImplementedException(msg);
} }
@ -87,7 +87,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
protected override Expression VisitMemberAccess(MemberExpression expression) protected override Expression VisitMemberAccess(MemberExpression expression)
{ {
string tableName = TableMapping.Mapper.TableNameMapping(expression.Expression.Type); var tableName = expression != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null;
if (tableName != null) if (tableName != null)
{ {
@ -95,27 +95,26 @@ protected override Expression VisitMemberAccess(MemberExpression expression)
} }
else else
{ {
object value = GetRightValue(expression); var value = GetRightValue(expression);
// string is IEnumerable<Char> but we don't want to pick up that case if (value != null)
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
bool isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
string paramName;
if (isEnumerable)
{ {
paramName = AddParameter(value, EnumerableMultiParameter); // string is IEnumerable<Char> but we don't want to pick up that case
var type = value.GetType();
var typeInfo = type.GetTypeInfo();
var isEnumerable =
type != typeof(string) && (
typeInfo.ImplementedInterfaces.Any(ti => ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(IEnumerable<>)) ||
(typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>)));
var paramName = isEnumerable ? AddParameter(value, EnumerableMultiParameter) : AddParameter(value);
_sb.Append(paramName);
} }
else else
{ {
paramName = AddParameter(value); _gotConcreteValue = true;
_sb.Append("NULL");
} }
_sb.Append(paramName);
} }
return expression; return expression;
@ -125,51 +124,78 @@ protected override Expression VisitConstant(ConstantExpression expression)
{ {
if (expression.Value != null) if (expression.Value != null)
{ {
string paramName = AddParameter(expression.Value); var paramName = AddParameter(expression.Value);
_sb.Append(paramName); _sb.Append(paramName);
} }
else else
{ {
_gotConcreteValue = true;
_sb.Append("NULL"); _sb.Append("NULL");
} }
return expression; return expression;
} }
private object GetRightValue(Expression rightExpression) private bool TryGetConstantValue(Expression expression, out object result)
{ {
object rightValue = null; if (expression is ConstantExpression constExp)
var right = rightExpression as ConstantExpression;
// Value is not directly passed in as a constant
if (right == null)
{ {
var rightMemberExp = rightExpression as MemberExpression; result = constExp.Value;
var parentMemberExpression = rightMemberExp.Expression as MemberExpression; return true;
}
result = null;
return false;
}
private bool TryGetPropertyValue(MemberExpression expression, out object result)
{
if (expression.Expression is MemberExpression nested)
{
// Value is passed in as a property on a parent entity // Value is passed in as a property on a parent entity
if (parentMemberExpression != null) var container = (nested.Expression as ConstantExpression).Value;
{ var entity = GetFieldValue(container, nested.Member);
var memberInfo = (rightMemberExp.Expression as MemberExpression).Member; result = GetFieldValue(entity, expression.Member);
var container = ((rightMemberExp.Expression as MemberExpression).Expression as ConstantExpression).Value; return true;
var entity = GetFieldValue(container, memberInfo);
rightValue = GetFieldValue(entity, rightMemberExp.Member);
}
else
{
// Value is passed in as a variable
var parent = (rightMemberExp.Expression as ConstantExpression).Value;
rightValue = GetFieldValue(parent, rightMemberExp.Member);
}
}
else
{
// Value is passed in directly as a constant
rightValue = right.Value;
} }
return rightValue; result = null;
return false;
}
private bool TryGetVariableValue(MemberExpression expression, out object result)
{
// Value is passed in as a variable
if (expression.Expression is ConstantExpression nested)
{
result = GetFieldValue(nested.Value, expression.Member);
return true;
}
result = null;
return false;
}
private object GetRightValue(Expression expression)
{
if (TryGetConstantValue(expression, out var constValue))
{
return constValue;
}
var memberExp = expression as MemberExpression;
if (TryGetPropertyValue(memberExp, out var propValue))
{
return propValue;
}
if (TryGetVariableValue(memberExp, out var variableValue))
{
return variableValue;
}
return null;
} }
private object GetFieldValue(object entity, MemberInfo member) private object GetFieldValue(object entity, MemberInfo member)
@ -187,13 +213,29 @@ private object GetFieldValue(object entity, MemberInfo member)
throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name)); throw new ArgumentException(string.Format("WhereBuilder could not get the value for {0}.{1}.", entity.GetType().Name, member.Name));
} }
private bool IsNullVariable(Expression expression)
{
if (expression.NodeType == ExpressionType.Constant &&
TryGetConstantValue(expression, out var constResult) &&
constResult == null)
{
return true;
}
if (expression.NodeType == ExpressionType.MemberAccess &&
expression is MemberExpression member &&
TryGetVariableValue(member, out var variableResult) &&
variableResult == null)
{
return true;
}
return false;
}
private string Decode(BinaryExpression expression) private string Decode(BinaryExpression expression)
{ {
bool isRightSideNullConstant = expression.Right.NodeType == if (IsNullVariable(expression.Right))
ExpressionType.Constant &&
((ConstantExpression)expression.Right).Value == null;
if (isRightSideNullConstant)
{ {
switch (expression.NodeType) switch (expression.NodeType)
{ {

View File

@ -120,7 +120,7 @@ public List<Movie> FindByTitleInexact(string cleanTitle)
public Movie FindByImdbId(string imdbid) public Movie FindByImdbId(string imdbid)
{ {
var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid); var imdbIdWithPrefix = Parser.Parser.NormalizeImdbId(imdbid);
return Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault(); return imdbIdWithPrefix == null ? null : Query(x => x.ImdbId == imdbIdWithPrefix).FirstOrDefault();
} }
public Movie FindByTmdbId(int tmdbid) public Movie FindByTmdbId(int tmdbid)