Scala上的素数筛法:朴素埃拉托色尼筛

更新于 2024-05-24 03:55

引言

我大概在读小学五六年级的时候,从学校图书馆里借到了一本叫做《数学奇观》的科普书,书上介绍了一种叫做“埃拉托色芬”的寻找素数的方法。看完以后兴趣盎 然,就动手照做了一番。

我先找来一张A4大小的纸,用铅笔和直尺在上面画出了1000个小格子,依次填入数字 1-1000. 然后,先把1划掉。跳过2, 从4开始,每隔两个数划掉一个,这样就把 所有偶数划掉了。接着跳过3, 从6开始,每隔3个数划掉一个,跳过5, 从10开始,每隔5个数划掉一个。这样不断地进行,一直划到31的倍数, 即跳过31, 从62开始, 每隔31个数划掉一个。

这样结束以后,再把剩下没有划掉的数依次抄下来,就得到了1000以内的所有素数。我仔细数了几遍,确认无误,终于得到了一个伟大的科研成果:1000以内一共有 168个素数。

这就是最朴素的埃拉托色尼筛法。

当时人菜瘾大,成就感爆棚。同时又激动地想,要是有一种可以自动划掉数字的办法,岂不是可以找出更大范围内的素数?直到很多年以后,我逐渐理解了计算机这 种存在,而且学会了写代码,这件事情才成为可能。

由此说来,这一系列寻找素数的文章,最早可以说是基于童年的一种痴念。

最简单的实现

我们现在可以轻而易举地用Scala写出这种最朴素的筛法:

object PrimesEratosthenesSieve {

  def primes(n: Int): List[Int] = {
    if (n <= 7) List(2, 3, 5, 7).takeWhile(_ <= n)
    else {
      val q  = math.sqrt(n.toDouble).toInt
      val qs = primes(q)
      qs ++ (q + 1 to n).filter { x => qs.forall(x % _ != 0) }.toList
    }
  }
}

跑一下看看:

object PrimesEratosthenesSieveApp extends App {
  PrimesEratosthenesSieve.println(primes(1000))
  PrimesEratosthenesSieve.println(primes(1000).length)
}

运行结果:

168
List(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997)

简直不费吹灰之力。

性能测试

我们来测试一下这个函数的性能。首先,从 wikipedia 上查到我们需要的素数个数,做一些准备工作:

package simple

import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, State}

@State(Scope.Benchmark)
class PrimeState {
  @Param(Array("1000", "10000", "100000", "1000000", "10000000", "100000000", "1000000000")) var n: Int = 0
}

trait PrimesBenchmark {
  protected final val p1e3 = 1000
  protected final val p1e4 = 10000
  protected final val p1e5 = 100000
  protected final val p1e6 = 1000000
  protected final val p1e7 = 10000000
  protected final val p1e8 = 100000000
  protected final val p1e9 = 1000000000

  protected final def primesMap: Map[Int, Int] = Map(
    p1e3 -> 168,
    p1e4 -> 1229,
    p1e5 -> 9592,
    p1e6 -> 78498,
    p1e7 -> 664579,
    p1e8 -> 5761455,
    p1e9 -> 50847534
  )

  def bench(n: Int, f: Int => Int): Unit = assert(primesMap.get(n).contains(f(n)))
}

定义测试类:

class EratosthenesSieveBenchmark extends PrimesBenchmark {
  @Benchmark def primes1(state: PrimeState): Unit = bench(state.n, PrimesEratosthenesSieve.primes1(_).length)
}

开始测试:

$ JAVA_OPTS="-Xmx8G -Xms8G -XX:+UseG1GC" sbt "bench/Jmh/run -bm avgt -i 3 -wi 1 -f 1 -t 1 simple.EratosthenesSieveBenchmark"

参数解释:

  • 堆内存,固定 8GB
  • 垃圾回收算法 G1GC
  • 测试模式 avgt, 应该设为平均运行时间,而不是吞吐
  • Blackhole mode: full + dont-inline hint
  • Warmup: 1 iterations, 10 s each
  • Measurement: 3 iterations, 10 s each
  • Threads: 1 thread, will synchronize iterations

硬件环境:

  • CPU: AMD Ryzen 9 3950X, 16 核 32 线程,Base Clock 3.5GHz, Max. Boost Clock Up to 4.7GHz
  • 散热器:猫头鹰 D15 风冷
  • 内存:64GB 3600MHz
  • SSD:三星 983ZET, 960GB

