Mybatis插件进行数据库操作记录.md

Mybatis插件进行数据库操作记录

需求介绍

先介绍一下需求,
我们需要对用户的一些操作进行记录,比如那个用户修改了商户的一些设置,或者修改了手续费率等。

那么我们很自然的想到,可以使用 AOP 对可能产生修改的地方进行拦截,获取前后参数然后再入库。但是这样需要知道有哪些地方会做出哪些修改,还需要在方法上加注解。

这里使用 Mybatis Plugin 来做操作记录。

Mybatis plugin 介绍

这里简要介绍一下Mybatis plugin。
Mybatis plugin 也就是 mybatis interception, 其实就是一个拦截器,可以拦截 Mybatis 运行各个阶段的方法。

1
2
3
4
5
6
7
8
@Intercepts({
@Signature(method = "update", type = Executor.class, args = {MappedStatement.class, Object.class}),
@Signature(method = "update", type = StatementHandler.class, args = {Statement.class}),
@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})
})
@Intercepts({
@Signature(type = StatementHandler.class, method = "update", args = {Statement.class})
})

这里不做详细介绍,参考文章如下
Mybatis interceptor 参考文章

实现

代码

下面先贴一下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

@Slf4j
@Intercepts({
@Signature(type = StatementHandler.class, method = "update", args = {Statement.class})
})
public class OperationLogPlugin implements Interceptor {

/**
* 监控列表
* 监控以下表的修改语句,写入redis中
*/
private static final Set<String> MONITOR_TABLE = new HashSet<>();

static {
String[] table = {
"agency_merchant",
"pay_merchantChannel",
"agency_merchantSettle",
"agency_vendor",
"agency_merchantWechat",
"agency_outlet"
};
for (String s : table) {
MONITOR_TABLE.add(s.toLowerCase());
}
}


@Override
public Object intercept(Invocation invocation) throws Throwable {
BoundSql boundSql = null;
StatementHandler statementHandler;
if (invocation.getTarget() instanceof StatementHandler) {
statementHandler = (StatementHandler) invocation.getTarget();
boundSql = statementHandler.getBoundSql();
} else {
return invocation.proceed();
}

Object[] args = invocation.getArgs();
Statement statement = (Statement) args[0];

String preSql = boundSql.getSql();
MySqlStatementParser mySqlStatementParser = new MySqlStatementParser(preSql);
SQLStatement sqlStatement = mySqlStatementParser.parseStatement();
SQLExprTableSource sqlTableSource = null;
SQLExpr where = null;
// 获取 SQL 的 表名 以及 where 条件。仅处理Update语句,其他情况直接执行原SQL
if (sqlStatement instanceof SQLSelectStatement) {
// 查询语句不做任何处理,直接结束
return invocation.proceed();
} else if (sqlStatement instanceof SQLInsertStatement) {
// SQLInsertStatement sqlInsertStatement = (SQLInsertStatement) sqlStatement;
// sqlTableSource = sqlInsertStatement.getTableSource();
return invocation.proceed();
} else if (sqlStatement instanceof SQLUpdateStatement) {
SQLUpdateStatement sqlUpdateStatement = (SQLUpdateStatement) sqlStatement;
sqlTableSource = (SQLExprTableSource) sqlUpdateStatement.getTableSource();
where = sqlUpdateStatement.getWhere();
} else if (sqlStatement instanceof SQLDeleteStatement) {
// SQLDeleteStatement sqlDeleteStatement = (SQLDeleteStatement) sqlStatement;
// sqlTableSource = (SQLExprTableSource) sqlDeleteStatement.getTableSource();
// where = sqlDeleteStatement.getWhere();
return invocation.proceed();
} else {
return invocation.proceed();
}
String tableName = sqlTableSource.toString();
// where 条件为空,说明是一个批量修改,直接执行,不记录。这种情况记录SQL最好
if (where == null) {
return invocation.proceed();
}
// 获取 Connection
Connection connection = statement.getConnection();
// Update影响的主键
Long id = 0L;
List<Map<String, Object>> beforeImage;
List<Map<String, Object>> afterImage;
if (MONITOR_TABLE.contains(tableName.toLowerCase())) {
// 查询前镜像
String beforeImagePreSql = String.format("SELECT * FROM %s where %s", tableName, where.toString());
RoutingStatementHandler routingStatementHandler = (RoutingStatementHandler) getJdkDynamicProxyTargetObject(statementHandler);
StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(routingStatementHandler, "delegate");
//获取mapper方法与xml的映射信息
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
List<SQLObject> children = where.getChildren();
int whereConditionCount = getWhereConditionCount(children, 0);
List<ParameterMapping> newParameterMapping = new ArrayList<>(parameterMappings.size());
// Where语句一般都在最后,所以根据 whereConditionCount 取最后几个参数
for (int i = parameterMappings.size() - whereConditionCount; i < parameterMappings.size(); i++) {
newParameterMapping.add(parameterMappings.get(i));
}
BoundSql beforeImageBoundSql = new BoundSql(mappedStatement.getConfiguration(), beforeImagePreSql, newParameterMapping, parameterObject);
//StatementHandler、ResultSetHandler、ParameterHandler运行在这几个中插入自己的代码
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, beforeImageBoundSql);
//创建一个 PreparedStatement 对象来将参数化的 SQL 语句发送到数据库。
PreparedStatement prepareStatement = connection.prepareStatement(beforeImagePreSql);
parameterHandler.setParameters(prepareStatement);
ResultSet rs = prepareStatement.executeQuery();
beforeImage = convertList(rs);
id = (Long) beforeImage.get(0).get("id");
} else {
// 不需要监控的表直接结束
return invocation.proceed();
}
Object proceed = invocation.proceed();
if (beforeImage.size() == 1) {
// 只在更新数据为一条记录时记录日志。已经可以满足绝大部分业务需求了。
// 查询新数据,此时使用主键查询
String afterImageSql = String.format("SELECT * from %s where id = %s", tableName, id);
PreparedStatement prepareStatement = connection.prepareStatement(afterImageSql);
ResultSet rs = prepareStatement.executeQuery();
afterImage = convertList(rs);
// 新旧数据对比
List<Map<String, Object>> difference = difference(beforeImage.get(0), afterImage.get(0));
// System.out.println(difference);
// 获取用户信息
AuthenticationDto authentication = null;
try {
AuthUtil bean = SpringUtil.getBean(AuthUtil.class);
authentication = bean.getAuthentication();
} catch (Exception e) {
log.warn("获取用户信息失败");
}

String userId = authentication == null ? "" : String.valueOf(authentication.getUserId());
String merchantId = authentication == null ? "" : String.valueOf(authentication.getMerchantId());

// 插入 opreation_log
String operationLogSQL = "INSERT INTO `pay_1`.`operation_log`(`uid`, `merchantId`, `tableName`, `operationCode`, `primaryKey`, `oriValue`, `newValue`, `recordTime`)" +
"VALUES(?, ?, ?, ?, ?, ?, ?, ?);";
PreparedStatement operationLogStatement = connection.prepareStatement(operationLogSQL); //创建一个 PreparedStatement 对象来将参数化的 SQL 语句发送到数据库。
// uid
operationLogStatement.setString(1, userId);
// merchantId
operationLogStatement.setString(2, merchantId);
//tableName
operationLogStatement.setString(3, tableName);
//operationCode
operationLogStatement.setString(4, "UPDATE");
//primaryKey
operationLogStatement.setString(5, String.valueOf(id));
//oriValue
operationLogStatement.setString(6, JsonUtil.entity2Json(difference.get(0)));
//newValue
operationLogStatement.setString(7, JsonUtil.entity2Json(difference.get(1)));
//recordTime
operationLogStatement.setInt(8, TimeUtil.getNow());
operationLogStatement.execute();
}
return proceed;
}

