{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SD-TSIA204 : Linear models\n", "## Least squares definition: 2 variables and 3D visualization\n", "### *Joseph Salmon*\n", "\n", "This notebook reproduces the pictures for the course \"LeastSquare_Def\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import statsmodels.api as sm\n", "import matplotlib.pyplot as plt\n", "from matplotlib import rc\n", "import seaborn as sns\n", "from os import mkdir, path\n", "from mpl_toolkits.mplot3d import Axes3D\n", "# interaction mode better for 3D\n", "%matplotlib notebook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Plot initialization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dirname = \"../srcimages/\"\n", "if not path.exists(dirname):\n", " mkdir(dirname)\n", "\n", "imageformat = '.pdf'\n", "rc('font', **{'family': 'sans-serif', 'sans-serif': ['Computer Modern Roman']})\n", "params = {'axes.labelsize': 12,\n", " 'font.size': 12,\n", " 'legend.fontsize': 12,\n", " 'xtick.labelsize': 10,\n", " 'ytick.labelsize': 10,\n", " 'text.usetex': True,\n", " 'figure.figsize': (8, 6)}\n", "plt.rcParams.update(params)\n", "plt.close(\"all\")\n", "\n", "# sns.set_context(\"poster\")\n", "sns.set_palette(\"colorblind\")\n", "sns.axes_style()\n", "sns.set_style({'legend.frameon': True})\n", "color_blind_list = sns.color_palette(\"colorblind\", 8)\n", "my_orange = color_blind_list[2]\n", "my_green = color_blind_list[1]\n", "\n", "###############################################################################\n", "# display function:\n", "\n", "saving = False\n", "\n", "\n", "def my_saving_display(fig, dirname, filename, imageformat):\n", " \"\"\"\"Saving with personal function.\"\"\"\n", " filename = filename.replace('.', 'pt') # remove \".\" to avoid floats issues\n", " if saving is True:\n", " dirname + filename + imageformat\n", " image_name = dirname + filename + imageformat\n", " fig.savefig(image_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3D case drawing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.close(\"all\")\n", "\n", "# Load data\n", "url = 'http://vincentarelbundock.github.io/Rdatasets/csv/datasets/trees.csv'\n", "dat3 = pd.read_csv(url)\n", "\n", "# Fit regression model\n", "X = dat3[['Girth', 'Height']]\n", "X = sm.add_constant(X)\n", "y = dat3['Volume']\n", "results = sm.OLS(y, X).fit().params\n", "\n", "\n", "XX = np.arange(8, 22, 0.5)\n", "YY = np.arange(64, 90, 0.5)\n", "xx, yy = np.meshgrid(XX, YY)\n", "zz = results[0] + results[1] * xx + results[2] * yy\n", "\n", "\n", "fig = plt.figure()\n", "ax = Axes3D(fig)\n", "\n", "ax.set_xlabel('Girth')\n", "ax.set_ylabel('Height')\n", "ax.set_zlabel('Volume')\n", "ax.set_zlim(5, 80)\n", "ax.plot(X['Girth'], X['Height'], y, 'o')\n", "ax.plot_wireframe(xx, yy, zz, rstride=10, cstride=10, alpha=0.3)\n", "ax.plot_surface(xx, yy, zz, alpha=0.3)\n", "my_saving_display(fig, dirname, \"tree_data_plot_regression\", imageformat)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Non trivial minima : 3D visualisation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sns.set_style(\"white\")\n", "\n", "XX = np.arange(-1, 1, 0.05)\n", "YY = XX\n", "xx, yy = np.meshgrid(XX, YY)\n", "zz = (xx - yy) ** 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "ax = Axes3D(fig)\n", "\n", "ax.view_init(elev=20., azim=50)\n", "ax.set_xlabel('$x$')\n", "ax.set_ylabel('$y$')\n", "ax.set_zlabel('$z$')\n", "\n", "plt.axis('off')\n", "\n", "ax.set_xticks([])\n", "ax.set_yticks([])\n", "ax.set_zticks([])\n", "\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "my_saving_display(fig, dirname, \"CN0_2d_non_trivial1\", imageformat)\n", "\n", "ax.view_init(elev=20., azim=90)\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "my_saving_display(fig, dirname, \"CN0_2d_non_trivial2\", imageformat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "ax = Axes3D(fig)\n", "\n", "ax.view_init(elev=20., azim=50)\n", "ax.set_xlabel('$x$')\n", "ax.set_ylabel('$y$')\n", "ax.set_zlabel('$z$')\n", "\n", "plt.axis('off')\n", "\n", "ax.set_xticks([])\n", "ax.set_yticks([])\n", "ax.set_zticks([])\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "ax.view_init(elev=20., azim=130)\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "my_saving_display(fig, dirname, \"CN0_2d_non_trivial3\", imageformat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "ax = Axes3D(fig)\n", "\n", "ax.view_init(elev=20., azim=50)\n", "ax.set_xlabel('$x$')\n", "ax.set_ylabel('$y$')\n", "ax.set_zlabel('$z$')\n", "\n", "plt.axis('off')\n", "\n", "ax.set_xticks([])\n", "ax.set_yticks([])\n", "ax.set_zticks([])\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "ax.view_init(elev=20., azim=170)\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "my_saving_display(fig, dirname, \"CN0_2d_non_trivial4\", imageformat)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "ax = Axes3D(fig)\n", "\n", "ax.view_init(elev=20., azim=50)\n", "ax.set_xlabel('$x$')\n", "ax.set_ylabel('$y$')\n", "ax.set_zlabel('$z$')\n", "\n", "plt.axis('off')\n", "\n", "ax.set_xticks([])\n", "ax.set_yticks([])\n", "ax.set_zticks([])\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "ax.view_init(elev=20., azim=210)\n", "surf = ax.plot_surface(xx, yy, zz, rstride=2, cstride=2,\n", " antialiased=False, alpha=0.5)\n", "my_saving_display(fig, dirname, \"CN0_2d_non_trivial5\", imageformat)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }