import jax
import jax.numpy as jnp
import numpy as np
import pyensmallen
import time
import optax
# Set random seed for reproducibility
np.random.seed(0)
key = jax.random.PRNGKey(0)pyensmallen + jax
# Set the parameters
K = 4 # number of classes
D = 10 # number of features
N = 10_000 # number of samples
# Generate true coefficients (K categories, last category is reference with zeros)
true_coeffs = np.random.randn(D, K)
true_coeffs[:, -1] = 0 # Set last category coefficients to zero
# Generate features
X = np.random.randn(N, D)
# Generate probabilities and labels
logits = X @ true_coeffs
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
y = np.array([np.random.choice(K, p=p) for p in probs])
# Convert data to JAX arrays
X_jax = jax.device_put(X)
y_jax = jax.device_put(y)# Define the multinomial logistic regression model
def multinomial_logit(params, X):
full_params = jnp.column_stack([params.reshape(D, K - 1), jnp.zeros((D, 1))])
return jax.nn.log_softmax(X @ full_params, axis=1)
# Define the loss function (negative log-likelihood)
def loss(params, X, y):
logits = multinomial_logit(params, X)
return -jnp.mean(logits[jnp.arange(y.shape[0]), y])
# Create JAX gradient function - autodiff!
grad_loss = jax.grad(loss)
# Define the objective function for pyensmallen
def objective(params, gradient, X, y):
params_jax = jax.device_put(params.reshape(D, K - 1))
loss_value = loss(params_jax, X_jax, y_jax)
grad = grad_loss(params_jax, X_jax, y_jax)
gradient[:] = np.array(grad).flatten()
return float(loss_value)
# Pyensmallen optimization
start_time = time.time()
optimizer = pyensmallen.L_BFGS()
initial_params = np.random.randn(D * (K - 1))
result_ens = optimizer.optimize(
lambda params, gradient: objective(params, gradient, X_jax, y_jax), initial_params
)
ens_time = time.time() - start_time
estimated_coeffs_ens = np.column_stack([result_ens.reshape(D, K - 1), np.zeros((D, 1))])Jax
# JAX optimization with Optax
start_time = time.time()
initial_params = jnp.array(initial_params.reshape(D, K - 1))
# Define the Optax optimizer (using Adam as an example)
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(initial_params)
@jax.jit
def step(params, opt_state, X, y):
loss_value, grads = jax.value_and_grad(loss)(params, X, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
params = initial_params
for i in range(2000):
params, opt_state, _ = step(params, opt_state, X_jax, y_jax)
estimated_coeffs_jax = jnp.column_stack([params, jnp.zeros((D, 1))])
jax_time = time.time() - start_timecomparison
true_coeffs.reshape(-1).shape(40,)
np.c_[true_coeffs.reshape(-1), estimated_coeffs_ens.reshape(-1), estimated_coeffs_jax.reshape(-1)]array([[ 1.76405235, 1.8157463 , 1.8109074 ],
[ 0.40015721, 0.42656764, 0.42548736],
[ 0.97873798, 0.98794635, 0.98618829],
[ 0. , 0. , 0. ],
[ 1.86755799, 1.8628787 , 1.8578708 ],
[-0.97727788, -0.99718906, -0.99704143],
[ 0.95008842, 0.91213478, 0.91011453],
[ 0. , 0. , 0. ],
[-0.10321885, -0.09043247, -0.09009397],
[ 0.4105985 , 0.40371903, 0.40340355],
[ 0.14404357, 0.1855751 , 0.18550714],
[ 0. , 0. , 0. ],
[ 0.76103773, 0.80953626, 0.8079425 ],
[ 0.12167502, 0.16927964, 0.16907167],
[ 0.44386323, 0.42807023, 0.42753084],
[ 0. , 0. , 0. ],
[ 1.49407907, 1.58021527, 1.57713197],
[-0.20515826, -0.13471932, -0.13484907],
[ 0.3130677 , 0.33451144, 0.33372573],
[ 0. , 0. , 0. ],
[-2.55298982, -2.57082937, -2.5637134 ],
[ 0.6536186 , 0.69179848, 0.6927646 ],
[ 0.8644362 , 0.86859013, 0.86924212],
[ 0. , 0. , 0. ],
[ 2.26975462, 2.34909469, 2.34216164],
[-1.45436567, -1.47021206, -1.47031831],
[ 0.04575852, -0.01799869, -0.01972243],
[ 0. , 0. , 0. ],
[ 1.53277921, 1.63790857, 1.63511849],
[ 1.46935877, 1.56287992, 1.56128166],
[ 0.15494743, 0.16244727, 0.16245115],
[ 0. , 0. , 0. ],
[-0.88778575, -0.86998273, -0.86919838],
[-1.98079647, -1.94830402, -1.94675255],
[-0.34791215, -0.34480436, -0.34499855],
[ 0. , 0. , 0. ],
[ 1.23029068, 1.25129001, 1.24912269],
[ 1.20237985, 1.17328927, 1.17191307],
[-0.38732682, -0.41164246, -0.41112207],
[ 0. , 0. , 0. ]])
# Compare results
print("Pyensmallen optimization time:", ens_time)
print("JAX optimization time:", jax_time)
mae_ens = np.mean(np.abs(true_coeffs - estimated_coeffs_ens))
mae_jax = np.mean(np.abs(true_coeffs - estimated_coeffs_jax))
print("\nPyensmallen Mean Absolute Error:", mae_ens)
print("JAX Mean Absolute Error:", mae_jax)Pyensmallen optimization time: 0.8505103588104248
JAX optimization time: 1.1517069339752197
Pyensmallen Mean Absolute Error: 0.026351136330546497
JAX Mean Absolute Error: 0.02586046071016848
def predict(coeffs, X):
logits = X @ coeffs
return np.argmax(logits, axis=1)
accuracy_ens = np.mean(predict(estimated_coeffs_ens, X) == y)
accuracy_jax = np.mean(predict(estimated_coeffs_jax, X) == y)
print("\nPyensmallen Accuracy:", accuracy_ens)
print("JAX Accuracy:", accuracy_jax)
final_loss_ens = loss(jax.device_put(estimated_coeffs_ens[:, :-1]), X_jax, y_jax)
final_loss_jax = loss(estimated_coeffs_jax[:, :-1], X_jax, y_jax)
print("\nPyensmallen Final Loss:", final_loss_ens)
print("JAX Final Loss:", final_loss_jax)
Pyensmallen Accuracy: 0.7667
JAX Accuracy: 0.76669997
Pyensmallen Final Loss: 0.5785731153707537
JAX Final Loss: 0.5785740845996223