软件环境:

  • OS: Linux NixOS
  • JVM: JDK 17.0.1, OpenJDK 64-Bit Server VM, 17.0.1+12-nixos
  • Scala: 3.1.1
  • Sbt: 1.6.2
  • sbt-jmh: 0.4.3

测试结果:

[info] Benchmark                                  (n)  Mode  Cnt     Score    Error  Units
[info] EratosthenesSieveBenchmark.primes1        1000  avgt    3    ≈ 10⁻⁵            s/op
[info] EratosthenesSieveBenchmark.primes1       10000  avgt    3    ≈ 10⁻⁴            s/op
[info] EratosthenesSieveBenchmark.primes1      100000  avgt    3     0.004 ±  0.001   s/op
[info] EratosthenesSieveBenchmark.primes1     1000000  avgt    3     0.086 ±  0.005   s/op
[info] EratosthenesSieveBenchmark.primes1    10000000  avgt    3     1.822 ±  0.017   s/op
[info] EratosthenesSieveBenchmark.primes1   100000000  avgt    3    42.306 ±  2.556   s/op
[info] EratosthenesSieveBenchmark.primes1  1000000000  avgt    3  1027.604 ± 20.288   s/op

可见,当 nn 等于 1 亿的时候,需要 42 秒,等于 10 亿的时候需要 1027 秒(17分钟)。

复杂度分析

众所周知埃拉托色尼筛的时间复杂度是 O(nlnln(n))O(n\ln{\ln{(n)}}), 但这个简单粗糙的实现显然要低效得多。至于具体的复杂度估算,我现在也不会(目测并不那 么简单)。但是我们可以简单地加上点调试代码,以便和后面的改进算法对比:

def primes1c(n: Int): List[Int] = {
  if (n <= 7) List(2, 3, 5, 7).takeWhile(_ <= n)
  else {
    val q  = math.sqrt(n.toDouble).toInt
    val qs = primes1c(q)
    var c = 0L
    val r = qs ++ (q + 1 to n).filter { x =>
      qs.forall { q =>
        c += 1L
        x % q != 0
      }
    }.toList
    println(s"n = $n, c = $c")
    r
  }
}

println(println(primes1c(100000000).length))
// n = 100000000, c = 8576796314

改进 1:使用 BitSet

容易想到的,可以改用一个 BitSet 来存储筛选的结果,这样整个算法过程会更接近当年手写素数的场景,遍历的数字会更少,占用的内存也会更小。同时,我们把函数的返回值从 List 改为 Iterator.

def primes2(n: Int): Iterator[Int] = {
    if (n <= 7) Iterator(2, 3, 5, 7).takeWhile(_ <= n)
    else {
      val q  = math.sqrt(n.toDouble).toInt
      val qs = primes2(q)
      val bs = collection.mutable.BitSet.empty
      qs.foreach { q =>
        (2 to n / q).foreach { p =>
          bs.addOne(p * q)
        }
      }
      Iterator.range(2, n).filterNot(bs.apply)
    }
  }

性能测试

class EratosthenesSieveBenchmark extends PrimesBenchmark {
  @Benchmark def primes2(state: PrimeState): Unit = bench(state.n, PrimesEratosthenesSieve.primes2(_).length)
}

结果:

[info] EratosthenesSieveBenchmark.primes2        1000  avgt    3  ≈ 10⁻⁵            s/op
[info] EratosthenesSieveBenchmark.primes2       10000  avgt    3  ≈ 10⁻⁴            s/op
[info] EratosthenesSieveBenchmark.primes2      100000  avgt    3   0.001 ±  0.001   s/op
[info] EratosthenesSieveBenchmark.primes2     1000000  avgt    3   0.014 ±  0.023   s/op
[info] EratosthenesSieveBenchmark.primes2    10000000  avgt    3   0.109 ±  0.009   s/op
[info] EratosthenesSieveBenchmark.primes2   100000000  avgt    3   1.245 ±  0.382   s/op
[info] EratosthenesSieveBenchmark.primes2  1000000000  avgt    3  15.494 ± 17.753   s/op

nn 等于 1 亿的时候,只需要 1.2 秒,等于 10 亿的时候只需要 16 秒。没想到看似低效的手工“划掉”,居然比现在比较流行的算法快 60 倍。

复杂度分析

下面来分析一下这个筛法的复杂度。显而易见,应该是

n(12+13+15+17+111+)=np prime1p n(\cfrac{1}{2} + \cfrac{1}{3} + \cfrac{1}{5} + \cfrac{1}{7} + \cfrac{1}{11} + \ldots) = n \sum_{p\ prime}^{}\cfrac{1}{p}

根据 wiki,

pn1plnln(n+1)lnπ26 \sum_{p \leq n}^{}\cfrac{1}{p} \sim \ln{\ln{(n+1)}} - \ln{\cfrac{\pi ^ 2}{6}}

因此朴素埃拉托色尼筛的算法复杂度应该是 O(nlnln(n))O(n\ln{\ln{(n)}}) .

简单验证一下:

def primes2c(n: Int): Iterator[Int] = {
  if (n <= 7) Iterator(2, 3, 5, 7).takeWhile(_ <= n)
  else {
    val q  = math.sqrt(n.toDouble).toInt
    val qs = primes2(q)
    val bs = collection.mutable.BitSet.empty
    var c = 0L
    qs.foreach { q =>
      (2 to n / q).foreach { j =>
        c += 1L
        bs.addOne(q * j)
      }
    }
    println(s"n = $n, c = $c")
    Iterator.range(2, n).filterNot(bs.apply)
  }
}

println(primes2c(100000000).length)
// n = 100000000, c = 248304142

可见 n=108n = 10^{8} 时, primes2 循环次数为 248304142248304142, 而 primes185767963148576796314, 差距为 34.534.5 倍。

108lnln(108)=29134739810^{8}\ln{\ln{(10^{8})}} = 291347398, 与实际的循环次数差距仅有 1717%.

改进 2: 使用反转的 BitSet

primes2 最后输出结果,使用了一个 Range (2 to n) 加上 BitSet 来逐个过滤,会有不小的性能浪费。如果我们反过来,直接先填满 BitSet, 然后把 合数从中删掉,最后再直接用 BitSet 的迭代器输出,会怎么样呢?马上试试:

def primes3(n: Int): Iterator[Int] = {
  if (n <= 7) Iterator(2, 3, 5, 7).takeWhile(_ <= n)
  else {
    val q  = math.sqrt(n.toDouble).toInt
    val qs = primes3(q)
    val bs = collection.mutable.BitSet.fromSpecific(2 to n)
    qs.foreach { q =>
      (2 to n / q).foreach { p => bs.remove(p * q) }
    }
    bs.iterator
  }
}

测试:

class EratosthenesSieveBenchmark extends PrimesBenchmark {
  @Benchmark def primes3(state: PrimeState): Unit = bench(state.n, PrimesEratosthenesSieve.primes3(_).length)
}

结果:

[info] Benchmark                                  (n)  Mode  Cnt   Score    Error  Units
[info] EratosthenesSieveBenchmark.primes3        1000  avgt    3  ≈ 10⁻⁵            s/op
[info] EratosthenesSieveBenchmark.primes3       10000  avgt    3  ≈ 10⁻⁴            s/op
[info] EratosthenesSieveBenchmark.primes3      100000  avgt    3  ≈ 10⁻³            s/op
[info] EratosthenesSieveBenchmark.primes3     1000000  avgt    3   0.004 ±  0.001   s/op
[info] EratosthenesSieveBenchmark.primes3    10000000  avgt    3   0.048 ±  0.008   s/op
[info] EratosthenesSieveBenchmark.primes3   100000000  avgt    3   0.524 ±  0.286   s/op
[info] EratosthenesSieveBenchmark.primes3  1000000000  avgt    3   8.781 ±  1.521   s/op

结果比 primes2 有接近翻倍的性能提升。

至此,我们得到了朴素埃拉托色尼筛法的最佳实现:

package simple.primes

import scala.collection.mutable

object PrimesEratosthenesSieve extends App {

  def primes(n: Int): List[Int] = {
    if (n <= 7) List(2, 3, 5, 7).takeWhile(_ <= n)
    else {
      val q  = math.sqrt(n.toDouble).toInt
      val qs = primes2(q)
      val bs = mutable.BitSet.empty
      qs.foreach { q =>
        (2 to n / q).foreach { j =>
          bs.addOne(q * j)
        }
      }
      (2 to n).filterNot(bs.apply).toList
      }
  }
}

我们将会在下一篇里实现更高效的素数筛法。

附,我后来从 孔夫子旧书网 买回来的二手书: