Rejection methods on GPUs

I've previously stated (ahem, in public at the Thalesian GPU workshop) that rejection methods are not that great on a GPU. However, then Mike Giles said that he was using rejection for generating gamma distributed numbers, and apparently someone asked him why, given I had said they don't perform well on GPUs (I was talking to someone outside, so missed this part of his presentation).

The short answer is that rejection is pretty much the only game in town when it comes to gamma variates (see Note 1), so you have to use rejection no matter the performance cost. The longer answer is that you can reduce the cost of rejection quite a lot in a GPU by turning the warp-based parallelism from a disadvantage into a (sort-of) advantage.

What do I have against rejection methods?

Err, nothing. Some of my best friends are rejection methods! My only point is that on a GPU the warp-based parallelism changes the costs of rejection methods quite significantly. At the workshop I was making a slightly different point about methods with a fast/slow probabilistic path, like Ziggurat, so I should deal with the more general case here.

Assume that we are generating samples from some distribution with a set of parameters. For the purposes of this analysis is doesn't really matter what the parameters are, we'll just treat them as an opaque tuple, which may be empty. For rejection, we need two stochastic functions, i.e. functions containing some internal source of randomness (although see Note 2). The functions are:

To generate a random number, we use these functions together:
float RNG(params_t parms)
  return x;
This is a reasonable characterisation of basic rejection - specific methods have multiple levels of accept/reject, or different paths, but I'm trying to tackle the basic problem here.

We will assume that the randomness in both a and g are IID - there is no memory or correlation between each call to the functions. We can then state that Pr[ a(g(parms),parms) = true | parms ] = p; so given a specific set of parameters, the probability of accepting a given sample is a fixed constant p (obviously 0<=p<=1). This probability p directly affects the performance of RNG, as it determines how many times we expect to go round the loop.

In a scalar processor (i.e. a CPU), the number of loop iterations follows a geometric distribution with parameter p, so the average number of loop iterations per call to RNG is 1/p.

However, in a vector processor with warp-based parallelism (i.e. a GPU), the number of loop iterations is different. The fundamental property of warp processors is that if one thread in a warp takes a branch, all threads in the warp take the branch, at least in terms of performance (if you don't know what I'm taking about, go and read up on CUDA and GPUs). So in a rejection method, if any thread in a warp fails the iteration test, then they all have to do another iteration. So given a warp size of K, the probability that all threads generate a valid sample (i.e. a(g())=true) on a given iteration is p^K.

This graphs shows the relationship between the probability of one thread succeeding (i.e. p) on the x-axis, versus the probability of all threads in a warp size of K succeeding:

If we look at the situation for a real GPU, i.e. K=32, then even very large values of p result in very low probabilities of all K threads succeeding. You may think this all looks rather bad, but please note that this is only the probability of generating a random number in a single iteration.

Calculating the average performance

To characterise the performance we need to work out the average number of loop iterations, rather than the probability of success on the first loop iteration (which as we saw is rather low). Please bear in mind we are talking about performance here, so the control-flow of each thread is not important, we have to focus on the control-flow of the warp.

An initial thought might be that the warp just follows a different geometric distribution, for example with probability of success per iteration of p^K. Thankfully, this isn't the case, otherwise we would be in serious trouble: there is thread-level memory between the different warp-level iterations, so the probability of exiting the warp-level loop changes after each iteration.

It should be obvious that the probability of successfully exiting the warp-level loop increases monotonically with iteration count - at each iteration either the number of threads still left to generate a random number either stays the same or decreases. So we can characterise the distribution of loop iterations recursively, by looking at how many iterations are required for different numbers of remaining threads.

Let r be the number of threads that still need to generate a number. We will define a discrete probability distribution L_r, which describes the number of warp-level loop iterations needed for r threads to all generate a random number. From a performance point of view we are interested in the distribution of L_K, and in particular on E[K_K], as this describes the average number of loop iterations.

Trivially, L_0=0; if there are no threads that need to generate a number, we don't have to do any loop iterations. We have already described the scalar case, L_1, which just follows a geometric distribution. However, calculating the distribution of L_r with r>1 is more complicated, so we'll have to do more work. Note that I am replicating known results for the maximum of K geometrics here, as it is a useful illustration of how to handle a more complicated case later on.

If we want to calculate L_r, we have to look at what can happen in each loop iteration. If there are r threads trying to generate a random number, then after one loop iteration we can end up in r+1 different states:

If you think about it you can see that this is just a binomial distribution with parameter r and p - we are performing r independent trials and the probability of success of each one is p.

We can now form a lattice of states, and the transition probabilities between them:

So the loop starts at the top left node, and each horizontal step is one loop iteration. All the edges leaving one node are given probabilities from the binomial distribution we just described. The bottom node indicates that the loop iteration has finished, and there is only one edge leaving it - once all threads have generated their number executing another iteration achieves nothing. We can view the complete set of transitions as a stochastic matrix T, with each column contiaining the outward probabilities from each node, and the lack of an arc meaning a probability of zero.

Given this matrix, we can work out the probability of moving from one state to another. If we define a column vector s, containing the probability of being in each state, then we can calculate the probability of each state after one iteration by doing the matrix vector multiplication T*s. Before the first iteration we know that r=0, so s will be all zeros except for one element (which is 1). From this initial state we can walk the probabilities forwards, and find the probability of being in each state after a given number of loop iterations. Looking at the probability for r=K after each iteration gives the CDF of L_K.

Calculating the probabilities, we see the expected loop count looks like this:

Usually in rejection methods we have a probability which is relatively close to 1, so here's a close-up of that region:

So with a rejection method that has a relatively low p like 0.8, the cost due to warp-level paralellism is that on average the method will take about 3 times as long as we might expect.

Pre-caching variates

One way of reducing this overhead is by remembering the basic relationship of threads and warps: if one thread in a warp takes a branch, then all threads take the overhead of that branch. In our rejection method that means that all the threads that have already generated a number are still doing all the instructions for any following iterations, they just aren't writing back to the registers afterwards. So during the second and third iteration, lots of threads are (sort of) generating "ghost" numbers - they are actually calculating the random number that they will generate on the next call to RNG, and then completely forgetting about it.

We can attempt to improve the efficiency of generation by actually taking advantage of all this free work. If lots of threads are generating the answer to the next call during the current call, we can just let them actually do the work, and save that value. So the idea is to "pre-cache" the answer to the next call during the current one. Then during the next call to RNG we can check whether the same parameters are being used, and so use the cached variate (if there is one). The random number generator with pre-caching looks something like this:

float RNG(params_t parms)
  static params_t cachedParms;
  static bool cachedValid=false;
  static float cachedSample;
  bool currentValid=false;
  float currentSample;
  currentValid = cachedValid && (cachedParms==parms);

    float sample=g(parms);
    bool valid=a(sample, parms);
    }else if(!cachedValid){
    // All threads in warp vote on whether they are finished
  }while( !__all(currentValid) );
  return currentSample;
If you've never come across it before, the __all intrinsic allows all the threads within a warp to vote on something - basically it does a horizontal "and" across all the threads. There is also an equivalent __any instrinic which does a horizontal "or", so we could have used De Morgan's laws and made the loop condition "while(__any(!currentValid))".

Adding the caching adds two types of overhead. First there is the need for additional registers to hold the persistent state between calls. The exact number of registers will depend on the size of the distribution's parameter set (i.e. params_t) and the properties of the distribution. For example, in a half-open distribution with no parameters you could get away with one register, as you could use the value of the cached sample to see whether it is valid or not. For example, in a distribution where outputs are in the range [a,infinity), you could use (a-1) to indicate an invalid sample.

The cost of the additional registers shouldn't be under-estimated, as it increased the overall register pressure on the whole kernel. You have in effect permanently removed some registers from the pool of registers available to the compiler, which may seriously effect on performance. For example, the compiler might have to start spilling registers to global memory, which will seriously hurt you.

The second type of overhead is due to the extra instructions needed to manage the cached values. A very informal scan of the code suggests an overhead of about 16 instructions, assuming the conditionals are done using predication rather than branches. This isn't too bad - they are all "nice" instructions, so no memory accesses, and nothing that will cause conflicts. However, whether caching is worth it depends on how expensive a(g()) is, and how often you get a cache hit.

It should be clear that, if we ignore the extra instructions and think in terms of loop iterations, the cached version always has the same or better performance than the original version. If the distribution parameters are being varied on a per-call basis then there will never be a cache hit, so it just reduces to the uncached case.

