TDD approach for Advent of Code challenge: infinite spiral on a grid

Posted on

Problem

I have started the Advent of Code 2017 to learn to use Python and Test Driven Development. I wrote a solution for Day 3, Part 2, and would like to hear some help.

Specifically:

  • Do I apply TDD correctly?
  • Are there OOP practices I am missing?
  • Are there Python-specific issues in my code?

I’ll explain the nomenclature (rings, chains, etc.) I used here at the bottom of the post.

Here is the test file:

# Just run this file with 'python 2_tests.py'
import unittest
from day3_2 import Grid

class GridTest(unittest.TestCase):

    def setUp(self):
        self.grid = Grid(rings=4)

    def test_get_coord_by_id(self):
        self.assertEqual(self.grid.get_coord_by_i(0), [0, 0])
        self.assertEqual(self.grid.get_coord_by_i(17), [-1, -2])        

    def test_first_element(self):
        """
        This tests that the first element in the Ulam spiral is 1
        """
        self.assertEqual(self.grid.get_elem_by_i(0), 1)

    def test_get_elem_by_i(self):
        self.assertEqual(self.grid.get_elem_by_i(1), 1)
        self.assertEqual(self.grid.get_elem_by_i(2), 2)
        self.assertEqual(self.grid.get_elem_by_i(3), 4)
        self.assertEqual(self.grid.get_elem_by_i(4), 5)
        self.assertEqual(self.grid.get_elem_by_i(5), 10)
        self.assertEqual(self.grid.get_elem_by_i(6), 11)
        self.assertEqual(self.grid.get_elem_by_i(7), 23)

    def test_get_chain(self):
        self.assertEqual(self.grid.get_chain(1), [1])
        self.assertEqual(self.grid.get_chain(5), [1, 1, 2, 4, 5])        

    def test_get_first_greater_than(self):
        self.assertEqual(self.grid.get_first_greater_than(8), 10)
        self.assertEqual(self.grid.get_first_greater_than(122), 133)
        self.assertEqual(self.grid.get_first_greater_than(1234), 1968)
        self.assertEqual(self.grid.get_first_greater_than(99999999), None)

if __name__ == "__main__":
    unittest.main()

And this is the main code:

class Grid:

    def __init__(self, rings):
        self.rings = rings
        self.n_elems = (2*rings - 1) ** 2
        self.chain = [0 for i in range(self.n_elems)]
        self.make_coords()
        self.fill_chain()

    def fill_chain(self):
        self.chain[0] = 1
        for i in range(1, self.n_elems):
            coord = self.get_coord_by_i(i)
            new_val = self.fill_this_chain_elem(coord)
            # print("I'ma fill number " + str(i) + " with " + str(new_val))
            self.chain[i] = new_val

    def make_coords(self):
        """
        Coordinates are in [row, col] format and start at the center
        of the chain at [0, 0]. For a grid with 4 rings (i.e. 49 elements),
        the coordinates then go from [-3, -3] to [3, 3]
        """
        self.coords = [[0, 0]]
        direction = "right"
        for i in range(1, self.n_elems):
            if direction == "right":
                x = self.coords[i-1][0]
                y = self.coords[i-1][1] + 1
                if x + 1 == y:
                    # switch walking direction from right to up in these
                    # "lower right" corners of the grid
                    direction = "up"

            elif direction == "up":
                x = self.coords[i-1][0] - 1
                y = self.coords[i-1][1]
                if -x == y:
                    direction = "left"

            elif direction == "left":
                x = self.coords[i-1][0]
                y = self.coords[i-1][1] - 1
                if x == y:
                    direction = "down"

            elif direction == "down":
                x = self.coords[i-1][0] + 1
                y = self.coords[i-1][1] 
                if -x == y:
                    direction = "right"

            self.coords.append([x, y])

    def get_coord_by_i(self, i):
        return self.coords[i]

    def get_neighboring_coords(self, coord):
        return [
            [coord[0]+dx, coord[1]+dy]
            for dx in range(-1, 2)
            for dy in range(-1, 2)
            if (not (dx==0 and dy==0))  # so that the field is not its own neighbor
            and (not abs(coord[0]+dx) >= self.rings)  # here we check that the fields are not outside of the outermost ring
            and (not abs(coord[1]+dy) >= self.rings)
        ]

    def fill_this_chain_elem(self, coord):
        neighbors = self.get_neighboring_coords(coord)
        total = 0
        for neighbor_coord in neighbors:
            this_val = self.get_elem_by_coord(neighbor_coord)
            # print("I got val=" + str(this_val) + " from coord " + str(neighbor_coord))
            total += this_val
        return total

    def get_chain(self, length):
        return self.chain[0:length]

    def get_elem_by_i(self, i):  # actually, only used for test code
        return self.chain[i]

    def get_elem_by_coord(self, coord):
        for i in range(self.n_elems):
            c_i = self.get_coord_by_i(i)
            if coord[0] == c_i[0] and coord[1] == c_i[1]:
                return self.chain[i]

    def get_first_greater_than(self, num):
        for i in range(self.n_elems):
            if self.chain[i] > num:
                return self.chain[i]
        return None

if __name__ == "__main__":    
    grid = Grid(rings=5)

    # for i in range(20):
    #     print("Coord " + str(i) + ": " + str(grid.get_coord_by_i(i)))

    # print(grid.get_elem_by_coord([-1, -2]))

    print(grid.get_chain(grid.n_elems))
    print(grid.get_first_greater_than(368078))

Nomenclature:

The goal here is to write a Ulam Spiral type of grid and fill it in that spiral order.

  • I create that spiral by specifying the number of rings around the first, central element. 2 Rings equal 3*3=9 elements, 3 rings are 5*5=25 elements, etc.
  • The spiral can be viewed as a chain, with element 0 in the center, then element 1, etc.
  • Or you view it as a grid of x and y coordinates, with the first element having [0, 0], and the next element [1, 0], the third element [1, -1], etc. (I just realized the y axis is flipped in my representation). This was useful for finding the neighbors of a chain element.

Solution

Some of the thoughts:

  • I don’t really think you need to introduce a special method for getting the element out of the chain by index. What if, instead of having get_elem_by_i() method, you would just access the elements directly via .chain[i]?
  • we can use the sum() function in the fill_this_chain_elem() method:

    def fill_this_chain_elem(self, coord):
        neighbors = self.get_neighboring_coords(coord)
    
        return sum(self.get_elem_by_coord(neighbor_coord)
                   for neighbor_coord in neighbors)
    
  • get_first_greater_than() may make use of next():

    def get_first_greater_than(self, num):
        return next((self.chain[index] for index in range(self.n_elems)
                     if self.chain[index] > num), None)
    
  • replace the unused i variable with a conventional underscore:

    self.chain = [0 for _ in range(self.n_elems)]
    
  • watch for spaces around operators

  • you can omit the 0 in the self.chain[0:length] slice -> self.chain[:length]

Leave a Reply

Your email address will not be published. Required fields are marked *