Hexo

点滴积累 豁达处之

0%

Mybatis分页插件(Plugin+Interceptor)

Mybatis的拦截器实现机制,使用的是JDK的InvocationHandler

示例:

拦截器:PageInterceptor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class}) })
public class PageInterceptor implements Interceptor {

/** 数据库类型,不同的数据库有不同的分页方法 */
private String databaseType;

@Override
public Object intercept(Invocation invocation) throws Throwable {
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Object obj = boundSql.getParameterObject();
if (obj instanceof Page<?>) {
Page<?> page = (Page<?>) obj;
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
Connection connection = (Connection) invocation.getArgs()[0];
String sql = boundSql.getSql();
String sortFild = page.getSortField();
String sortValue = page.getSortValue();
page.setSortField(null);
page.setSortValue(null);
if(page.isPage()) {
this.settotalCount(page, mappedStatement, connection);
}
this.setTimestamp(page, connection);
page.setSortField(sortFild);
page.setSortValue(sortValue);
String pageSql = this.getPageSql(page, sql);
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
}
return invocation.proceed();
}

/**
* 拦截器对应的封装原始对象的方法
*/
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}

/**
* 设置注册拦截器时设定的属性
*/
@Override
public void setProperties(Properties properties) {
this.databaseType = properties.getProperty("databaseType");
}

/**
* 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle 其它的数据库都 没有进行分页
*
* @param page 分页对象
* @param sql 原sql语句
* @return
*/
private String getPageSql(Page<?> page, String sql) {
StringBuffer sqlBuffer = new StringBuffer(sql);
if ("mysql".equalsIgnoreCase(databaseType)) {
return getMysqlPageSql(page, sqlBuffer);
} else if ("oracle".equalsIgnoreCase(databaseType)) {
return getOraclePageSql(page, sqlBuffer);
}
return sqlBuffer.toString();
}

/**
* 获取Mysql数据库的分页查询语句
*
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Mysql数据库分页语句
*/
private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {
int offset = (page.getPageNo() - 1) * page.getPageSize();
if(page.isPage()){
sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize());
}
return sqlBuffer.toString();
}

/**
* 获取Oracle数据库的分页查询语句
*
* @param page 分页对象
* @param sqlBuffer 包含原sql语句的StringBuffer对象
* @return Oracle数据库的分页查询语句
*/
private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset
+ page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
return sqlBuffer.toString();
}

/**
* 给当前的参数对象page设置总记录数
*
* @param page Mapper映射语句对应的参数对象
* @param mappedStatement Mapper映射语句
* @param connection 当前的数据库连接
*/
private void settotalCount(Page<?> page, MappedStatement mappedStatement, Connection connection) {
BoundSql boundSql = mappedStatement.getBoundSql(page);
String sql = boundSql.getSql();
String countSql = this.getCountSql(sql);
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
parameterHandler.setParameters(pstmt);
rs = pstmt.executeQuery();
if (rs.next()) {
int totalCount = rs.getInt(1);
page.setTotalCount(totalCount);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (rs != null) rs.close();
if (pstmt != null) pstmt.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}

private void setTimestamp(Page<?> page, Connection connection)
{
Statement stmt = null;
ResultSet rs = null;
try
{
stmt = connection.createStatement();
rs = stmt.executeQuery(getTimestampSql());
if (rs.next())
{
Date timestamp = rs.getTime(1);
page.setTimestamp(new SimpleDateFormat("yyyyMMddHHmmss").format(timestamp));
}
}
catch (SQLException e)
{
e.printStackTrace();
}
finally
{
try
{
if (rs != null)
rs.close();
if (stmt != null)
stmt.close();
}
catch (SQLException e)
{
e.printStackTrace();
}
}
}

/**
* 根据原Sql语句获取对应的查询总记录数的Sql语句
*
* @param sql
* @return
*/
private String getCountSql(String sql) {
return "select count(1) from ( " + sql+") aaa";
}

private String getTimestampSql()
{
if ("mysql".equalsIgnoreCase(databaseType))
{
return "select now()";
}
else if ("oracle".equalsIgnoreCase(databaseType))
{
return "select sysdate from dual";
}
else
{
return "select now()";
}
}

/**
* 利用反射进行操作的一个工具类
*/
private static class ReflectUtil {

/**
* 利用反射获取指定对象的指定属性
*
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标属性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
return result;
}

/**
* 利用反射获取指定对象里面的指定属性
*
* @param obj 目标对象
* @param fieldName 目标属性
* @return 目标字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
}
}
return field;
}

/**
* 利用反射设置指定对象的指定属性为指定的值
*
* @param obj 目标对象
* @param fieldName 目标属性
* @param fieldValue 目标值
*/
public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}

}

Page 对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
public class Page<T> implements Paginable<T> {


private static final long serialVersionUID = 8933698230226183372L;

public static final int DEFAULT_PAGE_SIZE = 15; // 默认每页记录数

public static final int PAGE_COUNT = 10;

protected int pageNo = 1; // 页码

private int pageSize = DEFAULT_PAGE_SIZE; // 每页记录数

private int totalCount = 0; // 总记录数

private int totalPage = 0; // 总页数

private String timestamp = null; // 查询时间戳

private boolean full = true; // 是否全量更新

private boolean isPage = true; // 是否分页
private String sortField; //排序字段
private String sortValue; //排序值 asc/desc

public String getSortField() {
return sortField;
}

public void setSortField(String sortField) {
this.sortField = sortField;
}

public String getSortValue() {
return sortValue;
}

public void setSortValue(String sortValue) {
this.sortValue = sortValue;
}

public int getPageNo() {
return pageNo;
}

public void setPageNo(int pageNo) {
this.pageNo = pageNo;
}

public int getPageSize() {
return pageSize;
}

public void setPageSize(int pageSize) {
this.pageSize = pageSize;
}

public int getTotalCount() {
return totalCount;
}

public void setTotalCount(int totalCount) {
this.totalCount = totalCount;
int totalPage = totalCount % pageSize == 0 ? totalCount / pageSize : totalCount / pageSize + 1;
this.setTotalPage(totalPage);
}

public int getTotalPage() {
return totalPage;
}

public void setTotalPage(int totalPage) {
this.totalPage = totalPage;
}

@Override
public boolean isFirstPage() {
return pageNo <= 1;
}

@Override
public boolean isLastPage() {
return pageNo >= totalPage;
}

@Override
public int getNextPage() {
return isLastPage() ? pageNo : (pageNo + 1);
}

@Override
public int getPrePage() {
return isFirstPage() ? pageNo : (pageNo - 1);
}

@Override
public int getBeginIndex() {
if (pageNo > 0) {
return (pageSize * (pageNo - 1));
} else {
return 0;
}
}

@Override
public int getEndIndex() {
if (pageNo > 0) {
return Math.min(pageSize * pageNo, totalCount);
} else {
return 0;
}
}

public int getBeginPage() {
if (pageNo > 0) {
return (PAGE_COUNT * ((pageNo - 1) / PAGE_COUNT)) + 1;
} else {
return 0;
}
}

public int getEndPage() {
if (pageNo > 0) {
return Math.min(PAGE_COUNT * ((pageNo - 1) / PAGE_COUNT + 1), getTotalPage());
} else {
return 0;
}
}

public boolean isFull()
{
return full;
}

public void setFull(boolean full)
{
this.full = full;
}

public String getTimestamp()
{
return timestamp;
}

public void setTimestamp(String timestamp)
{
this.timestamp = timestamp;
}

public boolean isPage() {
return isPage;
}

public void setIsPage(boolean isPage) {
this.isPage = isPage;
}

}

Paginable 注解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public interface Paginable<T> extends Serializable{
/** 总记录数 */
int getTotalCount();
/** 总页数 */
int getTotalPage();
/** 每页记录数 */
int getPageSize();
/** 当前页号 */
int getPageNo();
/** 是否第一页 */
boolean isFirstPage();
/** 是否最后一页 */
boolean isLastPage();
/** 返回下页的页号 */
int getNextPage();
/** 返回上页的页号 */
int getPrePage();
/** 取得当前页显示的项的起始序号 */
int getBeginIndex();
/** 取得当前页显示的末项序号 */
int getEndIndex();
int getBeginPage();
int getEndPage();
}

测试:

1
2
3
4
5
6
7
8
9
10
@Test
public void test() throws IOException {
String resource = "qingsong-mybatis.xml";
InputStream inputStream = Resources.getResourceAsStream(resource);
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream);
SqlSession session = sqlSessionFactory.openSession();
UserMapper userMapper = session.getMapper(UserMapper.class);
User result = userMapper.selectByid(1);
System.out.println(result.toString());
}

分页插件执行步骤:

1,PageInterceptor:注册到Configuration的拦截器链上,拦截所有查询

2,判断是否是分页查询对象if (obj instanceof Page<?>) 如果是,则利用反射拿出StatementHandler先拼接sql查询出总条数,封装到page对象中,之后在原sql语句上拼接limit参数,再执行后续流程

注解说明

@Intercepts 在实现Interceptor接口的类声明,使该类注册成为拦截器
Signature[] value//定义需要拦截哪些类,哪些方法
@Signature 定义哪些类(4种),方法,参数需要被拦截
Class type()//ParameterHandler,ResultSetHandler,StatementHandler,Executor String method()// Class[] args()//

调用分析

当我们调用ParameterHandler,ResultSetHandler,StatementHandler,Executor的对象的时候,实际上使用的是Plugin这个代理类的对象,

Plugin类实现了InvocationHandler接口.,在调用上述被代理类的方法的时候,就会执行Plugin的invoke方法.
Plugin在invoke方法中根据@Intercepts的配置信息(方法名,参数等)动态判断是否需要拦截该方法.
再然后使用需要拦截的方法Method封装成Invocation,并调用Interceptor的proceed方法.

执行大概是这样的流程:拦截器代理类对象->拦截器->目标方法

Executor -> Plugin -> Interceptor -> InvocationExecutor.Method -> Plugin.invoke -> Interceptor.intercept -> Invocation.proceed -> method.invoke