Spark读写MySQL
简介
本文就简单介绍一下Spark读写MySQL的几种方式,不同场景下可以灵活选取。
环境说明:
CDH版本:CDH 6.3.1
Spark版本:2.4.0
MySQL版本:5.7
读取MySQL表
Spark读取MySQL数据可以有单分区和多分区两种方式,一般读取小数据量的表采用简单的单分区模式就可以,对于比较大的单分区抽取需要消耗时间较长的表来说,采用多分区模式读取性能会更好。
单分区单线程读取
此种方式是最简单的读取方式,但只有单线程,仅限于小数据量表,需要谨慎在生产库中使用,指定连接地址和表名即可:
val jdbcUrl = "jdbc:mysql://xxx.xxx.xxx.xxx:3306/test?useUnicode=true&characterEncoding=utf8&tinyInt1isBit=false"
val jdbcUser = ""
val jdbcPass = ""
// 可以直接指定表名,也可以写 SELECT 语句(必须要有临时表包裹),比如:
// val table = "(select * from test.test_table where status = 1) tmp"
val table = ""
val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("user", jdbcUser)
.option("password", jdbcPass)
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", table)
.load()
多分区并行读取
此种方式对于抽取数据量较大的表有很好的性能提升,但仅限于有连续数值型主键(比如自增id)的数据表:
val jdbcUrl = "jdbc:mysql://xxx.xxx.xxx.xxx:3306/test?useUnicode=true&characterEncoding=utf8&tinyInt1isBit=false"
val jdbcUser = ""
val jdbcPass = ""
// 可以直接指定表名,也可以写 SELECT 语句(必须要有临时表包裹),比如:
// val table = "(select * from test.test_table where status = 1) tmp"
val table = ""
val partitionNum = 6
val minId = 1
val maxId = 6000000
val df = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("user", jdbcUser)
.option("password", jdbcPass)
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", table)
// 以下4个配置项必须同时使用
// 分区数量,可以理解为读取并行度、线程数
.option("numPartitions", partitionNum)
// 分区字段,必须为数字、日期、时间戳字段
.option("partitionColumn", "id")
// lowerBound 和 upperBound 仅用于计算每个分区的取数步长,不用于数据过滤
// 分区字段的最小值
.option("lowerBound", minId)
// 分区字段的最大值
.option("upperBound", maxId)
.load()
写入MySQL表
追加写和覆盖写比较简单,但要注意覆盖写表可能会出现删表重建的操作。
追加写
df.write
.format("jdbc")
.mode(SaveMode.Append)
.option("driver", "com.mysql.jdbc.Driver")
.option("url", jdbcUrl)
.option("dbtable", table)
.option("user", jdbcUser)
.option("password", jdbcPass)
// JDBC批大小,默认 1000,灵活调整该值可以提高写入性能
.option("batchsize", 10000)
// 事务级别,默认为 READ_UNCOMMITTED,无事务要求可以填 NONE 以提高性能
.option("isolationLevel", "NONE")
.save()
覆盖写
df.write
.format("jdbc")
.mode(SaveMode.Overwrite)
.option("driver", "com.mysql.jdbc.Driver")
.option("url", jdbcUrl)
.option("dbtable", table)
.option("user", jdbcUser)
.option("password", jdbcPass)
// JDBC批大小,默认 1000,灵活调整该值可以提高写入性能
.option("batchsize", 10000)
// 事务级别,默认为 READ_UNCOMMITTED,无事务要求可以填 NONE 以提高性能
.option("isolationLevel", "NONE")
// 需要注意该配置项,Overwrite 模式下,不设置为 true 会删表重建
.option("truncate", "true")
.save()
更新写(UPSERT/INSERT OR UPDATE)
更新接入比较复杂一些,一般结合 foreachPartition
使用。同时需要目标表创建 UNIQUE KEY
,因为需要基于UNIQUE KEY
来实现UPSERT
。
df.foreachPartition(iter => {
val conn = ds.getConnection
val sql =
"""
|INSERT INTO test_table (uid,a,b,c,d,e)
|VALUES (?,?,?,?,?,?)
|ON DUPLICATE KEY
|UPDATE c = ?, d = ?
|""".stripMargin
val ps = conn.prepareStatement(sql)
iter.foreach(row => {
val uid = row.getAs[Long]("pid")
val a = row.getAs[Long]("a")
val b = row.getAs[String]("b")
val c = row.getAs[java.math.BigDecimal]("c")
val d = row.getAs[java.math.BigDecimal]("d")
val e = row.getAs[Byte]("e")
ps.setLong(1, uid)
ps.setLong(2, a)
ps.setString(3, b)
ps.setBigDecimal(4, c)
ps.setBigDecimal(5, d)
ps.setByte(6, e)
ps.setBigDecimal(7, c)
ps.setBigDecimal(8, d)
ps.executeUpdate()
})
DbUtil.close(conn)
})
代码封装示例
基于上面的介绍,可以将Spark读写MySQL进行一个简单地封装,使用起来会更加方便:
/**
* 读取 MySQL 表,并行读取时固定 id 为分区字段
*
* @param spark SparkSession
* @param table 表名
* @param partitionNum 分区数量
* @param filterKey 过滤字段
* @param filterMin 过滤条件最小值
* @param filterMax 过滤条件最大值
* @param jdbcUrl url
* @param jdbcUser user
* @param jdbcPass password
* @return
*/
def readMysqlPartById(spark: SparkSession, table: String, partitionNum: Int = 0,
filterKey: String = "", filterMin: String = "", filterMax: String = "",
jdbcUrl: String = url, jdbcUser: String = user, jdbcPass: String = pass): DataFrame = {
val conn = spark.read
.format("jdbc")
.option("url", jdbcUrl)
.option("user", jdbcUser)
.option("password", jdbcPass)
.option("driver", "com.mysql.jdbc.Driver")
if (partitionNum == 0) {
conn.option("dbtable", table).load()
} esle {
if ("".equals(filterKey)) {
// 读取最大id
val ids = conn.option("dbtable", s"(select min(id),max(id) from $table) tmp")
.load()
.first()
if (ids.isNullAt(0)) {
spark.emptyDataFrame
} else {
val minId = String.valueOf(ids.get(0)).toLong
val maxId = String.valueOf(ids.get(1)).toLong
conn.option("dbtable", table)
.option("numPartitions", partitionNum)
.option("partitionColumn", "id")
.option("lowerBound", minId)
.option("upperBound", maxId)
.load()
}
} else {
val filter = s"where $filterKey between '$filterMin' and '$filterMax'"
// 读取最大id
val ids = conn.option("dbtable", s"(select min(id),max(id) from $table $filter) tmp")
.load()
.first()
if (ids.isNullAt(0)) {
spark.emptyDataFrame
} else {
val minId = String.valueOf(ids.get(0)).toLong
val maxId = String.valueOf(ids.get(1)).toLong
conn.option("dbtable", s"(select * from $table $filter) tmp")
.option("numPartitions", partitionNum)
.option("partitionColumn", "id")
.option("lowerBound", minId)
.option("upperBound", maxId)
.load()
}
}
}
}
/**
* Spark写入MySQL通用方法
*
* @param df Spark DataFrame
* @param saveMode 写入模式,覆盖 or 追加
* @param table 目标表名
* @param jdbcUrl 目标数据库jdbc连接
* @param jdbcUser 目标数据库访问用户
* @param jdbcPass 目标数据库访问密码
*/
def writeMysql(df: DataFrame, saveMode: SaveMode, table: String,
jdbcUrl: String = url, jdbcUser: String = user, jdbcPass: String = pass): Unit = {
val dfWriter = df.write
.format("jdbc")
.mode(saveMode)
.option("driver", "com.mysql.jdbc.Driver")
.option("url", jdbcUrl)
.option("dbtable", table)
.option("user", jdbcUser)
.option("password", jdbcPass)
.option("batchsize", 10000)
.option("isolationLevel", "NONE")
if (saveMode == SaveMode.Overwrite) {
// 如果覆盖写入,采用 truncate 模式,避免重建表
dfWriter.option("truncate", "true").save()
} else {
dfWriter.save()
}
}