[EN] Prime numbers and speed
In an ICPC-style contest, 3-person teams compete in programming challenges, and they are allowed to bring a "notebook" into the contest. A notebook is a printable document, that teams usually fill with implementations of common algorithms, so that they are able to pull them out if needed in a contest.
The other day, my ICPC team and I participated in one such contest (though an unofficial one) and we got stuck on a problem. Nevermind the details, but the solution called for finding all primes out of about a million numbers that were less than a trillion (that's 10 to the 12th power). Accordingly, our code looked something like this:
// ... using ull = uint64_t; int main() { // ... vector<ull> numbers; populate(numbers); vector<ull> primes; for (ull x : numbers) if (is_prime(x)) primes.push_back(x); // ... }
The only missing piece was fast primality testing. But since we already have that algorithm in our notebook, it was just copied into our solution.
I've attached the code below, but the specifics are not important. Also, I recognize the awfulness of it but keep in mind that this is library code that we just import into our solution, not something that needs any maintenance or modifications at all. It's mostly optimized for code size because, in official contests, notebooks must be printed on dead trees and all code must be typed in by hand.
// multiplies 64-bit numbers modulo m using egyptian multiplication ull mulmod(ull a, ull b, ull m) { if(a<(1ULL<<32)&&b<(1ULL<<32))return a*b%m; ull r=0;a%=m;b%=m; while(b){if(b&1){r+=a;if(r>=m)r-= m;}b>>=1;a<<=1;if(a>=m)a-=m;} return r; } // Does miller-rabin witness checking for 64-bit numbers bool maybe_prime(ull n, ull a, ull d, int s) { // ... some complicated stuff that uses mulmod ... } // rabin-miller primality testing for 64-bit numbers bool is_prime(ull n){ if(n<2)return false; int s=0;ull d=n-1;while(!(d&1))d>>=1,++s; static const int bases[]={2,3,5,7,11,13,17,19,23,29,31,37}; for(int a:bases)if(!maybe_prime(n,a,d,s))return false; return true; }
Up to this point, one of my teammates was working on this and the other two of us were collaborating on a different problem. As he ran it, he realized there was no way our code would be within the time limit, as it took about 20 seconds to run on our local machine. Since I'm the one with the most experience micro-optimizing code, he called me in and I started working on it.
At first, I looked through his code. There were no quadratic algorithms, not many unnecesary allocations, no weird memory access patters, no blatant unnecesary work. All in all, it looked fast. Nonetheless, I tried to optimize what little there was. I eliminated the 20 or so allocations that I could, fused a few loops, removed a few branches, etc.
Of course, this was a big mistake on my part. There were no tell-tale signs of innefficiency, and I didn't have any data that pointed to that part of the code being the culprit, so the chance of it paying off was very slim. And, indeed, it did not pay off: after optimizing for around 25 minutes, the program still took 20 seconds to run.
At this point, I realized I was just muddying things up, so I reverted to the un-"optimized" version. After calming myself for one or two minutes (even if unofficial, you still want to give everything you have during a contest, so it can be quite nerve-racking), I took a second look and realized that there was only one part of the program I hadn't looked at: The primality testing.
Now, this was a pretty scary realization. I mean, you've seen what the code looks like, you really don't want to have to optimize competitive programming library code in the middle of a contest. I hesitated for a few seconds, but dove in nonetheless.
I started by looking at is_prime
, which does some lightweight computation then calls maybe_prime
several times.
// rabin-miller primality testing for 64-bit numbers bool is_prime(ull n){ if(n<2)return false; int s=0;ull d=n-1;while(!(d&1))d>>=1,++s; static const int bases[]={2,3,5,7,11,13,17,19,23,29,31,37}; for(int a:bases)if(!maybe_prime(n,a,d,s))return false; return true; }
Now, I don't know how maybe_prime
works but, at a glance, it was clear that it was some pretty heavy stuff. It performed several calls to mulmod
(over 60 in some cases), used the modulo operator, and had tons of branches.
So, my first idea was to reduce the amount of calls to maybe_prime
, by adding some early exits in is_prime
, by using a few iterations of the simplest primality testing algorithm there is: trial division (i.e. check if it's divisible by some small prime numbers). For some reason I didn't apply this optimization right away, and kept looking.
My next realization was that a single call to is_prime
could result in several hundred calls to mulmod
. Since is_prime
would be called around a million times, this comes out to potentially a billion calls to mulmod
. Therefore, optimizing mulmod
could be worthwhile. (In hindsight, this is an extreme overestimation, but that's the kind of back of the envelope maths that one does during a contest)
No need to read it, but note that the code for mulmod
is pretty complicated, considering it just multiplies two numbers:
// multiplies 64-bit numbers modulo m using egyptian multiplication ull mulmod(ull a, ull b, ull m) { if(a<(1ULL<<32)&&b<(1ULL<<32))return a*b%m; ull r=0;a%=m;b%=m; while(b){if(b&1){r+=a;if(r>=m)r-= m;}b>>=1;a<<=1;if(a>=m)a-=m;} return r; }
You probably know that the result of multiplying two 64-bit numbers always fits in 128 bits. You might also know that GCC comes with 128 integers built right in. What you might not know is that most contests compile their code using GCC, and that using compiler-specific language extensions is 100% allowed. So, it turns out, that this code can be replaced with the following (with some caveats, but it's ok in this case):
ull mulmod(ull a, ull b, ull m) { return __int128(a) * __int128(b) % m; }
Now, the function is just a 128-bit multiplication, and a 128-bit modulo operation. If this is supported by the hardware the contest runs on, it will be just two instructions, and extremely fast. If it's emulated by GCC, it's still going to be way faster than our crude implementation. (post-contest research places GCC's emulated version at ~100 cycles, and an optimistic estimate for ours is ~600 cycles)
With this change done, I compiled, ran, and timed, and, finally, got an improvement: This version of the code runs in about 6 seconds! From a baseline of 20 seconds, that's more than a 3x improvement!
But our journey is not over yet, because time limits are usually around one or two seconds, and we're still way over that.
After this success, my more mathematically skilled teammates suggested that we might not need to use that many calls to maybe_prime
in the is_prime
function, because it was designed to test any 64-bit integer (roughly up to 10 quintillion), but we only needed it to work up to a trillion. In particular, checking the bases only up to 17 should be sufficient. We tried this optimization, but couldn't get it to work even after 15 or so minutes. I tried again after the contest and it worked fine. I'm still not sure what the problem was.
Then, I implemented my first idea: add an early exit using trial division. The code is as simple as could be: before any of the expensive calls to maybe_prime
, just check if the number is multiple of a few small primes. This should get rid of a very large percentage of composites, lowering the total number of calls to the heavier test.
// rabin-miller primality testing for 64-bit numbers bool is_prime(ull n){ if(n<2)return false; static const int early_bases[]={2,3,5,7,11,13,17}; // This line is new for(int a:early_bases)if(n == a)return true; // This line is new for(int a:early_bases)if(n % a == 0)return false; // This line is new int s=0;ull d=n-1;while(!(d&1))d>>=1,++s; static const int bases[]={2,3,5,7,11,13,17,19,23,29,31,37}; for(int a:bases)if(!maybe_prime(n,a,d,s))return false; return true; }
Again, I compiled, ran, and timed, and found another massive improvement! This time, the code ran in about 0.7 seconds! This is an 8.6x speedup over the previous version, and a 28.6x improvement overall! This version would surely be fast enough for the contest, so we submitted it.
We submitted it, and what we got as an answer could not be more unexpected: Veredict: Wrong Answer
. This meant our code was incorrect. Debugging this was a real head scratcher. I immediately thought that the bug was in one the changes I did to the primality testing (after all, it is really scary competitive programming library code), so I spent some time on that. Then, the teammate that had written the rest of the code suggested that the bug might be in the code that populates the numbers vector.
I was skeptical, since he rarely writes bugs in that kind of code, but he was adamant that we take a look at it. After a few minutes of looking at it, he found the bug, we patched it up, submitted our fixed version, and finally, finally, got our long sought after goal: Veredict: Accepted
Here is what I think I learned:
- Library code is just like other code. Treat it as such. There is no need to be afraid.
- There are two ways to optimize code: make slow functions faster, or make fewer calls to them.
- Only optimize code if you know it's slow. This is good advice, but in this case, something more obvious needs to be said:
- Don't optimize code if you know it's fast. In this case, I lost half an hour optimizing code that I was already convinced was fast.
The lessons are focused around optimization, and what to optimize. I'm sure there is something to be said about debugging, and what to debug, but I haven't figured it out yet. If you have any tips, feel free to message me
About me
My name is Sebastian. I'm a competitive programmer (ICPC World Finalist) and coach. I also enjoy learning about programming language theory and computer graphics.
Social links:
Comments
Post a Comment