Spark-Core 学习笔记
SparkCore源码学习笔记
Spark中Application的提交流程及环境准备源码分析(以Yarn Cluster 模式为例)
1.AppLication提交流程
作业的提交样例:
bin/spark-submit \
--class org.apache.spark.examples.SparkPi \
--master yarn \
--deploy-mode cluster \
./examples/jars/spark-examples_2.12-3.0.0.jar \
10
执行spark-submit 指令 对应到代码中就是执行 SparkSubmit 类中的main方法
用创建好的SparkSubmit 调用 doSubmit 方法 执行其中的 parseArguments 方法
SparkSubmit 类:
// 解析你传入的参数
val appArgs = parseArguments(args)
SparkSubmitArguments 类中:
protected def parseArguments(args: Array[String]): SparkSubmitArguments = {
// 获得一个SparkSubmitArguments对象
new SparkSubmitArguments(args)
}
//通过SparkSubmitArguments对象调用parse方法,开始正真解析参数
parse(args.asJava)
protected final void parse(List<String> args) {
// 通过正则表达式进行匹配,并之后循环进行分组判断是否符合条件
Pattern eqSeparatedOpt = Pattern.compile("(--[^=]+)=(.+)");
// 处理传入的参数,即通过匹配已经写好的参数,
//将传入的参数封装进SparkSubmitArguments对象(对传入的参数进行第一次封装)
//如果不匹配就直接终止填充,并跳出循环
if (!handle(name, value)) {
break;
}
//继续匹配参数
continue;
}
//通过SparkSubmitArguments对象调用
//Use `sparkProperties` map along with env vars to fill in any missing parameters
//并将action属性赋值为 SUBMIT 以供后面使用
loadEnvironmentArguments()
SparkSubmit 类:
appArgs.action match {
//由于刚才将action属性赋值为了 SUBMIT 所以在这里匹配就可以匹配到第一个,执行submit方法
case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)
case SparkSubmitAction.KILL => kill(appArgs)
case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
case SparkSubmitAction.PRINT_VERSION => printVersion()
}
//由于是yarnCluster模式我们传入的参数中也没有什么过多的参数
//所以这里就可以直接执行到runMain
private def runMain(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
//,将我们上面封装好的环境参数传入,准备提交环境参数
val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args)
/**prepareSubmitEnvironment 方法中由于是YarnCluster模式
我们可以获得YARN_CLUSTER_SUBMIT_CLASS = "org.apache.spark.deploy.yarn.YarnClusterApplication"
// In client mode, launch the application main class directly
//判断是不是Client模式,是的话就直接启动你自己的 application
if (deployMode == CLIENT) {
//这里的mainClass就是以前我们封装的SparkSubmitArguments对象中的属性
//对应到提交参数中就是"--class"
childMainClass = args.mainClass
}
//判断是否是YarnCluster模式
if (isYarnCluster) {
childMainClass = YARN_CLUSTER_SUBMIT_CLASS
}
*/
//通过反射加载childMainClass类
mainClass = Utils.classForName(childMainClass)
//判断mainClass 是否是SparkApplication的子类,很明显是的,所以通过反射创建YarnClusterApplication类对象
val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) {
mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication]
} else {
new JavaMainApplication(mainClass)
}
//调用YarnClusterApplication类中的start方法
app.start(childArgs.toArray, sparkConf)
}
YarnClusterApplication 类:
private[spark] class YarnClusterApplication extends SparkApplication {
override def start(args: Array[String], conf: SparkConf): Unit = {
// SparkSubmit would use yarn cache to distribute files & jars in yarn mode,
// so remove them from sparkConf here for yarn mode.
conf.remove(JARS)
conf.remove(FILES)
//参数:new ClientArguments(args) ,将我们解析好的参数进行第二次封装
//Client 类中有一个重要的属性: private val yarnClient = YarnClient.createYarnClient
// yarnClient 中有个rmClient 用于和resourceManager通信
//然后调用 run 方法
new Client(new ClientArguments(args), conf, null).run()
}
}
Client 类:
//Submit an application to the ResourceManager.
def run(): Unit = {
//首先调用submitApplication方法开始提交Application
this.appId = submitApplication()
}
def submitApplication(): ApplicationId = {
//先前的yarnClient开始启动与RM通信
launcherBackend.connect()
yarnClient.init(hadoopConf)
yarnClient.start()
// Get a new application from our RM
val newApp = yarnClient.createApplication()
val newAppResponse = newApp.getNewApplicationResponse()
// Set up the appropriate contexts to launch our AM
//获取一个合适的环境来启动ApplicationMaster
val containerContext = createContainerLaunchContext(newAppResponse)
val appContext = createApplicationSubmissionContext(newApp, containerContext)
//最终由 yarnClient 提交任务
yarnClient.submitApplication(appContext)
}
private val isClusterMode = sparkConf.get(SUBMIT_DEPLOY_MODE) == "cluster"
/**
* Set up a ContainerLaunchContext to launch our ApplicationMaster container.
* This sets up the launch environment, java options, and the command for launching the AM.
*/
private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse): ContainerLaunchContext = {
val amClass =
if (isClusterMode) {
Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName
} else {
/**
object ExecutorLauncher {
def main(args: Array[String]): Unit = {
ApplicationMaster.main(args)
}
}
可以看到ExecutorLauncher用的就是ApplicationMaster的main方法
*/
Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
}
// Command for the ApplicationMaster
val commands = prefixEnv ++
Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++
//参数说明:javaOpts :封装了启动AM所需的节点资源
// amArgs :封装了启动我们上传的jar包信息,以及参数信息、要启动的AM类信息等,
//其中 amClass 根据是否是cluster来决定在集群节点上是启动 ApplicationMaster 还是 ExecutorLauncher
javaOpts ++ amArgs ++
Seq(
"1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
"2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
//最后返回一个AM的Container
amContainer
}
至此提交流程结束:以上代码大致如下
下面开始进入环境准备流程
2.环境准备流程
ApplicationMaster类:
object ApplicationMaster extends Logging {
private var master: ApplicationMaster = _
//由于提交到了ResourceManager上并执行了,所以开始执行main方法
def main(args: Array[String]): Unit = {
//将传入的参数进行第三次封装
val amArgs = new ApplicationMasterArguments(args)
// 利用获得参数,以及配置信息创建一个ApplicationMaster对象maser
//ApplicationMaster 类中有个属性:private val client = new YarnRMClient()
//YarnRMClient 类中有个属性 : private var amClient: AMRMClient[ContainerRequest] = _
//amClient 用于AM和RM进行通信,在后面的AM向RM注册中起到关键性作用
master = new ApplicationMaster(amArgs, sparkConf, yarnConf)
//调用maser.run()
ugi.doAs(new PrivilegedExceptionAction[Unit]() {
override def run(): Unit = System.exit(master.run())
})
}
private val isClusterMode = args.userClass != null
final def run(): Int = {
//根据 isClusterMode 来选择启动Driver 还是 ExecutorLauncher
//我们这里是YarnCluster走runDriver()
if (isClusterMode) {
runDriver()
} else {
//如果是Client模式的话会走这里,直接就注册AM和分配器
/**
private def runExecutorLauncher(): Unit = {
val rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
// The client-mode AM doesn't listen for incoming connections, so report an invalid port.
registerAM(hostname, -1, sparkConf, sparkConf.get(DRIVER_APP_UI_ADDRESS), appAttemptId)
createAllocator(driverRef, sparkConf, rpcEnv, appAttemptId, distCacheConf)
}
*/
runExecutorLauncher()
}
}
private def runDriver(): Unit = {
//关键的一步:启动并获得我们自己的程序
userClassThread = startUserApplication()
//从log信息中可以看出下面就开始等待main方法中的SparkContext初始化完毕,最大的等待时间默认是100s
logInfo("Waiting for spark context initialization...")
//等待并获得SparkContext对象
/**
SparkContext类中 _taskScheduler.postStartHook() 调用的是YarnClusterScheduler类中重写的 postStartHook方法
YarnClusterScheduler类:
override def postStartHook(): Unit = {
//这一步就可以让等待的程序继续往下执行了
ApplicationMaster.sparkContextInitialized(sc)
//让Driver程序等待环境的创建完毕再执行,也就是说Driver程序会阻塞在这里
super.postStartHook()
logInfo("YarnClusterScheduler.postStartHook done")
}
*/
val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
Duration(totalWaitTime, TimeUnit.MILLISECONDS))
//判断是否获得了SparkContext对象没有就直接报错了
if (sc != null) {
//开始注册AM 由client(用于AM和RM通信)调用register方法,在其中启动amClient 由它去真正执行向RM注册AM操作
registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)
//创建信息传递终端,即一个发件终端,用于发送信息给Executor端(在后面的通信环境说,现在跳过)
val driverRef = rpcEnv.setupEndpointRef(
RpcAddress(host, port),
YarnSchedulerBackend.ENDPOINT_NAME)
//创建分配器,用于分配与处理获取到的资源(container)
createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)
}
//前面获取到SparkContext对象之后Driver程序就会阻塞,走到这里程序运行的环境已经搭建完成,可以继续运行driver程序
resumeDriver()
userClassThread.join()
}
//Start the user class, which contains the spark driver, in a separate Thread
private def startUserApplication(): Thread = {
//通过类加载器获得main方法
val mainMethod = userClassLoader.loadClass(args.userClass)
.getMethod("main", classOf[Array[String]])
//创建了一个线程,用于执行main方法
val userThread = new Thread {
//判断main方法是否是静态的,不是就报错了
if (!Modifier.isStatic(mainMethod.getModifiers)) {
logError(s"Could not find static main method in object ${args.userClass}")
finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS)
} else {
//执行main方法
mainMethod.invoke(null, userArgs.toArray)
finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
logDebug("Done running user class")
}
}
//这里就可以看到,Driver启动了,并返回回去
userThread.setContextClassLoader(userClassLoader)
userThread.setName("Driver")
userThread.start()
userThread
}
private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf, rpcEnv: RpcEnv,appAttemptId: ApplicationAttemptId,distCacheConf: SparkConf): Unit = {
//client创建一个分配器
allocator = client.createAllocator(
yarnConf,
_sparkConf,
appAttemptId,
driverUrl,
driverRef,
securityMgr,
localResources)
//注册了一个收件终端,用于接收Executor端的信息
rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef))
//由创建的分配器去请求资源,但最后还是由amClient去请求资源
allocator.allocateResources()
}
//Request resources such that, if YARN gives us all we ask for, we'll have a number of containers
//equal to maxExecutors.
def allocateResources(): Unit = synchronized {
val progressIndicator = 0.1f
// Poll the ResourceManager. This doubles as a heartbeat if there are no pending container requests.
val allocateResponse = amClient.allocate(progressIndicator)
//获取可分配的资源列表
val allocatedContainers = allocateResponse.getAllocatedContainers()
allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes)
//如果获取的资源数大于0则开始处理这些资源
if (allocatedContainers.size > 0) {
//处理之后调用runAllocatedContainers方法启用这些资源
handleAllocatedContainers(allocatedContainers.asScala)
}
}
//Launches executors in the allocated containers.
private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = {
//如果正在启动的Executors数量小于总的Executor数量,则增加一个Executor
if (runningExecutors.size() < targetNumExecutors) {
//增加一个Executor
numExecutorsStarting.incrementAndGet()
if (launchContainers) {
//从线程池中,拉一个线程来启动Executor
launcherPool.execute(() =>
new ExecutorRunnable(
Some(container),
conf,
sparkConf,
driverUrl,
executorId,
executorHostname,
executorMemory,
executorCores,
appAttemptId.getApplicationId.toString,
securityMgr,
localResources,
ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID // use until fully supported
).run()
updateInternalState()
}
}
ExecutorRunnable 类:
var nmClient: NMClient = _
//通过nmClient(和NodeManager通信的)来启动Executor所需的Container
def run(): Unit = {
logDebug("Starting Executor Container")
nmClient = NMClient.createNMClient()
nmClient.init(conf)
nmClient.start()
//启动Executor所需的Container
startContainer()
}
def startContainer(): java.util.Map[String, ByteBuffer] = {
//准备启动container以及Executor所需要的指令
val commands = prepareCommand()
ctx.setCommands(commands.asJava)
// Send the start request to the ContainerManager
//开始执行YarnCoarseGrainedExecutorBackend
nmClient.startContainer(container.get, ctx)
}
private def prepareCommand(): List[String] = {
//可以看到javaOPts变量是前面封装好的获得JVM资源指令
//最终执行的类为"org.apache.spark.executor.YarnCoarseGrainedExecutorBackend":这个是Executor端的通信后台
val commands = prefixEnv ++
Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++
javaOpts ++
Seq("org.apache.spark.executor.YarnCoarseGrainedExecutorBackend",
"--driver-url", masterAddress,
"--executor-id", executorId,
"--hostname", hostname,
"--cores", executorCores.toString,
"--app-id", appId,
"--resourceProfileId", resourceProfileId.toString) ++
userClassPath ++
Seq(
s"1>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stdout",
s"2>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stderr")
}
以上Driver端环境准备完成,由nmClient 和拥有那些Container的节点通信,启动YarnCoarseGrainedExecutorBackend
以上代码流程图如下:
启动YarnCoarseGrainedExecutorBackend
YarnCoarseGrainedExecutorBackend类:
private[spark] object YarnCoarseGrainedExecutorBackend extends Logging {
//执行main方法
def main(args: Array[String]): Unit = {
//创建了一个YarnCoarseGrainedExecutorBackend对象,后面用作收邮件的收件终端
val createFn: (RpcEnv, CoarseGrainedExecutorBackend.Arguments, SparkEnv, ResourceProfile) =>
CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) =>
new YarnCoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId,
arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath, env,
arguments.resourcesFileOpt, resourceProfile)
}
//将传入的参数进行第四次封装
val backendArgs = CoarseGrainedExecutorBackend.parseArguments(args,
this.getClass.getCanonicalName.stripSuffix("$"))
//调用run
CoarseGrainedExecutorBackend.run(backendArgs, createFn)
System.exit(0)
}
}
CoarseGrainedExecutorBackend 类:
def run( arguments: Arguments, backendCreateFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) =>
CoarseGrainedExecutorBackend): Unit = {
//创建一个发件箱,用于发信息给Driver端
driver = fetcher.setupEndpointRefByURI(arguments.driverUrl)
//创建Executor的运行环境 ,这里会创建通讯环境 ,代码见下面的通讯解析
// val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
// securityManager, numUsableCores, !isDriver)
val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.bindAddress,
arguments.hostname, arguments.cores, cfg.ioEncryptionKey, isLocal = false)
//创建一个名为"Executor"收件箱,用于收取Driver端信息
env.rpcEnv.setupEndpoint("Executor",
backendCreateFn(env.rpcEnv, arguments, env, cfg.resourceProfile))
}
override def onStart(): Unit = {
logInfo("Connecting to driver: " + driverUrl)
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
//获得一个通讯终端driver
driver = Some(ref)
//向Driver端请求(RegisterExecutor)注册Executor
ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls,
extractAttributes, _resources, resourceProfile.id))
}(ThreadUtils.sameThread).onComplete {
case Success(_) =>
//如果接收到Driver端发送的信息是true的话就给自己发一个RegisteredExecutor指令,用于执行Executor
self.send(RegisteredExecutor)
case Failure(e) =>
exitExecutor(1, s"Cannot register with driver: $driverUrl", e, notifyDriver = false)
}(ThreadUtils.sameThread)
}
override def receive: PartialFunction[Any, Unit] = {
//模式匹配
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
//启动真正的Executor,至此Driver和Executor都启动完成,环境准备完成
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false,
resources = _resources)
//向Driver端发送LaunchedExecutor指令,表示这边的Executor已经启动,接收到这个信息的Driver会对ExecutorDataMap信息 //进行更新
driver.get.send(LaunchedExecutor(executorId))
}
NettyRpcEnv 类:
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
dispatcher.registerRpcEndpoint(name, endpoint)
}
Dispatcher 类:
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
//由endpoint 的类型来决定是创建一个专属的 MessageLoop 还是共享的
//一旦被创建就会有一个inbox对象发送OnStart信息给自己节点的终端,然后由终端调用onStart方法,开始向Driver通信
var messageLoop: MessageLoop = null
messageLoop = endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
}
SparkContext 类中有一个private var _schedulerBackend: SchedulerBackend = _ 属性,根据不同的运行模式选择不同的SchedulerBackend。
分类如下:
由于我们这里是集群模式,所以选第二大类 CoarseGrainedSchedulerBackend 其中有receiveAndReply方法进行模式匹配RegisterExecutor 最终返回
context.reply(true)
以上代码流程图为:
3.总结流程图
Yarn Cluster流程图
Yarn Client流程图
Spark中Driver和Executor通讯原理及代码分析
源码分析
1.初始化我们提交的Application中的SparkContext对象时,会调用如下代码
SparkContext类:
_env = createSparkEnv(_conf, isLocal, listenerBus)
private[spark] def createSparkEnv(conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = {
//createDriverEnv 方法中会调用create方法
/**
create方法中:
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
securityManager, numUsableCores, !isDriver)
RpcEnv 类中的 create 方法:
def create() 中会 new NettyRpcEnvFactory().create(config)
*/
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf))
}
对Netty通信框架的简单说明:
//通讯终端的生命周期
constructor -> onStart -> receive* -> onStop
NettyRpcEnvFactory 类:
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
//这里可以看到用的是NettyRpcEnv
val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,config.securityManager, config.numUsableCores)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
//启动服务端
nettyEnv.startServer(config.bindAddress, actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
//这里在一个端口上启动通讯用的服务端 但里面真正启动服务端的代码为:
//val (service, port) = startService(tryPort) 这里的startService调用的还是上面的
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
def startServer(bindAddress: String, port: Int): Unit = {
//创建一个服务端 return new TransportServer(this, host, port, rpcHandler, bootstraps);
//返回一个TransportServer 对象 这个就是通讯的服务端
//TransportServer这个类中的构造器中 init(hostToBind, portToBind); 用来初始化
server = transportContext.createServer(bindAddress, port, bootstraps)
//开始注册Rpc通讯终端
/**Dispatcher 类中:
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
//创建一个发送端的终端 其内部有一个 outboxes 属性
// private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
//用于发送信息给指定的RpcAddress ,即一个发件终端可能会有多个发件箱(outbox),依据为:Outbox对象中
//有一个属性用于存储需要发件的客户端为:private var client: TransportClient = null
// 当然 存储这些信息也是这个: private val messages = new java.util.LinkedList[OutboxMessage]
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
}
// 通过模式匹配 匹配你传入的endpoint对象 来创建一个收件终端,其内部有一个inbox对象 ,
//其中message属性 protected val messages = new java.util.LinkedList[InboxMessage]()
// 用于存储 接收到的信息
messageLoop = endpoint match {
case e: IsolatedRpcEndpoint =>
new DedicatedMessageLoop(name, e, this)
case _ =>
sharedLoop.register(name, endpoint)
sharedLoop
}
*/
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
}
TransportServer 类:
private void init(String hostToBind, int portToBind) {
//这里会获得通讯模式:默认是 NIO :return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT);
IOMode ioMode = IOMode.valueOf(conf.ioMode());
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
//getServerChannelClass 这里通过switch case 匹配并获得对应的通讯模式的对象
/**
switch (mode) {
case NIO:
return NioServerSocketChannel.class;
case EPOLL:
return EpollServerSocketChannel.class;
default:
throw new IllegalArgumentException("Unknown io mode: " + mode);
}
*/
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, pooledAllocator)
.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
.childOption(ChannelOption.ALLOCATOR, pooledAllocator);
}
通讯组件架构图
Spark中SparkContext 中重要的属性
源码中的属性
初始化SparkContext是这些属性都会被初始化
private var _conf: SparkConf = _
private var _env: SparkEnv = _
private var _schedulerBackend: SchedulerBackend = _
private var _taskScheduler: TaskScheduler = _
private var _dagScheduler: DAGScheduler = _
图解
Spark中RDD间的依赖关系、阶段的划分、任务的的切分、调度、执行源码分析
1.RDD间的依赖
OneToOneDependency ----> 父类为 NarrowDependency 所以又称为 窄依赖
ShuffleDependency ----> 父类为 Dependency 虽然没有写成宽依赖,但是大家还是公认这是与窄依赖相对应的宽依赖
2.RDD的创建及依赖的创建
2.1RDD的创建
RDD的创建模式有两种
从文件中创建 textFile
textFile 参数说明
path: String // 目标数据源位置
minPartitions: Int = defaultMinPartitions // 最小分区数,这里有默认值
// defaultParallelism 的值 根据 本地 和集群模式有所不同,集群模式没有这一步
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
// CoarseGrainedSchedulerBackend 集群模式:
override def defaultParallelism(): Int = {
// 取集群总核数 与 2 取最大值
conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
}
// LocalSchedulerBackend 本地模式:
override def defaultParallelism(): Int =
// 取本地节点的总核数
scheduler.conf.getInt("spark.default.parallelism", totalCores)
textFile 从文件中读取数据,默认用的就是Hadoop的读取方式 TextInputFormat
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException{
// 遍历获取文件的总字节数大小
long totalSize = 0; // compute total size
for (FileStatus file: files) { // check we have valid files
if (file.isDirectory()) {
throw new IOException("Not a file: "+ file.getPath());
}
totalSize += file.getLen();
}
// numSplits 表示的就是 传入的minPartitions 参数值
long goalSize = totalSize / (numSplits == 0 ? 1 : numSplits);
// minSize = Math.max(1,1) = 1
long minSize = Math.max(job.getLong(org.apache.hadoop.mapreduce.lib.input.
FileInputFormat.SPLIT_MINSIZE, 1), minSplitSize);
// 获取切片大小 Math.max(minSize, Math.min(goalSize, blockSize))
// blockSize :本地 32M
// 集群 hadoop 1.x 64M
// hadoop 2.x 128M
long splitSize = computeSplitSize(goalSi***Size, blockSize);
// hadoop 读取文件是每个文件独立读取的 这里就是拿 每个文件的size / splitSize ,然后判断是否大于 SPLIT_SLOP=1.1
// 大于 切两块 (两个分区),不大于切一块(一个分区)
while (((double) bytesRemaining)/splitSize > SPLIT_SLOP) {}
}
例子:
用 textFile 读取两个文件 (@为隐藏字符回车换行)
1.txt 字节数为 24byte 11.txt 字节数为 33 byte
hello spark@@ 1984949156418564@@
hello scala 15641@@
1256@@
12
分区计算过程:
57 / 2 = 28(byte)一个分区存放28个byte
1.txt => 24(byte) => 1个分区
11.txt => 33(byte) > 28 * 1.1 = 30.8 => 2个分区
一共生成3个分区(@为隐藏符号回车换行)
1.txt => hello spark@@ 0~12 11.txt => 1984949156418564@@ 25~42 0~17
hello scala 13~23 15641@@ 43~49 18~24
1256@@ 50~55 25~30
12 56~57 31~32
每个分区内只能存一个文件数据,这也就是我们通常说的每个文件之间独立切片
[0,28(最大可以容纳30.8)] => 1.txt
[0,28(最大可以容纳30.8)] => 1984949156418564@@
15641@@
1256@@
[0,28(最大可以容纳30.8)] => 12
从内存中创建 makeRDD 或者 parallelize
两者一样 makeRDD 底层调用的还是parallelize
参数解析
seq: Seq[T] // 目标数据所在的序列
numSlices: Int = defaultParallelism // 切片数(不一定是分区数)
// 注意这里和从文件中读数据不同,这里没有取最小值的限制,所以在集群中使用从内存创建RDD时最好还是要手动设置下这个 //"spark.default.parallelism" 值,避免出现过多的小文件
// CoarseGrainedSchedulerBackend 集群模式:
override def defaultParallelism(): Int = {
// 取集群总核数 与 2 取最大值
conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2))
}
// LocalSchedulerBackend 本地模式:
override def defaultParallelism(): Int =
// 取本地节点的总核数
scheduler.conf.getInt("spark.default.parallelism", totalCores)
具体分区源码
// 传入的 numSlices < 1 会直接报错
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
// 这里是具体的分区计算逻辑
(0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
}
}
例子:
目标序列:List(5, 6, 7, 8)
设置numSlices的值为3
那么分区计算为:(0 until numSlices) => (0,1,2)
例子: val start = ((0 * 4) / 3) => 0
val end = ((1 * 4) / 3) => 1
List(5, 6, 7, 8) =>(下标:) 0 1 2 3
【0,1) => 5
【1,2) => 6
【2,4) => 7 8
2.2RDD依赖的创建
没有shuffle的算子
最终调用的都是这段代码,可以清楚的看到传入了一个OneToOneDependency对象作为该RDD的依赖关系,后面代码如果要获取这个依赖的话,可以调用 getDependencies 方法直接得到
def this(@transient oneParent: RDD[_]) =
this(oneParent.context, List(new OneToOneDependency(oneParent)))
protected def getDependencies: Seq[Dependency[_]] = deps
有shuffle的算子
这里可以看到,在依赖的参数位置为Nil一个空集合,但是有默认值
class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
@transient var prev: RDD[_ <: Product2[K, V]],
part: Partitioner)
extends RDD[(K, C)](prev.context, Nil){
override def getDependencies: Seq[Dependency[_]] = {
//这里可以看到默认就返回了一个 ShuffleDependency 对象
//参数中prev就是上一个需要依赖的RDD
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}
}
RDD之间的依赖图解
3.Job的运行及阶段的划分、任务的切分
源码解析
当运行到action算子的时候才会触发Job运行,调用runJob方法
SparkContext 类:
def runJob[T, U: ClassTag](rdd: RDD[T],func: (TaskContext, Iterator[T]) => U,partitions: Seq[Int],resultHandler: (Int, U) => Unit): Unit = {
//clean 方法是用于闭包检测的,判断你传入的值中是否可以序列化,不能就直接报错 "Task not serializable"
val cleanedFunc = clean(func)
// 由 DAGScheduler 对象开始运行Job
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
//这里可以看到rdd在运行时会尝试去获取Checkpoint数据
rdd.doCheckpoint()
}
DAGScheduler 类:
def runJob[T, U](rdd: RDD[T],func: (TaskContext, Iterator[T]) => U,partitions: Seq[Int],callSite: CallSite,
resultHandler: (Int, U) => Unit,properties: Properties): Unit = {
//提交Job
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
}
def submitJob[T, U](rdd: RDD[T],func: (TaskContext, Iterator[T]) => U,partitions: Seq[Int],callSite: CallSite,
resultHandler: (Int, U) => Unit,properties: Properties): JobWaiter[U] = {
/**EventLoop 类:
private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()
//Put the event into the event queue. The event thread will process it later.
def post(event: E): Unit = {
if (!stopped.get) {
if (eventThread.isAlive) {
eventQueue.put(event)
} else {
onError(new IllegalStateException(s"$name has already been stopped accidentally."))
}
}
}
*/
// 封装JobSubmitted 信息并将其放入一个事件队列中,后面由 eventThread 去 eventQueue 里取事件
//并执行位于DAGScheduler 类中的 doOnReceive 方法
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
Utils.cloneProperties(properties)))
}
//doOnReceive 通过模式匹配 执行对应的程序
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
//handleJobSubmitted 方法中会调用以下程序开始进行阶段划分创建一个 ResultStage
// finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
// 将 finalStage 传入 submitStage 方法中开始提交
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
...
}
private def createResultStage(rdd: RDD[_],func: (TaskContext, Iterator[_]) => _,partitions: Array[Int],jobId: Int,
callSite: CallSite): ResultStage = {
//创建 ResultStage 之前会先获取 前面所有的ShuffleStage信息,并封装成一个List
val parents = getOrCreateParentStages(rdd, jobId)
// 将获得的List传入,创建一个 ResultStage 并返回给 finalStage,到此,阶段的划分就结束了
val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
}
private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
// 获取 所有 shuffle依赖
getShuffleDependencies(rdd).map { shuffleDep =>
//根据这些依赖,创建 ShuffleMapStage
/**
// 层层调用 getOrCreateParentStages 方法,直到当前rdd 的前面没有 shuffle依赖时,从当前rdd 开始
// 往后依次创建shuffle依赖数量的 ShuffleMapStage
val parents = getOrCreateParentStages(rdd, jobId)
val stage = new ShuffleMapStage(
id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
*/
getOrCreateShuffleMapStage(shuffleDep, firstJobId)
//将这些stage 封装成一个List 返回 给 parents
}.toList
}
private[scheduler] def getShuffleDependencies(
rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
// 用于存储所有的shuffle依赖
val parents = new HashSet[ShuffleDependency[_, _, _]]
// 已经被 获取过依赖的rdd 将会存在这里
val visited = new HashSet[RDD[_]]
// 正在被 获取依赖的rdd 将会存在这里
val waitingForVisit = new ListBuffer[RDD[_]]
waitingForVisit += rdd
while (waitingForVisit.nonEmpty) {
val toVisit = waitingForVisit.remove(0)
if (!visited(toVisit)) {
visited += toVisit
// 判断当前rdd 的依赖是否是 ShuffleDependency ,是的话就将这个依赖加入到 parents 中
// 不是的话,由当前依赖获取其指向的rdd,即当前rdd 的上一个 rdd ,再次循环判断,直到当前rdd 的前面没有rdd 了,终止
// 将 parents 返回
toVisit.dependencies.foreach {
case shuffleDep: ShuffleDependency[_, _, _] =>
parents += shuffleDep
case dependency =>
waitingForVisit.prepend(dependency.rdd)
}
}
}
parents
}
private def submitStage(stage: Stage): Unit = {
// 通过当前stage 获取 其前面的 stage
val missing = getMissingParentStages(stage).sortBy(_.id)
if (missing.isEmpty) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
// 如果当前的 stage 前面没有 stage 了就从当前stage 开始 依次 提交
submitMissingTasks(stage, jobId.get)
} else {
for (parent <- missing) {
submitStage(parent)
}
waitingStages += stage
}
}
private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
// 这里 使用模式匹配 按当前stage 的类型去创建 对应类型的Task
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear()
/**partitionsToCompute 中计算 Task 数量的源码
override def findMissingPartitions(): Seq[Int] = {
mapOutputTrackerMaster
.findMissingPartitions(shuffleDep.shuffleId)
.getOrElse(0 until numPartitions)
}
*/
// partitionsToCompute 默认根据 当前stage 最后一个RDD 的分区数来创建对应数量的Task
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
}
case stage: ResultStage =>
/**partitionsToCompute 中计算 Task 数量的源码
override def findMissingPartitions(): Seq[Int] = {
val job = activeJob.get
(0 until job.numPartitions).filter(id => !job.finished(id))
}
*/
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
stage.rdd.isBarrier())
}
}
}
图解
Task 类 是 ShuffleMapTask 和ResultTask 类的父类
4.任务的调度、执行
源码解析
DAGScheduler 类:
private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
if (tasks.nonEmpty) {
// 将封装好的Task 序列传入,taskScheduler开始提交
taskScheduler.submitTasks(new TaskSet(
tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties))
}
}
TaskSchedulerImpl 类:
override def submitTasks(taskSet: TaskSet): Unit = {
// 将传入的TaskSet 封装成 一个 TaskSetManager
/**
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
}
*/
val manager = createTaskSetManager(taskSet, maxTaskFailures)
// TaskSetManager 放入 调度器中 的rootPool 中 (可以理解为 一个任务池)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
/**
CoarseGrainedSchedulerBackend 类中:
val driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint())
override def reviveOffers(): Unit = {
其 makeOffers 方法 中的 resourceOffers 方法:
//We fill each node with tasks in a round-robin manner so that tasks are balanced across the cluster
// 采用轮循 的方法 ,让Task 均匀分布到 集群节点上
// resourceOffers 方法中:val sortedTaskSets = rootPool.getSortedTaskSetQueue.filterNot(_.isZombie)
任务调度器的算法:
FAIR => new FairSchedulingAlgorithm()
FIFO => new FIFOSchedulingAlgorithm()
其他 => 报异常
//getSortedTaskSetQueue 方法中将 通过 当前的 taskSetSchedulingAlgorithm (任务调度器的算法) 来将taskSet进行
// 排序 ,然后通过为这些TaskSet中Tasks设置优先级,表明以后这些Task优先发送到哪个节点执行
Task优先级别:
PROCESS_LOCAL : 进程本地化:数据和计算在同一个进程中
NODE_LOCAL, 节点本地化:数据和计算在同一个节点中
NO_PREF, 没有最佳位置,数据从哪访问都一样快,不需要位置优先。
RACK_LOCAL, 机架本地化:数据和计算在同一个机架中
ANY 跨机架,数据在非同一机架的网络上,速度最慢。
//scheduler.resourceOffers(workOffers)
// 发送 ReviveOffers 信息 ,通过接受并进行模式匹配获得的信息,调用 makeOffers 方法,在判断任务数不为空的情况下调用 launchTasks 方法
driverEndpoint.send(ReviveOffers)
}
*/
backend.reviveOffers()
}
def initialize(backend: SchedulerBackend): Unit = {
this.backend = backend
schedulableBuilder = {
// 通过模式匹配 创建对应的调度器 默认是 FIFO
schedulingMode match {
case SchedulingMode.FIFO =>
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
new FairSchedulableBuilder(rootPool, conf)
case _ =>
throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " +
s"$schedulingMode")
}
}
schedulableBuilder.buildPools()
}
CoarseGrainedSchedulerBackend 类:
private def launchTasks(tasks: Seq[Seq[TaskDescription]]): Unit = {
for (task <- tasks.flatten) {
// 将获得Task序列遍历出一个个的task,并序列化
val serializedTask = TaskDescription.encode(task)
// 判断 序列化任务数是否大于最大的RpcMessageSize
if (serializedTask.limit() >= maxRpcMessageSize) {
...
}else {
val executorData = executorDataMap(task.executorId)
// 给对应container所在的 executor端的Endpoint 发送LaunchTask 信息
executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
}
CoarseGrainedExecutorBackend 类:
override def receive: PartialFunction[Any, Unit] = {
//接收到 Driver端的信息 LaunchTask
case LaunchTask(data) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
// 将获得到的Task 进行反序列化
val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
taskResources(taskDesc.taskId) = taskDesc.resources
// 由前面环境准备时创建的executor计算对象进行调用launchTask方法,启动Task
executor.launchTask(this, taskDesc)
}
}
Executor 类:
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
// 创建一个TaskRunner 对象 用于启动Task,其中有 task.run() 调用Task 类中 run() 方法
// 然后调用对应Task子类 (ShuffleMapTask 和 ResultTask) 中重写的 runTask() 方法,启动Task
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
// 用一个线程池中的线程去执行上面创建的TaskRunner对象,启动Task
threadPool.execute(tr)
}
图解
Spark中Shuffle机制
1.ShuffleMapTask写数据流程源码分析
// ShuffleMapTask 运行时会调用 它自己类中重写的 runTask 方法
override def runTask(context: TaskContext): MapStatus = {
// shuffle 是要将数据进行落盘操作的,这里开始写数据
dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)
}
ShuffleWriteProcessor 类:
shuffleManager在Spark中分类:早期是HashShuffleMananger ,现在是SortShuffleManager
ShuffleHandler 分类:
def write(rdd: RDD[_],dep: ShuffleDependency[_, _, _],mapId: Long,context: TaskContext,partition: Partition): MapStatus = {
// 创建一个ShuffleWriter对象,用于接收后面获得的ShuffleWriter
var writer: ShuffleWriter[Any, Any] = null
// 获取shuffleManager
val manager = SparkEnv.get.shuffleManager
// 通过manager 调用getWriter 方法通过传入的 shuffleHandle 参数 获取对应的 写对象
writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
mapId,
context,
createMetricsReporter(context))
// SortShuffleWriter的话,一开始要判断是否支持预聚合,支持就有sorter对象中就有聚合器和排序器,否则没有
// 然后 sorter.insertAll(records) 开始将数据插入进去排序
// sorter.writePartitionedMapOutput 开始写数据,最终调用commitAllPartitions方法将index文件和date文件更新到最新,即把indexFile_tmp 和dataFile_tmp
// 数据处理后 删掉原来的indexFile和dataFile 将indexFile_tmp 和dataFile_tmp重命名为indexFile和dataFile
writer.write(
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
}
SortShuffleWriter 类:
override def write(records: Iterator[Product2[K, V]]): Unit = {
//判断依赖rdd是否支持预聚合,支持就有sorter对象中就有聚合器和排序器,否则没有
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 将数据插入进去排序
sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
// writePartitionedMapOutput 方***根据是否溢写
// 溢写到磁盘了 => 将内存和溢写磁盘的文件数据进行合并,使用堆排序实现了mergeSort
// 没有溢写 => 直接处理内存中的数据
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
val partitionLengths = mapOutputWriter.commitAllPartitions()
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
insertAll方法中会根据是否可以预聚合,选择不同的类处理数据 ,但底层存储数据都是数组
true => PartitionedAppendOnlyMap
false => PartitionedPairBuffer
之后会调用 maybeSpillCollection 方法 中的 maybeSpill 方法 判断是否需要进行溢写磁盘
Spillable 类:
/**
maybeSpillCollection 方法:
if (usingMap) {
// 获取当前内存所需的预估值
estimatedSize = map.estimateSize()
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
//myMemoryThreshold = "spark.shuffle.spill.initialMemoryThreshold" 为 5MB
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = acquireMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
shouldSpill = currentMemory >= myMemoryThreshold
}
//numElementsForceSpillThreshold = "spark.shuffle.spill.numElementsForceSpillThreshold" 为 Int 类型的最大值
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
//"spark.shuffle.file.buffer" 溢写到磁盘默认的缓冲区大小为32K
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}
ShuffleDependency 类:
// 创建 ShuffleDependency 类对象的时候会 向shuffleManager 注册 Shuffle写对象
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(shuffleId, this)
SortShuffleManager 类中:
override def registerShuffle[K, V, C](shuffleId: Int,dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// 当前依赖的rdd不会预聚合,且分区数小于等于200,使用 BypassMergeSortShuffleHandle
new BypassMergeSortShuffleHandle[K, V](
shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
//当前依赖的rdd 的序列化需要支持重定向(java序列化不支持,Kryo序列化支持),不会预聚合,分区数小于等于 MAXIMUM_PARTITION_ID (16777215 = (1 << 24) - 1) + 1
//满足上述条件时 SerializedShuffleHandle
new SerializedShuffleHandle[K, V](
shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
//其他情况 BaseShuffleHandle
new BaseShuffleHandle(shuffleId, dependency)
}
}
override def getWriter[K, V](handle: ShuffleHandle,mapId: Long,context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
// 根据传入的ShuffleHandle对象类型进行模式匹配,创建对应的ShuffleWriter对象,用于写shuffle数据
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
metrics,
shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
env.conf,
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
}
}
2.ResultTask 读数据流程源码分析
override def runTask(context: TaskContext): U = {
// iterator 这个方法中 computeOrReadCheckpoint方法中会调用compute方法
// 我们这里是说明shuffle机制中的读流程 调用的是shuffleRDD 中的 compute 方法
func(context, rdd.iterator(partition, context))
}
shuffleRDD 类
和写对象有多个不同,ShuffleReader只有一个为BlockStoreShuffleReader
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
// getReader 中会 new BlockStoreShuffleReader 其read方法中会配置一些读对象的参数
// "spark.reducer.maxSizeInFlight" :最大读缓冲区 为 48MB
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context, metrics)
// 读数据
.read()
.asInstanceOf[Iterator[(K, C)]]
}