Problem
I have this following code which find every possible pair of numbers that sum up to n number
lst = [int(input()) for i in range(int(input()))]
num = int(input()) #Amount to be matched
cnt = 0
lst = list(filter(lambda i: i <= num, lst)) #Remove number that more than `num`
for i in lst:
for j in lst:
if i+j == num:
cnt += 1
print(int(cnt/2))
For example, if I enter
5 #How many numbers
1 < #Start of number input
4 <
5
7
1 < #End of number input
5 #Amount to be matched
It will return 2 because there is two pair that their sum is equal to 5 (1,4) and (4,1) (The number I marked with < ).
The problem is the complexity is O(n2) which will run slow on large input. I wanted to know if that is there a way to make this run faster?
Another example:
10 #How many numbers
46 #Start of number input
35
27
45
16
0 <
30 <
30 <
45
37 #End of number input
30 #Amount to be matched
The pair will be (0, 30) and (0, 30) which will return 2.
Solution
Removing numbers greater than the target affects the correctness of this program – they could be added to negative numbers to reach the target.
You could get a performance improvement by finding the difference between each element and the target, then looking to see if that difference is in the list. This doesn’t in itself reduce the computational complexity (it’s still O(n²)), but we can build on that: if the list is first sorted, and we then use a binary search to test membership we get to O(n log n); if we convert to a structure with fast lookup (such as a collections.Counter
, which has amortized O(1) insertion and lookup), then we come down to O(n).
If we have a Counter
, then we can account for all combinations of that pair by multiplying one count by the other (but we’ll need to consider the special case that the number is exactly half the target).
We could do with some auto tests. Consider importing the doctest
module and using it. Some good test cases to include:
1, [] → 0
1, [1] → 0
1, [0, 1] → 1
0, [-1, 1] → 1
0, [0, 1] → 0
4, [1, 4, 3, 0] → 2
4, [1, 1, 3, 3] → 4
4, [2, 2, 2, 2] → 6
So far every solution given has been O(n2) or O(n log n), but there is an O(n) solution, which is sketched as follows:
- Get your input into a list, as you have done so. Obviously this is O(n)
- Create an empty map from integers to integers. The map must have an O(1) insertion operation and an O(1) contains-key operation and an O(1) lookup operation and an O(n) “enumerate all the keys” operation. Commonly-used map types in modern programming languages typically have these characteristics.
- Build a count of all the items in the input. That is, for each input item, check to see if it is in the map. If it is not, add the pair (item, 1) to the map. If it is already in the map, look up the associated value and change the map so that it has the pair (item, value + 1). All those operations are O(1) and we do them n times, so this step is O(n).
- Now we take our target, call it
sum
and we wish to enumerate the pairs which add to that target. Enumerate the keys of the map. Suppose the key isk
. We computesum-k
. Now there are two cases.- Case 1: if
sum-k == k
then check the map to see if the value associated withk
is 2 or greater. If it is, then we have a pair(k, k)
. Output it. - Case 2: if
sum-k
is notk
then check the map to see ifsum-k
is in the map. If it is, then we have a pair(k, sum-k)
.
- Case 1: if
- The enumeration enumerates at most
n
keys, and each step is O(1), so this step is also O(n) - And we’re done, with total cost O(n).
Now, can you implement this solution?
Other minor suggestions:
- Don’t leave your
input()
s blank. Pass a prompt so that the user knows what they’re entering. - The first time you initialize
lst
, it doesn’t need to be memory; it can be left as a generator (parens instead of brackets). - The second time you initialize
lst
, it does not need to be mutable, so make it atuple
instead of alist
.
You can use a sorting algorithm to first sort the list, this can be done (in many ways) with a complexity of O(nlog(n)).
Once the list is sorted (small to large for example), the problem can be solved with complexity O(n) as followed:
head = 0
tail = len(list) - 1
while (head < tail):
sum = list[head] + list[tail]
if sum == num:
cnt += 1
head += 1
tail -= 1
elif sum > num:
tail -= 1
else:
head += 1
This results in an overall complexity of O(nlog(n))
Note : This runs under the assumption that all elements are unique. This can be easily fixed depending on how you want to handle cases with duplicate elements.