{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MDI 720 : Statistiques\n", "## Lasso\n", "### *Joseph Salmon*\n", "\n", "This notebook reproduces the pictures for the course \"Lasso_fr\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt # for plots\n", "from matplotlib import rc\n", "from os import mkdir, path\n", "from functools import partial # functions that act on or return other function\n", "from functions_Lasso import LSLassoCV, PredictionError, \\\n", " ScenarioEquiCor, ridge_path, refitting, my_nonzeros\n", "from sklearn.linear_model import LassoCV, RidgeCV, ElasticNetCV, Lasso, \\\n", " lasso_path\n", "from sklearn.linear_model import enet_path\n", "from matplotlib.patches import Polygon, Circle\n", "from prox_collection import l22_objective, l1_objective, l0_objective, \\\n", " scad_objective, mcp_objective, log_objective, sqrt_objective, \\\n", " enet_objective, l22_pen, l1_pen, l0_pen, scad_pen,\\\n", " mcp_pen, log_pen, sqrt_pen, enet_pen\n", "\n", "np.random.seed(seed=666)\n", "\n", "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dirname = \"../prebuiltimages/\"\n", "if not path.exists(dirname):\n", " mkdir(dirname)\n", "\n", "imageformat = '.pdf'\n", "rc('font', **{'family': 'sans-serif', 'sans-serif': ['Computer Modern Roman']})\n", "params = {'axes.labelsize': 12,\n", " 'font.size': 16,\n", " 'legend.fontsize': 16,\n", " 'text.usetex': True,\n", " 'figure.figsize': (8, 6)}\n", "plt.rcParams.update(params)\n", "plt.close(\"all\")\n", "\n", "sns.set_context(\"poster\")\n", "sns.set_palette(\"colorblind\")\n", "sns.set_style(\"white\")\n", "sns.axes_style()\n", "\n", "\n", "###############################################################################\n", "# display function:\n", "\n", "saving = False\n", "\n", "\n", "def my_saving_display(fig, dirname, filename, imageformat, saving=False):\n", " \"\"\"\"Saving with personal function.\"\"\"\n", " filename = filename.replace('.', 'pt') # remove \".\" to avoid floats issues\n", " if saving is True:\n", " dirname + filename + imageformat\n", " image_name = dirname + filename + imageformat\n", " fig.savefig(image_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "###############################################################################\n", "# Motivation for Lasso: projection on balls / squares:\n", "\n", "def funct_quadratic(X, Y):\n", " \"\"\" quadratic function to be displayed\"\"\"\n", " X1 = 1.5\n", " X2 = 1.9\n", " theta = -np.pi / 3 - 0.1\n", " c = np.cos(theta)\n", " s = np.sin(theta)\n", " elong = 0.3\n", " return ((c * (X - X1) + s * (Y - X2)) ** 2 +\n", " elong * (-s * (X - X1) + c * (Y - X2)) ** 2)\n", "\n", "spacing = funct_quadratic(0, 1) / 6 * np.arange(70) # for level lines\n", "Y, X = np.mgrid[-1.2:2.1:100j, -1.2:2.1:100j]\n", "\n", "\n", "# Projection over a ball\n", "fig3 = plt.figure(figsize=(6, 6))\n", "ax = plt.subplot(111)\n", "ax.get_yaxis().set_ticks([])\n", "ax.get_xaxis().set_ticks([])\n", "plt.axhline(0, color='white', zorder=4)\n", "plt.axvline(0, color='white', zorder=4)\n", "pt = 0.37\n", "plt.plot(pt, np.sqrt(1 - pt ** 2), '.', markersize=20,\n", " color=(1, 0., 0.), zorder=5)\n", "plt.contour(X, Y, funct_quadratic(X, Y), spacing, colors='k', linewidths=1.2)\n", "plt.contourf(X, Y, funct_quadratic(X, Y), spacing, alpha=.75, cmap=plt.cm.hot)\n", "circle1 = Circle((0, 0), 1, color=(0, 178. / 255, 236. / 256), ec='k',\n", " alpha=1, zorder=2)\n", "fig3.gca().add_artist(circle1)\n", "plt.show()\n", "my_saving_display(fig3, dirname, \"l2_ball_projection\", imageformat)\n", "\n", "\n", "# Projection over a diamond\n", "fig3 = plt.figure(figsize=(6, 6))\n", "ax = plt.subplot(111)\n", "ax.get_yaxis().set_ticks([])\n", "ax.get_xaxis().set_ticks([])\n", "plt.axhline(0, color='white', zorder=4)\n", "plt.axvline(0, color='white', zorder=4)\n", "pt = 0\n", "plt.plot(pt, np.sqrt(1 - pt ** 2), '.', markersize=20, color=(1, 0., 0.),\n", " zorder=5)\n", "plt.contourf(X, Y, funct_quadratic(X, Y), spacing, alpha=.75, cmap=plt.cm.hot)\n", "plt.contour(X, Y, funct_quadratic(X, Y), spacing, colors='k', linewidths=1.2)\n", "polygon = Polygon([[1, 0], [0, 1], [-1, 0], [0, -1]],\n", " color=(0, 178. / 255, 236. / 256), ec='k', alpha=1, zorder=2)\n", "fig3.gca().add_artist(polygon)\n", "plt.show()\n", "my_saving_display(fig3, dirname, \"l1_ball_projection\", imageformat)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# Lasso\n", "\n", "n_features = 40\n", "n_samples = 60\n", "\n", "y, theta_true, X = ScenarioEquiCor(n_samples=n_samples,\n", " n_features=n_features,\n", " sig_noise=1,\n", " rho=0.5, s=5, normalize=True)\n", "\n", "alpha_max = 1e1\n", "eps = 1e-3\n", "n_alphas = 50 # grid size\n", "alphas = np.logspace(np.log10(alpha_max), np.log10(alpha_max * eps),\n", " num=n_alphas)\n", "_, theta_lasso, _ = lasso_path(X, y, alphas=alphas, fit_intercept=False,\n", " return_models=False)\n", "\n", "# plot lasso path\n", "fig1 = plt.figure(figsize=(12, 8))\n", "plt.title(\"Lasso path: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", "ax1 = fig1.add_subplot(111)\n", "ax1.plot(alphas, np.transpose(theta_lasso), linewidth=3)\n", "ax1.set_xscale('log')\n", "ax1.set_xlabel(r\"$\\lambda$\")\n", "ax1.set_ylabel(\"Coefficient value\")\n", "ax1.set_ylim([-1, 3])\n", "sns.despine()\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"Lasso_path\", imageformat)\n", "\n", "# nb of folds for CV\n", "CV = 5\n", "clf = LassoCV(alphas=alphas, fit_intercept=False, normalize=False, cv=CV)\n", "clf.fit(X, y)\n", "\n", "coef_lasso = clf.coef_\n", "alpha_CV = clf.alpha_\n", "index_lasso = np.where(alphas == alpha_CV)[0][0]\n", "\n", "# plot lasso path with CV choice\n", "ax1.axvline(clf.alpha_, color='K', linestyle='-', linewidth=3)\n", "plt.annotate(r\"$CV={0}$\".format(CV).format(CV), xy=(1.2 * alpha_CV, +0.8),\n", " xycoords='data', xytext=(0, 80), textcoords='offset points',\n", " fontsize=18)\n", "my_saving_display(fig1, dirname, \"Lasso_path_CV\", imageformat)\n", "\n", "\n", "# Support and support size:\n", "support_lasso_cv = my_nonzeros(coef_lasso)\n", "print('Sparsity level for LassoCV: ' + str(support_lasso_cv.shape[0]))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# LSLasso: refitting on the support\n", "\n", "theta_lslasso, _, _ = refitting(theta_lasso, X, y)\n", "\n", "fig1 = plt.figure(figsize=(12, 8))\n", "plt.title(\"LSLasso path: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", "ax1 = fig1.add_subplot(111)\n", "ax1.plot(alphas, np.transpose(theta_lslasso), linewidth=3)\n", "ax1.set_xscale('log')\n", "ax1.set_xlabel(r\"$\\lambda$\")\n", "ax1.set_ylabel(\"Coefficient value\")\n", "ax1.set_ylim([-1, 3])\n", "sns.despine()\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"LSLasso_path\", imageformat)\n", "\n", "\n", "# LSLasso: refitting the on the support, tuning by CV\n", "coef_lslasso, index_lslasso = LSLassoCV(X, y, alphas, cv=CV, max_iter=10000,\n", " tol=1e-7, fit_intercept=False)\n", "\n", "alpha_LSCV = alphas[index_lslasso]\n", "\n", "index_lslasso = index_lslasso[0]\n", "ax1.axvline(alpha_LSCV, color='K', linestyle='-', linewidth=3)\n", "plt.annotate(r\"$CV={0}$\".format(CV), xy=(1.2 * alphas[index_lslasso], 1.6),\n", " xycoords='data', xytext=(-80, 80), textcoords='offset points',\n", " fontsize=18)\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"LSLasso_path_CV\", imageformat)\n", "\n", "\n", "# Support and support size:\n", "support_lslasso_cv = my_nonzeros(coef_lslasso)\n", "print('Sparsity level for LSLassoCV: ' + str(support_lslasso_cv.shape[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# signal with Lasso and LSLasso: illustrating the Lasso bias\n", "\n", "fig5 = plt.figure(figsize=(10, 6))\n", "ax5 = fig5.add_subplot(111)\n", "ax5.plot(theta_true, 'k', label=\"True signal\")\n", "ax5.set_ylim([0, 1.4])\n", "ax5.plot(theta_lasso[:, index_lslasso], '--', label=\"Lasso\")\n", "plt.title(\" Signal estimation: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", "plt.legend()\n", "sns.despine()\n", "plt.show()\n", "my_saving_display(fig5, dirname, \"Estimation_signal\", imageformat)\n", "\n", "ax5.plot(theta_lslasso[:, index_lslasso], ':', label=\"LSLasso\")\n", "plt.legend()\n", "plt.show()\n", "my_saving_display(fig5, dirname, \"Estimation_signal_withLS\", imageformat)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "###############################################################################\n", "# Elastic Net\n", "\n", "def enet_plot(l1_ratio):\n", " \"\"\"Function plotting enet_path for some tuning parameter.\"\"\"\n", " _, theta_enet, _ = enet_path(X, y, alphas=alphas, fit_intercept=False,\n", " l1_ratio=l1_ratio, return_models=False)\n", " fig1 = plt.figure(figsize=(12, 8))\n", " plt.title(\"Enet path: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", " ax1 = fig1.add_subplot(111)\n", " ax1.plot(alphas, np.transpose(theta_enet), linewidth=3)\n", " ax1.set_xscale('log')\n", " ax1.set_xlabel(r\"$\\lambda$\")\n", " ax1.set_ylabel(\"Coefficient value\")\n", " ax1.set_ylim([-1, 2])\n", " sns.despine()\n", " plt.show()\n", " filename = \"Enet_path\" + str(l1_ratio)\n", " filename = filename.replace(\".\", \"\")\n", " my_saving_display(fig1, dirname, filename, imageformat)\n", " return theta_enet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet1 = enet_plot(1.00)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet099 = enet_plot(0.99)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet095 = enet_plot(0.95)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet090 = enet_plot(0.90)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet075 = enet_plot(0.75)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet05 = enet_plot(0.50)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet025 = enet_plot(0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet010 = enet_plot(0.10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet005 = enet_plot(0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet001 = enet_plot(0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "theta_enet0 = enet_plot(0.00)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# Ridge\n", "\n", "alpha_max = 1e4\n", "eps = 1e-9\n", "alphas_ridge = np.logspace(np.log10(alpha_max), np.log10(alpha_max * eps),\n", " num=n_alphas)\n", "theta_ridge = ridge_path(X, y, alphas_ridge)\n", "clf_ridge = RidgeCV(alphas=alphas, fit_intercept=False, normalize=False, cv=CV)\n", "clf_ridge.fit(X, y)\n", "\n", "fig1 = plt.figure(figsize=(12, 8))\n", "ax1 = fig1.add_subplot(111)\n", "plt.title(\"Ridge path: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", "ax1.plot(alphas_ridge, np.transpose(theta_ridge), linewidth=3)\n", "ax1.set_xscale('log')\n", "ax1.set_xlabel(r\"$\\lambda$\")\n", "ax1.set_ylabel(\"Coefficient value\")\n", "ax1.set_ylim([-1, 3])\n", "sns.despine()\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"Ridge_path\", imageformat)\n", "\n", "# plot ridge path with CV choice\n", "ax1.axvline(clf_ridge.alpha_, color='K', linestyle='-', linewidth=3)\n", "plt.annotate(r\"$CV={0}$\".format(CV), xy=(1.2 * clf_ridge.alpha_, 1.8),\n", " xycoords='data', xytext=(-80, 80), textcoords='offset points',\n", " fontsize=18)\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"Ridge_path_CV\", imageformat)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# plot prediction error\n", "\n", "fig1 = plt.figure(figsize=(12, 8))\n", "ax1 = fig1.add_subplot(111)\n", "plt.title(\"Prediction Error: \" + r\"$p={0}, n={1} $\".format(n_features,\n", " n_samples), fontsize=16)\n", "ax1.plot(alphas, PredictionError(X, theta_lasso, theta_true), linewidth=3,\n", " label=\"Lasso\")\n", "ax1.plot(alphas, PredictionError(X, theta_lslasso, theta_true), linewidth=3,\n", " label=\"LSLasso\")\n", "\n", "ax1.set_xscale('log')\n", "ax1.set_yscale('log')\n", "ax1.set_xlabel(r\"$\\lambda$\")\n", "ax1.set_ylabel(\"Prediction Error\")\n", "\n", "LassoCVMSE = PredictionError(X, theta_lasso[:, index_lasso], theta_true)\n", "LSLassoCVMSE = PredictionError(X, theta_lslasso[:, index_lslasso], theta_true)\n", "\n", "current_palette = sns.color_palette()\n", "\n", "ax1.scatter(alpha_CV, LassoCVMSE, color=current_palette[0], linewidth=10)\n", "ax1.scatter(alpha_LSCV, LSLassoCVMSE, color=current_palette[1], linewidth=10)\n", "\n", "ax1.axvline(alpha_CV, linestyle='--', color=current_palette[0], linewidth=3)\n", "ax1.axvline(alpha_LSCV, linestyle='--', color=current_palette[1], linewidth=3)\n", "\n", "plt.annotate('CV-Lasso', xy=(alpha_CV, LassoCVMSE), xycoords='data',\n", " color=current_palette[0], xytext=(-74, -50),\n", " textcoords='offset points', fontsize=18)\n", "plt.annotate('CV-LSLasso', xy=(alpha_LSCV, LSLassoCVMSE), xycoords='data',\n", " color=current_palette[1], xytext=(20, -20),\n", " textcoords='offset points', fontsize=18)\n", "plt.annotate('CV-Ridge', xy=(alphas[0] - 6, 0.1), xycoords='data',\n", " color='K', xytext=(0, 80), textcoords='offset points',\n", " fontsize=18)\n", "\n", "ax1.set_xlim(alphas[-1], alphas[0])\n", "ax1.axhline(PredictionError(X, clf_ridge.coef_, theta_true), linestyle='-.',\n", " color='K', linewidth=3)\n", "sns.despine()\n", "plt.show()\n", "my_saving_display(fig1, dirname, \"various_path_prediction_error\", imageformat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# Adaptive Lasso:\n", "\n", "# index_choice = index_lasso # same as the one obtained by LassoCV\n", "index_choice = index_lslasso # same as the one obtained by LSLassoCV\n", "alpha = alphas[index_choice]\n", "# index_choice = 24\n", "# alpha = alphas[index_choice]\n", "\n", "\n", "def sqr_abs(w):\n", " \"\"\" square root of absolute value: adapative lasso penalty\"\"\"\n", " return np.sqrt(np.abs(w))\n", "\n", "\n", "def sqr_abs_prime(w):\n", " \"\"\" square root of absolute value: adapative lasso penalty\"\"\"\n", " return 1. / (2. * np.sqrt(np.abs(w)) + np.finfo(float).eps)\n", "\n", "\n", "n_samples, n_features = X.shape\n", "\n", "\n", "def primal_obj(w):\n", " \"\"\" objective function to optimize in adapative lasso\"\"\"\n", " return 1. / (2 * n_samples) * np.sum((y - np.dot(X, w)) ** 2) \\\n", " + alpha * np.sum(sqr_abs(w))\n", "\n", "weights = np.ones(n_features)\n", "n_lasso_iterations = 5\n", "\n", "print(\"Prd. risk Lasso :\" +\\\n", " str(PredictionError(X, theta_lasso[:, index_choice], theta_true)))\n", "print(\"Prd. risk LSLasso:\" +\\\n", " str(PredictionError(X, theta_lslasso[:, index_choice], theta_true)))\n", "\n", "\n", "for k in range(n_lasso_iterations):\n", " X_w = X / weights[np.newaxis, :]\n", " clf_adapt = Lasso(alpha=alpha, fit_intercept=False)\n", " clf_adapt.fit(X_w, y)\n", " coef_ = clf_adapt.coef_ / weights\n", " weights = sqr_abs_prime(coef_)\n", " print(\"Objective: \" + str(primal_obj(coef_)) + \" Prediction Risk: \" + \\\n", " str(PredictionError(X, coef_, theta_true)))\n", " # Objective should decrease with k, not necessarily Prediction Risk\n", "np.sum(clf_adapt.coef_ != 0.0)\n", "# print np.mean((clf_adapt.coef_ != 0.0) == (theta_true != 0.0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_pen(x, threshold, pen, image_name, title):\n", " \"\"\" function to plot and save pen functions\"\"\"\n", " xx = pen(x, threshold)\n", " fig0 = plt.figure(figsize=(6, 6))\n", " ax1 = plt.subplot(111)\n", " ax1.plot(x, xx, label=label)\n", " ax1.get_yaxis().set_ticks([])\n", " ax1.get_xaxis().set_ticks([])\n", " ax1.set_ylim(-0.1, np.max(xx) * 1.05)\n", " ax1.set_xlim(-10, 10)\n", " plt.title(title)\n", " plt.show()\n", " my_saving_display(fig0, dirname, image_name + \"pen\", imageformat)\n", " return" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = np.arange(-10, 10, step=0.01)\n", "\n", "# No penalty\n", "image_name = \"no_pen_orth_1d\"\n", "label = r\"$\\eta_{0}$\"\n", "pen_l1 = l1_pen\n", "plot_pen(x, 0, pen_l1, image_name,'No penalty')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# log\n", "threshold = 4.5\n", "epsilon = .5\n", "label = r\"$\\eta_{\\rm {log},\\lambda,\\gamma}$\"\n", "image_name = \"log_orth_1d\"\n", "pen_log = partial(log_pen, epsilon=epsilon)\n", "plot_pen(x, threshold, pen_log, image_name,'Log')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "# mcp prox\n", "threshold = 3\n", "gamma = 2.5\n", "label = r\"$\\eta_{\\rm {MCP},\\lambda,\\gamma}$\"\n", "image_name = \"mcp_orth_1d\"\n", "pen_mcp = partial(mcp_pen, gamma=gamma)\n", "plot_pen(x, threshold, pen_mcp, image_name,'MCP')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# SCAD\n", "label = r\"$\\eta_{\\rm {SCAD},\\lambda,\\gamma}$\"\n", "image_name = \"scad_orth_1d\"\n", "pen_scad = partial(scad_pen, gamma=gamma)\n", "plot_pen(x, threshold, pen_scad, image_name,'SCAD')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# L1\n", "image_name = \"l1_orth_1d\"\n", "label = r\"$\\eta_{\\rm {ST},\\lambda}$\"\n", "pen_l1 = l1_pen\n", "plot_pen(x, threshold, pen_l1, image_name,'L1')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# L22 \n", "label = r\"$\\eta_{\\rm {Ridge},\\lambda}$\"\n", "image_name = \"l22_orth_1d\"\n", "pen_l22 = l22_pen\n", "plot_pen(x, threshold, pen_l22, image_name,'L22')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Enet\n", "beta = 1\n", "label = r\"$\\eta_{\\rm {Enet},\\lambda,\\gamma}$\"\n", "image_name = \"enet_orth_1d\"\n", "pen_enet = partial(enet_pen, beta=beta)\n", "plot_pen(x, threshold, pen_enet, image_name,'Enet')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Sqrt\n", "label = r\"$\\eta_{\\rm {sqrt},\\lambda}$\"\n", "image_name = \"sqrt_orth_1d\"\n", "pen_sqrt = sqrt_pen\n", "plot_pen(x, threshold, pen_sqrt, image_name,'Sqrt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# L0\n", "threshold = 4.5\n", "label = r\"$\\eta_{\\rm {HT},\\lambda}$\"\n", "image_name = \"l0_orth_1d\"\n", "pen_l0 = l0_pen\n", "plot_pen(x, threshold, pen_l0, image_name,'L0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# ploting penalty functions altogether\n", "sns.set_palette(\"Paired\", 10)\n", "# sns.set_palette(\"colorblind\")\n", "\n", "fig0 = plt.figure(figsize=(6, 6))\n", "ax1 = plt.subplot(111)\n", "\n", "ax1.plot(x, pen_l0(x, threshold), label='l0')\n", "ax1.plot(x, pen_sqrt(x, threshold), label='sqrt')\n", "ax1.plot(x, pen_l22(x, threshold), label='l22')\n", "ax1.plot(x, pen_enet(x, threshold), label='enet')\n", "ax1.plot(x, pen_log(x, threshold), label='log')\n", "ax1.plot(x, pen_mcp(x, threshold), label='mcp')\n", "ax1.plot(x, pen_scad(x, threshold), label='scad')\n", "ax1.plot(x, pen_l1(x, threshold), label='l1')\n", "\n", "ax1.set_ylim(-0.1, 40)\n", "ax1.set_xlim(-10, 10)\n", "ax1.get_yaxis().set_ticks([])\n", "ax1.get_xaxis().set_ticks([])\n", "plt.legend(loc=\"upper center\", fontsize=14)\n", "\n", "plt.show()\n", "my_saving_display(fig0, dirname, \"penalties\", imageformat, saving=saving)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }