1
0
mirror of https://github.com/Radarr/Radarr.git synced 2024-09-17 15:02:34 +02:00

Revert "Revert "Fixed: Rename more than 999 movies in one go""

This reverts commit c0b80696bc.
This commit is contained in:
ta264 2020-05-14 20:56:26 +01:00 committed by Qstick
parent ebc72bfba9
commit 32a6c9fe2a
16 changed files with 840 additions and 263 deletions

View File

@ -45,7 +45,7 @@ public void test_mappable_types()
{ {
var properties = typeof(TypeWithAllMappableProperties).GetProperties(); var properties = typeof(TypeWithAllMappableProperties).GetProperties();
properties.Should().NotBeEmpty(); properties.Should().NotBeEmpty();
properties.Should().OnlyContain(c => ColumnMapper<int>.IsMappableProperty(c)); properties.Should().OnlyContain(c => c.IsMappableProperty());
} }
[Test] [Test]
@ -53,7 +53,7 @@ public void test_un_mappable_types()
{ {
var properties = typeof(TypeWithNoMappableProperties).GetProperties(); var properties = typeof(TypeWithNoMappableProperties).GetProperties();
properties.Should().NotBeEmpty(); properties.Should().NotBeEmpty();
properties.Should().NotContain(c => ColumnMapper<int>.IsMappableProperty(c)); properties.Should().NotContain(c => c.IsMappableProperty());
} }
} }
} }

View File

@ -24,7 +24,7 @@ public void MapTables()
private WhereBuilder Where(Expression<Func<Movie, bool>> filter) private WhereBuilder Where(Expression<Func<Movie, bool>> filter)
{ {
return new WhereBuilder(filter, true); return new WhereBuilder(filter, true, 0);
} }
[Test] [Test]
@ -32,9 +32,8 @@ public void where_equal_const()
{ {
_subject = Where(x => x.Id == 10); _subject = Where(x => x.Id == 10);
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = @{name})"); _subject.Parameters.Get<int>("Clause1_P1").Should().Be(10);
_subject.Parameters.Get<int>(name).Should().Be(10);
} }
[Test] [Test]
@ -43,44 +42,71 @@ public void where_equal_variable()
var id = 10; var id = 10;
_subject = Where(x => x.Id == id); _subject = Where(x => x.Id == id);
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = @{name})"); _subject.Parameters.Get<int>("Clause1_P1").Should().Be(id);
_subject.Parameters.Get<int>(name).Should().Be(id); }
[Test]
public void where_equal_property()
{
var movie = new Movie { Id = 10 };
_subject = Where(x => x.Id == movie.Id);
_subject.Parameters.ParameterNames.Should().HaveCount(1);
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = @Clause1_P1)");
_subject.Parameters.Get<int>("Clause1_P1").Should().Be(movie.Id);
}
[Test]
public void where_equal_joined_property()
{
_subject = Where(x => x.Profile.Id == 1);
_subject.Parameters.ParameterNames.Should().HaveCount(1);
_subject.ToString().Should().Be($"(\"Profiles\".\"Id\" = @Clause1_P1)");
_subject.Parameters.Get<int>("Clause1_P1").Should().Be(1);
} }
[Test] [Test]
public void where_throws_without_concrete_condition_if_requiresConcreteCondition() public void where_throws_without_concrete_condition_if_requiresConcreteCondition()
{ {
var movie = new Movie(); Expression<Func<Movie, Movie, bool>> filter = (x, y) => x.Id == y.Id;
Expression<Func<Movie, bool>> filter = (x) => x.Id == movie.Id; _subject = new WhereBuilder(filter, true, 0);
_subject = new WhereBuilder(filter, true);
Assert.Throws<InvalidOperationException>(() => _subject.ToString()); Assert.Throws<InvalidOperationException>(() => _subject.ToString());
} }
[Test] [Test]
public void where_allows_abstract_condition_if_not_requiresConcreteCondition() public void where_allows_abstract_condition_if_not_requiresConcreteCondition()
{ {
var movie = new Movie(); Expression<Func<Movie, Movie, bool>> filter = (x, y) => x.Id == y.Id;
Expression<Func<Movie, bool>> filter = (x) => x.Id == movie.Id; _subject = new WhereBuilder(filter, false, 0);
_subject = new WhereBuilder(filter, false);
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")"); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" = \"Movies\".\"Id\")");
} }
[Test] [Test]
public void where_string_is_null() public void where_string_is_null()
{ {
_subject = Where(x => x.ImdbId == null); _subject = Where(x => x.CleanTitle == null);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)"); _subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" IS NULL)");
} }
[Test] [Test]
public void where_string_is_null_value() public void where_string_is_null_value()
{ {
string imdb = null; string cleanTitle = null;
_subject = Where(x => x.ImdbId == imdb); _subject = Where(x => x.CleanTitle == cleanTitle);
_subject.ToString().Should().Be($"(\"Movies\".\"ImdbId\" IS NULL)"); _subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" IS NULL)");
}
[Test]
public void where_equal_null_property()
{
var movie = new Movie { CleanTitle = null };
_subject = Where(x => x.CleanTitle == movie.CleanTitle);
_subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" IS NULL)");
} }
[Test] [Test]
@ -89,9 +115,8 @@ public void where_column_contains_string()
var test = "small"; var test = "small";
_subject = Where(x => x.CleanTitle.Contains(test)); _subject = Where(x => x.CleanTitle.Contains(test));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE '%' || @Clause1_P1 || '%')");
_subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE '%' || @{name} || '%')"); _subject.Parameters.Get<string>("Clause1_P1").Should().Be(test);
_subject.Parameters.Get<string>(name).Should().Be(test);
} }
[Test] [Test]
@ -100,9 +125,8 @@ public void where_string_contains_column()
var test = "small"; var test = "small";
_subject = Where(x => test.Contains(x.CleanTitle)); _subject = Where(x => test.Contains(x.CleanTitle));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(@Clause1_P1 LIKE '%' || \"Movies\".\"CleanTitle\" || '%')");
_subject.ToString().Should().Be($"(@{name} LIKE '%' || \"Movies\".\"CleanTitle\" || '%')"); _subject.Parameters.Get<string>("Clause1_P1").Should().Be(test);
_subject.Parameters.Get<string>(name).Should().Be(test);
} }
[Test] [Test]
@ -111,9 +135,8 @@ public void where_column_starts_with_string()
var test = "small"; var test = "small";
_subject = Where(x => x.CleanTitle.StartsWith(test)); _subject = Where(x => x.CleanTitle.StartsWith(test));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE @Clause1_P1 || '%')");
_subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE @{name} || '%')"); _subject.Parameters.Get<string>("Clause1_P1").Should().Be(test);
_subject.Parameters.Get<string>(name).Should().Be(test);
} }
[Test] [Test]
@ -122,9 +145,8 @@ public void where_column_ends_with_string()
var test = "small"; var test = "small";
_subject = Where(x => x.CleanTitle.EndsWith(test)); _subject = Where(x => x.CleanTitle.EndsWith(test));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE '%' || @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"CleanTitle\" LIKE '%' || @{name})"); _subject.Parameters.Get<string>("Clause1_P1").Should().Be(test);
_subject.Parameters.Get<string>(name).Should().Be(test);
} }
[Test] [Test]
@ -133,11 +155,9 @@ public void where_in_list()
var list = new List<int> { 1, 2, 3 }; var list = new List<int> { 1, 2, 3 };
_subject = Where(x => list.Contains(x.Id)); _subject = Where(x => list.Contains(x.Id));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Id\" IN (1, 2, 3))");
_subject.ToString().Should().Be($"(\"Movies\".\"Id\" IN @{name})");
var param = _subject.Parameters.Get<List<int>>(name); _subject.Parameters.ParameterNames.Should().BeEmpty();
param.Should().BeEquivalentTo(list);
} }
[Test] [Test]
@ -146,37 +166,33 @@ public void where_in_list_2()
var list = new List<int> { 1, 2, 3 }; var list = new List<int> { 1, 2, 3 };
_subject = Where(x => x.CleanTitle == "test" && list.Contains(x.Id)); _subject = Where(x => x.CleanTitle == "test" && list.Contains(x.Id));
var names = _subject.Parameters.ParameterNames.ToList(); _subject.ToString().Should().Be($"((\"Movies\".\"CleanTitle\" = @Clause1_P1) AND (\"Movies\".\"Id\" IN (1, 2, 3)))");
_subject.ToString().Should().Be($"((\"Movies\".\"CleanTitle\" = @{names[0]}) AND (\"Movies\".\"Id\" IN @{names[1]}))");
} }
[Test] [Test]
public void enum_as_int() public void enum_as_int()
{ {
_subject = Where(x => x.Status == MovieStatusType.Released); _subject = Where(x => x.Status == MovieStatusType.Announced);
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Status\" = @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"Status\" = @{name})");
} }
[Test] [Test]
public void enum_in_list() public void enum_in_list()
{ {
var allowed = new List<MovieStatusType> { MovieStatusType.InCinemas, MovieStatusType.Released }; var allowed = new List<MovieStatusType> { MovieStatusType.Announced, MovieStatusType.InCinemas };
_subject = Where(x => allowed.Contains(x.Status)); _subject = Where(x => allowed.Contains(x.Status));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Status\" IN @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"Status\" IN @{name})");
} }
[Test] [Test]
public void enum_in_array() public void enum_in_array()
{ {
var allowed = new MovieStatusType[] { MovieStatusType.InCinemas, MovieStatusType.Released }; var allowed = new MovieStatusType[] { MovieStatusType.Announced, MovieStatusType.InCinemas };
_subject = Where(x => allowed.Contains(x.Status)); _subject = Where(x => allowed.Contains(x.Status));
var name = _subject.Parameters.ParameterNames.First(); _subject.ToString().Should().Be($"(\"Movies\".\"Status\" IN @Clause1_P1)");
_subject.ToString().Should().Be($"(\"Movies\".\"Status\" IN @{name})");
} }
} }
} }

