2020/01/28, ASP.NET Core 3.1, VS2019, Microsoft.EntityFrameworkCore.Relational 3.1.1

摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构【4-工作单元和仓储设计】

使用泛型仓储(Repository)和工作单元(UnitOfWork)模式封装数据访问层基础的增删改查等方法

文章目录

此分支项目代码

关于本章节的工作单元模式:

泛型仓储封装了通用的增删改查方法,由工作单元统一管理仓储以保证数据库上下文一致性。

要获取仓储,都从工作单元中获取,通过仓储改动数据库后,由工作单元进行提交。

代码参考Arch/UnitOfWork的设计,大部分都是参考他的,然后做了一些中文注释,去除了分布式多库支持

添加包引用

MS.UnitOfWork项目添加对Microsoft.EntityFrameworkCore.Relational包的引用:

<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.1" />
</ItemGroup>

分页处理封装

MS.UnitOfWork项目中添加Collections文件夹,在该文件夹下添加IPagedList.csPagedList.csIEnumerablePagedListExtensions.csIQueryablePageListExtensions.cs类。

IPagedList.cs

using System.Collections.Generic;

namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 提供任何类型的分页接口
/// </summary>
/// <typeparam name="T">需要分页的数据类型</typeparam>
public interface IPagedList<T>
{
/// <summary>
/// 起始页 值
/// </summary>
int IndexFrom { get; }
/// <summary>
/// 当前页 值
/// </summary>
int PageIndex { get; }
/// <summary>
/// 每页大小
/// </summary>
int PageSize { get; }
/// <summary>
/// 数据总数
/// </summary>
int TotalCount { get; }
/// <summary>
/// 总页数
/// </summary>
int TotalPages { get; }
/// <summary>
/// 当前页数据
/// </summary>
IList<T> Items { get; }
/// <summary>
/// 是否有上一页
/// </summary>
bool HasPreviousPage { get; }
/// <summary>
/// 是否有下一页
/// </summary>
bool HasNextPage { get; }
}
}

PagedList.cs

using System;
using System.Collections.Generic;
using System.Linq; namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 提供数据的分页,<see cref="IPagedList{T}"/>的默认实现
/// </summary>
/// <typeparam name="T"></typeparam>
public class PagedList<T> : IPagedList<T>
{
/// <summary>
/// 当前页 值
/// </summary>
public int PageIndex { get; set; }
/// <summary>
/// 每页大小
/// </summary>
public int PageSize { get; set; }
/// <summary>
/// 数据总数
/// </summary>
public int TotalCount { get; set; }
/// <summary>
/// 总页数
/// </summary>
public int TotalPages { get; set; }
/// <summary>
/// 起始页 值
/// </summary>
public int IndexFrom { get; set; }
/// <summary>
/// 当前页数据
/// </summary>
public IList<T> Items { get; set; }
/// <summary>
/// 是否有上一页
/// </summary>
public bool HasPreviousPage => PageIndex - IndexFrom > 0;
/// <summary>
/// 是否有下一页
/// </summary>
public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages; /// <summary>
/// 初始化实例
/// </summary>
/// <param name="source">The source.</param>
/// <param name="pageIndex">The index of the page.</param>
/// <param name="pageSize">The size of the page.</param>
/// <param name="indexFrom">The index from.</param>
internal PagedList(IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom)
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
} if (source is IQueryable<T> querable)
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = querable.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
}
else
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = source.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); Items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
}
} /// <summary>
/// Initializes a new instance of the <see cref="PagedList{T}" /> class.
/// </summary>
internal PagedList() => Items = new T[0];
} /// <summary>
/// 提供数据的分页,并支持数据类型转换
/// </summary>
/// <typeparam name="TSource">数据源类型</typeparam>
/// <typeparam name="TResult">输出数据类型</typeparam>
internal class PagedList<TSource, TResult> : IPagedList<TResult>
{
/// <summary>
/// 当前页 值
/// </summary>
public int PageIndex { get; set; }
/// <summary>
/// 每页大小
/// </summary>
public int PageSize { get; set; }
/// <summary>
/// 数据总数
/// </summary>
public int TotalCount { get; set; }
/// <summary>
/// 总页数
/// </summary>
public int TotalPages { get; set; }
/// <summary>
/// 起始页 值
/// </summary>
public int IndexFrom { get; set; }
/// <summary>
/// 当前页数据
/// </summary>
public IList<TResult> Items { get; set; }
/// <summary>
/// 是否有上一页
/// </summary>
public bool HasPreviousPage => PageIndex - IndexFrom > 0;
/// <summary>
/// 是否有下一页
/// </summary>
public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages; /// <summary>
/// 初始化实例
/// </summary>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
/// <param name="pageIndex">The index of the page.</param>
/// <param name="pageSize">The size of the page.</param>
/// <param name="indexFrom">The index from.</param>
public PagedList(IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom)
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
} if (source is IQueryable<TSource> querable)
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = querable.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List<TResult>(converter(items));
}
else
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = source.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); var items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List<TResult>(converter(items));
}
} /// <summary>
/// Initializes a new instance of the <see cref="PagedList{TSource, TResult}" /> class.
/// </summary>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
public PagedList(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter)
{
PageIndex = source.PageIndex;
PageSize = source.PageSize;
IndexFrom = source.IndexFrom;
TotalCount = source.TotalCount;
TotalPages = source.TotalPages; Items = new List<TResult>(converter(source.Items));
}
} /// <summary>
/// Provides some help methods for <see cref="IPagedList{T}"/> interface.
/// </summary>
public static class PagedList
{
/// <summary>
/// Creates an empty of <see cref="IPagedList{T}"/>.
/// </summary>
/// <typeparam name="T">The type for paging </typeparam>
/// <returns>An empty instance of <see cref="IPagedList{T}"/>.</returns>
public static IPagedList<T> Empty<T>() => new PagedList<T>();
/// <summary>
/// Creates a new instance of <see cref="IPagedList{TResult}"/> from source of <see cref="IPagedList{TSource}"/> instance.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <typeparam name="TSource">The type of the source.</typeparam>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
/// <returns>An instance of <see cref="IPagedList{TResult}"/>.</returns>
public static IPagedList<TResult> From<TResult, TSource>(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) => new PagedList<TSource, TResult>(source, converter);
}
}

