Spark读写MySQL

Author Avatar
山小杰 3月 12, 2022
  • 在其它设备中阅读本文章

简介

本文就简单介绍一下Spark读写MySQL的几种方式,不同场景下可以灵活选取。

环境说明:

CDH版本:CDH 6.3.1
Spark版本:2.4.0
MySQL版本:5.7

读取MySQL表

Spark读取MySQL数据可以有单分区和多分区两种方式,一般读取小数据量的表采用简单的单分区模式就可以,对于比较大的单分区抽取需要消耗时间较长的表来说,采用多分区模式读取性能会更好。

单分区单线程读取

此种方式是最简单的读取方式,但只有单线程,仅限于小数据量表,需要谨慎在生产库中使用,指定连接地址和表名即可:

val jdbcUrl = "jdbc:mysql://xxx.xxx.xxx.xxx:3306/test?useUnicode=true&characterEncoding=utf8&tinyInt1isBit=false"
val jdbcUser = ""
val jdbcPass = ""
// 可以直接指定表名,也可以写 SELECT 语句(必须要有临时表包裹),比如:
// val table = "(select * from test.test_table where status = 1) tmp"
val table = ""
val df = spark.read
  .format("jdbc")
  .option("url", jdbcUrl)
  .option("user", jdbcUser)
  .option("password", jdbcPass)
  .option("driver", "com.mysql.jdbc.Driver")
  .option("dbtable", table)
  .load()

多分区并行读取

此种方式对于抽取数据量较大的表有很好的性能提升,但仅限于有连续数值型主键(比如自增id)的数据表:

val jdbcUrl = "jdbc:mysql://xxx.xxx.xxx.xxx:3306/test?useUnicode=true&characterEncoding=utf8&tinyInt1isBit=false"
val jdbcUser = ""
val jdbcPass = ""
// 可以直接指定表名,也可以写 SELECT 语句(必须要有临时表包裹),比如:
// val table = "(select * from test.test_table where status = 1) tmp"
val table = ""
val partitionNum = 6
val minId = 1
val maxId = 6000000
val df = spark.read
  .format("jdbc")
  .option("url", jdbcUrl)
  .option("user", jdbcUser)
  .option("password", jdbcPass)
  .option("driver", "com.mysql.jdbc.Driver")
  .option("dbtable", table)
  // 以下4个配置项必须同时使用
  // 分区数量,可以理解为读取并行度、线程数
  .option("numPartitions", partitionNum)
  // 分区字段,必须为数字、日期、时间戳字段
  .option("partitionColumn", "id")
  // lowerBound 和 upperBound 仅用于计算每个分区的取数步长,不用于数据过滤
  // 分区字段的最小值
  .option("lowerBound", minId)
  // 分区字段的最大值
  .option("upperBound", maxId)
  .load()

写入MySQL表

追加写和覆盖写比较简单,但要注意覆盖写表可能会出现删表重建的操作。

追加写

df.write
  .format("jdbc")
  .mode(SaveMode.Append)
  .option("driver", "com.mysql.jdbc.Driver")
  .option("url", jdbcUrl)
  .option("dbtable", table)
  .option("user", jdbcUser)
  .option("password", jdbcPass)
  // JDBC批大小,默认 1000,灵活调整该值可以提高写入性能
  .option("batchsize", 10000)
  // 事务级别,默认为 READ_UNCOMMITTED,无事务要求可以填 NONE 以提高性能
  .option("isolationLevel", "NONE")
  .save()

覆盖写

df.write
  .format("jdbc")
  .mode(SaveMode.Overwrite)
  .option("driver", "com.mysql.jdbc.Driver")
  .option("url", jdbcUrl)
  .option("dbtable", table)
  .option("user", jdbcUser)
  .option("password", jdbcPass)
  // JDBC批大小,默认 1000,灵活调整该值可以提高写入性能
  .option("batchsize", 10000)
  // 事务级别,默认为 READ_UNCOMMITTED,无事务要求可以填 NONE 以提高性能
  .option("isolationLevel", "NONE")
  // 需要注意该配置项,Overwrite 模式下,不设置为 true 会删表重建
  .option("truncate", "true")
  .save()

更新写(UPSERT/INSERT OR UPDATE)

更新接入比较复杂一些,一般结合 foreachPartition 使用。同时需要目标表创建 UNIQUE KEY,因为需要基于UNIQUE KEY来实现UPSERT

df.foreachPartition(iter => {
  val conn = ds.getConnection
  val sql =
      """
        |INSERT INTO test_table (uid,a,b,c,d,e)
        |VALUES (?,?,?,?,?,?)
        |ON DUPLICATE KEY
        |UPDATE c = ?, d = ?
        |""".stripMargin
  val ps = conn.prepareStatement(sql)
  iter.foreach(row => {
    val uid = row.getAs[Long]("pid")
    val a = row.getAs[Long]("a")
    val b = row.getAs[String]("b")
    val c = row.getAs[java.math.BigDecimal]("c")
    val d = row.getAs[java.math.BigDecimal]("d")
    val e = row.getAs[Byte]("e")
    ps.setLong(1, uid)
    ps.setLong(2, a)
    ps.setString(3, b)
    ps.setBigDecimal(4, c)
    ps.setBigDecimal(5, d)
    ps.setByte(6, e)
    ps.setBigDecimal(7, c)
    ps.setBigDecimal(8, d)
    ps.executeUpdate()
  })
  DbUtil.close(conn)
})

代码封装示例

基于上面的介绍,可以将Spark读写MySQL进行一个简单地封装,使用起来会更加方便:

/**
 * 读取 MySQL 表,并行读取时固定 id 为分区字段
 *
 * @param spark        SparkSession
 * @param table        表名
 * @param partitionNum 分区数量
 * @param filterKey    过滤字段
 * @param filterMin    过滤条件最小值
 * @param filterMax     过滤条件最大值
 * @param jdbcUrl      url
 * @param jdbcUser     user
 * @param jdbcPass     password
 * @return
 */
def readMysqlPartById(spark: SparkSession, table: String, partitionNum: Int = 0,
                      filterKey: String = "", filterMin: String = "", filterMax: String = "",
                      jdbcUrl: String = url, jdbcUser: String = user, jdbcPass: String = pass): DataFrame = {
  val conn = spark.read
    .format("jdbc")
    .option("url", jdbcUrl)
    .option("user", jdbcUser)
    .option("password", jdbcPass)
    .option("driver", "com.mysql.jdbc.Driver")
  if (partitionNum == 0) {
    conn.option("dbtable", table).load()
  } esle {
    if ("".equals(filterKey)) {
      // 读取最大id
      val ids = conn.option("dbtable", s"(select min(id),max(id) from $table) tmp")
        .load()
        .first()
      if (ids.isNullAt(0)) {
        spark.emptyDataFrame
      } else {
        val minId = String.valueOf(ids.get(0)).toLong
        val maxId = String.valueOf(ids.get(1)).toLong
        conn.option("dbtable", table)
          .option("numPartitions", partitionNum)
          .option("partitionColumn", "id")
          .option("lowerBound", minId)
          .option("upperBound", maxId)
          .load()
      }
    } else {
      val filter = s"where $filterKey between '$filterMin' and '$filterMax'"
      // 读取最大id
      val ids = conn.option("dbtable", s"(select min(id),max(id) from $table $filter) tmp")
        .load()
        .first()
      if (ids.isNullAt(0)) {
        spark.emptyDataFrame
      } else {
        val minId = String.valueOf(ids.get(0)).toLong
        val maxId = String.valueOf(ids.get(1)).toLong
        conn.option("dbtable", s"(select * from $table $filter) tmp")
          .option("numPartitions", partitionNum)
          .option("partitionColumn", "id")
          .option("lowerBound", minId)
          .option("upperBound", maxId)
          .load()
      }
    }
  }
}


/**
 * Spark写入MySQL通用方法
 *
 * @param df       Spark DataFrame
 * @param saveMode 写入模式,覆盖 or 追加
 * @param table    目标表名
 * @param jdbcUrl  目标数据库jdbc连接
 * @param jdbcUser 目标数据库访问用户
 * @param jdbcPass 目标数据库访问密码
 */
def writeMysql(df: DataFrame, saveMode: SaveMode, table: String,
               jdbcUrl: String = url, jdbcUser: String = user, jdbcPass: String = pass): Unit = {
  val dfWriter = df.write
    .format("jdbc")
    .mode(saveMode)
    .option("driver", "com.mysql.jdbc.Driver")
    .option("url", jdbcUrl)
    .option("dbtable", table)
    .option("user", jdbcUser)
    .option("password", jdbcPass)
    .option("batchsize", 10000)
    .option("isolationLevel", "NONE")
  if (saveMode == SaveMode.Overwrite) {
    // 如果覆盖写入,采用 truncate 模式,避免重建表
    dfWriter.option("truncate", "true").save()
  } else {
    dfWriter.save()
  }
}

参考内容

Spark JDBC data sources

本章完