ThreadLocal源码分析

发布于:2025-02-15 ⋅ 阅读:(17) ⋅ 点赞:(0)

1.核心数据结构 ThreadLocalMap

1.静态内部类 Entry
// 继承了弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
    // value就是平常定义ThreadLocal中存的那个东西
    Object value;
   
    // key就是ThreadLocal(是弱引用)
    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}
2.真正存储数据的是table数组

CleanShot 2025-02-02 at 11.57.50@2x

2.ThreadLocal.set()方法源码详解

1.set
    public void set(T value) {
        // 获取当前线程
        Thread t = Thread.currentThread();
        // 获取当前线程的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        // 如果map不为空,就将当前的ThreadLocal作为key,ThreadLocal存的值作为value设置到map中
        if (map != null) {
            map.set(this, value);
        } else {
            // 如果map为空,就使用当前线程和ThreadLocal存的值去创建一个map
            createMap(t, value);
        }
    }
2.getMap
    ThreadLocalMap getMap(Thread t) {
        // 返回当前线程的ThreadLocalMap
        return t.threadLocals;
    }

CleanShot 2025-02-02 at 12.10.06@2x

3.ThreadLocalMap.set
private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    // 存储数据的Entry数组
    Entry[] tab = table;
    // 数组的长度
    int len = tab.length;
    // 计算当前的ThreadLocal对象在Entry数组中的位置,其实跟hashmap一样,都是(n - 1) & hash
    int i = key.threadLocalHashCode & (len-1);

    // 只要目标位置不为空,就进行循环
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
				// 期间如果遇到key相同的,则替换value
        if (k == key) {
            e.value = value;
            return;
        }
				// 如果遇到key为null的,就会执行替换过期数据的逻辑
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
		
 		// 如果目标位置为空,就直接将这个Entry对象放进去
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 如果清理过期槽位失败并且元素数量大于等于阈值(数组长度乘2/3)就进行rehash逻辑
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
4.createMap
void createMap(Thread t, T firstValue) {
    // 使用ThreadLocal对象和value去创建一个map
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    // table的初始容量为16
    table = new Entry[INITIAL_CAPACITY];
    // 使用(n - 1) & hash去找到指定的table数组的位置
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 给指定位置的table数组设置一个Entry对象
    table[i] = new Entry(firstKey, firstValue);
    // 设置数组大小为1
    size = 1;
    // 设置阈值
    setThreshold(INITIAL_CAPACITY);
}
private void setThreshold(int len) {
    // 阈值为16 * 2/3
    threshold = len * 2 / 3;
}
5.rehash
private void rehash() {
    // 进行探测式清理工作
    expungeStaleEntries();

    // 如果清理后 size >= threshold * 3/4 就进行扩容
    if (size >= threshold - threshold / 4)
        resize();
}
6.resize
private void resize() {
    // 旧的Entry数组
    Entry[] oldTab = table;
    // 旧的容量
    int oldLen = oldTab.length;
    // 新的长度为旧长度的两倍
    int newLen = oldLen * 2;
    // 创建一个新的Entry数组
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    // 遍历旧的数组,将旧数组中的元素重新hash后放到新数组中,如果哈希冲突就往后放
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

3.ThreadLocalMap.get()详解

1.get
public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程的ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    // 只要map不为空
    if (map != null) {
        // 通过当前的ThreadLocal作为key去获取到对应的ThreadLocalMap.Entry
        ThreadLocalMap.Entry e = map.getEntry(this);
        // 只要不为空就返回这个value
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // 如果map为空就返回一个初始值null
    return setInitialValue();
}
2.ThreadLocalMap.getEntry
private Entry getEntry(ThreadLocal<?> key) {
    // (n - 1) & hash 找到table数组的位置
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 如果key相同,直接返回
    if (e != null && e.get() == key)
        return e;
    else
   		  // 如果key不同,就向后找,如果最后找到了就返回,找不到就返回null
        return getEntryAfterMiss(key, i, e);
}
3.getEntryAfterMiss
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
					
            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

4.ThreadLocalMap过期 key 的探测式清理流程

遍历散列数组,从开始位置向后探测清理过期数据,将过期数据的Entry设置为null

沿途中碰到未过期的数据则将此数据rehash后重新在table数组中定位,如果定位的位置已经有了数据,则会将未过期的数据放到最靠近此位置的Entry=null的桶中,使rehash后的Entry数据距离正确的桶的位置更近一些。

5.InheritableThreadLocal原理

实现原理是子线程是通过在父线程中通过调用new Thread()方法来创建子线程,Thread#init方法在Thread的构造方法中被调用。在init方法中拷贝父线程数据到子线程中:

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    this.stackSize = stackSize;
    tid = nextThreadID();
}