Sunday, September 27, 2015

Sudoku Solver using Backtracking

Java Code:

package com.sourabh.second;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class SudokuSolver {
    //Brute force back-tracking based Sudoku Solver Program
    class Square {
        int startRow;
        int endRow;
        int startCol;
        int endCol;
    }
   
    public static void main(String[] args) {
        SudokuSolver solver = new SudokuSolver();
        Integer[][] matrix = new Integer[][]{
            {5,3,0,0,7,0,0,0,0},
            {6,0,0,1,9,5,0,0,0},
            {0,9,8,0,0,0,0,6,0},
            {8,0,0,0,6,0,0,0,3},
            {4,0,0,8,0,3,0,0,1},
            {7,0,0,0,2,0,0,0,6},
            {0,6,0,0,0,0,2,8,0},
            {0,0,0,4,1,9,0,0,5},
            {0,0,0,0,8,0,0,7,9}
            };
        matrix = solver.solveSudoku(matrix);
        //solver.printMatrix(matrix);
    }
   
    void printMatrix(Integer [][] matrix) {
        System.out.println();
        for(int i=0; i<matrix.length; i++) {
            for(int j=0; j<matrix.length; j++) {
                System.out.print(matrix[i][j] + ",");
            }
            System.out.println();
        }
    }
   
    public Integer[][] solveSudoku(Integer [][] matrix) {
        boolean complete = true;
        int i=0;
        int j=0;
        for(i=0;i<matrix.length;i++) {
            for(j=0;j<matrix[i].length;j++) {
                if(matrix[i][j] == 0) {
                    complete = false;
                    break;
                }
            }
            if(!complete) {
                break;
            }
        }
        if(complete) {
            printMatrix(matrix);
            return matrix;
        }
        //Find the square boundary
        Square currentSquare = getSquareBoundary(i, j);
        //Find in this row
        Integer[] possibleRowNumbers = getPossibleNumbersInRow(matrix, i);
        //Find in this column
        Integer[] possibleColNumbers = getPossibleNumbersInColumn(matrix, j);
        //Find in this square
        Integer[] possibleSquareNumbers = getPossibleNumbersInSquare(matrix,
                currentSquare.startRow, currentSquare.endRow, currentSquare.startCol, currentSquare.endCol);
        Integer[] possibleNumbers = getIntersectionOfArrays(possibleRowNumbers, possibleColNumbers, possibleSquareNumbers);
        boolean incomplete = false;
        for(int k=0;k<possibleNumbers.length;k++) {
            matrix[i][j] = possibleNumbers[k];
            matrix = solveSudoku(matrix);
            //printMatrix(matrix);
            for(int row=0;row<matrix.length;row++) {
                for(int col=0;col<matrix[row].length;col++) {
                    if(matrix[row][col] == 0) {
                        incomplete = true;
                        break;
                    }
                }
            }
            if(incomplete) {
                matrix[i][j] = 0;
            }
            else {
                break;
            }
        }
        return matrix;
    }
   
    Integer[] getPossibleNumbersInRow(Integer[][] matrix, int row) {
        List<Integer> possibleNumbers = new ArrayList<>(Arrays.asList(1,2,3,4,5,6,7,8,9));
        for(int i=0; i<matrix[row].length; i++) {
            if(matrix[row][i] != 0) {
                possibleNumbers.remove(matrix[row][i]);
            }
        }
        Integer[] array = new Integer[possibleNumbers.size()];
        return possibleNumbers.toArray(array);
    }
   
    Integer[] getPossibleNumbersInColumn(Integer[][] matrix, int col) {
        List<Integer> possibleNumbers = new ArrayList<>(Arrays.asList(1,2,3,4,5,6,7,8,9));
        for(int i=0; i<matrix.length; i++) {
            if(matrix[i][col] != 0) {
                possibleNumbers.remove(matrix[i][col]);
            }
        }
        Integer[] array = new Integer[possibleNumbers.size()];
        return possibleNumbers.toArray(array);
    }
   
    Integer[] getPossibleNumbersInSquare(Integer[][] matrix, int startRow, int endRow,
            int startCol, int endCol) {
        List<Integer> possibleNumbers = new ArrayList<>(Arrays.asList(1,2,3,4,5,6,7,8,9));
        for(int i=startRow; i<=endRow; i++) {
            for(int j=startCol; j<=endCol; j++) {
                if(matrix[i][j] != 0) {
                    possibleNumbers.remove(matrix[i][j]);
                }
            }
        }
        Integer[] array = new Integer[possibleNumbers.size()];
        return possibleNumbers.toArray(array);
    }

    Integer[] getIntersectionOfArrays(Integer[] ... arrays) {
        Integer[] intersection = getIntersectionOfTwoArrays(arrays[0], arrays[1]);
        //Arrays are already sorted, so applying merge on 2 arrays at a time.
        for(int i=2;i<arrays.length;i++) {
            intersection = getIntersectionOfTwoArrays(intersection, arrays[i]);
        }
        return intersection;
    }
   
    Integer[] getIntersectionOfTwoArrays(Integer[] array1, Integer[] array2) {
        List<Integer> intersectionList = new ArrayList<>();
        int a=0;
        int b=0;
        while(a<array1.length && b<array2.length) {
            if(array1[a] < array2[b]) {
                a++;
            }
            else if(array1[a] > array2[b]) {
                b++;
            }
            else {
                intersectionList.add(array1[a]);
                a++;
                b++;
            }
        }
        Integer[] array = new Integer[intersectionList.size()];
        return intersectionList.toArray(array);
    }

    Square getSquareBoundary(int row, int col) {
        Square square = new Square();
        if(row < 3) {
            square.startRow = 0;
            square.endRow = 2;
        }
        else if(row < 6) {
            square.startRow = 3;
            square.endRow = 5;
        }
        else if(row < 9) {
            square.startRow = 6;
            square.endRow = 8;
        }
       
        if(col < 3) {
            square.startCol = 0;
            square.endCol = 2;
        }
        else if(col < 6) {
            square.startCol = 3;
            square.endCol = 5;
        }
        else if(col < 9) {
            square.startCol = 6;
            square.endCol = 8;
        }
        return square;
    }
}

Unit Tests:

package com.sourabh.second;

import static org.junit.Assert.*;

import org.junit.Assert;
import org.junit.Test;

import com.sourabh.second.SudokuSolver.Square;

public class SudokuSolverTests {

    @Test
    public void testSolveSudoku3() {
        Integer[][] matrix = new Integer[][]{
            {0,0,0,1,0,0,0,0,7},
            {0,0,6,0,7,0,0,5,0},
            {0,9,0,0,0,3,6,0,0},
            {3,0,0,0,0,9,5,0,0},
            {0,1,0,0,2,0,0,3,0},
            {0,0,8,4,0,0,0,0,9},
            {0,0,9,7,0,0,0,4,0},
            {0,5,0,0,4,0,3,0,0},
            {2,0,0,0,0,1,0,0,0}
            };
        SudokuSolver solver = new SudokuSolver();
        solver.solveSudoku(matrix);
    }
   
    //@Test
    public void testSolveSudoku2() {
        Integer[][] matrix = new Integer[][]{
            {6,0,4,0,0,8,5,3,0},
            {9,0,0,0,3,6,1,0,0},
            {0,0,0,9,0,0,0,0,0},
            {4,9,2,0,0,0,0,0,5},
            {0,0,6,0,0,0,4,0,0},
            {8,0,0,0,0,0,9,2,7},
            {0,0,0,0,0,2,0,0,0},
            {0,0,9,8,7,0,0,0,1},
            {0,1,7,4,0,0,8,0,2}
            };
        SudokuSolver solver = new SudokuSolver();
        solver.solveSudoku(matrix);
    }
   
    //@Test
    public void testSolveSudoku() {
        Integer[][] matrix = new Integer[][]{
            {1,0,0,6,0,0,2,0,0},
            {0,2,0,0,0,0,9,0,5},
            {0,0,0,7,0,0,0,1,6},
            {0,9,4,0,7,6,0,0,0},
            {0,1,0,0,9,0,0,2,0},
            {0,0,0,3,1,0,4,9,0},
            {2,3,0,0,0,8,0,0,0},
            {4,0,6,0,0,0,0,5,0},
            {0,0,1,0,0,7,0,0,2}
            };
        SudokuSolver solver = new SudokuSolver();
        solver.solveSudoku(matrix);
    }

    @Test
    public void testAssertArrayEquals() {
        Assert.assertArrayEquals(new Integer[] {1,2,3}, new Integer[] {1,2,3});
    }
   
    @Test
    public void testGetPossibleNumbersInRow() {
        Integer[][] matrix = {{1,2,0,0,0}, {0,0,0,3,4}, {5,0,0,0,6}};
        SudokuSolver solver = new SudokuSolver();
        int row = 0;
        Integer[] result = solver.getPossibleNumbersInRow(matrix, row);
        Assert.assertArrayEquals(new Integer[] {3,4,5,6,7,8,9}, result);
        row = 1;
        result = solver.getPossibleNumbersInRow(matrix, row);
        Assert.assertArrayEquals(new Integer[] {1,2,5,6,7,8,9}, result);
        row = 2;
        result = solver.getPossibleNumbersInRow(matrix, row);
        Assert.assertArrayEquals(new Integer[] {1,2,3,4,7,8,9}, result);
    }

    @Test
    public void testGetPossibleNumbersInColumn() {
        Integer[][] matrix = {{1,1}, {0,0}, {0,2}, {3,0}};
        SudokuSolver solver = new SudokuSolver();
        int col = 0;
        Integer[] result = solver.getPossibleNumbersInColumn(matrix, col);
        Assert.assertArrayEquals(new Integer[] {2,4,5,6,7,8,9}, result);
        col = 1;
        result = solver.getPossibleNumbersInColumn(matrix, col);
        Assert.assertArrayEquals(new Integer[] {3,4,5,6,7,8,9}, result);
    }

    @Test
    public void testGetPossibleNumbersInSquare() {
        Integer[][] matrix = {{1,0,0}, {0,2,0}, {0,0,3}};
        SudokuSolver solver = new SudokuSolver();
        Integer[] result = solver.getPossibleNumbersInSquare(matrix, 0, 2, 0, 2);
        Assert.assertArrayEquals(new Integer[] {4,5,6,7,8,9}, result);
    }
   
    @Test
    public void testGetIntersectionOfArrays() {
        Integer[] array1 = new Integer[]{0,0,1};
        Integer[] array2 = new Integer[]{0,0,2};
        Integer[] array3 = new Integer[]{0,0,3};
        SudokuSolver solver = new SudokuSolver();
        Integer[] result = solver.getIntersectionOfArrays(array1, array2, array3);
        Assert.assertArrayEquals(new Integer[] {0,0}, result);
    }
   
    @Test
    public void testGetIntersectionOfTwoArrays() {
        Integer[] array1 = new Integer[]{0,0,1};
        Integer[] array2 = new Integer[]{0,0,2};
        SudokuSolver solver = new SudokuSolver();
        Integer[] result = solver.getIntersectionOfTwoArrays(array1, array2);
        Assert.assertArrayEquals(new Integer[] {0,0}, result);
    }
   
    @Test
    public void testGetSquareBoundary() {
        SudokuSolver solver = new SudokuSolver();
        Square square = solver.getSquareBoundary(1, 1);
        Assert.assertEquals(square.startRow, 0);
        Assert.assertEquals(square.endRow, 2);
        Assert.assertEquals(square.startCol, 0);
        Assert.assertEquals(square.endCol, 2);
    }
}

2 comments: