pyensmallen + jax

import jax
import jax.numpy as jnp
import numpy as np
import pyensmallen
import time
import optax

# Set random seed for reproducibility
np.random.seed(0)
key = jax.random.PRNGKey(0)
# Set the parameters
K = 4  # number of classes
D = 10  # number of features
N = 10_000  # number of samples

# Generate true coefficients (K categories, last category is reference with zeros)
true_coeffs = np.random.randn(D, K)
true_coeffs[:, -1] = 0  # Set last category coefficients to zero

# Generate features
X = np.random.randn(N, D)

# Generate probabilities and labels
logits = X @ true_coeffs
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
y = np.array([np.random.choice(K, p=p) for p in probs])

# Convert data to JAX arrays
X_jax = jax.device_put(X)
y_jax = jax.device_put(y)
# Define the multinomial logistic regression model
def multinomial_logit(params, X):
    full_params = jnp.column_stack([params.reshape(D, K - 1), jnp.zeros((D, 1))])
    return jax.nn.log_softmax(X @ full_params, axis=1)


# Define the loss function (negative log-likelihood)
def loss(params, X, y):
    logits = multinomial_logit(params, X)
    return -jnp.mean(logits[jnp.arange(y.shape[0]), y])


# Create JAX gradient function - autodiff!
grad_loss = jax.grad(loss)


# Define the objective function for pyensmallen
def objective(params, gradient, X, y):
    params_jax = jax.device_put(params.reshape(D, K - 1))
    loss_value = loss(params_jax, X_jax, y_jax)
    grad = grad_loss(params_jax, X_jax, y_jax)
    gradient[:] = np.array(grad).flatten()
    return float(loss_value)


# Pyensmallen optimization
start_time = time.time()
optimizer = pyensmallen.L_BFGS()
initial_params = np.random.randn(D * (K - 1))
result_ens = optimizer.optimize(
    lambda params, gradient: objective(params, gradient, X_jax, y_jax), initial_params
)
ens_time = time.time() - start_time
estimated_coeffs_ens = np.column_stack([result_ens.reshape(D, K - 1), np.zeros((D, 1))])

Jax

# JAX optimization with Optax
start_time = time.time()
initial_params = jnp.array(initial_params.reshape(D, K - 1))

# Define the Optax optimizer (using Adam as an example)
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(initial_params)


@jax.jit
def step(params, opt_state, X, y):
    loss_value, grads = jax.value_and_grad(loss)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value


params = initial_params
for i in range(2000):
    params, opt_state, _ = step(params, opt_state, X_jax, y_jax)

estimated_coeffs_jax = jnp.column_stack([params, jnp.zeros((D, 1))])
jax_time = time.time() - start_time

comparison

true_coeffs.reshape(-1).shape
(40,)
np.c_[true_coeffs.reshape(-1), estimated_coeffs_ens.reshape(-1), estimated_coeffs_jax.reshape(-1)]
array([[ 1.76405235,  1.8157463 ,  1.8109074 ],
       [ 0.40015721,  0.42656764,  0.42548736],
       [ 0.97873798,  0.98794635,  0.98618829],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.86755799,  1.8628787 ,  1.8578708 ],
       [-0.97727788, -0.99718906, -0.99704143],
       [ 0.95008842,  0.91213478,  0.91011453],
       [ 0.        ,  0.        ,  0.        ],
       [-0.10321885, -0.09043247, -0.09009397],
       [ 0.4105985 ,  0.40371903,  0.40340355],
       [ 0.14404357,  0.1855751 ,  0.18550714],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.76103773,  0.80953626,  0.8079425 ],
       [ 0.12167502,  0.16927964,  0.16907167],
       [ 0.44386323,  0.42807023,  0.42753084],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.49407907,  1.58021527,  1.57713197],
       [-0.20515826, -0.13471932, -0.13484907],
       [ 0.3130677 ,  0.33451144,  0.33372573],
       [ 0.        ,  0.        ,  0.        ],
       [-2.55298982, -2.57082937, -2.5637134 ],
       [ 0.6536186 ,  0.69179848,  0.6927646 ],
       [ 0.8644362 ,  0.86859013,  0.86924212],
       [ 0.        ,  0.        ,  0.        ],
       [ 2.26975462,  2.34909469,  2.34216164],
       [-1.45436567, -1.47021206, -1.47031831],
       [ 0.04575852, -0.01799869, -0.01972243],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.53277921,  1.63790857,  1.63511849],
       [ 1.46935877,  1.56287992,  1.56128166],
       [ 0.15494743,  0.16244727,  0.16245115],
       [ 0.        ,  0.        ,  0.        ],
       [-0.88778575, -0.86998273, -0.86919838],
       [-1.98079647, -1.94830402, -1.94675255],
       [-0.34791215, -0.34480436, -0.34499855],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.23029068,  1.25129001,  1.24912269],
       [ 1.20237985,  1.17328927,  1.17191307],
       [-0.38732682, -0.41164246, -0.41112207],
       [ 0.        ,  0.        ,  0.        ]])
# Compare results
print("Pyensmallen optimization time:", ens_time)
print("JAX optimization time:", jax_time)

mae_ens = np.mean(np.abs(true_coeffs - estimated_coeffs_ens))
mae_jax = np.mean(np.abs(true_coeffs - estimated_coeffs_jax))

print("\nPyensmallen Mean Absolute Error:", mae_ens)
print("JAX Mean Absolute Error:", mae_jax)
Pyensmallen optimization time: 0.8505103588104248
JAX optimization time: 1.1517069339752197

Pyensmallen Mean Absolute Error: 0.026351136330546497
JAX Mean Absolute Error: 0.02586046071016848
def predict(coeffs, X):
    logits = X @ coeffs
    return np.argmax(logits, axis=1)


accuracy_ens = np.mean(predict(estimated_coeffs_ens, X) == y)
accuracy_jax = np.mean(predict(estimated_coeffs_jax, X) == y)

print("\nPyensmallen Accuracy:", accuracy_ens)
print("JAX Accuracy:", accuracy_jax)

final_loss_ens = loss(jax.device_put(estimated_coeffs_ens[:, :-1]), X_jax, y_jax)
final_loss_jax = loss(estimated_coeffs_jax[:, :-1], X_jax, y_jax)

print("\nPyensmallen Final Loss:", final_loss_ens)
print("JAX Final Loss:", final_loss_jax)

Pyensmallen Accuracy: 0.7667
JAX Accuracy: 0.76669997

Pyensmallen Final Loss: 0.5785731153707537
JAX Final Loss: 0.5785740845996223