快速幂算法

设计一个算法计算\(x^n\)的值。

根据定义最常见也最能瞬间想到的是如下的算法:

// 递归写法
public int pow1(int x, int n) {
  if (n == 0) return 1;
  if (n == 1) return x;
  return x * pow1(x, n - 1);
}
// 循环写法
public int pow2(int x, int n) {
  int y = 1;
  while (n) {
    y *= x;
    n--;
  }
  return y;
}

但上面的算法的时间复杂度是\(O(n)\)

下面采用快速幂算法来解决这个问题。

在解决它之前先来看一下原理:

\[x^{n}=x^{n-a}x^a
\]

所以我们可以对本身要求的\(x^n\)对半分来求,只求一半的数,然后乘以自己本身就可以达到\(x^n\)

但是会出现的情况就是,对半分的时候会出现小数的情况,所以一定要分奇数和偶数的情况。

如果n分半了之后是偶数,那就直接对半分,如果是奇数则在对半分之后还要乘以一个x。

所以可以有下面的规律:

f(x, n) = {
  f(x, n/2)*f(x, n/2),    // 当n为偶数
  x*f(x, n/2)*f(x, n/2)   // 当n为奇数
}

所以得出快速幂算法1:

public int qPow1(int x, int n) {
  if (n == 0) return 1;
  if (n == 1) return x;
  if (n % 2 == 1) return x*f(x, n/2)*f(x, n/2);
  return f(x, n/2)*f(x, n/2);
}

但是上面的算法明显没有任何增进,因为f(x, n/2)要算两次,那和之前的\(O(n)\)的算法没什么区别。所以使用一个中间变量去接受一下,就可以提高算法效率。

public int qPow1(int x, int n) {
  if (n == 0) return 1;
  if (n == 1) return x;
  int t = f(x, n/2)
  if (n % 2 == 1) return x * t * t;
  return t * t;
}

上面的快速幂算法还是比较好理解的,下面的快速幂算法就比较的炫技了我觉得,但是也就那样(原理还是上面的,只是不是对半分而已,而是根据进制数来分)。

下面采用二进制数来分。

假设我们要计算的是\(x^{10}\),那么10的二进制数是1010,所以有如下公式及变换:

\[x^{10}=x^{(10)_{10}}=x^{(1010)_2}=x^{1*2^3+0*2^2+1*2^1+0*2^0}=x^{8+2}=x^8x^2
\]

\[x^{10}=x^8x^2
\]

和上面第一种快速幂的算法类似,只不过上一种采用的分法是:

\[x^{10}=x^5x^5
\]

那么不管怎样分,最后肯定会被分到1,因为\(x^0=1\),其实上面的分法都隐藏了一个\(x^0\),即:

\[x^{10}=x^8x^2x^0
\]

所以状态是怎么转移的,即每一次迭代都是怎样变化的。初始化\(t=x^{2^0}=x^1=x\),那么下一代的变化是\(x^{2^1}\),它是由\(x^{2^0*2}\)变化而来,因为采用的是二进制。所以指数部分要想从\(2^0\)变换到\(2^1\)就需要乘以一个2.那也就是说,\(x\)变到\(x^2\).那么迭代变化过程就是\(t=t*t\).

\(x^{2^0*2}=(x^{2^0})^2\)

所以得到第二种快速幂算法代码:

public int qPow2(int x, int n) {
  int y = 1;
  int t = x;
  while (n > 0) {
    switch (n % 2) {
      case 1: y = y * t;  // 这里不要写break
      case 0: t = t * t;
    }
    n = n / 2;
  }
  return y;
}

那么这个采用二进制的方法分,当然也有三进制的,四进制的,五进制的等等。十六进制就不要搞了,因为不是进制越高就越快。

通过我对上面的二进制写法的快速幂就可以看出来我还会有其他进制的写法。那么下面就来看一下三进制的写法,然后四进制的就顺其自然就明白了。

那么三进制的推导过程也是和二进制的推导过程是类似的。假设计算的是\(x^{10}\).

\[x^{10}=x^{(10)_{10}}=x^{(101)_3}=x^{1*3^2+0+3^1+1*3^0}=x^9x
\]

所以从二进制的分法和三进制的分法可以看出,不管怎么分都是可以合起来达到10.只要能达到10的说明采用什么进制分法都是可以的。但并不是说采用的进制越高就越好。

那么初始化\(t=x^{3^0}=x\),那么下一代的变化是\(x^{3^1}\),它是由\(x^{3^0*3}\)变化而来,因为采用的是三进制。所以指数部分要想从\(3^0\)变换到\(3^1\)就需要乘以一个3.那也就是说,\(x\)变到\(x^3\).那么迭代变化过程就是\(t=t*t*t\).

\(x^{3^0*3}=(x^{3^0})^3\)

对比一下二进制和三进制的区别,所以四进制往后的就不需要我一个个推导了吧。

直接得出算法代码:

public int qPow3(int x, int n) {
  int y = 1;
  int t = x;
  while (n > 0) {
    switch (n % 3) {
      case 2: y = y * t;  // 这里不要写break
      case 1: y = y * t;  // 这里不要写break
      case 0: t = t * t * t;
    }
    n = n / 3;
  }
  return y;
}

直接得出四进制版本的代码:

public int qPow4(int x, int n) {
  int y = 1;
  int t = x;
  while (n > 0) {
    switch (n % 4) {
      case 3: y = y * t;  // 这里不要写break
      case 2: y = y * t;  // 这里不要写break
      case 1: y = y * t;  // 这里不要写break
      case 0: t = t * t * t * t;
    }
    n = n / 4;
  }
  return y;
}

直接得出五进制版本的代码:

public int qPow5(int x, int n) {
  int y = 1;
  int t = x;
  while (n > 0) {
    switch (n % 5) {
      case 4: y = y * t;  // 这里不要写break
      case 3: y = y * t;  // 这里不要写break
      case 2: y = y * t;  // 这里不要写break
      case 1: y = y * t;  // 这里不要写break
      case 0: t = t * t * t * t;
    }
    n = n / 5;
  }
  return y;
}

以此类推。。。就不写了。

可以去测试一下:

public class Main {
        public static void main(String[] args) {
                System.out.println("计算2的10次方:");
                int x = 2, n = 10;
                Solution s = new Solution();
                showMessage("pow1: ",  s.pow1(x, n));
                showMessage("pow2: ",  s.pow2(x, n));
                showMessage("qPow1: ", s.qPow1(x, n));
                showMessage("qPow2: ", s.qPow2(x, n));
                showMessage("qPow3: ", s.qPow3(x, n));
                showMessage("qPow4: ", s.qPow4(x, n));
                showMessage("qPow5: ", s.qPow5(x, n));
        }
        public static void showMessage(String str, int result) {
                System.out.println("---------------------");
                System.out.println(str + result);
                System.out.println("---------------------");
        }
}

class Solution {
        public int pow1(int x, int n) {
                if (n == 0) return 1;
                if (n == 1) return x;
                return x * pow1(x, n - 1);
        }
        public int pow2(int x, int n) {
                int y = 1;
                while (n > 0) {
                        y = y * x;
                        n--;
                }
                return y;
        }
        public int qPow1(int x, int n) {
                if (n == 0) return 1;
                if (n == 1) return x;
                int t = qPow1(x, n / 2);
                if (n % 2 == 1) return x * t * t;
                return t * t;
        }
        public int qPow2(int x, int n) {
                int y = 1;
                int t = x;
                while (n > 0) {
                        switch (n % 2) {
                                case 1: y = y * t; // 这里不要写break
                                case 0: t = t * t;
                        }
                        n = n / 2;
                }
                return y;
        }
        public int qPow3(int x, int n) {
                int y = 1;
                int t = x;
                while (n > 0) {
                        switch (n % 3) {
                                case 2: y = y * t;  // 这里不要写break
                                case 1: y = y * t;  // 这里不要写break
                                case 0: t = t * t * t;
                        }
                        n = n / 3;
                }
                return y;
        }
        public int qPow4(int x, int n) {
                int y = 1;
                int t = x;
                while (n > 0) {
                        switch (n % 4) {
                                case 3: y = y * t;  // 这里不要写break
                                case 2: y = y * t;  // 这里不要写break
                                case 1: y = y * t;  // 这里不要写break
                                case 0: t = t * t * t * t;
                        }
                        n = n / 4;
                }
                return y;
        }
        public int qPow5(int x, int n) {
                int y = 1;
                int t = x;
                while (n > 0) {
                        switch (n % 5) {
                                case 4: y = y * t;  // 这里不要写break
                                case 3: y = y * t;  // 这里不要写break
                                case 2: y = y * t;  // 这里不要写break
                                case 1: y = y * t;  // 这里不要写break
                                case 0: t = t * t * t * t * t;
                        }
                        n = n / 5;
                }
                return y;
        }
}

终端输出:

计算2的10次方:
---------------------
pow1: 1024
---------------------
---------------------
pow2: 1024
---------------------
---------------------
qPow1: 1024
---------------------
---------------------
qPow2: 1024
---------------------
---------------------
qPow3: 1024
---------------------
---------------------
qPow4: 1024
---------------------
---------------------
qPow5: 1024
---------------------

然后来分析一下它们的执行效率。

一般的\(O(n)\)就不说了,肯定是比\(O(log_2n)\)差的。主要是看是不是进制越高就越好?先给出答案,并不一定是。看着qPow5好像可以更快收到答案,但我们忘了看\(t=t*t*t*t*t\)这段代码和它上面的那一坨。然后再放大一点看,如果我采用的是十进制的写法。那会发现状态转移是\(t=t*t*t...(10个t)\)那和一般的pow有什么区别?所以,从这一点可以看出来并不是进制越高就越好。就好像是用\(x^8x^2\)\(x^2x^8\)比效率一样。都是一样的嘛,如果外层循环少了,那里面的乘法就多了。所以得找个平衡的点。那这个平衡的点一般也就是对半的时候(并不是所有情况都是),所以我们折腾了那么久又回到了二进制的版本。因为计算机底层是二进制,所以我们就采用二进制的版本,然后再采用代码上的语法优化,这样应该是更好一点,因为其它进制的版本可优化的点并不多。下面给出二进制优化的版本。

public int qPow2(int x, int n) {
  int y = 1;
  int t = x;
  while (n > 0) {
    if (n % 2 == 1) y *= t;
    t *= t;
    n >>= 1;
  }
  return y;
}

因为Java语法本身的原因在做位运算的时候不能像C/C++一样可以用非零当作真。所以if (n % 2 == 1) y *= t;并不改变。但如果是C/C++的话可以采用下面的版本:

int qPow2(int x, int n) {
  int y = 1;
  while (n) {
    if (n & 1) y *= t;
    t *= t;
    n >>= 1;
  }
  return y;
}

快速幂算法就先到这里结束了。