和上篇一樣,這篇是把計算圓周率π的程式照樣修改,
執行速度也快了兩倍多,
原本十億位需要 1 小時 28 分鐘,縮短到 40 分鐘
速度變快的關鍵主要出在三處:
1. mpfr.div() 原本寫錯了,應該要把整數的 mpz 轉成浮點數 mpfr 再除比較快,
我卻寫錯了,直接拿 mpz 去除,不知為何這樣會讓運算速度變得超級慢。
2. 寫檔時 write_string() 裡面有做排版,而這個排版的迴圈是一個字元一個字元
處理,速度非常的慢,我改成一次處理 50 個字元一行,速度就快了 50 倍。
3. 進度條處理得不好,不需更新進度時應該儘快 return,我卻做了多餘的
數學運算,多做一次當然沒什麼,多做七千萬次就有影響效能了。
這個 Divide and Conquer 的寫法很適合 multi-processing,
以及進度條可以改用 tqdm module,這些建議都不錯,
不過請體諒我才剛學 Python 沒幾天,需要點時間消化 (汗)
這次程式碼也放在 https://ideone.com/6YO1zU 方便大家複製貼上
#!/usr/bin/env python3
#
# pi.py - Calculate Pi
#
import sys
import time
import math
import gmpy2
from gmpy2 import mpfr
from gmpy2 import mpz
#
# Global Variables
#
count = 0
total = 0
grad = 0
step = 0
#
# Show Progress
#
def progress_init(max):
global count, total, grad, step
total = max
count = 0
step = int(total / 1000)
grad = int(step / 2)
def progress():
global count, total, grad, step
if (count > grad):
grad += step
g = int(math.floor(72.5*count/total+0.5))
p = int(math.floor(1000.5*count/total+0.5))
msg = "H" * g + "-" * (72-g) + " " + str(p/10) + "%\r"
if (grad > total):
msg += "\n"
print(msg, sep="", end="", flush=True)
#
# Write digit string
#
def write_string(digit_string):
fd = open("pi-py.txt", mode="w")
fd.write(" pi = ")
fd.write(digit_string[0])
fd.write(".")
for c in range(1, len(digit_string), 50):
if (c != 1):
fd.write("\t")
fd.write(digit_string[c:c+50])
if ((c % 1000) == 951):
fd.write(" << ")
fd.write(str(c+49))
fd.write("\r\n")
elif ((c % 500) == 451):
fd.write(" <\r\n")
else:
fd.write("\r\n")
# Final new-line
fd.write("\r\n")
fd.close()
#
# Recursive funcion.
#
def s(a, b, max):
global count
m = math.ceil((a + b) / 2)
if (b - a == 1):
if (a == 0):
r = 120 # 6!
q = mpz(640320**3)
p = gmpy2.sub( gmpy2.mul(q, 13591409),
gmpy2.mul(r, 13591409+545140134) )
else:
r = mpz(8 * (a*6+1) * (a*6+3) * (a*6+5))
q = mpz((b*640320)**3)
if ((b%2) == 0):
p = gmpy2.mul(mpz(13591409 + b*545140134), r)
else:
p = gmpy2.mul(mpz(-13591409 - b*545140134), r)
else:
p1, q1, r1 = s(a, m, max)
p2, q2, r2 = s(m, b, max)
# Merge
p = gmpy2.add( gmpy2.mul(p1, q2), gmpy2.mul(p2, r1) )
q = gmpy2.mul(q1, q2)
if (b != max):
r = gmpy2.mul(r1, r2)
else:
r = 0
count += 1
progress()
return p, q, r
#
# Calculate e
#
def calc_pi(digits):
global total
d = digits+1
n_terms = math.ceil(d*math.log(10)/(3*math.log(53360)))
precision = math.ceil(d * math.log2(10)) + 4
print("d = ", d, ", n = ", n_terms, ", precision = ", precision, sep="")
print("gmpy2 version:", gmpy2.version())
print("MP version:", gmpy2.mp_version())
print("MPFR version:", gmpy2.mpfr_version())
max_precision = gmpy2.get_max_precision()
print("max_precision =", max_precision)
max_emax = gmpy2.get_emax_max()
print("max_emax =", max_emax)
if (max_precision < precision):
print("Error! Max precision is too small! Program terminated.")
return
gmpy2.get_context().precision = precision
gmpy2.get_context().emax = max_emax
print("Real precision = ", gmpy2.get_context().precision)
progress_init(n_terms * 2 - 1) # Initialize progress bar
# Recursion
start_time = time.monotonic_ns()
p, q, r = s(0, n_terms, n_terms)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Recursion:", elapsed, "seconds.")
start_time = time.monotonic_ns()
q = gmpy2.mul(q, 426880)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Multiply by 426880:", elapsed, "seconds.")
start_time = time.monotonic_ns()
pf = mpfr(p)
qf = mpfr(q)
ef = gmpy2.div(qf, pf)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Grand Division:", elapsed, "seconds.")
start_time = time.monotonic_ns()
ef = gmpy2.mul(ef, gmpy2.sqrt(10005))
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Multiply by sqrt(10005):", elapsed, "seconds.")
# Convert to decimal digits
start_time = time.monotonic_ns()
estr, exp, prec = mpfr.digits(ef)
estr = estr[0:d]
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Convert to decimal digits:", elapsed, "seconds.")
# Write to file
start_time = time.monotonic_ns()
write_string(estr)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Write to file:", elapsed, "seconds.")
#
# main program
#
if __name__ == '__main__':
argc = len(sys.argv)
if (argc >= 2):
digits = int(sys.argv[1])
else:
digits = 100000
calc_pi(digits)
# End of pi.py