Problem
For a programming challenge I was doing, I needed to traverse a list of lists, and while writing the standard 2D traversal function, I thought why not generalise it, so I did. Here’s my best effort:
My Code
def rTraverse(lst, f = lambda x: x): #Traverse a list and any of its (non self referential) sublists recursively.
for item in lst:
if isinstance(item, list):
if "..." in str(item): #Skip self-referential lists.
continue
traverse(item, f) #Traverse the sublist.
else:
f(item)
Solution
The way you keep track of lists which you are currently exploring is a bit inefficient. You basically have a stack which contains the first argument of the function for each recursive call. This exact information is also stored on the execution stack of the Python interpreter, so it’s duplicated.
A more efficient way would be to abandon recursion and use a cycle instead, since you are already using a stack. That way, you will also save some more space by not duplicating the constant (non-changing) arguments f
and seen
on the execution stack.
Another, even more significant way to improve efficiency is by observing that the check if lst[idx] not in seen
takes linear time, although this can be done in constant time.
1. Removing recursion
We can start with the existing stack of lists and extend it as necessary. Since we are basically doing a dept-first search, upon descent to the deeper level we need to remember the list which we are currently iterating over, as well as the position within that list (so we can later return and continue where we left off). This can be done for example by having two lists, one containing the lists and the other one containing the positions. Or we can have just one list containing 2-tuples of the form (list, position)
.
2. Faster cyclic reference detection
Instead of searching through the stack, item by item, we can use a set
as you suggested in a comment. This will allow us to detect cyclic references in constant time and the whole issue will be settled.
An even better approach would be to combine the position-tracking with fast membership checking: we could have a list of lists (the stack), and a dictionary mapping each list to our position within that list. Then a list will be used as a key in the dictionary if and only if it is in the stack. Note that adding, updating and deleting items in a hashed dictionary are all essentially constant time operations, and so is checking whether a key is present. So we achieve constant time cycle detection, much like we would with a set
, but in a more elegant way.
One “problem” here may be that if we have the same list in the stack twice, with two different positions, which position do we put in the dictionary? Or can we somehow put both of them there? The answer is, we don’t care, because this will never happen. Since we don’t want to visit the same list twice on our path down the tree, the stack will never contain the same list twice. Now the only problem is that Python doesn’t allow us to use list
objects as keys in dict
.
3. Using lists as keys in a dictionary
Since list
is not a hashable type, it cannot be used as a key in dict
(which is a hashed dictionary). However, the solution is quite simple: we use id(lst)
instead of the lst
itself as the key. We can do this, since we only care about identity, not equality. As a side note, this is another case in which your program behaves incorrectly. It compares lists with the items in seen
based on equality, not identity: if lst[idx] not in seen: …
. Consider following code:
>>> a = []
>>> a.append(a)
>>> b = [a]
>>> a == b
True
>>> id(a) == id(b)
False
>>> a in [b]
True
What should happen if you call rTraverse(b)
? I suppose you would want to traverse both a
and b
.
Refactored code
The code contains a (very basic) implementation of a custom data structure which facilitates the above-proposed approach. This data structure is then employed in the modified traverse
function.
class CustomStack(list):
"""
Our custom data structure to facilitate position-tracking
and provide constant time membership checking for cycle detection.
This data structure serves as the stack, hence inherits from `list`.
It has an `item_pos` attribute which serves as the dictionary.
"""
def __init__(self, *args, **kwargs):
super().__init__((item for item, pos in args), **kwargs)
self.item_pos = {id(item): pos for item, pos in args}
def append(self, item, pos):
if item not in self:
super().append(item)
self.item_pos[id(item)] = pos
def pop(self):
del self.item_pos[id(self.top)]
return super().pop()
def __contains__(self, item):
return id(item) in self.item_pos
def __getitem__(self, idx):
item = super().__getitem__(idx)
pos = self.item_pos[id(item)]
return item, pos
@property
def top(self):
return super().__getitem__(-1)
def setpos(self, pos):
"""
This functions allows us to update the saved position within the list
which is being currently explored (= is last in the stack).
"""
self.item_pos[id(self.top)] = pos
def traverse(lst, f = lambda x: x):
stk = CustomStack((lst, 0)) # initial position in the list is 0
while stk:
curr_list, curr_pos = stk[-1] # don't pop the top of the stack yet
for idx in range(curr_pos, len(curr_list)): # continue from the saved position
item = curr_list[idx]
if isinstance(item, list) and item not in stk:
stk.setpos(idx+1) # update the current position to restore it later
stk.append(item, 0) # push the new list onto the stack
break # we are going depth-first into the new, deeper list
else:
curr_list[idx] = f(item)
else:
# we did not break out of the `for` loop. that means we're done
# with this list and we are returning to the previous level.
stk.pop()
Recursive version
Since the code got a bit too long for such a simple task, we can try going back to using recursion, and use a set
as you proposed. Similarly to the case with a dictionary, we have to use id(lst)
instead of lst
, because items in a set
are hashed. Considering AJNeufeld’s remarks on thread and exception safety, we can use a dummy default value for the seen
parameter, and if we see that seen
has this dummy value (as opposed to a “real” value, eg. a set
instance), then we create a set
instance as a local variable on the function’s stack frame. Upon recursively calling the function, we pass a reference to this local variable, therefore no thread safety or exception safety issues ensue.
def traverse_simple(lst, f=lambda x: x, seen=None):
if seen is None:
seen = set()
seen.add(id(lst))
for i, item in enumerate(lst):
if isinstance(item, list) and id(item) not in seen:
traverse_simple(item, f, seen)
else:
lst[i] = f(item)
seen.remove(id(lst))
By using id(lst)
, we are able to not only use a set and therefore detect cyclic references in constant time; we are also able to compare lists by identity instead of equality. This means the function does not fail on inputs on which your original function does fail, such as this:
>>> a = []
>>> b = [a]
>>> a.append(b)
Well, one failure case that I didn’t consider is that sublists that contain strings with “…” inside them break the function. Furthermore, converting a list to a string is an expensive operation. After consulting with people outside SE, I was able to refine the algorithm further:
Refactored Version
"""
* A faster version of `rTraverse()` that uses a set instead of a stack to store a list's ancestry.
* Searching a set is O(1) in the average case, while searching a list is O(n) in the average case, so `rTraverse2()` would run faster.
* Params:
* `lst`: the list to be traversed.
* `f`: the function to be applied to the list items.
* `seen`: a set that stores the ancestry of the current sublist. (would be used for checking if a sublist is self referential).
* Return:
* None: The list is modified in place.
* Caveats:
* The function no longer traverses in order.
* Credits:
* The main insight(s) for the algorithm comes from "raylu" on the [programming discord](https://discord.gg/010z0Kw1A9ql5c1Qe)
"""
def rTraverse2(lst, f, seen=None):
seen = set() if seen is None else seen #Initialise the set.
toRecurse = [] #The list of sublists to be recursed on.
for idx in range(len(lst)):
if isinstance(lst[idx], list):
if id(lst[idx]) not in seen: #Traverse only non self referential sublists.
toRecurse.append(lst[idx]) #Add the sublist to the list of sublists to be recursed upon.
else:
lst[idx] = f(lst[idx])
seen.update(id(x) for x in toRecurse)
for item in toRecurse: #Traverse all sublists in `toRecurse`.
rTraverse2(item, f, seen)
```