@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}

@Override
public void setProperties(Properties properties) {
Interceptor.super.setProperties(properties);
}


/**
* 获取JDK动态代理的被代理对象
*
* @param proxy
* @return
* @throws Exception
*/
private static Object getJdkDynamicProxyTargetObject(Object proxy) throws Exception {
Field field = proxy.getClass().getSuperclass().getDeclaredField("h");
field.setAccessible(true);
//获取指定对象中此字段的值
Plugin pluginProxy = (Plugin) field.get(proxy); //获取Proxy对象中的此字段的值
Field target = pluginProxy.getClass().getDeclaredField("target");
target.setAccessible(true);
return target.get(pluginProxy);
}


/**
* 解析where条件抽象语法树,获取where条件数量
*
* @param listCommon
* @param sum
* @return
*/
private int getWhereConditionCount(List<SQLObject> listCommon, int sum) {
for (SQLObject sqlObject : listCommon) {
if (sqlObject instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr sqlObject1 = (SQLBinaryOpExpr) sqlObject;
List<SQLObject> chList = sqlObject1.getChildren();
sum = getWhereConditionCount(chList, sum);
} else if (sqlObject instanceof SQLIdentifierExpr) {
sum += 1;
}
}
return sum;
}

/**
* MYSQL结果集转对象
*
* @param rs
* @return
*/
private static List<Map<String, Object>> convertList(ResultSet rs) {
List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
try {
ResultSetMetaData md = rs.getMetaData();
int columnCount = md.getColumnCount();
while (rs.next()) {
Map<String, Object> rowData = new HashMap<String, Object>();
for (int i = 1; i <= columnCount; i++) {
Object object = rs.getObject(i);
if (object instanceof BigInteger) {
rowData.put(md.getColumnName(i), ((BigInteger) object).longValue());
} else {
rowData.put(md.getColumnName(i), object);
}
}
list.add(rowData);
}
} catch (SQLException e) {
// Auto-generated catch block
e.printStackTrace();
} finally {
try {
if (rs != null) {
rs.close();
}
rs = null;
} catch (SQLException e) {
e.printStackTrace();
}
}
return list;
}

/**
* 结果集对比
* 返回一个List map
* List.get(0) 存储前镜像
* List.get(1) 存储后镜像
*
* @param beforeImage
* @param afterImage
* @return
*/
private List<Map<String, Object>> difference(Map<String, Object> beforeImage, Map<String, Object> afterImage) {
List<Map<String, Object>> result = new ArrayList<>();
result.add(new HashMap<>());
result.add(new HashMap<>());
for (Map.Entry<String, Object> beforeEntry : beforeImage.entrySet()) {
String key = beforeEntry.getKey();
Object beforeValue = beforeEntry.getValue();
Object afterValue = afterImage.get(key);
if (!Objects.equals(beforeValue, afterValue)) {
result.get(0).put(key, beforeValue);
result.get(1).put(key, afterValue);
}
}
return result;
}
}

解析

首先,这个需求我做了一定的简化,事实上可以做的更加完善。
比如,我只处理 update 语句,只处理单条更新的语句。
也只是简单的测试通过,还没有经过生产的验证,如果以后生产验证有问题再修改一下Bug。

说一下整体思路。
首先我们的项目是强依赖 Mybatis 的,所以所有的数据库操作都是经过 Mybatis 的。那我们可以拦截所有的SQL语句,如果发现有更新我们特定表的SQL语句就做一些处理。
处理如下:

  1. 根据 update 语句获取它的 where 条件,然后生成反向的查询 SQL 语句。在执行更新语句前查出前置的SQL镜像,这里称为 beforeImage.
  2. 执行 update 语句。
  3. 我们所有的表都是有主键的,且一定叫id。把 beforeImage 的 id 取出来,根据Id 获取后置镜像,这里称为 afterImage.
  4. 根据前后镜像比较,取出有变更的字段。再加上从 ThreadLocal 中获取到的用户身份信息,生成一条SQL记录。

下面是代码详解:

1
2
3
@Intercepts({
@Signature(type = StatementHandler.class, method = "update", args = {Statement.class})
})

上面这一个注解,代表我拦截的是 StatementHandler 阶段。拦截这个阶段的目的是可以拿到 Statement, 然后就可以拿到 Connection。

MySqlStatementParser 这个类是alibaba druid的包,这里我肯定不可能去自己写SQL语句的解析的,当然是选择白嫖。
生成 MySqlStatementParser 后可以取到其中的 where 条件,然后获取 where 条件参数个数。
同时还可以获取到表名。

当我们判断这张表在我们需要监控的表中时。查询前置镜像。

