본문 바로가기

AI-ML/LLM

논문 읽어보기 - Learning to Reason as Action Abstractions with Scalable Mid-Training RL

728x90
https://arxiv.org/html/2509.25810v1 논문 읽기

Learning to Reason as Action Abstractions with Scalable Mid-Training RL: 종합 분석

1. 논문의 취지 (Research Purpose/Motivation)

본 논문은 대규모 언어 모델(LLM)의 성능을 극대화하기 위한 mid-training 단계의 이론적 기반을 최초로 제시하고자 한다. 현재 LLM 훈련은 pre-training, mid-training, post-training RL의 3단계 파이프라인으로 이루어지는데, mid-training은 경험적으로 효과적임이 입증되었지만 그 작동 원리와 post-training RL에 미치는 영향이 이론적으로 명확히 규명되지 않았다는 문제의식에서 출발한다. 저자들은 효과적인 mid-training이 두 가지 핵심 목표를 달성해야 한다고 주장한다.

  1. 전문가 시연(expert demonstrations)으로부터 모든 태스크에 충분한 유용한 행동(actions)의 집합을 추출하고,
  2. 이러한 행동들 중에서 RL 과정에서 효율적인 선택을 가능하게 해야 한다.

특히, 기존의 next-token prediction (NTP) 방식은 primitive action space(토큰 단위)에서 작동하여 방대한 탐색 공간과 긴 planning horizon으로 인해 비효율적이라는 인식이 연구의 동기가 되었다. 핵심 통찰은 행동 추상화(action abstractions) 공간에서 학습하는 것이 더 효율적이라는 것이다. 여러 시간 단계에 걸친 temporally-extended actions를 학습함으로써 결정 공간을 압축하고 유효 horizon을 단축시켜, pruning 효율성을 높이고 RL 수렴 속도를 가속화할 수 있다는 가설을 검증하고자 한다.

2. 접근 방향 (Approach/Methodology)

저자들은 이론적 분석과 알고리즘 설계, 실증 실험을 결합한 다층적 접근 방식을 취한다.

이론적 프레임워크

Regret Decomposition: Lemma 3.2를 통해 post-training RL의 regret을 두 항으로 분해한다:

  • Approximation error: 전체 행동 공간을 부분 공간 A'로 pruning할 때 발생하는 오류
  • Post-training RL error: pruned space 내에서 planning 시 발생하는 오류

이 분해는 mid-training의 두 가지 핵심 역할을 명확히 한다.

  • Pruning Efficiency (Theorem 3.5): ε-optimal action subset의 최소 크기 N*(ε,D)가 작을수록, 그리고 근사 오차 ε가 작을수록, 더 적은 전문가 시연으로 suboptimal actions를 효과적으로 제거할 수 있음을 증명한다. 필요한 샘플 수는 O((N*/ε²)log|A|/δ)에 비례한다.
  • RL Convergence Rate (Theorem 3.6): Value iteration 기반 분석을 통해, temporally-extended actions의 평균 duration τ가 클수록 ε-optimality 달성에 필요한 iteration 수가 Ω(log(1/ε) / (1-γ^τ))로 감소함을 보인다. 이는 action abstractions가 효과적 horizon을 단축시켜 RL 수렴을 가속화함을 의미한다.

알고리즘 설계

Temporal Variational Lower Bound (Theorem 4.1): NTP 목적함수에 대한 sequential ELBO를 유도한다.

L(θ,φ) = E_D[∑t log p_θ(a_t|s_t,z_t) - KL(q_φ(z_t|s_t,a{t+1:T})||p(z_t|s_t))]

이는 관측된 primitive actions a_t를 설명하는 숨겨진 latent sequence z_t를 발견하는 것을 목표로 한다.

Expectation-Maximization 절차:

  • E-step (식 4.1): 고정된 θ로 variational posterior q_φ를 최적화. 이는 T-horizon RL 문제로, reward는 전문가 행동의 log-likelihood이며 KL penalty가 추가됨
  • M-step (식 4.2): 고정된 φ로 bootstrapped latent trajectories에 대해 supervised fine-tuning 수행

Temporal Consistency (식 4.3): Latent prior를 τ 시간 단계 동안 동일한 latent를 유지하도록 설계.

p(z_t|s_t) = (1-ε)δ(z_t=z_{t-1}) + ε·Uniform(Z)

여기서 δ는 Dirac delta 함수이다. 이를 통해 latents가 temporally-extended action abstractions로 기능하도록 강제한다.

RA3 (Reasoning as Action Abstractions) 알고리즘

코드 생성 맥락에서 구체화하여, 두 가지 latent types를 정의한다:

  • z=<think>: 새로운 rationale 생성 시작 (\n# 형태의 주석)
  • z=<act>: Temporal consistency 유지 (\n만 생성)

Scalability: Temporal consistency 덕분에 전체 T 토큰에 대해 rollout을 생성하지 않고, <think>가 샘플링될 때만 새로운 rationale을 생성하므로 계산 비용을 크게 절감한다.

KL Regularization (Proposition 5.1): KL 항이 Bernoulli KL과 entropy로 분해되며, λ 하이퍼파라미터를 통해 불필요한 reasoning을 억제하는 threshold 역할을 한다. 실제 구현에서는 <think>에 고정된 penalty를 부여하는 reward shaping을 사용한다.

3. 해결하고자 한 문제 (Problem Statement)

본 논문이 해결하고자 하는 핵심 문제는 다음과 같다:

이론적 공백: Mid-training이 post-training RL에 미치는 영향에 대한 원리적 이해 부족. 기존 연구들은 initial RL policy의 성능이나 entropy 같은 간접적 지표에 의존했으며, 이는 downstream 성능 향상을 보장하지 못했다.

계산 비효율성: 전통적인 NTP 기반 mid-training은 primitive action space(개별 토큰)에서 작동하여:

  • 방대한 행동 공간 |A|로 인한 높은 pruning error
  • 긴 sequence length T로 인한 느린 RL convergence
  • 충분한 pruning을 위해 과도하게 많은 전문가 시연 필요

확장성 문제: 기존의 rationale 생성 방법들은 모든 시간 단계에서 rollout을 필요로 하여, mid-training 규모(수십억 토큰)의 데이터에 적용하기에는 계산적으로 실행 불가능했다.

실증적 목표: 코드 생성 태스크에서 mid-training 후 HumanEval, MBPP 등의 벤치마크에서 성능을 향상시키고, 후속 RLVR post-training의 수렴 속도와 점근적 성능을 개선하는 것이다.

논문은 이러한 문제들을 action abstractions 학습이라는 통합된 프레임워크로 해결하고자 한다. 즉, 시간적으로 확장된 고수준 "skills"를 발견하여 결정 공간을 압축하고 planning horizon을 단축시킴으로써, 이론적 효율성과 실용적 확장성을 동시에 달성한다.

4. 실험 방법 (Experimental Methods)

실험 설정

Base Models:

  • Qwen-2.5-1.5B
  • Llama-3.2-1B
  • Llama-3.1-8B 3개의 사전 훈련된 base models를 사용하여 다양한 모델 크기에서의 일반화 가능성 검증

Mid-Training Data:

  • OpenCoder의 Python 코드 continued pre-training corpus 사용
  • 5M개의 고품질 인터넷 코드 스니펫 (11.3M 토큰)
  • 200K개의 코드 전용 합성 스니펫 (2.2M 토큰)
  • 총 5.2M 스니펫, 13.5B 토큰

Action Granularity: Primitive actions를 코드 한 줄 단위로 정의. 특수 토큰 대신 실용적 구현:

  • <act> = \n (줄바꿈만)
  • <think> = \n# (주석 라인 시작)
  • Format reward를 통해 주석이 # 으로 시작하고 \n 으로 끝나도록 강제

RA3 하이퍼파라미터:

  • EM iteration: 총 20K gradient steps
    • 처음 5K steps: RL policy gradient (E-step)
    • 나머지 15K steps: Supervised fine-tuning (M-step)
  • RL warmup: 500 steps (KL penalty 없이, latent 추출에 집중)
  • T-horizon: 5-step truncated return (다음 5개 step 내 reward만 사용)
  • Maximum reasoning length: 최대 20 토큰
  • Sampling temperature: 1.5
  • Group size: 8 (advantage 계산용)
  • Entropy coefficient: 0.01
  • Penalty λ: 0.1
  • RL batch size: 32, learning rate: 5×10⁻⁶
  • M-step: batch size 256, learning rate: 2×10⁻⁵
  • Optimizer: AdamW

NTP Baseline: 동일한 데이터에 대해 표준 next-token prediction으로 훈련. M-step과 동일한 하이퍼파라미터 사용 (batch 256, lr 2×10⁻⁵)

구현 세부사항:

  • Asynchronous rollout with SGLang engine: 배치 추론의 idle time 최소화 (모든 샘플이 다음 turn으로 진행하기 전 대기하지 않음)
  • Multi-turn RL 구조: 코드 라인과 주석이 turns로 교대

평가 방법론

Mid-Training 평가:

  1. RL Training Reward: E-step에서 log-likelihood reward 최대화 성능 추적 (Figure 2)
  2. Cross-Entropy Loss: M-step에서 bootstrapped data에 대한 CE loss 측정 (Figure 4). 낮은 CE loss는 latents가 효과적으로 expert decisions를 설명함을 의미
  3. Qualitative Analysis: 훈련 전후 데이터 예시 비교 (Figure 3). RL이 추출한 high-level abstractions (e.g., "create dummy head", "BFS") 시각화
  4. 벤치마크 성능:
    • HumanEval: 164개 Python 프로그래밍 문제
    • MBPP: 500개 기본 Python 프로그래밍 문제
    • HumanEval+, MBPP+: 더 엄격한 테스트 케이스의 확장 버전
    • Pass@1, Pass@5 메트릭 사용 (BigCode evaluation harness)
    • EM iteration마다 평균 점수 추적 (Figure 5)

Post-Training RLVR 평가:

  1. RLVR 알고리즘: Group Relative Policy Optimization (GRPO) 사용
  2. Training Data: AReaL-boba-2-RL-Code 데이터셋
    • 50K 샘플
    • TACO, CodeContests, LiveCodeBench에서 필터링
    • Chat format 대신 function signature를 instruction 끝에 제공하여 base models도 직접 함수 body 완성 가능
  3. 평가 벤치마크:
    • HumanEval+, MBPP+
    • LiveCodeBench: 실시간 contamination-free 평가
    • Codeforces: 경쟁 프로그래밍 문제
  4. 메트릭:
    • Convergence speed: 학습 곡선의 기울기
    • Asymptotic performance: 최종 수렴 성능
    • Random seeds: 소형 모델 3개, 8B 모델 2개 독립 실행하여 통계적 신뢰성 확보
  5. 비교 대상:
    • Base models (mid-training 없음)
    • NTP mid-training 후 RLVR
    • RA3 mid-training 후 RLVR

Ablation Study:

λ (penalty parameter) 값 변화에 따른 영향 분석 (Figure 7):

  • Reasoning frequency: <think> 생성 비율
  • Reasoning length: 평균 rationale 길이
  • Cross-entropy loss
  • 평가 성능

5. 실험 결과 (Experimental Results)

Mid-Training 단계 결과

RL Reward 학습 (Figure 2):

  • 모든 base models에서 RL step 동안 reward가 빠르게 증가
  • 대부분의 데이터와 계산을 reasoning bootstrapping과 supervised fine-tuning에 할당 가능
  • 빠른 수렴은 expert demonstrations에서 유용한 latent structures를 효과적으로 발견함을 시사

Qualitative Analysis (Figure 3):

  • RL 후 모델이 transferable skills에 해당하는 high-level abstractions 추출
  • 예시 1: "create a dummy head" - 연결 리스트 문제의 일반적 패턴
  • 예시 2: "BFS (Breadth-First Search)" - 그래프 탐색 알고리즘
  • 이러한 abstractions는 여러 태스크에 걸쳐 재사용 가능한 지식을 인코딩

Cross-Entropy Loss (Figure 4):

  • RA3가 reasoning-bootstrapped data에서 학습 속도를 현저히 가속화
  • NTP 대비 더 낮은 CE loss 달성
  • 이는 hidden reasoning trajectories가 expert decisions를 더 쉽게 설명 가능하게 만듦을 의미
  • 가설 검증: mid-training data는 primitive actions만 제공하지만, 숨겨진 reasoning이 그 decisions를 guide함

벤치마크 성능 향상 (Figure 5, Table 1):

Llama-3.2-1B 결과:

  • Base: HumanEval p@1 18.9%, MBPP p@1 25.8%, 평균 23.3%
  • NTP: 평균 25.3% (base 대비 +2.0p)
  • RA3: 평균 29.8% (base 대비 +6.5p, NTP 대비 +4.5p)
  • HumanEval p@1: 25.0% (base 18.9%, NTP 21.3%)
  • MBPP+ p@1: 39.4% (base 31.5%, NTP 34.4%)

Qwen-2.5-1.5B 결과:

  • Base: 평균 37.9%
  • NTP: 평균 41.7% (+3.8p)
  • RA3: 평균 46.6% (+8.7p over base, +4.9p over NTP)
  • HumanEval p@1: 48.2% vs NTP 41.5% (+6.7p)
  • HumanEval+ p@1: 42.7% vs NTP 35.4% (+7.3p)

Llama-3.1-8B 결과:

  • Base: 평균 41.0%
  • NTP: 평균 47.7% (+6.7p)
  • RA3: 평균 48.9% (+7.9p over base, +1.2p over NTP)
  • 더 큰 모델에서도 일관된 개선 확인

전체 경향:

  • RA3가 모든 모델 크기에서 NTP를 일관되게 능가
  • 평균적으로 base 대비 8 points, NTP 대비 4 points 개선
  • 더 엄격한 테스트(HumanEval+, MBPP+)에서도 강건한 성능
  • EM iterations가 진행됨에 따라 성능이 지속적으로 향상
  • RA3가 더 적은 데이터로 NTP보다 높은 정확도 달성

Post-Training RLVR 결과 (Figure 6)

수렴 속도 (Convergence Speed):

  • RA3 mid-training 후 RLVR이 더 빠르게 수렴
  • 초기 학습 단계에서 NTP보다 가파른 성능 향상
  • 이는 Theorem 3.6의 예측과 일치: temporally-extended actions가 효과적 horizon 단축

점근적 성능 (Asymptotic Performance):

  • 모든 벤치마크에서 RA3 > NTP > Base의 명확한 순서
  • HumanEval+에서 가장 두드러진 차이
  • MBPP+, LiveCodeBench, Codeforces에서도 일관된 우위
  • Random seed에 걸쳐 안정적인 결과 (error bars 작음)

이론적 예측 검증:

  • Mid-training이 post-training RL 성능을 **실질적으로 형성(shapes)**함을 확인
  • Action abstractions가 두 가지 측면 모두에서 우수:
    1. 더 강한 policy prior (초기 성능)
    2. 더 빠른 RL convergence (학습 효율)

Ablation Study (Figure 7)

λ (Penalty) 값의 영향:

  1. λ = 0.03 (낮은 penalty):
    • Reasoning frequency: ~95% (거의 모든 라인에서 reasoning)
    • Reasoning length: 평균 15-18 토큰
    • 문제: NTP와 차별성 없음 - decision space나 horizon 감소 없음
    • 계산 오버헤드만 증가
  2. λ = 0.1 (default):
    • Reasoning frequency: ~15-20%
    • Reasoning length: 평균 8-12 토큰
    • 최적 균형: 효율성과 성능
    • CE loss 최저
    • 평가 성능 최고
  3. λ = 0.3 (높은 penalty):
    • Reasoning frequency: ~5% 미만
    • Reasoning length: 매우 짧음
    • 문제: RA3가 NTP로 퇴화
    • <think> 생성 시 infinite penalty에 가까움
    • 성능 저하

핵심 통찰:

  • λ가 <think> 생성의 threshold 역할: log-likelihood 개선이 λ 이상일 때만 reasoning 생성
  • 적절한 λ 선택이 중요: 불필요한 reasoning 억제하면서도 유용한 abstractions는 유지
  • Computational cost 조절 가능: reasoning frequency를 통해 추가 비용 관리
  • Proposition 5.1의 이론적 분석과 정확히 일치

통계적 유의성

  • 여러 random seeds에 걸친 일관된 결과
  • NTP 대비 RA3의 개선이 모든 모델과 벤치마크에서 재현됨
  • Error bars가 작아 결과의 신뢰성 높음
  • Pass@5 메트릭에서도 유사한 경향 (단순히 Pass@1의 우연이 아님)

6. 시사점 (Implications/Contributions)

이론적 기여

1. Mid-Training의 원리적 이해:

  • 최초의 formal analysis: Mid-training이 post-training RL을 어떻게 형성하는지에 대한 이론적 특성화
  • Regret decomposition (Lemma 3.2)을 통해 두 가지 독립적 역할 규명:
    • Pruning efficiency → policy prior 형성
    • RL convergence acceleration → online improvement 가능성
  • 이는 mid-training 알고리즘 설계에 원칙적 가이드 제공

2. Action Space Cardinality의 중요성 (Theorem 3.5):

  • Minimal ε-optimal action subset size N*(ε,D)가 작을수록 pruning이 효율적
  • 필요한 expert demonstrations 수가 O((N*/ε²)log|A|/δ)에 비례
  • 시사점: Primitive actions (토큰 단위)보다 action abstractions가 근본적으로 우월
  • Transferable skills는 여러 태스크에 걸쳐 재사용되므로 N*가 작음

3. Temporal Abstraction의 가치 (Theorem 3.6):

  • Duration τ가 클수록 value iteration convergence가 O(log(1/ε)/(1-γ^τ))로 가속
  • 직관적 설명: 각 Bellman backup이 τ steps를 한 번에 점프 → 유효 horizon 감소
  • Planning tractability 향상: 긴 horizon 태스크에서 특히 중요

4. 통합된 이론적 프레임워크:

  • Markov options, lifelong RL의 option transfer 개념을 LLM mid-training에 적용
  • 하지만 기존 연구와 달리 transferability보다는 LLM 훈련 파이프라인 설계에 초점
  • Multi-task RL 문헌과 언어 모델 훈련 문헌을 연결하는 교량 역할

방법론적 기여

1. Temporal Variational Lower Bound (Theorem 4.1):

  • NTP 목적함수의 새로운 변분 하한 유도
  • Sequential ELBO with latent trajectories
  • 이론적으로 원칙적이면서도 실용적 알고리즘으로 변환 가능

2. RA3 알고리즘의 확장성:

  • 핵심 혁신: Temporal consistency를 활용한 계산 비용 절감
  • 모든 timestep에서 rollout 필요 없이, <think> 샘플링 시만 rationale 생성
  • Mid-training 규모 (13.5B 토큰)에서 실행 가능
  • 기존 방법들 (Zhong et al., Dong et al.)은 수천 샘플에 국한, RA3는 수백만 샘플 처리

3. 자율적 Reasoning 결정:

  • 모델이 스스로 언제 reasoning이 필요한지 학습
  • Hand-crafted rules (Dong et al.의 entropy-based rule) 불필요
  • KL regularization을 통한 자연스러운 threshold 메커니즘
  • Reward shaping으로 실용적 구현

4. 특수 토큰 없는 실용적 설계:

  • \n과 \n#를 사용하여 기존 코드 syntax 유지
  • 추가 fine-tuning 없이 적용 가능
  • Format reward로 올바른 형식 보장
  • 실제 배포 시 간편함

실증적 기여

1. 일관된 성능 개선:

  • 3개 모델 (1B~8B), 6개 벤치마크에서 체계적 향상
  • Base 대비 평균 8 points, NTP 대비 4 points
  • Qwen-2.5-1.5B에서 HumanEval p@1: 48.2% (NTP 41.5%)
  • 통계적으로 robust한 결과

2. Post-Training RL 개선 검증:

  • Faster convergence와 higher asymptotic performance
  • 이론적 예측 (Theorem 3.6)의 실증적 확인
  • 실제 RLVR 파이프라인에서의 실용적 가치 입증

3. Qualitative Insights:

  • 모델이 학습한 action abstractions의 가시화
  • "Dummy head", "BFS" 같은 transferable patterns
  • 해석 가능성과 디버깅에 유용

한계점 (Limitations)

1. 도메인 특수성:

  • 실험이 Python 코드 생성에만 집중
  • 수학, 자연어 생성, 다른 프로그래밍 언어에 대한 일반화 검증 필요
  • 코드는 명확한 line 구조를 가져 action granularity 정의가 자연스러웠으나, 다른 도메인에서는 불명확할 수 있음

2. 이론적 가정:

  • Theorem 3.5는 ε-optimal action subsets의 존재를 가정
  • 모든 태스크에서 재사용 가능한 compact action set이 항상 존재하는지 불명확
  • Theorem 3.6은 value iteration 기반이나, 실제로는 policy gradient 사용

3. 계산 비용:

  • RA3가 NTP보다 추가 RL step 필요
  • Asynchronous rollout으로 완화했으나 여전히 오버헤드 존재
  • λ tuning이 필요하며, 최적값이 데이터셋과 모델에 따라 다를 수 있음

4. Latent 해석:

  • 학습된 latents가 항상 인간이 이해 가능한 reasoning은 아님
  • 일부 abstractions는 spurious correlations일 수 있음
  • Qualitative analysis가 제한적 (2개 예시만 제공)

5. Baseline 비교 제한:

  • 최신 distillation-based methods (frontier LLM의 reasoning 증류)와 직접 비교 없음
  • 다른 hierarchical RL 방법들과의 비교 부족
  • 다른 mid-training 알고리즘들과의 체계적 비교 필요

향후 연구 방향 (Future Work)

1. 다른 도메인으로의 확장:

  • 수학적 추론: 증명 단계를 action abstractions로
  • Multi-turn dialog: 대화 전략을 high-level actions로
  • Robotics: Physical skills hierarchy

2. 적응적 Abstraction 학습:

  • 태스크별로 다른 granularity의 abstractions
  • Hierarchical RL의 다단계 옵션 학습
  • Curriculum learning: 간단한 abstractions부터 점진적으로

3. 이론적 확장:

  • General policy gradient 알고리즘에 대한 convergence 분석
  • Approximation error bounds 개선
  • Multi-task transfer 보장에 대한 characterization

4. 실용적 개선:

  • λ의 자동 tuning 메커니즘
  • 더 효율적인 RL 알고리즘 (e.g., actor-critic)
  • Distributed training 최적화

5. 해석 가능성 연구:

  • 학습된 action abstractions의 체계적 분석
  • Transferability 측정 메트릭
  • Human evaluation of reasoning quality

분야에 대한 영향

LLM Training Pipeline 재정의:

  • Mid-training이 단순한 continued pre-training이 아니라 구조적 지식 추출 과정임을 확립
  • 향후 reasoning models 개발에 원칙적 가이드 제공
  • Pre-training → Mid-training → Post-training 파이프라인의 이론적 기반 강화

Reasoning Models의 새로운 패러다임:

  • O1, DeepSeek-R1 같은 reasoning models의 작동 원리 이해에 기여
  • Chain-of-thought가 단순한 prompting 기법을 넘어 action abstractions로 이해 가능
  • Future reasoning models의 설계에 영향

RL for LLMs 연구 활성화:

  • 이론적 분석이 부족했던 LLM RL 분야에 formal framework 제공
  • Self-supervised RL, hierarchical RL 등 관련 방향 연구 촉진
  • Verifiable reward를 넘어선 일반적 RL 적용 가능성 시사

실무적 가치:

  • 코드 생성 모델의 실질적 성능 향상
  • 계산 비용 대비 효율적인 mid-training 방법
  • 산업계에서 즉시 적용 가능한 알고리즘

최종 평가

본 논문은 이론적 엄밀함과 실용적 효과성을 모두 갖춘 우수한 연구이다. Mid-training의 역할을 최초로 formal하게 분석하고, 이를 기반으로 확장 가능한 알고리즘을 제안하며, 다양한 모델과 벤치마크에서 일관된 개선을 실증했다. 특히 action abstractions라는 통합된 프레임워크로 pruning efficiency와 RL convergence를 동시에 설명한 점이 탁월하다.

몇 가지 한계점(도메인 특수성, 제한적 baseline 비교)에도 불구하고, 본 연구는 LLM 훈련 파이프라인에 대한 이해를 크게 진전시켰으며, reasoning models 개발의 이론적 기반을 마련했다는 점에서 높은 학술적, 실용적 가치를 지닌다. 향후 다른 도메인으로의 확장과 더 깊은 이론적 탐구가 기대된다.

728x90