Re[4]: Параллельный статистический алгоритм решения задачи о сумме подмножества
От: maxkar  
Дата: 13.06.13 18:00
Оценка: 7 (1)
Здравствуйте, virusmxa, Вы писали:

V>Я не сообразил, как динамикой вытащить результат, думал в сторону сокращения перебора с помощью статистики.

V>Суммы были 100 — 1000 рублей, динамика с извлечением результирующего подмножества была бы в тему, только как?

Да запросто. Стандартное решение для большинства подобных задач — вместо битового флага "состояние достижимо" хранить информацию о том, "как пришли". Если информация отсутствует, значит "не пришли". В зависимости от пожеланий в ячейке хранится либо "любой" путь (если достаточно любого решения), либо список "предыдущих" ячеек (не пути, а именно список!). В этом случае все пути — это объединение списков от предыдущих вершин с добавленной текущей. Что-то вроде allPathsTo(x) = allPossible1StepPreceedors(x).map(p => allPathsTo(p).map(y => x +: y)).flatten (это не конкретный язык, а общая идея, ближе всего к scala по синтаксису). В олимпиадном контексте пути не строятся, а рекурсивно генерируются (т.е. выбрали 1 шаг, прошли, сгенерировали все хвосты, выбрали следующих шаг, сгенерировали еще хвосты). 99% олимпиадных задач на динамику решаются именно так.

Под катом простенькие примеры (разные), которые можно погонять.
  Скрытый текст
