C++ regex golf solver program much slower than original Python

Posted on

Problem

I am rewriting Peter Norvig’s Python regex golf solver in C++. I had originally planned to post the whole thing here at once when I was finished, but:

  • This program takes double the time of the Python version on the driver input shown in main(). And that’s just to complete all the functions I have so far: in half the time, the Python program is able to generate solutions both ways. C++ is supposed to be faster than Python; is there a way to speed this up?
  • The code is starting to smell in general and it needs a checkup.

Here is what I have so far:

#include <algorithm>
#include <iostream>
#include <map>
#include <regex>
#include <set>
#include <string>
#include <sstream>
#include <utility>
#include <vector>

using std::endl;

typedef std::pair<std::string, std::set<std::string>> covers_pair;
typedef std::map<std::string, std::set<std::string>> regex_covers_t;

std::string replacements(char c)
{
    return c == '^' || c == '$'? std::string{c} : std::string{c} + '.';
}


std::set<std::string> subparts(const std::string& word, size_t subpart_size=5)
{
    std::set<std::string> subparts;
    for (size_t i = 0; i < word.size(); i++) {
        for (size_t s = 0; s < subpart_size; s++) {
            subparts.insert(word.substr(i, s + 1));
        }
    }
    return subparts;
}


std::set<std::string> dotify(const std::string& part)
{
    if (part.empty()) {
        return std::set<std::string>{""};
    }
    auto result = std::set<std::string>();
    for (const auto& rest: dotify(part.substr(1))) {
        for (const auto& c: replacements(part.front())) {
            result.insert(c + rest);
        }
    }
    return result;
}


const auto rep_chars = {'*', '+', '?'};

std::set<std::string> repetitions(const std::string& part)
{
    const auto dots = std::string("..");
    auto repetitions = std::set<std::string>();
    for (size_t i = 1, limit = part.size(); i <= limit; i++) {
        auto A = part.substr(0, i);
        auto B = part.substr(i);
        bool valid_last_char = A.back() != '^' && A.back() != '$';
        bool no_adjoining_dots = !(A.back() == '.' && B.front() == '.');
        if (valid_last_char && no_adjoining_dots &&
                !std::equal(dots.crbegin(), dots.crend(), A.crbegin()))
        {
            for (const auto& q: rep_chars) {
                repetitions.insert(A + q + B);
            }
        }
    }
    return repetitions;
}


std::string join(
        const std::set<std::string> container, const std::string& delim)
{
    if (container.empty()) {
        return "";
    }   
    std::ostringstream joiner;
    std::string joined_string;
    std::copy(
            container.cbegin(), container.cend(),
            std::ostream_iterator<std::string>(joiner, delim.c_str()));
     joined_string = joiner.str();
    // The std::copy call above will append the delimiter to the end, so
    // we now remove it.
    joined_string.erase(
            joined_string.end() - delim.size(), joined_string.end());
    return joined_string;
}


std::set<std::string> create_parts_pool(const std::set<std::string>& winners)
{
    auto wholes = std::set<std::string>();
    for (const auto& winner: winners) {
        wholes.insert('^' + winner + '$');
    }

    auto parts = std::set<std::string>();
    for (const auto& whole: wholes) {
        for (const auto& subpart: subparts(whole)) {
            for (const auto& dotified: dotify(subpart)) {
                parts.insert(dotified);
            }
        }
    }

    auto winners_str = join(winners, "");
    auto charset = std::set<char>(winners_str.begin(), winners_str.end());
    auto chars = std::set<std::string>();
    for (const auto& c: charset) {
        chars.emplace(1, c);
    }


    auto pairs = std::set<std::string>();
    for (const auto& A: chars) {
        for (const auto& B: chars) {
            for (const auto& q: rep_chars) {
                pairs.insert(A + '.' + q + B);
            }
        }
    }

    auto reps = std::set<std::string>();
    for (const auto& part: parts) {
        for (const auto& repetition: repetitions(part)) {
            reps.insert(repetition);
        }
    }

    std::set<std::string> pool;
    for (auto set: {wholes, parts, chars, pairs, reps}) {
        std::set_union(
                pool.begin(), pool.end(), set.begin(), set.end(),
                std::inserter(pool, pool.begin()));
    }

    return pool;
}

regex_covers_t regex_covers(
        const std::set<std::string>& winners,
        const std::set<std::string>& losers)
{
    const auto& pool = create_parts_pool(winners);
    regex_covers_t covers;
    for (const auto& part: pool) {
        auto re = std::regex(part);
        bool matched_loser = false;
        for (const auto& loser: losers) {
            if (regex_search(loser, re)) {
                matched_loser = true;
                break;
            }
        }
        if (matched_loser) {
            continue;
        }

        std::set<std::string> matched_winners;
        for (const auto& winner: winners) {
            if (std::regex_search(winner, re)) {
                matched_winners.insert(winner);
            }
        }

        covers[part] = matched_winners;
    }

    return covers;
}


std::vector<std::string> create_sorted_parts_list(const regex_covers_t& covers)
{
    std::vector<std::string> sorted_parts;
    for (const covers_pair& pair: covers) {
        sorted_parts.push_back(pair.first);
    }
    auto sort_lambda = [&covers](const std::string& r1, const std::string& r2)
    {
        auto r1_pair = std::make_pair(
                -static_cast<int>(covers.at(r1).size()),
                r1.size());
        auto r2_pair = std::make_pair(
                -static_cast<int>(covers.at(r2).size()),
                r2.size());
        return r1_pair < r2_pair;
    };
    std::sort(sorted_parts.begin(), sorted_parts.end(), sort_lambda);

    return sorted_parts;
}


regex_covers_t eliminate_dominated(const regex_covers_t& covers)
{
    auto covers_copy = covers;
    regex_covers_t new_covers;

    const auto& sorted_parts = create_sorted_parts_list(covers);
    for (const auto& r: sorted_parts) {
        if (covers_copy.at(r).empty()) {
            // All remaining r must not cover anything
            break;
        }

        auto is_dominated_lambda = [&covers_copy, &r](const covers_pair& pair)
        {
            const auto& r2 = pair.first;
            bool is_superset = std::includes(
                    pair.second.cbegin(), pair.second.cend(),
                    covers_copy.at(r).cbegin(), covers_copy.at(r).cend());
            bool is_shorter = r2.size() <= r.size();
            return is_superset && is_shorter && r != r2;
        };
        bool is_dominated = std::any_of(
                covers_copy.cbegin(), covers_copy.cend(), is_dominated_lambda);
        if (!is_dominated) {
            new_covers[r] = covers_copy[r];
        } else {
            covers_copy.erase(covers_copy.find(r));
        }
    }

    return new_covers;
}

/* Driver code I've used for testing. */
int main()
{
    auto twelve_sons_of_jacob = std::set<std::string>{
            "reuben", "simeon", "levi", "judah", "dan", "naphthali", "gad",
            "asher", "issachar", "zebulun", "joseph", "benjamin"};
    auto twelve_disciples = std::set<std::string>{
            "peter", "andrew", "james", "john", "philip", "bartholomew",
            "judas", "judas iscariot", "simon", "thomas", "matthew"};
    for (const auto& part: eliminate_dominated(
            regex_covers(twelve_sons_of_jacob, twelve_disciples)))
    {
        std::cout << "'" << part.first << "': [";
        for (const auto& covered: part.second) {
            std::cout << "'" << covered << "', ";
        }
        std::cout << ']' << endl;
    }
    std::cout << endl;
}

For reference, here is Norvig’s original, much more readable Python code:

rep_chars = ('*', '+', '?')
cat = ''.join


def subparts(word, subpart_size=5):
    "Return a set of subparts of word, consecutive characters up to length 5."
    return set(word[i:i+1+s] for i in range(len(word)) for s in range(subpart_size))


def dotify(part):
    "Return all ways to replace a subset of chars in part with '.'."
    if part == '':
        return {''}  
    else:
        return {c+rest for rest in dotify(part[1:]) 
                for c in replacements(part[0])}

def replacements(char):
    "Return replacement characters for char (char + '.' unless char is special)."
    if (char == '^' or char == '$'):
        return char
    else:
        return char + '.'


def repetitions(part):
    """Return a set of strings derived by inserting a single repetition character ('+' or '*' or '?') 
    after each non-special character.  Avoid redundant repetition of dots."""
    splits = [(part[:i], part[i:]) for i in range(1, len(part)+1)]
    return {A + q + B for (A, B) in splits
            if not (A[-1] in '^$')
            and not A.endswith('..')
            and not (A.endswith('.') and B.startswith('.'))
            for q in rep_chars}


def regex_covers(winners, losers):
    """
    Generate regex components and return a dict of {regex: {winner...}}.
    Each regex matches at least one winner and no loser.
    """
    losers_str = 'n'.join(losers)
    wholes = {'^'+winner+'$' for winner in winners}
    parts = {d for w in wholes for p in subparts(w) for d in dotify(p)}
    chars = set(cat(winners))
    pairs = {A+'.'+q+B for A in chars for B in chars for q in rep_chars}
    reps = {r for p in parts for r in repetitions(p)}
    pool = wholes | parts | pairs | reps                         
    searchers = [re.compile(c, re.MULTILINE).search for c in pool]
    return {r: set(filter(searcher, winners)) 
            for (r, searcher) in zip(pool, searchers)
            if not searcher(losers_str)}


def eliminate_dominated(covers):
    """Given a dict of {regex: {winner...}}, make a new dict with only the regexes
    that are not dominated by any others. A regex r is dominated by r2 if r2 covers 
    a superset of the matches covered by r, and r2 is shorter."""
    newcovers = {}
    def signature(r): return (-len(covers[r]), len(r))
    for r in sorted(covers, key=signature):
        if not covers[r]: break # All remaining r must not cover anything
        # r goes in newcache if it is not dominated by any other regex
        if not any(covers[r2] >= covers[r] and len(r2) <= len(r) 
                   for r2 in newcovers):
            newcovers[r] = covers[r]
    return newcovers

Solution

Your eliminate_dominated function is not equivalent to the original. Try this version:

regex_covers_t eliminate_dominated(const regex_covers_t& covers)
{
    regex_covers_t new_covers;

    const auto& sorted_parts = create_sorted_parts_list(covers);
    for (const auto& r: sorted_parts) {
        if (covers.at(r).empty()) {
            // All remaining r must not cover anything
            break;
        }

        auto is_dominated_lambda =
            [&covers, &r](const covers_pair& pair) {
                const auto& r2 = pair.first;
                bool is_superset = std::includes(
                        pair.second.cbegin(), pair.second.cend(),
                        covers.at(r).cbegin(), covers.at(r).cend());
                bool is_shorter = r2.size() <= r.size();
                return is_superset && is_shorter;
            };

        bool is_dominated = std::any_of(
                new_covers.cbegin(), new_covers.cend(),
                is_dominated_lambda);

        if (!is_dominated) {
            new_covers[r] = get(covers,r);
        }
    }

    return new_covers;
}

With this fixed, the C++ version runs about 3 times faster than the Python version on my system using g++ -O2.

You might consider replacing std::set with a different data structure, such as bool[] (for the chars, with appropriate typecasting) or std::unordered_set. The problem with std::set is that it guarantees to be able to iterate through the set in sorted order.

Have you profiled it? Where is it spending its time?

Leave a Reply

Your email address will not be published. Required fields are marked *