IEnumerablePagedListExtensions.cs

using System;
using System.Collections.Generic; namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 给<see cref="IEnumerable{T}"/>添加扩展方法来支持分页
/// </summary>
public static class IEnumerablePagedListExtensions
{
/// <summary>
/// 在数据中取得固定页的数据
/// </summary>
/// <typeparam name="T">数据类型</typeparam>
/// <param name="source">数据源</param>
/// <param name="pageIndex">当前页</param>
/// <param name="pageSize">页大小</param>
/// <param name="indexFrom">起始页</param>
/// <returns></returns>
public static IPagedList<T> ToPagedList<T>(this IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<T>(source, pageIndex, pageSize, indexFrom); /// <summary>
/// 在数据中取得固定页数据,并转换为指定数据类型
/// </summary>
/// <typeparam name="TSource">数据源类型</typeparam>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="source">数据源</param>
/// <param name="converter"></param>
/// <param name="pageIndex">当前页</param>
/// <param name="pageSize">页大小</param>
/// <param name="indexFrom">起始页</param>
/// <returns></returns>
public static IPagedList<TResult> ToPagedList<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<TSource, TResult>(source, converter, pageIndex, pageSize, indexFrom);
}
}

IQueryablePageListExtensions.cs

using Microsoft.EntityFrameworkCore;
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks; namespace MS.UnitOfWork.Collections
{
public static class IQueryablePageListExtensions
{
/// <summary>
/// 在数据中取得固定页的数据(异步操作)
/// </summary>
/// <typeparam name="T">数据类型</typeparam>
/// <param name="source">数据源</param>
/// <param name="pageIndex">当前页</param>
/// <param name="pageSize">页大小</param>
/// <param name="indexFrom">起始页</param>
/// <param name="cancellationToken">异步观察参数</param>
/// <returns></returns>
public static async Task<IPagedList<T>> ToPagedListAsync<T>(this IQueryable<T> source, int pageIndex, int pageSize, int indexFrom = 1, CancellationToken cancellationToken = default(CancellationToken))
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex");
} var count = await source.CountAsync(cancellationToken).ConfigureAwait(false);
var items = await source.Skip((pageIndex - indexFrom) * pageSize)
.Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false); var pagedList = new PagedList<T>()
{
PageIndex = pageIndex,
PageSize = pageSize,
IndexFrom = indexFrom,
TotalCount = count,
Items = items,
TotalPages = (int)Math.Ceiling(count / (double)pageSize)
}; return pagedList;
}
}
}

