通过StreamExecution探究FileStreamSink和ConsoleSinkProvider-实现StreamingSink的两种方式

Structured Streaming部分源码
实现StreamingSink的两种方式
Structured Streaming从readStream到writerStream
实现StreamingSink的两种方式:
继承Sink【实现了BaseStreamingSink】
实现StreamWriteSupport【继承了BaseStreamingSink】

继承Sink实现XxxSink【StreamingSink】

继承Sink实现的XxxSink,如:
FileStreamSink
KafkaSink
trait Sink extends BaseStreamingSink

FileStreamSink
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/**
* A sink that writes out results to parquet files. Each batch is written out to a unique
* directory. After all of the files in a batch have been successfully written, the list of
* file paths is appended to the log atomically. In the case of partial failures, some duplicate
* data may be present in the target directory, but only one copy of each file will be present
* in the log.
*/
class FileStreamSink(
sparkSession: SparkSession,
path: String,
fileFormat: FileFormat,
partitionColumnNames: Seq[String],
options: Map[String, String]) extends Sink with Logging {

private val basePath = new Path(path)
private val logPath = new Path(basePath, FileStreamSink.metadataDir)
private val fileLog =
new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString)
private val hadoopConf = sparkSession.sessionState.newHadoopConf()

private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics)
}

override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
logInfo(s"Skipping already committed batch $batchId")
} else {
val committer = FileCommitProtocol.instantiate(
className = sparkSession.sessionState.conf.streamingFileCommitProtocolClass,
jobId = batchId.toString,
outputPath = path)

committer match {
case manifestCommitter: ManifestFileCommitProtocol =>
manifestCommitter.setupManifestOptions(fileLog, batchId)
case _ => // Do nothing
}

// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col =>
val nameEquality = data.sparkSession.sessionState.conf.resolver
data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}")
}
}
val qe = data.queryExecution

FileFormatWriter.write(
sparkSession = sparkSession,
plan = qe.executedPlan,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
statsTrackers = Seq(basicWriteJobStatsTracker),
options = options)
}
}

override def toString: String = s"FileSink[$path]"
}
KafkaSink
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private[kafka010] class KafkaSink(
sqlContext: SQLContext,
executorKafkaParams: ju.Map[String, Object],
topic: Option[String]) extends Sink with Logging {
@volatile private var latestBatchId = -1L

override def toString(): String = "KafkaSink"

override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= latestBatchId) {
logInfo(s"Skipping already committed batch $batchId")
} else {
KafkaWriter.write(sqlContext.sparkSession,
data.queryExecution, executorKafkaParams, topic)
latestBatchId = batchId
}
}
}

实现StreamWriteSupport【StreamingSink】

public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink

ConsoleSinkProvider
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
extends BaseRelation {
override def schema: StructType = data.schema
}

class ConsoleSinkProvider extends DataSourceV2
with StreamWriteSupport
with DataSourceRegister
with CreatableRelationProvider {

override def createStreamWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
new ConsoleWriter(schema, options)
}

def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
// Number of rows to display, by default 20 rows
val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20)

// Truncate the displayed data if it is too long, by default it is true
val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
data.show(numRowsToShow, isTruncated)

ConsoleRelation(sqlContext, data)
}

def shortName(): String = "console"
}

处理Sink的核心代码

Structured Streaming从readStream到writerStream
writerStream的部分【readStream略】
final class DataStreamWriter

如:

1
2
3
4
5
6
7
xxx_agg_df.writeStream \
.option("checkpointLocation", "/tmp/checkpoint/douban_gallery_click_cnt") \
.option("catalog", catalog) \
.format('streaming.DemoSinkProvider') \
.trigger(processingTime='10 seconds') \
.outputMode('complete') \
.start()
org.apache.spark.sql.streaming.DataStreamWriter.scala

看下start()函数
会调用df.sparkSession.sessionState.streamingQueryManager.startQuery

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
* key-value stores, etc). Use `Dataset.writeStream` to access this.
*
* @since 2.0.0
*/
@InterfaceStability.Evolving
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

def start(): StreamingQuery = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"write files of Hive data source directly.")
}

if (source == "memory") {
assertNotPartitioned("memory")
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
val (sink, resultDf) = trigger match {
case _: ContinuousTrigger =>
val s = new MemorySinkV2()
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
(s, r)
case _ =>
val s = new MemorySink(df.schema, outputMode)
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
(s, r)
}
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
chkpointLoc,
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
recoverFromCheckpointLocation = recoverFromChkpoint,
trigger = trigger)
resultDf.createOrReplaceTempView(query.name)
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else {
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
val sink = ds.newInstance() match {
case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w
case _ =>
val ds = DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}

df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = source == "console",
recoverFromCheckpointLocation = true,
trigger = trigger)
}
}
}
org.apache.spark.sql.streaming.StreamingQueryManager.scala

看下startQuery函数
会调用自己的createQuery创建一个query【这里有个模式匹配包装一个StreamingQuery-StreamingQueryWrapper(StreamExecution)】。
StreamExecution【继承StreamingQuery】的两个具体实现:ContinuousExecution和MicroBatchExecution
query.streamingQuery.start()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/**
* A class to manage all the [[StreamingQuery]] active in a `SparkSession`.
*
* @since 2.0.0
*/
@InterfaceStability.Evolving
class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {

private def createQuery(
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
extraOptions: Map[String, String],
sink: BaseStreamingSink,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean,
recoverFromCheckpointLocation: Boolean,
trigger: Trigger,
triggerClock: Clock): StreamingQueryWrapper = {
var deleteCheckpointOnStop = false
val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified =>
new Path(userSpecified).toUri.toString
}.orElse {
df.sparkSession.sessionState.conf.checkpointLocation.map { location =>
new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toUri.toString
}
}.getOrElse {
if (useTempCheckpointLocation) {
// Delete the temp checkpoint when a query is being stopped without errors.
deleteCheckpointOnStop = true
Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath
} else {
throw new AnalysisException(
"checkpointLocation must be specified either " +
"""through option("checkpointLocation", ...) or """ +
s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
}
}

// If offsets have already been created, we trying to resume a query.
if (!recoverFromCheckpointLocation) {
val checkpointPath = new Path(checkpointLocation, "offsets")
val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.newHadoopConf())
if (fs.exists(checkpointPath)) {
throw new AnalysisException(
s"This query does not support recovering from checkpoint location. " +
s"Delete $checkpointPath to start over.")
}
}

val analyzedPlan = df.queryExecution.analyzed
df.queryExecution.assertAnalyzed()

if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode)
}

if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " +
"is not supported in streaming DataFrames/Datasets and will be disabled.")
}

(sink, trigger) match {
case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) =>
if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
}
new StreamingQueryWrapper(new ContinuousExecution(
sparkSession,
userSpecifiedName.orNull,
checkpointLocation,
analyzedPlan,
v2Sink,
trigger,
triggerClock,
outputMode,
extraOptions,
deleteCheckpointOnStop))
case _ =>
new StreamingQueryWrapper(new MicroBatchExecution(
sparkSession,
userSpecifiedName.orNull,
checkpointLocation,
analyzedPlan,
sink,
trigger,
triggerClock,
outputMode,
extraOptions,
deleteCheckpointOnStop))
}
}

