Trie implementation with Python 3

Posted on

Problem

This is just a classic implementation of a trie using python 3.6, with types. I am a PHP developer, so these are my first attempts with Python.

I know I can improve the readability of my code, it’s a bit messy, but I am struggling finding the right way of doing it. I’m sure I am missing a lot of the Pythonic ways in here 🙂

from typing import Optional


class Trie:
    def __init__(self):
        self.root = Node(char=None, is_word=False)

    def add(self, word: str) -> None:
        current = self.root
        is_end_of_word = False

        for i in range(0, len(word)):
            char = word[i]
            if i == len(word) - 1:
                is_end_of_word = True

            if not current.children:
                node_to_insert = Node(char, is_end_of_word)
                current.add_child(node_to_insert)
                current = node_to_insert
                continue

            if char not in current.children:
                node_to_insert = Node(char, is_end_of_word)
                current.add_child(node_to_insert)
                current = node_to_insert
            else:
                current = current.children[char]

    def contains(self, word: str) -> bool:
        current = self.root

        for char in word:
            if not current.children:
                return False

            if char in current.children:
                current = current.children[char]
                continue
            else:
                return False

        if not current.is_word:
            return False

        return True

    def remove(self, word: str) -> None:
        current = self.root

        for i in range(0, len(word)):
            char = word[i]
            is_end_of_word = False

            if i == len(word) - 1:
                is_end_of_word = True

            if char in current.children:
                if current.children[char].is_word and is_end_of_word:
                    current.children[char].is_word = False
                    return
                current = current.children[char]
            else:
                return

    def retrieve_all_words(self) -> list:
        return self._retrieve_all_words(self.root, '', [])

    def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
        for child in current.children.values():
            word = word + child.char
            if child.is_word:
                words.append(word)

            if child.children:
                self._retrieve_all_words(child, word, words)
                word = word[:-1]
                continue

            word = word[:-1]

        return words


class Node:
    def __init__(self, char: Optional[str], is_word: bool):
        self.char = char
        self.is_word = is_word
        self.children = {}

    def add_child(self, node: 'Node'):
        self.children[node.char] = node

Solution

Your code looks very good and the type annotation is a really nice touch
Also, your code is well organised and easy to understand. Some documentation could be a good idea (but it’s better to have no doc rather than bad doc). Also, you could consider writing unit tests.

Let’s try to see what could be improved/made more Pythonic.

Loop like a native

I highly recommand Ned Batchelder's excellent talk "loop like a native". One of the idea is that whenever you are using for i in range(len(iterable)), you are probably doing it wrong. In your case, you could use for i, char in enumerate(word):.

End of word

The end-of-word detection could be done in a single statement: is_end_of_word = i == len(word) - 1. Also, you can get rid of the definition before the loop and even in the loops, sometimes, you could define it only behind the if char in current.children: because you use it only there.

Reorganise the logic

Sometimes, you check if something is empty and then if it contains a particular element. This can be factorised out:

Also, simplifying the code if (cond) return False else return True, you could rewrite contains:

def contains(self, word: str) -> bool:
    current = self.root
    for char in word:
        if char not in current.children:
            return False
        current = current.children[char]
    return current.is_word

Then in _retrieve_all_words, you can get rid of continue by using a simple else which makes things more explicit. Then, you can actually factorise out the common code at the end of the 2 branches and get the more simple:

        if child.is_word:
            words.append(word)
        if child.children:
            self._retrieve_all_words(child, word, words)
        word = word[:-1]

Finally, you can use += to simplify word = word + child.char into word += child.char.

At this stage, the code looks like:

from typing import Optional

class Trie:
    def __init__(self):
        self.root = Node(char=None, is_word=False)

    def add(self, word: str) -> None:
        current = self.root

        for i, char in enumerate(word):
            if char in current.children:
                current = current.children[char]
            else:
                is_end_of_word = i == len(word) - 1 
                node_to_insert = Node(char, is_end_of_word)
                current.add_child(node_to_insert)
                current = node_to_insert

    def contains(self, word: str) -> bool:
        current = self.root
        for char in word:
            if char not in current.children:
                return False
            current = current.children[char]
        return current.is_word

    def remove(self, word: str) -> None:
        current = self.root
        for i, char in enumerate(word):
            if char not in current.children:
                return
            is_end_of_word = i == len(word) - 1 
            if current.children[char].is_word and is_end_of_word:
                current.children[char].is_word = False
                return
            current = current.children[char]

    def retrieve_all_words(self) -> list:
        return self._retrieve_all_words(self.root, '', [])

    def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
        for child in current.children.values():
            word += child.char
            if child.is_word:
                words.append(word)
            if child.children:
                self._retrieve_all_words(child, word, words)
            word = word[:-1]
        return words


class Node:
    def __init__(self, char: Optional[str], is_word: bool):
        self.char = char
        self.is_word = is_word
        self.children = {}

    def add_child(self, node: 'Node'):
        self.children[node.char] = node

t = Trie()
t.add("toto")
t.add("tutu")
t.add("foobar")
print(t)
print(t.retrieve_all_words())
print(t.contains("tot"))
print(t.contains("toto"))
print(t.contains("totot"))

Storing things in a different format

Instead of maintaining a is_word attribute, maybe you could add store a sentinel like None to say that we have a full word here. This could be written like this:

from typing import Optional

class Trie:
    def __init__(self):
        self.root = Node(char=None)

    def add(self, word: str) -> None:
        current = self.root

        for char in list(word) + [None]:
            if char in current.children:
                current = current.children[char]
            else:
                node_to_insert = Node(char)
                current.add_child(node_to_insert)
                current = node_to_insert

    def contains(self, word: str) -> bool:
        current = self.root
        for char in list(word) + [None]:
            if char not in current.children:
                return False
            current = current.children[char]
        return True 

    def remove(self, word: str) -> None:
        current = self.root
        for char in list(word) + [None]:
            if char not in current.children:
                return
            elif char is None:
                del current.children[char]
            else:
                current = current.children[char]

    def retrieve_all_words(self) -> list:
        return self._retrieve_all_words(self.root, '', [])

    def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
        for child in current.children.values():
            if child.char is None:
                words.append(word)
            else:
                word += child.char
                if child.children:
                    self._retrieve_all_words(child, word, words)
                word = word[:-1]
        return words


class Node:
    def __init__(self, char: Optional[str]):
        self.char = char
        self.children = {}

    def add_child(self, node: 'Node'):
        self.children[node.char] = node

t = Trie()
t.add("toto")
t.add("tutu")
t.add("foobar")
print(t)
print(t.retrieve_all_words())
print(t.contains("tot"))
print(t.contains("toto"))
print(t.contains("totot"))
t.remove("toto")
print(t.retrieve_all_words())
print(t.contains("toto"))

Back to word = word + child.char

Something I should have mentionned earlier but spotted on the late.

In _retrieve_all_words, you add a char to word only to remove it afterward. It is much clearer (and more efficient?) to just write word + char in the places where you actually need to word with the added character.

Then, the code becomes:

def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
    for child in current.children.values():
        if child.char is None:
            words.append(word)
        elif child.children:
            self._retrieve_all_words(child, word + child.char, words)
    return words

this is a small thing, but the API you provide isn’t very pythonic. rather than using a containsmethod, you should overload in. similarly, rather than defining return_all_words, you should define iteration, so you can loop directly through it, and then the list conversion will just be list(tried). these may seem insignificant, but this type of consistency is what makes python feel nice to use.

Leave a Reply

Your email address will not be published. Required fields are marked *