pip install "nutpie[pymc]"
Collecting nutpie[pymc] Downloading nutpie-0.16.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.5 kB) Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (18.1.0) Collecting arro3-core>=0.6.0 (from nutpie[pymc]) Downloading arro3_core-0.8.0-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (515 bytes) Requirement already satisfied: pandas>=2.0 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (2.2.2) Requirement already satisfied: xarray>=2025.1.2 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (2025.12.0) Requirement already satisfied: arviz<1.0,>=0.20.0 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (0.22.0) Collecting obstore>=0.8.0 (from nutpie[pymc]) Downloading obstore-0.9.2-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB) Collecting zarr>=3.1.0 (from nutpie[pymc]) Downloading zarr-3.1.5-py3-none-any.whl.metadata (10 kB) Requirement already satisfied: pymc>=5.20.1 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (5.28.1) Requirement already satisfied: numba>=0.60.0 in /usr/local/lib/python3.12/dist-packages (from nutpie[pymc]) (0.60.0) Requirement already satisfied: setuptools>=60.0.0 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (75.2.0) Requirement already satisfied: matplotlib>=3.8 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (3.10.0) Requirement already satisfied: numpy>=1.26.0 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (2.0.2) Requirement already satisfied: scipy>=1.11.0 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (1.16.3) Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (26.0) Requirement already satisfied: h5netcdf>=1.0.2 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (1.8.1) Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (4.15.0) Requirement already satisfied: xarray-einstats>=0.3 in /usr/local/lib/python3.12/dist-packages (from arviz<1.0,>=0.20.0->nutpie[pymc]) (0.10.0) Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.12/dist-packages (from numba>=0.60.0->nutpie[pymc]) (0.43.0) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0->nutpie[pymc]) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0->nutpie[pymc]) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.0->nutpie[pymc]) (2025.3) Requirement already satisfied: cachetools<7,>=4.2.1 in /usr/local/lib/python3.12/dist-packages (from pymc>=5.20.1->nutpie[pymc]) (6.2.6) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from pymc>=5.20.1->nutpie[pymc]) (3.1.2) Requirement already satisfied: pytensor<2.39,>=2.38.0 in /usr/local/lib/python3.12/dist-packages (from pymc>=5.20.1->nutpie[pymc]) (2.38.2) Requirement already satisfied: rich>=13.7.1 in /usr/local/lib/python3.12/dist-packages (from pymc>=5.20.1->nutpie[pymc]) (13.9.4) Requirement already satisfied: threadpoolctl<4.0.0,>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from pymc>=5.20.1->nutpie[pymc]) (3.6.0) Collecting donfig>=0.8 (from zarr>=3.1.0->nutpie[pymc]) Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB) Requirement already satisfied: google-crc32c>=1.5 in /usr/local/lib/python3.12/dist-packages (from zarr>=3.1.0->nutpie[pymc]) (1.8.0) Collecting numcodecs>=0.14 (from zarr>=3.1.0->nutpie[pymc]) Downloading numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (3.4 kB) Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from donfig>=0.8->zarr>=3.1.0->nutpie[pymc]) (6.0.3) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (4.62.0) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (1.5.0) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (11.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.8->arviz<1.0,>=0.20.0->nutpie[pymc]) (3.3.2) Requirement already satisfied: filelock>=3.15 in /usr/local/lib/python3.12/dist-packages (from pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (3.25.1) Requirement already satisfied: etuples in /usr/local/lib/python3.12/dist-packages (from pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (0.3.10) Requirement already satisfied: logical-unification in /usr/local/lib/python3.12/dist-packages (from pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (0.4.7) Requirement already satisfied: miniKanren in /usr/local/lib/python3.12/dist-packages (from pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (1.0.5) Requirement already satisfied: cons in /usr/local/lib/python3.12/dist-packages (from pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (0.4.7) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas>=2.0->nutpie[pymc]) (1.17.0) Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.7.1->pymc>=5.20.1->nutpie[pymc]) (4.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.7.1->pymc>=5.20.1->nutpie[pymc]) (2.19.2) Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.7.1->pymc>=5.20.1->nutpie[pymc]) (0.1.2) Requirement already satisfied: toolz in /usr/local/lib/python3.12/dist-packages (from logical-unification->pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (0.12.1) Requirement already satisfied: multipledispatch in /usr/local/lib/python3.12/dist-packages (from logical-unification->pytensor<2.39,>=2.38.0->pymc>=5.20.1->nutpie[pymc]) (1.0.0) Downloading arro3_core-0.8.0-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 3.0/3.0 MB 31.6 MB/s eta 0:00:00 Downloading obstore-0.9.2-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.2 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 4.2/4.2 MB 80.0 MB/s eta 0:00:00 Downloading zarr-3.1.5-py3-none-any.whl (284 kB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 284.1/284.1 kB 9.7 MB/s eta 0:00:00 Downloading nutpie-0.16.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.7 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 8.7/8.7 MB 64.5 MB/s eta 0:00:00 Downloading donfig-0.8.1.post1-py3-none-any.whl (21 kB) Downloading numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (9.2 MB) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 9.2/9.2 MB 36.3 MB/s eta 0:00:00 Installing collected packages: obstore, numcodecs, donfig, arro3-core, zarr, nutpie Successfully installed arro3-core-0.8.0 donfig-0.8.1.post1 numcodecs-0.16.5 nutpie-0.16.8 obstore-0.9.2 zarr-3.1.5
DINA model in PyMCยถ
- 2026-03-13
- jw
ๆฌๆๅ ้่ไปฅ DINA model ๅฏฆๅใ
- DINA (Deterministic Input, Noisy AND gate). ๅ่จญ่ฉฒ้ก็ฎ้่ฆๆๅฐๆ็ๆๆๅฑฌๆง้ฝๆๆก (AND)๏ผๆ่ฝ็ญๅฐใ๏ผๅฑฌๆง็ไนๆณ้ไฟ๏ผไบคไบๆๆ้จไปฝใ๏ผ้ฉๅ่ช็ฅๆธฌ้ฉใ
- DINO (Deterministic Input, Noisy OR gate). ๅ่จญ่ฉฒ้ก็ฎ้ๆๅฐๆ็ๅฑฌๆงๆๆกไธ็จฎ๏ผๅฐฑ่ฝ็ญๅฐ (OR)ใ๏ผๅฑฌๆง็ๅ ๆณ้ไฟ๏ผไธปๆๆ้จไปฝใ๏ผ้ฉๅ้่ช็ฅๆธฌ้ฉใ
- GDINA. ็ดๅ ฅๆๆๅฏ่ฝ็ไธปๆๆๅไบคไบๆๆ็้ฃฝๅๆจกๅผใ่ฎ่ณๆๆฑบๅฎๅชไบๆๆ้่ฆ๏ผไธ้ ่จญ็ตๆงใ
่้้ปใยถ
GDINA ้็ถๆฏ่ผๅฝๆง๏ผไฝๆฏๅ้ก็ฎ็ๅๆธ้ไน่ผๅคใๅ ๆญค่้ๅฐๅๆธๆธ้ๅๆจฃๆฌๆธ้็ๆฌ่กก๏ผ่ฆไฝฟ็จ้ฃฝๅๆจกๅผ (GDINA, LCDM, ...) ้ๆฏ็ธฎๆธๆจกๅผ(DINA, DINO, ...) ้่ฆ็ ็ฉถ่ ้ธๆใ ไปฅ GDINA ่่จ๏ผๆฏๅ้ก็ฎ้่ฆ $2^{K_j}$ ๅๅๆธ๏ผ${K_j}$ ๆฏ็ๆ้ก็ฎ $j$ ๅฐๆๅนพๅๅฑฌๆง $K$. ่ DINA ๆจกๅผๆฏๅ้ก็ฎๅฐฑๅฐๆ2ๅๅๆธ (slip, guess)ใ
ไปฅ 20 ้ก๏ผ5 ๅๅฑฌๆง็บไพ๏ผ
- DINA. $20 \times 2 = 40$ ๅๅๆธใ
- GDINA. ๆไฝๆ ๆณๆฏ้กๅชๅฐๆ 1 ๅๅฑฌๆง๏ผไฝ้ๆๆ่ญๅฅๆงๅ้ก๏ผ๏ผ$20 \times 2^1 = 40$ ๅๅๆธ๏ผๆ้ซๆ ๆณๆฏ้กๅฐๆ 5 ๅๅฑฌๆง๏ผ้ไนๆ่ญๅฅๆงๅ้ก๏ผ๏ผ$20 \times 2^5 = 640$ ๅๅๆธใๆญฃๅธธๆ ๆณๆ่ฉฒไปๆผ 40--640 ๅๅๆธไน้ใไป็ถๆฏ้ๅธธๅคง็ๅๆธ้ใๅจ Q-matrix ่จญ่จๅพ็ถ็ๆ ๆณไธ๏ผไน้่ฆๆ่ถณๅค ็ๆจฃๆฌๆ่ฝไผฐ่จๅพๅฅฝใ
ๅฆๅค๏ผ GDINA ็ไผฐ่จๆนๆณไนๆด็บ่ค้ใDINA ๆฏ็ธๅฐๅฎนๆไธ็ฉฉๅฎ็้ธ้ ไนไธใ
DCM ่ mixture modelingใยถ
่จบๆทๅ้กๆจกๅผ (diagnostic classification model, DCM) (้้ๆ่ชช็ DINA, DINO, GDINA ้ฝๅฑฌๆผๆญค้ก) ๅฏ่ฆ็บไธ็จฎใ้ฉ่ญๆงใ๏ผๆ่ชช้ๅถๆง๏ผ็ๅ้กๆจกๅผ๏ผไนๅฐฑๆฏ่ชช่ฆๅๅนพ้กๆฏ็ขบๅฎๅฅฝ็ใๅณ $2^K$ ้ก๏ผK ๆฏๅฑฌๆงๆธ้ใไพๅฆ 3 ๅๅฑฌๆง๏ผๅฐฑๆ $2^3=8$ ๅ็ฒพ็็ตๅ๏ผๅพ {0,0,0}, {0,0,1},...,{1,1,1}, ๅชๆฏๆไบ็ตๅฏ่ฝๆฒๆไบบใ่ฉณ่ฆๆฌๅๆไปถ็ๆๅพใ
็ธ่ผๆผใๆข็ดขๆงใ็ latent class analysis (LCA)(ๅฏ่ฆ็บๅบๅฎๆฝๅจ่ฎ็ฐๆธ็บ 0 ็ GMM), ๆๆฏ k-means (็จ่ท้ขๆฑบๅฎ็) ๏ผๆๅ ถไปๅๅผ mixture model๏ผไพๅฆ Gaussian mixture model๏ผ๏ผ่ฆๅนพ้กๅพๅพๆฏใๆข็ดขใๅบไพ็ใไนๅฐฑๆฏ็ ็ฉถ่ ้้ๆจกๅผๆฏ่ผ๏ผๆฑบๅฎ่ฆๅๆๅนพ้กใ
ๅ ๆญค๏ผDCM ๆๅฅๆผๅ ถไป็ๅ้กๆจกๅผใไธ่ฌ็ๅ้กๆจกๅผๅชๆฏๅฎ็ดๆ นๆ่ฎ้ ๆไบบๅๅ้ก๏ผไฝ DCM ๆด้่ฆ็็ฎๆจๆฏๆไบบๅ้ฒใ่จญๅฎๅฅฝ็็ฒพ็้กๅใไธญ๏ผ่ๆญค็ฅ้ๆฏไฝๅญธ็ๅจ้ไบๅฑฌๆง๏ผๆๆ่ฝ็ญ็ญ๏ผๆฏๅฆ็ฒพ็/ไธ็ฒพ็ใ้ๅฐ่ช็ฅ่จบๆท็ๆๆใ
MCMCยถ
ไนๅฏไปฅไฝฟ็จ EM ไผฐ่จใไฝ้่ฆๅฆๅคๆชข้ฉใ
ๅจ่ฒๆฐ็ๆถๆงไธๅฏไปฅ็ด่ฆบ็ๅฏซๅบ็ตฑ่จๆจกๅผ๏ผไธฆ้้่ชๅๆฝๆจฃ็ๆนๅผๅพๅฐๅๆธไผฐ่จใๅ ็บ DCM ๆถๅๅฐ้ขๆฃ็ๆฝๅจ่ฎๆธ๏ผๅ ๆญค้ๅป้ๅธธ็จ Gibbs sampler ไผฐ่จ ๏ผไพๅฆ jags, ่ฆไธๆนๆ็ป๏ผ
ไฝ่้ๅฐๅจ python ไธๅฏฆๅ๏ผๆฏ่ผๆ็็่ฒๆฐ็ตฑ่จ่ป้ซ็บ PyMC . ไฝ PyMC ๆฏ็จ HMC/NUTS ๆฝๆจฃๅจ๏ผ่ฉฒๆนๆณๆฏ่ผๅฟซ๏ผไฝๅฐๆผ้ขๆฃๆฝๅจ่ฎๆธไธๅๅใๅ ๆญค้่ฆ็นๆฎ่็ ๏ผlogsumexp๏ผ๏ผ่ฆไธ๏ผใ
Refยถ
(Bayesian-Gibbs)
Zhan, P., Jiao, H., Man, K., & Wang, L. (2019). Using JAGS for Bayesian cognitive diagnosis modeling: A tutorial. Journal of Educational and Behavioral Statistics, 44(4), 473-503. LINK
(EM-DINA)
- De La Torre, J. (2009). DINA model and parameter estimation: A didactic. Journal of educational and behavioral statistics, 34(1), 115-130. LINK
(EM-GDINA)
- De La Torre, J. (2011). The generalized DINA model framework. Psychometrika, 76(2), 179-199. LINK
"""
DINA + PyMC ่ฒๆฐไผฐ่จ
========================
ๆจกๅ๏ผDeterministic Input, Noisy And Gate (DINA)
- ๆฏ้กๅ
ฉๅๅๆธ๏ผslip (s_j) ่ guess (g_j)
- ็ๆณๅๆ๏ผฮท_ij = โ_k ฮฑ_ik^q_jk ๏ผๅฟ
้ ๆๆกๆๆๅฟ
่ฆๅฑฌๆง๏ผ
- P(Y_ij=1 | ฮฑ_i) = (1 - s_j)^ฮท_ij * g_j^(1 - ฮท_ij)
ๆจกๆฌ่ณๆ๏ผ500ไบบใ20้กใ3ๅฑฌๆง
"""
import numpy as np
import pandas as pd
import itertools
import warnings
warnings.filterwarnings("ignore")
import pymc as pm
import pytensor.tensor as pt
from scipy.special import logsumexp
from scipy.stats import pearsonr
import nutpie
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 1๏ผๆจกๆฌ่ณๆ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
np.random.seed(1234)
N, J, K = 1500, 20, 3
all_alpha = np.array(list(itertools.product([0, 1], repeat=K))) # (8, 3)
L = len(all_alpha) # 2^K = 8 ็จฎๅฑฌๆง็ตๅ
# Q ็ฉ้ฃ๏ผๆๅฎๆฏ้กๆธฌ้ๅชไบๅฑฌๆง
Q = np.array([
[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], # ้ก 1-4๏ผๅ
ๆธฌๅฑฌๆง1
[0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], # ้ก 5-8๏ผๅ
ๆธฌๅฑฌๆง2
[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], # ้ก 9-12๏ผๅ
ๆธฌๅฑฌๆง3
[1, 1, 0], [1, 1, 0], [1, 1, 0], # ้ก 13-15๏ผๆธฌๅฑฌๆง1+2
[1, 0, 1], [1, 0, 1], [1, 0, 1], # ้ก 16-18๏ผๆธฌๅฑฌๆง1+3
[1, 1, 1], [1, 1, 1], # ้ก 19-20๏ผๆธฌๅ
จ้จๅฑฌๆง
])
# ๅฑฌๆง็ตๅ็ๅ
้ฉๆฏไพ
pi_true = np.array([0.05, 0.08, 0.08, 0.12, 0.10, 0.12, 0.15, 0.30])
# DINA ็ๅฏฆๅๆธ๏ผslip ่ guess
# slip๏ผๆๆกๆๆๅฑฌๆงไฝ็ญ้ฏ็ๆฆ็๏ผguess๏ผๆชๆๆกๅฑฌๆงไฝ็ญๅฐ็ๆฆ็
slip_true = np.array([0.10, 0.15, 0.10, 0.15, # ้ก 1-4
0.10, 0.15, 0.10, 0.15, # ้ก 5-8
0.10, 0.15, 0.10, 0.15, # ้ก 9-12
0.10, 0.15, 0.10, # ้ก 13-15
0.10, 0.15, 0.10, # ้ก 16-18
0.10, 0.15]) # ้ก 19-20
guess_true = np.array([0.15, 0.20, 0.10, 0.25, # ้ก 1-4
0.15, 0.20, 0.10, 0.25, # ้ก 5-8
0.15, 0.20, 0.10, 0.25, # ้ก 9-12
0.10, 0.15, 0.10, # ้ก 13-15
0.10, 0.15, 0.10, # ้ก 16-18
0.10, 0.15]) # ้ก 19-20
# ่จ็ฎๅๅฑฌๆง็ตๅไธ๏ผๆฏ้ก็็ๆณๅๆ ฮท_lj = โ_k ฮฑ_lk^q_jk
# eta_true[l, j] = 1 iff ๅฑฌๆง็ตๅ l ๆๆก้ก j ็ๆๆๅฟ
่ฆๅฑฌๆง
eta_true = np.array([
[int(np.all(all_alpha[l] >= Q[j])) for j in range(J)]
for l in range(L)
]) # (L, J)
# DINA ๆๅๆฆ็๏ผP(Y=1|ฮฑ_l) = (1-s_j)^ฮท_lj * g_j^(1-ฮท_lj)
P_true = np.zeros((J, L))
for j in range(J):
for l in range(L):
if eta_true[l, j] == 1:
P_true[j, l] = 1 - slip_true[j] # ๆๆก โ ๆๅๆฆ็้ซ
else:
P_true[j, l] = guess_true[j] # ๆชๆๆก โ ็ๅฐๆฆ็ไฝ
# ๆจกๆฌๅ่ฉฆ่
ๅฑฌๆง็ตๅ
alpha_idx = np.random.choice(L, size=N, p=pi_true)
Alpha_true = all_alpha[alpha_idx]
# ๆจกๆฌไฝ็ญ่ณๆ
Y = np.array([
[np.random.binomial(1, P_true[j, alpha_idx[i]]) for j in range(J)]
for i in range(N)
])
print(f"ๆจกๆฌ่ณๆๅฎๆ๏ผY={Y.shape}, Q={Q.shape}")
print(f"slip ็ฏๅ๏ผ{slip_true.min():.2f} ~ {slip_true.max():.2f}")
print(f"guess ็ฏๅ๏ผ{guess_true.min():.2f} ~ {guess_true.max():.2f}")
ๆจกๆฌ่ณๆๅฎๆ๏ผY=(1500, 20), Q=(20, 3) slip ็ฏๅ๏ผ0.10 ~ 0.15 guess ็ฏๅ๏ผ0.10 ~ 0.25
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 2๏ผ้ ่จ็ฎ ฮท ็ฉ้ฃ๏ผPyMC ๅค้จ๏ผ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# ฮท_lj = 1 iff ๅฑฌๆง็ตๅ l ๆๆกไบ้ก j ็ๆๆๅฟ
่ฆๅฑฌๆง
# ๆญค็บๅบๅฎ็็ตๆง็ฉ้ฃ๏ผ็ฑ Q ็ฉ้ฃๆฑบๅฎ๏ผไธ้่ฆไผฐ่จ
# eta[l, j]: ๅฑฌๆง็ตๅ l ๅฐ้ก j ็็ๆณๅๆ
eta = np.array([
[int(np.all(all_alpha[l] >= Q[j])) for j in range(J)]
for l in range(L)
], dtype=np.float32) # (L, J)
print("ฮท ็ฉ้ฃ๏ผๅ 4 ้ก ร ๅ
จ้จ 8 ๅๅฑฌๆง็ตๅ๏ผ๏ผ")
df_eta = pd.DataFrame(eta[:, :4],
index=["".join(map(str, a)) for a in all_alpha],
columns=[f"Q{j+1}" for j in range(4)])
print(df_eta.to_string())
print("\n๏ผฮท=1 ่กจ็คบ่ฉฒๅฑฌๆง็ตๅๆๆกไบ้ก็ฎๆๆๅฟ
่ฆๅฑฌๆง๏ผ")
ฮท ็ฉ้ฃ๏ผๅ 4 ้ก ร ๅ
จ้จ 8 ๅๅฑฌๆง็ตๅ๏ผ๏ผ
Q1 Q2 Q3 Q4
000 0.0 0.0 0.0 0.0
001 0.0 0.0 0.0 0.0
010 0.0 0.0 0.0 0.0
011 0.0 0.0 0.0 0.0
100 1.0 1.0 1.0 1.0
101 1.0 1.0 1.0 1.0
110 1.0 1.0 1.0 1.0
111 1.0 1.0 1.0 1.0
๏ผฮท=1 ่กจ็คบ่ฉฒๅฑฌๆง็ตๅๆๆกไบ้ก็ฎๆๆๅฟ
่ฆๅฑฌๆง๏ผ
้้็จ VI ไผฐ่จๆฏ่ผๅฟซ๏ผไฝๆฏๅๆธๅๅพฉๆงๆฏ่ผไธๅฅฝใ ็จ MCMC ไผฐ่จๆฏ่ผๆ ข๏ผๅจ colab ไธ็ด้ ่ฆ 1-3 ๅ้๏ผๅฐๆฌๅฐไธๅฏ่ฝๅฏๅ ้๏ผ๏ผไฝๅๆธๅๅพฉๆง่ผๅฅฝใ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 3๏ผPyMC DINA๏ผMCMC ไผฐ่จ๏ผ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
#
# DINA ๆจกๅ๏ผ
# P(Y_ij=1 | ฮฑ_i) = (1 - s_j)^ฮท_ij * g_j^(1 - ฮท_ij)
#
# ๅ
้ฉ่จญๅฎ๏ผๅ่่ซๆ de la Torre 2009; Junker & Sijtsma 2001๏ผ๏ผ
# slip_j ~ Beta(1, 4) โ ๅๅๅฐๅผ๏ผslip ้ๅธธ่ผๅฐ๏ผ
# guess_j ~ Beta(1, 4) โ ๅๅๅฐๅผ๏ผguess ้ๅธธ่ผๅฐ๏ผ
# ฯ ~ Dirichlet(1, ..., 1) โ ๅๅปๅ
้ฉ
#
# ๅฏ่ญๅฅๆง็ดๆ๏ผs_j + g_j < 1 ๏ผ้้ๅ
้ฉๅฝข็้ฑๅผๆงๅถ๏ผ
eta_pt = pt.constant(eta) # (L, J)
Y_pt = pt.constant(Y.astype(np.float32)) # (N, J)
with pm.Model() as dina_model:
# โโ ๅ
้ฉ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# ๅฑฌๆง็ตๅ็ๆททๅๆฏไพ
pi = pm.Dirichlet("pi", a=np.ones(L)) # (L,)
# ๆฏ้ก็ slip ่ guess ๅๆธ
# Beta(1,4) ๅ
้ฉ่ฎๅๆธๅๅ 0๏ผ็ฌฆๅ DINA ็่ญๅฅๆขไปถ
slip = pm.Beta("slip", alpha=1, beta=4, shape=(J,)) # (J,)
guess = pm.Beta("guess", alpha=1, beta=4, shape=(J,)) # (J,)
# โโ DINA ๆๅๆฆ็็ฉ้ฃ P_mat (J, L) โโโโโโโโโโโโโโโโโโโโโ
# P_mat[j, l] = (1 - slip[j])^ฮท[l,j] * guess[j]^(1 - ฮท[l,j])
# ๅฉ็จ log-sum trick ่จ็ฎ
# log P = ฮท[l,j] * log(1-slip[j]) + (1-ฮท[l,j]) * log(guess[j])
log_1_minus_slip = pt.log(1 - slip + 1e-8) # (J,)
log_guess = pt.log(guess + 1e-8) # (J,)
# eta_pt: (L, J) โ ่ฝ็ฝฎๆ (J, L) ไปฅไพฟๅปฃๆญ
eta_T = eta_pt.T # (J, L)
# log P_mat[j, l] = ฮท[l,j]*log(1-s_j) + (1-ฮท[l,j])*log(g_j)
log_P_mat = (
eta_T * log_1_minus_slip[:, None] +
(1 - eta_T) * log_guess[:, None]
) # (J, L)
P_mat = pt.exp(log_P_mat) # (J, L)
# โโ Mixture likelihood๏ผๅฐ L ๅๅฑฌๆง็ตๅๅ marginalization๏ผโ
# log p(Y_i | ฯ, s, g) = log ฮฃ_l ฯ_l * โ_j P_mat[j,l]^Y_ij * (1-P_mat[j,l])^(1-Y_ij)
log_likes = pt.stack([
pt.log(pi[l] + 1e-8) +
pt.dot(Y_pt, pt.log(P_mat[:, l] + 1e-8)) +
pt.dot(1 - Y_pt, pt.log(1 - P_mat[:, l] + 1e-8))
for l in range(L)
], axis=1) # (N, L)
pm.Potential("obs", pt.logsumexp(log_likes, axis=1).sum())
# โโ VI ไผฐ่จ๏ผADVI๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
"""
approx = pm.fit(20000,
#method="advi",
method="fullrank_advi",
callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
trace = approx.sample(2000)
"""
#trace = pm.sample(cores=2, target_accept=0.85, init="advi+adapt_diag")
# -- nutpie
compiled = nutpie.compile_pymc_model(dina_model)
trace = nutpie.sample(compiled, chains=2, cores=2, target_accept=0.85)
Sampler Progress
Total Chains: 2
Active Chains: 0
Finished Chains: 2
Sampling for a minute
Estimated Time to Completion: now
| Progress | Draws | Divergences | Step Size | Gradients/Draw |
|---|---|---|---|---|
| 1400 | 0 | 0.44 | 7 | |
| 1400 | 0 | 0.35 | 15 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 4๏ผๅบๆฌๅๆธๅๆถ็ตๆ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# ๆๅๅพ้ฉๅๅผ
slip_est = trace.posterior["slip"].values.reshape(-1, J).mean(axis=0) # (J,)
guess_est = trace.posterior["guess"].values.reshape(-1, J).mean(axis=0) # (J,)
pi_est = trace.posterior["pi"].values.reshape(-1, L).mean(axis=0) # (L,)
# Pearson r
r_slip, _ = pearsonr(slip_true, slip_est)
r_guess, _ = pearsonr(guess_true, guess_est)
r_pi, _ = pearsonr(pi_true, pi_est)
print("ๅๆธๅๆถ๏ผPearson r๏ผ๏ผ")
print(f" slip r = {r_slip:.4f}")
print(f" guess r = {r_guess:.4f}")
print(f" ฯ r = {r_pi:.4f}")
ๅๆธๅๆถ๏ผPearson r๏ผ๏ผ slip r = 0.8583 guess r = 0.9777 ฯ r = 0.9914
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 5๏ผๅๅๆธไผฐ่จ่กจ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
profile_labels = ["".join(map(str, a)) for a in all_alpha]
# โโ 5a. slip ่ guess๏ผ็ๅฏฆ vs ไผฐ่จ โโโโโโโโโโโโโโโโโโโโโโโโโ
slip_sd = trace.posterior["slip"].values.reshape(-1, J).std(axis=0)
guess_sd = trace.posterior["guess"].values.reshape(-1, J).std(axis=0)
df_sg = pd.DataFrame({
"้ก็ฎ": [f"Q{j+1}" for j in range(J)],
"slip ็ๅฏฆ": slip_true.round(3),
"slip ไผฐ่จ": slip_est.round(3),
"slip SD": slip_sd.round(3),
"slip ๅๅทฎ": (slip_est - slip_true).round(3),
"guess ็ๅฏฆ": guess_true.round(3),
"guess ไผฐ่จ": guess_est.round(3),
"guess SD": guess_sd.round(3),
"guess ๅๅทฎ": (guess_est - guess_true).round(3),
})
print("\n" + "="*70)
print("ใDINA ๅๆธ๏ผslip & guess โ ็ๅฏฆ vs ไผฐ่จใ")
print("="*70)
print(df_sg.to_string(index=False))
# ๅนณๅ็ตๅฐๅๅทฎ
mab_slip = np.abs(slip_est - slip_true).mean()
mab_guess = np.abs(guess_est - guess_true).mean()
print(f"\n slip MAB = {mab_slip:.4f}")
print(f" guess MAB = {mab_guess:.4f}")
====================================================================== ใDINA ๅๆธ๏ผslip & guess โ ็ๅฏฆ vs ไผฐ่จใ ====================================================================== ้ก็ฎ slip ็ๅฏฆ slip ไผฐ่จ slip SD slip ๅๅทฎ guess ็ๅฏฆ guess ไผฐ่จ guess SD guess ๅๅทฎ Q1 0.10 0.110 0.010 0.010 0.15 0.167 0.017 0.017 Q2 0.15 0.170 0.012 0.020 0.20 0.205 0.019 0.005 Q3 0.10 0.112 0.010 0.012 0.10 0.128 0.015 0.028 Q4 0.15 0.147 0.011 -0.003 0.25 0.254 0.020 0.004 Q5 0.10 0.100 0.010 -0.000 0.15 0.165 0.018 0.015 Q6 0.15 0.152 0.012 0.002 0.20 0.192 0.018 -0.008 Q7 0.10 0.101 0.010 0.001 0.10 0.120 0.016 0.020 Q8 0.15 0.144 0.012 -0.006 0.25 0.248 0.020 -0.002 Q9 0.10 0.105 0.011 0.005 0.15 0.147 0.016 -0.003 Q10 0.15 0.140 0.012 -0.010 0.20 0.199 0.017 -0.001 Q11 0.10 0.088 0.010 -0.012 0.10 0.105 0.013 0.005 Q12 0.15 0.141 0.011 -0.009 0.25 0.258 0.019 0.008 Q13 0.10 0.105 0.012 0.005 0.10 0.100 0.011 0.000 Q14 0.15 0.127 0.013 -0.023 0.15 0.135 0.012 -0.015 Q15 0.10 0.091 0.011 -0.009 0.10 0.112 0.012 0.012 Q16 0.10 0.118 0.013 0.018 0.10 0.087 0.010 -0.013 Q17 0.15 0.141 0.014 -0.009 0.15 0.151 0.012 0.001 Q18 0.10 0.080 0.010 -0.020 0.10 0.100 0.010 0.000 Q19 0.10 0.131 0.015 0.031 0.10 0.087 0.009 -0.013 Q20 0.15 0.166 0.017 0.016 0.15 0.162 0.011 0.012 slip MAB = 0.0109 guess MAB = 0.0092
# โโ 5b. ็ฑ slip/guess ้ๅปบ P ็ฉ้ฃ โโโโโโโโโโโโโโโโโโโโโโโโโโโ
# P_est[j, l] = (1 - slip_est[j])^ฮท[l,j] * guess_est[j]^(1 - ฮท[l,j])
P_est = np.zeros((J, L))
for j in range(J):
for l in range(L):
if eta[l, j] == 1:
P_est[j, l] = 1 - slip_est[j]
else:
P_est[j, l] = guess_est[j]
df_P_true = pd.DataFrame(P_true,
index=[f"Q{j+1}" for j in range(J)],
columns=profile_labels)
df_P_est = pd.DataFrame(P_est,
index=[f"Q{j+1}" for j in range(J)],
columns=profile_labels)
print("\n" + "="*60)
print("ใๅ้กๆๅๆฆ็ P โ ็ๅฏฆๅผใ")
print(" ๏ผๆฏ้กๅชๆ 2 ็จฎไธๅๅผ๏ผguess vs 1-slip๏ผ")
print("="*60)
print(df_P_true.round(3).to_string())
print("\n" + "="*60)
print("ใๅ้กๆๅๆฆ็ P โ PyMC DINA ไผฐ่จๅผ๏ผๅพ้ฉๅๅผ๏ผใ")
print("="*60)
print(df_P_est.round(3).to_string())
df_P_diff = df_P_est - df_P_true
mab = df_P_diff.abs().values.mean()
rmse = np.sqrt((df_P_diff.values**2).mean())
print(f"\n P ็ฉ้ฃ MAB = {mab:.4f}")
print(f" P ็ฉ้ฃ RMSE = {rmse:.4f}")
============================================================
ใๅ้กๆๅๆฆ็ P โ ็ๅฏฆๅผใ
๏ผๆฏ้กๅชๆ 2 ็จฎไธๅๅผ๏ผguess vs 1-slip๏ผ
============================================================
000 001 010 011 100 101 110 111
Q1 0.15 0.15 0.15 0.15 0.90 0.90 0.90 0.90
Q2 0.20 0.20 0.20 0.20 0.85 0.85 0.85 0.85
Q3 0.10 0.10 0.10 0.10 0.90 0.90 0.90 0.90
Q4 0.25 0.25 0.25 0.25 0.85 0.85 0.85 0.85
Q5 0.15 0.15 0.90 0.90 0.15 0.15 0.90 0.90
Q6 0.20 0.20 0.85 0.85 0.20 0.20 0.85 0.85
Q7 0.10 0.10 0.90 0.90 0.10 0.10 0.90 0.90
Q8 0.25 0.25 0.85 0.85 0.25 0.25 0.85 0.85
Q9 0.15 0.90 0.15 0.90 0.15 0.90 0.15 0.90
Q10 0.20 0.85 0.20 0.85 0.20 0.85 0.20 0.85
Q11 0.10 0.90 0.10 0.90 0.10 0.90 0.10 0.90
Q12 0.25 0.85 0.25 0.85 0.25 0.85 0.25 0.85
Q13 0.10 0.10 0.10 0.10 0.10 0.10 0.90 0.90
Q14 0.15 0.15 0.15 0.15 0.15 0.15 0.85 0.85
Q15 0.10 0.10 0.10 0.10 0.10 0.10 0.90 0.90
Q16 0.10 0.10 0.10 0.10 0.10 0.90 0.10 0.90
Q17 0.15 0.15 0.15 0.15 0.15 0.85 0.15 0.85
Q18 0.10 0.10 0.10 0.10 0.10 0.90 0.10 0.90
Q19 0.10 0.10 0.10 0.10 0.10 0.10 0.10 0.90
Q20 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.85
============================================================
ใๅ้กๆๅๆฆ็ P โ PyMC DINA ไผฐ่จๅผ๏ผๅพ้ฉๅๅผ๏ผใ
============================================================
000 001 010 011 100 101 110 111
Q1 0.167 0.167 0.167 0.167 0.890 0.890 0.890 0.890
Q2 0.205 0.205 0.205 0.205 0.830 0.830 0.830 0.830
Q3 0.128 0.128 0.128 0.128 0.888 0.888 0.888 0.888
Q4 0.254 0.254 0.254 0.254 0.853 0.853 0.853 0.853
Q5 0.165 0.165 0.900 0.900 0.165 0.165 0.900 0.900
Q6 0.192 0.192 0.848 0.848 0.192 0.192 0.848 0.848
Q7 0.120 0.120 0.899 0.899 0.120 0.120 0.899 0.899
Q8 0.248 0.248 0.856 0.856 0.248 0.248 0.856 0.856
Q9 0.147 0.895 0.147 0.895 0.147 0.895 0.147 0.895
Q10 0.199 0.860 0.199 0.860 0.199 0.860 0.199 0.860
Q11 0.105 0.912 0.105 0.912 0.105 0.912 0.105 0.912
Q12 0.258 0.859 0.258 0.859 0.258 0.859 0.258 0.859
Q13 0.100 0.100 0.100 0.100 0.100 0.100 0.895 0.895
Q14 0.135 0.135 0.135 0.135 0.135 0.135 0.873 0.873
Q15 0.112 0.112 0.112 0.112 0.112 0.112 0.909 0.909
Q16 0.087 0.087 0.087 0.087 0.087 0.882 0.087 0.882
Q17 0.151 0.151 0.151 0.151 0.151 0.859 0.151 0.859
Q18 0.100 0.100 0.100 0.100 0.100 0.920 0.100 0.920
Q19 0.087 0.087 0.087 0.087 0.087 0.087 0.087 0.869
Q20 0.162 0.162 0.162 0.162 0.162 0.162 0.162 0.834
P ็ฉ้ฃ MAB = 0.0092
P ็ฉ้ฃ RMSE = 0.0116
# โโ 5c. ฯ๏ผ็ๅฏฆ vs ไผฐ่จ๏ผๅซ 95% HDI๏ผโโโโโโโโโโโโโโโโโโโโโโโ
pi_samples = trace.posterior["pi"].values.reshape(-1, L)
pi_hdi_low = np.percentile(pi_samples, 2.5, axis=0)
pi_hdi_high = np.percentile(pi_samples, 97.5, axis=0)
df_pi = pd.DataFrame({
"ๅฑฌๆง็ตๅ": profile_labels,
"็ๅฏฆ ฯ": pi_true.round(4),
"ไผฐ่จ ฯ": pi_est.round(4),
"SD": pi_samples.std(axis=0).round(4),
"95% HDI ไธ": pi_hdi_low.round(4),
"95% HDI ไธ": pi_hdi_high.round(4),
"ๅๅทฎ": (pi_est - pi_true).round(4),
})
print("\n" + "="*60)
print("ใๅฑฌๆง็ตๅๆฏไพ ฯ โ ็ๅฏฆ vs ไผฐ่จใ")
print("="*60)
print(df_pi.to_string(index=False))
============================================================ ใๅฑฌๆง็ตๅๆฏไพ ฯ โ ็ๅฏฆ vs ไผฐ่จใ ============================================================ ๅฑฌๆง็ตๅ ็ๅฏฆ ฯ ไผฐ่จ ฯ SD 95% HDI ไธ 95% HDI ไธ ๅๅทฎ 000 0.05 0.0606 0.0070 0.0475 0.0746 0.0106 001 0.08 0.0837 0.0080 0.0692 0.1005 0.0037 010 0.08 0.0880 0.0083 0.0726 0.1049 0.0080 011 0.12 0.1000 0.0085 0.0839 0.1169 -0.0200 100 0.10 0.0993 0.0081 0.0834 0.1152 -0.0007 101 0.12 0.1184 0.0083 0.1031 0.1349 -0.0016 110 0.15 0.1420 0.0092 0.1247 0.1608 -0.0080 111 0.30 0.3081 0.0117 0.2849 0.3314 0.0081
# โโ 5d. slip/guess ๅพ้ฉๅไฝ๏ผ95% HDI๏ผโโโโโโโโโโโโโโโโโโโโโโ
slip_samples = trace.posterior["slip"].values.reshape(-1, J)
guess_samples = trace.posterior["guess"].values.reshape(-1, J)
slip_hdi_low = np.percentile(slip_samples, 2.5, axis=0)
slip_hdi_high = np.percentile(slip_samples, 97.5, axis=0)
guess_hdi_low = np.percentile(guess_samples, 2.5, axis=0)
guess_hdi_high = np.percentile(guess_samples, 97.5, axis=0)
df_hdi = pd.DataFrame({
"้ก็ฎ": [f"Q{j+1}" for j in range(J)],
"slip ็ๅฏฆ": slip_true.round(3),
"slip ไผฐ่จ": slip_est.round(3),
"slip 95%ไธ": slip_hdi_low.round(3),
"slip 95%ไธ": slip_hdi_high.round(3),
"guess ็ๅฏฆ": guess_true.round(3),
"guess ไผฐ่จ": guess_est.round(3),
"guess 95%ไธ": guess_hdi_low.round(3),
"guess 95%ไธ": guess_hdi_high.round(3),
})
print("\n" + "="*80)
print("ใslip & guess โ ๅพ้ฉๅๅผ่ 95% HDIใ")
print("="*80)
print(df_hdi.to_string(index=False))
================================================================================ ใslip & guess โ ๅพ้ฉๅๅผ่ 95% HDIใ ================================================================================ ้ก็ฎ slip ็ๅฏฆ slip ไผฐ่จ slip 95%ไธ slip 95%ไธ guess ็ๅฏฆ guess ไผฐ่จ guess 95%ไธ guess 95%ไธ Q1 0.10 0.110 0.090 0.131 0.15 0.167 0.134 0.201 Q2 0.15 0.170 0.147 0.196 0.20 0.205 0.169 0.243 Q3 0.10 0.112 0.092 0.132 0.10 0.128 0.098 0.158 Q4 0.15 0.147 0.124 0.170 0.25 0.254 0.214 0.295 Q5 0.10 0.100 0.081 0.120 0.15 0.165 0.132 0.204 Q6 0.15 0.152 0.129 0.177 0.20 0.192 0.158 0.230 Q7 0.10 0.101 0.082 0.121 0.10 0.120 0.091 0.150 Q8 0.15 0.144 0.121 0.166 0.25 0.248 0.210 0.287 Q9 0.10 0.105 0.086 0.127 0.15 0.147 0.117 0.180 Q10 0.15 0.140 0.117 0.165 0.20 0.199 0.167 0.234 Q11 0.10 0.088 0.070 0.108 0.10 0.105 0.080 0.132 Q12 0.15 0.141 0.120 0.163 0.25 0.258 0.223 0.295 Q13 0.10 0.105 0.082 0.130 0.10 0.100 0.081 0.122 Q14 0.15 0.127 0.102 0.155 0.15 0.135 0.112 0.160 Q15 0.10 0.091 0.069 0.115 0.10 0.112 0.091 0.135 Q16 0.10 0.118 0.094 0.145 0.10 0.087 0.069 0.107 Q17 0.15 0.141 0.115 0.170 0.15 0.151 0.128 0.176 Q18 0.10 0.080 0.061 0.102 0.10 0.100 0.081 0.122 Q19 0.10 0.131 0.100 0.163 0.10 0.087 0.070 0.105 Q20 0.15 0.166 0.133 0.202 0.15 0.162 0.142 0.185
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# PART 6๏ผๅไบบๅ้ก็ตๆ๏ผMAP ่ EAP๏ผ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โโ 6a. ่จ็ฎๅพ้ฉ ฮณ๏ผๅไบบๅฐๆฏๅๅฑฌๆง็ตๅ็ๅพ้ฉๆฆ็๏ผโโโโโโโโโโ
# ไฝฟ็จไผฐ่จ็ slip_estใguess_estใpi_est ๅๅๅ่จ็ฎ
log_gamma = np.zeros((N, L))
for l in range(L):
# P(Y_i | ฮฑ_l)
log_p_given_l = np.zeros(N)
for j in range(J):
if eta[l, j] == 1:
p_jl = 1 - slip_est[j]
else:
p_jl = guess_est[j]
log_p_given_l += (
Y[:, j] * np.log(p_jl + 1e-10) +
(1 - Y[:, j]) * np.log(1 - p_jl + 1e-10)
)
log_gamma[:, l] = np.log(pi_est[l] + 1e-10) + log_p_given_l
log_norm = logsumexp(log_gamma, axis=1, keepdims=True)
gamma = np.exp(log_gamma - log_norm) # (N, L) ๅพ้ฉๆฆ็
# โโ 6b. MAP ๅ้ก โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
map_idx = np.argmax(gamma, axis=1)
map_profiles = all_alpha[map_idx]
# โโ 6c. EAP๏ผๅๅฑฌๆง็ๆๆๆๆกๆฆ็๏ผโโโโโโโโโโโโโโโโโโโโโโโโโ
eap = gamma @ all_alpha # (N, K)
# โโ 6d. ็ตๅๅ้ก็ตๆ DataFrame โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
df_class = pd.DataFrame({
"็ๅฏฆๅฑฌๆง็ตๅ": ["".join(map(str, all_alpha[i])) for i in alpha_idx],
"MAP้ ๆธฌ็ตๅ": ["".join(map(str, p)) for p in map_profiles],
"ๅ้กๆญฃ็ขบ": alpha_idx == map_idx,
"MAPๅพ้ฉๆฆ็": gamma[np.arange(N), map_idx].round(3),
"EAP_ฮฑ1": eap[:, 0].round(3),
"EAP_ฮฑ2": eap[:, 1].round(3),
"EAP_ฮฑ3": eap[:, 2].round(3),
})
for l, label in enumerate(profile_labels):
df_class[f"P({label})"] = gamma[:, l].round(3)
print("\n" + "="*60)
print("ใๅไบบๅ้ก็ตๆ๏ผๅ 15 ไบบ๏ผใ")
print("="*60)
print(df_class.head(15).to_string(index=True))
# โโ 6e. ๆด้ซๅ้กๆบ็ขบ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
acc = df_class["ๅ้กๆญฃ็ขบ"].mean()
print(f"\n MAP ๆด้ซๅ้กๆบ็ขบ็ = {acc:.4f} ({acc*100:.1f}%)")
# โโ 6f. ๅๅฑฌๆง็ตๅ็ๅ้กๆบ็ขบ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
print("\n ๅๅฑฌๆง็ตๅๅ้กๆบ็ขบ็๏ผ")
summary = df_class.groupby("็ๅฏฆๅฑฌๆง็ตๅ").agg(
ไบบๆธ=("ๅ้กๆญฃ็ขบ", "count"),
ๆญฃ็ขบๆธ=("ๅ้กๆญฃ็ขบ", "sum"),
ๆบ็ขบ็=("ๅ้กๆญฃ็ขบ", "mean"),
ๅนณๅMAPๅพ้ฉ=("MAPๅพ้ฉๆฆ็", "mean"),
).round(3)
print(summary.to_string())
# โโ 6g. ๆททๆท็ฉ้ฃ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
conf = pd.crosstab(
df_class["็ๅฏฆๅฑฌๆง็ตๅ"],
df_class["MAP้ ๆธฌ็ตๅ"],
rownames=["็ๅฏฆ"], colnames=["้ ๆธฌ"]
)
print("\n" + "="*60)
print("ใๆททๆท็ฉ้ฃ๏ผ่ก=็ๅฏฆ๏ผๅ=้ ๆธฌ๏ผใ")
print("="*60)
print(conf.to_string())
============================================================
ใๅไบบๅ้ก็ตๆ๏ผๅ 15 ไบบ๏ผใ
============================================================
็ๅฏฆๅฑฌๆง็ตๅ MAP้ ๆธฌ็ตๅ ๅ้กๆญฃ็ขบ MAPๅพ้ฉๆฆ็ EAP_ฮฑ1 EAP_ฮฑ2 EAP_ฮฑ3 P(000) P(001) P(010) P(011) P(100) P(101) P(110) P(111)
0 010 010 True 0.972 0.000 0.980 0.008 0.019 0.000 0.972 0.008 0.000 0.000 0.0 0.0
1 110 110 True 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.0 0.0
2 101 101 True 0.984 0.991 0.000 0.993 0.000 0.009 0.000 0.000 0.007 0.984 0.0 0.0
3 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
4 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
5 011 011 True 0.906 0.000 0.999 0.908 0.000 0.001 0.092 0.906 0.000 0.000 0.0 0.0
6 011 011 True 0.941 0.000 0.942 0.999 0.000 0.058 0.001 0.941 0.000 0.000 0.0 0.0
7 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
8 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
9 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
10 100 100 True 0.951 0.951 0.000 0.002 0.047 0.002 0.000 0.000 0.951 0.000 0.0 0.0
11 101 101 True 1.000 1.000 0.000 1.000 0.000 0.000 0.000 0.000 0.000 1.000 0.0 0.0
12 110 110 True 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.0 0.0
13 111 111 True 1.000 1.000 1.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.0 1.0
14 100 100 True 0.979 0.979 0.000 0.000 0.020 0.000 0.000 0.000 0.979 0.000 0.0 0.0
MAP ๆด้ซๅ้กๆบ็ขบ็ = 0.9560 (95.6%)
ๅๅฑฌๆง็ตๅๅ้กๆบ็ขบ็๏ผ
ไบบๆธ ๆญฃ็ขบๆธ ๆบ็ขบ็ ๅนณๅMAPๅพ้ฉ
็ๅฏฆๅฑฌๆง็ตๅ
000 88 77 0.875 0.866
001 127 112 0.882 0.883
010 127 119 0.937 0.922
011 153 135 0.882 0.910
100 145 139 0.959 0.938
101 180 176 0.978 0.969
110 218 214 0.982 0.972
111 462 462 1.000 0.996
============================================================
ใๆททๆท็ฉ้ฃ๏ผ่ก=็ๅฏฆ๏ผๅ=้ ๆธฌ๏ผใ
============================================================
้ ๆธฌ 000 001 010 011 100 101 110 111
็ๅฏฆ
000 77 2 3 0 5 1 0 0
001 7 112 0 6 1 1 0 0
010 4 0 119 2 0 0 2 0
011 0 11 6 135 1 0 0 0
100 2 1 0 0 139 2 1 0
101 0 2 0 0 0 176 0 2
110 0 1 2 0 1 0 214 0
111 0 0 0 0 0 0 0 462