View File

@ -1,6 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Dapper;
using NzbDrone.Core.Datastore; using NzbDrone.Core.Datastore;
using NzbDrone.Core.Messaging.Events; using NzbDrone.Core.Messaging.Events;
using NzbDrone.Core.Movies; using NzbDrone.Core.Movies;
@ -36,23 +34,11 @@ public List<Blacklist> BlacklistedByMovie(int movieId)
return Query(x => x.MovieId == movieId); return Query(x => x.MovieId == movieId);
} }
private IEnumerable<Blacklist> SelectJoined(SqlBuilder.Template sql) protected override SqlBuilder PagedBuilder() => new SqlBuilder().Join<Blacklist, Movie>((b, m) => b.MovieId == m.Id);
{ protected override IEnumerable<Blacklist> PagedQuery(SqlBuilder sql) => _database.QueryJoined<Blacklist, Movie>(sql, (bl, movie) =>
using (var conn = _database.OpenConnection())
{
return conn.Query<Blacklist, Movie, Blacklist>(
sql.RawSql,
(bl, movie) =>
{ {
bl.Movie = movie; bl.Movie = movie;
return bl; return bl;
}, });
sql.Parameters)
.ToList();
}
}
protected override SqlBuilder PagedBuilder() => new SqlBuilder().Join<Blacklist, Movie>((b, m) => b.MovieId == m.Id);
protected override IEnumerable<Blacklist> PagedSelector(SqlBuilder.Template sql) => SelectJoined(sql);
} }
} }

View File

@ -47,8 +47,6 @@ public class BasicRepository<TModel> : IBasicRepository<TModel>
protected readonly IDatabase _database; protected readonly IDatabase _database;
protected readonly string _table; protected readonly string _table;
protected string _selectTemplate;
protected string _deleteTemplate;
public BasicRepository(IDatabase database, IEventAggregator eventAggregator) public BasicRepository(IDatabase database, IEventAggregator eventAggregator)
{ {
@ -62,42 +60,19 @@ public BasicRepository(IDatabase database, IEventAggregator eventAggregator)
var excluded = TableMapping.Mapper.ExcludeProperties(type).Select(x => x.Name).ToList(); var excluded = TableMapping.Mapper.ExcludeProperties(type).Select(x => x.Name).ToList();
excluded.Add(_keyProperty.Name); excluded.Add(_keyProperty.Name);
_properties = type.GetProperties().Where(x => !excluded.Contains(x.Name)).ToList(); _properties = type.GetProperties().Where(x => x.IsMappableProperty() && !excluded.Contains(x.Name)).ToList();
_insertSql = GetInsertSql(); _insertSql = GetInsertSql();
_updateSql = GetUpdateSql(_properties); _updateSql = GetUpdateSql(_properties);
_selectTemplate = $"SELECT /**select**/ FROM {_table} /**join**/ /**innerjoin**/ /**leftjoin**/ /**where**/ /**orderby**/";
_deleteTemplate = $"DELETE FROM {_table} /**where**/";
} }
protected virtual SqlBuilder BuilderBase() => new SqlBuilder(); protected virtual SqlBuilder Builder() => new SqlBuilder();
protected virtual SqlBuilder Builder() => BuilderBase().SelectAll();
protected virtual IEnumerable<TModel> GetResults(SqlBuilder.Template sql) protected virtual List<TModel> Query(SqlBuilder builder) => _database.Query<TModel>(builder).ToList();
{
using (var conn = _database.OpenConnection())
{
return conn.Query<TModel>(sql.RawSql, sql.Parameters);
}
}
protected List<TModel> Query(Expression<Func<TModel, bool>> where) protected List<TModel> Query(Expression<Func<TModel, bool>> where) => Query(Builder().Where(where));
{
return Query(Builder().Where<TModel>(where));
}
protected List<TModel> Query(SqlBuilder builder) protected virtual List<TModel> QueryDistinct(SqlBuilder builder) => _database.QueryDistinct<TModel>(builder).ToList();
{
return Query(builder, GetResults);
}
protected List<TModel> Query(SqlBuilder builder, Func<SqlBuilder.Template, IEnumerable<TModel>> queryFunc)
{
var sql = builder.AddTemplate(_selectTemplate).LogQuery();
return queryFunc(sql).ToList();
}
public int Count() public int Count()
{ {
@ -197,6 +172,7 @@ private string GetInsertSql()
private TModel Insert(IDbConnection connection, IDbTransaction transaction, TModel model) private TModel Insert(IDbConnection connection, IDbTransaction transaction, TModel model)
{ {
SqlBuilderExtensions.LogQuery(_insertSql, model);
var multi = connection.QueryMultiple(_insertSql, model, transaction); var multi = connection.QueryMultiple(_insertSql, model, transaction);
var id = (int)multi.Read().First().id; var id = (int)multi.Read().First().id;
_keyProperty.SetValue(model, id); _keyProperty.SetValue(model, id);
@ -262,7 +238,7 @@ protected void Delete(Expression<Func<TModel, bool>> where)
protected void Delete(SqlBuilder builder) protected void Delete(SqlBuilder builder)
{ {
var sql = builder.AddTemplate(_deleteTemplate).LogQuery(); var sql = builder.AddDeleteTemplate(typeof(TModel)).LogQuery();
using (var conn = _database.OpenConnection()) using (var conn = _database.OpenConnection())
{ {
@ -368,7 +344,7 @@ public void SetFields(IList<TModel> models, params Expression<Func<TModel, objec
private string GetUpdateSql(List<PropertyInfo> propertiesToUpdate) private string GetUpdateSql(List<PropertyInfo> propertiesToUpdate)
{ {
var sb = new StringBuilder(); var sb = new StringBuilder();
sb.AppendFormat("update {0} set ", _table); sb.AppendFormat("UPDATE {0} SET ", _table);
for (var i = 0; i < propertiesToUpdate.Count; i++) for (var i = 0; i < propertiesToUpdate.Count; i++)
{ {
@ -380,7 +356,7 @@ private string GetUpdateSql(List<PropertyInfo> propertiesToUpdate)
} }
} }
sb.Append($" where \"{_keyProperty.Name}\" = @{_keyProperty.Name}"); sb.Append($" WHERE \"{_keyProperty.Name}\" = @{_keyProperty.Name}");
return sb.ToString(); return sb.ToString();
} }
@ -389,6 +365,8 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction,
{ {
var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate);
SqlBuilderExtensions.LogQuery(sql, model);
connection.Execute(sql, model, transaction: transaction); connection.Execute(sql, model, transaction: transaction);
} }
@ -396,15 +374,20 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction,
{ {
var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate);
foreach (var model in models)
{
SqlBuilderExtensions.LogQuery(sql, model);
}
connection.Execute(sql, models, transaction: transaction); connection.Execute(sql, models, transaction: transaction);
} }
protected virtual SqlBuilder PagedBuilder() => BuilderBase(); protected virtual SqlBuilder PagedBuilder() => Builder();
protected virtual IEnumerable<TModel> PagedSelector(SqlBuilder.Template sql) => GetResults(sql); protected virtual IEnumerable<TModel> PagedQuery(SqlBuilder sql) => Query(sql);
public virtual PagingSpec<TModel> GetPaged(PagingSpec<TModel> pagingSpec) public virtual PagingSpec<TModel> GetPaged(PagingSpec<TModel> pagingSpec)
{ {
pagingSpec.Records = GetPagedRecords(PagedBuilder().SelectAll(), pagingSpec, PagedSelector); pagingSpec.Records = GetPagedRecords(PagedBuilder(), pagingSpec, PagedQuery);
pagingSpec.TotalRecords = GetPagedRecordCount(PagedBuilder().SelectCount(), pagingSpec); pagingSpec.TotalRecords = GetPagedRecordCount(PagedBuilder().SelectCount(), pagingSpec);
return pagingSpec; return pagingSpec;
@ -420,7 +403,7 @@ private void AddFilters(SqlBuilder builder, PagingSpec<TModel> pagingSpec)
} }
} }
protected List<TModel> GetPagedRecords(SqlBuilder builder, PagingSpec<TModel> pagingSpec, Func<SqlBuilder.Template, IEnumerable<TModel>> queryFunc) protected List<TModel> GetPagedRecords(SqlBuilder builder, PagingSpec<TModel> pagingSpec, Func<SqlBuilder, IEnumerable<TModel>> queryFunc)
{ {
AddFilters(builder, pagingSpec); AddFilters(builder, pagingSpec);
@ -428,16 +411,22 @@ protected List<TModel> GetPagedRecords(SqlBuilder builder, PagingSpec<TModel> pa
var pagingOffset = (pagingSpec.Page - 1) * pagingSpec.PageSize; var pagingOffset = (pagingSpec.Page - 1) * pagingSpec.PageSize;
builder.OrderBy($"{pagingSpec.SortKey} {sortDirection} LIMIT {pagingSpec.PageSize} OFFSET {pagingOffset}"); builder.OrderBy($"{pagingSpec.SortKey} {sortDirection} LIMIT {pagingSpec.PageSize} OFFSET {pagingOffset}");
var sql = builder.AddTemplate(_selectTemplate).LogQuery(); return queryFunc(builder).ToList();
return queryFunc(sql).ToList();
} }
protected int GetPagedRecordCount(SqlBuilder builder, PagingSpec<TModel> pagingSpec) protected int GetPagedRecordCount(SqlBuilder builder, PagingSpec<TModel> pagingSpec, string template = null)
{ {
AddFilters(builder, pagingSpec); AddFilters(builder, pagingSpec);
var sql = builder.AddTemplate(_selectTemplate).LogQuery(); SqlBuilder.Template sql;
if (template != null)
{
sql = builder.AddTemplate(template).LogQuery();
}
else
{
sql = builder.AddPageCountTemplate(typeof(TModel));
}
using (var conn = _database.OpenConnection()) using (var conn = _database.OpenConnection())
{ {

View File

@ -14,12 +14,18 @@ namespace NzbDrone.Core.Datastore
{ {
public static class SqlBuilderExtensions public static class SqlBuilderExtensions
{ {
public static bool LogSql { get; set; }
private static readonly Logger Logger = NzbDroneLogger.GetLogger(typeof(SqlBuilderExtensions)); private static readonly Logger Logger = NzbDroneLogger.GetLogger(typeof(SqlBuilderExtensions));
public static SqlBuilder SelectAll(this SqlBuilder builder) public static bool LogSql { get; set; }
public static SqlBuilder Select(this SqlBuilder builder, params Type[] types)
{ {
return builder.Select("*"); return builder.Select(types.Select(x => TableMapping.Mapper.TableNameMapping(x) + ".*").Join(", "));
}
public static SqlBuilder SelectDistinct(this SqlBuilder builder, params Type[] types)
{
return builder.Select("DISTINCT " + types.Select(x => TableMapping.Mapper.TableNameMapping(x) + ".*").Join(", "));
} }
public static SqlBuilder SelectCount(this SqlBuilder builder) public static SqlBuilder SelectCount(this SqlBuilder builder)
@ -27,23 +33,30 @@ public static SqlBuilder SelectCount(this SqlBuilder builder)
return builder.Select("COUNT(*)"); return builder.Select("COUNT(*)");
} }
public static SqlBuilder SelectCountDistinct<TModel>(this SqlBuilder builder, Expression<Func<TModel, object>> property)
{
var table = TableMapping.Mapper.TableNameMapping(typeof(TModel));
var propName = property.GetMemberName().Name;
return builder.Select($"COUNT(DISTINCT \"{table}\".\"{propName}\")");
}
public static SqlBuilder Where<TModel>(this SqlBuilder builder, Expression<Func<TModel, bool>> filter) public static SqlBuilder Where<TModel>(this SqlBuilder builder, Expression<Func<TModel, bool>> filter)
{ {
var wb = new WhereBuilder(filter, true); var wb = new WhereBuilder(filter, true, builder.Sequence);
return builder.Where(wb.ToString(), wb.Parameters); return builder.Where(wb.ToString(), wb.Parameters);
} }
public static SqlBuilder OrWhere<TModel>(this SqlBuilder builder, Expression<Func<TModel, bool>> filter) public static SqlBuilder OrWhere<TModel>(this SqlBuilder builder, Expression<Func<TModel, bool>> filter)
{ {
var wb = new WhereBuilder(filter, true); var wb = new WhereBuilder(filter, true, builder.Sequence);
return builder.OrWhere(wb.ToString(), wb.Parameters); return builder.OrWhere(wb.ToString(), wb.Parameters);
} }
public static SqlBuilder Join<TLeft, TRight>(this SqlBuilder builder, Expression<Func<TLeft, TRight, bool>> filter) public static SqlBuilder Join<TLeft, TRight>(this SqlBuilder builder, Expression<Func<TLeft, TRight, bool>> filter)
{ {
var wb = new WhereBuilder(filter, false); var wb = new WhereBuilder(filter, false, builder.Sequence);
var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight)); var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight));
@ -52,41 +65,76 @@ public static SqlBuilder Join<TLeft, TRight>(this SqlBuilder builder, Expression
public static SqlBuilder LeftJoin<TLeft, TRight>(this SqlBuilder builder, Expression<Func<TLeft, TRight, bool>> filter) public static SqlBuilder LeftJoin<TLeft, TRight>(this SqlBuilder builder, Expression<Func<TLeft, TRight, bool>> filter)
{ {
var wb = new WhereBuilder(filter, false); var wb = new WhereBuilder(filter, false, builder.Sequence);
var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight)); var rightTable = TableMapping.Mapper.TableNameMapping(typeof(TRight));
return builder.LeftJoin($"{rightTable} ON {wb.ToString()}"); return builder.LeftJoin($"{rightTable} ON {wb.ToString()}");
} }
public static SqlBuilder GroupBy<TModel>(this SqlBuilder builder, Expression<Func<TModel, object>> property)
{
var table = TableMapping.Mapper.TableNameMapping(typeof(TModel));
var propName = property.GetMemberName().Name;
return builder.GroupBy($"{table}.{propName}");
}
public static SqlBuilder.Template AddSelectTemplate(this SqlBuilder builder, Type type)
{
return builder.AddTemplate(TableMapping.Mapper.SelectTemplate(type)).LogQuery();
}
public static SqlBuilder.Template AddPageCountTemplate(this SqlBuilder builder, Type type)
{
return builder.AddTemplate(TableMapping.Mapper.PageCountTemplate(type)).LogQuery();
}
public static SqlBuilder.Template AddDeleteTemplate(this SqlBuilder builder, Type type)
{
return builder.AddTemplate(TableMapping.Mapper.DeleteTemplate(type)).LogQuery();
}
public static SqlBuilder.Template LogQuery(this SqlBuilder.Template template) public static SqlBuilder.Template LogQuery(this SqlBuilder.Template template)
{ {
if (LogSql) if (LogSql)
{ {
var sb = new StringBuilder(); LogQuery(template.RawSql, (DynamicParameters)template.Parameters);
sb.AppendLine();
sb.AppendLine("==== Begin Query Trace ====");
sb.AppendLine();
sb.AppendLine("QUERY TEXT:");
sb.AppendLine(template.RawSql);
sb.AppendLine();
sb.AppendLine("PARAMETERS:");
foreach (var p in ((DynamicParameters)template.Parameters).ToDictionary())
{
object val = (p.Value is string) ? string.Format("\"{0}\"", p.Value) : p.Value;
sb.AppendFormat("{0} = [{1}]", p.Key, val.ToJson() ?? "NULL").AppendLine();
}
sb.AppendLine();
sb.AppendLine("==== End Query Trace ====");
sb.AppendLine();
Logger.Trace(sb.ToString());
} }
return template; return template;
} }
public static void LogQuery(string sql, object parameters)
{
if (LogSql)
{
LogQuery(sql, new DynamicParameters(parameters));
}
}
private static void LogQuery(string sql, DynamicParameters parameters)
{
var sb = new StringBuilder();
sb.AppendLine();
sb.AppendLine("==== Begin Query Trace ====");
sb.AppendLine();
sb.AppendLine("QUERY TEXT:");
sb.AppendLine(sql);
sb.AppendLine();
sb.AppendLine("PARAMETERS:");
foreach (var p in parameters.ToDictionary())
{
var val = (p.Value is string) ? string.Format("\"{0}\"", p.Value) : p.Value;
sb.AppendFormat("{0} = [{1}]", p.Key, val.ToJson() ?? "NULL").AppendLine();
}
sb.AppendLine();
sb.AppendLine("==== End Query Trace ====");
sb.AppendLine();
Logger.Trace(sb.ToString());
}
private static Dictionary<string, object> ToDictionary(this DynamicParameters dynamicParams) private static Dictionary<string, object> ToDictionary(this DynamicParameters dynamicParams)
{ {
var argsDictionary = new Dictionary<string, object>(); var argsDictionary = new Dictionary<string, object>();
@ -99,32 +147,21 @@ private static Dictionary<string, object> ToDictionary(this DynamicParameters dy
} }
var templates = dynamicParams.GetType().GetField("templates", BindingFlags.NonPublic | BindingFlags.Instance); var templates = dynamicParams.GetType().GetField("templates", BindingFlags.NonPublic | BindingFlags.Instance);
if (templates != null) if (templates != null && templates.GetValue(dynamicParams) is List<object> list)
{ {
var list = templates.GetValue(dynamicParams) as List<object>; foreach (var objProps in list.Select(obj => obj.GetPropertyValuePairs().ToList()))
if (list != null)
{ {
foreach (var objProps in list.Select(obj => obj.GetPropertyValuePairs().ToList())) objProps.ForEach(p => argsDictionary.Add(p.Key, p.Value));
{
objProps.ForEach(p => argsDictionary.Add(p.Key, p.Value));
}
} }
} }
return argsDictionary; return argsDictionary;
} }
private static Dictionary<string, object> GetPropertyValuePairs(this object obj, string[] hidden = null) private static Dictionary<string, object> GetPropertyValuePairs(this object obj)
{ {
var type = obj.GetType(); var type = obj.GetType();
var pairs = hidden == null var pairs = type.GetProperties().Where(x => x.IsMappableProperty())
? type.GetProperties()
.DistinctBy(propertyInfo => propertyInfo.Name)
.ToDictionary(
propertyInfo => propertyInfo.Name,
propertyInfo => propertyInfo.GetValue(obj, null))
: type.GetProperties()
.Where(it => !hidden.Contains(it.Name))
.DistinctBy(propertyInfo => propertyInfo.Name) .DistinctBy(propertyInfo => propertyInfo.Name)
.ToDictionary( .ToDictionary(
propertyInfo => propertyInfo.Name, propertyInfo => propertyInfo.Name,

View File

@ -0,0 +1,47 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Dapper;
using NzbDrone.Common.Reflection;
namespace NzbDrone.Core.Datastore
{
public static class MappingExtensions
{
public static PropertyInfo GetMemberName<T, TChild>(this Expression<Func<T, TChild>> member)
{
if (!(member.Body is MemberExpression memberExpression))
{
memberExpression = (member.Body as UnaryExpression).Operand as MemberExpression;
}
return (PropertyInfo)memberExpression.Member;
}
public static bool IsMappableProperty(this MemberInfo memberInfo)
{
var propertyInfo = memberInfo as PropertyInfo;
if (propertyInfo == null)
{
return false;
}
if (!propertyInfo.IsReadable() || !propertyInfo.IsWritable())
{
return false;
}
// This is a bit of a hack but is the only way to see if a type has a handler set in Dapper
#pragma warning disable 618
SqlMapper.LookupDbType(propertyInfo.PropertyType, "", false, out var handler);
#pragma warning restore 618
if (propertyInfo.PropertyType.IsSimpleType() || handler != null)
{
return true;
}
return false;
}
}
}

View File

@ -0,0 +1,176 @@
using System;
using System.Collections.Generic;
using System.Data;
using Dapper;
namespace NzbDrone.Core.Datastore
{
public static class SqlMapperExtensions
{
public static IEnumerable<T> Query<T>(this IDatabase db, string sql, object param = null)
{
using (var conn = db.OpenConnection())
{
var items = SqlMapper.Query<T>(conn, sql, param);
if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(T), out var lazyProperties))
{
foreach (var item in items)
{
ApplyLazyLoad(db, item, lazyProperties);
}
}
return items;
}
}
public static IEnumerable<TReturn> Query<TFirst, TSecond, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null)
{
TReturn MapWithLazy(TFirst first, TSecond second)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
return map(first, second);
}
IEnumerable<TReturn> result = null;
using (var conn = db.OpenConnection())
{
result = SqlMapper.Query<TFirst, TSecond, TReturn>(conn, sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
return result;
}
public static IEnumerable<TReturn> Query<TFirst, TSecond, TThird, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
return map(first, second, third);
}
IEnumerable<TReturn> result = null;
using (var conn = db.OpenConnection())
{
result = SqlMapper.Query<TFirst, TSecond, TThird, TReturn>(conn, sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
return result;
}
public static IEnumerable<TReturn> Query<TFirst, TSecond, TThird, TFourth, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TFourth, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
ApplyLazyLoad(db, fourth);
return map(first, second, third, fourth);
}
IEnumerable<TReturn> result = null;
using (var conn = db.OpenConnection())
{
result = SqlMapper.Query<TFirst, TSecond, TThird, TFourth, TReturn>(conn, sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
return result;
}
public static IEnumerable<TReturn> Query<TFirst, TSecond, TThird, TFourth, TFifth, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TFourth, TFifth, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth, TFifth fifth)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
ApplyLazyLoad(db, fourth);
ApplyLazyLoad(db, fifth);
return map(first, second, third, fourth, fifth);
}
IEnumerable<TReturn> result = null;
using (var conn = db.OpenConnection())
{
result = SqlMapper.Query<TFirst, TSecond, TThird, TFourth, TFifth, TReturn>(conn, sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
return result;
}
public static IEnumerable<T> Query<T>(this IDatabase db, SqlBuilder builder)
{
var type = typeof(T);
var sql = builder.Select(type).AddSelectTemplate(type);
return db.Query<T>(sql.RawSql, sql.Parameters);
}
public static IEnumerable<T> QueryDistinct<T>(this IDatabase db, SqlBuilder builder)
{
var type = typeof(T);
var sql = builder.SelectDistinct(type).AddSelectTemplate(type);
return db.Query<T>(sql.RawSql, sql.Parameters);
}
public static IEnumerable<T> QueryJoined<T, T2>(this IDatabase db, SqlBuilder builder, Func<T, T2, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2)).AddSelectTemplate(type);
return db.Query(sql.RawSql, mapper, sql.Parameters);
}
public static IEnumerable<T> QueryJoined<T, T2, T3>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3)).AddSelectTemplate(type);
return db.Query(sql.RawSql, mapper, sql.Parameters);
}
public static IEnumerable<T> QueryJoined<T, T2, T3, T4>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T4, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4)).AddSelectTemplate(type);
return db.Query(sql.RawSql, mapper, sql.Parameters);
}
public static IEnumerable<T> QueryJoined<T, T2, T3, T4, T5>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T4, T5, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4), typeof(T5)).AddSelectTemplate(type);
return db.Query(sql.RawSql, mapper, sql.Parameters);
}
private static void ApplyLazyLoad<TModel>(IDatabase db, TModel model)
{
if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(TModel), out var lazyProperties))
{
ApplyLazyLoad(db, model, lazyProperties);
}
}
private static void ApplyLazyLoad<TModel>(IDatabase db, TModel model, List<LazyLoadedProperty> lazyProperties)
{
if (model == null)
{
return;
}
foreach (var lazyProperty in lazyProperties)
{
var lazy = (ILazyLoaded)lazyProperty.LazyLoad.Clone();
lazy.Prepare(db, model);
lazyProperty.Property.SetValue(model, lazy);
}
}
}
}

View File

@ -0,0 +1,139 @@
using System;
using NLog;
using NzbDrone.Common.Instrumentation;
namespace NzbDrone.Core.Datastore
{
public interface ILazyLoaded : ICloneable
{
bool IsLoaded { get; }
void Prepare(IDatabase database, object parent);
void LazyLoad();
}
/// <summary>
/// Allows a field to be lazy loaded.
/// </summary>
/// <typeparam name="TChild"></typeparam>
public class LazyLoaded<TChild> : ILazyLoaded
{
protected TChild _value;
public LazyLoaded()
{
}
public LazyLoaded(TChild val)
{
_value = val;
IsLoaded = true;
}
public TChild Value
{
get
{
LazyLoad();
return _value;
}
}
public bool IsLoaded { get; protected set; }
public static implicit operator LazyLoaded<TChild>(TChild val)
{
return new LazyLoaded<TChild>(val);
}
public static implicit operator TChild(LazyLoaded<TChild> lazy)
{
return lazy.Value;
}
public virtual void Prepare(IDatabase database, object parent)
{
}
public virtual void LazyLoad()
{
}
public object Clone()
{
return MemberwiseClone();
}
public bool ShouldSerializeValue()
{
return IsLoaded;
}
}
/// <summary>
/// This is the lazy loading proxy.
/// </summary>
/// <typeparam name="TParent">The parent entity that contains the lazy loaded entity.</typeparam>
/// <typeparam name="TChild">The child entity that is being lazy loaded.</typeparam>
internal class LazyLoaded<TParent, TChild> : LazyLoaded<TChild>
{
private static readonly Logger Logger = NzbDroneLogger.GetLogger(typeof(LazyLoaded<TParent, TChild>));
private readonly Func<IDatabase, TParent, TChild> _query;
private readonly Func<TParent, bool> _condition;
private IDatabase _database;
private TParent _parent;
public LazyLoaded(TChild val)
: base(val)
{
_value = val;
IsLoaded = true;
}
internal LazyLoaded(Func<IDatabase, TParent, TChild> query, Func<TParent, bool> condition = null)
{
_query = query;
_condition = condition;
}
public static implicit operator LazyLoaded<TParent, TChild>(TChild val)
{
return new LazyLoaded<TParent, TChild>(val);
}
public static implicit operator TChild(LazyLoaded<TParent, TChild> lazy)
{
return lazy.Value;
}
public override void Prepare(IDatabase database, object parent)
{
_database = database;
_parent = (TParent)parent;
}
public override void LazyLoad()
{
if (!IsLoaded)
{
if (_condition != null && _condition(_parent))
{
if (SqlBuilderExtensions.LogSql)
{
Logger.Trace($"Lazy loading {typeof(TChild)} for {typeof(TParent)}");
Logger.Trace("StackTrace: '{0}'", Environment.StackTrace);
}
_value = _query(_database, _parent);
}
else
{
_value = default;
}
IsLoaded = true;
}
}
}
}

View File

@ -0,0 +1,168 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Dapper;
namespace NzbDrone.Core.Datastore
{
public class SqlBuilder
{
private readonly Dictionary<string, Clauses> _data = new Dictionary<string, Clauses>();
public int Sequence { get; private set; }
public Template AddTemplate(string sql, dynamic parameters = null) =>
new Template(this, sql, parameters);
public SqlBuilder Intersect(string sql, dynamic parameters = null) =>
AddClause("intersect", sql, parameters, "\nINTERSECT\n ", "\n ", "\n", false);
public SqlBuilder InnerJoin(string sql, dynamic parameters = null) =>
AddClause("innerjoin", sql, parameters, "\nINNER JOIN ", "\nINNER JOIN ", "\n", false);
public SqlBuilder LeftJoin(string sql, dynamic parameters = null) =>
AddClause("leftjoin", sql, parameters, "\nLEFT JOIN ", "\nLEFT JOIN ", "\n", false);
public SqlBuilder RightJoin(string sql, dynamic parameters = null) =>
AddClause("rightjoin", sql, parameters, "\nRIGHT JOIN ", "\nRIGHT JOIN ", "\n", false);
public SqlBuilder Where(string sql, dynamic parameters = null) =>
AddClause("where", sql, parameters, " AND ", "WHERE ", "\n", false);
public SqlBuilder OrWhere(string sql, dynamic parameters = null) =>
AddClause("where", sql, parameters, " OR ", "WHERE ", "\n", true);
public SqlBuilder OrderBy(string sql, dynamic parameters = null) =>
AddClause("orderby", sql, parameters, " , ", "ORDER BY ", "\n", false);
public SqlBuilder Select(string sql, dynamic parameters = null) =>
AddClause("select", sql, parameters, " , ", "", "\n", false);
public SqlBuilder AddParameters(dynamic parameters) =>
AddClause("--parameters", "", parameters, "", "", "", false);
public SqlBuilder Join(string sql, dynamic parameters = null) =>
AddClause("join", sql, parameters, "\nJOIN ", "\nJOIN ", "\n", false);
public SqlBuilder GroupBy(string sql, dynamic parameters = null) =>
AddClause("groupby", sql, parameters, " , ", "\nGROUP BY ", "\n", false);
public SqlBuilder Having(string sql, dynamic parameters = null) =>
AddClause("having", sql, parameters, "\nAND ", "HAVING ", "\n", false);
protected SqlBuilder AddClause(string name, string sql, object parameters, string joiner, string prefix = "", string postfix = "", bool isInclusive = false)
{
if (!_data.TryGetValue(name, out var clauses))
{
clauses = new Clauses(joiner, prefix, postfix);
_data[name] = clauses;
}
clauses.Add(new Clause { Sql = sql, Parameters = parameters, IsInclusive = isInclusive });
Sequence++;
return this;
}
public class Template
{
private static readonly Regex _regex = new Regex(@"\/\*\*.+?\*\*\/", RegexOptions.Compiled | RegexOptions.Multiline);
private readonly string _sql;
private readonly SqlBuilder _builder;
private readonly object _initParams;
private int _dataSeq = -1; // Unresolved
private string _rawSql;
private object _parameters;
public Template(SqlBuilder builder, string sql, dynamic parameters)
{
_initParams = parameters;
_sql = sql;
_builder = builder;
}
public string RawSql
{
get
{
ResolveSql();
return _rawSql;
}
}
public object Parameters
{
get
{
ResolveSql();
return _parameters;
}
}
private void ResolveSql()
{
if (_dataSeq != _builder.Sequence)
{
var p = new DynamicParameters(_initParams);
_rawSql = _sql;
foreach (var pair in _builder._data)
{
_rawSql = _rawSql.Replace("/**" + pair.Key + "**/", pair.Value.ResolveClauses(p));
}
_parameters = p;
// replace all that is left with empty
_rawSql = _regex.Replace(_rawSql, "");
_dataSeq = _builder.Sequence;
}
}
}
private class Clause
{
public string Sql { get; set; }
public object Parameters { get; set; }
public bool IsInclusive { get; set; }
}
private class Clauses : List<Clause>
{
private readonly string _joiner;
private readonly string _prefix;
private readonly string _postfix;
public Clauses(string joiner, string prefix = "", string postfix = "")
{
_joiner = joiner;
_prefix = prefix;
_postfix = postfix;
}
public string ResolveClauses(DynamicParameters p)
{
foreach (var item in this)
{
p.AddDynamicParams(item.Parameters);
}
return this.Any(a => a.IsInclusive)
? _prefix +
string.Join(_joiner,
this.Where(a => !a.IsInclusive)
.Select(c => c.Sql)
.Union(new[]
{
" ( " +
string.Join(" OR ", this.Where(a => a.IsInclusive).Select(c => c.Sql).ToArray()) +
" ) "
}).ToArray()) + _postfix
: _prefix + string.Join(_joiner, this.Select(c => c.Sql).ToArray()) + _postfix;
}
}
}
}

View File

@ -3,48 +3,36 @@
using System.Linq; using System.Linq;
using System.Linq.Expressions; using System.Linq.Expressions;
using System.Reflection; using System.Reflection;
using Dapper;
using NzbDrone.Common.Reflection;
namespace NzbDrone.Core.Datastore namespace NzbDrone.Core.Datastore
{ {
public static class MappingExtensions
{
public static PropertyInfo GetMemberName<T>(this Expression<Func<T, object>> member)
{
var memberExpression = member.Body as MemberExpression;
if (memberExpression == null)
{
memberExpression = (member.Body as UnaryExpression).Operand as MemberExpression;
}
return (PropertyInfo)memberExpression.Member;
}
}
public class TableMapper public class TableMapper
{ {
public TableMapper() public TableMapper()
{ {
IgnoreList = new Dictionary<Type, List<PropertyInfo>>(); IgnoreList = new Dictionary<Type, List<PropertyInfo>>();
LazyLoadList = new Dictionary<Type, List<LazyLoadedProperty>>();
TableMap = new Dictionary<Type, string>(); TableMap = new Dictionary<Type, string>();
} }
public Dictionary<Type, List<PropertyInfo>> IgnoreList { get; set; } public Dictionary<Type, List<PropertyInfo>> IgnoreList { get; set; }
public Dictionary<Type, List<LazyLoadedProperty>> LazyLoadList { get; set; }
public Dictionary<Type, string> TableMap { get; set; } public Dictionary<Type, string> TableMap { get; set; }
public ColumnMapper<TEntity> Entity<TEntity>(string tableName) public ColumnMapper<TEntity> Entity<TEntity>(string tableName)
where TEntity : ModelBase
{ {
TableMap.Add(typeof(TEntity), tableName); var type = typeof(TEntity);
TableMap.Add(type, tableName);
if (IgnoreList.TryGetValue(typeof(TEntity), out var list)) if (IgnoreList.TryGetValue(type, out var list))
{ {
return new ColumnMapper<TEntity>(list); return new ColumnMapper<TEntity>(list, LazyLoadList[type]);
} }
list = new List<PropertyInfo>(); IgnoreList[type] = new List<PropertyInfo>();
IgnoreList[typeof(TEntity)] = list; LazyLoadList[type] = new List<LazyLoadedProperty>();
return new ColumnMapper<TEntity>(list); return new ColumnMapper<TEntity>(IgnoreList[type], LazyLoadList[type]);
} }
public List<PropertyInfo> ExcludeProperties(Type x) public List<PropertyInfo> ExcludeProperties(Type x)
@ -56,21 +44,44 @@ public string TableNameMapping(Type x)
{ {
return TableMap.ContainsKey(x) ? TableMap[x] : null; return TableMap.ContainsKey(x) ? TableMap[x] : null;
} }
public string SelectTemplate(Type x)
{
return $"SELECT /**select**/ FROM {TableMap[x]} /**join**/ /**innerjoin**/ /**leftjoin**/ /**where**/ /**groupby**/ /**having**/ /**orderby**/";
}
public string DeleteTemplate(Type x)
{
return $"DELETE FROM {TableMap[x]} /**where**/";
}
public string PageCountTemplate(Type x)
{
return $"SELECT /**select**/ FROM {TableMap[x]} /**join**/ /**innerjoin**/ /**leftjoin**/ /**where**/";
}
}
public class LazyLoadedProperty
{
public PropertyInfo Property { get; set; }
public ILazyLoaded LazyLoad { get; set; }
} }
public class ColumnMapper<T> public class ColumnMapper<T>
where T : ModelBase
{ {
private readonly List<PropertyInfo> _ignoreList; private readonly List<PropertyInfo> _ignoreList;
private readonly List<LazyLoadedProperty> _lazyLoadList;
public ColumnMapper(List<PropertyInfo> ignoreList) public ColumnMapper(List<PropertyInfo> ignoreList, List<LazyLoadedProperty> lazyLoadList)
{ {
_ignoreList = ignoreList; _ignoreList = ignoreList;
_lazyLoadList = lazyLoadList;
} }
public ColumnMapper<T> AutoMapPropertiesWhere(Func<PropertyInfo, bool> predicate) public ColumnMapper<T> AutoMapPropertiesWhere(Func<PropertyInfo, bool> predicate)
{ {
Type entityType = typeof(T); var properties = typeof(T).GetProperties();
var properties = entityType.GetProperties();
_ignoreList.AddRange(properties.Where(x => !predicate(x))); _ignoreList.AddRange(properties.Where(x => !predicate(x)));
return this; return this;
@ -78,7 +89,7 @@ public ColumnMapper<T> AutoMapPropertiesWhere(Func<PropertyInfo, bool> predicate
public ColumnMapper<T> RegisterModel() public ColumnMapper<T> RegisterModel()
{ {
return AutoMapPropertiesWhere(IsMappableProperty); return AutoMapPropertiesWhere(x => x.IsMappableProperty());
} }
public ColumnMapper<T> Ignore(Expression<Func<T, object>> property) public ColumnMapper<T> Ignore(Expression<Func<T, object>> property)
@ -87,30 +98,31 @@ public ColumnMapper<T> Ignore(Expression<Func<T, object>> property)
return this; return this;
} }
public static bool IsMappableProperty(MemberInfo memberInfo) public ColumnMapper<T> LazyLoad<TChild>(Expression<Func<T, LazyLoaded<TChild>>> property, Func<IDatabase, T, TChild> query, Func<T, bool> condition)
{ {
var propertyInfo = memberInfo as PropertyInfo; var lazyLoad = new LazyLoaded<T, TChild>(query, condition);
if (propertyInfo == null) var item = new LazyLoadedProperty
{ {
return false; Property = property.GetMemberName(),
} LazyLoad = lazyLoad
};
if (!propertyInfo.IsReadable() || !propertyInfo.IsWritable()) _lazyLoadList.Add(item);
{
return false;
}
// This is a bit of a hack but is the only way to see if a type has a handler set in Dapper return this;
#pragma warning disable 618 }
SqlMapper.LookupDbType(propertyInfo.PropertyType, "", false, out var handler);
#pragma warning restore 618
if (propertyInfo.PropertyType.IsSimpleType() || handler != null)
{
return true;
}
return false; public ColumnMapper<T> HasOne<TChild>(Expression<Func<T, LazyLoaded<TChild>>> portalExpression, Func<T, int> childIdSelector)
where TChild : ModelBase
{
return LazyLoad(portalExpression,
(db, parent) =>
{
var id = childIdSelector(parent);
return db.Query<TChild>(new SqlBuilder().Where<TChild>(x => x.Id == id)).SingleOrDefault();
},
parent => childIdSelector(parent) > 0);
} }
} }
} }

View File

@ -19,9 +19,9 @@ public class WhereBuilder : ExpressionVisitor
private int _paramCount = 0; private int _paramCount = 0;
private bool _gotConcreteValue = false; private bool _gotConcreteValue = false;
public WhereBuilder(Expression filter, bool requireConcreteValue) public WhereBuilder(Expression filter, bool requireConcreteValue, int seq)
{ {
_paramNamePrefix = Guid.NewGuid().ToString().Replace("-", "_"); _paramNamePrefix = string.Format("Clause{0}", seq + 1);
_requireConcreteValue = requireConcreteValue; _requireConcreteValue = requireConcreteValue;
_sb = new StringBuilder(); _sb = new StringBuilder();
@ -87,16 +87,16 @@ protected override Expression VisitMethodCall(MethodCallExpression expression)
protected override Expression VisitMemberAccess(MemberExpression expression) protected override Expression VisitMemberAccess(MemberExpression expression)
{ {
var tableName = expression != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null; var tableName = expression?.Expression?.Type != null ? TableMapping.Mapper.TableNameMapping(expression.Expression.Type) : null;
var gotValue = TryGetRightValue(expression, out var value);
if (tableName != null) // Only use the SQL condition if the expression didn't resolve to an actual value
if (tableName != null && !gotValue)
{ {
_sb.Append($"\"{tableName}\".\"{expression.Member.Name}\""); _sb.Append($"\"{tableName}\".\"{expression.Member.Name}\"");
} }
else else
{ {
var value = GetRightValue(expression);
if (value != null) if (value != null)
{ {
// string is IEnumerable<Char> but we don't want to pick up that case // string is IEnumerable<Char> but we don't want to pick up that case
@ -138,33 +138,43 @@ protected override Expression VisitConstant(ConstantExpression expression)
private bool TryGetConstantValue(Expression expression, out object result) private bool TryGetConstantValue(Expression expression, out object result)
{ {
result = null;
if (expression is ConstantExpression constExp) if (expression is ConstantExpression constExp)
{ {
result = constExp.Value; result = constExp.Value;
return true; return true;
} }
result = null;
return false; return false;
} }
private bool TryGetPropertyValue(MemberExpression expression, out object result) private bool TryGetPropertyValue(MemberExpression expression, out object result)
{ {
result = null;
if (expression.Expression is MemberExpression nested) if (expression.Expression is MemberExpression nested)
{ {
// Value is passed in as a property on a parent entity // Value is passed in as a property on a parent entity
var container = (nested.Expression as ConstantExpression).Value; var container = (nested.Expression as ConstantExpression)?.Value;
if (container == null)
{
return false;
}
var entity = GetFieldValue(container, nested.Member); var entity = GetFieldValue(container, nested.Member);
result = GetFieldValue(entity, expression.Member); result = GetFieldValue(entity, expression.Member);
return true; return true;
} }
result = null;
return false; return false;
} }
private bool TryGetVariableValue(MemberExpression expression, out object result) private bool TryGetVariableValue(MemberExpression expression, out object result)
{ {
result = null;
// Value is passed in as a variable // Value is passed in as a variable
if (expression.Expression is ConstantExpression nested) if (expression.Expression is ConstantExpression nested)
{ {
@ -172,30 +182,31 @@ private bool TryGetVariableValue(MemberExpression expression, out object result)
return true; return true;
} }
result = null;
return false; return false;
} }
private object GetRightValue(Expression expression) private bool TryGetRightValue(Expression expression, out object value)
{ {
if (TryGetConstantValue(expression, out var constValue)) value = null;
if (TryGetConstantValue(expression, out value))
{ {
return constValue; return true;
} }
var memberExp = expression as MemberExpression; var memberExp = expression as MemberExpression;
if (TryGetPropertyValue(memberExp, out var propValue)) if (TryGetPropertyValue(memberExp, out value))
{ {
return propValue; return true;
} }
if (TryGetVariableValue(memberExp, out var variableValue)) if (TryGetVariableValue(memberExp, out value))
{ {
return variableValue; return true;
} }
return null; return false;
} }
private object GetFieldValue(object entity, MemberInfo member) private object GetFieldValue(object entity, MemberInfo member)
@ -224,8 +235,8 @@ private bool IsNullVariable(Expression expression)
if (expression.NodeType == ExpressionType.MemberAccess && if (expression.NodeType == ExpressionType.MemberAccess &&
expression is MemberExpression member && expression is MemberExpression member &&
TryGetVariableValue(member, out var variableResult) && ((TryGetPropertyValue(member, out var result) && result == null) ||
variableResult == null) (TryGetVariableValue(member, out result) && result == null)))
{ {
return true; return true;
} }
@ -264,7 +275,7 @@ private void ParseContainsExpression(MethodCallExpression expression)
{ {
var list = expression.Object; var list = expression.Object;
if (list != null && list.Type == typeof(string)) if (list != null && (list.Type == typeof(string) || list.Type == typeof(List<string>)))
{ {
ParseStringContains(expression); ParseStringContains(expression);
return; return;
@ -304,7 +315,20 @@ private void ParseEnumerableContains(MethodCallExpression body)
_sb.Append(" IN "); _sb.Append(" IN ");
Visit(list); // hardcode the integer list if it exists to bypass parameter limit
if (item.Type == typeof(int) && TryGetRightValue(list, out var value))
{
var items = (IEnumerable<int>)value;
_sb.Append("(");
_sb.Append(string.Join(", ", items));
_sb.Append(")");
_gotConcreteValue = true;
}
else
{
Visit(list);
}
_sb.Append(")"); _sb.Append(")");
} }

View File

@ -1,7 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Dapper;
using NzbDrone.Core.Datastore; using NzbDrone.Core.Datastore;
using NzbDrone.Core.Messaging.Events; using NzbDrone.Core.Messaging.Events;
using NzbDrone.Core.Movies; using NzbDrone.Core.Movies;
@ -74,28 +73,17 @@ public void DeleteForMovie(int movieId)
Delete(c => c.MovieId == movieId); Delete(c => c.MovieId == movieId);
} }
private IEnumerable<MovieHistory> SelectJoined(SqlBuilder.Template sql)
{
using (var conn = _database.OpenConnection())
{
return conn.Query<MovieHistory, Movie, Profile, MovieHistory>(
sql.RawSql,
(hist, movie, profile) =>
{
hist.Movie = movie;
hist.Movie.Profile = profile;
return hist;
},
sql.Parameters)
.ToList();
}
}
protected override SqlBuilder PagedBuilder() => new SqlBuilder() protected override SqlBuilder PagedBuilder() => new SqlBuilder()
.Join<MovieHistory, Movie>((h, m) => h.MovieId == m.Id) .Join<MovieHistory, Movie>((h, m) => h.MovieId == m.Id)
.Join<Movie, Profile>((m, p) => m.ProfileId == p.Id); .Join<Movie, Profile>((m, p) => m.ProfileId == p.Id);
protected override IEnumerable<MovieHistory> PagedSelector(SqlBuilder.Template sql) => SelectJoined(sql); protected override IEnumerable<MovieHistory> PagedQuery(SqlBuilder sql) =>
_database.QueryJoined<MovieHistory, Movie, Profile>(sql, (hist, movie, profile) =>
{
hist.Movie = movie;
hist.Movie.Profile = profile;
return hist;
});
public MovieHistory MostRecentForMovie(int movieId) public MovieHistory MostRecentForMovie(int movieId)
{ {

View File

@ -40,7 +40,7 @@ public MovieRepository(IMainDatabase database,
_profileRepository = profileRepository; _profileRepository = profileRepository;
} }
protected override SqlBuilder BuilderBase() => new SqlBuilder() protected override SqlBuilder Builder() => new SqlBuilder()
.Join<Movie, Profile>((m, p) => m.ProfileId == p.Id) .Join<Movie, Profile>((m, p) => m.ProfileId == p.Id)
.LeftJoin<Movie, AlternativeTitle>((m, t) => m.Id == t.MovieId) .LeftJoin<Movie, AlternativeTitle>((m, t) => m.Id == t.MovieId)
.LeftJoin<Movie, MovieFile>((m, f) => m.Id == f.MovieId); .LeftJoin<Movie, MovieFile>((m, f) => m.Id == f.MovieId);
@ -65,40 +65,33 @@ private Movie Map(Dictionary<int, Movie> dict, Movie movie, Profile profile, Alt
return movieEntry; return movieEntry;
} }
protected override IEnumerable<Movie> GetResults(SqlBuilder.Template sql) protected override List<Movie> Query(SqlBuilder builder)
{ {
var movieDictionary = new Dictionary<int, Movie>(); var movieDictionary = new Dictionary<int, Movie>();
using (var conn = _database.OpenConnection()) _ = _database.QueryJoined<Movie, Profile, AlternativeTitle, MovieFile>(
{ builder,
conn.Query<Movie, Profile, AlternativeTitle, MovieFile, Movie>( (movie, profile, altTitle, file) => Map(movieDictionary, movie, profile, altTitle, file));
sql.RawSql,
(movie, profile, altTitle, file) => Map(movieDictionary, movie, profile, altTitle, file),
sql.Parameters);
}
return movieDictionary.Values; return movieDictionary.Values.ToList();
} }
public override IEnumerable<Movie> All() public override IEnumerable<Movie> All()
{ {
// the skips the join on profile and populates manually // the skips the join on profile and populates manually
// to avoid repeatedly deserializing the same profile // to avoid repeatedly deserializing the same profile
var noProfileTemplate = $"SELECT /**select**/ FROM {_table} /**leftjoin**/ /**where**/ /**orderby**/"; var builder = new SqlBuilder()
var sql = Builder().AddTemplate(noProfileTemplate).LogQuery(); .LeftJoin<Movie, AlternativeTitle>((m, t) => m.Id == t.MovieId)
.LeftJoin<Movie, MovieFile>((m, f) => m.Id == f.MovieId);
var movieDictionary = new Dictionary<int, Movie>(); var movieDictionary = new Dictionary<int, Movie>();
var profiles = _profileRepository.All().ToDictionary(x => x.Id); var profiles = _profileRepository.All().ToDictionary(x => x.Id);
using (var conn = _database.OpenConnection()) _ = _database.QueryJoined<Movie, AlternativeTitle, MovieFile>(
{ builder,
conn.Query<Movie, AlternativeTitle, MovieFile, Movie>( (movie, altTitle, file) => Map(movieDictionary, movie, profiles[movie.ProfileId], altTitle, file));
sql.RawSql,
(movie, altTitle, file) => Map(movieDictionary, movie, profiles[movie.ProfileId], altTitle, file),
sql.Parameters);
}
return movieDictionary.Values; return movieDictionary.Values.ToList();
} }
public bool MoviePathExists(string path) public bool MoviePathExists(string path)
@ -163,23 +156,24 @@ public List<Movie> MoviesBetweenDates(DateTime start, DateTime end, bool include
return Query(builder); return Query(builder);
} }
public SqlBuilder MoviesWithoutFilesBuilder() => BuilderBase().Where<Movie>(x => x.MovieFileId == 0); public SqlBuilder MoviesWithoutFilesBuilder() => Builder()
.Where<Movie>(x => x.MovieFileId == 0);
public PagingSpec<Movie> MoviesWithoutFiles(PagingSpec<Movie> pagingSpec) public PagingSpec<Movie> MoviesWithoutFiles(PagingSpec<Movie> pagingSpec)
{ {
pagingSpec.Records = GetPagedRecords(MoviesWithoutFilesBuilder().SelectAll(), pagingSpec, PagedSelector); pagingSpec.Records = GetPagedRecords(MoviesWithoutFilesBuilder(), pagingSpec, PagedQuery);
pagingSpec.TotalRecords = GetPagedRecordCount(MoviesWithoutFilesBuilder().SelectCount(), pagingSpec); pagingSpec.TotalRecords = GetPagedRecordCount(MoviesWithoutFilesBuilder().SelectCount(), pagingSpec);
return pagingSpec; return pagingSpec;
} }
public SqlBuilder MoviesWhereCutoffUnmetBuilder(List<QualitiesBelowCutoff> qualitiesBelowCutoff) => BuilderBase() public SqlBuilder MoviesWhereCutoffUnmetBuilder(List<QualitiesBelowCutoff> qualitiesBelowCutoff) => Builder()
.Where<Movie>(x => x.MovieFileId != 0) .Where<Movie>(x => x.MovieFileId != 0)
.Where(BuildQualityCutoffWhereClause(qualitiesBelowCutoff)); .Where(BuildQualityCutoffWhereClause(qualitiesBelowCutoff));
public PagingSpec<Movie> MoviesWhereCutoffUnmet(PagingSpec<Movie> pagingSpec, List<QualitiesBelowCutoff> qualitiesBelowCutoff) public PagingSpec<Movie> MoviesWhereCutoffUnmet(PagingSpec<Movie> pagingSpec, List<QualitiesBelowCutoff> qualitiesBelowCutoff)
{ {
pagingSpec.Records = GetPagedRecords(MoviesWhereCutoffUnmetBuilder(qualitiesBelowCutoff).SelectAll(), pagingSpec, PagedSelector); pagingSpec.Records = GetPagedRecords(MoviesWhereCutoffUnmetBuilder(qualitiesBelowCutoff), pagingSpec, PagedQuery);
pagingSpec.TotalRecords = GetPagedRecordCount(MoviesWhereCutoffUnmetBuilder(qualitiesBelowCutoff).SelectCount(), pagingSpec); pagingSpec.TotalRecords = GetPagedRecordCount(MoviesWhereCutoffUnmetBuilder(qualitiesBelowCutoff).SelectCount(), pagingSpec);
return pagingSpec; return pagingSpec;

View File

@ -1,6 +1,5 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Dapper;
using NzbDrone.Core.CustomFormats; using NzbDrone.Core.CustomFormats;
using NzbDrone.Core.Datastore; using NzbDrone.Core.Datastore;
using NzbDrone.Core.Messaging.Events; using NzbDrone.Core.Messaging.Events;
@ -24,11 +23,11 @@ public ProfileRepository(IMainDatabase database,
_customFormatService = customFormatService; _customFormatService = customFormatService;
} }
protected override IEnumerable<Profile> GetResults(SqlBuilder.Template sql) protected override List<Profile> Query(SqlBuilder builder)
{ {
var cfs = _customFormatService.All().ToDictionary(c => c.Id); var cfs = _customFormatService.All().ToDictionary(c => c.Id);
var profiles = base.GetResults(sql); var profiles = base.Query(builder);
// Do the conversions from Id to full CustomFormat object here instead of in // Do the conversions from Id to full CustomFormat object here instead of in
// CustomFormatIntConverter to remove need to for a static property containing // CustomFormatIntConverter to remove need to for a static property containing

View File

@ -4,7 +4,6 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Dapper" Version="2.0.30" /> <PackageReference Include="Dapper" Version="2.0.30" />
<PackageReference Include="Dapper.SqlBuilder" Version="2.0.30" />
<PackageReference Include="System.Memory" Version="4.5.4" /> <PackageReference Include="System.Memory" Version="4.5.4" />
<PackageReference Include="System.ServiceModel.Syndication" Version="4.7.0" /> <PackageReference Include="System.ServiceModel.Syndication" Version="4.7.0" />
<PackageReference Include="FluentMigrator.Runner" Version="4.0.0-alpha.268" /> <PackageReference Include="FluentMigrator.Runner" Version="4.0.0-alpha.268" />

View File

@ -16,15 +16,18 @@ protected ProviderRepository(IMainDatabase database, IEventAggregator eventAggre
{ {
} }
protected override IEnumerable<TProviderDefinition> GetResults(SqlBuilder.Template sql) protected override List<TProviderDefinition> Query(SqlBuilder builder)
{ {
var type = typeof(TProviderDefinition);
var sql = builder.Select(type).AddSelectTemplate(type);
var results = new List<TProviderDefinition>(); var results = new List<TProviderDefinition>();
using (var conn = _database.OpenConnection()) using (var conn = _database.OpenConnection())
using (var reader = conn.ExecuteReader(sql.RawSql, sql.Parameters)) using (var reader = conn.ExecuteReader(sql.RawSql, sql.Parameters))
{ {
var parser = reader.GetRowParser<TProviderDefinition>(typeof(TProviderDefinition)); var parser = reader.GetRowParser<TProviderDefinition>(typeof(TProviderDefinition));
var settingsIndex = reader.GetOrdinal("Settings"); var settingsIndex = reader.GetOrdinal(nameof(ProviderDefinition.Settings));
var serializerSettings = new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; var serializerSettings = new JsonSerializerOptions { PropertyNameCaseInsensitive = true };
while (reader.Read()) while (reader.Read())