์ด๋ฒ section์์๋ ์ด๋ฅผ ๋ชจ๋ ํด๊ฒฐํ Infinitely Differentiable Monte-Carlo Estimator(DiCE)๋ฅผ ์๊ฐํฉ๋๋ค. ์ด๋ SCG์์ ์ด๋ค ์ฐจ์์ ๋ฏธ๋ถ๋ ์ ํํ๊ฒ ๊ณ์ฐํ ์ ์๋ ์ค์ฉ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค. ํน์ ์ฐจ์์ ๋ฏธ๋ถ์ ํ๊ธฐ ์ํด ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ 9.3.1์ ๋์จ ๋ฐฉ๋ฒ์ ์ฌ๊ท์ ์ผ๋ก ๊ณ์ ์ฌ์ฉํ๋ฉด ๋๊ฒ ์ง๋ง, ์ด๋ ๋ ๊ฐ์ง ๊ฒฐ์ ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ฒซ๋ฒ์งธ๋ก gradient๋ฅผ ์ด๋ ๊ฒ ์ ์ํ๋ ๊ฒ์ด auto-diff library์ ์ ์ฉํ๊ธฐ ํ๋ค๋ค๋ ์ ์
๋๋ค. ๋์งธ๋ก, ๋จ์ํ๊ฒ gradient estimator๋ฅผ ๊ตฌํ๋ฉด โฮธโf(x;ฮธ)๎ =g(x;ฮธ)์ด๊ธฐ ๋๋ฌธ์ ์ ๋๋ก ์
๋ฐ์ดํธ๋์ง ์์ต๋๋ค.
์์์ ์ ์์์ ์ ์ํ ๊ฒ๊ณผ ๊ฐ์ด L=E[โcโCโc]๋ฅผ SCG์์์ objective๋ก ์ ์ํ๊ณ ์์ํฉ๋๋ค. ์ด๋ ๋ชจ๋ ์์กด์ฑ์ ๋ง์กฑํ๋ gradient estimator๋ ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
โฮธโL=E[โcโCโ(cโwโWcโโโฮธโlogp(wโฃDEPSwโ)+โฮธโc(DEPScโ))]โฏ(1)
Wcโ๋ stochastic nodes์ ์ํ๊ณ , cost nodes์ ์ํฅ์ ๋ผ์น๋ฉด์ ฮธ์ ์ํฅ์ ๋ฐ๋ ๋ชจ๋ node๋ฅผ ์๋ฏธํฉ๋๋ค.ancestors node์ ์ ์กฐ๊ฑดํ๋์๋ค๊ณ ๊ฐ์ ํ๊ณ ์ด์ ๋ถํฐ DEPS์ ํ๊ธฐ์ ๋ํด ์๋ตํด์ ํ๊ธฐํ๊ฒ ์ต๋๋ค.
์ด์ ๋ถํฐ ์๊ฐํ์ง๋ง, DiCE์์๋ ๋์ ์ฐจ์์ ๋ฏธ๋ถ์ ์ ํํ๊ฒ ํ๊ธฐ ์ํด MagicBox โก๋ผ๋ operator๋ฅผ ์ฌ์ฉํ๊ณ , input์ผ๋ก๋ stochastic nodes W, ๊ทธ๋ฆฌ๊ณ ์๋์ ๊ฐ์ ๋๊ฐ์ง ์ฑ์ง์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
โก(W)โ1
โฮธโโก(W)=โก(W)โwโWโโฮธโlog(p(w;ฮธ))
์ฒซ๋ฒ์งธ ์ฑ์ง์ โ๋ ํ๊ฐํ๋ค(evaluates to)๋ผ๋ ์๋ฏธ๋ก ๋ชจ๋ gradient์ ๊ฐ์์ ์๋ฏธํ๋ full equality(=)์๋ ๋์กฐ์ ์
๋๋ค. auto-diff์์๋ ์ด๋ฅผ forward pass evaluation์ ์๋ฏธ๋ก ์ฌ์ฉํฉ๋๋ค.
๋๋ฒ์งธ ์ฑ์ง์ โก๋ฅผ ์ฌ์ฉํด์ sample์ด ์ด๋์ sampling๋๋์ง ๊ทธ ๋ถํฌ์ ๋ํ ์์กด์ฑ์ ๋ณด์
๋๋ค.(w์ ๋ํ ํ๋ฅ ํฉ ํํ๊ฐ ๋ฉ๋๋ค.) ๊ทธ๋ฆฌ๊ณ ๋ฏธ๋ถํ๋ฉด log likelihood trick์ ์ด์ฉํด logํํ๋ก ๋ํ๋ ๊ฒ์
๋๋ค. ์ด๋ ์ด ์ฑ์ง์ ๋ง์กฑํ๋ฉด ์ฒซ๋ฒ์งธ ์ฑ์ง์ ์ฝ๊ฒ ๋ง์กฑํ ์ ์์ต๋๋ค. (์ด ํ๋ฅ ํฉ์ด 1์ด๋ฏ๋ก)
๋๋ฒ์งธ ํน์ฑ์ ๋ง์กฑํ๋ค๋ฉด, L=E[โcโCโc]์ธ objective์ ๋ํด ๋ค์๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
Lโกโ=โcโCโโก(Wcโ)cย ย (โตโก(Wcโ)โ1)
์ด Lโกโ์ ๊ฐ์ง๊ณ ์ด๋ป๊ฒ ์ ํํ๊ฒ ๊ณ ์ฐจ๋ฏธ๋ถ์ ํ ์ ์๋์ง์ ๋ํด ์ฆ๋ช
ํด๋ณด๊ฒ ์ต๋๋ค.
Theroem1.ย ย ย ย ย ย E[โฮธnโLโกโ]โโฮธnโL,โnโ{0,1,2,โฏ}
๋ชจ๋ cost nodes cโC์ ๋ํด ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๊ฒ ์ต๋๋ค.
ย ย ย ย ย c0ย ย ย =ย ย ย ย ย ย cE[cn+1]=โฮธโE[cn]
์ฆ cn๋ objective E[c]์ n์ฐจ ๋ฏธ๋ถ๊ฐ์
๋๋ค.
๋ค์์ผ๋ก cโกnโ์ cnโก(Wcnโ)์ธ๋ฐ, magicbox operator์ ์ฒซ๋ฒ์งธ ํน์ฑ์ผ๋ก ์ธํด, โกWcnโ์ 1์ด ๋์ด, cโกnโโcn์์ ์์์ต๋๋ค. ์ด๋ฅผ ํตํด, cโกnโ๋ํ objective์ n๋ฒ์งธ ๋ฏธ๋ถ๊ฐ์ด๋ ๊ฐ๋ค๋ ์๋ฏธ๊ฐ ๋ฉ๋๋ค. ๊ทธ๋ ๋ค๋ฉด, ๋ง์ง๋ง์ผ๋ก โฮธโcโกnโ=cโกn+1โ์์ ๋ณด์ด๋ฉด n์ฐจ ๋ฏธ๋ถ ์ ์ฒด์ ๋ํด magicbox operator๋ก ๊ตฌํ ์ ์๊ณ , ๊ทธ๊ฒ์ด ์ค์ ๋ฏธ๋ถ๊ฐ๊ณผ ๊ฐ๋ค๋ ์๋ฏธ๊ฐ ๋ฉ๋๋ค.
โฮธโcโกnโ=โฮธโ(cnโก(Wcnโ))
=cnโฮธโโก(Wcnโ)+โฮธโ(Wcnโ)โกcn
=cnโก(Wcnโ)(โwโWcnโโโฮธโlog(p(w;ฮธ)))+โก(Wcnโ)โฮธโcn
=โก(Wcnโ)(โฮธโcn+cnโwโWcnโโโฮธโlog(p(w;ฮธ)))โฏ(9.4.4)
โก(Wcn+1โ)cn+1=cโกn+1โโฏ(9.4.5)
์ด ๋, (9.4.4)์์ (9.4.5)๋ก๊ฐ ๋, ๋๊ฐ์ง ํ
ํฌ๋์ด ํ์ํฉ๋๋ค. ์ฒซ๋ฒ์งธ๋ก, L=E[cn]์ ํํ๋ฅผ ๋ณธ๋ฌธ ์(1)ํํ๋ก ๋ณํํด ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. ๊ทธ๋ ๊ฒ ๋๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
cn+1=โฮธโcn+cnโwโWcnโโโฮธโlogp(w;ฮธ)
์ด๋ฅผ ์์ธํ ๋ณด๋ฉด (9.4.4)์ ํํ๊ณผ ๊ฐ์์ ์ ์ ์์ต๋๋ค. ๋ ์งธ๋ก, Wcnโ๊ณผ Wcn+1โ์ ๊ฐ์ stochastic nodes๋ฅผ ๊ฐ๋ฆฌํค๊ณ ์์ ๊ฒ์ด๋ฏ๋ก, Wcnโ=Wcn+1โ์ด ์๋ช
ํฉ๋๋ค.