循環依存関係の検出を高速化

寝付けなかったので書いてみた。rt.jarも1分以内に終わる。

以下、プログラム。

ちなみに、↓とまったく同じ手順でやってる。
http://www.hgc.jp/~tshibuya/classes/shibuya20050225.pdf

import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import org.objectweb.asm.ClassReader;

public class DependencyAnalyzer {
    
    public static void main(String[] args) throws IOException {
        if (args.length != 1) {
            System.err.println("usage: java DependencyAnalyzer <class-library>");
            throw new IllegalArgumentException(Arrays.deepToString(args));
        }
        Map<String, Integer> dictionary;
        BitSet[] direct;
        ZipFile zip = new ZipFile(args[0]);
        try {
            // 辞書を作る
            dictionary = buildDictionary(zip);
            // 直接依存関係の表を作る
            direct = buildDirectDependencies(zip, dictionary);
        }
        finally {
            zip.close();
        }

        // ループを検出し、循環依存関係の表に変換する
        BitSet[] cyclic = toCyclicDependencies2(direct);
        
        // ループの一覧を出力
        BitSet saw = new BitSet();
        String[] classes = reverse(dictionary);
        for (int i = 0, n = dictionary.size(); i < n; i++) {
            if (!saw.get(i) && cyclic[i].cardinality() >= 2 && cyclic[i].get(i)) {
                System.out.println(extract(cyclic[i], classes));
                // このループを「一度見た」に追加
                saw.or(cyclic[i]);
            }
        }
    }
    
    /**
     * ZIPに含まれるクラス名の辞書を返す。
     * クラス名は「バイナリ名」と呼ばれる表現で保持され、マップの各値は0から連続する自然数である。
     * @param zip 対象のZIPファイル
     * @return クラス名の辞書
     */
    private static Map<String, Integer> buildDictionary(ZipFile zip) {
        Map<String, Integer> dictionary = new HashMap<String, Integer>();
        Enumeration<? extends ZipEntry> entries = zip.entries();
        while (entries.hasMoreElements()) {
            ZipEntry elem = entries.nextElement();
            String name = elem.getName();
            if (name.endsWith(".class") && name.indexOf('$') < 0) {
                // .class を除去
                name = name.substring(0, name.length() - 6/*".class".length*/);
                // 辞書にシリアル番号とともに追加
                dictionary.put(name, dictionary.size());
            }
        }
        return dictionary;
    }
    
    /**
     * ZIPに含まれるクラスライブラリ間の直接依存関係の表を返す。
     * クラスの番号は{@code dic.get(<class-name>)}が返す値を利用し、
     * i -> j に依存関係がある場合は result[i].get(j) がtrueとなるような行列を表現する値を返す。
     * @param zip 対象のZIPファイル
     * @param dic クラス名の辞書
     * @return 直接依存関係の表 (result[i].get(j) がtrueのとき、i -> j への直接依存関係)
     */
    private static BitSet[] buildDirectDependencies(ZipFile zip, Map<String, Integer> dic) {
        // a -> b という直接の依存がある場合、direct[index(a), index(b)] = true にする。
        BitSet[] direct = newMatrix(dic.size());
        for (Map.Entry<String, Integer> dicEntry: dic.entrySet()) {
            ZipEntry entry = zip.getEntry(dicEntry.getKey() + ".class");
            int sourceIndex = dicEntry.getValue();
            try {
                // 指定のエントリから依存するすべてのクラス名を抽出する
                Set<String> dep = collectDependentTypes(zip.getInputStream(entry));
                for (String name: dep) {
                    if (dic.containsKey(name)) {
                        Integer targetIndex = dic.get(name);
                        direct[sourceIndex].set(targetIndex);
                    }
                }
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        return direct;
    }

    /**
     * 指定のストリームに格納されたクラスが依存するクラス名の一覧を返す。
     * @param in 対象のストリーム
     * @return 依存するクラス名の一覧
     * @throws IOException 読み込みに失敗した場合
     */
    private static Set<String> collectDependentTypes(InputStream in) throws IOException {
        ClassReader cr;
        try {
            cr = new ClassReader(in);
        }
        finally {
            in.close();
        }
        
        Set<String> dependent = new HashSet<String>();
        // 依存しているクラスの一覧を作成 (手抜き)
        // ASMってコンスタントプール読めないのかよorz
        // see http://java.sun.com/docs/books/jvms/second_edition/html/ClassFile.doc.html
        // u4 magic; u2 minor_version; u2 major_version; u2 constant_pool_count;
        int constantPoolCount = cr.readUnsignedShort(4 + 2 + 2);
        char[] buf = new char[1024];
        for (int i = 1; i < constantPoolCount; i++) {
            int itemIndex = cr.getItem(i);
            int tag = cr.readByte(itemIndex - 1);
            if (tag == 7/* CONSTANT_Class */) {
                String name = cr.readUTF8(itemIndex, buf);
                dependent.add(name);
            }
            else if (tag == 5 || tag == 6/* CONSTANT_(Long|Double)*/) {
                i++;
            }
        }
        
        return dependent;
    }
    
    /**
     * 間接依存関係の表を、閉路(ループ)の組み合わせに変換して返す。
     * @param direct 直接依存関係の表 (result[i].get(j) がtrueのとき、i -> j への間接依存関係)
     * @return 依存関係の閉路の組み合わせ。
     *      result[i, j] がtrueのとき、i, jは循環依存関係を持つ
     * @throws NullPointerException 引数に{@code null}が指定された場合
     */
    private static BitSet[] toCyclicDependencies2(BitSet[] direct) {
        int[] postOrder = postOrder(direct);

        int[] transpose = new int[postOrder.length];
        for (int i = 0; i < transpose.length; i++) {
            transpose[transpose.length - postOrder[i]] = i;
        }
        
        BitSet[] result = new BitSet[direct.length];
        
        int[] pathStack = new int[direct.length];
        int[] branchStack = new int[direct.length];
        BitSet saw = new BitSet(direct.length);
        for (int i: transpose) {
            if (!saw.get(i)) {
                BitSet group = new BitSet(transpose.length);
                group.set(i);
                pathStack[0] = i;
                branchStack[0] = 0;
                int sp = 0;
                while (sp >= 0) {
                    int nodeNum = pathStack[sp];
                    group.set(nodeNum);
                    int next = -1;
                    for (int j = 0; j < direct.length; j++) {
                        if (direct[j].get(nodeNum) && !saw.get(j) && !group.get(j)) {
                            next = j;
                            branchStack[sp] = j + 1;
                            break;
                        }
                    }
                    if (next >= 0) {
                        sp++;
                        pathStack[sp] = next;
                        branchStack[sp] = 0;
                    }
                    else {
                        sp--;
                    }
                }
                for (int j = group.nextSetBit(0); j >= 0; j = group.nextSetBit(j + 1)) {
                    result[j] = group;
                }
                saw.or(group);
            }
        }
        return result;
    }

    /**
     * 間接依存関係の表から、ポストオーダーの列を生成して返す。
     * @param direct 間接依存関係の表
     * @return ポストオーダーの列
     */
    private static int[] postOrder(BitSet[] direct) {
        BitSet rest = new BitSet(direct.length); // 訪問前かどうか
        rest.set(0, direct.length);
        int[] result = new int[direct.length]; // ポストオーダーの結果
        
        int serial = 1;
        int[] pathStack = new int[direct.length];
        int[] branchStack = new int[direct.length];
        while (!rest.isEmpty()) {
            pathStack[0] = rest.nextSetBit(0);
            branchStack[0] = 0;
            int sp = 0;
            while (sp >= 0) {
                int nodeNum = pathStack[sp];
                rest.clear(nodeNum);
                BitSet node = direct[nodeNum];
                int next = -1;
                for (
                        int i = node.nextSetBit(branchStack[sp]);
                        i >= 0;
                        i = node.nextSetBit(i + 1)) {
                    if (rest.get(i)) {
                        next = i;
                        branchStack[sp] = i + 1;
                        break;
                    }
                }
                if (next >= 0) {
                    sp++;
                    pathStack[sp] = next;
                    branchStack[sp] = 0;
                }
                else {
                    result[nodeNum] = serial++;
                    sp--;
                }
            }
        }
        return result;
    }

    private static BitSet[] newMatrix(int size) {
        BitSet[] loops = new BitSet[size];
        for (int i = 0; i < size; i++) {
            loops[i] = new BitSet(size);
        }
        return loops;
    }
    
    private static SortedSet<String> extract(BitSet set, String[] dic) {
        SortedSet<String> result = new TreeSet<String>();
        for (int i = set.nextSetBit(0); i >= 0; i = set.nextSetBit(i + 1)) {
            result.add(dic[i]);
        }
        return result;
    }
    
    private static String[] reverse(Map<String, Integer> dictionary) {
        String[] result = new String[dictionary.size()];
        for (Map.Entry<String, Integer> entry: dictionary.entrySet()) {
            result[entry.getValue()] = entry.getKey().replace('/', '.');
        }
        return result;
    }
}