如何用一条SQL计算24点

《数据库编程大赛:一条SQL计算扑克牌24点》

最近公司举办了SQL编程大赛,用一条SQL来计算24点的公式。
因为先看到了题目,也尝试着完成了一下,但没有参赛,只在这里写一下我的解题思路。

24点游戏大家都玩过,现在也是笔试中常考的编程题目,输入4个10以内整数,给出经过加、减、乘、除运算后得出24的计算公式,通常的思路是用回溯方法遍历所有可能性:

  1. 在输入的4个整数中选择两个数,经过运算后得到结果。
  2. 用第一步的结果取代选中的两个数,再从剩下3个数中选择两个数,继续运算后得到结果。
  3. 最后只剩下两个数,运算后判断结果是否等于24,若结果不等于24则继续重复上一个步骤。

因为加法和乘法满足交换律,减法和除法运算时要考虑运算顺序,所以运算时增加被减和除以两种运算,由加、减、乘、除四则运算改为了加、减、乘、除、被减、被除这六种运算,遍历计算公式数目为:

C42 * 6 * C32 * 6 * C22 * 6 = 3888
即是说通过回溯的方法想要找到4个数得出24的计算公式最多只需要遍历3888次。

SQL解题思路

(比赛题目)
再看本次比赛的题目,要求用一条SQL查询出表中随机给出4个整数的24点计算公式,测试表中的数据超过1W条,假如遍历每行数据的所有计算公式,计算结果会超过千万行再做过滤聚合,这显然无法满足比赛的性能要求。
仔细分析运算过程会发现因为运算结合律、交换律、数字顺序等特殊性,遍历过程中有大量的重复计算,例如:

  1. ( 1, 2, 3, 4 )和( 4, 3, 2, 1 ),都可以通过公式( 1 + 3 ) * ( 2 + 4 )计算得到24,所以相同的4个数只需要计算一次即可。
  2. ( 1, 2, 3, 4 )组合中,( 1 + 3 ) + ( 2 + 4 ) 和 ( 1 + 2 ) + ( 3 + 4 )计算过程重复。

简化的办法就是从如何减少重复计算开始入手,4个数不考虑排列顺序,只有715个组合,具体算法如下:

  1. 四个数不同:C104 = 210
  2. 两个数相同,两个数不同:C103 * C31 = 360
  3. 两个数相同,另外两个数也相同:C102 = 45
  4. 三个数相同:C102 * C21 = 90
  5. 四个数相同:C101 = 10

( 210 + 360 + 45 + 90 + 10 ) = 715,因此4个数能够得出24的组合不超过715种,所以假如计算出所有能够得到24的组合和计算公式,表数量一定少于715行,解题思路可以将计算结果和cards表做join匹配,得到对应计算公式作为result,这样可以避免直接对cards表中的数据计算得出公式。
(join匹配)
通过SQL的解法需要先找到4个数中能够计算得出24点的所有组合,再计算组合的hash值与cards中每行记录的hash值匹配,最终输出数字和计算公式。

计算Hash

简单来说就是要找到一种令hash( 1, 2, 3, 4 ) = hash( 4, 3, 2, 1 )的计算方法,让相同数字组合计算出的hash值相等,作为最终cards表join连接的条件。
hash可以有很多种计算方法,例如可以将4个排序后再拼接字符串,或者计算10n将数移位到十进制对应位后再相加。
(移位计算hash)
本文采用了移位的方法,对每个数进行( 1 << ( n * 2 ))的左移操作后再相加得到hash值,与十进制的计算思路相似,移位操作的效率更高,左移 n * 2 位后相当于采用四进制,只有4个数相同时才会发生进位,但不会产生hash碰撞的情况。

计算24点组合公式

总结分析文章开始提到的回溯遍历方法,可以发现4个数进行3次运算,其实只有以下两种计算路径:

  1. (A op B) op (C op D):4个数两两运算后再做运算。
  2. ((A op B) op C) op D:两个数运算后再顺序与第三个数和第四个数做运算。

因此SQL的计算24点组合公式的思路是算出两种计算路径的24点公式,再拼接后去重得到所有组合和计算公式。

(计算路径)

一、计算( A op B )

先计算出两个数运算,是因为两种计算路径中会有三处用到,可以将结果作为临时表,避免重复运算。
将两个数和四个操作符做笛卡尔积就可以得到运算结果,这里有一点优化可以在加法和乘法时只计算A >= B的情况减少部分重复结果。

二、计算( A op B ) op ( C op D )

将第一步两数的运算结果再和操作符做笛卡尔积即可以得到所有的运算结果。
这里仍然可以只计算( A op B ) >= ( C op D )的情况,加法和乘法因为交换律减少部分重复计算,减法和除法则因为小数减/除大数结果肯定小于24,不需要计算。

三、计算(( A op B ) op C ) op D

因为结合律,((A op B) op C) op D 有部分组合在第二步(A op B) op (C op D)中已经计算过。
例如 (( 1 + 2) + 3 ) + 4 和 ( 1 + 2 ) + ( 3 + 4 )属于计算重复,(( 2 + 3 ) * 3 ) / 4 和 ( 2 + 3 ) * ( 3 / 4 )属于计算重复。
所以((A op B) op2 C) op3 D只需要计算op2和op3不同时为 +- 或 */ 的场景,考虑减法和除法存在顺序性,可以列举为以下20个计算公式。

(公式说明)

在20个计算公式中只有15个可以计算得出24,其中有6个是重复计算,所以((A op B) op C) op D只需要计算9个计算公式即可。

四、结果判断

在做除法运算时会转成浮点数计算,可能会精度问题影响最终结果,常见的处理方法是将精度误差计算在内,例如判断最终结果并不是等于24,而是(>23.999999 AND <24.000001)。
本文采用转为分数来避免除法运算,将每个数看作分母是1,遇到加减法时先通分再加减,遇到乘法时分子与分子相乘,分母与分母相乘,遇到除法时乘以被除数的倒数。最终结果用( 分子 = 24 * 分母 )来判断。

性能测试

测试环境:MySQL 8.2.0,4C32G
测试结果:

cards表查询耗时
1万行0.57秒
10万行0.60秒
100万行0.97秒
1000万行4.75秒

测试SQL:

WITH RECURSIVE numbers (n) AS (
    SELECT 10
    UNION ALL
    SELECT n - 1 FROM numbers WHERE n > 1
  ),
  operators (op) AS (
    SELECT '+'
    UNION ALL
    SELECT '-'
    UNION ALL
    SELECT '*'
    UNION ALL
    SELECT '/'
  ),
  calculations (hash, res, exp, den) AS (
    /* calculate (A op B) , hash, result, expression, denominator */
    SELECT (1 << (A.n * 2)) + (1 << (B.n * 2)),
      CASE op.op 
        WHEN '+' THEN A.n + B.n
        WHEN '-' THEN A.n - B.n
        WHEN '*' THEN A.n * B.n
        ELSE A.n
      END,
      CONCAT('(', A.n, op.op, B.n, ')'),
      CASE op.op
        WHEN '/' THEN B.n
        ELSE 1
      END
    FROM numbers AS A, numbers AS B, operators AS op
    WHERE A.n >= B.n OR op.op IN ('-', '/')
  ),
  result_set (hash, res, exp, den) AS (
    /* (A op B) op (C op D) */
    SELECT
      c1.hash + c2.hash,
      CASE op.op
        WHEN '+' THEN c1.res * c2.den + c2.res * c1.den
        WHEN '-' THEN c1.res * c2.den - c2.res * c1.den
        WHEN '*' THEN c1.res * c2.res
        ELSE c1.res * c2.den
      END,
      CONCAT(c1.exp, op.op, c2.exp),
      CASE op.op
        WHEN '/' THEN c1.den * c2.res
        ELSE c1.den * c2.den
      END
    FROM calculations AS c1, calculations AS c2, operators AS op
    WHERE c1.res * c2.den >= c2.res * c1.den
    UNION ALL
    /* ((A op B) op2 C) op3 D */
    SELECT
      calc.hash + (1 << (C.n * 2)) + (1 << (D.n * 2)),
      CASE
        WHEN sence = 1 THEN (calc.res + C.n * calc.den) * D.n
        WHEN sence = 2 THEN calc.res * C.n - D.n * calc.den
        WHEN sence = 3 THEN (calc.res - C.n * calc.den) * D.n
        WHEN sence = 4 THEN calc.res * C.n + D.n * calc.den
        WHEN sence = 5 THEN (calc.res - C.n * calc.den)
        WHEN sence = 6 THEN (calc.res + C.n * calc.den)
        WHEN sence = 7 THEN (C.n * calc.den - calc.res) * D.n
        WHEN sence > 7 THEN D.n * calc.den
      END,
      CASE
        WHEN sence < 7 THEN CONCAT('(', calc.exp, op.op2, C.n, ')', op.op3, D.n)
        WHEN sence = 7 THEN CONCAT('(', C.n, '-', calc.exp, ')*', D.n)
        WHEN sence = 8 THEN CONCAT(D.n, '/(', calc.exp, '-', C.n, ')')
        WHEN sence = 9 THEN CONCAT(D.n, '/(', C.n, '-', calc.exp, ')')
      END,
      CASE
        WHEN sence < 5 THEN calc.den
        WHEN sence = 5 THEN calc.den * D.n
        WHEN sence = 6 THEN calc.den * D.n
        WHEN sence = 7 THEN calc.den
        WHEN sence = 8 THEN calc.res - C.n * calc.den
        WHEN sence = 9 THEN C.n * calc.den - calc.res
      END
    FROM calculations AS calc, numbers AS C, numbers AS D, (SELECT '+' AS op2, '*' AS op3, 1 AS sence
                                                                UNION ALL
                                                                SELECT '*' AS op2, '-' AS op3, 2 AS sence
                                                                UNION ALL
                                                                SELECT '-' AS op2, '*' AS op3, 3 AS sence
                                                                UNION ALL
                                                                SELECT '*' AS op2, '+' AS op3, 4 AS sence
                                                                UNION ALL
                                                                SELECT '-' AS op2, '/' AS op3, 5 AS sence
                                                                UNION ALL
                                                                SELECT '+' AS op2, '/' AS op3, 6 AS sence
                                                                UNION ALL
                                                                SELECT '-' AS op2, '*' AS op3, 7 AS sence
                                                                UNION ALL
                                                                SELECT '-' AS op2, '/' AS op3, 8 AS sence
                                                                UNION ALL
                                                                SELECT '-' AS op2, '/' AS op3, 9 AS sence) AS op
  ),
  result_table (hash, exp) AS (
    SELECT hash, any_value(exp) FROM result_set WHERE den <> 0 AND res = 24 * den GROUP BY hash
  )
SELECT id, c1, c2, c3, c4, result FROM `poker24`.`cards` LEFT JOIN (SELECT hash+0 AS hash, exp AS result from result_table) t ON ((1 << (c1 * 2)) + (1 << (c2 * 2)) + (1 << (c3 * 2)) + (1 << (c4 * 2))) = t.hash;