十、Map 子接口之 ConcurrentHashMap

ConcurrentHashMap 底层是基于
数组 +链表组成的,JDK 1.7 和 JDK 1.8 中具体的实现稍有不同。

10.1 JDK 1.7 中的 ConcurrentHashMap

01、JDK 1.7 中 ConcurrentHashMap 底层结构

JDK 1.7 的 ConcurrentHashMap 类所采用的是分段锁的思想,将 HashMap 进行切割,把 HashMap 中的哈希数组切分成小数组,每个小数组由 n 个 HashEntry 组成,其中小数组(Segment)继承自 ReentrantLock(可重入锁)。

JDK 1.7 中的数据结构:
在这里插入图片描述

如图所示,是由 Segment 数组、HashEntry 组成,与 HashMap 一样,仍然是数组 + 链表

ConcurrentHashMap 由很多个 Segment 组合,而每一个 Segment 是一个类似于 HashMap 的结构(子哈希表)。所以每个 HashMap 的内部都可以进行扩容。但是 Segment 的个数是 16 个,也可以认为 ConcurrentHashMap 默认支持最多 16 个线程并发。其中,Segment 里维护了一个 HashEntry 数组,Segment 继承自 ReentrantLock,并发环境下,对于不同的 Segment 数据进行操作是不用考虑锁竞争的,因此不会像 Hashtable 那样不管是添加、删除、查询操作都需要同步处理。

扒一下 ConcurrentHashMap 类:

public class ConcurrentHashMap<K,V> extends AbstractMap<K,V>
    implements ConcurrentMap<K,V>, Serializable {
    /**
     * The segments, each of which is a specialized hash table.
     */
    final Segment<K, V>[] segments;
    ...
}

Segment 类是 ConcurrentHashMap 的一个静态内部类,内部结构跟 HashMap 差不多:

static final class Segment<K,V> extends ReentrantLock implements Serializable {
	private static final long serialVersionUID = 2249069246763182397L;
	// 1. 和 HashMap 中的 HashEntry 作⽤⼀样,是真正存放数据的数组位桶
	transient volatile HashEntry<K,V>[] table;
	// 2. table 的数组容量
	transient int count;
	// 3. 记录修改次数的变量
	transient int modCount;
	// 4. 阈值⼤⼩,所能容纳的元素极限,用于扩容判断
	transient int threshold;
	// 5. 负载因⼦,用于扩容
	final float loadFactor;
}

存放元素的 HashEntry,也是一个静态内部类:

static final class HashEntry<K, V> {
	// hash 值
	final int hash;
	// 键
	final K key;
	// 值
	volatile V value;
	// 下一个节点
	volatile HashEntry<K, V> next;

	HashEntry(int hash, K key, V value, HashEntry<K, V> next) {
		this.hash = hash;
		this.key = key;
		this.value = value;
		this.next = next;
	}
}

HashEntry 和 HashMap 中的 Entry 非常类似,唯一的区别就是其中的核心数据,比如 value、next 都使用了 修饰,这就保证了多线程环境下数据获取时的可见性。

volatile 关键字的特性:

  1. 保证了不同线程对这个变量进行操作时的可见性,即一个线程修改了某个变量的值,这个新的值对其他线程来说是可见的(实时可见性);
  2. 禁止进行指令重排序(实现有序性);
  3. volatile 只能保证对单次读写的原子性,像 i++ 这种操作不能保证原子性。

从类的定义上可以看到,Segment 这个静态内部类继承了 ReetrantLock 类

02、ConcurrentHashMap 的常量

// 初始初始容量
static final int DEFAULT_INITIAL_CAPACITY = 16;
// 默认负载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
// 初始的并发等级
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// 最大容量
static final int MAXIMUM_CAPACITY = 1 << 30;
// segment最小容量
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
// 一个segment最大容量
static final int MAX_SEGMENTS = 1 << 16; 
// 锁之前重试次数
static final int RETRIES_BEFORE_LOCK = 2;

03、初始化

public ConcurrentHashMap() {
	// DEFAULT_INITIAL_CAPACITY 表示初始化容量,默认为 16
	// DEFAULT_LOAD_FACTOR 表示负载因子,默认为 0.75
	// DEFAULT_CONCURRENCY_LEVEL 表示 Segment[] 初始并发等级,默认为 16
    this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
 
public ConcurrentHashMap(int initialCapacity) {
    this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
 
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
    this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
}

this 调用对应的构造方法:

// 通过指定的容量、负载因子和并发等级创建一个新的ConcurrentHashMap
public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
     // 1. 参数校验:对容量、负载因子和并发等级做限制
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    // 2. MAX_SEGMENTS = 1 << 16 = 65536,限制并发等级不可以大于最大等级,如果并发量大于 65536,concurrencyLevel = 65536
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    // 下面即通过并发等级来确定Segment的大小
    // 3. sshift用来记录向左按位移动的次数,2的多少次方
    int sshift = 0;
    // 4. ssize用来记录Segment数组的大小
    int ssize = 1;
    // Segment的大小为大于等于concurrencyLevel的第一个2的n次方的数(也就是concurrencyLevel之上最近的2的次方值)
    while (ssize < concurrencyLevel) {
    	// 新增移动次数,直到 ssize >= concurrencyLevel 为止,concurrencyLevel 为 16,循环之后 ssize = 16
        ++sshift;
        // 位移动,向左移动 1
        ssize <<= 1;
    }
    
	// 5. segmentShift、segmentMask 用于元素在Segment[]数组的定位
	// 记录段偏移量
    this.segmentShift = 32 - sshift;
    // segmentMask的值等于ssize - 1(这个值很重要)
    // 记录段掩码
    this.segmentMask = ssize - 1;
    
    // 6. 传入初始化的容量值大于最大容量值,则默认为最大容量值
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    // 7. c记录每个Segment上要放置多少个元素,即Segment中HashEntry的数组长度,c也一定为 2 的 n 次方,这里的计算类似于 HashMap 的容量
    int c = initialCapacity / ssize;
    // 假如有余数,则Segment数量加1
    if (c * ssize < initialCapacity)
        ++c;
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    // Segment中的类似于HashMap的容量,至少是2或者2的倍数
    while (cap < c)
        cap <<= 1;

    // 8. 创建第一个Segment对象,并放入Segment[]数组中,作为第一个Segment,默认数组长度为 2
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    // 9. 创建Segment[],指定segment数组的长度,默认数组长度为 16
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];

	// 10. 使用CAS方式,将上面创建的segment对象放入segment[]数组中
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

总结一下在 JDK 1.7 中 ConcurrnetHashMap 的初始化逻辑。

  1. 必要参数校验;
  2. 校验并发级别 concurrencyLevel 大小,如果大于最大值,重置为最大值。无参构造默认值是 16;
  3. 寻找并发级别 concurrencyLevel 之上最近的 2 的幂次方值,作为初始化容量大小,默认是 16;
  4. 记录 segmentShift 偏移量,这个值为【容量 = 2 的N次方】中的 N,在后面 Put 时计算位置时会用到。默认是 32 - sshift = 28;
  5. 记录 segmentMask,默认是 ssize - 1 = 16 -1 = 15;
  6. 初始化 segments[0],默认大小为 2,负载因子 0.75,扩容阀值是 2*0.75=1.5,插入第二个值时才会进行扩容。

从源码上可以看出,ConcurrentHashMap 初始化方法有三个参数:initialCapacity(初始化容量)为16、loadFactor(负载因子)为0.75、concurrentLevel(并发等级)为16,如果不指定则会使用默认值。

其中,值得注意的是 concurrentLevel 这个参数,虽然 Segment 数组大小 ssize 是由 concurrentLevel 来决定的,但是却不一定等于 concurrentLevel,ssize 通过位移动运算,一定是大于或者等于 concurrentLevel 的最小的 2 的次幂!

通过计算可以看出,按默认的 initialCapacity 初始容量为16,concurrentLevel 并发等级为16,理论上就允许 16 个线程并发执行,并且每一个线程独占一把锁访问 Segment,不影响其它的 Segment 操作。

从之前的文章中,我们了解到 HashMap 在多线程环境下操作可能会导致程序死循环,仔细想想就会发现,造成这个问题无非是 put 和扩容阶段发生的,所以在 ConcurrentHashMap 中对 put 操作做了一些改变。

04、put 操作

扒一下 put() 方法的源码:

public V put(K key, V value) {
    Segment<K,V> s;
    // 1. ConcurrentHashMap中key和value都不能为null
    if (value == null)
        throw new NullPointerException();

    // 2. 计算key的哈希值
    int hash = hash(key);
    // 3. 通过key的哈希值,定位ConcurrentHashMap中Segment[]的角标
    // hash 值无符号右移28位(初始化时获得),然后与segmentMask=15做与运算
    int j = (hash >>> segmentShift) & segmentMask;

	// 4. 使用CAS的方式,从Segment[]中获取j角标下的Segment对象,并判断是否存在
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        // 如果查到的Segment为空,初始化
        s = ensureSegment(j);
    // 5. 底层使用了Segment对象的put()方法
    return s.put(key, hash, value, false);
}

从源码中可以看出,put 操作主要分为两步:

  1. 获取要 put 的 key 的位置,获取指定位置的 Segment;
  2. 如果指定位置的 Segment 为空,则初始化这个 Segment;
  3. 调用 Segment 的 put() 方法。

扒一下初始化 Segment() 方法的源码:

@SuppressWarnings("unchecked")
private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    // 判断 u 位置的 Segment 是否为null
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        // 获取0号 segment 里的 HashEntry<K,V> 初始化长度
        int cap = proto.table.length;
        // 获取0号 segment 里的 hash 表里的扩容负载因子,所有的 segment 的 loadFactor 是相同的
        float lf = proto.loadFactor;
        // 计算扩容阀值
        int threshold = (int)(cap * lf);
        // 创建一个 cap 容量的 HashEntry 数组
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // recheck
            // 再次检查 u 位置的 Segment 是否为null,因为这时可能有其他线程进行了操作
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            // 自旋检查 u 位置的 Segment 是否为null
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                // 使用CAS 赋值,只会成功一次
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

初始化 Segment 流程:

  1. 检查计算得到的位置的 Segment 是否为 null;
  2. 为 null 继续初始化,使用 Segment[0] 的容量和负载因子创建一个 HashEntry 数组;
  3. 再次检查计算得到的指定位置的 Segment 是否为 null;
  4. 使用创建的 HashEntry 数组初始化这个 Segment;
  5. 自旋判断计算得到的指定位置的 Segment 是否为null,使用 CAS 在这个位置赋值为 Segment。

真正插入元素的 put() 方法:

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
	// 1. tryLock是ReentrantLock类中的方法,表示尝试从CPU手上获取到锁。如果锁没有被另外一个线程持有,获得锁并返回true;否则返回false,并调用scanAndLockForPut()方法
    // 这里是并发的关键,每一个Segment进行put时,都会加锁
    HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        // 2. 获取Segment对象中的HashEntry[]数组
        HashEntry<K,V>[] tab = table;
        // 3. 确定key的hash值所在HashEntry数组的索引位置
        int index = (tab.length - 1) & hash;
        // 4. 根据索引获取HashEntry对象
        HashEntry<K,V> first = entryAt(tab, index);
        // 5. 遍历当前HashEntry链
        for (HashEntry<K,V> e = first;;) {
        	// 判断逻辑与HashMap相似
            // 如果链头不为null
            if (e != null) {
                K k;
                // 如果在该链中找到相同的key,则用新值替换旧值,并退出循环
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    if (!onlyIfAbsent) {
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                // 如果没有和key相同的,一直遍历到链尾,链尾的next为null,进入到else
                e = e.next;
            }
            else {
            	// 如果key不存在,则把当前Entry插入到链头(头插法)
                if (node != null)
                    node.setNext(first);
                else
                    node = new HashEntry<K,V>(hash, key, value, first);
                // 此时数量+1
                int c = count + 1;
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    // 需要注意的地方:如果超出了HashEntry的阈值,就要对HashEntry[]进行扩容
                    rehash(node);
                else
                    setEntryAt(tab, index, node);
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        // 6. 操作完成后,释放对象锁
        unlock();
    }
    return oldValue;
}

从 put() 源码中可以看出,真正的 put 操作主要分为以下几步:

  1. 尝试获取对象锁,如果获取到就返回 true,否则执行 scanAndLockForPut() 方法,这个方法也是尝试获取对象锁;

  2. 获取到锁之后,类似于 hashMap 的 put() 方法,通过 key 计算所在 HashEntry 数组的下标,然后获取这个位置上的 HashEntry;

  3. 获取到数组下标之后遍历链表内容,为什么要遍历?因为这里获取的 HashEntry 可能是一个空元素,也可能是链表已存在,所以要区别对待。

    ① 如果这个位置上的 HashEntry 不存在:

    a. 如果当前容量大于扩容阀值,小于最大容量,进行扩容。
    b. 直接头插法插入。
    ② 如果这个位置上的 HashEntry 存在:

    a. 判断链表当前元素 key 和 hash 值是否和要 put 的 key 和 hash 值一致。一致则替换值;
    b. 不一致,获取链表下一个节点,直到发现相同进行值替换,或者链表表里没有相同的:

     a) 如果当前容量大于扩容阀值,小于最大容量,进行扩容。
     b) 直接链表头插法插入。
    
  4. 如果要插入的位置之前已经存在,替换后返回旧值,否则返回 null;

  5. 最后操作完整之后,释放对象锁;

再来扒一下 scanAndLockForPut() 方法的源码:

private HashEntry<K, V> scanAndLockForPut(K key, int hash, V value) {
	// 1. 定位HashEntry数组位置,获取第一个节点
	HashEntry<K,V> first = entryForHash(this, hash); 
	HashEntry<K,V> e = first;
	HashEntry<K,V> node = null;
	// 2. 重试次数
	int retries = -1; // negative while locating node
	// 自旋获取锁
	while (!tryLock()) {
		HashEntry<K,V> f; // to recheck first below
		if (retries < 0)
			if(e == null)
				if (node == null) // speculatively create node
					// 3. 构造新节点
					node = new HashEntry<K, V> (hash, key, value, null);
				retries = 0;
			}
			else if (key. equals(e. key))
				retries = 0;
			else
				e = e.next;
		}
		else if (++retries > MAX_ SCAN _RETRIES) {
			// 4. 重试次数+1,如果大于最大次数,调用lock()方法获取锁,如果没有获取当前线程就被阻塞,直到获取并跳出循环
			Lock();
			break;
		}
		else if ((retries & 1) == 0 &&
				(f = entryForHash(this, hash)) != first) {
			// 5. HashEntry存储内容发生变化,重置重试次数
			e = first = f; // re-traverse if entry changed
			retries = -1;
		}
	}
	return node;
}	

scanAndLockForPut() 方法的操作也是分以下几步:

  1. 当前线程尝试去获得锁,查找 key 是否已经存在,如果不存在,就创建一个HashEntry 对象;
  2. 如果重试次数大于最大次数,就调用 lock() 方法获取对象锁,如果依然没有获取到,当前线程就阻塞,直到获取之后退出循环;
  3. 在这个过程中,key 可能被别的线程给插入,所以在第 5 步中,如果 HashEntry 存储内容发生变化,重置重试次数;

通过 scanAndLockForPut() 方法,当前线程就可以在即使获取不到 segment 锁的情况下,完成需要添加节点的实例化工作,当获取锁后,就可以直接将该节点插入链表即可。

这个方法还实现了类似于自旋锁的功能,循环式的判断对象锁是否能够被成功获取,直到获取到锁才会退出循环,防止执行 put 操作的线程频繁阻塞,这些优化都提升了 put 操作的性能。

05、get 操作

get() 方法因为不涉及增、删、改操作,所以不存在并发故障问题。

扒一下 get() 方法的源码:

public V get(Object key) {
	Segment<K,V> s; // manually integrate access methods to reduce overhead
	HashEntry<K,V>[] tab;
	
	// 1. 计算key的hash值
	int h = hash(key); 
	// 2. 计算该hash值所属的Segment[]的角标
	long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE ;
	// 3. 获取Segment[]中u角标下的Segment对象,不存在直接返回
	if ((s = (Segment<K, V> )UNSAFE. getObjectVolatile(segments, u)) != null &&
		(tab = s.table) != null) {
		// 4. 再根据hash值,从Segment对象中的HashEntry[]获取HashEntry对象,并对HashEntry对象进行链表遍历
		for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE. getObjectVolatile
			(tab, ((Long)(((tab. Length - 1) & h)) << TSHIFT) + TBASE);
		e != null; e = e.next) {
			K k;
			// 5. 在链表中找到对应元素,并返回
			if ((k = e.key) === key || (e.hash == h && key.equals(k)))
				return e. value;
		}
	}
	return null;
}

get() 方法只需要两步就可以实现:

  1. 计算得到 key 的存放位置;
  2. 遍历指定位置查找相同 key 的 value 值。

由于 HashEntry 涉及到的共享变量都使用 volatile 修饰,volatile 可以保证内存可见性,所以不会读取到过期数据。

06、remove 操作

remove() 操作和 put() 方法差不多,都需要获取锁对象才能操作,通过 key 找到元素所在的 Segment 对象,然后移除即可。

扒一下 remove() 方法的源码:

public V remove(0bject key) {
	// 1. 计算key的hash值
	int hash = hash( key);
	// 2. 计算该hash值所属的Segment[]的角标,再通过角标下的Segment对象
	Segment<K,V> S = segmentForHash(hash); 
	// 3. 执行移除方法
	return s == null ? null : s. remove(key, hash, null);
}

与 get() 方法类似,都是先获取 Segment 数组所在的 Segment 对象,然后再调用 Segment 对象的 remove() 方法区移除。

扒一下 Segment 对象的 remove() 方法:

final V remove(0bject key, int hash, 0bject value) {
	// 1. 尝试获取对象锁
	if (!tryLock())
		scanAndLock(key, hash);
	V oLdValue = null;
	try {
		HashEntry<K,V>[] tab = table; 
		// 2. 计算key的hash值在HashEntry[]中的角标
		int index = (tab. length - 1) & hash;
		// 3. 根据index角标获取HashEntry对象
		HashEntry<K,V> e = entryAt (tab, index);
		HashEntry<K,V> pred = null;
		// 4. 循环遍历HashEntry对象,HashEntry为单向链表结构
		while (e != null) {
			K k;
			HashEntry<K,V> next = e. next;
			// 5. 通过key和hash判断key是否存在
			if ((k = e.key) == key |
				(e.hash == hash && key. equals(k))) {
				V v = e.value;
				// 6. 移除元素,将下节点向上移
				if (value == null|| value == v H value. equals(v)) {
					if (pred == null)
						setEntryAt(tab, index, next);
					else
						pred. setNext(next);
					++modCount ;
					count;
					oldValue = V;
				}
				break;
			}
			pred = e
			e = next;
		}
	} finally {
	// 7. 释放锁
	unLock(); 
	return oldValue;
}

先获取对象锁,如果获取到之后执行移除操作,之后的操作类似于 HashMap 的移除方法,步骤如下:

  1. 先获取对象锁;
  2. 计算 key 的 hash 值在 HashEntry[] 中的角标;
  3. 根据 index 角标获取 HashEntry 对象;
  4. 循环遍历 HashEntry 对象,HashEntry 为单向链表结构;
  5. 通过 key 和 hash 判断 key 是否存在,如果存在,就移除元素,并将需要移除的元素节点的下一个向上移动;
  6. 最后就是释放对象锁,以便其他线程使用。

07、扩容 rehash

ConcurrentHashMap 的扩容只会扩容到原来的两倍。老数组里的数据移动到新的数组时,位置要么不变,要么变为 index+ oldSize,参数里的 node 会在扩容之后使用链表头插法插入到指定位置。

扒一下 rehash() 方法的源码:

private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    // 老容量
    int oldCapacity = oldTable.length;
    // 新容量,扩大两倍
    int newCapacity = oldCapacity << 1;
    // 新的扩容阀值 
    threshold = (int)(newCapacity * loadFactor);
    // 创建新的数组
    HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity];
    // 新的掩码,默认2扩容后是4,-1是3,二进制就是11。
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
        // 遍历老数组
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 计算新的位置,新的位置只可能是不变或者是老的位置+老的容量。
            int idx = e.hash & sizeMask;
            if (next == null)   //  Single node on list
                // 如果当前位置还不是链表,只是一个元素,直接赋值
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                // 如果是链表了
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                // 新的位置只可能是不变或者是老的位置+老的容量。
                // 遍历结束后,lastRun 后面的元素位置都是相同的
                for (HashEntry<K,V> last = next; last != null; last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                // lastRun 后面的元素位置都是相同的,直接作为链表赋值到新位置。
                newTable[lastIdx] = lastRun;
                // Clone remaining nodes
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    // 遍历剩余元素,头插法到指定 k 位置。
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 头插法插入新的节点
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

这里有两个 for 循环:第一个 for 是为了寻找这样一个节点,这个节点后面的所有 next 节点的新位置都是相同的。然后把这个作为一个链表赋值到新位置。第二个 for 循环是为了把剩余的元素通过头插法插入到指定位置链表。这样实现的原因可能是基于概率统计。

10.2 JDK 1.8 中的 ConcurrentHashMap

01、ConcurrentHashMap 的关键属性

  1. tablevolatile Node<K,V>[] table:

    装载 Node 的数组,作为 ConcurrentHashMap 的数据容器,采用懒加载的方式,直到第一次插入数据的时候才会进行初始化操作,数组的大小总是为 2 的幂次方。

  2. nextTablevolatile Node<K,V>[] nextTable;

    扩容时使用,平时为 null,只有在扩容的时候才为非 null。

  3. sizeCtlvolatile int sizeCtl;

    该属性用来控制 table 数组的大小,根据是否初始化和是否正在扩容有几种情况:

    • 当值为负数时: 如果为 -1 表示正在初始化,如果为 -N 则表示当前正有 N-1 个线程进行扩容操作;
    • 当值为正数时: 如果当前数组为 null 的话表示 table 在初始化过程中,sizeCtl 表示为需要新建数组的长度;
    • 若已经初始化了,表示当前数据容器(table 数组)可用容量也可以理解成临界值(插入节点数超过了该临界值就需要扩容),具体指为数组的长度 n 乘以加载因子 loadFactor;
    • 当值为0时,即数组长度为默认初始值。
  4. sun.misc.Unsafe U

    在 ConcurrentHashMapde 的实现中可以看到大量的 U.compareAndSwapXXXX 的方法去修改 ConcurrentHashMap 的一些属性。

    这些方法实际上是利用了 CAS 算法保证了线程安全性,这是一种乐观策略,假设每一次操作都不会产生冲突,当且仅当冲突发生的时候再去尝试。

    而 CAS 操作依赖于现代处理器指令集,通过底层 CMPXCHG 指令实现。CAS(V,O,N) 核心思想为:若当前变量实际值 V 与期望的旧值 O 相同,则表明该变量没被其他线程进行修改,因此可以安全的将新值 N 赋值给变量;若当前变量实际值 V 与期望的旧值 O 不相同,则表明该变量已经被其他线程做了处理,此时将新值 N 赋给变量操作就是不安全的,再进行重试。

    而在大量的同步组件和并发容器的实现中使用 CAS 是通过 sun.misc.Unsafe 类实现的,该类提供了一些可以直接操控内存和线程的底层操作,可以理解为 java 中的“指针”。该成员变量的获取是在静态代码块中:

    static {
        try {
            U = sun.misc.Unsafe.getUnsafe();
    		.......
        } catch (Exception e) {
            throw new Error(e);
        }
    }
    

02、ConcurrentHashMap 中关键内部类

  1. Node

    Node 实现了 Map.Entry 接口,主要存放 key-value 键值对,并且具有 next 域:

    static class Node<K,V> implements Map.Entry<K,V> {
            final int hash;
            final K key;
            volatile V val;
            volatile Node<K,V> next;
    
            Node(int hash, K key, V val, Node<K,V> next) {
                this.hash = hash;
                this.key = key;
                this.val = val;
                this.next = next;
            }
            ...
    }
    

    可以看出很多属性都是用 volatile 进行修饰的,也就是为了保证内存可见性。

  2. TreeNode

    树节点,继承于承载数据的 Node 类。而红黑树的操作是针对 TreeBin 类的,从该类的注释也可以看出,也就是 TreeBin 会将 TreeNode 进行再一次封装:

    static final class TreeNode<K,V> extends Node<K,V> {
            TreeNode<K,V> parent;  // red-black tree links
            TreeNode<K,V> left;
            TreeNode<K,V> right;
            TreeNode<K,V> prev;    // needed to unlink next upon deletion
            boolean red;
    		......
    }
    
  3. TreeBin

    这个类并不负责包装用户的 key、value 信息,而是包装的很多 TreeNode 节点。实际的 ConcurrentHashMap “数组” 中,存放的是 TreeBin 对象,而不是 TreeNode 对象:

    static final class TreeBin<K,V> extends Node<K,V> {
            TreeNode<K,V> root;
            volatile TreeNode<K,V> first;
            volatile Thread waiter;
            volatile int lockState;
            // values for lockState
            static final int WRITER = 1; // set while holding write lock
            static final int WAITER = 2; // set when waiting for write lock
            static final int READER = 4; // increment value for setting read lock
    		......
    }
    
  4. ForwardingNode

    在扩容时才会出现的特殊节点,其 key、value、hash 全部为 null。并拥有 nextTable 指针引用新的 table 数组:

    static final class ForwardingNode<K,V> extends Node<K,V> {
        final Node<K,V>[] nextTable;
        ForwardingNode(Node<K,V>[] tab) {
            super(MOVED, null, null, null);
            this.nextTable = tab;
        }
       .....
    }
    

03、JDK 1.8 中 ConcurrentHashMap 底层结构

虽然 JDK 1.7 中的 ConcurrentHashMap 解决了 HashMap 并发的安全性,但是当冲突的链表过长时,在查询遍历的时候依然很慢。

所以在 JDK 1.8 中,引入了 红黑树。当冲突的链表长度大于 8 时,会将链表转化成红黑树,红黑树又被称为平衡二叉树,在查询效率方面,又有所提升。

JDK 1.8 中的数据结构:

JavaSE进阶之(十)Map 子接口之 ConcurrentHashMap-小白菜博客
可以发现 JDK 1.8 的 ConcurrentHashMap 相对于 JDK 1.7 来说变化比较大,不再是之前的 Segment 数组 + HashEntry 数组 + 链表,而是 Node 数组 + 链表 / 红黑树。当冲突链表达到一定长度时,链表会转换成红黑树。

与 JDK1.7 中的 ConcurrentHashMap 相比, 它抛弃了原有的 Segment 分段锁实现,采用了 CAS + synchronized 来保证并发的安全性。

JDK 1.8 中的 ConcurrentHashMap 对节点 Node 类中的共享变量也使用了 volatile 关键字,保证多线程操作时变量的可见性。

扒一下 Node 类的源码:

static class Node<K,V> implements Map.Entry<K,V> {
		// hash 值
        final int hash;
        // 键
        final K key;
        // 值
        volatile V val;
        // 下一个节点
        volatile Node<K,V> next;

        Node(int hash, K key, V val, Node<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.val = val;
            this.next = next;
        }
        ...
}

04、实例构造器方法

在使用 ConcurrentHashMap 第一件事自然而然就是 new 出来一个 ConcurrentHashMap 对象,一共提供了几个构造器方法:

// 1. 构造一个空的map,即table数组还未初始化,初始化放在第一次插入数据时,默认大小为16
ConcurrentHashMap()
// 2. 给定map的大小
ConcurrentHashMap(int initialCapacity) 
// 3. 给定一个map
ConcurrentHashMap(Map<? extends K, ? extends V> m)
// 4. 给定map的大小以及加载因子
ConcurrentHashMap(int initialCapacity, float loadFactor)
// 5. 给定map大小,加载因子以及并发度(预计同时操作数据的线程)
ConcurrentHashMap(int initialCapacity,float loadFactor, int concurrencyLevel)

当传入了指定大小的 map 时,该构造器的源码为:

public ConcurrentHashMap(int initialCapacity) {
	// 1. 小于0直接抛异常
    if (initialCapacity < 0)
        throw new IllegalArgumentException();
	// 2. 判断是否超过了允许的最大值,超过了话则取最大值,否则再对该值进一步处理
    int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
               MAXIMUM_CAPACITY :
               tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
	// 3. 赋值给sizeCtl
    this.sizeCtl = cap;
}

这段代码的逻辑很容易理解:如果小于 0 就直接抛出异常;如果指定值大于了所允许的最大值的话就取最大值;否则,再对指定值做进一步处理。最后将 cap 赋值给 sizeCtl,当调用构造器方法之后,sizeCtl 的大小应该就代表了 ConcurrentHashMap 的大小,即 table 数组长度。

扒一下其中 tableSizeFor() 方法的源码:

/**
 * Returns a power of two table size for the given desired capacity.
 * See Hackers Delight, sec 3.2
 */
private static final int tableSizeFor(int c) {
    int n = c - 1;
    n |= n >>> 1;
    n |= n >>> 2;
    n |= n >>> 4;
    n |= n >>> 8;
    n |= n >>> 16;
    return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
}

该方法会将调用构造器方法时指定的大小转换成一个 2 的幂次方数,也就是说 ConcurrentHashMap 的大小一定是 2 的幂次方。比如,当指定大小为 18 时,为了满足 2 的幂次方特性,实际上 concurrentHashMap 的大小为 2 的 5 次方(32)。

另外,需要注意的是,调用构造器方法的时候并未构造出 table 数组(可以理解为 ConcurrentHashMap 的数据容器),只是算出 table 数组的长度,当第一次向 ConcurrentHashMap 插入数据的时候才真正的完成初始化创建 table 数组的工作。

05、put 操作

扒一下 put() 方法的源码:

public V put(K key, V value) {
    return putVal(key, value, false);
}

/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
	// 1. key、value不允许为空
    if (key == null || value == null) throw new NullPointerException();
    // 2. 通过 key 获取到hash值
    int hash = spread(key.hashCode());
    int binCount = 0;
    for (Node<K,V>[] tab = table;;) {
    	// f = 目标位置元素
        Node<K,V> f; int n, i, fh;
        // 如果当前table还没有初始化,先调用initTable()方法将tab进行初始化
        if (tab == null || (n = tab.length) == 0)
        	// 3. 如果tab为空,就初始化node数组(自旋+CAS)
            tab = initTable();
        else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
        	// 4. 如果f为null,说明table中这个位置第一次插入元素,利用Unsafe.compareAndSwapObject()方法插入Node节点
        	// 桶内为空,CAS放入,不加锁,成功了就直接break
            if (casTabAt(tab, i, null,
                         new Node<K,V>(hash, key, value, null)))
                break;                   // no lock when adding to empty bin
        }
        // 当前正在扩容
        else if ((fh = f.hash) == MOVED)
        	// 5. MOVED等于-1,如果f.hash等于-1,说明当前f是ForwardingNode节点,意味着其他线程正在扩容,则一起进行扩容操作
            tab = helpTransfer(tab, f);
        else {
        	// 6. 其余情况都是把新的Node节点按链表或红黑树的方式插入到合适的位置
            V oldVal = null;
            // 7. 采用同步内置锁实现并发控制
            synchronized (f) {
            	// 7.1 节点插入之前,再次利用tabAt(tab, i)==f判断,防止被其他线程修改
                if (tabAt(tab, i) == f) {
                	// 7.2 如果fh=f.hash >= 0,说明当前为链表,在链表中插入新的键值对
                    if (fh >= 0) {
                    	// 7.3 遍历链表,如果找到对应的node节点,则修改value;否则直接在链表尾部加入节点
                        binCount = 1;
                        // 循环加入新的或者覆盖节点
                        for (Node<K,V> e = f;; ++binCount) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                oldVal = e.val;
                                if (!onlyIfAbsent)
                                    e.val = value;
                                break;
                            }
                            Node<K,V> pred = e;
                            if ((e = e.next) == null) {
                                pred.next = new Node<K,V>(hash, key,
                                                          value, null);
                                break;
                            }
                        }
                    }
                    else if (f instanceof TreeBin) {
                    	// 7.4 如果f是TreeBin类型节点,说明f是红黑树根节点,则在树结构上遍历元素,更新或增加节点
                        Node<K,V> p;
                        binCount = 2;
                        if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                       value)) != null) {
                            oldVal = p.val;
                            if (!onlyIfAbsent)
                                p.val = value;
                        }
                    }
                }
            }
            // 插入完键值对后再根据实际大小看是否需要转换成红黑树
            if (binCount != 0) {
            	// 8. 如果链表中节点数binCount>=TREEIFY_THRESHOLD(默认是8),则把链表转化为红黑树结构
                if (binCount >= TREEIFY_THRESHOLD)
                    treeifyBin(tab, i);
                if (oldVal != null)
                    return oldVal;
                break;
            }
        }
    }
    // 9. 插入完成之后,扩容判断(如果超过了临界值:实际大小*加载因子,就需要扩容)
    addCount(1L, binCount);
    return null;
}

