Reinforcement Learning Notes

February 2, 2026

RL 是 post-training 里比较重要的一个部分。开个坑,记录下 RL 相关的基础知识。

跑个 demo 看下 RL 的基本概念

RL 本质是在 Actor 和 Env 的持续交互中,通过 action 和 state/reward 变化来不断迭代 policy/critic 等模型。

下面的例子用一个左右移动的场景(目标是让一个小球从 0 移动到 2)来举例(下面代码基本都有注释,highlight 几个点):

  • reward model: 这个场景下 reward 会比较简单,直接通过 state 来判断。为了引入 critic values,这里在 state=2 的 reward 是最大的,但是故意在 state=1 的 reward 设置为最小。所以如果 actor 只关心 reward,那永远走不到 2。
    • reward 分为确定性的 reward 和非确定性的 reward。如何定义好 reward(尤其是过程中的 reward)很关键;
  • critic model:为了避免 actor 只看 reward(短期收益),引入 critic values(长期价值),长期价值本质上是未来所有潜在 reward 的压缩。所以每个 step 的真实价值应该是 reward + gamma * critic-value。
    • critic value 也分两种情况,一种是可枚举的,比如这个场景,我可以遍历所有的情况,然后不断更新 critic value;另一种是不可枚举的,预估的;
  • policy 最后输出的是一个 logprobs,(1, T),T 是 action 的维度,然后加入一定的随机性,这里和 LLM 的 temperature 是类似的。
  • Training 训练过程
    • rollout:actor 每走一步或者几步,产生的 trajectory 可以认为是一次 rollout,包括 state/action/logprobs/reward/critic-values 等多个信息。
    • reward model 一般不参与 RL 训练。
    • critic model 更新:
      • 计算一个 advantage,公式是 adv = step["reward"] + gamma * V[s_next] - V[s]
        • adv > 0:说明 V[s] 估计小了;adv < 0:说明 V[s] 估计大了;
        • 用这个 delta 值去训模型
    • policy model 更新:
      • 对于 policy model,adv 的值表示这个行为应该被鼓励还是抑制,从而影响下次输出的 logprobs
  • 这是一个最简单的 demo,reward/critic/policy 都直接用 python dict 来表示,实际的训练过程,根据不同的方法论会更加的复杂...
import numpy as np
import random

# 环境的描述(-2  2 共有 5 个点)
states = [-2, -1, 0, 1, 2]
# 动作集合(左,不动,右)
actions = [-1, 0, 1]

# 执行 action
def transition(s, a):
    return max(-2, min(2, s + a))

# reward model 定义
def reward(s):
    if s == 2:
        return 10
    if s == 1:
        return -5
    return -1

# policy logits: state -> action logits
policy_logits = {
    s: np.zeros(len(actions)) for s in states
}

def softmax(x):
    e = np.exp(x - np.max(x))
    return e / e.sum()

# policy model, 
def policy(state):
    probs = softmax(policy_logits[state])
    return probs

V = {s: 0.0 for s in states}
gamma = 0.9

trajectory = []

# 一次 rollout
state = 0
for t in range(100):
    probs = policy(state)
    action_idx = np.random.choice(len(actions), p=probs)
    action = actions[action_idx]

    logprob = np.log(probs[action_idx] + 1e-8)

    next_state = transition(state, action)
    r = reward(next_state)

    trajectory.append({
        "state": state,
        "action": action,
        "action_idx": action_idx,
        "reward": r,
        "logprob": logprob
    })

    state = next_state


# 计算 advantage, 作为 policy model 的输入
advantages = []
for step in trajectory:
    s = step["state"]
    a = step["action"]
    s_next = transition(s, a)
    td_target = step["reward"] + gamma * V[s_next]
    advantage = td_target - V[s]
    advantages.append(advantage)

# 利用 advantages 更新 policy model
lr = 0.1
for step, adv in zip(trajectory, advantages):
    s = step["state"]
    a_idx = step["action_idx"]
    logp = step["logprob"]

    # policy gradient:  logπ(a|s) * advantage
    policy_logits[s][a_idx] += lr * adv

# 利用 advantages 更新 critic model
alpha = 0.1
for step, adv in zip(trajectory, advantages):
    s = step["state"]
    V[s] += alpha * adv

for i, step in enumerate(trajectory):
    print(
        f"t={i}, s={step['state']}, a={step['action']}, "
        f"r={step['reward']}, logp={step['logprob']:.3f}, "
        f"adv={advantages[i]:.3f}"
    )

rl-gpt

  • 除了上面的 demo 之外,还有一个 reference model,目的是约束,放置模型训歪。不过我现在还没理解这个是否是必要的,如果是为了约束,那在 reward 和 critic model 里也能起到类似的作用,多加一个模型反而让整体复杂度和稳定性风险又高一个量级。
    • Reference 模型可能来自于初始化的 policy model,为了防止 policy model 跑的太偏;(利用 KL 散度来定义 loss)

PPO vs DPO vs GRPO

PPO, DPO & GRPO: Reinforcement Learning Techniques for Training LLMs:这篇文章讲得不错。

  • PPO(Proximal Policy Optimization): 指的是 policy 的 update 要尽可能贴近 pretrain model
    • 整体思路和上面比较类似,有个不同是在 policy model 更新时,会有一个 clip ratio 的概念,比如 clip=0.1,那每次最多只能 update 10% 的 logprobs 的变化;(或者是 model weights???)
    • advantage 计算:A(s,a) = Q(s,a) - V(s),这里的 Q(s,a) 是从 transition(s,a) 之后的所有 reward 的平均结果。
  • DPO(Direct Preference Optimization): 直接用比较的方式来做优化
    • 没有 reward model,数据集是用户的偏好比较(类似现在 ChatGPT 上弹出两个 answer 你来比较)
    • policy model 的 loss 是两者差值,如果用户觉得 A 比 B 好,模型推的是 B,那就把 A 的 logprobs 加起来,B 的降下去;(利用的是 KL 散度的公式)
    • 有 reference model,作为基线
  • GRPO(Group Relative Policy Optimization): 在一个 Group 内用 DPO 的思想来优化 policy
    • 将一个 group 里的 rank 后,两两组成 pair 走 DPO 思想,但区别是,DPO 里是基于 reference model(off-policy) 来比较;GRPO 里是基于自身的 policy 来做两两比较;

从演进方向上来看,

  • PPO 是最符合直觉的,也是最贴近传统 RL 思想的。但是整体结构过于复杂,需要依赖一个额外的 reward model;
  • DPO,在 PPO 的基础上简化掉了 reward model,但是需要一个比较好的人为标记数据集;
  • GRPO 在 DPO 的基础上又做了简化,通过 rank 的方式让数据的利用程度更高了;比如 rank 4 条数据就产生了 (3+2+1)=6 个 pair;另外也不再需要一个 reference model 作为 baseline 了;

OpenRLHF / Slime / VeRL

Slime

在了解完 RL 的基础概念之后,slime 的主流程代码相对来说不算复杂(但实际上涉及的能力非常非常多)。以 train.py 为入口:

  1. 利用 Ray 构建 placement group,支持 rollout 推理和 actor 训练两种 colocate 和非 colocate 的情况(对应的是 weights 等是否需要 update across groups),利用 ray 的 placementgroup 能力提前占用资源
  2. TrainGroup 初始化,这里把 critic 和 actor 分到一个 process group,用 1 和 0 表示 (process group 的 world_size=2),因为之前做 placement group 的时候已经按序来排资源了,所以这里保证了 critic 和 actor 尽可能在一个 ip 下的不同 rank:
def connect(self, critic_group):
    return ray.get(
        [
            actor.connect_actor_critic.remote(critic)
            for actor, critic in zip(self._actor_handlers, critic_group._actor_handlers, strict=False)
        ]
    )
  1. 初始化 actor/critic model
  2. 开始训练,先生成数据 generate_rollout,这里如果是 SGLang 来生成的话,用了很多异步操作(为了加速,更简单的实现要看 sft_rollout,就是直接读 data_buffer 了),生成 rollout 和 rm,自定义了一些 rm type 和实现;返回是一个 list[list[Sample]];再做一下 reward 的 normalization 和其他小的处理,产生 train_data;再做一下 ddp;
while state.remaining_batch_size < target_data_size:
    # get samples from the buffer and submit the generation requests.
    samples = data_source(args.over_sampling_batch_size)
    state.submit_generate_tasks(samples)

# wait for the generation to finish
done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)
  1. 每个 train actor 去拿数据(这里数据拿的有点低效,为啥不直接把 shard 逻辑下推到 datasource 呢,要先 put 到 ray memory store 之后再拷贝到每个 shard?)
  2. 做 critic model 的前向,(这里如果涉及到 context_parallel 会比较麻烦),对 reward 有一些处理后调用 megatron 的前向;ciritic values 的结果 broadcast 到同一个 process group 的 actor;拿到结果后 train critic model;
for value in values:
    handles.append(dist.broadcast(value, src=1, group=group, async_op=True))

if args.kl_coef != 0 or args.use_kl_loss:
    if not log_probs:
        log_probs = [torch.empty_like(value) for value in values]
    if not ref_log_probs:
        ref_log_probs = [torch.empty_like(value) for value in values]
    for ref_log_prob, log_prob in zip(ref_log_probs, log_probs, strict=False):
        handles.append(dist.broadcast(log_prob, src=0, group=group, async_op=True))
        handles.append(dist.broadcast(ref_log_prob, src=0, group=group, async_op=True))
  1. actor 的逻辑也很类似。

接下来看一些细节的设计。

Router

在初始化 RolloutManager 的时候,如果是 SGLang,会初始化一个 Router,Router 初始化后会用来记录 SGLang Server 的信息,提供 health check 等基础能力,另外对外提供 3 个接口:

def _setup_routes(self):
    """Setup all the HTTP routes"""
    
    #  SGLangRolloutEngine 初始化的时候进行注册(之前提到过 SGLangRolloutEngine 是一组 Ray Actor,独立部署)
    self.app.post("/add_worker")(self.add_worker)
    self.app.get("/list_workers")(self.list_workers)

    # 这里 router支持了一个 middleware 能力,在这个方法中主要是利用 radix tree 来做 token cache
    self.app.post("/retrieve_from_text")(self.retrieve_from_text)

    # 所有其他请求直接 route 给对应的 sglang server
    self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy)

SGLang Rollout

  • 在生成 rollout data 的过程中用了非常多的协程和异步,所以用GenerateState来做管理;
  • 直接调用 SGLang Server 的 generate 方法
  • rollout 后生成 reward

Speculative Decoding

  • 了解了 Speculative Decoding 的基本概念,即大模型 self decoding 比较慢的话,用一个小模型来做 decoding 然后大模型做一次 forward 来 verify 小模型连续 T 次预测的结果;
  • Slime 里的具体实现有空再看看,涉及到一些 megatron 的改造;

FSDP Backend

  • 除了 megatron backend 之外,还支持了 torch 的 FSDP;
  • 模型初始化的时候,手动让 rank0 做 dist.broadcast(state_dict, src=0),防止其他 rank 重新 load model;
  • pack_sequencees: 拿到一个 batch rollout 数据后,pack 一下降低后续 padding 成本;
  • 通过 dist.all_gather_object 来回收 metrics

weights updater & offload

// TODO

Multi Turn Rollout

适用 Agent 和 robotics 场景 // TODO

Quantitization

量化是解决显存和带宽问题的常见手段,同时也会引入 trian/infer 不一致的问题,所以会有很多种组合玩法。 // TODO

OpenTinker(RL-as-a-Service)