/**
* Start a [[StreamingQuery]].
*
* @param userSpecifiedName Query name optionally specified by the user.
* @param userSpecifiedCheckpointLocation Checkpoint location optionally specified by the user.
* @param df Streaming DataFrame.
* @param sink Sink to write the streaming outputs.
* @param outputMode Output mode for the sink.
* @param useTempCheckpointLocation Whether to use a temporary checkpoint location when the user
* has not specified one. If false, then error will be thrown.
* @param recoverFromCheckpointLocation Whether to recover query from the checkpoint location.
* If false and the checkpoint location exists, then error
* will be thrown.
* @param trigger [[Trigger]] for the query.
* @param triggerClock [[Clock]] to use for the triggering.
*/
private[sql] def startQuery(
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
extraOptions: Map[String, String],
sink: BaseStreamingSink,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
trigger: Trigger = ProcessingTime(0),
triggerClock: Clock = new SystemClock()): StreamingQuery = {
val query = createQuery(
userSpecifiedName,
userSpecifiedCheckpointLocation,
df,
extraOptions,
sink,
outputMode,
useTempCheckpointLocation,
recoverFromCheckpointLocation,
trigger,
triggerClock)

activeQueriesLock.synchronized {
// Make sure no other query with same name is active
userSpecifiedName.foreach { name =>
if (activeQueries.values.exists(_.name == name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
}

// Make sure no other query with same id is active
if (activeQueries.values.exists(_.id == query.id)) {
throw new IllegalStateException(
s"Cannot start query with id ${query.id} as another query with same id is " +
s"already active. Perhaps you are attempting to restart a query from checkpoint " +
s"that is already active.")
}

activeQueries.put(query.id, query)
}
try {
// When starting a query, it will call `StreamingQueryListener.onQueryStarted` synchronously.
// As it's provided by the user and can run arbitrary codes, we must not hold any lock here.
// Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long
// time to finish.
query.streamingQuery.start()
} catch {
case e: Throwable =>
activeQueriesLock.synchronized {
activeQueries -= query.id
}
throw e
}
query
}
}
org.apache.spark.sql.execution.streaming.StreamExecution.scala

一个抽象类,上面说了有两个具体实现,下面看其中一个。
看下start函数,起一个QueryExecutionThread并启动,线程的run方法会调用runStream(),runStream()里面包含事件通知、逻辑执行计划等等,这里关心的是runActivatedStream(sparkSessionForStream)。
消费Kafka数据流时会打印KAFKA-1894的循环警告,注释已经说明。
runActivatedStream(sparkSessionForStream)由具体的实现类实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
* Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any
* [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created
* and the results are committed transactionally to the given [[Sink]].
*
* @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without
* errors
*/
abstract class StreamExecution(
override val sparkSession: SparkSession,
override val name: String,
private val checkpointRoot: String,
analyzedPlan: LogicalPlan,
val sink: BaseStreamingSink,
val trigger: Trigger,
val triggerClock: Clock,
val outputMode: OutputMode,
deleteCheckpointOnStop: Boolean)
extends StreamingQuery with ProgressReporter with Logging {

/**
* Activate the stream and then wrap a callout to runActivatedStream, handling start and stop.
*
* Note that this method ensures that [[QueryStartedEvent]] and [[QueryTerminatedEvent]] are
* posted such that listeners are guaranteed to get a start event before a termination.
* Furthermore, this method also ensures that [[QueryStartedEvent]] event is posted before the
* `start()` method returns.
*/
private def runStream(): Unit = {
// 其他代码忽略,state是一个AtomicReference[State](INITIALIZING)
if (state.compareAndSet(INITIALIZING, ACTIVE)) {
// Unblock `awaitInitialization`
initializationLatch.countDown()
runActivatedStream(sparkSessionForStream)
updateStatusMessage("Stopped")
} else {
// `stop()` is already called. Let `finally` finish the cleanup.
}
}

/**
* The thread that runs the micro-batches of this stream. Note that this thread must be
* [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a
* running `KafkaConsumer` may cause endless loop.
*/
val queryExecutionThread: QueryExecutionThread =
new QueryExecutionThread(s"stream execution thread for $prettyIdString") {
override def run(): Unit = {
// To fix call site like "run at <unknown>:0", we bridge the call site from the caller
// thread to this micro batch thread
sparkSession.sparkContext.setCallSite(callSite)
runStream()
}
}

/**
* Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]]
* has been posted to all the listeners.
*/
def start(): Unit = {
logInfo(s"Starting $prettyIdString. Use $resolvedCheckpointRoot to store the query checkpoint.")
queryExecutionThread.setDaemon(true)
queryExecutionThread.start()
startLatch.await() // Wait until thread started and QueryStart event has been posted
}
}
org.apache.spark.sql.execution.streaming.MicroBatchExecution.scala

直接查看runActivatedStream
会调用runBatch(sparkSessionForStream)
runBatch中可以看到addBatch,Sink和StreamWriteSupport在这里进行匹配和后续执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class MicroBatchExecution(
sparkSession: SparkSession,
name: String,
checkpointRoot: String,
analyzedPlan: LogicalPlan,
sink: BaseStreamingSink,
trigger: Trigger,
triggerClock: Clock,
outputMode: OutputMode,
extraOptions: Map[String, String],
deleteCheckpointOnStop: Boolean)
extends StreamExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {


/**
* Processes any data available between `availableOffsets` and `committedOffsets`.
* @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
*/
private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
// 关心这一段
reportTimeTaken("addBatch") {
SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
sink match {
case s: Sink => s.addBatch(currentBatchId, nextBatch)
case _: StreamWriteSupport =>
// This doesn't accumulate any data - it just forces execution of the microbatch writer.
nextBatch.collect()
}
}
}
}

/**
* Repeatedly attempts to run batches as data arrives.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {

val noDataBatchesEnabled =
sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled

triggerExecutor.execute(() => {
if (isActive) {
var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run
var currentBatchHasNewData = false // Whether the current batch had new data

startTrigger()

reportTimeTaken("triggerExecution") {
// We'll do this initialization only once every start / restart
if (currentBatchId < 0) {
populateStartOffsets(sparkSessionForStream)
logInfo(s"Stream started from $committedOffsets")
}

// Set this before calling constructNextBatch() so any Spark jobs executed by sources
// while getting new data have the correct description
sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)

// Try to construct the next batch. This will return true only if the next batch is
// ready and runnable. Note that the current batch may be runnable even without
// new data to process as `constructNextBatch` may decide to run a batch for
// state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data
// is available or not.
currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled)

// Remember whether the current batch has data or not. This will be required later
// for bookkeeping after running the batch, when `isNewDataAvailable` will have changed
// to false as the batch would have already processed the available data.
currentBatchHasNewData = isNewDataAvailable

currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable)
if (currentBatchIsRunnable) {
if (currentBatchHasNewData) updateStatusMessage("Processing new data")
else updateStatusMessage("No new data but cleaning up state")
runBatch(sparkSessionForStream)
} else {
updateStatusMessage("Waiting for data to arrive")
}
}

finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded

// If the current batch has been executed, then increment the batch id, else there was
// no data to execute the batch
if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs)
}
updateStatusMessage("Waiting for next trigger")
isActive
})
}

}
邵志鹏 wechat
扫一扫上面的二维码关注我的公众号
0%