当进行 put 操作时,流程大概可以分如下几个步骤:

  1. 首先对于每一个放入的值,首先利用 spread() 方法对 key 的 hashcode 进行一次 hash 计算,由此来确定这个值在 table 中的位置;
  2. 如果当前 table 数组还未初始化,先将 table 数组进行初始化操作;
  3. 如果这个位置是 null 的,那么使用 CAS 操作直接放入;
  4. 如果这个位置存在结点,说明发生了 hash 碰撞,首先判断这个节点的类型。如果该节点 fh==MOVED(代表 forwardingNode,数组正在进行扩容)的话,说明正在进行扩容;
  5. 如果是链表节点(fh>0),则得到的结点就是 hash 值相同的节点组成的链表的头节点。需要依次向后遍历确定这个新加入的值所在位置。如果遇到 key 相同的节点,则只需要覆盖该结点的 value 值即可。否则依次向后遍历,直到链表尾插入这个结点;
  6. 如果这个节点的类型是 TreeBin 的话,直接调用红黑树的插入方法进行插入新的节点;
  7. 插入完节点之后再次检查链表长度,如果长度大于 8,就把这个链表转换成红黑树;
  8. 对当前容量大小进行检查,如果超过了临界值(实际大小*加载因子)就需要扩容。

put 操作大致的流程,就是这样的,可以看的出,复杂程度比 JDK1.7 上了一个台阶。

1、initTable 初始化数组

put() 方法的第 3 步中调用了 initTable()方法进行初始化数组:

private final Node<K,V>[] initTable() {
	// 1. 创建临时变量tab、sc,tab表示Node数组,sc表示临时变量
    Node<K,V>[] tab; int sc;
    while ((tab = table) == null || tab.length == 0) {
    	// 2. sizeCtl默认为0,用来控制table的初始化和扩容操作,使用了volatile关键字修饰保证并发的可见性
    	// 保证只有一个线程正在进行初始化操作
    	// 如果 sizeCtl < 0,说明另外的线程执行CAS操作成功,当前线程只需要让出CPU时间片
        if ((sc = sizeCtl) < 0)
        	// 让出 CPU 使用权
            Thread.yield(); // lost initialization race; just spin
        else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
        	// 3. 通过CAS方法,将sizeCtl修改为-1,有且只有一个线程能够修改成功
            try {
            	// 4. 初始化Node数组
                if ((tab = table) == null || tab.length == 0) {
                	// 得出数组的大小
                    int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                    @SuppressWarnings("unchecked")
                    // 这里才真正的初始化数组
                    Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                    table = tab = nt;
                   	// 计算数组中可用的大小:实际大小 n*0.75
                   	// n-(n>>2) = n-(1/4)n = (3/4)n = 0.75
                    sc = n - (n >>> 2);
                }
            } finally {
                sizeCtl = sc;
            }
            break;
        }
    }
    return tab;
}

这里需要注意:计算数组中可用的大小(数组实际大小)时:n-(n>>2) = n-(1/4)n = (3/4)n = 0.75,位运算效率会高一些。

如果选择是无参的构造器的话,这里在 new Node 数组的时候会使用默认大小为 DEFAULT_CAPACITY(16),然后乘以加载因子 0.75 为 12,也就是说数组的可用大小为 12。

sizeCtl 是一个对象属性,使用了volatile关键字修饰保证并发的可见性,默认为 0,当第一次执行 put 操作时,通过 Unsafe.compareAndSwapInt() 方法,俗称CAS,将 sizeCtl修改为 -1,有且只有一个线程能够修改成功,接着执行 table 初始化任务。

如果别的线程发现 sizeCtl<0,意味着有另外的线程执行 CAS 操作成功,当前线程通过执行 Thread.yield() 让出 CPU 时间片等待 table 初始化完成。

从源码中可以发现 ConcurrentHashMap 的初始化是通过自旋和 CAS 操作完成的。里面的变量 sizeCtl,它的值决定着当前的初始化状态

  1. -1 说明正在初始化
  2. -N 说明有 N-1 个线程正在进行扩容
  3. 表示 table 初始化大小,如果 table 没有初始化
  4. 表示 table 容量,如果 table 已经初始化

2、helpTransfer() 帮组扩容

put() 方法中的第 5 步调用了 helpTransfer() 方法,如果 f.hash == -1,说明当前 f 是 ForwardingNode 节点,意味着有其他线程正在扩容,则一起进行扩容操作:

final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
    Node<K,V>[] nextTab; int sc;
    // 1. 如果table不为空且node节点是转移类型,同时node节点的nextTable(新table)不为空,进行数据校验
    if (tab != null && (f instanceof ForwardingNode) &&
        (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
        // 2. 满足以上条件之后,尝试帮助扩容
        // 数据数组的length得到一个标识符号
        int rs = resizeStamp(tab.length);
        // 3. 如果nextTable没有被并发修改且tab也没有被并发修改,同时sizeCtl<0,说明还在扩容
        while (nextTab == nextTable && table == tab &&
               (sc = sizeCtl) < 0) {
            // 4. 对sizeCtl参数值进行分析判断
            // 判断1:如果sizeCtl无符号右移16不等于rs,则标识符变化了
            // 判断2:如果sizeCtl == rs+1,标识扩容结束了,不再线程进行扩容
            // 判断3:如果sizeCtl == rs+65535,表示达到最大帮助线程的数量,即65535
            // 判断4:如果转移下标transferIndex <= 0,表示扩容结束
            // 满足任何一个判断,结束循环,返回table
            if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                sc == rs + MAX_RESIZERS || transferIndex <= 0)
                break;
            // 5. 如果以上都不是,将sizeCtl + 1,表示增加了一个线程帮助其扩容
            if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
            	// 6. 对数组进行转移,执行完之后结束循环。
                transfer(tab, nextTab);
                break;
            }
        }
        return nextTab;
    }
    return table;
}

这个过程操作步骤如下:

  1. 第1步,对 table、node 节点、node 节点的 nextTable,进行数据校验;
  2. 第2步,根据数组的 length 得到一个标识符号;
  3. 第3步,进一步校验 nextTab、tab、sizeCtl 值,如果 nextTab 没有被并发修改并且 tab 也没有被并发修改,同时 sizeCtl < 0,说明还在扩容;
  4. 第4步,对 sizeCtl 参数值进行分析判断,如果不满足任何一个判断,将 sizeCtl + 1, 增加了一个线程帮助其扩容。

3、addCount() 扩容判断

put() 方法的第 9 步调用了 addCount() 方法,插入完成之后进行扩容判断:

private final void addCount(long x, int check) {
	// 1. 从putVal传入的参数是 x=1、check=0,只有hash冲突且它的结构是链表的结构时,check才会大于1
    CounterCell[] as; long b, s;
    // 2. 利用CAS方法更新baseCount的值
    if ((as = counterCells) != null ||
        !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
        CounterCell a; long v; int m;
        boolean uncontended = true;
        if (as == null || (m = as.length - 1) < 0 ||
            (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
            !(uncontended =
              U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
            fullAddCount(x, uncontended);
            return;
        }
        if (check <= 1)
            return;
        s = sumCount();
    }
    // 3. 检查是否需要扩容,默认check=1,需要检查
    if (check >= 0) {
        Node<K,V>[] tab, nt; int n, sc;
        // 4. 如果map.size()大于等于sizeCtl(达到扩容阈值需要扩容)并且table不是空,同时table的长度小于最大容量,可以扩容
        while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
               (n = tab.length) < MAXIMUM_CAPACITY) {
            // 5. 根据length得到一个标识
            int rs = resizeStamp(n);
            // sc=sizeCtl,如果小于0,标识正在扩容,尝试帮助扩容
            if (sc < 0) {
            	// 6. 对sizeCtl参数值进行分析判断,与帮助扩容阶段的判断一样
	            // 判断1:如果sizeCtl无符号右移16不等于rs,则标识符变化了
	            // 判断2:如果sizeCtl == rs+1,标识扩容结束了,不再线程进行扩容
	            // 判断3:如果sizeCtl == rs+65535,表示达到最大帮助线程的数量,即65535
	            // 判断4:如果转移下标transferIndex <= 0,表示扩容结束
	            // 满足任何一个判断,结束循环,返回table
                if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                    sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                    transferIndex <= 0)
                    break;
                if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                    transfer(tab, nt);
            }
            // 7. 如果不在扩容,将sizeCtl更新:标识符左移16位+2,也就是变成一个负数
            else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                         (rs << RESIZE_STAMP_SHIFT) + 2))
                // 8. 进行扩容处理
                transfer(tab, null);
            s = sumCount();
        }
    }
}

这个过程操作步骤如下:

  1. 第1步,利用 CAS 将方法更新 baseCount 的值;
  2. 第2步,检查是否需要扩容,默认 check = 1,需要检查;
  3. 第3步,如果满足扩容条件,判断当前是否正在扩容,如果是正在扩容就一起扩容;
  4. 第4步,如果不在扩容,将 sizeCtl 更新为负数,并进行扩容处理。

从 put 的流程中可以发现,里面大量的使用了 CAS 方法,CAS 表示比较与替换,里面有 3 个参数,分别是目标内存地址、旧值、新值每次判断的时候,会将旧值与目标内存地址中的值进行比较,如果相等,就将新值更新到内存地址里,如果不相等,就继续循环,直到操作成功为止。

虽然使用的了CAS这种乐观锁方法,但是里面的细节设计的是很复杂的。

06、get 操作

get() 方法不涉及并发操作,直接查询就可以了:

public V get(Object key) {
    Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    int h = spread(key.hashCode());
    // 1. 判断数组是否为空,通过key定位到数组下标是否为空
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (e = tabAt(tab, (n - 1) & h)) != null) {
        // 2. 判断node节点第一个元素是不是要找到,如果是直接返回
        if ((eh = e.hash) == h) {
            if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                return e.val;
        }
         // 3. 如果是红黑树结构,就从红黑树里面查询
        else if (eh < 0)
        	// 头节点hash值小于0,说明正在扩容或者是红黑树,find查找
            return (p = e.find(h, key)) != null ? p.val : null;
        while ((e = e.next) != null) {
        	// 是链表,遍历查询
            if (e.hash == h &&
                ((ek = e.key) == key || (ek != null && key.equals(ek))))
                return e.val;
        }
    }
    return null;
}

get() 方法的步骤如下:

  1. 第1步,判断数组是否为空,通过 key 定位到数组下标是否为空;
  2. 第2步,判断 node 节点第一个元素是不是要找到,如果是直接返回;
  3. 第3步,如果是红黑树结构,就从红黑树里面查询;
  4. 第4步,如果是链表结构,循环遍历判断;

07、remove 操作

remove() 方法和 put() 方法类似,只是方向是反的。

扒一下 remove() 方法的源码:

public V remove(Object key) {
    return replaceNode(key, null, null);
}

/**
 * Implementation for the four public remove/replace methods:
 * Replaces node value with v, conditional upon match of cv if
 * non-null.  If resulting value is null, delete.
 */
final V replaceNode(Object key, V value, Object cv) {
    int hash = spread(key.hashCode());
    // 1. 循环遍历数组
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        // 2. 参数校验
        if (tab == null || (n = tab.length) == 0 ||
            (f = tabAt(tab, i = (n - 1) & hash)) == null)
            break;
        else if ((fh = f.hash) == MOVED)
        	// 3. 帮助扩容
            tab = helpTransfer(tab, f);
        else {
            V oldVal = null;
            boolean validated = false;
            // 4. 利用 synchronized 同步锁,保证并发时元素移除安全
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                	// 4.1 判断当前冲突节点是否为链表结构,如果是循环遍历移除
                    if (fh >= 0) {
                        validated = true;
                        for (Node<K,V> e = f, pred = null;;) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                V ev = e.val;
                                if (cv == null || cv == ev ||
                                    (ev != null && cv.equals(ev))) {
                                    oldVal = ev;
                                    if (value != null)
                                        e.val = value;
                                    else if (pred != null)
                                        pred.next = e.next;
                                    else
                                        setTabAt(tab, i, e.next);
                                }
                                break;
                            }
                            pred = e;
                            if ((e = e.next) == null)
                                break;
                        }
                    }
                    // 4.2 如果是红黑树结构,利用红黑二叉树特性进行查找并移除节点,最后调整红黑树结构
                    else if (f instanceof TreeBin) {
                        validated = true;
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> r, p;
                        if ((r = t.root) != null &&
                            (p = r.findTreeNode(hash, key, null)) != null) {
                            V pv = p.val;
                            if (cv == null || cv == pv ||
                                (pv != null && cv.equals(pv))) {
                                oldVal = pv;
                                if (value != null)
                                    p.val = value;
                                else if (t.removeTreeNode(p))
                                    setTabAt(tab, i, untreeify(t.first));
                            }
                        }
                    }
                }
            }
            if (validated) {
                if (oldVal != null) {
                    if (value == null)
                    	// 5. 因为check=-1,所以不会进行扩容操作,利用CAS操作修改baseCount值
                        addCount(-1L, -1);
                    return oldVal;
                }
                break;
            }
        }
    }
    return null;
}

