发布于 

深入了解与使用ThreadLocal

什么是ThreadLocal

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its get or set method) has its own, independently initialized copy of the variable. ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).

For example, the class below generates unique identifiers local to each thread. A thread’s id is assigned the first time it invokes ThreadId.get() and remains unchanged on subsequent calls.

ThreadLocal提供了线程的局部变量,每个线程都可以通过set()get()方法操作此变量,并且不会与其他线程的局部变量进行冲突,保证了线程间局部变量的隔离行。

翻译一下:ThrealLocal类型的变量属于当前线程,哪怕是同一个方法内的同一个ThreadLocal变量,他们的值在不同的业务运行下也是不一样的,线程安全。

ThreadLocal业务应用场景

数据库连接

jdbc时代,数据库连接connection需要我们手动来维护,同一个库连接我们可以用数据库连接池来实现。但是不同数据库连接的时候,不同的线程需要获取到不同的连接。这时候就可以使用ThreadLocal来维护线程池中不同数据源连接。

示例代码:

来源:https://blog.csdn.net/u013360850/article/details/78861442
public class DynamicDataSourceContextHolder {

private static final Logger logger = LoggerFactory.getLogger(DynamicDataSourceContextHolder.class);

/**
* 用于在切换数据源时保证不会被其他线程修改
*/
private static Lock lock = new ReentrantLock();

/**
* 用于轮循的计数器
*/
private static int counter = 0;

/**
* Maintain variable for every thread, to avoid effect other thread
*/
private static final ThreadLocal<Object> CONTEXT_HOLDER = ThreadLocal.withInitial(DataSourceKey.master);


/**
* All DataSource List
*/
public static List<Object> dataSourceKeys = new ArrayList<>();

/**
* The constant slaveDataSourceKeys.
*/
public static List<Object> slaveDataSourceKeys = new ArrayList<>();

/**
* To switch DataSource
*
* @param key the key
*/
public static void setDataSourceKey(String key) {
CONTEXT_HOLDER.set(key);
}

/**
* Use master data source.
*/
public static void useMasterDataSource() {
CONTEXT_HOLDER.set(DataSourceKey.master);
}

/**
* 当使用只读数据源时通过轮循方式选择要使用的数据源
*/
public static void useSlaveDataSource() {
lock.lock();

try {
int datasourceKeyIndex = counter % slaveDataSourceKeys.size();
CONTEXT_HOLDER.set(String.valueOf(slaveDataSourceKeys.get(datasourceKeyIndex)));
counter++;
} catch (Exception e) {
logger.error("Switch slave datasource failed, error message is {}", e.getMessage());
useMasterDataSource();
e.printStackTrace();
} finally {
lock.unlock();
}
}

/**
* Get current DataSource
*
* @return data source key
*/
public static String getDataSourceKey() {
return CONTEXT_HOLDER.get();
}

/**
* To set DataSource as default
*/
public static void clearDataSourceKey() {
CONTEXT_HOLDER.remove();
}

/**
* Check if give DataSource is in current DataSource list
*
* @param key the key
* @return boolean boolean
*/
public static boolean containDataSourceKey(String key) {
return dataSourceKeys.contains(key);
}
}

以上代码用于数据源上下文配置,用于切换数据源。

全局变量传递

最典型的例子就是用户数据的传递,请求的时候通过Header将用户的token传递,通过拦截器(过滤器、AOP)等方式将token放到ThreadLocal,让当前线程的相关方法共享该变量,减少参数传递。

示意图:

image-20210428223107168

示例代码:

public class ContextHolder {

private static final ThreadLocal<AccountLoginInfoBo> ACCOUNT_LOGIN_INFO_HOLDER = new ThreadLocal<>();

private static final ThreadLocal<String> REQUEST_ID_HOLDER = new ThreadLocal<>();

public static AccountLoginInfoBo getAccountLoginInfo() {
return ACCOUNT_LOGIN_INFO_HOLDER.get();
}
// ... 省略其他代码
}

链路追踪

