Class1.cs 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. using System;
  2. using System.Data;
  3. using System.Data.Common;
  4. using System.Linq;
  5. using System.Linq.Expressions;
  6. using System.Management;
  7. using System.Text;
  8. namespace EntityFramework.Extensions
  9. {
  10. /// <summary>
  11. /// An extensions class for batch queries.
  12. /// </summary>
  13. public static class BatchExtensions
  14. {
  15. /// <summary>
  16. /// Executes a delete statement using the query to filter the rows to be deleted.
  17. /// </summary>
  18. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  19. /// <param name="source">The source used to determine the table to delete from.</param>
  20. /// <param name="query">The IQueryable used to generate the where clause for the delete statement.</param>
  21. /// <returns>The number of row deleted.</returns>
  22. /// <remarks>
  23. /// When executing this method, the statement is immediately executed on the database provider
  24. /// and is not part of the change tracking system. Also, changes will not be reflected on
  25. /// any entities that have already been materialized in the current context.
  26. /// </remarks>
  27. public static int Delete<TEntity>(
  28. this ObjectSet<TEntity> source,
  29. IQueryable<TEntity> query)
  30. where TEntity : class
  31. {
  32. if (source == null)
  33. throw new ArgumentNullException("source");
  34. if (query == null)
  35. throw new ArgumentNullException("query");
  36. ObjectContext objectContext = source.Context;
  37. if (objectContext == null)
  38. throw new ArgumentException("The ObjectContext for the source query can not be null.", "source");
  39. EntityMap entityMap = source.GetEntityMap<TEntity>();
  40. if (entityMap == null)
  41. throw new ArgumentException("Could not load the entity mapping information for the source ObjectSet.", "source");
  42. ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
  43. if (objectQuery == null)
  44. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "query");
  45. return Delete(objectContext, entityMap, objectQuery);
  46. }
  47. /// <summary>
  48. /// Executes a delete statement using an expression to filter the rows to be deleted.
  49. /// </summary>
  50. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  51. /// <param name="source">The source used to determine the table to delete from.</param>
  52. /// <param name="filterExpression">The filter expression used to generate the where clause for the delete statement.</param>
  53. /// <returns>The number of row deleted.</returns>
  54. /// <example>Delete all users with email domain @test.com.
  55. /// <code><![CDATA[
  56. /// var db = new TrackerEntities();
  57. /// string emailDomain = "@test.com";
  58. /// int count = db.Users.Delete(u => u.Email.EndsWith(emailDomain));
  59. /// ]]></code>
  60. /// </example>
  61. /// <remarks>
  62. /// When executing this method, the statement is immediately executed on the database provider
  63. /// and is not part of the change tracking system. Also, changes will not be reflected on
  64. /// any entities that have already been materialized in the current context.
  65. /// </remarks>
  66. public static int Delete<TEntity>(
  67. this ObjectSet<TEntity> source,
  68. Expression<Func<TEntity, bool>> filterExpression)
  69. where TEntity : class
  70. {
  71. if (source == null)
  72. throw new ArgumentNullException("source");
  73. if (filterExpression == null)
  74. throw new ArgumentNullException("filterExpression");
  75. return source.Delete(source.Where(filterExpression));
  76. }
  77. /// <summary>
  78. /// Executes a delete statement using the query to filter the rows to be deleted.
  79. /// </summary>
  80. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  81. /// <param name="source">The source used to determine the table to delete from.</param>
  82. /// <param name="query">The IQueryable used to generate the where clause for the delete statement.</param>
  83. /// <returns>The number of row deleted.</returns>
  84. /// <remarks>
  85. /// When executing this method, the statement is immediately executed on the database provider
  86. /// and is not part of the change tracking system. Also, changes will not be reflected on
  87. /// any entities that have already been materialized in the current context.
  88. /// </remarks>
  89. public static int Delete<TEntity>(
  90. this DbSet<TEntity> source,
  91. IQueryable<TEntity> query)
  92. where TEntity : class
  93. {
  94. if (source == null)
  95. throw new ArgumentNullException("source");
  96. if (query == null)
  97. throw new ArgumentNullException("query");
  98. ObjectQuery<TEntity> sourceQuery = source.ToObjectQuery();
  99. if (sourceQuery == null)
  100. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "source");
  101. ObjectContext objectContext = sourceQuery.Context;
  102. if (objectContext == null)
  103. throw new ArgumentException("The ObjectContext for the source query can not be null.", "source");
  104. EntityMap entityMap = sourceQuery.GetEntityMap<TEntity>();
  105. if (entityMap == null)
  106. throw new ArgumentException("Could not load the entity mapping information for the source ObjectSet.", "source");
  107. ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
  108. if (objectQuery == null)
  109. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "query");
  110. return Delete(objectContext, entityMap, objectQuery);
  111. }
  112. /// <summary>
  113. /// Executes a delete statement using an expression to filter the rows to be deleted.
  114. /// </summary>
  115. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  116. /// <param name="source">The source used to determine the table to delete from.</param>
  117. /// <param name="filterExpression">The filter expression used to generate the where clause for the delete statement.</param>
  118. /// <returns>The number of row deleted.</returns>
  119. /// <example>Delete all users with email domain @test.com.
  120. /// <code><![CDATA[
  121. /// var db = new TrackerContext();
  122. /// string emailDomain = "@test.com";
  123. /// int count = db.Users.Delete(u => u.Email.EndsWith(emailDomain));
  124. /// ]]></code>
  125. /// </example>
  126. /// <remarks>
  127. /// When executing this method, the statement is immediately executed on the database provider
  128. /// and is not part of the change tracking system. Also, changes will not be reflected on
  129. /// any entities that have already been materialized in the current context.
  130. /// </remarks>
  131. public static int Delete<TEntity>(
  132. this DbSet<TEntity> source,
  133. Expression<Func<TEntity, bool>> filterExpression)
  134. where TEntity : class
  135. {
  136. if (source == null)
  137. throw new ArgumentNullException("source");
  138. if (filterExpression == null)
  139. throw new ArgumentNullException("filterExpression");
  140. return source.Delete(source.Where(filterExpression));
  141. }
  142. /// <summary>
  143. /// Executes an update statement using the query to filter the rows to be updated.
  144. /// </summary>
  145. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  146. /// <param name="source">The source used to determine the table to update.</param>
  147. /// <param name="query">The query used to generate the where clause.</param>
  148. /// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
  149. /// <returns>The number of row updated.</returns>
  150. /// <remarks>
  151. /// When executing this method, the statement is immediately executed on the database provider
  152. /// and is not part of the change tracking system. Also, changes will not be reflected on
  153. /// any entities that have already been materialized in the current context.
  154. /// </remarks>
  155. public static int Update<TEntity>(
  156. this ObjectSet<TEntity> source,
  157. IQueryable<TEntity> query,
  158. Expression<Func<TEntity, TEntity>> updateExpression)
  159. where TEntity : class
  160. {
  161. if (source == null)
  162. throw new ArgumentNullException("source");
  163. if (query == null)
  164. throw new ArgumentNullException("query");
  165. if (updateExpression == null)
  166. throw new ArgumentNullException("updateExpression");
  167. ObjectContext objectContext = source.Context;
  168. if (objectContext == null)
  169. throw new ArgumentException("The ObjectContext for the source query can not be null.", "source");
  170. EntityMap entityMap = source.GetEntityMap<TEntity>();
  171. if (entityMap == null)
  172. throw new ArgumentException("Could not load the entity mapping information for the source ObjectSet.", "source");
  173. ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
  174. if (objectQuery == null)
  175. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "query");
  176. return Update(objectContext, entityMap, objectQuery, updateExpression);
  177. }
  178. /// <summary>
  179. /// Executes an update statement using an expression to filter the rows that are updated.
  180. /// </summary>
  181. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  182. /// <param name="source">The source used to determine the table to update.</param>
  183. /// <param name="filterExpression">The filter expression used to generate the where clause.</param>
  184. /// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
  185. /// <returns>The number of row updated.</returns>
  186. /// <example>Update all users in the test.com domain to be inactive.
  187. /// <code><![CDATA[
  188. /// var db = new TrackerEntities();
  189. /// string emailDomain = "@test.com";
  190. /// int count = db.Users.Update(
  191. /// u => u.Email.EndsWith(emailDomain),
  192. /// u => new User { IsApproved = false, LastActivityDate = DateTime.Now });
  193. /// ]]></code>
  194. /// </example>
  195. /// <remarks>
  196. /// When executing this method, the statement is immediately executed on the database provider
  197. /// and is not part of the change tracking system. Also, changes will not be reflected on
  198. /// any entities that have already been materialized in the current context.
  199. /// </remarks>
  200. public static int Update<TEntity>(
  201. this ObjectSet<TEntity> source,
  202. Expression<Func<TEntity, bool>> filterExpression,
  203. Expression<Func<TEntity, TEntity>> updateExpression)
  204. where TEntity : class
  205. {
  206. if (source == null)
  207. throw new ArgumentNullException("source");
  208. if (filterExpression == null)
  209. throw new ArgumentNullException("filterExpression");
  210. return source.Update(source.Where(filterExpression), updateExpression);
  211. }
  212. /// <summary>
  213. /// Executes an update statement using the query to filter the rows to be updated.
  214. /// </summary>
  215. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  216. /// <param name="source">The source used to determine the table to update.</param>
  217. /// <param name="query">The query used to generate the where clause.</param>
  218. /// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
  219. /// <returns>The number of row updated.</returns>
  220. /// <remarks>
  221. /// When executing this method, the statement is immediately executed on the database provider
  222. /// and is not part of the change tracking system. Also, changes will not be reflected on
  223. /// any entities that have already been materialized in the current context.
  224. /// </remarks>
  225. public static int Update<TEntity>(
  226. this DbSet<TEntity> source,
  227. IQueryable<TEntity> query,
  228. Expression<Func<TEntity, TEntity>> updateExpression)
  229. where TEntity : class
  230. {
  231. if (source == null)
  232. throw new ArgumentNullException("source");
  233. if (query == null)
  234. throw new ArgumentNullException("query");
  235. if (updateExpression == null)
  236. throw new ArgumentNullException("updateExpression");
  237. ObjectQuery<TEntity> sourceQuery = source.ToObjectQuery();
  238. if (sourceQuery == null)
  239. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "source");
  240. ObjectContext objectContext = sourceQuery.Context;
  241. if (objectContext == null)
  242. throw new ArgumentException("The ObjectContext for the source query can not be null.", "source");
  243. EntityMap entityMap = sourceQuery.GetEntityMap<TEntity>();
  244. if (entityMap == null)
  245. throw new ArgumentException("Could not load the entity mapping information for the source.", "source");
  246. ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
  247. if (objectQuery == null)
  248. throw new ArgumentException("The query must be of type ObjectQuery or DbQuery.", "query");
  249. return Update(objectContext, entityMap, objectQuery, updateExpression);
  250. }
  251. /// <summary>
  252. /// Executes an update statement using an expression to filter the rows that are updated.
  253. /// </summary>
  254. /// <typeparam name="TEntity">The type of the entity.</typeparam>
  255. /// <param name="source">The source used to determine the table to update.</param>
  256. /// <param name="filterExpression">The filter expression used to generate the where clause.</param>
  257. /// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
  258. /// <returns>The number of row updated.</returns>
  259. /// <example>Update all users in the test.com domain to be inactive.
  260. /// <code><![CDATA[
  261. /// var db = new TrackerContext();
  262. /// string emailDomain = "@test.com";
  263. /// int count = db.Users.Update(
  264. /// u => u.Email.EndsWith(emailDomain),
  265. /// u => new User { IsApproved = false, LastActivityDate = DateTime.Now });
  266. /// ]]></code>
  267. /// </example>
  268. /// <remarks>
  269. /// When executing this method, the statement is immediately executed on the database provider
  270. /// and is not part of the change tracking system. Also, changes will not be reflected on
  271. /// any entities that have already been materialized in the current context.
  272. /// </remarks>
  273. public static int Update<TEntity>(
  274. this DbSet<TEntity> source,
  275. Expression<Func<TEntity, bool>> filterExpression,
  276. Expression<Func<TEntity, TEntity>> updateExpression)
  277. where TEntity : class
  278. {
  279. if (source == null)
  280. throw new ArgumentNullException("source");
  281. if (filterExpression == null)
  282. throw new ArgumentNullException("filterExpression");
  283. return source.Update(source.Where(filterExpression), updateExpression);
  284. }
  285. private static int Delete<TEntity>(ObjectContext objectContext, EntityMap entityMap, ObjectQuery<TEntity> query)
  286. where TEntity : class
  287. {
  288. DbConnection deleteConnection = null;
  289. DbTransaction deleteTransaction = null;
  290. DbCommand deleteCommand = null;
  291. bool ownConnection = false;
  292. bool ownTransaction = false;
  293. try
  294. {
  295. // get store connection and transaction
  296. var store = GetStore(objectContext);
  297. deleteConnection = store.Item1;
  298. deleteTransaction = store.Item2;
  299. if (deleteConnection.State != ConnectionState.Open)
  300. {
  301. deleteConnection.Open();
  302. ownConnection = true;
  303. }
  304. if (deleteTransaction == null)
  305. {
  306. deleteTransaction = deleteConnection.BeginTransaction();
  307. ownTransaction = true;
  308. }
  309. deleteCommand = deleteConnection.CreateCommand();
  310. deleteCommand.Transaction = deleteTransaction;
  311. if (objectContext.CommandTimeout.HasValue)
  312. deleteCommand.CommandTimeout = objectContext.CommandTimeout.Value;
  313. var innerSelect = GetSelectSql(query, entityMap, deleteCommand);
  314. var sqlBuilder = new StringBuilder(innerSelect.Length * 2);
  315. sqlBuilder.Append("DELETE ");
  316. sqlBuilder.Append(entityMap.TableName);
  317. sqlBuilder.AppendLine();
  318. sqlBuilder.AppendFormat("FROM {0} AS j0 INNER JOIN (", entityMap.TableName);
  319. sqlBuilder.AppendLine();
  320. sqlBuilder.AppendLine(innerSelect);
  321. sqlBuilder.Append(") AS j1 ON (");
  322. bool wroteKey = false;
  323. foreach (var keyMap in entityMap.KeyMaps)
  324. {
  325. if (wroteKey)
  326. sqlBuilder.Append(" AND ");
  327. sqlBuilder.AppendFormat("j0.{0} = j1.{0}", keyMap.ColumnName);
  328. wroteKey = true;
  329. }
  330. sqlBuilder.Append(")");
  331. deleteCommand.CommandText = sqlBuilder.ToString();
  332. int result = deleteCommand.ExecuteNonQuery();
  333. // only commit if created transaction
  334. if (ownTransaction)
  335. deleteTransaction.Commit();
  336. return result;
  337. }
  338. finally
  339. {
  340. if (deleteCommand != null)
  341. deleteCommand.Dispose();
  342. if (deleteTransaction != null && ownTransaction)
  343. deleteTransaction.Dispose();
  344. if (deleteConnection != null && ownConnection)
  345. deleteConnection.Close();
  346. }
  347. }
  348. private static int Update<TEntity>(ObjectContext objectContext, EntityMap entityMap, ObjectQuery<TEntity> query, Expression<Func<TEntity, TEntity>> updateExpression)
  349. where TEntity : class
  350. {
  351. DbConnection updateConnection = null;
  352. DbTransaction updateTransaction = null;
  353. DbCommand updateCommand = null;
  354. bool ownConnection = false;
  355. bool ownTransaction = false;
  356. try
  357. {
  358. // get store connection and transaction
  359. var store = GetStore(objectContext);
  360. updateConnection = store.Item1;
  361. updateTransaction = store.Item2;
  362. if (updateConnection.State != ConnectionState.Open)
  363. {
  364. updateConnection.Open();
  365. ownConnection = true;
  366. }
  367. // use existing transaction or create new
  368. if (updateTransaction == null)
  369. {
  370. updateTransaction = updateConnection.BeginTransaction();
  371. ownTransaction = true;
  372. }
  373. updateCommand = updateConnection.CreateCommand();
  374. updateCommand.Transaction = updateTransaction;
  375. if (objectContext.CommandTimeout.HasValue)
  376. updateCommand.CommandTimeout = objectContext.CommandTimeout.Value;
  377. var innerSelect = GetSelectSql(query, entityMap, updateCommand);
  378. var sqlBuilder = new StringBuilder(innerSelect.Length * 2);
  379. sqlBuilder.Append("UPDATE ");
  380. sqlBuilder.Append(entityMap.TableName);
  381. sqlBuilder.AppendLine(" SET ");
  382. var memberInitExpression = updateExpression.Body as MemberInitExpression;
  383. if (memberInitExpression == null)
  384. throw new ArgumentException("The update expression must be of type MemberInitExpression.", "updateExpression");
  385. int nameCount = 0;
  386. bool wroteSet = false;
  387. foreach (MemberBinding binding in memberInitExpression.Bindings)
  388. {
  389. if (wroteSet)
  390. sqlBuilder.AppendLine(", ");
  391. string propertyName = binding.Member.Name;
  392. string columnName = entityMap.PropertyMaps
  393. .Where(p => p.PropertyName == propertyName)
  394. .Select(p => p.ColumnName)
  395. .FirstOrDefault();
  396. string parameterName = "p__update__" + nameCount++;
  397. var memberAssignment = binding as MemberAssignment;
  398. if (memberAssignment == null)
  399. throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "updateExpression");
  400. object value;
  401. if (memberAssignment.Expression.NodeType == ExpressionType.Constant)
  402. {
  403. var constantExpression = memberAssignment.Expression as ConstantExpression;
  404. if (constantExpression == null)
  405. throw new ArgumentException("The MemberAssignment expression is not a ConstantExpression.", "updateExpression");
  406. value = constantExpression.Value;
  407. }
  408. else
  409. {
  410. LambdaExpression lambda = Expression.Lambda(memberAssignment.Expression, null);
  411. value = lambda.Compile().DynamicInvoke();
  412. }
  413. var parameter = updateCommand.CreateParameter();
  414. parameter.ParameterName = parameterName;
  415. parameter.Value = value;
  416. updateCommand.Parameters.Add(parameter);
  417. sqlBuilder.AppendFormat("{0} = @{1}", columnName, parameterName);
  418. wroteSet = true;
  419. }
  420. sqlBuilder.AppendLine(" ");
  421. sqlBuilder.AppendFormat("FROM {0} AS j0 INNER JOIN (", entityMap.TableName);
  422. sqlBuilder.AppendLine();
  423. sqlBuilder.AppendLine(innerSelect);
  424. sqlBuilder.Append(") AS j1 ON (");
  425. bool wroteKey = false;
  426. foreach (var keyMap in entityMap.KeyMaps)
  427. {
  428. if (wroteKey)
  429. sqlBuilder.Append(" AND ");
  430. sqlBuilder.AppendFormat("j0.{0} = j1.{0}", keyMap.ColumnName);
  431. wroteKey = true;
  432. }
  433. sqlBuilder.Append(")");
  434. updateCommand.CommandText = sqlBuilder.ToString();
  435. int result = updateCommand.ExecuteNonQuery();
  436. // only commit if created transaction
  437. if (ownTransaction)
  438. updateTransaction.Commit();
  439. return result;
  440. }
  441. finally
  442. {
  443. if (updateCommand != null)
  444. updateCommand.Dispose();
  445. if (updateTransaction != null && ownTransaction)
  446. updateTransaction.Dispose();
  447. if (updateConnection != null && ownConnection)
  448. updateConnection.Close();
  449. }
  450. }
  451. private static Tuple<DbConnection, DbTransaction> GetStore(ObjectContext objectContext)
  452. {
  453. DbConnection dbConnection = objectContext.Connection;
  454. var entityConnection = dbConnection as EntityConnection;
  455. // by-pass entity connection
  456. if (entityConnection == null)
  457. return new Tuple<DbConnection, DbTransaction>(dbConnection, null);
  458. DbConnection connection = entityConnection.StoreConnection;
  459. // get internal transaction
  460. dynamic connectionProxy = new DynamicProxy(entityConnection);
  461. dynamic entityTransaction = connectionProxy.CurrentTransaction;
  462. if (entityTransaction == null)
  463. return new Tuple<DbConnection, DbTransaction>(connection, null);
  464. DbTransaction transaction = entityTransaction.StoreTransaction;
  465. return new Tuple<DbConnection, DbTransaction>(connection, transaction);
  466. }
  467. private static string GetSelectSql<TEntity>(ObjectQuery<TEntity> query, EntityMap entityMap, DbCommand command)
  468. where TEntity : class
  469. {
  470. // changing query to only select keys
  471. var selector = new StringBuilder(50);
  472. selector.Append("new(");
  473. foreach (var propertyMap in entityMap.KeyMaps)
  474. {
  475. if (selector.Length > 4)
  476. selector.Append((", "));
  477. selector.Append(propertyMap.PropertyName);
  478. }
  479. selector.Append(")");
  480. var selectQuery = DynamicQueryable.Select(query, selector.ToString());
  481. var objectQuery = selectQuery as ObjectQuery;
  482. if (objectQuery == null)
  483. throw new ArgumentException("The query must be of type ObjectQuery.", "query");
  484. string innerJoinSql = objectQuery.ToTraceString();
  485. // create parameters
  486. foreach (var objectParameter in objectQuery.Parameters)
  487. {
  488. var parameter = command.CreateParameter();
  489. parameter.ParameterName = objectParameter.Name;
  490. parameter.Value = objectParameter.Value;
  491. command.Parameters.Add(parameter);
  492. }
  493. return innerJoinSql;
  494. }
  495. }
  496. }