1
2
3
4
if (MONITOR_TABLE.contains(tableName.toLowerCase())) {
// 查询前镜像
// .....
}

前置镜像的查询方法

核心方法如下。就是使用 Mybatis 的方法去创建。

1
2
3
4
5
6
7
BoundSql beforeImageBoundSql = new BoundSql(mappedStatement.getConfiguration(), beforeImagePreSql, newParameterMapping, parameterObject);
//StatementHandler、ResultSetHandler、ParameterHandler运行在这几个中插入自己的代码
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, beforeImageBoundSql);
//创建一个 PreparedStatement 对象来将参数化的 SQL 语句发送到数据库。
PreparedStatement prepareStatement = connection.prepareStatement(beforeImagePreSql);
parameterHandler.setParameters(prepareStatement);
ResultSet rs = prepareStatement.executeQuery();

关键点就是 要生成一个新的 BoundSql 和 ParameterHandler。
这里的关键点又是参数列表。

我们的参数一定是包含在 update 语句的参数中的,从 Mybatis 设置参数的源码可知,它是遍历 parameterMapping 从 parameterObject 取值的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@Override
public void setParameters(PreparedStatement ps) {
ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) { // issue #448 ask first for additional params
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
value = parameterObject;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
try {
typeHandler.setParameter(ps, i + 1, value, jdbcType);
} catch (TypeException | SQLException e) {
throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
}

所以这里我们生成一个新的 List<ParameterMapping> 然后从旧的 parameterMappings 中取最后一个值。

1
2
3
4
5
List<ParameterMapping> newParameterMapping = new ArrayList<>(parameterMappings.size());
// Where语句一般都在最后,所以根据 whereConditionCount 取最后几个参数
for (int i = parameterMappings.size() - whereConditionCount; i < parameterMappings.size(); i++) {
newParameterMapping.add(parameterMappings.get(i));
}

拿到结果集后 放入 List<Map<String, Object>> beforeImage;

后置镜像同理

1
2
3
4
String afterImageSql = String.format("SELECT * from %s where id = %s", tableName, id);
PreparedStatement prepareStatement = connection.prepareStatement(afterImageSql);
ResultSet rs = prepareStatement.executeQuery();
afterImage = convertList(rs);

调用 difference 方法拿到结果集的差异。

这一段是获取用户信息

1
2
3
4
5
6
7
8
9
10
11
 // 获取用户信息
AuthenticationDto authentication = null;
try {
AuthUtil bean = SpringUtil.getBean(AuthUtil.class);
authentication = bean.getAuthentication();
} catch (Exception e) {
log.warn("获取用户信息失败");
}

String userId = authentication == null ? "" : String.valueOf(authentication.getUserId());
String merchantId = authentication == null ? "" : String.valueOf(authentication.getMerchantId());

这一段是插入日志表中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 插入 opreation_log
String operationLogSQL = "INSERT INTO `pay_1`.`operation_log`(`uid`, `merchantId`, `tableName`, `operationCode`, `primaryKey`, `oriValue`, `newValue`, `recordTime`)" +
"VALUES(?, ?, ?, ?, ?, ?, ?, ?);";
PreparedStatement operationLogStatement = connection.prepareStatement(operationLogSQL); //创建一个 PreparedStatement 对象来将参数化的 SQL 语句发送到数据库。
// uid
operationLogStatement.setString(1, userId);
// merchantId
operationLogStatement.setString(2, merchantId);
//tableName
operationLogStatement.setString(3, tableName);
//operationCode
operationLogStatement.setString(4, "UPDATE");
//primaryKey
operationLogStatement.setString(5, String.valueOf(id));
//oriValue
operationLogStatement.setString(6, JsonUtil.entity2Json(difference.get(0)));
//newValue
operationLogStatement.setString(7, JsonUtil.entity2Json(difference.get(1)));
//recordTime
operationLogStatement.setInt(8, TimeUtil.getNow());
operationLogStatement.execute();

流程大概就这么多。

这里其实有一个思考点,关于回滚。
所以这里有一个比较讲究的地方,我们这里的 Connection 其实和update语句是同一个 Connection,所以回滚的时候会一起回滚。

这一篇写的比较匆忙,主要以代码为主


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!