package ru.maxkar.java;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class Generator {

    private static final int PAR = 8;
    private static final int PAR_MIN = 300;

    private static final ExecutorService exec = Executors
            .newFixedThreadPool(PAR);

    private static void doRange(int[] cur, int[] next, int id, int size,
            int min, int max) {
        for (int i = min; i < max; i++)
            if (cur[i] >= 0)
                next[i + size] = id;
    }

    public static void fillProducts(int[] cur, int[] next, int id, int size) {
        doRange(cur, next, id, size, 0, cur.length - size);
    }

    public static void fillProductsPar(final int[] cur, final int[] next,
            final int id, final int size) {
        final int realLast = (cur.length - size);
        final int step = realLast / PAR;
        int ptr = 0;
        final List<Future<?>> futures = new ArrayList<>();
        for (int i = 0; i < PAR - 1; i++) {
            final int start = ptr;
            final int end = ptr + step;
            futures.add(exec.submit(new Runnable() {
                @Override
                public void run() {
                    doRange(cur, next, id, size, start, end);
                }
            }));
            ptr = end;
        }
        final int lastPtr = ptr;
        futures.add(exec.submit(new Runnable() {
            @Override
            public void run() {
                doRange(cur, next, id, size, lastPtr, realLast);
            }
        }));

        for (Future<?> f : futures)
            try {
                f.get();
            } catch (InterruptedException e) {
                // FIXME: !
            } catch (ExecutionException e) {
                // FIXME: !
            }
    }

    public static void fillProductsAny(int[] cur, int[] next, int id, int size) {
        final int realLast = (cur.length - size);
        if (realLast <= PAR_MIN)
            fillProducts(cur, next, id, size);
        else
            fillProductsPar(cur, next, id, size);
    }

    public static void printAnswer(int[][] ans, int sum, int[] weights) {
        if (ans[ans.length - 1][sum] == -1) {
            System.out.println("Not found");
            return;
        }

        System.out.print("Found :");
        int rest = sum;
        int ptr = ans[ans.length - 1][rest];
        while (ptr > 0) {
            final int wname = ptr - 1;
            System.out.print(" " + weights[wname]);
            rest -= weights[wname];
            ptr = ans[wname][rest];
        }
        System.out.println();
    }

    public static void printAnswerAlt(int[][] ans, int sum, int[] weights) {
        if (ans[ans.length - 1][sum] == -1) {
            System.out.println("Not found");
            return;
        }

        System.out.print("Found :");
        int rest = sum;
        for (int i = weights.length; i -- > 0;) {
            if (ans[i][rest] < 0) {
                final int w = weights[i];
                System.out.print(" " + w);
                rest -= w;
            }
        }
        System.out.println();
    }
    
    public static void qsolve(int sum, int[] weights) {
        final int[][] steps = new int[weights.length + 1][];
        final int[] start = new int[sum + 1];
        Arrays.fill(start, -1);
        start[0] = 0;
        steps[0] = start;
        
        int[] queue = new int[sum + 1];
        int queec = 1;
        queue[0] = 0;

        for (int i = 0; i < weights.length; i++) {
            final int[] next = steps[i].clone();
            final int id = i + 1;
            steps[id] = next;
            
            final int w = weights[i];
            
            final int queel = queec;
            for (int queep = 0; queep < queel; queep++) {
                final int qw = queue[queep];
                final int nw = qw + w;
                if (nw <= sum && next[nw] < 0) {
                    next[nw] = id;
                    queue[queec++] = nw;
                }
            }
        }
        printAnswer(steps, sum, weights);
    }

    public static void solve(int sum, int[] weights) {
        final int[][] steps = new int[weights.length + 1][];
        int[] current = new int[sum + 1];
        Arrays.fill(current, -1);
        current[0] = 0;
        steps[0] = current;

        for (int i = 0; i < weights.length; i++) {
            final int id = i + 1;
            final int[] next = current.clone();
            steps[id] = next;
            fillProductsAny(current, next, id, weights[i]);
            current = next;
        }
        printAnswer(steps, sum, weights);
        printAnswerAlt(steps, sum, weights);
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        solve(1, new int[] { 1, 2, 4, 6, 8, 16, 32, 64, 128, 256 });
        solve(100, new int[] { 1, 2, 4, 6, 8, 16, 32, 64, 128, 256 });
        
        solve(321230, new int[] { 16648, 28450, 13524, 12991, 13593, 4797,
                18596, 10680, 17865, 25254, 23808, 27379, 6657, 3325, 7573,
                5126, 14361, 17311, 4578, 9632, 12375, 15291, 13814, 7935,
                22872, 26001, 7895, 9854, 20650, 3317, 23612, 5258, 15092,
                1171, 7528, 26422, 19202, 23453, 18628, 4293, 15008, 24880,
                24119, 21476, 29331, 3093, 8820, 11034, 6879, 29859, 24913,
                22327, 29880, 24401, 383, 28490, 1960, 7673, 29307, 19311,
                24713, 2632, 11248, 2261, 17632, 16322, 26870, 17694, 23047,
                13155, 15627, 11247, 25238, 8480, 6236, 4343, 16317, 2668,
                11793, 9485, 10799, 4823, 16498, 23618, 14526, 12048, 18346,
                28581, 7984, 731, 17158, 26949, 21531, 23965, 17754, 12625,
                7170, 1924, 26609, 21243, 18748, 13556, 19710, 25036, 8157,
                23372, 26228, 353, 24819, 11351, 28590, 10866, 14783, 895,
                15260, 15715, 2238, 26571, 14970, 13764, 19156, 6041, 1615,
                2947, 24500, 19495, 9775, 4290, 9661, 1083 });
        exec.shutdown();
    }
}


qsolve — это не "quick solve", это "queued solve"! На моей машине показало нулевое преимущество перед обычным последовательным решением (это fillProducts вместо fillProductsAny в solve). Фокус в том, что прямой проход по массиву (doRange) очень хорошо оптимизируется процессором (prefetch работает). На очереди prefetch не работает (доступ к допустимым значениям — случайный), поэтому хотя алгоритмическая сложность в начале гораздо лучше, в результате выходит примерно так же. Приведен потому, что был написан. Ну и в некоторых случаях подобный подход все же может оказаться лучше, чем полное сканирование.

PAR — уровень параллелизма для чуть усложненного алгоритма. 8 — только потому что у меня 4 ядра с HyperThreading. И 8 там не нужно. После 4-х никакого роста не видно. Собственно, основной рост — с однопроцессорного (fillProducts всегда) до двухпроцессорного. С 2 до 4 рост слабее, чем с 1 до 2. Накладные расходы на синхронизацию получаются слишком большие (а прогон выполняется очень и очень быстро). Последний solve у меня 100 раз выполняется за 5-6 секунд (без "прогрева" JVM, я не ставил задачу хоть сколько нибудь точно сравнивать алгоритмы).

Можно еще интересное наблюдение сделать. У вас задача хорошо ложится на динамику. При этом для "восстановления" решения можно даже и не хранить точные "координаты" шага, их можно вычислить вручную при построении решения. Именно этим и занимается printAnswerAlt. Он строит некоторый достижимый путь (он старается начальные значения использовать). Подобный подход можно и для более общей формы динамики использовать. Но там может больше итераций потребоваться на поиск предыдущего шага. В общем случае если не хранить информацию о шаге, потребуется в два раза больше времени. Фактически поиск пути получения нужного "next step" может потребовать все те же шаги, что использовались при генерации матрицы достижимости, причем это только при поиске "любого" пути, все пути генерировать будет долго.

Из простоты восстановления предыдущего шага следует еще более замечательный факт. На достаточно ограниченной памяти можно считать очень сложные и объемные задачи при условии, что мы ищем "любое" решение. Смотрите. Определение: шаг — итерация внешнего цикла динамики (не процесс добавления конкретного веса к уже построенной цепочке). Для восстановления одного шага нам достаточно одного состояния и целевой вершины (при условии, что она достижима). Два шага (и какой-то путь) мы можем восстановить, сгенерировав заново матрицу для одного шага. После этого мы можем найти, из какого состояния в ней достижимо финальное, а затем (за один шаг) — как достигнуть этого промежуточного состояния. Потребуется 2*M памяти (M — память для состояния одного шага). Одно состояние — начальное, второе — через один шаг. Сколько памяти нужно для вычисления пути за 4 шага? Правильно, 3M. Вычисляем состояние через 2 шага (игнорируем промежуточное). Последние два шага мы решаем за добавочную память M (см. предыдущий шаг), и в промежуточной матрице мы знаем состояние, ведущее к цели. Теперь уже на этом интервале (начальная->промежуточная) аналогично за одну дополнительную (уже к первой!) матрице вычисляем еще два шага. Путь "на 8 шагов" вычисляется за 4*M памяти (аналогично, вычисляем сначала среднее состояние, потом шаги в каждом из двух). Общая сложность (по сравнению построить все за одних проход) возрастает линейно от количества таких "удвоений". Каждую позицию мы проходим по разу на степень двойки (на 1 итерацию, на 2, на 4, на 8 и т.д.). Т.е. по сравнению с обычным построением у нас сложность вырастает в log2(items) раз. При этом требуемая память будет порядка log2(items) * M. Если брать M порядка 100000000 (сто миллионов, это миллион рублей с точностью копеек, 100Мб памяти на "уровень" итерации), мы сможем вычислять объемы items порядка 1024-4096 на достаточно обычной машине (2-4Гб памяти). Там по моим прикидкам выходят часы/дни. При меньшей максимальной сумме меньше проходов и мы можем еще больше items посчитать (но при этом уже само количество items может стать решающим, итоговая сложность получается O(sum * items * log2(items)). Так что решение с динамикой в разумных пределах получается хорошо масштабируемо. Мне кажется, на таких числах ваше решение с перебором уйдет либо в глубокие переборы, либо у него кончится память . Ах да, на практике вместо "микрошагов" может быть выгоднее взять упакованные данные (по 8 значений в одной памяти) и убрать несколько проходов генерации результата. Но это нужно измерять на целевой машине (работа с упакованными данными медленее, чем с обычными).
 
Подождите ...
Wait...
Пока на собственное сообщение не было ответов, его можно удалить.