remove() 操作的步骤如下:

  1. 第1步,循环遍历数组,接着校验参数;
  2. 第2步,判断是否有别的线程正在扩容,如果是一起扩容;
  3. 第3步,用 synchronized 同步锁,保证并发时元素移除安全;
  4. 第4步,因为 check= -1,所以不会进行扩容操作,利用 CAS 操作修改 baseCount 值。

08、transfer() 方法

当 ConcurrentHashMap 容量不足的时候,需要对 table 进行扩容。这个方法的基本思想跟 HashMap 是很像的,但是由于它是支持并发扩容的,所以要复杂的多。原因是它支持多线程进行扩容操作,而并没有加锁。这样做的目的可能不仅仅是为了满足 concurrent 的要求,而是希望利用并发处理去减少扩容带来的时间影响。

扒一下 transfer() 方法的源码:

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
	// 1. 新建Node数组,容量为之前的两倍
    if (nextTab == null) {            // initiating
        try {
            @SuppressWarnings("unchecked")
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        nextTable = nextTab;
        transferIndex = n;
    }
    int nextn = nextTab.length;
	// 2. 新建forwardingNode引用,在之后会用到
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    for (int i = 0, bound = 0;;) {
        Node<K,V> f; int fh;
        // 3. 确定遍历中的索引i
		while (advance) {
            int nextIndex, nextBound;
            if (--i >= bound || finishing)
                advance = false;
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            else if (U.compareAndSwapInt
                     (this, TRANSFERINDEX, nextIndex,
                      nextBound = (nextIndex > stride ?
                                   nextIndex - stride : 0))) {
                bound = nextBound;
                i = nextIndex - 1;
                advance = false;
            }
        }
		// 4.将原数组中的元素复制到新数组中去
		// 4.5 for循环退出,扩容结束修改sizeCtl属性
        if (i < 0 || i >= n || i + n >= nextn) {
            int sc;
            if (finishing) {
                nextTable = null;
                table = nextTab;
                sizeCtl = (n << 1) - (n >>> 1);
                return;
            }
            if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                    return;
                finishing = advance = true;
                i = n; // recheck before commit
            }
        }
		// 4.1 当前数组中第i个元素为null,用CAS设置成特殊节点forwardingNode(可以理解成占位符)
        else if ((f = tabAt(tab, i)) == null)
            advance = casTabAt(tab, i, null, fwd);
		// 4.2 如果遍历到ForwardingNode节点,说明这个点已经被处理过了,直接跳过  这里是控制并发扩容的核心
        else if ((fh = f.hash) == MOVED)
            advance = true; // already processed
        else {
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    Node<K,V> ln, hn;
                    if (fh >= 0) {
						// 4.3 处理当前节点为链表的头结点的情况,构造两个链表:一个是原链表、另一个是原链表的反序排列
                        int runBit = fh & n;
                        Node<K,V> lastRun = f;
                        for (Node<K,V> p = f.next; p != null; p = p.next) {
                            int b = p.hash & n;
                            if (b != runBit) {
                                runBit = b;
                                lastRun = p;
                            }
                        }
                        if (runBit == 0) {
                            ln = lastRun;
                            hn = null;
                        }
                        else {
                            hn = lastRun;
                            ln = null;
                        }
                        for (Node<K,V> p = f; p != lastRun; p = p.next) {
                            int ph = p.hash; K pk = p.key; V pv = p.val;
                            if ((ph & n) == 0)
                                ln = new Node<K,V>(ph, pk, pv, ln);
                            else
                                hn = new Node<K,V>(ph, pk, pv, hn);
                        }
                         // 在nextTable的i位置上插入一个链表
                         setTabAt(nextTab, i, ln);
                         // 在nextTable的i+n的位置上插入另一个链表
                         setTabAt(nextTab, i + n, hn);
                         // 在table的i位置上插入forwardNode节点,表示已经处理过该节点
                         setTabAt(tab, i, fwd);
                         // 设置advance为true,返回到上面的while循环中就可以执行i--操作
                         advance = true;
                    }
					// 4.4 处理当前节点是TreeBin时的情况,操作和上面的类似
                    else if (f instanceof TreeBin) {
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> lo = null, loTail = null;
                        TreeNode<K,V> hi = null, hiTail = null;
                        int lc = 0, hc = 0;
                        for (Node<K,V> e = t.first; e != null; e = e.next) {
                            int h = e.hash;
                            TreeNode<K,V> p = new TreeNode<K,V>
                                (h, e.key, e.val, null, null);
                            if ((h & n) == 0) {
                                if ((p.prev = loTail) == null)
                                    lo = p;
                                else
                                    loTail.next = p;
                                loTail = p;
                                ++lc;
                            }
                            else {
                                if ((p.prev = hiTail) == null)
                                    hi = p;
                                else
                                    hiTail.next = p;
                                hiTail = p;
                                ++hc;
                            }
                        }
                        ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                            (hc != 0) ? new TreeBin<K,V>(lo) : t;
                        hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                            (lc != 0) ? new TreeBin<K,V>(hi) : t;
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                }
            }
        }
    }
}

整个扩容操作分为两个部分:

  1. 第一部分是构建一个 nextTable,它的容量是原来的两倍,这个操作是单线程完成的。新建 table 数组的代码为:Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1],在原容量大小的基础上右移一位。

  2. 第二个部分就是将原来 table 中的元素复制到 nextTable 中,主要是遍历复制的过程。根据运算得到当前遍历的数组的位置 i,然后利用 tabAt() 方法获得 i 位置的元素再进行判断:

    • 如果这个位置为空,就在原 table 中的 i 位置放入 forwardNode 节点,这个也是触发并发扩容的关键点;
    • 如果这个位置是 Node 节点(fh>=0),如果它是一个链表的头节点,就构造一个反序链表,把他们分别放在 nextTable 的 i 和 i+n 的位置上;
    • 如果这个位置是 TreeBin 节点(fh<0),也做一个反序处理,并且判断是否需要 untreefi,把处理的结果分别放在 nextTable 的 i 和 i+n 的位置上;
    • 遍历过所有的节点以后就完成了复制工作,这时让 nextTable 作为新的 table,并且更新 sizeCtl 为新容量的 0.75 倍 ,完成扩容。设置为新容量的 0.75 倍代码为 sizeCtl = (n << 1) - (n >>> 1),仔细体会下是不是很巧妙,n<<1 相当于 n 左移一位表示 n 的两倍即 2n,n>>>1 右一位相当于 n 除以 2 即 0.5n,然后两者相减为 2n-0.5n=1.5n,就刚好等于新容量的 0.75 倍即 2n*0.75=1.5n
      在这里插入图片描述

09、与 size 相关的一些方法

对于 ConcurrentHashMap 来说,这个 table 里到底装了多少东西其实是个不确定的数量,因为不可能在调用 size() 方法的时候像 GC 的 “stop the world” 一样让其他线程都停下来让你去统计,因此只能说这个数量是个估计值。对于这个估计值,ConcurrentHashMap 也是大费周章才计算出来的。

为了统计元素个数,ConcurrentHashMap 定义了一些变量和一个内部类:

/**
 * A padded cell for distributing counts.  Adapted from LongAdder
 * and Striped64.  See their internal docs for explanation.
 */
@sun.misc.Contended static final class CounterCell {
    volatile long value;
    CounterCell(long x) { value = x; }
}

/******************************************/ 

/**
 * 实际上保存的是hashmap中的元素个数  利用CAS锁进行更新
 但它并不用返回当前hashmap的元素个数 

 */
private transient volatile long baseCount;
/**
 * Spinlock (locked via CAS) used when resizing and/or creating CounterCells.
 */
private transient volatile int cellsBusy;

/**
 * Table of counter cells. When non-null, size is a power of 2.
 */
private transient volatile CounterCell[] counterCells;

mappingCount 与 size() 方法:

mappingCount 与 size() 方法的类似,从给出的注释来看,应该使用 mappingCount 代替 size 方法,两个方法都没有直接返回 basecount ,而是统计一次这个值,而这个值其实也是一个大概的数值,因此可能在统计的时候有其他线程正在执行插入或删除操作。

public int size() {
    long n = sumCount();
    return ((n < 0L) ? 0 :
            (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
            (int)n);
}
 /**
 * Returns the number of mappings. This method should be used
 * instead of {@link #size} because a ConcurrentHashMap may
 * contain more mappings than can be represented as an int. The
 * value returned is an estimate; the actual count may differ if
 * there are concurrent insertions or removals.
 *
 * @return the number of mappings
 * @since 1.8
 */
public long mappingCount() {
    long n = sumCount();
    return (n < 0L) ? 0L : n; // ignore transient negative values
}

 final long sumCount() {
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
        for (int i = 0; i < as.length; ++i) {
            if ((a = as[i]) != null)
                sum += a.value;//所有counter的值求和
        }
    }
    return sum;
}

10.3 CAS(乐观锁)

CAS 是乐观锁的一种实现方式,是一种轻量级锁,JUC 中很多工具类的实现就是基于 CAS 的。

CAS 操作的流程如下图所示。线程在读取数据时不进行加锁,在准备写回数据时,比较原值是否被修改,若未被其他线程修改则写回,若已被修改,则重新执行读取流程。

这是一种乐观策略,认为并发操作并不总会发生。
JavaSE进阶之(十)Map 子接口之 ConcurrentHashMap-小白菜博客
乐观锁在开发场景中非常常用。

举个例子,假如我们要修改数据库中的一条数据,修改之前要先拿到它原来的值,然后还要在 SQL 中加一个判断:原来的值 == 现在拿到的它的原来的值是否一样,一样的话就可以修改了,不一样就说明被别的线程修改了,我们直接 return 错误就好了。

# oldValue 就是我们执⾏前查询出来的值
update a set value = newValue where value = #{oldValue}

01、CAS 关键操作

  1. tabAt

    static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i) {
        return (Node<K,V>)U.getObjectVolatile(tab, ((long)i << ASHIFT) + ABASE);
    }
    

    该方法用来获取 table 数组中索引为 i 的 Node 元素。

  2. casTabAt

    static final <K,V> boolean casTabAt(Node<K,V>[] tab, int i,
                                        Node<K,V> c, Node<K,V> v) {
        return U.compareAndSwapObject(tab, ((long)i << ASHIFT) + ABASE, c, v);
    }
    

    利用 CAS 操作设置 tab 数组中索引为 i 的元素。

  3. setTabAt

    static final <K,V> void setTabAt(Node<K,V>[] tab, int i, Node<K,V> v) {
        U.putObjectVolatile(tab, ((long)i << ASHIFT) + ABASE, v);
    }
    

    该方法用来设置 table 数组中索引为 i 的元素。

02、CAS 是否一定能保证数据没有被别的线程修改过?

当然不能。比如很经典的 ABA 问题,CAS 就无法判断。其实这一点在快速失败机制中也遇到了类似的问题。

03、什么是 ABA?

就是说,一个线程把值改为了 B,另一个线程又把值改为了 A。这个时候判断线程是否被修改,发现它的值还是 A,所以它并不知道这个值是否被修改过。其实在大多数场景下只是追求最终结果,结果正确就可以。

但是在实际开发中,是需要记录每一步的修改过程的,比如银行转账,每次的修改都应该是有记录的,方便回溯。

04、如何解决 ABA 问题?

  • 方法1:版本号验证

    用版本号去验证(可自定义组合而成)。比如说,我在修改前去查询它原来的值所对应的版本号,每次判断都将值和版本号一起进行判断,判断成功就给版本号加 1,说明记录下了这一次的操作。

# 判断原来的值和版本号是否匹配,中间有别的线程修改,值可能相等,但是版本号一定不相同
update a set value = newValue ,vision = vision + 1 where value = #{oldValue} and vision = #{vision} 

所以这样就解决了ABA问题,保证了一个线程操作的数据不被别的线程所修改。

  • 方法2:时间戳

    其实时间戳和版本号作用是一样的,异曲同工.对一个数据进行操作的时候把时间戳也带上,我们知道时间戳每一刻都在变化,所以任何时候的操作只要记录下来时间戳,后续再对该数据进行操作时带上时间戳进行判断就解决了 ABA 问题。

05、CAS乐观锁效率很高,synchronized 效率不是很好,为什么 JDK 1.8 之后反而增加了 synchronized 同步锁的数量?

synchronized 之前一直都是重量级的锁,但是后来 Java 官方是对它进行过升级了,它现在采用的是锁升级的方式去做的。

针对 sychronized 获取锁的方式,JVM 使用了锁升级的优化方式,就是先使用偏向锁优先同一线程,然后再次获取锁:

  • 如果失败,就会升级为 CAS 轻量级锁;
  • 如果失败就会短暂自旋,防止线程被系统挂起;
  • 最后如果上述都失败,就会升级为重量级锁。

所以 synchronized 升级后,刚开始不是直接就使用重量级锁,而是通过很多轻量级锁的方式一步步升级上去的。

10.4 总结

01、JDK 1.7

JDK 1.7 中 ConcurrentHashMap 使用的是分段锁来减小锁粒度,分割成若干个 Segment,然后每一个 Segment 上同时只有一个线程可以操作,每一个 Segment 都是一个类似 HashMap 数组的结构,它可以扩容,它的冲突会转化为链表。但是 Segment 的个数一但初始化就不能改变。

在 put 的时候需要锁住 Segment;get 的时候不加锁,使用 volatile 来保证可见性。当要统计全局时(size()),首先会尝试多次计算 modcount 来确定:这几次尝试中,是否有其他线程进行了修改操作。如果没有,则直接返回 size;如果有,则需要依次锁住所有的 Segment 来计算。

02、JDK 1.8

JDK 1.8 中的 ConcurrentHashMap 使用的 Synchronized 锁加 CAS 的机制。结构也由 JDK 1.7 中的 Segment 数组 + HashEntry 数组 + 链表 进化成了 Node 数组 + 链表 / 红黑树,Node 是类似于一个 HashEntry 的结构。它的冲突再达到一定大小时会转化成红黑树,在冲突小于一定数量时又退回链表。

03、对比

1.8 之前 put 定位节点时要先定位到具体的 segment,然后再在 segment 中定位到具体的桶。而在 1.8 的时候摒弃了 segment 臃肿的设计,直接针对的是 Node[] tale数组中的每一个桶,进一步减小了锁粒度。并且防止拉链过长导致性能下降,当链表长度大于 8 的时候采用红黑树的设计。

主要设计上的变化有以下几点:

  1. 不采用 segment 而采用 node,锁住 node 来实现减小锁粒度。
  2. 设计了 MOVED 状态,在 resize 的过程中线程 2 还在 put 数据的话,线程 2 会帮助 resize。
  3. 使用 3 个 CAS 操作来确保 node 的一些操作的原子性,这种方式代替了锁。
  4. sizeCtl 的不同值来代表不同含义,起到了控制的作用。
  5. 采用 synchronized 而不是 ReentrantLock。