Problem
I am trying to solve a coding challenge:
Given N integers A1, A2, …, AN, count the number of triplets (x, y, z) (with 1 ≤ x < y < z ≤ N) such that at least one of the following is true:
Ax = Ay × Az, and/or
Ay = Ax × Az, and/or
Az = Ax × AySample case 1
5 2 4 6 3 1
In Sample Case #1, the only triplet satisfying the condition given in the problem statement is (2, 4, 5). The triplet is valid since the second, fourth, and fifth integers are 2, 6, and 3, and 2 × 3 = 6. thus the answer here is 1.
Sample case 2
2 4 8 16 32 64
The six triplets satisfying the condition given in the problem statement are: (1, 2, 3), (1, 3, 4), (1, 4, 5), (1, 5, 6), (2, 3, 5), (2, 4, 6). so the answer here is 6.
My Code in python:
import itertools
count=0
for t in itertools.combinations(l,3):
if t[0]*t[1]==t[2] or t[1]*t[2]==t[0] or t[0]*t[2]==t[1]:
count+=1
print(count)
This is the naive way of generating all possible 3 length combinations and checking for the condition. This works fine for smaller input but when the inout size increase time complexity increases. I am assuming for an example that has 1,2,3,6,8
the combinations generated are (2,3,6),(2,3,8)
2,3,6 satisfy the condition so the checking for 2,3,8 is unnecessary and can be avoided. How can I modify my code to take advantage of this observation ?
Solution
Your combinations(…, 3)
loop makes your algorithm O(N3).
It’s easy to improve it to be O(N2). The question is, essentially: for every pair of entries, how many occurrences of their product are in the list? So, make an indexes
data structure to help you find, in O(1) time, where the product might be located.
from collections import defaultdict
from itertools import combinations
a = [int(ai) for ai in input('Input: ').split()]
indexes = defaultdict(set)
for i, ai in enumerate(a):
indexes[ai].add(i)
triplets = set()
for x, y in combinations(range(len(a)), 2):
for z in indexes[a[x] * a[y]].difference([x, y]):
triplets.add(tuple(sorted((x, y, z))))
print(len(triplets))
Here, I’ve chosen to stick closer to the notation used in the challenge itself, with a
being the list, and x
, y
, z
as the indexes of entries (but 0-based rather than 1-based).
Your answer is a faithful implementation of the problem description. Unfortunately, that makes it O(N3)O(N3)
200_success has described an O(N2)O(N2) solution to this problem; I’m thinking we can do a little better, perhaps not as good as O(NlogN)O(NlogN), but maybe close.
We’ve not been asked to find the tuples (x,y,z)
, x < y < z
, which satisfy that A[x],A[y],A[z]
represent (in any order) multiplier, multiplicand, and product; we’ve been asked to find only the number of such tuples. This means we can change the order of values in A[]
, reindex the A[]
array, or even change the representation of A[]
to a list of value, count pairs, and still obtain the same answer.
The first step should be to sort the A[]
values. O(NlogN)O(NlogN). This guarantees A[x] ≤ A[y] ≤ A[z]
, when x < y < z
, and we now only need to look for tuples where A[x]*A[y]=A[z]
.
We can loop for x in range(N-2)
and for y in range(x+1, N-1)
, and find the product’s indexes in the manner suggested by 200_success, but now we can break out of the inner loop when A[x]*A[y] > A_max
, and break the outer loop when A[x]*A[x+1] > A_max
, which should reduce the number of iterations significantly.
But again we really don’t care what the indices are which correspond to the product. We only care about how many there are. So after sorting A[]
, we can count the number of occurrences of each unique value C[]
, and eliminate the duplicates from A[]
. If C[p] > 0
for p = A[x] * A[y]
, then the number of tuples for that multiplier, multiplicand, product combination is C[A[x]] * C[A[y]] * C[p]
, if A[x] ≠ 1
. When A[x]=1
, then p=A[y]
for all x<y
, and the number of tuples for these combinations is C[1] * C[p] * (C[p]-1)/2
. In addition, there are C[1] * (C[1]-1) * (C[1]-2)/6
combinations of 1*1=1
. Lastly, we need to count any combinations of bases and their squares: p = A[x] * A[x]
which is C[p] * C[A[x]] * (C[A[x]]-1)/2
.
from collections import Counter
def tuple_products(*A):
c = Counter(A) # Record # of duplicates
A = sorted(c) # Sorted list of unique values
largest = A[-1] # For early loop termination
triplets = 0 # Number of product triplets
# Handle (1, A[y], A[y]) triplets first
if A[0] == 1:
c1 = c[1] # Number of 1's
# Number of (1, 1, 1) triplets
triplets += c1 * (c1-1) * (c1-2) // 6
A = A[1:] # Remove 1 from A list and
del c[1] # from count dictionary
# Number of (1, A[y], A[y]) triplets (A[y] != 1)
triplets += c1 * sum(cy * (cy-1) for cy in c.values()) // 2
# Handle (A[x], A[y], A[z]) triplets (A[x] != 1)
for x, ax in enumerate(A[:-1]):
# Break outer loop if beyond possible products
square = ax*ax
if square > largest:
break
# Number of (A[x], A[x], A[x]^2) triplets
cx = c[ax]
triplets += cx * (cx-1) * c[square] // 2
# Handle (A[x], A[y], A[z]) triplets
for ay in A[x+1:-1]:
# Break inner loop when beyond possible products
product = ax*ay
if product > largest:
break
# Number of (A[x], A[y], A[z]) triplets
triplets += cx * c[ay] * c[product]
print(triplets)
tuple_products(4,4,4,4,16)
tuple_products(5,2,4,6,3,1)
tuple_products(2,4,8,16,32,64)
tuple_products(1,1,1,1)
tuple_products(8,1,1,4,2,1,4,1,2)
We are still O(N2)O(N2), but we’ve made N smaller by eliminating any duplicate numbers, and eliminated many of the N2N2 combinations by sorting and breaking out of the loops early. Can we do any better? I think so. Here’s why:
Consider the numbers 3, 5, 7, 8, 9, 10, 11, 14, 36, 42, 45. The outer loop would start at x=0, A[x]=3
. The inner loop would start at y=1, A[y]=5
, with the product product=3*5=15
. Since A[]
is sorted, we can perform a bisection search for the 15
and find the next number higher than 15
, which is 36
. Dividing 36/A[x]
gives us 12
, performing a bisection search, and getting the next higher number retrieves 14
. Multiplying A[x]*14
gives us 42
, which can actually be found in the A[]
list. Advancing past 14
takes us to 36
and 3*36
is greater than the maximum so we can break out of the inner loop. Using this bisection search technique has skipped over the numbers 7, 8, 9, 10, and 11. I think this means we’ve got log N
numbers in each inner loop, and a log N
bisection search giving an inner loop complexity of O(log2N)O(log2N), which combined with the outer loop gives O(Nlog2N)O(Nlog2N). I think.