微服务架构下,多个服务之间调用如果出现了报错,使用链路追踪是一个常见的手段。在请求的线程变量中嵌入traceId,根据这个traceId就可以找到对应请求分别在各个业务应用的日志。

示例代码:

public class GeneralInterceptorHandler extends HandlerInterceptorAdapter {

@Override
public boolean preHandle(HttpServletRequest request,
javax.servlet.http.HttpServletResponse response, Object handler) throws Exception {
request.setAttribute("_timestamp", System.currentTimeMillis());

String requestId = RandomStringUtils.randomNumeric(10);
MDC.put("requestId", requestId);

BaseContextHolder.setRequestId(requestId);
String requestUri = request.getRequestURI();
BaseContextHolder.setRequestUrl(requestUri);
log.info("request uri: {}", requestUri);

return super.preHandle(request, response, handler);
}

@Override
public void postHandle(HttpServletRequest request,
javax.servlet.http.HttpServletResponse response, Object handler,
ModelAndView modelAndView) throws Exception {

long begin = (long) request.getAttribute("_timestamp");
long end = System.currentTimeMillis();
log.info("process success. cost {}ms. ", end - begin);
MDC.remove("requestId");
BaseContextHolder.remove();
super.postHandle(request, response, handler, modelAndView);
}
}

示例结果:

image-20210428223736918

在以上代码中,除了将生成的请求ID放到MDC中,还将其放入到Context类中定义的ThreadLocal中,方便在代码中直接拿出来使用,这样通过日志就能打印出每次的请求ID,在微服务请求链路中查找日志只需要使用cat xxx.log | rep 8475913673就可以找到本次请求的所有链路日志。

ThreadLocal原理解析

顺便了解一下ThreadLocal原理,看一下set()get()remove()方法的实现。

set()方法

/**
* Sets the current thread's copy of this thread-local variable
* to the specified value. Most subclasses will have no need to
* override this method, relying solely on the {@link #initialValue}
* method to set the values of thread-locals.
*
* @param value the value to be stored in the current thread's copy of
* this thread-local.
*/
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取维护当前线程变量的ThreadLocalMap数据,一种类似于HashMap的数据结构
ThreadLocalMap map = getMap(t);
// 如果当前线程已经存在了Map,直接调用map.set
if (map != null)
map.set(this, value);
// 不存在Map,则先进行新增map,再进行set
else
createMap(t, value);
}

查看源码发现set()方法中使用到了ThreadLocalMap类。

展开查看ThreadLocalMap类源码
static class ThreadLocalMap {

/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;

/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;

/**
* The number of entries in the table.
*/
private int size = 0;

/**
* The next size value at which to resize.
*/
private int threshold; // Default to 0

/**
* Set the resize threshold to maintain at worst a 2/3 load factor.
*/
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

/**
* Increment i modulo len.
*/
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

/**
* Decrement i modulo len.
*/
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

/**
* Construct a new map initially containing (firstKey, firstValue).
* ThreadLocalMaps are constructed lazily, so we only create
* one when we have at least one entry to put in it.
*/
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

/**
* Construct a new map including all Inheritable ThreadLocals
* from given parent map. Called only by createInheritedMap.
*
* @param parentMap the map associated with parent thread.
*/
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

/**
* Get the entry associated with key. This method
* itself handles only the fast path: a direct hit of existing
* key. It otherwise relays to getEntryAfterMiss. This is
* designed to maximize performance for direct hits, in part
* by making this method readily inlinable.
*
* @param key the thread local object
* @return the entry associated with key, or null if no such
*/
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

/**
* Version of getEntry method for use when key is not found in
* its direct hash slot.
*
* @param key the thread local object
* @param i the table index for key's hash code
* @param e the entry at table[i]
* @return the entry associated with key, or null if no such
*/
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

/**
* Set the value associated with key.
*
* @param key the thread local object
* @param value the value to be set
*/
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

/**
* Remove the entry for key.
*/
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

/**
* Replace a stale entry encountered during a set operation
* with an entry for the specified key. The value passed in
* the value parameter is stored in the entry, whether or not
* an entry already exists for the specified key.
*
* As a side effect, this method expunges all stale entries in the
* "run" containing the stale entry. (A run is a sequence of entries
* between two null slots.)
*
* @param key the key
* @param value the value to be associated with key
* @param staleSlot index of the first stale entry encountered while
* searching for key.
*/
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// Find either the key or trailing null slot of run, whichever
// occurs first
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;

// Start expunge at preceding stale entry if it exists
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

/**
* Expunge a stale entry by rehashing any possibly colliding entries
* lying between staleSlot and the next null slot. This also expunges
* any other stale entries encountered before the trailing null. See
* Knuth, Section 6.4
*
* @param staleSlot index of slot known to have null key
* @return the index of the next null slot after staleSlot
* (all between staleSlot and this slot will have been checked
* for expunging).
*/
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// expunge entry at staleSlot
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

/**
* Heuristically scan some cells looking for stale entries.
* This is invoked when either a new element is added, or
* another stale one has been expunged. It performs a
* logarithmic number of scans, as a balance between no
* scanning (fast but retains garbage) and a number of scans
* proportional to number of elements, that would find all
* garbage but would cause some insertions to take O(n) time.
*
* @param i a position known NOT to hold a stale entry. The
* scan starts at the element after i.
*
* @param n scan control: {@code log2(n)} cells are scanned,
* unless a stale entry is found, in which case
* {@code log2(table.length)-1} additional cells are scanned.
* When called from insertions, this parameter is the number
* of elements, but when from replaceStaleEntry, it is the
* table length. (Note: all this could be changed to be either
* more or less aggressive by weighting n instead of just
* using straight log n. But this version is simple, fast, and
* seems to work well.)
*
* @return true if any stale entries have been removed.
*/
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

/**
* Re-pack and/or re-size the table. First scan the entire
* table removing stale entries. If this doesn't sufficiently
* shrink the size of the table, double the table size.
*/
private void rehash() {
expungeStaleEntries();

// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}

/**
* Double the capacity of the table.
*/
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}

/**
* Expunge all stale entries in the table.
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}

其中维护了一个entry结构用来用来维护节点的数据,细心地同学应该已经发现了Entry这个结构继承了WeakReference,从构造方法可以看出,ThreadLocalMapKey是软引用维护的。

ThreadLocal.ThreadLocalMap threadLocals = null;

上面这行代码是一个成员变量:*针对于每一个线程,都是独立维护一个ThreadLocalMap,一个线程也可以拥有多个ThreadLocal变量。*

get()方法

get()
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
setInitialValue()
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
initialValue()
protected T initialValue() {
return null;
}

get()方法整体上比较简单,贴上了关键逻辑逻辑代码,调用get()时,如果存在值,则将值返回,不存在值调用setInitialValue()获取值,其中初始化的值为null,也就是说如果ThreadLocal变量未被赋值,或者赋值后被remove掉了,直接调用get()方法不会报错,将会返回null值。

remove()方法

ThreadLocal#remove()
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
ThreadLocalMap#remove
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

remove方法调用时会判断当前线程中ThreadLocalMap是否存在,如果存在则调用ThreadLocalMap.remove(key);遍历链表结构移除entry节点。

小结

  • Thread维护着一个ThreadLocalMap的引用,其中ThreadLocalMapkeyWeakReference维护。
  • ThreadLocal本身并不存储值,ThreadLocal通过操作ThreadLocalMap达到对线程变量的赋值,获取,删除操作。

ThreadLocal内存泄漏问题

JVM中对ThreadLocal的堆栈维护图:

image-20210426152220803.png

(图片来源见水印)

entry对于value的引用为强应用,key的引用为弱引用。

如果一个对象只是被弱引用引用者,那么只要发生 GC,不管内存空间是否足够,都会回收该对象。

那么问题就来了,如果操作ThreadLocal变量的方法QPS很高,疯狂被请求,这个时候调用了set(),get()方法,并未调用remove方法,那么,当GC发生。entryThreadLocal的关联关系中断,Key被回收,value还被强连接关联着。这样跟垃圾回收可达性分析,value仍旧为可达,但是从业务角度上看,这个value值将永远访问不到,出现了内存泄露。

因此在使用ThreadLocal时必须要显示的调用remove方法,否则出现了问题,排查起来都很麻烦。

  • 可以在拦截器、AOP、过滤器结束的时候调用remove方法

线程池中线程上下文丢失

ThreadLocal不能在父子线程中传递,因此最常见的做法是把父线程中的ThreadLocal值拷贝到子线程中,因此大家会经常看到类似下面的这段代码:

for(value in valueList){
Future<?> taskResult = threadPool.submit(new BizTask(ContextHolder.get()));//提交任务,并设置拷贝Context到子线程
results.add(taskResult);
}
for(result in results){
result.get();//阻塞等待任务执行完成
}

提交的任务定义长这样:

class BizTask<T> implements Callable<T>  {
private String session = null;

public BizTask(String session) {
this.session = session;
}

@Override
public T call(){
try {
ContextHolder.set(this.session);
// 执行业务逻辑
} catch(Exception e){
//log error
} finally {
ContextHolder.remove(); // 清理 ThreadLocal 的上下文,避免线程复用时context互串
}
return null;
}
}

对应的线程上下文管理类:

class ContextHolder {
private static ThreadLocal<String> localThreadCache = new ThreadLocal<>();

public static void set(String cacheValue) {
localThreadCache.set(cacheValue);
}

public static String get() {
return localThreadCache.get();
}

public static void remove() {
localThreadCache.remove();
}

}

线程池的设置:

ThreadPoolExecutor executorPool 
= new ThreadPoolExecutor(20, 40, 30, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(40),
new XXXThreadFactory(), ThreadPoolExecutor.CallerRunsPolicy);

其中最后一个参数控制着当线程池满时,该如何处理提交的任务,内置有4种策略:

ThreadPoolExecutor.AbortPolicy //直接抛出异常
ThreadPoolExecutor.DiscardPolicy //丢弃当前任务
ThreadPoolExecutor.DiscardOldestPolicy //丢弃工作队列头部的任务
ThreadPoolExecutor.CallerRunsPolicy //转串行执行

可以看到,我们初始化线程池的时候指定如果线程池满,则新提交的任务转为串行执行,那我们之前的写法就会有问题了,串行执行的时候调用ContextHolder.remove();会将主线程的上下文也清理,即使后面线程池继续并行工作,传给子线程的上下文也已经是null了,而且这样的问题很难在预发测试的时候发现。

并行流中线程上下文丢失

如果ThreadLocal碰到并行流,也会有很多有意思的事情发生,比如有下面的代码:

class ParallelProcessor<T> {

public void process(List<T> dataList) {
// 先校验参数,篇幅限制先省略不写
dataList.parallelStream().forEach(entry -> {
doIt();
});
}

private void doIt() {
String session = ContextHolder.get();
// do something
}
}

这段代码很容易在线下测试的过程中发现不能按照预期工作,因为并行流底层的实现也是一个ForkJoin线程池,既然是线程池,那ContextHolder.get()可能取出来的就是一个null。我们顺着这个思路把代码再改一下:\

class ParallelProcessor<T> {

private String session;

public ParallelProcessor(String session) {
this.session = session;
}

public void process(List<T> dataList) {
// 先校验参数,篇幅限制先省略不写
dataList.parallelStream().forEach(entry -> {
try {
ContextHolder.set(session);
// 业务处理
doIt();
} catch (Exception e) {
// log it
} finally {
ContextHolder.remove();
}
});
}

private void doIt() {
String session = ContextHolder.get();
// do something
}
}

修改完后的这段代码可以工作吗?如果运气好,你会发现这样改又有问题,运气不好,这段代码在线下运行良好,这段代码就顺利上线了。不久你就会发现系统中会有一些其他很诡异的bug。原因在于并行流的设计比较特殊,父线程也有可能参与到并行流线程池的调度,那如果上面的process方法被父线程执行,那么父线程的上下文会被清理。导致后续拷贝到子线程的上下文都为null,同样产生丢失上下文的问题。