|
|
|
package com.zhonglai.luhui.datasource.plugin;
|
|
|
|
|
|
|
|
import org.apache.ibatis.executor.Executor;
|
|
|
|
import org.apache.ibatis.executor.statement.StatementHandler;
|
|
|
|
import org.apache.ibatis.mapping.MappedStatement;
|
|
|
|
import org.apache.ibatis.mapping.SqlCommandType;
|
|
|
|
import org.apache.ibatis.plugin.*;
|
|
|
|
import org.apache.ibatis.reflection.MetaObject;
|
|
|
|
import org.apache.ibatis.reflection.SystemMetaObject;
|
|
|
|
|
|
|
|
import java.lang.annotation.Annotation;
|
|
|
|
import java.lang.reflect.Field;
|
|
|
|
import java.sql.ResultSet;
|
|
|
|
import java.sql.Statement;
|
|
|
|
import java.util.*;
|
|
|
|
|
|
|
|
@Intercepts({
|
|
|
|
@Signature(
|
|
|
|
type = StatementHandler.class,
|
|
|
|
method = "update",
|
|
|
|
args = {Statement.class}
|
|
|
|
)
|
|
|
|
})
|
|
|
|
public class GeneratedKeyInterceptor implements Interceptor {
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public Object intercept(Invocation invocation) throws Throwable {
|
|
|
|
StatementHandler handler = (StatementHandler) invocation.getTarget();
|
|
|
|
MetaObject meta = SystemMetaObject.forObject(handler);
|
|
|
|
MappedStatement ms = (MappedStatement) meta.getValue("delegate.mappedStatement");
|
|
|
|
if (ms == null || ms.getSqlCommandType() != SqlCommandType.INSERT) {
|
|
|
|
return invocation.proceed();
|
|
|
|
}
|
|
|
|
|
|
|
|
Object paramObj = meta.getValue("delegate.boundSql.parameterObject");
|
|
|
|
List<?> list = extractList(paramObj);
|
|
|
|
if (list == null || list.isEmpty()) {
|
|
|
|
return invocation.proceed();
|
|
|
|
}
|
|
|
|
|
|
|
|
// 执行插入
|
|
|
|
Object result = invocation.proceed();
|
|
|
|
|
|
|
|
// 从 Statement 里读取生成的 keys,并回写到集合对象
|
|
|
|
Statement stmt = (Statement) invocation.getArgs()[0];
|
|
|
|
try (ResultSet rs = stmt.getGeneratedKeys()) {
|
|
|
|
String pkName = resolvePkName(list.get(0));
|
|
|
|
if (pkName == null) {
|
|
|
|
return result; // 无法解析主键名,跳过
|
|
|
|
}
|
|
|
|
int idx = 0;
|
|
|
|
while (rs != null && rs.next() && idx < list.size()) {
|
|
|
|
Object keyVal = rs.getObject(1);
|
|
|
|
setFieldValue(list.get(idx), pkName, keyVal);
|
|
|
|
idx++;
|
|
|
|
}
|
|
|
|
} catch (Exception ignore) {
|
|
|
|
// 不应该抛出,保证插入不受影响;若需要可记录日志
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
@SuppressWarnings("rawtypes")
|
|
|
|
private List extractList(Object param) {
|
|
|
|
if (param == null) return null;
|
|
|
|
if (param instanceof List) return (List) param;
|
|
|
|
if (param instanceof Map) {
|
|
|
|
Map map = (Map) param;
|
|
|
|
// MyBatis 对未指定 name 的 collection 参数可能为 "list" 或 "collection"
|
|
|
|
if (map.get("list") instanceof List) return (List) map.get("list");
|
|
|
|
if (map.get("collection") instanceof List) return (List) map.get("collection");
|
|
|
|
// 支持 @Param("list")
|
|
|
|
for (Object v : map.values()) {
|
|
|
|
if (v instanceof List) return (List) v;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
|
|
|
|
private String resolvePkName(Object obj) {
|
|
|
|
Class<?> cls = obj.getClass();
|
|
|
|
// 查找带 @Id 注解的字段(javax.persistence 或其它同名注解)
|
|
|
|
List<Field> list = getAllFields(cls);
|
|
|
|
for (Field f : list) {
|
|
|
|
for (Annotation a : f.getAnnotations()) {
|
|
|
|
String an = a.annotationType().getSimpleName();
|
|
|
|
if ("Id".equals(an) || "TableId".equals(an)) {
|
|
|
|
return f.getName();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// 回退到常见命名
|
|
|
|
for (Field f : list) {
|
|
|
|
if ("id".equalsIgnoreCase(f.getName())) return f.getName();
|
|
|
|
}
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
|
|
|
|
private List<Field> getAllFields(Class<?> cls) {
|
|
|
|
List<Field> fields = new ArrayList<>();
|
|
|
|
Class<?> cur = cls;
|
|
|
|
while (cur != null && cur != Object.class) {
|
|
|
|
Collections.addAll(fields, cur.getDeclaredFields());
|
|
|
|
cur = cur.getSuperclass();
|
|
|
|
}
|
|
|
|
return fields;
|
|
|
|
}
|
|
|
|
|
|
|
|
private void setFieldValue(Object target, String fieldName, Object value) {
|
|
|
|
try {
|
|
|
|
Field f = findField(target.getClass(), fieldName);
|
|
|
|
if (f == null) return;
|
|
|
|
f.setAccessible(true);
|
|
|
|
Class<?> type = f.getType();
|
|
|
|
Object converted = convertValue(value, type);
|
|
|
|
f.set(target, converted);
|
|
|
|
} catch (Exception ignored) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private Field findField(Class<?> cls, String name) {
|
|
|
|
Class<?> cur = cls;
|
|
|
|
while (cur != null && cur != Object.class) {
|
|
|
|
try {
|
|
|
|
return cur.getDeclaredField(name);
|
|
|
|
} catch (NoSuchFieldException e) {
|
|
|
|
cur = cur.getSuperclass();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
|
|
|
|
private Object convertValue(Object val, Class<?> targetType) {
|
|
|
|
if (val == null) return null;
|
|
|
|
if (targetType.isAssignableFrom(val.getClass())) return val;
|
|
|
|
if (Number.class.isAssignableFrom(val.getClass()) || val instanceof Number) {
|
|
|
|
Number n = (Number) val;
|
|
|
|
if (targetType == Long.class || targetType == long.class) return n.longValue();
|
|
|
|
if (targetType == Integer.class || targetType == int.class) return n.intValue();
|
|
|
|
if (targetType == Short.class || targetType == short.class) return n.shortValue();
|
|
|
|
if (targetType == Byte.class || targetType == byte.class) return n.byteValue();
|
|
|
|
}
|
|
|
|
// 最后回退到字符串再尝试解析简单类型
|
|
|
|
String s = val.toString();
|
|
|
|
if (targetType == String.class) return s;
|
|
|
|
try {
|
|
|
|
if (targetType == Long.class || targetType == long.class) return Long.parseLong(s);
|
|
|
|
if (targetType == Integer.class || targetType == int.class) return Integer.parseInt(s);
|
|
|
|
if (targetType == Short.class || targetType == short.class) return Short.parseShort(s);
|
|
|
|
if (targetType == Byte.class || targetType == byte.class) return Byte.parseByte(s);
|
|
|
|
} catch (Exception ignored) {
|
|
|
|
}
|
|
|
|
return val;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public Object plugin(Object target) {
|
|
|
|
return Plugin.wrap(target, this);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public void setProperties(Properties properties) {
|
|
|
|
// no-op
|
|
|
|
}
|
|
|
|
} |
...
|
...
|
|