Skip to content

CS 285: Lecture 19, Control as Inference

영상링크: https://youtu.be/MzVlYYGtg0M

Optimal control as a Model of Human Behavior

Optimal control is a mathematical framework for computing control policies that optimize a given objective.

  • r(st,at) 함수를 찾아서 데이터를 설명하려고함. 기본적 세팅은 다음과 같음.
a1,,aT=argmaxa1,,aTt=1Tr(st,at)st+1=f(st,at)π=argmaxπEst+1p(st+1|st,at),atπ(at|st)[r(st,at)]atπ(at|st)
  • 하지만 이는 적용하기 쉽지 않음. 우리는 목표에 최적화된 행동을 항상 하지 않음. 같은 목표라도 직진하는 경우도 있고, 멀리 돌아갈 때도 있음(결국에 목적에 도달함). 현 프레임워크에서 이를 설정하거나 설명하기 어려움. 어떤 행동이 최적화된 행동인지에 대한 변수자체가 없음.
  • 그래서 이를 설명하기 위해 binary 변수 O 도입함. 최적화인 행동일 때 1, 그렇지 아니할 때 0의 값을 가짐.

HeadImg

p(O1:T)=exp(r(st,at))p(τ|O1:T)=p(τ,O1:T)p(O1:T)p(τ)texp(r(st,at))=p(τ)exp(tr(st,at))
  • 이렇게 함으로써 suboptimal behavior를 모델링 할 수 있고, inference 알고리즘을 적용하여 control과 planning 문제를 해결할 수 있음. 그리고 stochastic behavior가 왜 선호되는지 설명할 수 있음. 이는 exploration과 transfer learning을 아는데 도움이 됨.
  • 그러면 어떻게 inference할까?

  • Backward Messeage 계산: β(st,at)=p(O1:T|st,at)

  • Policy 계산: π(at|st,O1:T)
  • Forward Message 계산: α(st)=p(st|O1:t1)

Control as Inference

Backward Messages

β(st,at)=p(Ot:T|st,at)=p(Ot:T,st+1|st,at)dst+1=p(Ot+1:T|st+1)p(st+1|st,at)p(Ot|st,at)dst+1
  • p(st+1|st,at)는 transition dynamics이고, p(Ot|st,at)는 observation likelihood임.
  • 그러면 제일 앞에 있는 p(Ot+1:T|st+1)를 풀어서 쓰면 다음과 같음.
p(Ot+1:T|st+1)=p(O1:T|st+1,at+1)p(at+1|st+1)dat+1
  • 여기서 p(O1:T|st+1,at+1) 를 다시 β(st+1,at+1)로 바꿔서 쓸수 있음.
  • 그리고 p(at+1|st+1)action prior이라고 하는데 (policy는 아님) 우선은 uniform distribution으로 가정함. 왜냐면 아무도 어떤 행동을 할지 모르기 때문임. 게다가 수학적으로 해당 항을 지울 수 있음1.
  • 따라서 Backward message passing은 다음과 같이 진행된다.

Backward Message Passing

for  t=T1,,1:β(st,at)=p(Ot|st,at)Est+1p(st+1|st,at)[βt+1(st+1)]β(st)=Eatp(at|st)[β(st,at)]
  • Vt(st)=logβ(st), Qt(st,at)=logβ(st,at) 라고 재정의 하자. 따라서 Vt(st)=logexp(Qt(st,at))dat로 쓸 수 있으며, 이는 soft value function이라고 함.
  • Qt(st,at)가 커짐에 다라서 Vt(st) 도 커짐. Vt(st)maxatQ(st,at).
  • Qt(st,at)=r(st,at)+logE[exp(Vt+1(st+1))]
    • deterministic transition: Qt(st,at)=r(st,at)+Vt+1(st+1)
    • sthocastic case는 차후에 다룸
  • Log domain에서 알고리즘을 다시 쓰면 다음과 같음.

Backward Message Passing(Log Domain)

for  t=T1,,1:Qt(st,at)=r(st,at)+logE[exp(Vt+1(st+1))]Vt(st)=logexp(Qt(st,at))dat

Policy Computation

p(at|st,O1:T)=π(at|st)=p(a|st,Ot:T)=p(at,st|Ot:T)p(st|Ot:T)=p(at,st|Ot:T)p(at,st)/p(Ot:T)p(Ot:T|st)p(st)/p(Ot:T)=p(at,st|Ot:T)p(Ot:T|st)p(at,st)p(st)=βt(st,at)βt(st)p(at|st)
  • p(at|st)는 action prior이라 무시하고, policy π(at|st)=βt(st,at)βt(st)를 얻을 수 있음.
  • Log domain에서 π(at|st)=exp(Qt(st,at)Vt(st))=exp(At(st,at)) 로 쓸 수 있음.
  • Temperature α 를 도입하면 π(at|st)=exp(1αQt(st,at)1αVt(st))=exp(1αAt(st,at))
  • α 가 0으로 갈 수록 deterministic policy가 되고 greedy policy에 가까움.

Forward Messages

α(st)=p(st|O1:t1)=p(st,st1,at1|O1:t1)dat1dst1=p(st|st1,at1,O1:t1)p(at1|st1,O1:t1)p(st1|O1:t1)dat1dst1=p(st|st1,at1)p(at1|st1,O1:t1)p(st1|O1:t1)dat1dst1
  • p(st|st1,at1,O1:t1) 에서 O1:t1st1at1에 의존하지 않으므로 생략 가능.
  • α1(s1)=p(s1) 보통 알고 시작함.
  • p(st|O1:T) 를 계산하고 싶으면?
p(st|O1:T)=p(st,O1:T)p(O1:T)=p(Ot:T|st)p(st,O1:t1)p(O1:T)βt(st)αt(st)

Message Intersection

HeadImg

  • 예를 들어 그림과 같이 목표지점에 공을 가져다 놓는 Task가 있다.
  • Backward message는 목표지점에 도달하기 위한 상태의 확률을 나타내고, Forward message는 처음 상태에서 목표지점에 도달되는 상태를 표현한다(높은 reward 와 함께).
  • 그리고 두 메시지의 곱은 목표지점에 도달하기 위한 상태의 확률을 나타낸다.

Control as Variational Inference

  • Inference Problem: p(s1:T,a1:T|O1:T)
  • Marginalizaing과 conditioning을 통해 목적 policy p(at|st,O1:T) 를 계산하고 싶음. "높은 리워드가 주어졌을 때, action probability가 어떻게 되는가?"
  • q(s1:T,a1:T) 분포로 p(s1:T,a1:T|O1:T) 를 근사하면 어떻까(단 dynamics p(st+1|st,at) 하에서)?
  • x=O1:T, z=(s1:T,a1:T) 로 두어서 p(z|x)q(z) 로 근사해보자!

과정1

Algorithms for RL as Inference

Q-Learning with soft optimality

HeadImg

Benefits of Soft Optimality

  • exploration을 향상시키고 entropy collapse를 방지할 수 있음
  • policies를 더 쉽게 fine-tuning 할 수 있음
  • 더 로버스트함. 더 많은 state를 커버하기 때문
Was this page helpful?

Comments