1
0
mirror of https://github.com/Radarr/Radarr.git synced 2024-11-09 12:32:31 +01:00
Radarr/Marr.Data/DataMapper.cs
2013-07-26 22:02:25 -07:00

968 lines
31 KiB
C#

/* Copyright (C) 2008 - 2011 Jordan Marr
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library. If not, see <http://www.gnu.org/licenses/>. */
using System;
using System.Collections.Generic;
using System.Text;
using System.Data;
using System.Data.Common;
using System.Reflection;
using System.Collections;
using Marr.Data.Mapping;
using Marr.Data.Parameters;
using Marr.Data.QGen;
using System.Linq.Expressions;
using System.Diagnostics;
namespace Marr.Data
{
/// <summary>
/// This class is the main access point for making database related calls.
/// </summary>
public class DataMapper : IDataMapper
{
#region - Contructor, Members -
private DbCommand _command;
/// <summary>
/// Initializes a DataMapper for the given provider type and connection string.
/// </summary>
/// <param name="providerName">Ex: </param>
/// <param name="connectionString">The database connection string.</param>
public DataMapper(string providerName, string connectionString)
: this(DbProviderFactories.GetFactory(providerName), connectionString)
{ }
/// <summary>
/// A database provider agnostic initialization.
/// </summary>
/// <param name="connectionString">The database connection string.</param>
public DataMapper(DbProviderFactory dbProviderFactory, string connectionString)
{
SqlMode = SqlModes.StoredProcedure;
if (dbProviderFactory == null)
throw new ArgumentNullException("dbProviderFactory");
if (string.IsNullOrEmpty(connectionString))
throw new ArgumentNullException("connectionString");
ProviderFactory = dbProviderFactory;
ConnectionString = connectionString;
}
public string ConnectionString { get; private set; }
public DbProviderFactory ProviderFactory { get; private set; }
/// <summary>
/// Creates a new command utilizing the connection string.
/// </summary>
private DbCommand CreateNewCommand()
{
DbConnection conn = ProviderFactory.CreateConnection();
conn.ConnectionString = ConnectionString;
DbCommand cmd = conn.CreateCommand();
SetSqlMode(cmd);
return cmd;
}
/// <summary>
/// Creates a new command utilizing the connection string with a given SQL command.
/// </summary>
private DbCommand CreateNewCommand(string sql)
{
DbCommand cmd = CreateNewCommand();
cmd.CommandText = sql;
return cmd;
}
/// <summary>
/// Gets or creates a DbCommand object.
/// </summary>
public DbCommand Command
{
get
{
// Lazy load
if (_command == null)
_command = CreateNewCommand();
else
SetSqlMode(_command); // Set SqlMode every time.
return _command;
}
}
#endregion
#region - Parameters -
public DbParameterCollection Parameters
{
get
{
return Command.Parameters;
}
}
public ParameterChainMethods AddParameter(string name, object value)
{
return new ParameterChainMethods(Command, name, value);
}
public IDbDataParameter AddParameter(IDbDataParameter parameter)
{
// Convert null values to DBNull.Value
if (parameter.Value == null)
parameter.Value = DBNull.Value;
Parameters.Add(parameter);
return parameter;
}
#endregion
#region - SP / SQL Mode -
/// <summary>
/// Gets or sets a value that determines whether the DataMapper will
/// use a stored procedure or a sql text command to access
/// the database. The default is stored procedure.
/// </summary>
public SqlModes SqlMode { get; set; }
/// <summary>
/// Sets the DbCommand objects CommandType to the current SqlMode.
/// </summary>
/// <param name="command">The DbCommand object we are modifying.</param>
/// <returns>Returns the same DbCommand that was passed in.</returns>
private DbCommand SetSqlMode(DbCommand command)
{
if (SqlMode == SqlModes.StoredProcedure)
command.CommandType = CommandType.StoredProcedure;
else
command.CommandType = CommandType.Text;
return command;
}
#endregion
#region - ExecuteScalar, ExecuteNonQuery, ExecuteReader -
/// <summary>
/// Executes a stored procedure that returns a scalar value.
/// </summary>
/// <param name="sql">The SQL command to execute.</param>
/// <returns>A scalar value</returns>
public object ExecuteScalar(string sql)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
Command.CommandText = sql;
try
{
OpenConnection();
return Command.ExecuteScalar();
}
finally
{
CloseConnection();
}
}
/// <summary>
/// Executes a non query that returns an integer.
/// </summary>
/// <param name="sql">The SQL command to execute.</param>
/// <returns>An integer value</returns>
public int ExecuteNonQuery(string sql)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
Command.CommandText = sql;
try
{
OpenConnection();
return Command.ExecuteNonQuery();
}
finally
{
CloseConnection();
}
}
/// <summary>
/// Executes a DataReader that can be controlled using a Func delegate.
/// (Note that reader.Read() will be called automatically).
/// </summary>
/// <typeparam name="TResult">The type that will be return in the result set.</typeparam>
/// <param name="sql">The sql statement that will be executed.</param>
/// <param name="func">The function that will build the the TResult set.</param>
/// <returns>An IEnumerable of TResult.</returns>
public IEnumerable<TResult> ExecuteReader<TResult>(string sql, Func<DbDataReader, TResult> func)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
Command.CommandText = sql;
try
{
OpenConnection();
var list = new List<TResult>();
DbDataReader reader = null;
try
{
reader = Command.ExecuteReader();
while (reader.Read())
{
list.Add(func(reader));
}
return list;
}
finally
{
if (reader != null) reader.Close();
}
}
finally
{
CloseConnection();
}
}
/// <summary>
/// Executes a DataReader that can be controlled using an Action delegate.
/// </summary>
/// <param name="sql">The sql statement that will be executed.</param>
/// <param name="action">The delegate that will work with the result set.</param>
public void ExecuteReader(string sql, Action<DbDataReader> action)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
Command.CommandText = sql;
try
{
OpenConnection();
DbDataReader reader = null;
try
{
reader = Command.ExecuteReader();
while (reader.Read())
{
action(reader);
}
}
finally
{
if (reader != null) reader.Close();
}
}
finally
{
CloseConnection();
}
}
#endregion
#region - DataSets -
public DataSet GetDataSet(string sql)
{
return GetDataSet(sql, new DataSet(), null);
}
public DataSet GetDataSet(string sql, DataSet ds, string tableName)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
try
{
using (DbDataAdapter adapter = ProviderFactory.CreateDataAdapter())
{
Command.CommandText = sql;
adapter.SelectCommand = Command;
if (ds == null)
ds = new DataSet();
OpenConnection();
if (string.IsNullOrEmpty(tableName))
adapter.Fill(ds);
else
adapter.Fill(ds, tableName);
return ds;
}
}
finally
{
CloseConnection(); // Clears parameters
}
}
public DataTable GetDataTable(string sql)
{
return GetDataTable(sql, null, null);
}
public DataTable GetDataTable(string sql, DataTable dt, string tableName)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
try
{
using (DbDataAdapter adapter = ProviderFactory.CreateDataAdapter())
{
Command.CommandText = sql;
adapter.SelectCommand = Command;
if (dt == null)
dt = new DataTable();
adapter.Fill(dt);
if (!string.IsNullOrEmpty(tableName))
dt.TableName = tableName;
return dt;
}
}
finally
{
CloseConnection(); // Clears parameters
}
}
public int UpdateDataSet(DataSet ds, string sql)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
if (ds == null)
throw new ArgumentNullException("ds", "DataSet cannot be null.");
DbDataAdapter adapter = null;
try
{
adapter = ProviderFactory.CreateDataAdapter();
adapter.UpdateCommand = Command;
adapter.UpdateCommand.CommandText = sql;
return adapter.Update(ds);
}
finally
{
if (adapter.UpdateCommand != null)
adapter.UpdateCommand.Dispose();
adapter.Dispose();
}
}
public int InsertDataTable(DataTable table, string insertSP)
{
return InsertDataTable(table, insertSP, UpdateRowSource.None);
}
public int InsertDataTable(DataTable dt, string sql, UpdateRowSource updateRowSource)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
if (dt == null)
throw new ArgumentNullException("dt", "DataTable cannot be null.");
DbDataAdapter adapter = null;
try
{
adapter = ProviderFactory.CreateDataAdapter();
adapter.InsertCommand = Command;
adapter.InsertCommand.CommandText = sql;
adapter.InsertCommand.UpdatedRowSource = updateRowSource;
return adapter.Update(dt);
}
finally
{
if (adapter.InsertCommand != null)
adapter.InsertCommand.Dispose();
adapter.Dispose();
}
}
public int DeleteDataTable(DataTable dt, string sql)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A SQL query or stored procedure name is required");
if (dt == null)
throw new ArgumentNullException("dt", "DataSet cannot be null.");
DbDataAdapter adapter = null;
try
{
adapter = ProviderFactory.CreateDataAdapter();
adapter.DeleteCommand = Command;
adapter.DeleteCommand.CommandText = sql;
return adapter.Update(dt);
}
finally
{
if (adapter.DeleteCommand != null)
adapter.DeleteCommand.Dispose();
adapter.Dispose();
}
}
#endregion
#region - Find -
public T Find<T>(string sql)
{
return Find<T>(sql, default(T));
}
/// <summary>
/// Returns an entity of type T.
/// </summary>
/// <typeparam name="T">The type of entity that is to be instantiated and loaded with values.</typeparam>
/// <param name="sql">The SQL command to execute.</param>
/// <returns>An instantiated and loaded entity of type T.</returns>
public T Find<T>(string sql, T ent)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A stored procedure name has not been specified for 'Find'.");
Type entityType = typeof(T);
Command.CommandText = sql;
MapRepository repository = MapRepository.Instance;
ColumnMapCollection mappings = repository.GetColumns(entityType);
bool isSimpleType = DataHelper.IsSimpleType(typeof(T));
try
{
OpenConnection();
var mappingHelper = new MappingHelper(this);
using (DbDataReader reader = Command.ExecuteReader())
{
if (reader.Read())
{
if (isSimpleType)
{
return mappingHelper.LoadSimpleValueFromFirstColumn<T>(reader);
}
else
{
if (ent == null)
ent = (T)mappingHelper.CreateAndLoadEntity<T>(mappings, reader, false);
else
mappingHelper.LoadExistingEntity(mappings, reader, ent, false);
}
}
}
}
finally
{
CloseConnection();
}
return ent;
}
#endregion
#region - Query -
/// <summary>
/// Creates a QueryBuilder that allows you to build a query.
/// </summary>
/// <typeparam name="T">The type of object that will be queried.</typeparam>
/// <returns>Returns a QueryBuilder of T.</returns>
public QueryBuilder<T> Query<T>()
{
var dialect = QueryFactory.CreateDialect(this);
return new QueryBuilder<T>(this, dialect);
}
/// <summary>
/// Returns the results of a query.
/// Uses a List of type T to return the data.
/// </summary>
/// <returns>Returns a list of the specified type.</returns>
public List<T> Query<T>(string sql)
{
return (List<T>)Query<T>(sql, new List<T>());
}
/// <summary>
/// Returns the results of a SP query.
/// </summary>
/// <returns>Returns a list of the specified type.</returns>
public ICollection<T> Query<T>(string sql, ICollection<T> entityList)
{
return Query<T>(sql, entityList, false);
}
internal ICollection<T> Query<T>(string sql, ICollection<T> entityList, bool useAltName)
{
if (entityList == null)
throw new ArgumentNullException("entityList", "ICollection instance cannot be null.");
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "A query or stored procedure has not been specified for 'Query'.");
var mappingHelper = new MappingHelper(this);
Type entityType = typeof(T);
Command.CommandText = sql;
ColumnMapCollection mappings = MapRepository.Instance.GetColumns(entityType);
bool isSimpleType = DataHelper.IsSimpleType(typeof(T));
try
{
OpenConnection();
using (DbDataReader reader = Command.ExecuteReader())
{
while (reader.Read())
{
if (isSimpleType)
{
entityList.Add(mappingHelper.LoadSimpleValueFromFirstColumn<T>(reader));
}
else
{
entityList.Add((T)mappingHelper.CreateAndLoadEntity<T>(mappings, reader, useAltName));
}
}
}
}
finally
{
CloseConnection();
}
return entityList;
}
#endregion
#region - Query to Graph -
public List<T> QueryToGraph<T>(string sql)
{
return (List<T>)QueryToGraph<T>(sql, new List<T>());
}
public ICollection<T> QueryToGraph<T>(string sql, ICollection<T> entityList)
{
EntityGraph graph = new EntityGraph(typeof(T), (IList)entityList);
return QueryToGraph<T>(sql, graph, new List<MemberInfo>());
}
/// <summary>
/// Queries a view that joins multiple tables and returns an object graph.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="sql"></param>
/// <param name="entityList"></param>
/// <param name="entityGraph">Coordinates loading all objects in the graph..</param>
/// <returns></returns>
internal ICollection<T> QueryToGraph<T>(string sql, EntityGraph graph, List<MemberInfo> childrenToLoad)
{
if (string.IsNullOrEmpty(sql))
throw new ArgumentNullException("sql", "sql");
var mappingHelper = new MappingHelper(this);
Type parentType = typeof(T);
Command.CommandText = sql;
try
{
OpenConnection();
using (DbDataReader reader = Command.ExecuteReader())
{
while (reader.Read())
{
// The entire EntityGraph is traversed for each record,
// and multiple entities are created from each view record.
foreach (EntityGraph lvl in graph)
{
if (lvl.IsParentReference)
{
// A child specified a circular reference to its previously loaded parent
lvl.AddParentReference();
}
else if (childrenToLoad.Count > 0 && !lvl.IsRoot && !childrenToLoad.ContainsMember(lvl.Member))
{
// A list of relationships-to-load was specified and this relationship was not included
continue;
}
else if (lvl.IsNewGroup(reader))
{
// Create a new entity with the data reader
var newEntity = mappingHelper.CreateAndLoadEntity(lvl.EntityType, lvl.Columns, reader, true);
// Add entity to the appropriate place in the object graph
lvl.AddEntity(newEntity);
}
}
}
}
}
finally
{
CloseConnection();
}
return (ICollection<T>)graph.RootList;
}
#endregion
#region - Update -
public UpdateQueryBuilder<T> Update<T>()
{
return new UpdateQueryBuilder<T>(this);
}
public int Update<T>(T entity, Expression<Func<T, bool>> filter)
{
return Update<T>()
.Entity(entity)
.Where(filter)
.Execute();
}
public int Update<T>(string tableName, T entity, Expression<Func<T, bool>> filter)
{
return Update<T>()
.TableName(tableName)
.Entity(entity)
.Where(filter)
.Execute();
}
public int Update<T>(T entity, string sql)
{
return Update<T>()
.Entity(entity)
.QueryText(sql)
.Execute();
}
#endregion
#region - Insert -
/// <summary>
/// Creates an InsertQueryBuilder that allows you to build an insert statement.
/// This method gives you the flexibility to manually configure all options of your insert statement.
/// Note: You must manually call the Execute() chaining method to run the query.
/// </summary>
public InsertQueryBuilder<T> Insert<T>()
{
return new InsertQueryBuilder<T>(this);
}
/// <summary>
/// Generates and executes an insert query for the given entity.
/// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
/// and if an identity query has been implemented for your current database dialect.
/// </summary>
public object Insert<T>(T entity)
{
var columns = MapRepository.Instance.GetColumns(typeof(T));
var dialect = QueryFactory.CreateDialect(this);
var builder = Insert<T>().Entity(entity);
// If an auto-increment column exists and this dialect has an identity query...
if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
{
builder.GetIdentity();
}
return builder.Execute();
}
/// <summary>
/// Generates and executes an insert query for the given entity.
/// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
/// and if an identity query has been implemented for your current database dialect.
/// </summary>
public object Insert<T>(string tableName, T entity)
{
var columns = MapRepository.Instance.GetColumns(typeof(T));
var dialect = QueryFactory.CreateDialect(this);
var builder = Insert<T>().Entity(entity).TableName(tableName);
// If an auto-increment column exists and this dialect has an identity query...
if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
{
builder.GetIdentity();
}
return builder.Execute();
}
/// <summary>
/// Executes an insert query for the given entity using the given sql insert statement.
/// This overload will automatically run an identity query if you have mapped an auto-incrementing column,
/// and if an identity query has been implemented for your current database dialect.
/// </summary>
public object Insert<T>(T entity, string sql)
{
var columns = MapRepository.Instance.GetColumns(typeof(T));
var dialect = QueryFactory.CreateDialect(this);
var builder = Insert<T>().Entity(entity).QueryText(sql);
// If an auto-increment column exists and this dialect has an identity query...
if (columns.Exists(c => c.ColumnInfo.IsAutoIncrement) && dialect.HasIdentityQuery)
{
builder.GetIdentity();
}
return builder.Execute();
}
#endregion
#region - Delete -
public int Delete<T>(Expression<Func<T, bool>> filter)
{
return Delete<T>(null, filter);
}
public int Delete<T>(string tableName, Expression<Func<T, bool>> filter)
{
// Remember sql mode
var previousSqlMode = SqlMode;
SqlMode = SqlModes.Text;
var mappingHelper = new MappingHelper(this);
if (tableName == null)
{
tableName = MapRepository.Instance.GetTableName(typeof(T));
}
var dialect = QueryFactory.CreateDialect(this);
TableCollection tables = new TableCollection();
tables.Add(new Table(typeof(T)));
var where = new WhereBuilder<T>(Command, dialect, filter, tables, false, false);
IQuery query = QueryFactory.CreateDeleteQuery(dialect, tables[0], where.ToString());
Command.CommandText = query.Generate();
int rowsAffected = 0;
try
{
OpenConnection();
rowsAffected = Command.ExecuteNonQuery();
}
finally
{
CloseConnection();
}
// Return to previous sql mode
SqlMode = previousSqlMode;
return rowsAffected;
}
#endregion
#region - Events -
public event EventHandler OpeningConnection;
public event EventHandler ClosingConnection;
#endregion
#region - Connections / Transactions -
protected virtual void OnOpeningConnection()
{
if (OpeningConnection != null)
OpeningConnection(this, EventArgs.Empty);
}
protected virtual void OnClosingConnection()
{
WriteToTraceLog();
if (ClosingConnection != null)
ClosingConnection(this, EventArgs.Empty);
}
protected internal void OpenConnection()
{
OnOpeningConnection();
if (Command.Connection.State != ConnectionState.Open)
Command.Connection.Open();
}
protected internal void CloseConnection()
{
OnClosingConnection();
Command.Parameters.Clear();
Command.CommandText = string.Empty;
if (Command.Transaction == null)
Command.Connection.Close(); // Only close if no transaction is present
UnbindEvents();
}
private void WriteToTraceLog()
{
if (MapRepository.Instance.EnableTraceLogging)
{
var sb = new StringBuilder();
sb.AppendLine();
sb.AppendLine("==== Begin Query Trace ====");
sb.AppendLine();
sb.AppendLine("QUERY TYPE:");
sb.AppendLine(Command.CommandType.ToString());
sb.AppendLine();
sb.AppendLine("QUERY TEXT:");
sb.AppendLine(Command.CommandText);
sb.AppendLine();
sb.AppendLine("PARAMETERS:");
foreach (IDbDataParameter p in Parameters)
{
object val = (p.Value != null && p.Value is string) ? string.Format("\"{0}\"", p.Value) : p.Value;
sb.AppendFormat("{0} = [{1}]", p.ParameterName, val ?? "NULL").AppendLine();
}
sb.AppendLine();
sb.AppendLine("==== End Query Trace ====");
sb.AppendLine();
Trace.Write(sb.ToString());
}
}
private void UnbindEvents()
{
OpeningConnection = null;
ClosingConnection = null;
}
public void BeginTransaction()
{
OpenConnection();
DbTransaction trans = Command.Connection.BeginTransaction();
Command.Transaction = trans;
}
public void RollBack()
{
try
{
if (Command.Transaction != null)
Command.Transaction.Rollback();
}
finally
{
Command.Connection.Close();
}
}
public void Commit()
{
try
{
if (Command.Transaction != null)
Command.Transaction.Commit();
}
finally
{
Command.Connection.Close();
}
}
#endregion
#region - IDisposable Members -
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this); // In case a derived class implements a finalizer
}
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
if (Command.Transaction != null)
{
Command.Transaction.Dispose();
Command.Transaction = null;
}
if (Command.Connection != null)
{
Command.Connection.Dispose();
Command.Connection = null;
}
if (Command != null)
{
Command.Dispose();
_command = null;
}
}
}
#endregion
}
}