Python – Why is the worst case for this function O(n^2)

big opython

I'm trying to teach myself how to calculate BigO notation for an arbitrary function. I found this function in a textbook. The book asserts that the function is O(n2). It gives an explanation as to why this is, but I'm struggling to follow. I wonder if someone might be able to show me the math behind why this is so. Fundamentally, I understand that it is something less than O(n3), but I couldn't independently land on O(n2)

Suppose we are given three sequences of numbers, A, B, and C. We will
assume that no individual sequence contains duplicate values, but that
there may be some numbers that are in two or three of the sequences.
The three-way set disjointness problem is to determine if the
intersection of the three sequences is empty, namely, that there is no
element x such that x ∈ A, x ∈ B, and x ∈ C.

Incidentally, this is not a homework problem for me — that ship has sailed years ago : ), just me trying to get smarter.

def disjoint(A, B, C):
        """Return True if there is no element common to all three lists."""  
        for a in A:
            for b in B:
                if a == b: # only check C if we found match from A and B
                   for c in C:
                       if a == c # (and thus a == b == c)
                           return False # we found a common value
        return True # if we reach this, sets are disjoint

[Edit]
According to the textbook:

In the improved version, it is not simply that we save time if we get
lucky. We claim that the worst-case running time for disjoint is
O(n2).

The book's explanation, which I struggle to follow, is this:

To account for the overall running time, we examine the time spent
executing each line of code. The management of the for loop over A
requires O(n) time. The management of the for loop over B accounts for
a total of O(n2) time, since that loop is executed n different times.
The test a == b is evaluated O(n2) times. The rest of the time spent
depends upon how many matching (a,b) pairs exist. As we have noted,
there are at most n such pairs, and so the management of the loop over
C, and the commands within the body of that loop, use at most O(n2)
time. The total time spent is O(n2).

(And to give proper credit …) The book is:
Data Structures and Algorithms in Python by Michael T. Goodrich et. all, Wiley Publishing, pg. 135

[Edit] A justification; Below is the code before optimization:

def disjoint1(A, B, C):
    """Return True if there is no element common to all three lists."""
       for a in A:
           for b in B:
               for c in C:
                   if a == b == c:
                        return False # we found a common value
return True # if we reach this, sets are disjoint

In the above, you can clearly see that this is O(n3), because each loop must run to its fullest. The book would assert that in the simplified example (given first), the third loop is only a complexity of O(n2), so the complexity equation goes as k + O(n2) + O(n2) which ultimately yields O(n2).

While I cannot prove this is the case (thus the question), the reader can agree that the complexity of the simplified algorithm is at least less than the original.

[Edit] And to prove that the simplified version is quadratic:

if __name__ == '__main__':
    for c in [100, 200, 300, 400, 500]:
        l1, l2, l3 = get_random(c), get_random(c), get_random(c)
        start = time.time()
        disjoint1(l1, l2, l3)
        print(time.time() - start)
        start = time.time()
        disjoint2(l1, l2, l3)
        print(time.time() - start)

Yields:

0.02684807777404785
0.00019478797912597656
0.19134306907653809
0.0007600784301757812
0.6405444145202637
0.0018095970153808594
1.4873297214508057
0.003167390823364258
2.953308343887329
0.004908084869384766

Since the second difference is equal, the simplified function is indeed quadratic:

enter image description here

[Edit] And yet even further proof:

If I assume worst case (A = B != C),

if __name__ == '__main__':
    for c in [10, 20, 30, 40, 50]:
        l1, l2, l3 = range(0, c), range(0,c), range(5*c, 6*c)
        its1 = disjoint1(l1, l2, l3)
        its2 = disjoint2(l1, l2, l3)
        print(f"iterations1 = {its1}")
        print(f"iterations2 = {its2}")
        disjoint2(l1, l2, l3)

yields:

iterations1 = 1000
iterations2 = 100
iterations1 = 8000
iterations2 = 400
iterations1 = 27000
iterations2 = 900
iterations1 = 64000
iterations2 = 1600
iterations1 = 125000
iterations2 = 2500

Using the second difference test, the worst case result is exactly quadratic.

enter image description here

Best Answer

The book is indeed correct, and it provides a good argument. Note that timings are not a reliable indicator of algorithmic complexity. The timings might only consider a special data distribution, or the test cases might be too small: algorithmic complexity only describes how resource usage or runtime scales beyond some suitably large input size.

The book makes the argument that complexity is O(n²) because the if a == b branch is entered at most n times. This is non-obvious because the loops are still written as nested. It is more obvious if we extract it:

def disjoint(A, B, C):
  AB = (a
        for a in A
        for b in B
        if a == b)
  ABC = (a
         for a in AB
         for c in C
         if a == c)
  for a in ABC:
    return False
  return True

This variant uses generators to represent intermediate results.

  • In the generator AB, we will have at most n elements (because of the guarantee that input lists won't contain duplicates), and producing the generator takes O(n²) complexity.
  • Producing the generator ABC involves a loop over the generator AB of length n and over C of length n, so that its algorithmic complexity is O(n²) as well.
  • These operations are not nested but happen independently, so that the total complexity is O(n² + n²) = O(n²).

Because pairs of input lists can be checked sequentially, it follows that determining whether any number of lists are disjoint can be done in O(n²) time.

This analysis is imprecise because it assumes that all lists have the same length. We can say more precisely that AB has at most length min(|A|, |B|) and producing it has complexity O(|A|•|B|). Producing ABC has complexity O(min(|A|, |B|)•|C|). Total complexity then depends how the input lists are ordered. With |A| ≤ |B| ≤ |C| we get total worst-case complexity of O(|A|•|C|).

Note that efficiency wins are possible if the input containers allow for fast membership tests rather than having to iterate over all elements. This could be the case when they are sorted so that a binary search can be done, or when they are hash sets. Without explicit nested loops, this would look like:

for a in A:
  if a in B:  # might implicitly loop
    if a in C:  # might implicitly loop
      return False
return True

or in the generator-based version:

AB = (a for a in A if a in B)
ABC = (a for a in AB if a in C)
for a in ABC:
  return False
return True
Related Topic