Efficient Reinforcement Learning – Rhythm Garg & Linden Li, Applied Compute

Channel: aiDotEngineer

Published at: 2025-12-09

YouTube video id: o15AaYl7Wu0

Source: https://www.youtube.com/watch?v=o15AaYl7Wu0

[music]
Hey everyone, it's great to meet you
all. Really great to be here today. My
name is Rhythm. This is my co-founder
Lyndon. Our third co-founder, Yash,
couldn't make it today, but we're all
very excited to be here. Um, three of us
were previously researchers at OpenAI,
and now we're bringing Frontier AI
inside of enterprise at applied compute.
Today, we're going to be talking about
efficient reinforcement learning.
As some context on applied compute, we
help enterprises build their own
intelligence to power real work in their
company. We think a lot about how do we
push AI beyond productivity into real
automations that deliver ROI. that's
quantitative for the company. Once we
build a system that's specialized to the
way that a company operates for a
particular use case, we deploy it with a
data flywheel so that it gets better
over time the more and more that you use
it. Picture an in-house expert at a
company that's always at the forefront
of their field.
RL mechanically is the is the tool that
we use in order to bring these out of
distribution data sets in distribution
for the models today. Yash Lyndon and I
all worked on the RL effort at OpenAI in
its early days and we saw firsthand the
power of RL in going and maximizing
these public benchmarks. Now we're
taking that a step further and helping
enterprises go solve the problems they
care the most about sort of their
private benchmarks.
So here's a very highle overview of how
highMP compute RL helps LM acquire these
reasoning and intelligence capabilities.
Let's say that you have a data set of
math problems and we pick four of them
for an RL training step.
Then we'll take an open source model,
say one of the GPOSS models or one of
the llama models, and we have the model
attempt each of those four problems 100
times. So each of these 100 attempts is
the model thinking through how it would
get to the final answer and then ending
off with with the final answer itself.
And these are many many reasoning tokens
in its thinking trajectory.
We can grade all of these answers
and when the model is correct, we can
bias the model's weights to reinforce
its thinking trace in that attempt. When
it's incorrect, we can discourage the
model from having that kind of behavior
again. So in this fashion, as we train
do more and more training steps with
batches of four problems, 100 attempts
each, the model learns to reason and
solve math problems, and it becomes
really, really good at math. Of course,
at Applied Compute, we're not really
helping enterprises solve math problems,
but this is kind of the mechanism by
which we're able to teach the models to
get really, really good at tasks that
they care about.
So, as we mentioned, the type of RL work
that we do at Applied Compute is
actually quite different from the lab.
So, the these are some real life photos
from from the labs and a photo we took
at the at the applied comput office the
other day. Um, they you know, the labs
do these big training runs over several
weeks. We do more specialized runs
And you know, there's a couple of
aspects of RO training that are
particularly important to us.
We need our runs to be fast so that we
can train a model and deliver it to a
customer very quickly on the order of
days.
They have to be cheap so that our unit
costs work and we're able to scale the
business sustainably.
And importantly, and this is a point
that I think um you know it's it's easy
to miss, we need our estimates for how
long these training jobs will be to be
very low variance because we don't want
to just be generally fast. We want to be
reliably fast when we work with
customers.
And so the research problem for us that
is very business critical is can we
build an RL stack that is so efficient
so that in conjunction with our agent
building platform we are really able to
scale up this use case specific training
motion.
So let's start with an inefficient form
of RL which is synchronous RL. In
synchronous RL sampling and training
happen in lock step. So there's some
simplifications here, but but let's say
that we want to train on batches of
eight samples. That means we're going to
wait for all eight samples to finish and
basically finish completion before we
start training. And then we're going to
repeat this process again. As a result,
we have a lot of idle GPUs that are
waiting on that third straggler sample
to complete.
So in other words, in synchronous RL,
our step times are dictated by whatever
sample takes the longest time in order
to complete.
To illustrate why this is bad, we took
40 arithmetic problems, requested 32
samples each for each of them with quen
30B, and we measured how long it would
take for the for these samples to
complete.
It turns out that 99% of the samples
completed in about 40 seconds. Took
another 80 seconds to get that last
percent of samples to complete. It
really has a long tail.
So, as you'd expect, if you look at the
throughput chart, the GPUs are doing a
lot of work at the beginning when all of
the sampling requests are launched, but
by the end, they're very very
underutilized because they're waiting on
those last samples to complete. The
technical term we use at applied compute
is the GPUs are slacking. Um, so
synchronous RL is not an efficient way
to to use these GPUs.
In order to solve this problem, we need
to break the condition that sampling and
training need to happen in lock step. In
other words, we need to allow training
while we're sampling. This is called
asynchronous RL. And there are many
approaches to doing asynchronous RL. One
that we particularly like is pipeline RL
from P at all.
We're going to make some simp
simplifications here, but in
asynchronous pipeline RL, we dedicate
some GPUs to sampling and some GPUs to
training. The sampling workers never
stop. They're constantly doing inference
with high batch size. As samples
complete, they get added to a queue for
training and the training workers pull a
batch from the queue to train on. After
a a batch has been trained on, the
training workers propagate the new model
weights to all of the sampling workers
for what's called an in-flight weight
update. And this is really what
differentiates pipeline RL. The sampling
workers might be in the middle of a
sample, but their weights will still get
updated if if a training step just
completed.
As a result, we end up with samples that
had multiple versions of the policy that
contributed to the sample in order to
generate it. In other words, there are
stale tokens in some of these in some of
these samples. Let's take a look at one
sample to make this a bit more clear.
As you can see, there's three versions
of the policy at time steps t, t+1, and
t plus2 that were used to generate this
sample since there were two completed
train steps and in turn two in-flight
weight updates while this sample was
being generated.
So when this sample gets trained on in
the T+3 to t+4 training batch, we will
have some tokens that came from policy
three steps behind, some that came from
policy two steps behind, and those last
two tokens that came from a policy that
was one step behind.
Now, let's say that we only tolerate
stailness up to two. That means we're
not going to allow the inflight weight
update after the T+1 to T+2 training
batch completes. And that means the
training workers are just going to be
idle waiting for this sample to complete
before they can propagate that in-flight
weight update and start training on the
next batch. Because if they were to do
the inflight weight update, that would
cause this sample to have stalness 3 as
we just saw.
And if we only tolerate stailness one,
the training workers are going to be
idle for even longer,
which is bad. So as you increase how
much stale you tolerate, you have less
idle GPUs in general. But as we all
know, there's no free lunch. Um this is
the standard policy gradient with an
importance ratio to adjust for the fact
that we're sampling from a policy at
time step t and training with the policy
at time step t t plus k given that
there's case staleness. The importance
ratio is what makes this policy gradient
unbiased. But the variance of that ratio
increases as you increase stalness. And
so this is kind of the big issue here
because now with with higher variance
importance ratio learning can become
unstable and cause divergence.
The concrete trade-off is we want a lot
of stailness for fast RL runs, but a lot
of staleness makes learning unstable,
which then requires innovating on the
algorithm and the science. And this is
one of the primary research problems
that we focus on here at Applied
Compute. And as I was talking about
earlier, it directly flows back into our
core business.
For the purpose of this talk, we're
going to focus on a simpler sub problem.
Let's assume that we have good science
and algorithmic innovations that allow
us to tolerate staleness up to some
fixed threshold and we have some fixed
compute budget as usually exists in the
world. What is the highest way for us to
do RL in this setting?
Cool. Thanks Rhythm.
So we posed this as a modeling problem
of our endto-end system which you know
admittedly is a little bit complicated
at first but we did find that we can get
surprisingly far with some first
principle systems modeling and as with
any modeling problem let's figure out
the cast of characters that describe the
system and then we'll think about how
they all fit together to model it. So
the first cast member is some proxy of
compute budget in which in this case we
have as the number of GPUs. In the
synchronous setting like rhythm just
explained all the GPUs will either be
used for training or sampling since they
happen one after the other. But in the
asynchronous setting it's a little bit
trickier cuz we can choose to allocate
that pool of GPU GPU compute as much as
we want for training or as much as we
want from sampling and that leads to
some design decisions.
The next is the training batch size
which is some proxy of the workload that
we have uh on the on the overall system
and this is kind of an ML decision but
in short what we have is a batch of
problems which is a subset of our data
set. Let's say we have n math problems
that we want to train on and for each of
these problems we're going to sample n
problems in parallel. So if the problems
are really difficult, we might sample
more to encourage some diversity in the
samples to encourage the model to learn
some potentially uh divergent
strategies.
The next thing we need is some proxy of
sampling throughput. And to get some
intuition of what we should choose here
as a modeling decision, let's look at
how some modern inference engine surface
requests. So in GPU memory, we have the
model weights, the activations, and some
runtime state called the KV cache in
memory. And given this train model,
we're going to run the forward pass
several times where each forward pass
samples the next token and then we'll
write to the KV cache. And so what this
model shows is that a principal estimate
that we should do is we should find some
way to measure the latency per GPU of
the forward pass. And this ends up being
a pretty good choice in practice because
from the systems angle, the inference
throughput that we choose is largely
determined by the batch size that we
perform sampling with. So what I've
shown here in the red square is a batch
of tokens that are all forwarded at the
same time. And this sampling forward
pass needs to be as large as possible to
efficiently utilize the GPUs subject to
the runtime constraint that we don't
actually run out of memory uh in the KV
cache.
So what we can then do is we can fit a
latency curve as a function of batch
size and that latency curve will look
something like this. You'll have some
regime where it's memory bound and when
it increases it becomes computebound and
there's some functional form below. And
to explain the details of why we chose
this decision, what we have here is an
equation that's based in the roof line
model from systems. At lower batch
sizes, which I've highlighted in yellow
here, we don't have that much work to do
because there isn't that much compute to
do on the processor and there's so many
parameters you need to load in at the
same time. And so, as a result, when you
add incremental work, it doesn't really
add that much latency to the overall
system since the processor is so fast at
doing math that we're just waiting on
memory to stream parameters in from the
pro from memory to the processor. But as
the batch sizes begin to get larger, we
then get bottlenecked by the processor.
And the more we add to our batch, the
slower the forward pass takes. And just
for good measure, we have this sigmoid
here that just sort of modulates the
smooth transition at this hinge point
here to show that there's a subtle
transition from a memory bound
computation to one that's more
computebound and bottlenecked by the
processor.
The final cast member is some proxy of
training throughput and we chose to
measure this on a per GPU basis. So in
this case the model takes in the
training batch size. So the parameter we
saw earlier and we typically do this by
fitting a proxy of our empirical
workloads. The units here is how many
each train how many tokens per second
each training GPU processes. So it needs
to do the forward the backward and some
optimizer steps.
So given these forecast members we can
then begin modeling the system. And the
first idea we had although Rhythm you
know suggested that this might not be a
great idea we can think about how to use
a synchronous setup. And this might be a
good idea from first principles because
we definitely meet the staleness
constraint because we don't train on
stale data and we always use the entire
GPU fleet for either training or
sampling making sure that we're using
efficient use of the hardware. Let's
think about how to actually model this.
There are two things we need to know. We
need to know the batch size at which
generation runs. And we also need to
know the response length distribution to
figure out how our training workload's
going to work and also how long the
sampling's going to take. And so what
I'm showing here in this simulation is a
couple of engines. Each square is a
request being processed and they get
darker and darker as we make progress
throughout the batch. And as they finish
samples, they write to the queue. And on
the right hand side is a time series
metric, maybe something that you'd see
in Graphana if you're monitoring
production metrics. And what you can see
is that the batch size begins very high,
but it slowly goes down over time as it
eventually goes to zero and all the
samples complete. And we can finally run
an optimization step. After the step
completes, we run this in a loop and we
move on to the next step. And so as a
result, we can have the following
sampling procedure. We do maximum tokens
inference forward passes where maximum
tokens is the total number of forward
passes we do for the longest request. We
use the fitted latency estimator to
figure out how long that forward pass
will take. And then the response length
distribution will tell us how many
responses to drop. And so what we're
showing in this video here is this
entire thing of the response length
distribution that we feed into the
latency estimator. At training time, we
can compute the total number of tokens
that we just sampled in the batch and
divide by the total uh training
throughput uh which is just the number
of GPUs multiplied by the per GPU
training throughput. And so what we have
here is a simulation of what this
latency curve looks like. So we have the
CDF of the response length distribution
that tells us how many responses we
should drop on the left and the latency
curve on the right. And this roughly
kind of tracks because as we add more
GPUs, we'd expect the latency per step
to go down.
The next idea, given that the
synchronous setup might not be the most
principled choice, as Rhythm showed, is
an asynchronous setup. But it's not just
as easy as just sort of provisioning the
compute between training and inference
because if we don't do this carefully,
we might actually run into the idle GPU
problem again. And to show this, let's
illustrate two extremes of what this
allocation problem looks like. Let's f
let's let's first look at one end of the
spectrum where we provision way too many
inference GPUs and not that many
samplers. In this case, we're consuming
from a queue much faster than we're
actually producing from it because the
sampling workers are producing work
significantly faster than significantly
slower than we can actually consume
them. When the red square grays out, it
shows that they're idle. And what this
diagram should hopefully illustrate is
that for a lot of the time we're
actually not using that and that has the
same problem of low GPU utilization in
the synchronous case as shown earlier.
On the other end of the extreme we can
provision way too many sampling GPUs in
which case our production rate is way
faster than the rate that we actually
consume them in. So here we've doubled
the number of overall sampling GPUs and
have the number of training GPUs. As you
can see, they produce samples at much
more rapid of a rate. But this index
here in each yellow square, which is the
staleness count of each sample, goes up.
And as time moves on, we get more and
more stale. And so the samples get more
and more kind of less more and more
transparent as a result. And we learn
less from each individual sample. So
let's think about how we can actually
model this workload then to to determine
an optimal async workload. In this case,
the picture looks a little bit different
because in steady state, the batch size
is relatively consistent compared to the
synchronous setup where it kind of goes
down over time. So on the right hand
side here, we have the same time series
metrics. But in this case, it's a little
bit different because the yellow squares
are always full because every time we
complete a sample, a new sample goes in
and we can continue writing to the
queue. And so that batch size with a
little bit of wiggles just for good
measure is like a is pretty consistent
over the course of a run. Now obviously
the caveat here is that this batch size
will certainly go down as we you know as
response lengths go up because we run
out of cache uh KV cache but that's kind
of a separate story and actually our
model accommodates for that because
we're actually accommodating for a
response length distribution.
We can then begin to figure out the
optimal layout and there's two kind of
constraints that we have to satisfy now
that we know that the generation batch
size is roughly consistent throughout
the course of a run. The first invariant
that we need to have is that the
production consumption rate are roughly
equal. So on the left hand side of this
equality we have the training throughput
which is the number of training GPUs
multiplied by the per GPU uh throughput
and then also we have the number of
sampling GPUs multiplied by the sampling
throughput which is just the batch size
multiplied by the latency to actually do
a forward pass on that batch size. And
the next thing is that given that rhythm
you indicated that if we have too much
stailness that can be bad from an ML ML
perspective, we want to make sure that
our max theoretical stailness or
simulated steness doesn't exceed what
our ML can handle. And so here we have
the max stillness on the left which is
equal to on the top how much time the
longest request took in the batch which
is just the maximum number of tokens
multiplied by the number of uh by the
amount of time each token forward pass
takes. And on the bottom here we have
the length of a training step which is
the training batch size multiplied by
the mean sequence uh by the mean
sequence length.
So the simulation here then will sweep
through multiple different values of the
number of training GPUs. And since we
have a fixed pool of compute that then
implies a certain number of GPUs used
for sampling. And for this number of
sampling GPUs, we can compute the
minimum steadystate generation batch
size to make sure that we don't blow out
of memory uh subject to our KV cache
memory constraints and also such that we
have maximum throughput on the on the
sampling side. And the final thing is we
want to prune out all simulations where
the sampling throughput brings us over
the maximum possible stailness. When we
look at that simulation, we can run an
end to end similarly parameterized by
the response length. We see that this
kind of roughly simulates a 60% speed up
relative to our synchronous baseline,
assuming that the GPU compute is
optimally allocated between training and
sampling.
As a result, when we sweep layouts
within these constraints, this allows us
to limit staleness, but also make sure
that we have our runs running at maximal
throughput without actually doing the
run itself. And so this gives us insight
to simulate different workloads before
actually running them on the GPU because
these runs can actually be fairly
expensive. And so this allows us to ask
answer scientific questions from first
principles like what is the optimal
configuration that we we should have of
our GPU compute if we made response
lengths very long because often times
when models learn via reinforcement
learning they begin to think for much
longer and also what empirical
throughputs we should target during our
performance optimization. So this has
been a really useful piece of technology
for simulation has informed a lot of the
systems and research design decisions
that we make. Cool. Thanks for your time
and find us afterwards to jam on some
more RL research engineering together
later. Thank you. [music]
>> [music]