{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MDI 720 : Statistiques\n", "## CD\n", "### *Joseph Salmon*\n", "\n", "This notebook reproduces the pictures for the course \"CD\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from functools import partial\n", "import numpy as np\n", "from os import mkdir, path\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt # for plots\n", "from matplotlib import rc\n", "from matplotlib.patches import Polygon, Circle\n", "\n", "# BEWARE: the prox_collection file is loaded in the Lasso course\n", "from prox_collection import l22_prox, l1_prox, l0_prox, scad_prox, mcp_prox, \\\n", " log_prox, sqrt_prox, enet_prox\n", "from prox_collection import l22_pen, l1_pen, l0_pen, \\\n", " scad_pen, mcp_pen, log_pen, sqrt_pen, enet_pen\n", " \n", "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dirname = \"../prebuiltimages/\"\n", "if not path.exists(dirname):\n", " mkdir(dirname)\n", "\n", "imageformat = '.pdf'\n", "rc('font', **{'family': 'sans-serif', 'sans-serif': ['Computer Modern Roman']})\n", "params = {'axes.labelsize': 12,\n", " 'font.size': 16,\n", " 'legend.fontsize': 16,\n", " 'text.usetex': True,\n", " 'figure.figsize': (8, 6)}\n", "plt.rcParams.update(params)\n", "plt.close(\"all\")\n", "\n", "sns.set_context(\"poster\")\n", "sns.set_palette(\"colorblind\")\n", "sns.set_style(\"white\")\n", "sns.axes_style()\n", "\n", "\n", "###############################################################################\n", "# display function:\n", "\n", "saving = False\n", "\n", "\n", "def my_saving_display(fig, dirname, filename, imageformat, saving=False):\n", " \"\"\"\"Saving with personal function.\"\"\"\n", " filename = filename.replace('.', 'pt') # remove \".\" to avoid floats issues\n", " if saving is True:\n", " dirname + filename + imageformat\n", " image_name = dirname + filename + imageformat\n", " fig.savefig(image_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "###############################################################################\n", "# plotting level set function\n", "\n", "def plotting_level_set(func, Y, X, name, precision=12):\n", " \"\"\" plotting level sets\"\"\"\n", " fig1 = plt.figure(figsize=(6, 6))\n", " plt.contourf(X, Y, func(X, Y), precision, alpha=.75, cmap=plt.cm.hot)\n", " plt.contour(X, Y, func(X, Y), precision, colors='black', linewidth=1)\n", " plt.show()\n", " my_saving_display(fig1, dirname, name, '.svg')\n", " return" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# quadratic level set\n", "\n", "def funct_quad_bis(X, Y):\n", " \"\"\" quadratic function to be displayed\"\"\"\n", " return 0.5 * (3 * X ** 2 + 6 * Y ** 2 + 4 * (X * Y)) - 2 * X + 8 * Y\n", "Y, X = np.mgrid[-8:5:100j, -5:8:100j]\n", "name = \"quadractic_level_set\"\n", "plotting_level_set(funct_quad_bis, Y, X, name)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# separable level set case 1\n", "\n", "def funct_separable(X, Y):\n", " \"\"\" separable function to be displayed\"\"\"\n", " return 100 * (np.abs(X) + np.abs(Y))\n", "\n", "Y, X = np.mgrid[-5:5:100j, -5:5:100j]\n", "name = \"separable_level_set\"\n", "plotting_level_set(funct_separable, Y, X, name, 12)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# separable level set case 1\n", "\n", "\n", "def funct_separable_bis(X, Y):\n", " \"\"\" separable function to be displayed\"\"\"\n", " return 10 * (np.abs(X) + np.abs(Y)) + \\\n", " (0.5 * (3 * X ** 2 + 6 * Y ** 2 + 4 * (X * Y)) - 2 * X + 8 * Y)\n", "\n", "X, Y = np.mgrid[-5:5:100j, -5:5:100j]\n", "name = \"separable_level_set10\"\n", "plotting_level_set(funct_separable_bis, Y, X, name, 12)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "###############################################################################\n", "# Non convex level set with l_1/2 pseudo norm penalty\n", "\n", "\n", "def funct_non_cvx_sqrt(X, Y, threshold=3):\n", " \"\"\" non-cvx function to be displayed\"\"\"\n", " z = threshold * ((np.sqrt(np.abs(X)) + np.sqrt(np.abs(Y)))) + \\\n", " (X - 1) ** 2 / 2 + (Y - 1) ** 2 / 2\n", " return z\n", "\n", "threshold = 3\n", "func_non_cvx = partial(funct_non_cvx_sqrt, threshold=threshold)\n", "name = \"non_cvx_sqrt_level_set\"\n", "plotting_level_set(func_non_cvx, Y, X, name, 12)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# Non convex level set with log penalty\n", "\n", "\n", "def funct_non_cvx_log(X, Y, threshold=1, eps=0.1):\n", " \"\"\" non-cvx function to be displayed\"\"\"\n", " z = threshold * np.log(1 + (np.abs(X)) / eps) + \\\n", " threshold * np.log(1 + (np.abs(Y)) / eps) + \\\n", " (X - 1) ** 2 / 2 + (Y - 1) ** 2 / 2\n", " return z\n", "\n", "threshold = 1\n", "eps = 0.1\n", "\n", "func_non_cvx = partial(funct_non_cvx_sqrt, threshold=threshold, eps=eps)\n", "name = \"non_cvx_log_level_set\"\n", "plotting_level_set(funct_non_cvx_log, Y, X, name, 12)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "###############################################################################\n", "# Non convex level set with MCP penalty, case with two local minima.\n", "\n", "\n", "def funct_non_cvx_mcp(X, Y, threshold=1, gamma=1.5):\n", " \"\"\" non-cvx function to be displayed\"\"\"\n", " z = mcp_pen(X, threshold, gamma) + mcp_pen(Y, threshold, gamma) + 0.1 * ((X - 2) ** 2 + (Y - 2) ** 2)\n", " return z\n", "\n", "\n", "threshold=0.94\n", "gamma=threshold * 1.2\n", "\n", "X, Y=np.mgrid[-0.5:3:100j, -0.5:3:100j]\n", "\n", "func_non_cvx=partial(funct_non_cvx_mcp, threshold = threshold, gamma = gamma)\n", "name=\"non_cvx_mcp_level_set\"\n", "plotting_level_set(func_non_cvx, Y, X, name, 20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Plot the surface.\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import cm\n", "fig = plt.figure()\n", "ax = fig.gca(projection='3d')\n", "surf = ax.plot_surface(X, Y, funct_non_cvx_mcp(X, Y, threshold=threshold, gamma=gamma), cmap=cm.coolwarm,\n", " linewidth=0, antialiased=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "##############################################################################\n", "# non separable level set\n", "\n", "\n", "def funct_non_sep(X, Y):\n", " \"\"\" non separable function to be displayed\"\"\"\n", " sqrt3 = np.sqrt(3)\n", " return (np.abs(sqrt3 * X + 1 + Y) + 2 * np.abs(sqrt3 * Y - X + 1))\n", "\n", "X, Y = np.mgrid[-5:5:100j, -5:5:100j]\n", "name = \"non_separable_level_set\"\n", "plotting_level_set(funct_non_sep, Y, X, name, 8)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "###############################################################################\n", "# ploting prox operators\n", "\n", "\n", "def plot_prox(x, threshold, prox, label, image_name, title):\n", " \"\"\"Function to plot and save prox operators.\"\"\"\n", " z = np.zeros(x.shape)\n", " for i, value in enumerate(np.nditer(x)):\n", " z[i] = prox(value, threshold)\n", "\n", " fig0 = plt.figure(figsize=(6, 6))\n", " ax1 = plt.subplot(111)\n", " ax1.plot(x, z, label=label)\n", " ax1.plot(x, x, 'k--', linewidth=1)\n", " plt.legend(loc=\"upper left\", fontsize=34)\n", " ax1.get_yaxis().set_ticks([])\n", " ax1.get_xaxis().set_ticks([])\n", " plt.title(title)\n", " my_saving_display(fig0, dirname, image_name, imageformat)\n", " return" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = np.arange(-10, 10, step=0.01)\n", "\n", "# No penalty\n", "prox = l1_prox\n", "image_name = \"no_pen_orth_1d\"\n", "label = r\"$\\eta_{0}$\"\n", "plot_prox(x, 0, prox, label, image_name,'No penalty')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Log prox\n", "threshold = 4.5\n", "epsilon = .5\n", "label = r\"$\\eta_{\\rm {log},\\lambda,\\gamma}$\"\n", "image_name = \"log_orth_1d\"\n", "prox = partial(log_prox, epsilon=epsilon)\n", "plot_prox(x, threshold, prox, label, image_name, 'Log prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# MCP prox\n", "threshold = 3\n", "gamma = 2.5\n", "label = r\"$\\eta_{\\rm {MCP},\\lambda,\\gamma}$\"\n", "image_name = \"mcp_orth_1d\"\n", "prox = partial(mcp_prox, gamma=gamma)\n", "plot_prox(x, threshold, prox, label, image_name, 'MCP prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# SCAD prox\n", "label = r\"$\\eta_{\\rm {SCAD},\\lambda,\\gamma}$\"\n", "image_name = \"scad_orth_1d\"\n", "prox = partial(scad_prox, gamma=gamma)\n", "plot_prox(x, threshold, prox, label, image_name, 'SCAD prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# L1 prox\n", "prox = l1_prox\n", "image_name = \"l1_orth_1d\"\n", "label = r\"$\\eta_{\\rm {ST},\\lambda}$\"\n", "plot_prox(x, threshold, prox, label, image_name, 'L1 prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# l22 prox\n", "prox = l22_prox\n", "label = r\"$\\eta_{\\rm {Ridge},\\lambda}$\"\n", "image_name = \"l22_orth_1d\"\n", "plot_prox(x, threshold, prox, label, image_name, 'L22 prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Enet prox\n", "beta = 1\n", "label = r\"$\\eta_{\\rm {Enet},\\lambda,\\gamma}$\"\n", "image_name = \"enet_orth_1d\"\n", "prox = partial(enet_prox, beta=beta)\n", "plot_prox(x, threshold, prox, label, image_name, 'Enet prox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Sqrt prox\n", "label = r\"$\\eta_{\\rm {sqrt},\\lambda}$\"\n", "image_name = \"sqrt_orth_1d\"\n", "prox = sqrt_prox\n", "plot_prox(x, threshold, prox, label, image_name, 'Sqrt prox')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# L0 prox\n", "threshold = 4.5\n", "label = r\"$\\eta_{\\rm {HT},\\lambda}$\"\n", "image_name = \"l0_orth_1d\"\n", "prox = l0_prox\n", "plot_prox(x, threshold, prox, label, image_name, 'L0 prox')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }