Problem
This is just a classic implementation of a trie using python 3.6, with types. I am a PHP developer, so these are my first attempts with Python.
I know I can improve the readability of my code, it’s a bit messy, but I am struggling finding the right way of doing it. I’m sure I am missing a lot of the Pythonic ways in here 🙂
from typing import Optional
class Trie:
def __init__(self):
self.root = Node(char=None, is_word=False)
def add(self, word: str) -> None:
current = self.root
is_end_of_word = False
for i in range(0, len(word)):
char = word[i]
if i == len(word) - 1:
is_end_of_word = True
if not current.children:
node_to_insert = Node(char, is_end_of_word)
current.add_child(node_to_insert)
current = node_to_insert
continue
if char not in current.children:
node_to_insert = Node(char, is_end_of_word)
current.add_child(node_to_insert)
current = node_to_insert
else:
current = current.children[char]
def contains(self, word: str) -> bool:
current = self.root
for char in word:
if not current.children:
return False
if char in current.children:
current = current.children[char]
continue
else:
return False
if not current.is_word:
return False
return True
def remove(self, word: str) -> None:
current = self.root
for i in range(0, len(word)):
char = word[i]
is_end_of_word = False
if i == len(word) - 1:
is_end_of_word = True
if char in current.children:
if current.children[char].is_word and is_end_of_word:
current.children[char].is_word = False
return
current = current.children[char]
else:
return
def retrieve_all_words(self) -> list:
return self._retrieve_all_words(self.root, '', [])
def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
for child in current.children.values():
word = word + child.char
if child.is_word:
words.append(word)
if child.children:
self._retrieve_all_words(child, word, words)
word = word[:-1]
continue
word = word[:-1]
return words
class Node:
def __init__(self, char: Optional[str], is_word: bool):
self.char = char
self.is_word = is_word
self.children = {}
def add_child(self, node: 'Node'):
self.children[node.char] = node
Solution
Your code looks very good and the type annotation is a really nice touch
Also, your code is well organised and easy to understand. Some documentation could be a good idea (but it’s better to have no doc rather than bad doc). Also, you could consider writing unit tests.
Let’s try to see what could be improved/made more Pythonic.
Loop like a native
I highly recommand Ned Batchelder's excellent talk "loop like a native"
. One of the idea is that whenever you are using for i in range(len(iterable))
, you are probably doing it wrong. In your case, you could use for i, char in enumerate(word):
.
End of word
The end-of-word detection could be done in a single statement: is_end_of_word = i == len(word) - 1
. Also, you can get rid of the definition before the loop and even in the loops, sometimes, you could define it only behind the if char in current.children:
because you use it only there.
Reorganise the logic
Sometimes, you check if something is empty and then if it contains a particular element. This can be factorised out:
Also, simplifying the code if (cond) return False else return True
, you could rewrite contains
:
def contains(self, word: str) -> bool:
current = self.root
for char in word:
if char not in current.children:
return False
current = current.children[char]
return current.is_word
Then in _retrieve_all_words
, you can get rid of continue
by using a simple else
which makes things more explicit. Then, you can actually factorise out the common code at the end of the 2 branches and get the more simple:
if child.is_word:
words.append(word)
if child.children:
self._retrieve_all_words(child, word, words)
word = word[:-1]
Finally, you can use +=
to simplify word = word + child.char
into word += child.char
.
At this stage, the code looks like:
from typing import Optional
class Trie:
def __init__(self):
self.root = Node(char=None, is_word=False)
def add(self, word: str) -> None:
current = self.root
for i, char in enumerate(word):
if char in current.children:
current = current.children[char]
else:
is_end_of_word = i == len(word) - 1
node_to_insert = Node(char, is_end_of_word)
current.add_child(node_to_insert)
current = node_to_insert
def contains(self, word: str) -> bool:
current = self.root
for char in word:
if char not in current.children:
return False
current = current.children[char]
return current.is_word
def remove(self, word: str) -> None:
current = self.root
for i, char in enumerate(word):
if char not in current.children:
return
is_end_of_word = i == len(word) - 1
if current.children[char].is_word and is_end_of_word:
current.children[char].is_word = False
return
current = current.children[char]
def retrieve_all_words(self) -> list:
return self._retrieve_all_words(self.root, '', [])
def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
for child in current.children.values():
word += child.char
if child.is_word:
words.append(word)
if child.children:
self._retrieve_all_words(child, word, words)
word = word[:-1]
return words
class Node:
def __init__(self, char: Optional[str], is_word: bool):
self.char = char
self.is_word = is_word
self.children = {}
def add_child(self, node: 'Node'):
self.children[node.char] = node
t = Trie()
t.add("toto")
t.add("tutu")
t.add("foobar")
print(t)
print(t.retrieve_all_words())
print(t.contains("tot"))
print(t.contains("toto"))
print(t.contains("totot"))
Storing things in a different format
Instead of maintaining a is_word
attribute, maybe you could add store a sentinel like None
to say that we have a full word here. This could be written like this:
from typing import Optional
class Trie:
def __init__(self):
self.root = Node(char=None)
def add(self, word: str) -> None:
current = self.root
for char in list(word) + [None]:
if char in current.children:
current = current.children[char]
else:
node_to_insert = Node(char)
current.add_child(node_to_insert)
current = node_to_insert
def contains(self, word: str) -> bool:
current = self.root
for char in list(word) + [None]:
if char not in current.children:
return False
current = current.children[char]
return True
def remove(self, word: str) -> None:
current = self.root
for char in list(word) + [None]:
if char not in current.children:
return
elif char is None:
del current.children[char]
else:
current = current.children[char]
def retrieve_all_words(self) -> list:
return self._retrieve_all_words(self.root, '', [])
def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
for child in current.children.values():
if child.char is None:
words.append(word)
else:
word += child.char
if child.children:
self._retrieve_all_words(child, word, words)
word = word[:-1]
return words
class Node:
def __init__(self, char: Optional[str]):
self.char = char
self.children = {}
def add_child(self, node: 'Node'):
self.children[node.char] = node
t = Trie()
t.add("toto")
t.add("tutu")
t.add("foobar")
print(t)
print(t.retrieve_all_words())
print(t.contains("tot"))
print(t.contains("toto"))
print(t.contains("totot"))
t.remove("toto")
print(t.retrieve_all_words())
print(t.contains("toto"))
Back to word = word + child.char
Something I should have mentionned earlier but spotted on the late.
In _retrieve_all_words
, you add a char to word
only to remove it afterward. It is much clearer (and more efficient?) to just write word + char
in the places where you actually need to word with the added character.
Then, the code becomes:
def _retrieve_all_words(self, current: 'Node', word: str, words: list) -> list:
for child in current.children.values():
if child.char is None:
words.append(word)
elif child.children:
self._retrieve_all_words(child, word + child.char, words)
return words
this is a small thing, but the API you provide isn’t very pythonic. rather than using a contains
method, you should overload in
. similarly, rather than defining return_all_words
, you should define iteration, so you can loop directly through it, and then the list conversion will just be list(tried)
. these may seem insignificant, but this type of consistency is what makes python feel nice to use.