ThreadLocal及线程间变量传递详解
ThreadLocal及线程间变量传递详解
面试速记
每个线程内部都有一个专属的 ThreadLocalMap,用来存放线程局部变量。ThreadLocalMap 的键是 ThreadLocal 对象,值是实际存储的数据。调用 get() 或 set() 时,会通过当前线程的 ThreadLocalMap 根据 ThreadLocal 对象的 hash 值快速定位并操作对应的值,从而实现数据隔离。
在多线程编程中,我们常常需要保证每个线程拥有自己的变量副本,避免线程间数据互相干扰。Java 提供的 ThreadLocal 就是一种用于存储线程局部变量的工具,它允许每个线程拥有自己独立的变量值。下面我们将从三个角度进行详细探讨。
概念及使用
public class ThreadLocalExample {
// 使用 withInitial() 定义初始值
private static final ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 0);
public static void main(String[] args) {
Thread thread1 = new Thread(() -> {
threadLocal.set(100);
System.out.println("Thread1: " + threadLocal.get());
});
Thread thread2 = new Thread(() -> {
threadLocal.set(200);
System.out.println("Thread2: " + threadLocal.get());
});
thread1.start();
thread2.start();
}
}
创建ThreadLocal对象,然后调用其get(),set()方法就可以在线程内任意存取变量
注意线程数据隔离,代码中的thread1,thread2通过set不同的值,取出不同的值,是因为底层维护还是依靠自己线程的一个属性ThreadLocalMap,其实就是在两个map中用相同的键(threadLocal)存储两个值,这样再取也是从两个map中取,详见下文解释
实现原理
每个线程内部存在一个ThreadLocalMap,键是ThreadLocal,值为对应的值
2.1 内部结构
-
ThreadLocalMap 每个线程对象(
Thread
)内部都有一个ThreadLocalMap
,用于存储该线程所有的 ThreadLocal 变量。ThreadLocal 本身只是一个访问接口,其具体数据则保存在该 Map 中。//Thread源码 public class Thread implements Runnable { /* *与此线程相关的 ThreadLocal 值。此映射由 ThreadLocal 类维护。 */ ThreadLocal.ThreadLocalMap threadLocals = null; }
-
键值对存储
- 键(Key): 实际上是 ThreadLocal 对象的弱引用(WeakReference),可以避免内存泄漏(当 ThreadLocal 对象不再被外部引用时,GC 可以回收它)。
- 值(Value): 即我们通过
set()
方法传入的数据。
2.2 关键方法
- set() 方法
当调用
ThreadLocal.set(value)
时,系统会获取当前线程的 ThreadLocalMap,将当前 ThreadLocal 对象作为键,value 作为值存入 Map 中。如果当前线程没有创建 ThreadLocalMap,则会进行初始化。 - get() 方法
调用
ThreadLocal.get()
时,同样会先获取当前线程的 ThreadLocalMap,然后通过当前 ThreadLocal 对象作为键查找对应的值。如果不存在,则调用initialValue()
方法(如果重写了该方法)来设置初始值。 - remove() 方法 用于手动移除当前线程中对应的线程局部变量,防止长时间持有引用造成内存泄漏,尤其在使用线程池等复用线程场景中尤为重要。
来从源码角度逐步拆分解析下
ThreadLocal的get()
public T get() {
Thread t = Thread.currentThread();//获取当前线程对象
ThreadLocal.ThreadLocalMap map = getMap(t);//取出当前线程的ThreadLocalMap
if (map != null) {
ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);//ThreadLocal为key获取entry
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;//已经存在该entry(执行过set()或初始化),则直接返回
return result;
}
}
//如果 map 不存在或没有找到对应的 Entry,则调用初始值方法初始化
return setInitialValue();
}
从getMap()源码也可以看出ThreadLocalMap作为Thread类的一个属性
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
然后来看看这个简单的初始化方法setInitialValue()
private T setInitialValue() {
T value = initialValue();//初始化值,可重写
Thread t = Thread.currentThread();
ThreadLocal.ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);//存在map则初始化值
} else {
createMap(t, value);//为当前线程创建ThreadLocalMap
}
if (this instanceof TerminatingThreadLocal) {
//jdk17的拓展类,提供线程终止时能自动通知并处理线程局部变量资源清理的机制
TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
}
return value;
}
Thread中创建Map的方法
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocal.ThreadLocalMap(this, firstValue);
}
然后是set():
public void set(T value) {
//同理将ThreadLocal作为键向map中set
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
createMap(t, value);
}
}
此外每个ThreadLocal内部还有一个hashcode:
public class ThreadLocal<T> {
// 每个 ThreadLocal 实例都有一个唯一的 hashCode,用于在 ThreadLocalMap 中定位存储位置
final int threadLocalHashCode = nextHashCode();
// 用于生成唯一的 hashCode(源码中采用原子递增方式,这里简化处理)
private static int nextHashCode() {
// 实际实现会考虑并发安全,这里为了说明原理,省略具体实现
return (int) (Math.random() * Integer.MAX_VALUE);
}
通过 ThreadLocal,每个线程都持有独立的变量副本,从而避免了数据竞争问题,无需进行同步(synchronized)的额外开销。
异步线程任务中上下文数据的传播
在线程池或异步任务(如 CompletableFuture、RxJava 等)中,线程并非父子线程一一对应,而是可能被重复使用。默认的 ThreadLocal 绑定的是线程本身,无法自动将父线程中的上下文数据传递到被复用的子线程中。
**解决方案: TransmittableThreadLocal (TTL)/InheritableThreadLocal **
传递时机 | 仅在子线程创建时一次性复制父线程数据,后续父线程更新不影响子线程。 | 在任务提交时捕获父线程数据,并在每次任务执行前复制最新数据到子线程(即使线程复用,也能保证数据是最新的)。 |
适用场景 | 适合父子线程关系明确的场景,子线程创建时即可获得父线程数据。 | 专为线程池和异步任务设计,确保每个任务启动时都能获得正确的上下文数据,适用于线程被复用的场景。 |
实现机制 | 基于 JDK 的简单复制机制,在子线程创建时自动复制父线程的 ThreadLocalMap 部分数据。 | 通过包装 Runnable/Callable,在任务提交时捕获数据、在执行前复制到子线程,并在任务结束后自动清理上下文数据,防止数据残留引起内存泄漏。 |
清理机制 | 没有内置的清理机制,可能导致线程长时间持有不再需要的数据。 | 内置任务包装器会在任务执行后自动清理传递的上下文数据,从而降低内存 |
有关TTL:
-
数据捕获与传递: 在任务提交时,捕获父线程中的 ThreadLocal 数据;在子线程(或线程池中任务执行前)复制该数据,从而确保每个任务都能拿到最新的上下文信息。
-
自动清理: 任务执行完毕后,自动清除子线程中的上下文数据,降低内存泄露的风险。
-
使用示例:
java复制编辑TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>(); context.set("父线程数据"); Runnable task = () -> { // 子线程能够获取到父线程传递的数据 System.out.println("子线程数据:" + context.get()); }; // 使用包装器将 Runnable 中的 ThreadLocal 数据传递到子线程中 ExecutorService executor = Executors.newFixedThreadPool(2); executor.submit(TtlRunnable.get(task)); executor.shutdown();
其他方案: 除 TTL 外,一些异步框架或利用 AOP、拦截器手段也能实现上下文数据的捕获和传递。另外,也可以在业务层设计一个 Context 对象,通过方法参数进行数据传递。
应用场景:
-
分布式链路追踪: 传递统一的上下文信息以实现调用链追踪(日志、监控等)。
-
安全与权限控制: 确保异步任务中能够获取到调用链中的安全上下文或权限数据。
两个运行线程的变量共享传递
TTL 和 InheritableThreadLocal 的设计目标都是在新建线程或任务提交时,将父线程中的数据复制到子线程中。这种传递是“一次性”的,在任务启动时完成数据复制,而不是持续的双向共享。 如果两个线程已经在运行中且不存在明确的父子关系,这些机制不会自动共享数据。也就是说,它们适合用来解决新创建线程或线程池中任务的上下文传递问题,而不能用于已经运行的独立线程之间的数据共享。 对于运行中的线程需要共享数据,可以采用以下方式:
- 共享数据结构: 利用线程安全的集合(如 ConcurrentHashMap)存储共享数据。
- 消息队列或事件机制: 通过消息或事件通知实现线程间通信。
- 显式传递参数: 在方法调用时将上下文数据作为参数传递。
下面是一个利用ConcurrentHashMap的例子:
为了让两个或多个线程能够共享数据,必须保证它们访问的是同一个 ConcurrentHashMap 实例。通常的做法是:
- 静态变量或单例: 将 ConcurrentHashMap 定义为静态变量或放在一个单例对象中,确保所有线程都能获得同一份数据。
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class SharedDataExample {
// 创建一个共享的 ConcurrentHashMap
private static ConcurrentHashMap<String, Object> sharedMap = new ConcurrentHashMap<>();
public static void main(String[] args) {
// 使用固定线程池,确保两个线程共享同一 Map 实例
ExecutorService executor = Executors.newFixedThreadPool(2);
// 第一个线程:更新共享计数器
executor.submit(() -> {
for (int i = 0; i < 5; i++) {
sharedMap.compute("counter", (key, value) -> {
if (value == null) {
return 1;
} else {
return ((Integer) value) + 1;
}
});
System.out.println("线程1更新后 counter 值:" + sharedMap.get("counter"));
try {
Thread.sleep(500); // 模拟工作负载
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
// 第二个线程:读取共享计数器的值
executor.submit(() -> {
for (int i = 0; i < 5; i++) {
System.out.println("线程2读取到 counter 值:" + sharedMap.get("counter"));
try {
Thread.sleep(700); // 模拟不同的工作节奏
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
// 关闭线程池
executor.shutdown();
}
}