一个利用反射查询数据库的小实例

October . 29 . 2018

前一段时间一直在学习反射,  最近利用反射实现了一个简单的数据库(mysql)查询, 在此做一个分享, 希望对正在学习反射的铁汁们有所帮助.

1. 首先创建一个表

table.PNG

表的结构很简单, 描述了一个学生的基本信息.

项目的包结构如下:

package.PNG

2. 构建student表的javaBean

package my;

import java.util.Date;

public class Student {
    private Integer id;
    private String name;
    private Boolean sex;
    private String phone;
    private Date birthday;

    public Integer getId() {
        return id;
    }

    public void setId(Integer id) {
        this.id = id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public Boolean getSex() {
        return sex;
    }

    public void setSex(Boolean sex) {
        this.sex = sex;
    }

    public String getPhone() {
        return phone;
    }

    public void setPhone(String phone) {
        this.phone = phone;
    }

    public Date getBirthday() {
        return birthday;
    }

    public void setBirthday(Date birthday) {
        this.birthday = birthday;
    }

    @Override
    public String toString() {
        return "\t学号: " + id + "\t姓名: " + name +
                "\t性别: " + (sex ? "男" : "女") +
                "\t电话: " + phone + "\t生日: " + new SimpleDateFormat("yyyy-MM-dd").format(birthday) + "\n";
    }
}

该类用来映射数据库, 作为一个简单的 ORM

3. 创建数据库连接类

package Sql;

import java.lang.reflect.Method;
import java.sql.*;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;

public class SqlConnection {
    private String ip;                 // 数据库ip
    private int port = 3306;           // 数据库端口
    private String datebase;           // 数据库名
    private String username;           // 数据库用户名
    private String password;           // 数据库密码

    Connection conn;


    public SqlConnection(String ip, int port, String datebase, String username, String password)
    {
        this.ip = ip;
        this.port = port;
        this.datebase = datebase;
        this.username = username;
        this.password = password;
    }

    // 打开数据库连接
    public void connect() throws SQLException {
        String urlFmt = "jdbc:mysql://%s:%d/%s?useUnicode=true&characterEncoding=UTF-8&useSSL=false";
        String connectionUrl = String.format(urlFmt, ip, port, datebase);
        conn = DriverManager.getConnection(connectionUrl, username, password);
    }

    // 关闭数据库连接
    public void close()
    {
        try {
            conn.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    // 查询, 返回结果集
    public ResultSet executeQuery(String sql) throws SQLException {
        Statement query = conn.createStatement();
        return query.executeQuery(sql);
    }

    // 执行查询, 并自动映射
    public List executeQuery(String sql, Class cls) throws SQLException, IllegalAccessException, InstantiationException {
        Statement stmt = conn.createStatement();
        ResultSet rs = stmt.executeQuery(sql);
        ResultSetMetaData rsmd = rs.getMetaData();

        int numColumns = rsmd.getColumnCount();   // 获取列数
        String[] columnNames  = new String[numColumns];  // 存储列名
        int[] columnTypes = new int[numColumns];   // 存储列类型

        for (int i=0; i<numColumns; i++)
        {
            int columnIndex = i + 1;    // 列索引从1开始
            columnNames[i] = rsmd.getColumnLabel(columnIndex);        // 列名
            columnTypes[i] = rsmd.getColumnType(columnIndex);         // 列类型
        }

        // 找出每列对应的 setters 方法
        Method[] setters = SqlReflect.findMethod(cls, columnNames);

        // 存储映射的pojo类
        List rows = new ArrayList<>();

        while (rs.next())
        {
            Object pojo = cls.newInstance();
            // 取出每一列的值并赋值
            for(int i=0; i<numColumns; i++)
            {
                int columnIndex = i + 1;
                String columnValue = rs.getString(columnIndex);   // 每列的值

                try {
                    SqlReflect.map(pojo, setters[i], columnTypes[i], columnValue);
                } catch (ParseException e) {
                    e.printStackTrace();
                }
            }
            rows.add(pojo);
        }
        return rows;
    }
}

这里我重点解释一下 executeQuery(String sql, Class cls) 这个方法, 接收的两个参数分别是 sql查询字符串 以及 需要对应映射的javaBean类, 最后的返回值为 对应 javaBean 的 List.

方法的大致思路是: (这里的 SqlReflect 是自定义的反射类稍后会做解释, 先明白它的作用即可)

  1. getMetaData() 得到表的基本信息.
  2. numColumns 存储表的总列数,  columnNames 存储表的每一列的名字, columnTypes 存储每一列的类型.
  3. 利用循环对 numColumns columnNames 数组赋值, 这里的需要注意的是两个数组的下标要相互对应.
  4. 调用 SqlReflect.findMethod(cls, columnNames) 方法的到 javaBean 类的所有 setter 方法.
  5. 循环取出每一列的值, 并利用SqlReflect.map(pojo, setters[i], columnTypes[i], columnValue) 将每一列的值一一赋值给javaBean对象.

相信到这大家都会对 SqlReflect 很好奇, 接下来我们就来揭开它神秘的面纱.

4. 创建 SqlReflect 类

package Sql;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.Types;
import java.text.ParseException;
import java.text.SimpleDateFormat;

public class SqlReflect {

    static SimpleDateFormat sdf_dt = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    static SimpleDateFormat sdf_d = new SimpleDateFormat("yyyy-MM-dd");
    static SimpleDateFormat sdf_t = new SimpleDateFormat("HH:mm:ss");

    // 寻找方法
    public static Method findMethod(Class cls, String columnName)
    {
        char firstChar = Character.toUpperCase(columnName.charAt(0));  // 首字母大写
        StringBuffer strbuf = new StringBuffer("set" + columnName);
        strbuf.setCharAt(3, firstChar);

        String methodName = strbuf.toString();
        Method[] methods = cls.getMethods();

        for (Method m: methods)
        {
            if(m.getName().equals(methodName))
            {
                return m;
            }
        }
        return null;
    }


    public static Method[] findMethod(Class cls, String[] columnNames){

        Method[] methods = new Method[ columnNames.length ];

        for (int i=0; i<columnNames.length; i++)
        {
            methods[i] = findMethod(cls, columnNames[i]);
        }

        return methods;
    }

    public static void map(Object pojo, Method method, int columnType, String columnValue) throws ParseException {

        if (method == null) return;
        if(columnValue == null) return;

        // 判断类型, 传递对应类型的值给setter
        Object arg0 = null;
        if(columnType == Types.CHAR || columnType == Types.VARCHAR )
        {
            arg0 = columnValue; // 给方法传一个String参数
        }
        else if(columnType == Types.BIT) // tinyint(1)
        {
            arg0 = (columnValue.equals("1"));
        }
        else if(columnType == Types.DATE) // date
        {
            arg0 = sdf_d.parse(columnValue);
        }
        else if(columnType == Types.TIME) // time
        {
            arg0 = sdf_t.parse(columnValue);
        }
        else if(columnType == Types.TIMESTAMP) // datetime timestamp
        {
            arg0 = sdf_dt.parse(columnValue);
        }
        else if(columnType == Types.TINYINT || columnType == Types.SMALLINT
                || columnType == Types.INTEGER || columnType == Types.BIGINT
                || columnType == Types.DOUBLE || columnType == Types.FLOAT)
        {
            // 整数类型的处理
            Class[] parameterTypes = method.getParameterTypes();
            Class c = parameterTypes[0];
            if(c.equals( int.class) || c.equals(Integer.class))
            {
                arg0 = Integer.valueOf(columnValue);
            }
            else if(c.equals( long.class) || c.equals(Long.class))
            {
                arg0 =  Long.valueOf(columnValue);
            }
            else if(c.equals( short.class) || c.equals(Short.class))
            {
                arg0 =  Short.valueOf(columnValue);
            }
            else if(c.equals( byte.class) || c.equals(Byte.class))
            {
                arg0 =  Byte.valueOf(columnValue);
            }
            else if(c.equals( float.class) || c.equals(Float.class))
            {
                arg0 =  Float.valueOf(columnValue);
            }
            else if(c.equals( double.class) || c.equals(Double.class))
            {
                arg0 =  Double.valueOf(columnValue);
            }
        }

        // 调用setter方法
        Object args[] = { arg0 };

        try {
            method.invoke(pojo, args);
        } catch (IllegalAccessException | InvocationTargetException e) {
            e.printStackTrace();
        }
    }
}

该方法具有三个静态的方法, 分别是:

  • public static Method findMethod(Class cls, String columnName)

      该方法返回一个Method, 第一个参数为对应的映射javaBean, 第二个参数为对应的列名, 具体逻辑是通过字符串的拼接得到一个标准的set方法名(如: setName), 再调用Class内置的 getMethods() 方法的到该类中所有方法, 利用循环遍历比较出相同名字的方法即可.

  • public static Method[] findMethod(Class cls, String[] columnNames)

       该方法是重载方法, 返回一个 Method 数组, 第一个参数为对应的映射javaBean, 第二个参数是列名数据, 具体逻辑就是循环调用了 findMethod(Class cls, String columnName) 方法就不多说了.

  • public static void map(Object pojo, Method method, int columnType, String columnValue)

      给对应映射关系赋值, 第一个参数是对应的映射类, 第二个参数是方法名, 第三个参数是参数的类型, 第四个参数是需要赋的值. 具体逻辑是通过判断 columnType 的类型将 columnValue 转换为对应的类型, 并调用Method中内置 invoke() 方法 , 第一个参数是对应该方法的类, 第二个参数是该方法的接收参数数组.

4. 添加测试类

package my;

import Sql.SqlConnection;

import java.sql.SQLException;
import java.util.List;

public class test {

    public static void queryTest() throws SQLException, InstantiationException, IllegalAccessException {
        SqlConnection conn = new SqlConnection(
                "127.0.0.1", 3306, "school", "root", "123456"
        );

        conn.connect();
        System.out.println("数据库已连接!");

        String sql = "SELECT * FROM student";
        List st = conn.executeQuery(sql, Student.class);
        System.out.println((ArrayList)st);
        conn.close();
        System.out.println("连接已关闭!");

    }

    public static void main(String[] args)
    {
        try {
            queryTest();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

运行结果如下图:

result.PNG

可以看到查询到了数据, 并且每一列都是一个 Student 对象, 是不是很方便.