Efficient Reinforcement Learning – Rhythm Garg & Linden Li, Applied Compute
Channel: aiDotEngineer
Published at: 2025-12-09
YouTube video id: o15AaYl7Wu0
Source: https://www.youtube.com/watch?v=o15AaYl7Wu0
[music] Hey everyone, it's great to meet you all. Really great to be here today. My name is Rhythm. This is my co-founder Lyndon. Our third co-founder, Yash, couldn't make it today, but we're all very excited to be here. Um, three of us were previously researchers at OpenAI, and now we're bringing Frontier AI inside of enterprise at applied compute. Today, we're going to be talking about efficient reinforcement learning. As some context on applied compute, we help enterprises build their own intelligence to power real work in their company. We think a lot about how do we push AI beyond productivity into real automations that deliver ROI. that's quantitative for the company. Once we build a system that's specialized to the way that a company operates for a particular use case, we deploy it with a data flywheel so that it gets better over time the more and more that you use it. Picture an in-house expert at a company that's always at the forefront of their field. RL mechanically is the is the tool that we use in order to bring these out of distribution data sets in distribution for the models today. Yash Lyndon and I all worked on the RL effort at OpenAI in its early days and we saw firsthand the power of RL in going and maximizing these public benchmarks. Now we're taking that a step further and helping enterprises go solve the problems they care the most about sort of their private benchmarks. So here's a very highle overview of how highMP compute RL helps LM acquire these reasoning and intelligence capabilities. Let's say that you have a data set of math problems and we pick four of them for an RL training step. Then we'll take an open source model, say one of the GPOSS models or one of the llama models, and we have the model attempt each of those four problems 100 times. So each of these 100 attempts is the model thinking through how it would get to the final answer and then ending off with with the final answer itself. And these are many many reasoning tokens in its thinking trajectory. We can grade all of these answers and when the model is correct, we can bias the model's weights to reinforce its thinking trace in that attempt. When it's incorrect, we can discourage the model from having that kind of behavior again. So in this fashion, as we train do more and more training steps with batches of four problems, 100 attempts each, the model learns to reason and solve math problems, and it becomes really, really good at math. Of course, at Applied Compute, we're not really helping enterprises solve math problems, but this is kind of the mechanism by which we're able to teach the models to get really, really good at tasks that they care about. So, as we mentioned, the type of RL work that we do at Applied Compute is actually quite different from the lab. So, the these are some real life photos from from the labs and a photo we took at the at the applied comput office the other day. Um, they you know, the labs do these big training runs over several weeks. We do more specialized runs And you know, there's a couple of aspects of RO training that are particularly important to us. We need our runs to be fast so that we can train a model and deliver it to a customer very quickly on the order of days. They have to be cheap so that our unit costs work and we're able to scale the business sustainably. And importantly, and this is a point that I think um you know it's it's easy to miss, we need our estimates for how long these training jobs will be to be very low variance because we don't want to just be generally fast. We want to be reliably fast when we work with customers. And so the research problem for us that is very business critical is can we build an RL stack that is so efficient so that in conjunction with our agent building platform we are really able to scale up this use case specific training motion. So let's start with an inefficient form of RL which is synchronous RL. In synchronous RL sampling and training happen in lock step. So there's some simplifications here, but but let's say that we want to train on batches of eight samples. That means we're going to wait for all eight samples to finish and basically finish completion before we start training. And then we're going to repeat this process again. As a result, we have a lot of idle GPUs that are waiting on that third straggler sample to complete. So in other words, in synchronous RL, our step times are dictated by whatever sample takes the longest time in order to complete. To illustrate why this is bad, we took 40 arithmetic problems, requested 32 samples each for each of them with quen 30B, and we measured how long it would take for the for these samples to complete. It turns out that 99% of the samples completed in about 40 seconds. Took another 80 seconds to get that last percent of samples to complete. It really has a long tail. So, as you'd expect, if you look at the throughput chart, the GPUs are doing a lot of work at the beginning when all of the sampling requests are launched, but by the end, they're very very underutilized because they're waiting on those last samples to complete. The technical term we use at applied compute is the GPUs are slacking. Um, so synchronous RL is not an efficient way to to use these GPUs. In order to solve this problem, we need to break the condition that sampling and training need to happen in lock step. In other words, we need to allow training while we're sampling. This is called asynchronous RL. And there are many approaches to doing asynchronous RL. One that we particularly like is pipeline RL from P at all. We're going to make some simp simplifications here, but in asynchronous pipeline RL, we dedicate some GPUs to sampling and some GPUs to training. The sampling workers never stop. They're constantly doing inference with high batch size. As samples complete, they get added to a queue for training and the training workers pull a batch from the queue to train on. After a a batch has been trained on, the training workers propagate the new model weights to all of the sampling workers for what's called an in-flight weight update. And this is really what differentiates pipeline RL. The sampling workers might be in the middle of a sample, but their weights will still get updated if if a training step just completed. As a result, we end up with samples that had multiple versions of the policy that contributed to the sample in order to generate it. In other words, there are stale tokens in some of these in some of these samples. Let's take a look at one sample to make this a bit more clear. As you can see, there's three versions of the policy at time steps t, t+1, and t plus2 that were used to generate this sample since there were two completed train steps and in turn two in-flight weight updates while this sample was being generated. So when this sample gets trained on in the T+3 to t+4 training batch, we will have some tokens that came from policy three steps behind, some that came from policy two steps behind, and those last two tokens that came from a policy that was one step behind. Now, let's say that we only tolerate stailness up to two. That means we're not going to allow the inflight weight update after the T+1 to T+2 training batch completes. And that means the training workers are just going to be idle waiting for this sample to complete before they can propagate that in-flight weight update and start training on the next batch. Because if they were to do the inflight weight update, that would cause this sample to have stalness 3 as we just saw. And if we only tolerate stailness one, the training workers are going to be idle for even longer, which is bad. So as you increase how much stale you tolerate, you have less idle GPUs in general. But as we all know, there's no free lunch. Um this is the standard policy gradient with an importance ratio to adjust for the fact that we're sampling from a policy at time step t and training with the policy at time step t t plus k given that there's case staleness. The importance ratio is what makes this policy gradient unbiased. But the variance of that ratio increases as you increase stalness. And so this is kind of the big issue here because now with with higher variance importance ratio learning can become unstable and cause divergence. The concrete trade-off is we want a lot of stailness for fast RL runs, but a lot of staleness makes learning unstable, which then requires innovating on the algorithm and the science. And this is one of the primary research problems that we focus on here at Applied Compute. And as I was talking about earlier, it directly flows back into our core business. For the purpose of this talk, we're going to focus on a simpler sub problem. Let's assume that we have good science and algorithmic innovations that allow us to tolerate staleness up to some fixed threshold and we have some fixed compute budget as usually exists in the world. What is the highest way for us to do RL in this setting? Cool. Thanks Rhythm. So we posed this as a modeling problem of our endto-end system which you know admittedly is a little bit complicated at first but we did find that we can get surprisingly far with some first principle systems modeling and as with any modeling problem let's figure out the cast of characters that describe the system and then we'll think about how they all fit together to model it. So the first cast member is some proxy of compute budget in which in this case we have as the number of GPUs. In the synchronous setting like rhythm just explained all the GPUs will either be used for training or sampling since they happen one after the other. But in the asynchronous setting it's a little bit trickier cuz we can choose to allocate that pool of GPU GPU compute as much as we want for training or as much as we want from sampling and that leads to some design decisions. The next is the training batch size which is some proxy of the workload that we have uh on the on the overall system and this is kind of an ML decision but in short what we have is a batch of problems which is a subset of our data set. Let's say we have n math problems that we want to train on and for each of these problems we're going to sample n problems in parallel. So if the problems are really difficult, we might sample more to encourage some diversity in the samples to encourage the model to learn some potentially uh divergent strategies. The next thing we need is some proxy of sampling throughput. And to get some intuition of what we should choose here as a modeling decision, let's look at how some modern inference engine surface requests. So in GPU memory, we have the model weights, the activations, and some runtime state called the KV cache in memory. And given this train model, we're going to run the forward pass several times where each forward pass samples the next token and then we'll write to the KV cache. And so what this model shows is that a principal estimate that we should do is we should find some way to measure the latency per GPU of the forward pass. And this ends up being a pretty good choice in practice because from the systems angle, the inference throughput that we choose is largely determined by the batch size that we perform sampling with. So what I've shown here in the red square is a batch of tokens that are all forwarded at the same time. And this sampling forward pass needs to be as large as possible to efficiently utilize the GPUs subject to the runtime constraint that we don't actually run out of memory uh in the KV cache. So what we can then do is we can fit a latency curve as a function of batch size and that latency curve will look something like this. You'll have some regime where it's memory bound and when it increases it becomes computebound and there's some functional form below. And to explain the details of why we chose this decision, what we have here is an equation that's based in the roof line model from systems. At lower batch sizes, which I've highlighted in yellow here, we don't have that much work to do because there isn't that much compute to do on the processor and there's so many parameters you need to load in at the same time. And so, as a result, when you add incremental work, it doesn't really add that much latency to the overall system since the processor is so fast at doing math that we're just waiting on memory to stream parameters in from the pro from memory to the processor. But as the batch sizes begin to get larger, we then get bottlenecked by the processor. And the more we add to our batch, the slower the forward pass takes. And just for good measure, we have this sigmoid here that just sort of modulates the smooth transition at this hinge point here to show that there's a subtle transition from a memory bound computation to one that's more computebound and bottlenecked by the processor. The final cast member is some proxy of training throughput and we chose to measure this on a per GPU basis. So in this case the model takes in the training batch size. So the parameter we saw earlier and we typically do this by fitting a proxy of our empirical workloads. The units here is how many each train how many tokens per second each training GPU processes. So it needs to do the forward the backward and some optimizer steps. So given these forecast members we can then begin modeling the system. And the first idea we had although Rhythm you know suggested that this might not be a great idea we can think about how to use a synchronous setup. And this might be a good idea from first principles because we definitely meet the staleness constraint because we don't train on stale data and we always use the entire GPU fleet for either training or sampling making sure that we're using efficient use of the hardware. Let's think about how to actually model this. There are two things we need to know. We need to know the batch size at which generation runs. And we also need to know the response length distribution to figure out how our training workload's going to work and also how long the sampling's going to take. And so what I'm showing here in this simulation is a couple of engines. Each square is a request being processed and they get darker and darker as we make progress throughout the batch. And as they finish samples, they write to the queue. And on the right hand side is a time series metric, maybe something that you'd see in Graphana if you're monitoring production metrics. And what you can see is that the batch size begins very high, but it slowly goes down over time as it eventually goes to zero and all the samples complete. And we can finally run an optimization step. After the step completes, we run this in a loop and we move on to the next step. And so as a result, we can have the following sampling procedure. We do maximum tokens inference forward passes where maximum tokens is the total number of forward passes we do for the longest request. We use the fitted latency estimator to figure out how long that forward pass will take. And then the response length distribution will tell us how many responses to drop. And so what we're showing in this video here is this entire thing of the response length distribution that we feed into the latency estimator. At training time, we can compute the total number of tokens that we just sampled in the batch and divide by the total uh training throughput uh which is just the number of GPUs multiplied by the per GPU training throughput. And so what we have here is a simulation of what this latency curve looks like. So we have the CDF of the response length distribution that tells us how many responses we should drop on the left and the latency curve on the right. And this roughly kind of tracks because as we add more GPUs, we'd expect the latency per step to go down. The next idea, given that the synchronous setup might not be the most principled choice, as Rhythm showed, is an asynchronous setup. But it's not just as easy as just sort of provisioning the compute between training and inference because if we don't do this carefully, we might actually run into the idle GPU problem again. And to show this, let's illustrate two extremes of what this allocation problem looks like. Let's f let's let's first look at one end of the spectrum where we provision way too many inference GPUs and not that many samplers. In this case, we're consuming from a queue much faster than we're actually producing from it because the sampling workers are producing work significantly faster than significantly slower than we can actually consume them. When the red square grays out, it shows that they're idle. And what this diagram should hopefully illustrate is that for a lot of the time we're actually not using that and that has the same problem of low GPU utilization in the synchronous case as shown earlier. On the other end of the extreme we can provision way too many sampling GPUs in which case our production rate is way faster than the rate that we actually consume them in. So here we've doubled the number of overall sampling GPUs and have the number of training GPUs. As you can see, they produce samples at much more rapid of a rate. But this index here in each yellow square, which is the staleness count of each sample, goes up. And as time moves on, we get more and more stale. And so the samples get more and more kind of less more and more transparent as a result. And we learn less from each individual sample. So let's think about how we can actually model this workload then to to determine an optimal async workload. In this case, the picture looks a little bit different because in steady state, the batch size is relatively consistent compared to the synchronous setup where it kind of goes down over time. So on the right hand side here, we have the same time series metrics. But in this case, it's a little bit different because the yellow squares are always full because every time we complete a sample, a new sample goes in and we can continue writing to the queue. And so that batch size with a little bit of wiggles just for good measure is like a is pretty consistent over the course of a run. Now obviously the caveat here is that this batch size will certainly go down as we you know as response lengths go up because we run out of cache uh KV cache but that's kind of a separate story and actually our model accommodates for that because we're actually accommodating for a response length distribution. We can then begin to figure out the optimal layout and there's two kind of constraints that we have to satisfy now that we know that the generation batch size is roughly consistent throughout the course of a run. The first invariant that we need to have is that the production consumption rate are roughly equal. So on the left hand side of this equality we have the training throughput which is the number of training GPUs multiplied by the per GPU uh throughput and then also we have the number of sampling GPUs multiplied by the sampling throughput which is just the batch size multiplied by the latency to actually do a forward pass on that batch size. And the next thing is that given that rhythm you indicated that if we have too much stailness that can be bad from an ML ML perspective, we want to make sure that our max theoretical stailness or simulated steness doesn't exceed what our ML can handle. And so here we have the max stillness on the left which is equal to on the top how much time the longest request took in the batch which is just the maximum number of tokens multiplied by the number of uh by the amount of time each token forward pass takes. And on the bottom here we have the length of a training step which is the training batch size multiplied by the mean sequence uh by the mean sequence length. So the simulation here then will sweep through multiple different values of the number of training GPUs. And since we have a fixed pool of compute that then implies a certain number of GPUs used for sampling. And for this number of sampling GPUs, we can compute the minimum steadystate generation batch size to make sure that we don't blow out of memory uh subject to our KV cache memory constraints and also such that we have maximum throughput on the on the sampling side. And the final thing is we want to prune out all simulations where the sampling throughput brings us over the maximum possible stailness. When we look at that simulation, we can run an end to end similarly parameterized by the response length. We see that this kind of roughly simulates a 60% speed up relative to our synchronous baseline, assuming that the GPU compute is optimally allocated between training and sampling. As a result, when we sweep layouts within these constraints, this allows us to limit staleness, but also make sure that we have our runs running at maximal throughput without actually doing the run itself. And so this gives us insight to simulate different workloads before actually running them on the GPU because these runs can actually be fairly expensive. And so this allows us to ask answer scientific questions from first principles like what is the optimal configuration that we we should have of our GPU compute if we made response lengths very long because often times when models learn via reinforcement learning they begin to think for much longer and also what empirical throughputs we should target during our performance optimization. So this has been a really useful piece of technology for simulation has informed a lot of the systems and research design decisions that we make. Cool. Thanks for your time and find us afterwards to jam on some more RL research engineering together later. Thank you. [music] >> [music]