crabbymetrics
  • Home
  • API
  • Binding Crash Course
  • Supervised Learning
    • OLS
    • Ridge
    • Fixed Effects OLS
    • ElasticNet
    • Synthetic Control
    • Logit
    • Multinomial Logit
    • Poisson
    • TwoSLS
    • GMM
    • FTRL
    • MEstimator Poisson
  • Semiparametrics
    • Balancing Weights
    • EPLM
    • Average Derivative
    • Double ML And AIPW
    • Richer Regression
  • Unsupervised Learning
    • PCA And Kernel Basis
  • Ablations
    • Variance Estimators
    • Semiparametric Estimator Comparisons
    • Bridging Finite And Superpopulation
  • Optimization
    • Optimizers
    • GMM With Optimizers
  • Ding: First Course
    • Overview And TOC
    • Ch 1 Correlation And Simpson
    • Ch 2 Potential Outcomes
    • Ch 3 CRE And Fisher RT
    • Ch 4 CRE And Neyman
    • Ch 9 Bridging Finite And Superpopulation
    • Ch 11 Propensity Score
    • Ch 12 Double Robust ATE
    • Ch 13 Double Robust ATT
    • Ch 21 Experimental IV
    • Ch 23 Econometric IV

Multinomial Logit Example

This page mirrors examples/multinomial_logit_example.py.

1 Fit A Multiclass Logit Model

import numpy as np
from pprint import pprint

from crabbymetrics import MultinomialLogit

np.set_printoptions(precision=4, suppress=True)
def softmax(x: np.ndarray) -> np.ndarray:
    x = x - x.max(axis=1, keepdims=True)
    exps = np.exp(x)
    return exps / exps.sum(axis=1, keepdims=True)


rng = np.random.default_rng(3)
n = 1000
k = 3
c = 3
coef = np.array(
    [
        [1.0, -0.5, 0.2],
        [-0.7, 0.9, -0.4],
        [0.2, -0.3, 0.8],
    ]
)
intercept = np.array([0.3, -0.2, 0.0])

x = rng.normal(size=(n, k))
logits = x @ coef.T + intercept
probs = softmax(logits)
y = np.array([rng.choice(c, p=probs[i]) for i in range(n)], dtype=np.int32)

model = MultinomialLogit(alpha=1.0, max_iterations=200)
model.fit(x, y)

print("true intercept:", intercept)
print("true coef:", coef)
pprint(model.summary())
true intercept: [ 0.3 -0.2  0. ]
true coef: [[ 1.  -0.5  0.2]
 [-0.7  0.9 -0.4]
 [ 0.2 -0.3  0.8]]
{'coef': array([[ 0.2477,  0.8949, -0.5722,  0.0565],
       [-0.3414, -0.9238,  0.7843, -0.6401],
       [ 0.0938,  0.0289, -0.2121,  0.5837]]),
 'se': array([[5773.5047, 5773.5124, 5773.4669, 5773.4813],
       [5773.5047, 5773.5124, 5773.4669, 5773.4813],
       [5773.5047, 5773.5124, 5773.4669, 5773.4813]])}