针对IQueryable、IEnumerable类型的数据做了分页扩展方法封装,主要用于向数据库获取数据时进行分页筛选

泛型仓储

MS.UnitOfWork项目中添加Repository文件夹,在该文件夹下添加IRepository.csRepository.cs类。

IRepository.cs

using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks; namespace MS.UnitOfWork
{
/// <summary>
/// 通用仓储接口
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public interface IRepository<TEntity> where TEntity : class
{
#region GetAll
/// <summary>
///获取所有实体
///注意性能!
/// </summary>
/// <returns>The <see cref="IQueryable{TEntity}"/>.</returns>
IQueryable<TEntity> GetAll(); /// <summary>
/// 获取所有实体
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
IQueryable<TEntity> GetAll(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false); /// <summary>
/// 获取所有实体,必须提供筛选谓词
/// </summary>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="selector">投影选择器</param>
/// <param name="predicate">筛选谓词</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <returns></returns>
IQueryable<TResult> GetAll<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false
) where TResult : class; /// <summary>
/// 获取所有实体
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
Task<IList<TEntity>> GetAllAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false);
#endregion #region GetPagedList
/// <summary>
/// 获取分页数据
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="pageIndex">当前页。默认第一页</param>
/// <param name="pageSize">页大小。默认20笔数据</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
IPagedList<TEntity> GetPagedList(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false); /// <summary>
/// 获取分页数据
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="pageIndex">当前页。默认第一页</param>
/// <param name="pageSize">页大小。默认20笔数据</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <param name="cancellationToken">异步token</param>
/// <returns></returns>
Task<IPagedList<TEntity>> GetPagedListAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default); /// <summary>
/// 获取分页数据
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="selector">投影选择器</param>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="pageIndex">当前页。默认第一页</param>
/// <param name="pageSize">页大小。默认20笔数据</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
IPagedList<TResult> GetPagedList<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false
) where TResult : class; /// <summary>
/// 获取分页数据
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="selector">投影选择器</param>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="pageIndex">当前页。默认第一页</param>
/// <param name="pageSize">页大小。默认20笔数据</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <param name="cancellationToken">异步token</param>
/// <returns></returns>
Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default) where TResult : class; #endregion #region GetFirstOrDefault
/// <summary>
/// 获取满足条件的序列中的第一个元素
/// 如果没有元素满足条件,则返回默认值
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
TEntity GetFirstOrDefault(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false); /// <summary>
/// 获取满足条件的序列中的第一个元素
/// 如果没有元素满足条件,则返回默认值
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <param name="cancellationToken">异步token</param>
/// <returns></returns>
Task<TEntity> GetFirstOrDefaultAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default); /// <summary>
/// 获取满足条件的序列中的第一个元素
/// 如果没有元素满足条件,则返回默认值
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="selector">投影选择器</param>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <returns></returns>
TResult GetFirstOrDefault<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false); /// <summary>
/// 获取满足条件的序列中的第一个元素
/// 如果没有元素满足条件,则返回默认值
/// 默认是关闭追踪查询的(拿到的数据默认只读)
/// 默认开启全局查询筛选过滤
/// </summary>
/// <typeparam name="TResult">输出数据类型</typeparam>
/// <param name="selector">投影选择器</param>
/// <param name="predicate">条件表达式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的导航属性</param>
/// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
/// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
/// <param name="cancellationToken">异步token</param>
/// <returns></returns>
Task<TResult> GetFirstOrDefaultAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default); #endregion #region Find
/// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <returns>The found entity or null.</returns>
TEntity Find(params object[] keyValues); /// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
ValueTask<TEntity> FindAsync(params object[] keyValues); /// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken);
#endregion #region sql、count、exist
/// <summary>
/// 使用原生sql查询来获取指定数据
/// </summary>
/// <param name="sql"></param>
/// <param name="parameters"></param>
/// <returns></returns>
IQueryable<TEntity> FromSql(string sql, params object[] parameters); /// <summary>
/// 查询数量
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
int Count(Expression<Func<TEntity, bool>> predicate = null); /// <summary>
/// 查询数量
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null); /// <summary>
/// 按指定条件元素是否存在
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
bool Exists(Expression<Func<TEntity, bool>> predicate = null);
#endregion #region Insert
/// <summary>
/// Inserts a new entity synchronously.
/// </summary>
/// <param name="entity"></param>
/// <returns></returns>
TEntity Insert(TEntity entity); /// <summary>
/// Inserts a range of entities synchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
void Insert(params TEntity[] entities); /// <summary>
/// Inserts a range of entities synchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
void Insert(IEnumerable<TEntity> entities); /// <summary>
/// Inserts a new entity asynchronously.
/// </summary>
/// <param name="entity">The entity to insert.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default); /// <summary>
/// Inserts a range of entities asynchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
Task InsertAsync(params TEntity[] entities); /// <summary>
/// Inserts a range of entities asynchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default);
#endregion #region Update
/// <summary>
/// Updates the specified entity.
/// </summary>
/// <param name="entity">The entity.</param>
void Update(TEntity entity); /// <summary>
/// Updates the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Update(params TEntity[] entities); /// <summary>
/// Updates the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Update(IEnumerable<TEntity> entities);
#endregion #region Delete
/// <summary>
/// Deletes the entity by the specified primary key.
/// </summary>
/// <param name="id">The primary key value.</param>
void Delete(object id); /// <summary>
/// Deletes the specified entity.
/// </summary>
/// <param name="entity">The entity to delete.</param>
void Delete(TEntity entity); /// <summary>
/// Deletes the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Delete(params TEntity[] entities); /// <summary>
/// Deletes the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Delete(IEnumerable<TEntity> entities);
#endregion
}
}

Repository.cs

using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks; namespace MS.UnitOfWork
{
/// <summary>
/// 通用仓储的默认实现
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
{
protected readonly DbContext _dbContext;
protected readonly DbSet<TEntity> _dbSet; public Repository(DbContext dbContext)
{
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
_dbSet = _dbContext.Set<TEntity>();
} #region GetAll
public IQueryable<TEntity> GetAll() => _dbSet; public IQueryable<TEntity> GetAll(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query);
}
else
{
return query;
}
} public IQueryable<TResult> GetAll<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false) where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query).Select(selector);
}
else
{
return query.Select(selector);
}
} public async Task<IList<TEntity>> GetAllAsync(Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet; if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return await orderBy(query).ToListAsync();
}
else
{
return await query.ToListAsync();
}
}
#endregion #region GetPagedList
public virtual IPagedList<TEntity> GetPagedList(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query).ToPagedList(pageIndex, pageSize);
}
else
{
return query.ToPagedList(pageIndex, pageSize);
}
} public virtual async Task<IPagedList<TEntity>> GetPagedListAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return await orderBy(query).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
else
{
return await query.ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
} public virtual IPagedList<TResult> GetPagedList<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize);
}
else
{
return query.Select(selector).ToPagedList(pageIndex, pageSize);
}
} public virtual async Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return await orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
else
{
return await query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
}
#endregion #region GetFirstOrDefault public virtual TEntity GetFirstOrDefault(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query).FirstOrDefault();
}
else
{
return query.FirstOrDefault();
}
} public virtual async Task<TEntity> GetFirstOrDefaultAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return await orderBy(query).FirstOrDefaultAsync(cancellationToken);
}
else
{
return await query.FirstOrDefaultAsync(cancellationToken);
}
} public virtual TResult GetFirstOrDefault<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return orderBy(query).Select(selector).FirstOrDefault();
}
else
{
return query.Select(selector).FirstOrDefault();
}
} public virtual async Task<TResult> GetFirstOrDefaultAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
} if (include != null)
{
query = include(query);
} if (predicate != null)
{
query = query.Where(predicate);
} if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
} if (orderBy != null)
{
return await orderBy(query).Select(selector).FirstOrDefaultAsync(cancellationToken);
}
else
{
return await query.Select(selector).FirstOrDefaultAsync(cancellationToken);
}
}
#endregion #region Find public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues); public virtual ValueTask<TEntity> FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); public virtual ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken);
#endregion #region sql、count、exist
public virtual IQueryable<TEntity> FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters); public virtual int Count(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return _dbSet.Count();
}
else
{
return _dbSet.Count(predicate);
}
} public virtual async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return await _dbSet.CountAsync();
}
else
{
return await _dbSet.CountAsync(predicate);
}
}
public virtual bool Exists(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return _dbSet.Any();
}
else
{
return _dbSet.Any(predicate);
}
}
#endregion #region Insert
public virtual TEntity Insert(TEntity entity)
{
return _dbSet.Add(entity).Entity;
} public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities); public virtual void Insert(IEnumerable<TEntity> entities) => _dbSet.AddRange(entities); public virtual ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken))
{
return _dbSet.AddAsync(entity, cancellationToken); // Shadow properties?
//var property = _dbContext.Entry(entity).Property("Created");
//if (property != null) {
//property.CurrentValue = DateTime.Now;
//}
} public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities); public virtual Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken); #endregion #region Update
public virtual void Update(TEntity entity)
{
_dbSet.Update(entity);
} public virtual void UpdateAsync(TEntity entity)
{
_dbSet.Update(entity); } public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities); public virtual void Update(IEnumerable<TEntity> entities) => _dbSet.UpdateRange(entities);
#endregion #region Delete public virtual void Delete(TEntity entity) => _dbSet.Remove(entity); public virtual void Delete(object id)
{
var entity = _dbSet.Find(id);
if (entity != null)
{
Delete(entity);
}
} public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities); public virtual void Delete(IEnumerable<TEntity> entities) => _dbSet.RemoveRange(entities); #endregion }
}

说明

  • 封装了通用的增删改查操作
  • 以Async方法名结尾的是异步操作
  • 方法注释都在接口中
  • 查询:
    • GetAll查询所有满足条件的实体(注意性能)
    • GetPagedList分页查询
    • GetFirstOrDefault获取满足条件的第一个元素
    • Find根据主键查找元素,比如给一个Id值
    • FromSql原生sql查询
    • Count查询数量
    • Exists查询是否存在
  • 查询中包含了很多条件:
    • 分页查询默认每页20笔数据
    • 默认关闭了追踪查询
    • 默认开启了全局查询过滤
    • selector参数可以转换查询出来的数据为其他类型

工作单元

MS.UnitOfWork项目中添加UnitOfWork文件夹,在该文件夹下添加IUnitOfWork.csUnitOfWork.cs类。

IUnitOfWork.cs

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Linq;
using System.Threading.Tasks; namespace MS.UnitOfWork
{
/// <summary>
/// 定义工作单元接口
/// </summary>
public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext
{
/// <summary>
/// 获取DBContext
/// </summary>
/// <returns></returns>
TContext DbContext { get; }
/// <summary>
/// 开始一个事务
/// </summary>
/// <returns></returns>
IDbContextTransaction BeginTransaction(); /// <summary>
/// 获取指定仓储
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="hasCustomRepository">如有自定义仓储设为True</param>
/// <returns></returns>
IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class; /// <summary>
/// DbContext提交修改
/// </summary>
/// <returns></returns>
int SaveChanges(); /// <summary>
/// DbContext提交修改(异步)
/// </summary>
/// <returns></returns>
Task<int> SaveChangesAsync(); /// <summary>
/// 执行原生sql语句
/// </summary>
/// <param name="sql">sql语句</param>
/// <param name="parameters">参数</param>
/// <returns></returns>
int ExecuteSqlCommand(string sql, params object[] parameters); /// <summary>
/// 使用原生sql查询来获取指定数据
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters">参数</param>
/// <returns></returns>
IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class;
}
}

UnitOfWork.cs

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks; namespace MS.UnitOfWork
{
/// <summary>
/// 工作单元的默认实现.
/// </summary>
/// <typeparam name="TContext"></typeparam>
public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext
{
protected readonly TContext _context;
protected bool _disposed = false;
protected Dictionary<Type, object> _repositories; public UnitOfWork(TContext context)
{
_context = context ?? throw new ArgumentNullException(nameof(context));
} /// <summary>
/// 获取DbContext
/// </summary>
public TContext DbContext => _context;
/// <summary>
/// 开始一个事务
/// </summary>
/// <returns></returns>
public IDbContextTransaction BeginTransaction()
{
return _context.Database.BeginTransaction();
} /// <summary>
/// 获取指定仓储
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="hasCustomRepository"></param>
/// <returns></returns>
public IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class
{
if (_repositories == null)
{
_repositories = new Dictionary<Type, object>();
} Type type = typeof(IRepository<TEntity>);
if (!_repositories.TryGetValue(type, out object repo))
{
IRepository<TEntity> newRepo = new Repository<TEntity>(_context);
_repositories.Add(type, newRepo);
return newRepo;
}
return (IRepository<TEntity>)repo;
} /// <summary>
/// 执行原生sql语句
/// </summary>
/// <param name="sql">sql语句</param>
/// <param name="parameters">参数</param>
/// <returns></returns>
public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters); /// <summary>
/// 使用原生sql查询来获取指定数据
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters">参数</param>
/// <returns></returns>
public IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class => _context.Set<TEntity>().FromSqlRaw(sql, parameters); /// <summary>
/// DbContext提交修改
/// </summary>
/// <returns></returns>
public int SaveChanges()
{
return _context.SaveChanges();
} /// <summary>
/// DbContext提交修改(异步)
/// </summary>
/// <returns></returns>
public async Task<int> SaveChangesAsync()
{
return await _context.SaveChangesAsync();
} public void Dispose()
{
Dispose(true); GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
// clear repositories
if (_repositories != null)
{
_repositories.Clear();
} // dispose the db context.
_context.Dispose();
}
} _disposed = true;
}
}
}

