Mybatis的拦截器实现机制,使用的是JDK的InvocationHandler
示例:
拦截器:PageInterceptor
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
| @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class}) }) public class PageInterceptor implements Interceptor {
private String databaseType;
@Override public Object intercept(Invocation invocation) throws Throwable { RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget(); StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate"); BoundSql boundSql = delegate.getBoundSql(); Object obj = boundSql.getParameterObject(); if (obj instanceof Page<?>) { Page<?> page = (Page<?>) obj; MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement"); Connection connection = (Connection) invocation.getArgs()[0]; String sql = boundSql.getSql(); String sortFild = page.getSortField(); String sortValue = page.getSortValue(); page.setSortField(null); page.setSortValue(null); if(page.isPage()) { this.settotalCount(page, mappedStatement, connection); } this.setTimestamp(page, connection); page.setSortField(sortFild); page.setSortValue(sortValue); String pageSql = this.getPageSql(page, sql); ReflectUtil.setFieldValue(boundSql, "sql", pageSql); } return invocation.proceed(); }
@Override public Object plugin(Object target) { return Plugin.wrap(target, this); }
@Override public void setProperties(Properties properties) { this.databaseType = properties.getProperty("databaseType"); }
private String getPageSql(Page<?> page, String sql) { StringBuffer sqlBuffer = new StringBuffer(sql); if ("mysql".equalsIgnoreCase(databaseType)) { return getMysqlPageSql(page, sqlBuffer); } else if ("oracle".equalsIgnoreCase(databaseType)) { return getOraclePageSql(page, sqlBuffer); } return sqlBuffer.toString(); }
private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) { int offset = (page.getPageNo() - 1) * page.getPageSize(); if(page.isPage()){ sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize()); } return sqlBuffer.toString(); }
private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) { int offset = (page.getPageNo() - 1) * page.getPageSize() + 1; sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize()); sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset); return sqlBuffer.toString(); }
private void settotalCount(Page<?> page, MappedStatement mappedStatement, Connection connection) { BoundSql boundSql = mappedStatement.getBoundSql(page); String sql = boundSql.getSql(); String countSql = this.getCountSql(sql); List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page); ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql); PreparedStatement pstmt = null; ResultSet rs = null; try { pstmt = connection.prepareStatement(countSql); parameterHandler.setParameters(pstmt); rs = pstmt.executeQuery(); if (rs.next()) { int totalCount = rs.getInt(1); page.setTotalCount(totalCount); } } catch (SQLException e) { e.printStackTrace(); } finally { try { if (rs != null) rs.close(); if (pstmt != null) pstmt.close(); } catch (SQLException e) { e.printStackTrace(); } } }
private void setTimestamp(Page<?> page, Connection connection) { Statement stmt = null; ResultSet rs = null; try { stmt = connection.createStatement(); rs = stmt.executeQuery(getTimestampSql()); if (rs.next()) { Date timestamp = rs.getTime(1); page.setTimestamp(new SimpleDateFormat("yyyyMMddHHmmss").format(timestamp)); } } catch (SQLException e) { e.printStackTrace(); } finally { try { if (rs != null) rs.close(); if (stmt != null) stmt.close(); } catch (SQLException e) { e.printStackTrace(); } } }
private String getCountSql(String sql) { return "select count(1) from ( " + sql+") aaa"; }
private String getTimestampSql() { if ("mysql".equalsIgnoreCase(databaseType)) { return "select now()"; } else if ("oracle".equalsIgnoreCase(databaseType)) { return "select sysdate from dual"; } else { return "select now()"; } }
private static class ReflectUtil {
public static Object getFieldValue(Object obj, String fieldName) { Object result = null; Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { field.setAccessible(true); try { result = field.get(obj); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } return result; }
private static Field getField(Object obj, String fieldName) { Field field = null; for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) { try { field = clazz.getDeclaredField(fieldName); break; } catch (NoSuchFieldException e) { } } return field; }
public static void setFieldValue(Object obj, String fieldName, String fieldValue) { Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { try { field.setAccessible(true); field.set(obj, fieldValue); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } } }
}
|
Page 对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
| public class Page<T> implements Paginable<T> {
private static final long serialVersionUID = 8933698230226183372L;
public static final int DEFAULT_PAGE_SIZE = 15;
public static final int PAGE_COUNT = 10;
protected int pageNo = 1;
private int pageSize = DEFAULT_PAGE_SIZE;
private int totalCount = 0;
private int totalPage = 0; private String timestamp = null;
private boolean full = true; private boolean isPage = true; private String sortField; private String sortValue; public String getSortField() { return sortField; }
public void setSortField(String sortField) { this.sortField = sortField; }
public String getSortValue() { return sortValue; }
public void setSortValue(String sortValue) { this.sortValue = sortValue; } public int getPageNo() { return pageNo; }
public void setPageNo(int pageNo) { this.pageNo = pageNo; }
public int getPageSize() { return pageSize; }
public void setPageSize(int pageSize) { this.pageSize = pageSize; }
public int getTotalCount() { return totalCount; }
public void setTotalCount(int totalCount) { this.totalCount = totalCount; int totalPage = totalCount % pageSize == 0 ? totalCount / pageSize : totalCount / pageSize + 1; this.setTotalPage(totalPage); }
public int getTotalPage() { return totalPage; }
public void setTotalPage(int totalPage) { this.totalPage = totalPage; }
@Override public boolean isFirstPage() { return pageNo <= 1; }
@Override public boolean isLastPage() { return pageNo >= totalPage; }
@Override public int getNextPage() { return isLastPage() ? pageNo : (pageNo + 1); }
@Override public int getPrePage() { return isFirstPage() ? pageNo : (pageNo - 1); }
@Override public int getBeginIndex() { if (pageNo > 0) { return (pageSize * (pageNo - 1)); } else { return 0; } }
@Override public int getEndIndex() { if (pageNo > 0) { return Math.min(pageSize * pageNo, totalCount); } else { return 0; } }
public int getBeginPage() { if (pageNo > 0) { return (PAGE_COUNT * ((pageNo - 1) / PAGE_COUNT)) + 1; } else { return 0; } }
public int getEndPage() { if (pageNo > 0) { return Math.min(PAGE_COUNT * ((pageNo - 1) / PAGE_COUNT + 1), getTotalPage()); } else { return 0; } }
public boolean isFull() { return full; }
public void setFull(boolean full) { this.full = full; }
public String getTimestamp() { return timestamp; }
public void setTimestamp(String timestamp) { this.timestamp = timestamp; }
public boolean isPage() { return isPage; }
public void setIsPage(boolean isPage) { this.isPage = isPage; } }
|
Paginable 注解
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| public interface Paginable<T> extends Serializable{ int getTotalCount(); int getTotalPage(); int getPageSize(); int getPageNo(); boolean isFirstPage(); boolean isLastPage(); int getNextPage(); int getPrePage(); int getBeginIndex(); int getEndIndex(); int getBeginPage(); int getEndPage(); }
|
测试:
1 2 3 4 5 6 7 8 9 10
| @Test public void test() throws IOException { String resource = "qingsong-mybatis.xml"; InputStream inputStream = Resources.getResourceAsStream(resource); SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream); SqlSession session = sqlSessionFactory.openSession(); UserMapper userMapper = session.getMapper(UserMapper.class); User result = userMapper.selectByid(1); System.out.println(result.toString()); }
|
分页插件执行步骤:
1,PageInterceptor:注册到Configuration的拦截器链上,拦截所有查询
2,判断是否是分页查询对象if (obj instanceof Page<?>)
如果是,则利用反射拿出StatementHandler
先拼接sql查询出总条数,封装到page对象中,之后在原sql语句上拼接limit参数,再执行后续流程
注解说明
@Intercepts 在实现Interceptor接口的类声明,使该类注册成为拦截器
Signature[] value//定义需要拦截哪些类,哪些方法
@Signature 定义哪些类(4种),方法,参数需要被拦截
Class> type()//ParameterHandler,ResultSetHandler,StatementHandler,Executor
String method()//
Class>[] args()//
调用分析
当我们调用ParameterHandler,ResultSetHandler,StatementHandler,Executor的对象的时候,实际上使用的是Plugin这个代理类的对象,
Plugin类实现了InvocationHandler接口.,在调用上述被代理类的方法的时候,就会执行Plugin的invoke方法.
Plugin在invoke方法中根据@Intercepts的配置信息(方法名,参数等)动态判断是否需要拦截该方法.
再然后使用需要拦截的方法Method封装成Invocation,并调用Interceptor的proceed方法.
执行大概是这样的流程:拦截器代理类对象->拦截器->目标方法
Executor -> Plugin -> Interceptor -> InvocationExecutor.Method -> Plugin.invoke -> Interceptor.intercept -> Invocation.proceed -> method.invoke