Skip to content

add dws sink pipeline connector#4341

Open
Daishuyuan wants to merge 1 commit intoapache:masterfrom
Daishuyuan:FLINK-39327
Open

add dws sink pipeline connector#4341
Daishuyuan wants to merge 1 commit intoapache:masterfrom
Daishuyuan:FLINK-39327

Conversation

@Daishuyuan
Copy link
Copy Markdown

What is changed

  • add a new DWS sink pipeline connector module with sink factory, sink function, metadata applier, option definitions, type conversion utilities and factory service
    registration
  • integrate the new module into the pipeline connectors build
  • add unit and integration tests for data sink factory, sink function, metadata applier and pipeline execution

Why

Flink CDC currently does not provide a DWS pipeline sink connector. This change adds DWS sink pipeline connector support so DWS can be used as a pipeline sink
consistently with other supported systems.

@melin
Copy link
Copy Markdown

melin commented Mar 26, 2026

在一个项目中,通过 flink cdc 把数据同步写入到 dws 和 TBase, 如果一个 task 只有一个线程写入 dws, 数据量很大,增加 flink task 数量,比较费资源,因为写入dws 和 TBase 是 IO 问题。 最后解决办法是,一个 task `中开启多线程并行写入。

package org.apache.flink.cdc.connectors.postgres.sink.v2;

import org.apache.flink.cdc.common.event.DataChangeEvent;
import org.apache.flink.cdc.common.event.OperationType;
import org.apache.flink.cdc.common.event.TableId;
import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.common.schema.Schema;
import org.apache.flink.cdc.common.utils.DirtyData;
import org.apache.flink.cdc.common.utils.KafkaHelper;
import org.apache.flink.cdc.connectors.postgres.sink.PostgresDialect;
import org.apache.flink.cdc.connectors.postgres.sink.PostgresEventSerializer;
import org.apache.flink.cdc.connectors.postgres.sink.SoftDeleteConstant;
import org.apache.flink.connector.jdbc.datasource.connections.JdbcConnectionProvider;
import org.apache.flink.connector.jdbc.internal.executor.JdbcBatchStatementExecutor;

import org.apache.flink.shaded.curator5.org.apache.curator.utils.ThreadUtils;
import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
import org.apache.flink.shaded.guava31.com.google.common.collect.Maps;

import com.alibaba.druid.pool.DruidPooledPreparedStatement;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.postgresql.core.ParameterList;
import org.postgresql.jdbc.PgStatement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static org.apache.flink.cdc.common.event.OperationType.UPDATE;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.BACKUP_TABLE_NAME;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.COLUMN_MAPPING;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.CUSTOM_PRIMARY_KEY;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.DATABASE;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.DATASOURCE_NAME;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.DIRTY_DATA_TOPIC;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.EXCLUDE_COLUMN;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.KAFKA_SERVERS;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.LOGIC_DELETE_COLUMN;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.LOGIC_DELETE_DATETIME_COLUMN;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.SINK_TYPE;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.SOFT_DELETE_ENABLED;
import static org.apache.flink.cdc.connectors.postgres.sink.PostgresDataSourceOptions.THREAD_BATCH_SIZE;
import static org.apache.flink.util.Preconditions.checkArgument;

/**
 * A {@link JdbcBatchStatementExecutor} that executes supplied statement for given the records
 * (without any pre-processing).
 */
class SimpleBatchStatementExecutorV2 implements JdbcBatchStatementExecutor<DataChangeEvent> {

    private static final Logger LOG = LoggerFactory.getLogger(SimpleBatchStatementExecutorV2.class);

    private final ConcurrentHashMap<TableId, LinkedHashMap<String, DataChangeEvent>>
            batchUpsertMap = new ConcurrentHashMap<>();

    private final ConcurrentHashMap<TableId, LinkedHashMap<String, DataChangeEvent>>
            batchDeleteMap = new ConcurrentHashMap<>();

    private volatile JdbcConnectionProvider connectionProvider;

    private final KafkaHelper kafkaHelper;

    private final String customPrimaryKey;

    private HashMap<String, List<Integer>> parameterMap = new HashMap<>();

    private final AtomicBoolean softDelete = new AtomicBoolean(false);

    private final AtomicReference<String> backupTableName = new AtomicReference<>("");

    private final AtomicReference<String> logicDeleteColumn = new AtomicReference<>("");

    private final AtomicReference<String> logicDeleteDatetimeColumn = new AtomicReference<>("");

    private final AtomicReference<String> columnMapping = new AtomicReference<>("");

    private final AtomicReference<String> excludeColumn = new AtomicReference<>("");

    private final AtomicReference<String> datasourceName = new AtomicReference<>("");

    private final AtomicReference<String> sinkType = new AtomicReference<>("");

    private final AtomicReference<String> database = new AtomicReference<>("");

    private final AtomicReference<String> name = new AtomicReference<>("");

    private Map<TableId, List<String>> customPrimaryKeyMap = new ConcurrentHashMap<>();

    private Map<TableId, List<String>> primaryKeyMap = new ConcurrentHashMap<>();

    private final PostgresEventSerializer serializer;

    private static final DateTimeFormatter FORMATTER =
            DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");

    private ExecutorService executorService;

    private int threadBatchSize = 200;

    private long totalRecordCount;

    private long totalTimes;

    private long globalRecordCount;

    private Field batchStatementsField;

    private Field batchParametersField;

    SimpleBatchStatementExecutorV2(
            JdbcConnectionProvider connectionProvider, PostgresEventSerializer serializer) {
        this.connectionProvider = connectionProvider;
        LOG.info("connectionProvider connection: " + connectionProvider.getClass().getName());
        this.serializer = serializer;
        Properties properties = connectionProvider.getProperties();
        String kafkaServers = properties.getProperty(KAFKA_SERVERS.key());
        String dirtyDataTopic = properties.getProperty(DIRTY_DATA_TOPIC.key());
        this.customPrimaryKey = properties.getProperty(CUSTOM_PRIMARY_KEY.key());
        LOG.info("pg customPrimaryKey:{}", customPrimaryKey);

        String value =
                connectionProvider.getProperties().getProperty(THREAD_BATCH_SIZE.key(), "200");
        threadBatchSize = Integer.parseInt(value);
        LOG.info("pg threadBatchSize: {}", threadBatchSize);

        try {
            Boolean softDeleteParam =
                    Boolean.valueOf(properties.getProperty(SOFT_DELETE_ENABLED.key()));
            softDelete.set(softDeleteParam);
            LOG.info("pg softDelete:{}", softDelete);

            String backupTableNameParam = properties.getProperty(BACKUP_TABLE_NAME.key());
            backupTableName.set(backupTableNameParam);
            LOG.info("pg backupTableName:{}", backupTableName);

            String logicDeleteColumnParam = properties.getProperty(LOGIC_DELETE_COLUMN.key());
            logicDeleteColumn.set(logicDeleteColumnParam);
            LOG.info("pg logicDeleteColumn:{}", logicDeleteColumn);

            String logicDeleteDatetimeColumnParam =
                    properties.getProperty(LOGIC_DELETE_DATETIME_COLUMN.key());
            logicDeleteDatetimeColumn.set(logicDeleteDatetimeColumnParam);
            LOG.info("pg logicDeleteDatetimeColumn:{}", logicDeleteDatetimeColumn);

            String columnMappingParam = properties.getProperty(COLUMN_MAPPING.key());
            columnMapping.set(columnMappingParam);
            LOG.info("pg columnMapping:{}", columnMapping);

            String excludeColumnParam = properties.getProperty(EXCLUDE_COLUMN.key());
            excludeColumn.set(excludeColumnParam);
            LOG.info("pg excludeColumn:{}", excludeColumn);

            String datasourceNameParam = properties.getProperty(DATASOURCE_NAME.key());
            datasourceName.set(datasourceNameParam);

            String sinkTypeParam = properties.getProperty(SINK_TYPE.key());
            sinkType.set(sinkTypeParam);

            String databaseParam = properties.getProperty(DATABASE.key());
            database.set(databaseParam);

            String nameParam = properties.getProperty(PipelineOptions.PIPELINE_NAME.key());
            name.set(nameParam);
        } catch (Exception e) {
            LOG.error(
                    "Failed to get the 'soft.delete.enabled' parameter, 'soft.delete.enabled' defaults to false",
                    e);
        }

        try {
            batchStatementsField = PgStatement.class.getDeclaredField("batchStatements");
            batchParametersField = PgStatement.class.getDeclaredField("batchParameters");

            batchStatementsField.setAccessible(true);
            batchParametersField.setAccessible(true);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        if (StringUtils.isBlank(kafkaServers) || StringUtils.isBlank(dirtyDataTopic)) {
            LOG.error("properties.bootstrap.servers and dirty-data parameter cannot be null");
            this.kafkaHelper = null;
        } else {
            this.kafkaHelper = new KafkaHelper(kafkaServers, dirtyDataTopic);
            LOG.info("kafkaHelper {}", kafkaHelper);
        }

        if (StringUtils.isNotBlank(customPrimaryKey)) {
            String[] parts = customPrimaryKey.split(";");
            for (String part : parts) {
                String[] subParts = part.split(":");
                if (subParts.length == 2) {
                    TableId tableId = TableId.parse(subParts[0]);
                    List<String> keys = Arrays.asList(subParts[1].split(","));
                    customPrimaryKeyMap.put(tableId, keys);
                }
            }
        }

        executorService =
                Executors.newCachedThreadPool(ThreadUtils.newGenericThreadFactory("pg-sink-"));
    }

    @Override
    public void prepareStatements(Connection connection) throws SQLException {}

    @Override
    public void addToBatch(DataChangeEvent event) {
        OperationType op = event.op();
        TableId tableId = event.tableId();

        Map<String, Object> eventMap = serializer.serializeDataChangeEvent(event);
        Map<TableId, Schema> schemaMaps = serializer.getSinkSchemaMaps();
        Schema schema = schemaMaps.get(tableId);
        List<String> primaryKeys = schema.primaryKeys();

        if (eventMap == null) {
            throw new RuntimeException("tableId: " + tableId + " is null");
        }

        globalRecordCount++;

        List<String> list = new ArrayList<>();
        for (String primaryKey : primaryKeys) {
            Object value = eventMap.get(primaryKey);

            if (value == null) {
                String msg =
                        "tableId: " + tableId + ", " + primaryKey + " is null, Map: " + eventMap;
                dirtyDataHandle(event, msg, op.name(), null);
                return;
            }

            list.add(value.toString());
        }
        String pkValues = StringUtils.join(list, ",");

        event.setEventMap(eventMap);

        if (op == OperationType.DELETE) {
            batchDeleteMap
                    .computeIfAbsent(tableId, k -> new LinkedHashMap<>())
                    .put(pkValues, event);
        } else {
            batchUpsertMap
                    .computeIfAbsent(tableId, k -> new LinkedHashMap<>())
                    .put(pkValues, event);

            // 比秒先来了删除事件,再插入一条相同主键的新数据。
            batchDeleteMap.computeIfAbsent(tableId, k -> new LinkedHashMap<>()).remove(pkValues);
        }
    }

    @Override
    public void executeBatch() throws SQLException {
        long startTimeMs = System.currentTimeMillis();
        AtomicInteger upsertCount = new AtomicInteger(0);
        AtomicInteger deleteCount = new AtomicInteger(0);
        Map<TableId, CopyOnWriteArrayList<String>> upsertCountMap = Maps.newHashMap();
        Map<TableId, CopyOnWriteArrayList<String>> deleteCountMap = Maps.newHashMap();

        long upsertStartTimeMs = System.currentTimeMillis();
        int upsertThreadCount = 0;
        if (batchUpsertMap.size() > 0) {
            Map<TableId, List<List<DataChangeEvent>>> eventMap = new HashMap<>();
            for (Map.Entry<TableId, LinkedHashMap<String, DataChangeEvent>> entry :
                    batchUpsertMap.entrySet()) {
                TableId tableId = entry.getKey();
                LinkedHashMap<String, DataChangeEvent> map = batchUpsertMap.remove(tableId);
                List<DataChangeEvent> list = map.values().stream().collect(Collectors.toList());
                List<List<DataChangeEvent>> partitions = Lists.partition(list, threadBatchSize);
                upsertThreadCount += partitions.size();
                eventMap.put(tableId, partitions);
                upsertCountMap.put(tableId, new CopyOnWriteArrayList<>());
            }

            CountDownLatch upsertDownLatch = new CountDownLatch(upsertThreadCount);
            for (Map.Entry<TableId, List<List<DataChangeEvent>>> entry : eventMap.entrySet()) {
                List<List<DataChangeEvent>> partitions = entry.getValue();
                CopyOnWriteArrayList<String> list = upsertCountMap.get(entry.getKey());
                for (List<DataChangeEvent> partition : partitions) {
                    executorService.submit(
                            new PgUpsertTask(
                                    entry.getKey(), partition, upsertCount, list, upsertDownLatch));
                }
            }

            try {
                boolean flag = upsertDownLatch.await(10, TimeUnit.MINUTES);
                if (!flag) {
                    LOG.error(
                            "upsertDownLatch timed out, upsertThreadCount: {}, sync count",
                            upsertThreadCount,
                            upsertDownLatch.getCount());
                }
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        long upsertTimes = System.currentTimeMillis() - upsertStartTimeMs;

        int deleteThreadCount = 0;
        long deleteStartTimeMs = System.currentTimeMillis();
        if (batchDeleteMap.size() > 0) {
            Map<TableId, List<List<DataChangeEvent>>> eventMap = new HashMap<>();
            for (Map.Entry<TableId, LinkedHashMap<String, DataChangeEvent>> entry :
                    batchDeleteMap.entrySet()) {
                TableId tableId = entry.getKey();
                LinkedHashMap<String, DataChangeEvent> map = batchDeleteMap.remove(tableId);
                List<DataChangeEvent> list = map.values().stream().collect(Collectors.toList());
                List<List<DataChangeEvent>> partitions = Lists.partition(list, threadBatchSize);
                deleteThreadCount += partitions.size();
                eventMap.put(tableId, partitions);
                deleteCountMap.put(tableId, new CopyOnWriteArrayList<>());
            }

            CountDownLatch downLatch = new CountDownLatch(deleteThreadCount);
            for (Map.Entry<TableId, List<List<DataChangeEvent>>> entry : eventMap.entrySet()) {
                List<List<DataChangeEvent>> partitions = entry.getValue();
                CopyOnWriteArrayList<String> list = deleteCountMap.get(entry.getKey());
                for (List<DataChangeEvent> partition : partitions) {
                    executorService.submit(
                            new PgDeleteTask(
                                    entry.getKey(), partition, deleteCount, list, downLatch));
                }
            }

            try {
                downLatch.await();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }

        long deleteTimes = System.currentTimeMillis() - deleteStartTimeMs;

        long endTimeMs = System.currentTimeMillis();
        long times = endTimeMs - startTimeMs;
        int total = upsertCount.get() + deleteCount.get();
        String avgTime = formatAvgTime(times, total);

        totalRecordCount += total;
        totalTimes += times;
        String totalAvgTime = formatAvgTime(totalTimes, totalRecordCount);
        LOG.info(
                "v5-{} postgresql execution took {} ms, total count: {}/{}/{}, avgTime: {}/{}. \n\tupsert [count: {}, times: {}ms, thread: {} {}], \n\tdelete [count: {}, times: {}ms, thread: {} {}]",
                this.hashCode(),
                times,
                total,
                totalRecordCount,
                globalRecordCount,
                avgTime,
                totalAvgTime,
                upsertCount.get(),
                upsertTimes,
                upsertThreadCount,
                upsertCountMap,
                deleteCount.get(),
                deleteTimes,
                deleteThreadCount,
                deleteCountMap);
    }

    private Result parseTableSchema(TableId tableId) {
        Map<TableId, Schema> schemaMaps = serializer.getSinkSchemaMaps();
        Schema schema = schemaMaps.get(tableId);
        List<String> columnNames = new ArrayList<>(schema.getColumnNames());
        List<String> primaryKeys = new ArrayList<>(schema.primaryKeys());
        excludeColumn(tableId, columnNames);
        if (primaryKeys.isEmpty()) {
            throw new RuntimeException("Table " + tableId + " has no primary key");
        }

        String columnMappingParam = getCustomColumn(columnMapping.get(), tableId);
        Map<String, String> columnMappingMap = new HashMap<>();
        if (StringUtils.isNotBlank(columnMappingParam)) {
            String[] keyValuePairs = columnMappingParam.split(",");

            for (String pair : keyValuePairs) {
                String[] entry = pair.split("=");
                columnMappingMap.put(entry[1], entry[0]);
            }
        }

        if (MapUtils.isNotEmpty(columnMappingMap)) {
            columnNames.replaceAll(
                    column ->
                            columnMappingMap.containsKey(column)
                                    ? columnMappingMap.get(column)
                                    : column);

            primaryKeys.replaceAll(
                    column ->
                            columnMappingMap.containsKey(column)
                                    ? columnMappingMap.get(column)
                                    : column);
        }
        Result result = new Result(columnNames, primaryKeys);
        return result;
    }

    private static class Result {
        public final List<String> columnNames;
        public final List<String> primaryKeys;

        public Result(List<String> columnNames, List<String> primaryKeys) {
            this.columnNames = columnNames;
            this.primaryKeys = primaryKeys;
        }
    }

    private class PgDeleteTask implements Runnable {
        private TableId tableId;
        private List<DataChangeEvent> events;
        private AtomicInteger deleteCount;
        private CopyOnWriteArrayList<String> deleteCountList;
        private CountDownLatch downLatch;

        public PgDeleteTask(
                TableId tableId,
                List<DataChangeEvent> events,
                AtomicInteger deleteCount,
                CopyOnWriteArrayList<String> deleteCountList,
                CountDownLatch downLatch) {
            this.tableId = tableId;
            this.events = events;
            this.deleteCount = deleteCount;
            this.deleteCountList = deleteCountList;
            this.downLatch = downLatch;
        }

        @Override
        public void run() {
            long start = System.currentTimeMillis();
            try {
                deleteCount.addAndGet(events.size());

                // 辽宁医保需求,除了正常删除,softDelete 为true时还需要把删除数据写到备份表
                List<Object> batchStatementsCopy = new ArrayList<>();
                List<Object> batchParametersCopy = new ArrayList<>();
                Result result = parseTableSchema(tableId);

                try (Connection connection = connectionProvider.getOrEstablishConnection();
                        PreparedStatement stDelete =
                                connection.prepareStatement(
                                        generateDml(tableId, OperationType.DELETE))) {
                    connection.setAutoCommit(false);

                    for (DataChangeEvent event : events) {
                        setRecordToStatementJdbc(
                                result.columnNames,
                                result.primaryKeys,
                                stDelete,
                                event,
                                DirtyData.ERROR_TYPE_DELETE);
                        stDelete.addBatch();
                    }
                    getPgStatementParam(stDelete, batchStatementsCopy, batchParametersCopy);
                    stDelete.executeBatch();

                    if (softDelete.get() && events.size() > 0) {
                        String softDeleteDmlSql = generateDml(tableId, OperationType.DELETE, true);
                        try (PreparedStatement stDeleteBackup =
                                connection.prepareStatement(softDeleteDmlSql)) {
                            for (DataChangeEvent event : events) {
                                setRecordToStatementDeleteBackup(
                                        stDeleteBackup, event, DirtyData.ERROR_TYPE_DELETE);
                                stDeleteBackup.addBatch();
                            }
                            stDeleteBackup.executeBatch();
                        }
                    }

                    connection.commit();
                } catch (Throwable e) {
                    dirtyDataHandle(
                            events,
                            batchStatementsCopy,
                            batchParametersCopy,
                            e,
                            DirtyData.ERROR_TYPE_DELETE);
                }

                batchStatementsCopy.clear();
                batchParametersCopy.clear();
            } catch (Throwable e) {
                LOG.error(e.getMessage(), e);
                throw new RuntimeException(e);
            } finally {
                downLatch.countDown();

                long time = System.currentTimeMillis() - start;
                int size = events.size();
                String avgTime = formatAvgTime(time, size);

                deleteCountList.add(size + "(" + time + "/" + avgTime + "ms)");
            }
        }
    }

    private class PgUpsertTask implements Runnable {
        private TableId tableId;
        private List<DataChangeEvent> events;
        private AtomicInteger upsertCount;
        private CopyOnWriteArrayList<String> upsertCountList;
        private CountDownLatch downLatch;

        public PgUpsertTask(
                TableId tableId,
                List<DataChangeEvent> events,
                AtomicInteger upsertCount,
                CopyOnWriteArrayList<String> upsertCountList,
                CountDownLatch downLatch) {
            this.tableId = tableId;
            this.events = events;
            this.upsertCount = upsertCount;
            this.upsertCountList = upsertCountList;
            this.downLatch = downLatch;
        }

        @Override
        public void run() {
            try {
                upsertCount.addAndGet(events.size());
                long start = System.currentTimeMillis();
                List<Object> batchStatementsCopy = new ArrayList<>();
                List<Object> batchParametersCopy = new ArrayList<>();
                Result result = parseTableSchema(tableId);

                try (Connection connection = connectionProvider.getOrEstablishConnection();
                        PreparedStatement stUpdate =
                                connection.prepareStatement(generateDml(tableId, UPDATE))) {

                    try {
                        connection.setAutoCommit(false);

                        // 过滤重复主键。reWriteBatchedInserts=true
                        // 不允许存在多个相同主键记录
                        for (DataChangeEvent event : events) {
                            setRecordToStatementJdbc(
                                    result.columnNames,
                                    result.primaryKeys,
                                    stUpdate,
                                    event,
                                    DirtyData.ERROR_TYPE_INSERT);

                            stUpdate.addBatch();
                        }

                        getPgStatementParam(stUpdate, batchStatementsCopy, batchParametersCopy);
                        stUpdate.executeBatch();

                        connection.commit();
                    } catch (BatchUpdateException e) {
                        connection.rollback();
                        executeIndividually(connection, result, events);
                    }
                } catch (Throwable e) {
                    dirtyDataHandle(
                            events,
                            batchStatementsCopy,
                            batchParametersCopy,
                            e,
                            DirtyData.ERROR_TYPE_UPDATE);
                } finally {
                    long time = System.currentTimeMillis() - start;
                    int size = events.size();
                    String avgTime = formatAvgTime(time, size);

                    upsertCountList.add(size + "(" + time + "ms/" + avgTime + "ms)");
                }
                batchStatementsCopy.clear();
                batchParametersCopy.clear();
            } catch (Exception e) {
                LOG.error(e.getMessage(), e);
                throw new RuntimeException(e);
            } finally {
                downLatch.countDown();
            }
        }

        public void executeIndividually(
                Connection connection, Result result, List<DataChangeEvent> events)
                throws SQLException {
            String sql = generateDml(tableId, UPDATE);
            try (PreparedStatement statement = connection.prepareStatement(sql)) {
                for (DataChangeEvent event : events) {
                    Object[] values = null;
                    try {
                        values =
                                setRecordToStatementJdbc(
                                        result.columnNames,
                                        result.primaryKeys,
                                        statement,
                                        event,
                                        DirtyData.ERROR_TYPE_INSERT);

                        connection.commit();
                    } catch (SQLException e) {
                        try {
                            connection.rollback(); // 回滚当前事务(仅针对这条失败的SQL)
                        } catch (SQLException rollbackEx) {
                            LOG.error("回滚失败SQL事务时发生错误: {}", rollbackEx.getMessage());
                        }

                        if (values != null) {
                            dirtyDataHandle(event, sql, e, DirtyData.ERROR_TYPE_UPDATE);
                        }
                    }
                }
            }
        }
    }

    private void dirtyDataHandle(DataChangeEvent event, String sql, Throwable e, String errorType) {
        try {
            TableId tableId =
                    new TableId(
                            database.get(),
                            event.tableId().getSchemaName(),
                            event.tableId().getTableName());
            DataChangeEvent dataChangeEventNew = DataChangeEvent.replaceTableId(tableId, event);
            kafkaHelper.sendKafkaDirtyData(
                    errorType,
                    sinkType.get(),
                    datasourceName.get(),
                    ExceptionUtils.getStackTrace(e),
                    sql,
                    dataChangeEventNew);
        } catch (Exception ex) {
            throw new RuntimeException(
                    "Dirty data is sent to kafka exception: " + ex.getMessage(), ex);
        }
    }

    @Override
    public void closeStatements() throws SQLException {}

    private void dirtyDataHandle(
            List<DataChangeEvent> batch,
            List<Object> batchStatementsCopy,
            List<Object> batchParametersCopy,
            Throwable e,
            String errorType) {

        if (e instanceof BatchUpdateException) {
            try {
                BatchUpdateException batchUpdateException = (BatchUpdateException) e;
                int[] updateCounts = batchUpdateException.getUpdateCounts();

                int index = 0;
                for (int count : updateCounts) {
                    Object batchStatement = batchStatementsCopy.get(index);
                    Method toStringMethod =
                            batchStatement.getClass().getMethod("toString", ParameterList.class);
                    toStringMethod.setAccessible(true);
                    String sqlStatement =
                            (String)
                                    toStringMethod.invoke(
                                            batchStatement, batchParametersCopy.get(index));
                    DataChangeEvent dataChangeEventOld = batch.get(index);
                    TableId tableId =
                            new TableId(
                                    database.get(),
                                    dataChangeEventOld.tableId().getSchemaName(),
                                    dataChangeEventOld.tableId().getTableName());
                    DataChangeEvent dataChangeEventNew =
                            DataChangeEvent.replaceTableId(tableId, dataChangeEventOld);
                    if (count == Statement.EXECUTE_FAILED) {
                        kafkaHelper.sendKafkaDirtyData(
                                errorType,
                                sinkType.get(),
                                datasourceName.get(),
                                ExceptionUtils.getStackTrace(e),
                                sqlStatement,
                                dataChangeEventNew);
                    }
                    index++;
                }
            } catch (IndexOutOfBoundsException ex) {
                // ingore
            } catch (Exception ex) {
                throw new RuntimeException(
                        "Dirty data is sent to kafka exception: " + ex.getMessage(), ex);
            }
        }
    }

    private void dirtyDataHandle(
            DataChangeEvent dataChangeEventOld,
            String errorMessage,
            String errorType,
            PreparedStatement ps) {
        if (kafkaHelper == null) {
            return;
        }

        TableId tableId =
                new TableId(
                        database.get(),
                        dataChangeEventOld.tableId().getSchemaName(),
                        dataChangeEventOld.tableId().getTableName());
        DataChangeEvent dataChangeEventNew =
                DataChangeEvent.replaceTableId(tableId, dataChangeEventOld);
        kafkaHelper.sendKafkaDirtyData(
                errorType,
                sinkType.get(),
                datasourceName.get(),
                errorMessage,
                ps == null ? "" : ps.toString(),
                dataChangeEventNew);
    }

    private void getPgStatementParam(
            PreparedStatement preparedStatement,
            List<Object> batchStatementsCopy,
            List<Object> batchParametersCopy) {

        PgStatement pgStatement = null;
        if (preparedStatement instanceof DruidPooledPreparedStatement) {
            preparedStatement =
                    ((DruidPooledPreparedStatement) preparedStatement).getRawStatement();
        }

        pgStatement = (PgStatement) preparedStatement;

        try {
            List<?> batchStatements = (List<?>) batchStatementsField.get(pgStatement);
            List<?> batchParameters = (List<?>) batchParametersField.get(pgStatement);
            batchStatementsCopy.addAll(batchStatements);
            batchParametersCopy.addAll(batchParameters);

        } catch (Exception e) {
            LOG.error("pg反射获取batchStatements和batchParameters失败", e);
        }
    }

    public Object[] setRecordToStatementJdbc(
            List<String> columnNames,
            List<String> primaryKeys,
            PreparedStatement ps,
            DataChangeEvent event,
            String errorType)
            throws SQLException {
        Map<String, Object> map = event.getEventMap();

        Object[] values = null;

        try {
            if (event.op() == OperationType.DELETE) {
                values = new Object[primaryKeys.size()];
                for (int index = 0; index < primaryKeys.size(); index++) {
                    Object value = map.get(primaryKeys.get(index));
                    ps.setObject(index + 1, value);
                    values[index] = value;
                }
            } else {
                values = new Object[columnNames.size()];
                for (int index = 0; index < columnNames.size(); index++) {
                    Object value = map.get(columnNames.get(index));
                    ps.setObject(index + 1, value);
                    values[index] = value;
                }
            }
        } catch (Exception e) {
            dirtyDataHandle(event, ExceptionUtils.getStackTrace(e), errorType, ps);
        }

        return values;
    }

    public synchronized void setRecordToStatementDeleteBackup(
            PreparedStatement ps, DataChangeEvent event, String errorType) {

        try {
            Map<String, Object> map = event.getEventMap();
            TableId tableId = event.tableId();
            Map<TableId, Schema> schemaMaps = serializer.getSinkSchemaMaps();
            Schema schema = schemaMaps.get(event.tableId());
            List<String> columnNames = new ArrayList<>(schema.getColumnNames());

            String columnMappingParam = getCustomColumn(columnMapping.get(), tableId);
            Map<String, String> columnMappingMap = new HashMap<>();
            if (StringUtils.isNotBlank(columnMappingParam)) {
                String[] keyValuePairs = columnMappingParam.split(",");

                for (String pair : keyValuePairs) {
                    String[] entry = pair.split("=");
                    columnMappingMap.put(entry[1], entry[0]);
                }
            }

            if (MapUtils.isNotEmpty(columnMappingMap)) {
                columnNames.replaceAll(
                        column ->
                                columnMappingMap.containsKey(column)
                                        ? columnMappingMap.get(column)
                                        : column);
            }

            // softDelete 为true时,增加delet_flag、delet_time
            if (softDelete.get()) {
                if (event.op() == OperationType.DELETE) {
                    String logicDeleteDatetimeColumnParam =
                            getCustomColumn(logicDeleteDatetimeColumn.get(), tableId);
                    String logicDeleteColumnParam =
                            getCustomColumn(logicDeleteColumn.get(), tableId);
                    if (StringUtils.isBlank(logicDeleteDatetimeColumnParam)
                            && StringUtils.isBlank(logicDeleteColumnParam)) {
                        logicDeleteColumnParam = SoftDeleteConstant.DELETED;
                        logicDeleteDatetimeColumnParam = SoftDeleteConstant.DELETED_TIME;
                    }
                    if (!columnNames.contains(logicDeleteDatetimeColumnParam)) {
                        columnNames.add(logicDeleteDatetimeColumnParam);
                    }
                    if (!columnNames.contains(logicDeleteColumnParam)) {
                        columnNames.add(logicDeleteColumnParam);
                    }
                    map.put(logicDeleteDatetimeColumnParam, getCurrentDateTimeAsString());
                    map.put(logicDeleteColumnParam, SoftDeleteConstant.DELETED_VALUE);
                }
            }

            if (event.op() == OperationType.DELETE) {
                for (int index = 0; index < columnNames.size(); index++) {
                    ps.setObject(index + 1, map.get(columnNames.get(index)));
                }
            }
        } catch (Exception e) {
            dirtyDataHandle(event, ExceptionUtils.getStackTrace(e), errorType, ps);
        }
    }

    public String generateDml(TableId tableId, OperationType op) {
        return generateDml(tableId, op, false);
    }

    public synchronized String generateDml(
            TableId tableId, OperationType op, Boolean softDeleteParam) {
        Map<TableId, Schema> schemaMaps = serializer.getSinkSchemaMaps();
        Schema schema = schemaMaps.get(tableId);
        if (schema == null) {
            String msg =
                    String.format(
                            "table: %s not found, %s, %s schemaMaps: %s",
                            tableId,
                            serializer.hashCode(),
                            schemaMaps.containsKey(tableId),
                            schemaMaps.keySet());
            throw new RuntimeException(msg);
        }

        PostgresDialect dialect = new PostgresDialect();
        List<String> columnNameList = new ArrayList<>(schema.getColumnNames());
        List<String> primaryKeyList = new ArrayList<>(schema.primaryKeys());
        if (primaryKeyList.isEmpty()) {
            throw new RuntimeException("pg tableId: " + tableId + " primaryKeys is empty");
        }

        // 支持用户自定义主键
        if (StringUtils.isNotBlank(customPrimaryKey) && customPrimaryKeyMap.containsKey(tableId)) {
            if (primaryKeyMap.containsKey(tableId)) {
                primaryKeyList = primaryKeyMap.get(tableId);
            } else {
                primaryKeyList = Lists.newArrayList(primaryKeyList);
                for (String pk : customPrimaryKeyMap.get(tableId)) {
                    if (!primaryKeyList.contains(pk)) {
                        primaryKeyList.add(pk);
                    }
                }

                primaryKeyMap.put(tableId, primaryKeyList);
            }
        }

        boolean hasDuplicates =
                primaryKeyList.stream()
                        .collect(Collectors.groupingBy(e -> e, Collectors.counting()))
                        .values()
                        .stream()
                        .anyMatch(count -> count > 1);
        if (hasDuplicates) {
            throw new RuntimeException(
                    "duplicate primary key column: "
                            + primaryKeyList
                            + ", origin primary key: "
                            + schema.primaryKeys());
        }

        // softDelete 为true时,增加delet_flag、delet_time
        if (softDelete.get() && op == OperationType.DELETE) {
            //            columnNameList.add(SoftDeleteConstant.DELETED_TIME);
            //            columnNameList.add(SoftDeleteConstant.DELETED);
            String logicDeleteDatetimeColumnParam =
                    getCustomColumn(logicDeleteDatetimeColumn.get(), tableId);
            String logicDeleteColumnParam = getCustomColumn(logicDeleteColumn.get(), tableId);
            if (StringUtils.isBlank(logicDeleteColumnParam)
                    && StringUtils.isBlank(logicDeleteDatetimeColumnParam)) {
                logicDeleteColumnParam = SoftDeleteConstant.DELETED;
                logicDeleteDatetimeColumnParam = SoftDeleteConstant.DELETED_TIME;
            }
            if (!columnNameList.contains(logicDeleteDatetimeColumnParam)) {
                columnNameList.add(logicDeleteDatetimeColumnParam);
            }
            if (!columnNameList.contains(logicDeleteColumnParam)) {
                columnNameList.add(logicDeleteColumnParam);
            }
        }

        String columnMappingParam = getCustomColumn(columnMapping.get(), tableId);
        Map<String, String> columnMappingMap = new HashMap<>();
        if (StringUtils.isNotBlank(columnMappingParam)) {
            String[] keyValuePairs = columnMappingParam.split(",");

            for (String pair : keyValuePairs) {
                String[] entry = pair.split("=");
                columnMappingMap.put(entry[0], entry[1]);
            }
        }

        if (MapUtils.isNotEmpty(columnMappingMap)) {
            columnNameList.replaceAll(
                    column ->
                            columnMappingMap.containsKey(column)
                                    ? columnMappingMap.get(column)
                                    : column);

            primaryKeyList.replaceAll(
                    column ->
                            columnMappingMap.containsKey(column)
                                    ? columnMappingMap.get(column)
                                    : column);
        }

        excludeColumn(tableId, columnNameList);
        excludeColumn(tableId, primaryKeyList);

        String[] columnNameArray = columnNameList.toArray(new String[columnNameList.size()]);
        String[] primaryKeyArray = primaryKeyList.toArray(new String[primaryKeyList.size()]);
        String dmlSql = "";
        if (op == UPDATE || op == OperationType.INSERT) {
            dmlSql =
                    dialect.getUpsertStatement(
                            tableId.identifier(),
                            columnNameArray,
                            primaryKeyArray,
                            schema.getShareKey());
        } else if (op == OperationType.DELETE) {
            // 根据辽宁医保客户要求,增加删除备份,softDelete 为true时,除了删除数据,还需要将数据写到备份表
            String backupTable = backupTableName.get();
            if (softDeleteParam && StringUtils.isNotBlank(backupTable)) {
                String[] parts = backupTable.split(";");

                String currentBackupTable = null;
                for (String part : parts) {
                    String[] subParts = part.split(":");
                    if (subParts.length == 2) {
                        if (subParts[0].equals(tableId.identifier())) {
                            currentBackupTable = subParts[1];
                            break;
                        }
                    }
                }

                if (StringUtils.isBlank(currentBackupTable)) {
                    throw new RuntimeException(
                            "tableId: " + tableId.identifier() + " backupTable not exists");
                }

                dmlSql =
                        dialect.getUpsertStatement(
                                currentBackupTable,
                                columnNameArray,
                                primaryKeyArray,
                                schema.getShareKey());
            } else {
                dmlSql = dialect.getDeleteStatement(tableId.identifier(), primaryKeyArray);
            }
        }
        dmlSql = parseNamedStatement(dmlSql, parameterMap);
        return dmlSql;
    }

    public String getCustomColumn(String customColumn, TableId tableId) {
        String currentCustomColumn = null;
        if (StringUtils.isBlank(customColumn)) {
            return currentCustomColumn;
        }

        String[] parts = customColumn.split(";");
        for (String part : parts) {
            if (part.contains(tableId.identifier())) {
                String[] subParts = part.split(":");
                if (subParts.length == 2) {
                    currentCustomColumn = subParts[1];
                    break;
                }
            }
        }
        return currentCustomColumn;
    }

    private void excludeColumn(TableId tableId, List<String> columnNames) {
        String excludeColumnParam = getCustomColumn(excludeColumn.get(), tableId);
        if (excludeColumnParam != null) {
            List<String> excludeColumnList = Arrays.asList(excludeColumnParam.split(","));
            if (CollectionUtils.isNotEmpty(excludeColumnList)) {
                for (String column : excludeColumnList) {
                    if (StringUtils.isNotBlank(column)) {
                        if (columnNames.contains(column)) {
                            columnNames.remove(column);
                        }
                    }
                }
            }
        }
    }

    public static String getCurrentDateTimeAsString() {
        LocalDateTime currentTime = LocalDateTime.now();
        return currentTime.format(FORMATTER);
    }

    public static String parseNamedStatement(String sql, Map<String, List<Integer>> paramMap) {
        StringBuilder parsedSql = new StringBuilder();
        int fieldIndex = 1; // SQL statement parameter index starts from 1
        int length = sql.length();
        for (int i = 0; i < length; i++) {
            char c = sql.charAt(i);
            if (':' == c) {
                int j = i + 1;
                while (j < length && Character.isJavaIdentifierPart(sql.charAt(j))) {
                    j++;
                }
                String parameterName = sql.substring(i + 1, j);
                checkArgument(
                        !parameterName.isEmpty(),
                        "Named parameters in SQL statement must not be empty.");
                paramMap.computeIfAbsent(parameterName, n -> new ArrayList<>()).add(fieldIndex);
                fieldIndex++;
                i = j - 1;
                parsedSql.append('?');
            } else {
                parsedSql.append(c);
            }
        }
        return parsedSql.toString();
    }

    private String formatAvgTime(long numerator, long denominator) {
        if (denominator == 0) {
            return "0";
        }
        double result = (double) numerator / denominator;
        return String.format("%.2f", result);
    }

    static JdbcBatchStatementExecutor simple(
            JdbcConnectionProvider connectionProvider, PostgresEventSerializer serializer) {
        return new SimpleBatchStatementExecutorV2(connectionProvider, serializer);
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants