package com.afanticar.service.impl;

import com.afanticar.service.ICkBatchService;
import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

import java.lang.reflect.Field;
import java.sql.*;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author chin
 * @contact chenyan@afanticar.com
 * @since 2022/8/23/023
 */
public class CkBatchServiceImpl<M extends BaseMapper<T>, T> extends ServiceImpl<M, T> implements ICkBatchService<T> {
    protected Class<M> mapperClass = this.currentMapperClass();

    protected Class<T> entityClass = this.currentModelClass();

    @Autowired
    private SqlSessionTemplate sqlSessionTemplate;
    @Autowired
    protected M myBaseMapper;

    protected Logger logger = LoggerFactory.getLogger(this.getClass());

    private final ConcurrentHashMap<String, String> sqlMap = new ConcurrentHashMap<>();

    private final ConcurrentHashMap<String, List<String>> columnMap = new ConcurrentHashMap<>();


    private static final String TABLE_SUFFIX_ALL = "_all";
    private static final String METHOD_NAME_KEY = ".saveBatchRecordsByInput";
    private static final String INSERT_SQL_FORMAT = "insert into %s select %s from input('%s') ";
    private static final String QUERY_SQL_FORMAT = "select name as col_name, `type` as data_type from  system.columns where table = '%s' and database = '%s' order by position asc";

    @Override
    public void saveBatchRecordsByInput(List<T> records) {
        String key = this.mapperClass.getSimpleName() + METHOD_NAME_KEY;
        TableName tableNameAnn = this.entityClass.getAnnotation(TableName.class);
        String tableName = tableNameAnn.value().endsWith(TABLE_SUFFIX_ALL) ? tableNameAnn.value().substring(0, tableNameAnn.value().length() - 4) : tableNameAnn.value();
        if (!sqlMap.containsKey(key)) {
            buildSql(tableName, key);
        }
        String mapperSql = sqlMap.get(key);
        List<String> fieldNames = columnMap.get(key);
        Connection connection = null;
        try {
            connection = getConnection();
            PreparedStatement ps = connection.prepareStatement(mapperSql);
            for (T record : records) {
                int i = 1;
                for (String column : fieldNames) {
                    Field field;
                    try {
                        field = entityClass.getDeclaredField(column);
                    } catch (NoSuchFieldException e) {
                        field = entityClass.getSuperclass().getDeclaredField(column);
                    }
                    field.setAccessible(true);
                    Object val = field.get(record);
                    if (val == null) {
                        ps.setNull(i, field.getType().equals(Date.class) ? Types.TIMESTAMP : Types.CHAR);
                    } else if (val instanceof String[]) {
                        ps.setArray(i, connection.createArrayOf("String", (String[]) val));
                    } else if (val instanceof Long[]) {
                        ps.setArray(i, connection.createArrayOf("Long", (Long[]) val));
                    } else if (val instanceof Integer[]) {
                        ps.setArray(i, connection.createArrayOf("Integer", (Integer[]) val));
                    } else if (val instanceof Date) {
                        //官方推荐LocalDateTime 不推荐 Timestamp
                        ps.setObject(i, ((Date) val).toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime());
                    } else {
                        ps.setObject(i, val);
                    }
                    i++;
                }
                ps.addBatch();
            }
            ps.executeBatch();
            ps.clearBatch();
        } catch (Exception e) {
            throw new RuntimeException("批量写入CK异常:" + e.getMessage(), e);
        } finally {
            try {
                if (connection != null) {
                    connection.close();
                }
            } catch (SQLException e) {
                logger.error("关闭CK连接异常:{}", e.getMessage());
            }
        }
    }

    private synchronized void buildSql(String tableName, String key) {
        Connection connection = null;
        StringJoiner columns = null;
        StringJoiner columnAndTypes = null;
        List<String> columsList = null;
        try {
            connection = getConnection();
            PreparedStatement ps = connection.prepareStatement(String.format(QUERY_SQL_FORMAT, tableName, connection.getSchema()));
            ResultSet set = ps.executeQuery();
            columns = new StringJoiner(",", "", "");
            columnAndTypes = new StringJoiner(",", "", "");
            columsList = new ArrayList<>();
            while (set.next()) {
                String column = set.getString("col_name");
                String dataType = set.getString("data_type");
                columns.add(column);
                columnAndTypes.add(column + " " + dataType);
                columsList.add(toHumpString(column));
            }
        } catch (Exception e) {
            logger.error("初始化CK表【{}】批量写入语句异常:{}", tableName, e.getMessage());
            throw new RuntimeException("初始化CK表【" + tableName + "】批量写入语句异常:" + e.getMessage(), e);
        } finally {
            try {
                if (connection != null) {
                    connection.close();
                }
            } catch (SQLException e) {
                logger.error("关闭CK连接异常:{}", e.getMessage());
            }
        }
        String querySt = String.format(INSERT_SQL_FORMAT, tableName + TABLE_SUFFIX_ALL, columns.toString(), columnAndTypes.toString());
        logger.info("初始化CK表【{}】批量写入语句:{}", tableName, querySt);
        sqlMap.putIfAbsent(key, querySt);
        columnMap.putIfAbsent(key, columsList);
    }

    private static String toHumpString(String string) {
        StringBuilder stringBuilder = new StringBuilder();
        String[] str = string.split("_");
        for (String string2 : str) {
            if (stringBuilder.length() == 0) {
                stringBuilder.append(string2);
            } else {
                stringBuilder.append(string2.substring(0, 1).toUpperCase());
                stringBuilder.append(string2.substring(1));
            }
        }
        return stringBuilder.toString();
    }

    public Connection getConnection() {
        Connection conn = null;
        try {
            SqlSession sqlSession = sqlSessionTemplate.getSqlSessionFactory().openSession();
            conn = sqlSession.getConfiguration().getEnvironment().getDataSource().getConnection();
        } catch (Exception e) {
            logger.error("获取CK连接异常:{}", e.getMessage());
            throw new RuntimeException("获取CK连接异常::" + e.getMessage(), e);
        }
        return conn;
    }
}