Performance estimation of cached version

So can we augment the uncached performance model to capture the caching as well? Yeah, sure, why not. We now have two state variables to track, which we'll call r for the number of threads that haven't got a current number yet, and c for the number of threads which have both a current and a cached number (I'm not sure why I inverted the state variables like that, but never mind).

So we have a set of states (r,c), where 0<=r<=K, 0<=c<=(K-r). Any state with r=0 is an absorbing state, i.e. the loop has finished. I'm not going to try drawing a lattice for this version, as it gets too dense.

On the very first call to RNG, we start from the state (r=K,c=0), i.e. all threads still need to get a current number, and so none of them have a cached number. From this state between 0 and K numbers will be generated on the next iteration, so we have r'=binomial(K,1-p), and c'=0. This is just the same as the non-cached verson.

However, once at least one thread has a current sample (i.e. r<K) then there is now a chance that caching may happen. If c<(K-r) then there are some threads that can now generate a cached sample (if c=(K-r) then all threads that could have a cached sample already do have one). This set of (K-r-c) threads form another independent binomial(K-r-c,p) distribution. So we now have two things happening: r'=binomial(r,1-p), and c'=c+binomial(K-r-c,p). But, because these two things are independent, we have to consider every valid output combination of (c',r'), and multiply out the PDFs of the two distributions.

We can build the transition matrix T for this process quite easily - it's a bit tedious and easy to mess up, but still simple. See calculate_cached_distribution_v2.m for matlab/octave code that will do it for you.

However, in the cached case, there is an additional issue, in that there is persistence between each call to RNG(). In the uncached case we always knew that the intial state was r=K, but this is not the case for the cached case.

What we need is to work out the steady-state distribution of cached values after each call to RNG() completes, and then use that to provide the initial caching state at the input of RNG(). We can map the output state of the RNG by noting that states (r=0,c) at the output map to states (r=K,c) at the input, so we can define some function w(s)->s that converts an output distribution into an input one.

To find the steady-state probabilities, we'll first repeatedly power the matrix T until it stops changing, i.e. find T^*=lim(T^n) as n->infinity. This gives a matrix that jumps from any initial distribution of states straight to the distribution of states after an infinite number of loop iterations.

We can then find the steady-state distribution of states between calls to RNG by repeatedly applying T^*, i.e. defining an intial state s_0=(r=K,c=0) then applying s_{i+1}=T^* w(s_i) until |s_{i+1}-s_i| drops below a threshold. In principle this might not be a stable recurrence, but it seems to work quite well here.

Once we know the steady-state between calls, we can use the same process we used for the uncached case to calculate the distribution of L_K. See recurse_cached_distribution.m for code that can calculate L_K or E(L_K) for various values of p and K.

If we look at the expected number of loop iterations for the cached version, then we find a graph that looks like this:

This isn't terribly useful by itself, so it makes more sense to look at the ratio of uncached to cached iterations:

So we find that up until p becomes very close to 1, the uncached version takes on average 60% more iterations than the cached version, so to a rough approximation we could argue that the uncached version takes 60% longer to generate each random number. The actual answer is more complicated than that, and should be derived through benchmarking of course.

So the three key concerns I haven't addressed here are:

Practical use of caching in rejection methods depends on all three, so it's use is related to the specific application, distribution, and rejection algorithm.

Note 1: Is rejection the only solution for gamma variates?

This is not strictly true, but for practical purposes it is. For integer a you can swap a probabilistic algorithm for a deterministic one, but it takes O(a) time to generate each variate. If you are using a specific a for a huge number of variates you can also use some kind of segmented polynomial or rational approximation. But, if we assume a might change on a per-call basis, rejection is the only viable solution I am aware of.

Note 2 : Do both functions have to be stochastic?

Often both of these functions contain some kind of internal source of randomness (i.e. a call to some underlying random number generator, like a uniform or gaussian RNG). If g does not contain randomness then you are just generating a bernoulli variate in a very complicated manner. If a does not contain randomness then you still have a rejection method, but where you are just masking off parts of the range (for example, this happens in the Polar method). I leave it to your imagination what happens if neither contains a source of randomness.

Up: RNGs for GPUs,