後始末のできるThreadLocal

/**
 * GCでインスタンスが解放される前に任意の処理を行うためのThreadLocal。<br/>
 * 必要な解放処理は{@link Closeable#close()}に実装する。<br/>
 * 注意点:<br/>
 * <ul>
 * <li>このThreadLocal以外でインスタンスの強い参照を持っていても、スレッドがなくなると{@link Closeable#close()}が呼ばれる。</li>
 * <li>set(A)の後にset(B)とすると、Aの{@link Closeable#close()}が呼ばれる。</li>
 * <li>インスタンスが解放の対象になっても、いずれかのスレッドで{@link #set(T)}/{@link #get()}/{@link #remove()}が呼ばれなければ解放されない。</li>
 * </ul>
 *
 * @param <T> 保持するインスタンスの型
 */
public class CloseableThreadLocal<T extends Closeable> extends ThreadLocal<T> {

  // リフレクションを使って別スレッドのThreadLocal変数をクリアするためのアクセサ
  private static Field THREADLOCALS_FIELD = null;
  private static Method THREAD_LOCAL_MAP_REMOVE_METHOD = null;
  private static boolean REFLECTION_OK = false;

  static {
    try {
      THREADLOCALS_FIELD = Thread.class.getDeclaredField("threadLocals");
      THREADLOCALS_FIELD.setAccessible(true);
      THREAD_LOCAL_MAP_REMOVE_METHOD = THREADLOCALS_FIELD.getType().getDeclaredMethod("remove", ThreadLocal.class);
      THREAD_LOCAL_MAP_REMOVE_METHOD.setAccessible(true);
      REFLECTION_OK = true;
    } catch (Exception e) {
    }
  }

  // GCで解放されたスレッドが参照していたインスタンスが格納されるキュー
  private final ReferenceQueue<Thread> _queue = new ReferenceQueue<>();

  /**
   * スレッドとインスタンスを関連づけるクラス
   */
  private class CloseableHolder extends WeakReference<Thread> {

    private AtomicReference<T> _value;

    public CloseableHolder(T closeable) {
      super(Thread.currentThread(), _queue);
      _value = new AtomicReference<>(closeable);
    }

    public void close() {
      @SuppressWarnings("resource")
      T closeable = _value.getAndSet(null);
      if (closeable != null) {
        try {
          closeable.close();
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }
    }

    public T getCloseable() {
      return _value.get();
    }

    public void setCloseable(T closeable) {
      _value.set(closeable);
    }
  }

  @SuppressWarnings("unchecked")
  private void cleanup() {
    CloseableHolder ref = null;
    while ((ref = (CloseableHolder)_queue.poll()) != null) {
      ref.close();
    }
  }

  private final Map<Thread, CloseableHolder> _threads = Collections.synchronizedMap(new WeakHashMap<Thread, CloseableHolder>());

  @Override
  public T get() {
    cleanup();
    T ret = super.get();
    if (ret == null) {
      return null;
    }
    CloseableHolder ref = _threads.get(Thread.currentThread());
    if (ref != null) {
      if (ref.getCloseable() != ret) {
        ref.close();
        ref.setCloseable(ret);
      }
    } else {
      ref = new CloseableHolder(ret);
      _threads.put(Thread.currentThread(), ref);
    }
    return ret;
  }

  @Override
  public void set(T value) {
    cleanup();
    CloseableHolder ref = _threads.get(Thread.currentThread());
    if (ref != null) {
      if (ref.getCloseable() != value) {
        ref.close();
        ref.setCloseable(value);
      }
    } else {
      if (value != null) {
        ref = new CloseableHolder(value);
        _threads.put(Thread.currentThread(), ref);
      }
    }
    super.set(value);
  }

  @Override
  public void remove() {
    cleanup();
    CloseableHolder ref = _threads.get(Thread.currentThread());
    if (ref != null) {
      ref.close();
    }
    super.remove();
  }

  /**
   * 保持しているすべてのインスタンスを解放する
   */
  public void clear() {
    // ThreadLocalMapにアクセスできない場合はそっちと矛盾が出るのでエラーとする
    if (!REFLECTION_OK) {
      throw new IllegalStateException("ThreadLocalMap reflection access failed.");
    }
    List<Thread> keys = new ArrayList<>(_threads.keySet());
    for (Thread t : keys) {
      CloseableHolder ref = _threads.remove(t);
      if (ref == null) {
        continue;
      }
      ref.close();
      forceRemove(t);
    }
  }

  private void forceRemove(Thread thread) {
    try {
      Object threadLocals = THREADLOCALS_FIELD.get(thread);
      if (threadLocals != null) {
        THREAD_LOCAL_MAP_REMOVE_METHOD.invoke(threadLocals, this);
      }
    } catch (Exception e) {
    }
  }

}