说明

  • 从工作单元中获取仓储或DbContext数据库上下文
  • 如果要使用Transaction事务,也是从工作单元中开启
  • 通过仓储修改数据后,使用工作单元SaveChanges提交修改

封装Ioc注册

MS.UnitOfWork项目中添加UnitOfWorkServiceExtensions.cs类:

using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection; namespace MS.UnitOfWork
{
/// <summary>
///在 <see cref="IServiceCollection"/>中安装工作单元依赖注入的扩展方法
/// </summary>
public static class UnitOfWorkServiceExtensions
{
/// <summary>
/// 在<see cref ="IServiceCollection"/>中注册给定上下文作为服务的工作单元。
/// 同时注册了dbcontext
/// </summary>
/// <typeparam name="TContext"></typeparam>
/// <param name="services"></param>
/// <remarks>此方法仅支持一个db上下文,如果多次调用,将抛出异常。</remarks>
/// <returns></returns>
public static IServiceCollection AddUnitOfWorkService<TContext>(this IServiceCollection services, System.Action<DbContextOptionsBuilder> action) where TContext : DbContext
{
//注册dbcontext
services.AddDbContext<TContext>(action);
//注册工作单元
services.AddScoped<IUnitOfWork<TContext>, UnitOfWork<TContext>>();
return services;
}
}
}

这样一来,如果项目要使用该工作单元,直接在Startup中调用AddUnitOfWorkService注册即可

项目完成后,如下图所示:

使用方法展示

using (var tran = _unitOfWork.BeginTransaction())//开启一个事务
{
Role newRow = _mapper.Map<Role>(viewModel);
newRow.Id = _idWorker.NextId();//获取一个雪花Id
newRow.Creator = 1219490056771866624;//由于暂时还没有做登录,所以拿不到登录者信息,先随便写一个后面再完善
newRow.CreateTime = DateTime.Now;
_unitOfWork.GetRepository<Role>().Insert(newRow);
await _unitOfWork.SaveChangesAsync();
await tran.CommitAsync();//提交事务
}

以上展示了工作单元开启事务,用using包裹,直到tran.CommitAsync()才提交事务,如果遇到错误,会自动回滚

//从数据库中取出该记录
var row = await _unitOfWork.GetRepository<Role>().FindAsync(viewModel.Id);//在viewModel.CheckField中已经获取了一次用于检查,所以此处不会重复再从数据库取一次,有缓存
//修改对应的值
row.Name = viewModel.Name;
row.DisplayName = viewModel.DisplayName;
row.Remark = viewModel.Remark;
row.Modifier = 1219490056771866624;//由于暂时还没有做登录,所以拿不到登录者信息,先随便写一个后面再完善
row.ModifyTime = DateTime.Now;
_unitOfWork.GetRepository<Role>().Update(row);
await _unitOfWork.SaveChangesAsync();//提交
  • 以上展示了根据主键Id获取数据,更新数据。
  • 也可以GetFirstOrDefault获取数据,disableTracking参数设为false,开启追踪,这样获取到的数据修改后,直接SaveChangesAsync,不需要update(关键就是开启了追踪,所以不需要update实体了)

ASP.NET Core搭建多层网站架构【4-工作单元和仓储设计】的更多相关文章

  1. ASP.NET Core搭建多层网站架构【0-前言】

    2020/01/26, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构 目录 0-前言 1-项目结构分层建立 2-公共基 ...

  2. ASP.NET Core搭建多层网站架构【1-项目结构分层建立】

    2020/01/26, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[1-项目结构分层建立] 文章目录 此分支项目代码 ...

  3. ASP.NET Core搭建多层网站架构【2-公共基础库】

    2020/01/28, ASP.NET Core 3.1, VS2019,Newtonsoft.Json 12.0.3, Microsoft.AspNetCore.Cryptography.KeyDe ...

  4. ASP.NET Core搭建多层网站架构【3-xUnit单元测试之简单方法测试】

    2020/01/28, ASP.NET Core 3.1, VS2019, xUnit 2.4.0 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[3-xUnit单元测试 ...

  5. ASP.NET Core搭建多层网站架构【5-网站数据库实体设计及映射配置】

    2020/01/29, ASP.NET Core 3.1, VS2019, EntityFrameworkCore 3.1.1, Microsoft.Extensions.Logging.Consol ...

  6. ASP.NET Core搭建多层网站架构【6-注册跨域、网站核心配置】

    2020/01/29, ASP.NET Core 3.1, VS2019, NLog.Web.AspNetCore 4.9.0 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站 ...

  7. ASP.NET Core搭建多层网站架构【7-使用NLog日志记录器】

    2020/01/29, ASP.NET Core 3.1, VS2019, NLog.Web.AspNetCore 4.9.0 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站 ...

  8. ASP.NET Core搭建多层网站架构【8.1-使用ViewModel注解验证】

    2020/01/29, ASP.NET Core 3.1, VS2019 摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构[8.1-使用ViewModel注解验证] 使用V ...

  9. ASP.NET Core搭建多层网站架构【8.2-使用AutoMapper映射实体对象】

    2020/01/29, ASP.NET Core 3.1, VS2019, AutoMapper.Extensions.Microsoft.DependencyInjection 7.0.0 摘要:基 ...

随机推荐

  1. jquery+layer实现无刷新、删除功能(laravel框架)

    先来看一下效果 路由代码 Route::get('car/{id}/delete', 'CarController@delete'); 控制器层代码 //删除汽车信息 public function ...

  2. JavaScript可视化运行工具推荐

    事件循环.执行栈和任务队列可视化 这个宏任务.微任务,自带例子,也可以自己编辑,不过超过5s的例子就不行 JavaScript Visualizer Tyler Mcginnis大佬的Advanced ...

  3. es2.0的语法学习

    确定文档和查询有多么相关的过程被称为打分(scoring):将查询作为输入,使用不同的手段来确定每一篇文档的得分,将每一个因素最后通过公式综合起来,返回该文档的最终得分.这个综合考量的过程,就是我们希 ...

  4. AcWing 830. 单调栈

    https://www.acwing.com/problem/content/832/ #include <iostream> using namespace std; ; int stk ...

  5. Ztree使用教程

    这两天项目需要写一个树形帮助,学习了我们项目组的老师的Ztree的树的写法,实现了这个帮助,下面来总结一下Ztree的用法. (也是参考的一篇csdn上的博客了) zTree 是一个依靠 jQuery ...

  6. webpack配置的说明

    {devtool: 'source-map',//要启用source-map需加上此配置项,同时css或less的loader要加上参数?sourceMap,js的loader不用加 entry: e ...

  7. HTML学习(16)颜色

    HTML 颜色由红色.绿色.蓝色混合而成. 颜色值 HTML 颜色由一个十六进制符号来定义,这个符号由红色.绿色和蓝色的值组成(RGB). 每种颜色的最小值是0(十六进制:#00).最大值是255(十 ...

  8. 【转载】Java开发中的23种设计模式

    转自:http://zz563143188.iteye.com/blog/1847029 一.设计模式的分类 总体来说设计模式分为三大类: 创建型模式,共五种:工厂方法模式.抽象工厂模式.单例模式.建 ...

  9. 【一句话解释】docker and vm

    效果 在一个host上面运行多个os,达到快速部署以及充分利用资源的额目的 vm 虚拟机,会模拟一个完整的操作系统堆栈出来. 缺点开销大,优点,guest os 是一个完整的操作系统 根据hyperv ...

  10. 吴裕雄 python 机器学习——数据预处理正则化Normalizer模型

    from sklearn.preprocessing import Normalizer #数据预处理正则化Normalizer模型 def test_Normalizer(): X=[[1,2,3, ...