9.4 Correct Gradient Estimators with DiCE
Last updated
Was this helpful?
Last updated
Was this helpful?
์ด๋ฒ section์์๋ ์ด๋ฅผ ๋ชจ๋ ํด๊ฒฐํ Infinitely Differentiable Monte-Carlo Estimator(DiCE)๋ฅผ ์๊ฐํฉ๋๋ค. ์ด๋ SCG์์ ์ด๋ค ์ฐจ์์ ๋ฏธ๋ถ๋ ์ ํํ๊ฒ ๊ณ์ฐํ ์ ์๋ ์ค์ฉ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ํน์ ์ฐจ์์ ๋ฏธ๋ถ์ ํ๊ธฐ ์ํด ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ 9.3.1์ ๋์จ ๋ฐฉ๋ฒ์ ์ฌ๊ท์ ์ผ๋ก ๊ณ์ ์ฌ์ฉํ๋ฉด ๋๊ฒ ์ง๋ง, ์ด๋ ๋ ๊ฐ์ง ๊ฒฐ์ ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ฒซ๋ฒ์งธ๋ก gradient๋ฅผ ์ด๋ ๊ฒ ์ ์ํ๋ ๊ฒ์ด auto-diff library์ ์ ์ฉํ๊ธฐ ํ๋ค๋ค๋ ์ ์ ๋๋ค. ๋์งธ๋ก, ๋จ์ํ๊ฒ gradient estimator๋ฅผ ๊ตฌํ๋ฉด ์ด๊ธฐ ๋๋ฌธ์ ์ ๋๋ก ์ ๋ฐ์ดํธ๋์ง ์์ต๋๋ค.
์์์ ์ ์์์ ์ ์ํ ๊ฒ๊ณผ ๊ฐ์ด ๋ฅผ SCG์์์ objective๋ก ์ ์ํ๊ณ ์์ํฉ๋๋ค. ์ด๋ ๋ชจ๋ ์์กด์ฑ์ ๋ง์กฑํ๋ gradient estimator๋ ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
๋ stochastic nodes์ ์ํ๊ณ , cost nodes์ ์ํฅ์ ๋ผ์น๋ฉด์ ์ ์ํฅ์ ๋ฐ๋ ๋ชจ๋ node๋ฅผ ์๋ฏธํฉ๋๋ค.ancestors node์ ์ ์กฐ๊ฑดํ๋์๋ค๊ณ ๊ฐ์ ํ๊ณ ์ด์ ๋ถํฐ DEPS์ ํ๊ธฐ์ ๋ํด ์๋ตํด์ ํ๊ธฐํ๊ฒ ์ต๋๋ค.
์ด์ ๋ถํฐ ์๊ฐํ์ง๋ง, DiCE์์๋ ๋์ ์ฐจ์์ ๋ฏธ๋ถ์ ์ ํํ๊ฒ ํ๊ธฐ ์ํด MagicBox ๋ผ๋ operator๋ฅผ ์ฌ์ฉํ๊ณ , input์ผ๋ก๋ stochastic nodes , ๊ทธ๋ฆฌ๊ณ ์๋์ ๊ฐ์ ๋๊ฐ์ง ์ฑ์ง์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
์ฒซ๋ฒ์งธ ์ฑ์ง์ ๋ ํ๊ฐํ๋ค(evaluates to)๋ผ๋ ์๋ฏธ๋ก ๋ชจ๋ gradient์ ๊ฐ์์ ์๋ฏธํ๋ full equality(=)์๋ ๋์กฐ์ ์ ๋๋ค. auto-diff์์๋ ์ด๋ฅผ forward pass evaluation์ ์๋ฏธ๋ก ์ฌ์ฉํฉ๋๋ค.
๋๋ฒ์งธ ์ฑ์ง์ ๋ฅผ ์ฌ์ฉํด์ sample์ด ์ด๋์ sampling๋๋์ง ๊ทธ ๋ถํฌ์ ๋ํ ์์กด์ฑ์ ๋ณด์ ๋๋ค.(์ ๋ํ ํ๋ฅ ํฉ ํํ๊ฐ ๋ฉ๋๋ค.) ๊ทธ๋ฆฌ๊ณ ๋ฏธ๋ถํ๋ฉด log likelihood trick์ ์ด์ฉํด logํํ๋ก ๋ํ๋ ๊ฒ์ ๋๋ค. ์ด๋ ์ด ์ฑ์ง์ ๋ง์กฑํ๋ฉด ์ฒซ๋ฒ์งธ ์ฑ์ง์ ์ฝ๊ฒ ๋ง์กฑํ ์ ์์ต๋๋ค. (์ด ํ๋ฅ ํฉ์ด 1์ด๋ฏ๋ก)
๋๋ฒ์งธ ํน์ฑ์ ๋ง์กฑํ๋ค๋ฉด, ์ธ objective์ ๋ํด ๋ค์๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
์ด ์ ๊ฐ์ง๊ณ ์ด๋ป๊ฒ ์ ํํ๊ฒ ๊ณ ์ฐจ๋ฏธ๋ถ์ ํ ์ ์๋์ง์ ๋ํด ์ฆ๋ช ํด๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ cost nodes ์ ๋ํด ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๊ฒ ์ต๋๋ค.
์ฆ ๋ objective ์ n์ฐจ ๋ฏธ๋ถ๊ฐ์ ๋๋ค.
๋ค์์ผ๋ก ์ ์ธ๋ฐ, magicbox operator์ ์ฒซ๋ฒ์งธ ํน์ฑ์ผ๋ก ์ธํด, ์ 1์ด ๋์ด, ์์ ์์์ต๋๋ค. ์ด๋ฅผ ํตํด, ๋ํ objective์ n๋ฒ์งธ ๋ฏธ๋ถ๊ฐ์ด๋ ๊ฐ๋ค๋ ์๋ฏธ๊ฐ ๋ฉ๋๋ค. ๊ทธ๋ ๋ค๋ฉด, ๋ง์ง๋ง์ผ๋ก ์์ ๋ณด์ด๋ฉด n์ฐจ ๋ฏธ๋ถ ์ ์ฒด์ ๋ํด magicbox operator๋ก ๊ตฌํ ์ ์๊ณ , ๊ทธ๊ฒ์ด ์ค์ ๋ฏธ๋ถ๊ฐ๊ณผ ๊ฐ๋ค๋ ์๋ฏธ๊ฐ ๋ฉ๋๋ค.
์ด ๋, (9.4.4)์์ (9.4.5)๋ก๊ฐ ๋, ๋๊ฐ์ง ํ ํฌ๋์ด ํ์ํฉ๋๋ค. ์ฒซ๋ฒ์งธ๋ก, ์ ํํ๋ฅผ ๋ณธ๋ฌธ ์(1)ํํ๋ก ๋ณํํด ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ๊ทธ๋ ๊ฒ ๋๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
์ด๋ฅผ ์์ธํ ๋ณด๋ฉด (9.4.4)์ ํํ๊ณผ ๊ฐ์์ ์ ์ ์์ต๋๋ค. ๋ ์งธ๋ก, ๊ณผ ์ ๊ฐ์ stochastic nodes๋ฅผ ๊ฐ๋ฆฌํค๊ณ ์์ ๊ฒ์ด๋ฏ๋ก, ์ด ์๋ช ํฉ๋๋ค.