import * as React from 'react'
  /* @jsx mdx */
import { mdx } from '@mdx-js/react';
/* @jsxRuntime classic */

/* @jsx mdx */

import DefaultLayout from "/home/runner/work/myedibleenso.github.io/myedibleenso.github.io/src/components/BasicLayout.js";
export const _frontmatter = {};
const layoutProps = {
  _frontmatter
};
const MDXLayout = DefaultLayout;
export default function MDXContent({
  components,
  ...props
}) {
  return <MDXLayout {...layoutProps} {...props} components={components} mdxType="MDXLayout">


    <h1 {...{
      "id": "overview",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#overview",
        "aria-label": "overview permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Overview`}</h1>
    <p>{`In this tutorial, we'll walk through how to implement, train, and evaluate a neural network using PyTorch.`}</p>
    <h1 {...{
      "id": "objectives",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#objectives",
        "aria-label": "objectives permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Objectives`}</h1>
    <ul>
      <li parentName="ul">{`define a simple neural network in PyTorch as a sequence of layers/operations`}</li>
      <li parentName="ul">{`train a simple neural network using gradient descent`}</li>
      <li parentName="ul">{`evaluate the performance of a neural network classifier written in PyTorch`}</li>
    </ul>
    <h1 {...{
      "id": "getting-started",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#getting-started",
        "aria-label": "getting started permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Getting started`}</h1>
    <p>{`In order to follow along, `}<a parentName="p" {...{
        "href": "/courses/general/configuring-your-development-environment"
      }}>{`ensure you've configured your development environment`}</a>{` with a `}<a parentName="p" {...{
        "href": "/tutorials/ubuntu-install-docker"
      }}>{`working docker installation`}</a>{`.`}</p>
    <h1 {...{
      "id": "introduction",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#introduction",
        "aria-label": "introduction permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Introduction`}</h1>
    <p>{`We'll walk through the steps involved in defining, training, and evaluating a simple form of neural network composed of a single neuron, the binonomial logistic regression classifier.  `}</p>
    <p>{`You can execute the code snippets that follow in the provided docker image:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "bash"
    }}><pre parentName="div" {...{
        "className": "language-bash"
      }}><code parentName="pre" {...{
          "className": "language-bash"
        }}><span parentName="code" {...{
            "className": "token function"
          }}>{`docker`}</span>{` run -it -p `}<span parentName="code" {...{
            "className": "token number"
          }}>{`8888`}</span>{`:9999 uazhlt/ling-582-playground:latest`}</code></pre></div>
    <p>{`Open your browser to `}<a parentName="p" {...{
        "href": "localhost:8888",
        "target": "_self",
        "rel": "nofollow"
      }}><code parentName="a" {...{
          "className": "language-text"
        }}>{`localhost:8888`}</code></a>{` and create a new notebook (`}<strong parentName="p">{`New`}</strong>{` `}<span parentName="p" {...{
        "className": "math math-inline"
      }}><span parentName="span" {...{
          "className": "katex"
        }}><span parentName="span" {...{
            "className": "katex-mathml"
          }}><math parentName="span" {...{
              "xmlns": "http://www.w3.org/1998/Math/MathML"
            }}><semantics parentName="math"><mrow parentName="semantics"><mo parentName="mrow">{`→`}</mo></mrow><annotation parentName="semantics" {...{
                  "encoding": "application/x-tex"
                }}>{`\\rightarrow`}</annotation></semantics></math></span><span parentName="span" {...{
            "className": "katex-html",
            "aria-hidden": "true"
          }}><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "0.3669em"
                }
              }}></span><span parentName="span" {...{
                "className": "mrel"
              }}>{`→`}</span></span></span></span></span>{` `}<strong parentName="p">{`Python 3 (ipykernel)`}</strong>{`).  `}</p>
    <div {...{
      "className": "admonition admonition-note alert alert--secondary"
    }}><div parentName="div" {...{
        "className": "admonition-heading"
      }}><h5 parentName="div"><span parentName="h5" {...{
            "className": "admonition-icon"
          }}><svg parentName="span" {...{
              "xmlns": "http://www.w3.org/2000/svg",
              "width": "14",
              "height": "16",
              "viewBox": "0 0 14 16"
            }}><path parentName="svg" {...{
                "fillRule": "evenodd",
                "d": "M6.3 5.69a.942.942 0 0 1-.28-.7c0-.28.09-.52.28-.7.19-.18.42-.28.7-.28.28 0 .52.09.7.28.18.19.28.42.28.7 0 .28-.09.52-.28.7a1 1 0 0 1-.7.3c-.28 0-.52-.11-.7-.3zM8 7.99c-.02-.25-.11-.48-.31-.69-.2-.19-.42-.3-.69-.31H6c-.27.02-.48.13-.69.31-.2.2-.3.44-.31.69h1v3c.02.27.11.5.31.69.2.2.42.31.69.31h1c.27 0 .48-.11.69-.31.2-.19.3-.42.31-.69H8V7.98v.01zM7 2.3c-3.14 0-5.7 2.54-5.7 5.68 0 3.14 2.56 5.7 5.7 5.7s5.7-2.55 5.7-5.7c0-3.15-2.56-5.69-5.7-5.69v.01zM7 .98c3.86 0 7 3.14 7 7s-3.14 7-7 7-7-3.12-7-7 3.14-7 7-7z"
              }}></path></svg></span>{`NOTE `}</h5></div><div parentName="div" {...{
        "className": "admonition-content"
      }}><p parentName="div">{`We didn't use a bind mount in the command above, but you can download your notebook once you've finished using the `}<strong parentName="p">{`File`}</strong>{` menu.`}</p></div></div>
    <p>{`Let's import PyTorch and define a value for our random seed to keep our example reproducible:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`from`}</span>{` typing `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`*`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`from`}</span>{` torch `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` nn
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` torch

SEED `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`42`}</span></code></pre></div>
    <h1 {...{
      "id": "generating-synthetic-data",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#generating-synthetic-data",
        "aria-label": "generating synthetic data permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Generating synthetic data`}</h1>
    <p>{`Since we're going to be writing a classifier, we'll need some data.  To keep things simple, let's generate some synthetic data.  We'll use a utility from `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`scikit-learn`}</code>{`:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`from`}</span>{` sklearn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`datasets `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` make_blobs `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`as`}</span>{` datagen
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`from`}</span>{` sklearn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`model_selection `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` train_test_split

`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{` `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` datagen`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# how many datapoints to generate`}</span>{`
    n_samples`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`100`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` 
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# the dimensionality of each row in our X matrix`}</span>{`
    n_features`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`2`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` 
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# the number of classes to generate`}</span>{`
    centers`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`2`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# for reproducibility`}</span>{`
    random_state`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{`SEED
`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`We've generated an `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`X`}</code>{` matrix and corresponding `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`y`}</code>{` array with two classes.  Our dataset contains only 100 points.  Each row (datapoint) in our `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`X`}</code>{` matrix has two columns (features). `}</p>
    <p>{`Let's split our training data into two partitions: `}<strong parentName="p">{`train`}</strong>{` and `}<strong parentName="p">{`test`}</strong>{`:`}</p>
    <div {...{
      "className": "admonition admonition-info alert alert--info"
    }}><div parentName="div" {...{
        "className": "admonition-heading"
      }}><h5 parentName="div"><span parentName="h5" {...{
            "className": "admonition-icon"
          }}><svg parentName="span" {...{
              "xmlns": "http://www.w3.org/2000/svg",
              "width": "14",
              "height": "16",
              "viewBox": "0 0 14 16"
            }}><path parentName="svg" {...{
                "fillRule": "evenodd",
                "d": "M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z"
              }}></path></svg></span>{`What about validation? `}</h5></div><div parentName="div" {...{
        "className": "admonition-content"
      }}><p parentName="div">{`Normally, we'd also want to create a validation partition.  I encourage you to revisit this step later with that change in mind.`}</p></div></div>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`partitioned_data `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` train_test_split`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`
    X`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# set aside 20% for test`}</span>{`
    test_size`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`.2`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{`
    random_state`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{`SEED`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{`
    shuffle`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token boolean"
          }}>{`True`}</span>{`
`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p><code parentName="p" {...{
        "className": "language-text"
      }}>{`partitioned_data`}</code>{` contains for elements (`}<code parentName="p" {...{
        "className": "language-text"
      }}>{`X_train, X_test, y_train, y_test`}</code>{`).`}</p>
    <p>{`In order to use our partitioned `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`X`}</code>{` and `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`y`}</code>{` with PyTorch, we'll need to convert component to a `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`torch.Tensor`}</code>{`:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token comment"
          }}>{`# NOTE: we need to convert our numpy arrays to PyTorch Tensors`}</span>{`
`}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# why .float()? see https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/14`}</span>{`
X_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` X_test`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y_test `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span>{`
  torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`from_numpy`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`partition`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`float`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{` \\
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`for`}</span>{` partition `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`in`}</span>{` partitioned_data
`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span>{`

`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`dtype`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`We've set aside 80% of our data for training, so we expect `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`X_train`}</code>{` to contain 80 elements.  Let's verify:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token comment"
          }}>{`# if we reserved 20% of our 100 datapoints for test, `}</span>{`
`}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# we should have 80 datapoints in our train partition`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`assert`}</span>{` `}<span parentName="code" {...{
            "className": "token builtin"
          }}>{`len`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{` `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`==`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`80`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` `}<span parentName="code" {...{
            "className": "token string-interpolation"
          }}><span parentName="span" {...{
              "className": "token string"
            }}>{`f"X_train should have 80 elements, but `}</span><span parentName="span" {...{
              "className": "token interpolation"
            }}><span parentName="span" {...{
                "className": "token punctuation"
              }}>{`{`}</span><span parentName="span" {...{
                "className": "token builtin"
              }}>{`len`}</span><span parentName="span" {...{
                "className": "token punctuation"
              }}>{`(`}</span>{`X_train`}<span parentName="span" {...{
                "className": "token punctuation"
              }}>{`)`}</span><span parentName="span" {...{
                "className": "token punctuation"
              }}>{`}`}</span></span><span parentName="span" {...{
              "className": "token string"
            }}>{` found"`}</span></span></code></pre></div>
    <p>{`Let's take a quick look at our synthetic data to make sure there aren't any surprises.`}</p>
    <h2 {...{
      "id": "plot-our-data",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h2" {...{
        "href": "#plot-our-data",
        "aria-label": "plot our data permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Plot our data`}</h2>
    <p>{`We'll first load our data into a `}<a parentName="p" {...{
        "href": "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html",
        "target": "_self",
        "rel": "nofollow"
      }}>{`Pandas `}<code parentName="a" {...{
          "className": "language-text"
        }}>{`DataFrame`}</code></a>{` to make it a bit easier to manipulate:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` pandas `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`as`}</span>{` pd

df `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` pd`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`concat`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span>{`pd`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`DataFrame`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` columns`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"x1"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` `}<span parentName="code" {...{
            "className": "token string"
          }}>{`"x2"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` pd`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`DataFrame`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` columns`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"y"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` axis`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`1`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
df`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`head`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`You should see something like the following:`}</p>
    <table>
      <thead parentName="table">
        <tr parentName="thead">
          <th parentName="tr" {...{
            "align": "right"
          }}></th>
          <th parentName="tr" {...{
            "align": "right"
          }}>{`x1`}</th>
          <th parentName="tr" {...{
            "align": "right"
          }}>{`x2`}</th>
          <th parentName="tr" {...{
            "align": null
          }}>{`y`}</th>
        </tr>
      </thead>
      <tbody parentName="table">
        <tr parentName="tbody">
          <td parentName="tr" {...{
            "align": "right"
          }}>{`0`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`5.265546`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`1.116012`}</td>
          <td parentName="tr" {...{
            "align": null
          }}>{`1.0`}</td>
        </tr>
        <tr parentName="tbody">
          <td parentName="tr" {...{
            "align": "right"
          }}>{`1`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`4.605167`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`0.804492`}</td>
          <td parentName="tr" {...{
            "align": null
          }}>{`1.0`}</td>
        </tr>
        <tr parentName="tbody">
          <td parentName="tr" {...{
            "align": "right"
          }}>{`2`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`4.562777`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`2.314322`}</td>
          <td parentName="tr" {...{
            "align": null
          }}>{`1.0`}</td>
        </tr>
        <tr parentName="tbody">
          <td parentName="tr" {...{
            "align": "right"
          }}>{`3`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`3.665197`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`2.760254`}</td>
          <td parentName="tr" {...{
            "align": null
          }}>{`1.0`}</td>
        </tr>
        <tr parentName="tbody">
          <td parentName="tr" {...{
            "align": "right"
          }}>{`4`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`4.890372`}</td>
          <td parentName="tr" {...{
            "align": "right"
          }}>{`2.319618`}</td>
          <td parentName="tr" {...{
            "align": null
          }}>{`1.0`}</td>
        </tr>
      </tbody>
    </table>
    <p>{`So far so good.`}</p>
    <p>{`Let's plot (nothing fancy):`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`markers `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` df`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`y`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`apply`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token keyword"
          }}>{`lambda`}</span>{` y`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{` `}<span parentName="code" {...{
            "className": "token string"
          }}>{`"o"`}</span>{` `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`if`}</span>{` y `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`==`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`1`}</span>{` `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`else`}</span>{` `}<span parentName="code" {...{
            "className": "token string"
          }}>{`"x"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
df`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`plot`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`scatter`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`x`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"x1"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"x2"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p><figure parentName="p" {...{
        "className": "gatsby-resp-image-figure",
        "style": {}
      }}>{`
    `}<span parentName="figure" {...{
          "className": "gatsby-resp-image-wrapper",
          "style": {
            "position": "relative",
            "display": "block",
            "marginLeft": "auto",
            "marginRight": "auto",
            "maxWidth": "382px"
          }
        }}>{`
      `}<a parentName="span" {...{
            "className": "gatsby-resp-image-link",
            "href": "/static/f8463a6d3a2f3707037500e9dcd4586d/77edc/synthetic-data.png",
            "style": {
              "display": "block"
            },
            "target": "_blank",
            "rel": "noopener"
          }}>{`
    `}<span parentName="a" {...{
              "className": "gatsby-resp-image-background-image",
              "style": {
                "paddingBottom": "68.91891891891892%",
                "position": "relative",
                "bottom": "0",
                "left": "0",
                "backgroundImage": "url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAOCAYAAAAvxDzwAAAACXBIWXMAAAsTAAALEwEAmpwYAAACcElEQVQ4y42Tz0tUURTHT4QEgZRvLMKIyHDpnxAU4cKNLkRajC5iJApXrnUhbkJognauIjJyatFOyKiFA4nBMFOUNiZOM+P89P3y/bj3Prvv3hPvNepkOPmF8+7hvPs+73vOuw/u3X8AT5+/aKtWKo9N03xjGPpCoVJP5Mv1hKppCdMwEoauJ/T/hGmar8vl8jNAROi7eaPLctySwD/iQmJOJWE4no8nFWMMQ+Dd6J0O03K+BaCiTvber9f4p5zGczsO112P+0JwlJLLFoGIkhBiQmxsDG7d7uukxN3gArFiUrGa0/DdWhXrFjt4u5TyWGeycZNSasPo6CiMjETPE0KyYbu+8Pe4wILmYs1iKBqg4BnZlDfX/gICwKmBgYGI67rZxq5wlB730RfyH3dHoYHEEeDZoaGhjn2HUkrRCtDslPsCg26CcnAhhNgwPDwMg4ODyr5D2XB4XIvNawBUHQ81xzt0OD09DdHo4Qz3gS0+wAFsx2bhnGsWkx4X6DFqw8TExFGHvmyhYIf3y5dfS7tydUuTH77XZLpgCMfjWKgbFkxNTUEsFjtHKf1x0gPsS8Sa5eF61UaLcdxSXeQS0SHUhfHxcejv71dUVf3CGBOEEIe0kOu6dI9Ropo2ydcNUlZ3yUZJJbZLPMPQ69DT0wPt7e1nksmkkkqlLk5OTl7PZrOdi4uLl2ZmZro1TVPm5+cvx+Pxq8ViMTI3N3fl5cKrLq26HXkSn722vbl2gZayFx/NPuxOpTMR6O3tBUVRIJlMQj6fD3/FTCYDlNIwT6fTsLy8HByvNgA4vbT0Nqxvrn0OV6P8E9CqhPnHlRX4DdUUswCXAO4HAAAAAElFTkSuQmCC')",
                "backgroundSize": "cover",
                "display": "block"
              }
            }}></span>{`
  `}<img parentName="a" {...{
              "className": "gatsby-resp-image-image",
              "alt": "\"A quick look at our synthetic data.\"",
              "title": "A quick look at our synthetic data.",
              "src": "/static/f8463a6d3a2f3707037500e9dcd4586d/77edc/synthetic-data.png",
              "srcSet": ["/static/f8463a6d3a2f3707037500e9dcd4586d/12f09/synthetic-data.png 148w", "/static/f8463a6d3a2f3707037500e9dcd4586d/e4a3f/synthetic-data.png 295w", "/static/f8463a6d3a2f3707037500e9dcd4586d/77edc/synthetic-data.png 382w"],
              "sizes": "(max-width: 382px) 100vw, 382px",
              "style": {
                "width": "100%",
                "height": "100%",
                "margin": "0",
                "verticalAlign": "middle",
                "position": "absolute",
                "top": "0",
                "left": "0"
              },
              "loading": "lazy",
              "decoding": "async"
            }}></img>{`
  `}</a>{`
    `}</span>{`
    `}<figcaption parentName="figure" {...{
          "className": "gatsby-resp-image-figcaption"
        }}>{`A quick look at our synthetic data.`}</figcaption>{`
  `}</figure></p>
    <p>{`It looks like binomial logistic regression won't have trouble learning this data.  Can you tell why?`}</p>
    <p>{`Now that our data is ready, let's define our simple neural network (just one neuron)...`}</p>
    <h1 {...{
      "id": "defining-our-logistic-regression-classifier",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#defining-our-logistic-regression-classifier",
        "aria-label": "defining our logistic regression classifier permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Defining our logistic regression classifier`}</h1>
    <p>{`There are actually several ways you can define a network in PyTorch.  I'll show you one way and point you towards another.  We'll define a class that extends PyTorch's `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`nn.Module`}</code>{`, and set a couple of attributes that for the components of our network:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`class`}</span>{` `}<span parentName="code" {...{
            "className": "token class-name"
          }}>{`LogisticRegression`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`nn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`Module`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
    
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`def`}</span>{` `}<span parentName="code" {...{
            "className": "token function"
          }}>{`__init__`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` input_size`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{` `}<span parentName="code" {...{
            "className": "token builtin"
          }}>{`int`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
        `}<span parentName="code" {...{
            "className": "token builtin"
          }}>{`super`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`__init__`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
        `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# xW^T + b`}</span>{`
        self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` nn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`Linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`in_features`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{`input_size`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` out_features`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`1`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` bias`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token boolean"
          }}>{`True`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
        self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`sigmoid `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` nn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`Sigmoid`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`def`}</span>{` `}<span parentName="code" {...{
            "className": "token function"
          }}>{`forward`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` x`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
        z `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`x`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
        `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`return`}</span>{` self`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`sigmoid`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`z`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`Our logistic regression defines `}<span parentName="p" {...{
        "className": "math math-inline"
      }}><span parentName="span" {...{
          "className": "katex"
        }}><span parentName="span" {...{
            "className": "katex-mathml"
          }}><math parentName="span" {...{
              "xmlns": "http://www.w3.org/1998/Math/MathML"
            }}><semantics parentName="math"><mrow parentName="semantics"><mtext parentName="mrow" {...{
                    "mathvariant": "bold"
                  }}>{`x`}</mtext><msup parentName="mrow"><mtext parentName="msup" {...{
                      "mathvariant": "bold"
                    }}>{`W`}</mtext><mtext parentName="msup">{`T`}</mtext></msup><mo parentName="mrow">{`+`}</mo><mi parentName="mrow">{`b`}</mi></mrow><annotation parentName="semantics" {...{
                  "encoding": "application/x-tex"
                }}>{`\\textbf{x}\\textbf{W}^{\\text{T}} + b`}</annotation></semantics></math></span><span parentName="span" {...{
            "className": "katex-html",
            "aria-hidden": "true"
          }}><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "1.0007em",
                  "verticalAlign": "-0.0833em"
                }
              }}></span><span parentName="span" {...{
                "className": "mord text"
              }}><span parentName="span" {...{
                  "className": "mord textbf"
                }}>{`x`}</span></span><span parentName="span" {...{
                "className": "mord"
              }}><span parentName="span" {...{
                  "className": "mord text"
                }}><span parentName="span" {...{
                    "className": "mord textbf"
                  }}>{`W`}</span></span><span parentName="span" {...{
                  "className": "msupsub"
                }}><span parentName="span" {...{
                    "className": "vlist-t"
                  }}><span parentName="span" {...{
                      "className": "vlist-r"
                    }}><span parentName="span" {...{
                        "className": "vlist",
                        "style": {
                          "height": "0.9173em"
                        }
                      }}><span parentName="span" {...{
                          "style": {
                            "top": "-3.139em",
                            "marginRight": "0.05em"
                          }
                        }}><span parentName="span" {...{
                            "className": "pstrut",
                            "style": {
                              "height": "2.7em"
                            }
                          }}></span><span parentName="span" {...{
                            "className": "sizing reset-size6 size3 mtight"
                          }}><span parentName="span" {...{
                              "className": "mord mtight"
                            }}><span parentName="span" {...{
                                "className": "mord text mtight"
                              }}><span parentName="span" {...{
                                  "className": "mord mtight"
                                }}>{`T`}</span></span></span></span></span></span></span></span></span></span><span parentName="span" {...{
                "className": "mspace",
                "style": {
                  "marginRight": "0.2222em"
                }
              }}></span><span parentName="span" {...{
                "className": "mbin"
              }}>{`+`}</span><span parentName="span" {...{
                "className": "mspace",
                "style": {
                  "marginRight": "0.2222em"
                }
              }}></span></span><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "0.6944em"
                }
              }}></span><span parentName="span" {...{
                "className": "mord mathnormal"
              }}>{`b`}</span></span></span></span></span>{`.  `}<a parentName="p" {...{
        "href": "https://pytorch.org/docs/stable/nn.html",
        "target": "_self",
        "rel": "nofollow"
      }}><code parentName="a" {...{
          "className": "language-text"
        }}>{`nn.Linear`}</code></a>{` is one of the lego-like blocks we use construct a network in PyTorch.  `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`nn.Linear`}</code>{` is a linear layer that performs the linear transformation `}<span parentName="p" {...{
        "className": "math math-inline"
      }}><span parentName="span" {...{
          "className": "katex"
        }}><span parentName="span" {...{
            "className": "katex-mathml"
          }}><math parentName="span" {...{
              "xmlns": "http://www.w3.org/1998/Math/MathML"
            }}><semantics parentName="math"><mrow parentName="semantics"><mtext parentName="mrow" {...{
                    "mathvariant": "bold"
                  }}>{`x`}</mtext><msup parentName="mrow"><mtext parentName="msup" {...{
                      "mathvariant": "bold"
                    }}>{`W`}</mtext><mtext parentName="msup">{`T`}</mtext></msup><mo parentName="mrow">{`+`}</mo><mi parentName="mrow">{`b`}</mi></mrow><annotation parentName="semantics" {...{
                  "encoding": "application/x-tex"
                }}>{`\\textbf{x}\\textbf{W}^{\\text{T}} + b`}</annotation></semantics></math></span><span parentName="span" {...{
            "className": "katex-html",
            "aria-hidden": "true"
          }}><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "1.0007em",
                  "verticalAlign": "-0.0833em"
                }
              }}></span><span parentName="span" {...{
                "className": "mord text"
              }}><span parentName="span" {...{
                  "className": "mord textbf"
                }}>{`x`}</span></span><span parentName="span" {...{
                "className": "mord"
              }}><span parentName="span" {...{
                  "className": "mord text"
                }}><span parentName="span" {...{
                    "className": "mord textbf"
                  }}>{`W`}</span></span><span parentName="span" {...{
                  "className": "msupsub"
                }}><span parentName="span" {...{
                    "className": "vlist-t"
                  }}><span parentName="span" {...{
                      "className": "vlist-r"
                    }}><span parentName="span" {...{
                        "className": "vlist",
                        "style": {
                          "height": "0.9173em"
                        }
                      }}><span parentName="span" {...{
                          "style": {
                            "top": "-3.139em",
                            "marginRight": "0.05em"
                          }
                        }}><span parentName="span" {...{
                            "className": "pstrut",
                            "style": {
                              "height": "2.7em"
                            }
                          }}></span><span parentName="span" {...{
                            "className": "sizing reset-size6 size3 mtight"
                          }}><span parentName="span" {...{
                              "className": "mord mtight"
                            }}><span parentName="span" {...{
                                "className": "mord text mtight"
                              }}><span parentName="span" {...{
                                  "className": "mord mtight"
                                }}>{`T`}</span></span></span></span></span></span></span></span></span></span><span parentName="span" {...{
                "className": "mspace",
                "style": {
                  "marginRight": "0.2222em"
                }
              }}></span><span parentName="span" {...{
                "className": "mbin"
              }}>{`+`}</span><span parentName="span" {...{
                "className": "mspace",
                "style": {
                  "marginRight": "0.2222em"
                }
              }}></span></span><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "0.6944em"
                }
              }}></span><span parentName="span" {...{
                "className": "mord mathnormal"
              }}>{`b`}</span></span></span></span></span>{`. `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`nn.Linear`}</code>{` includes our weights and an optional bias.  All we need to do is define how many incoming connections we have and how many neurons we want in this layer.  For logistic regression, we just want a single neuron.  `}<a parentName="p" {...{
        "href": "https://pytorch.org/docs/stable/nn.html",
        "target": "_self",
        "rel": "nofollow"
      }}>{`You can peruse the PyTorch docs to learn about some of the other available layers in the `}<code parentName="a" {...{
          "className": "language-text"
        }}>{`nn`}</code>{` package`}</a>{`.  `}</p>
    <p>{`Our next component is our non-linear activation function, the sigmoid.`}</p>
    <p>{`Any class that extends `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`nn.Module`}</code>{` must implement a `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`forward()`}</code>{`.  The `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`forward()`}</code>{` method defines the forward pass of our network (input `}<span parentName="p" {...{
        "className": "math math-inline"
      }}><span parentName="span" {...{
          "className": "katex"
        }}><span parentName="span" {...{
            "className": "katex-mathml"
          }}><math parentName="span" {...{
              "xmlns": "http://www.w3.org/1998/Math/MathML"
            }}><semantics parentName="math"><mrow parentName="semantics"><mo parentName="mrow">{`→`}</mo></mrow><annotation parentName="semantics" {...{
                  "encoding": "application/x-tex"
                }}>{`\\rightarrow`}</annotation></semantics></math></span><span parentName="span" {...{
            "className": "katex-html",
            "aria-hidden": "true"
          }}><span parentName="span" {...{
              "className": "base"
            }}><span parentName="span" {...{
                "className": "strut",
                "style": {
                  "height": "0.3669em"
                }
              }}></span><span parentName="span" {...{
                "className": "mrel"
              }}>{`→`}</span></span></span></span></span>{` output).  Through the "mathemagic" of autodifferentiation (PyTorch's implementation is called `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`Autograd`}</code>{`), PyTorch can figure out how to efficiently compute our backward pass for us.  Talk about convenient!`}</p>
    <p>{`Now that we've defined our network, let's train it using the data we generated earlier ...`}</p>
    <h2 {...{
      "id": "train",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h2" {...{
        "href": "#train",
        "aria-label": "train permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Train`}</h2>
    <p>{`First, let's set our random seed to keep thing reproducible.  Then we'll create an instance of our `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`LogisticRegression`}</code>{` class that expects an input of two features:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token comment"
          }}>{`# for reproducibility`}</span>{`
torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`manual_seed`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`SEED`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
`}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# each datapoint in our X has two features`}</span>{`
model `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` LogisticRegression`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`input_size`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`2`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`We're almost ready to train.  We need to define a loss function, an optimizer, our learning rate (i.e., how large each step should be when adjusting our parameters),  and decide on a number of epochs (number of complete passes over the training data).  We'll use the familiar binary cross entropy cost function for our loss and stochastic gradient descent as our optimizer. `}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`EPOCHS `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`50`}</span>{`
`}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# we'll use binary cross entropy as loss function`}</span>{`
loss_fn `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`nn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`BCELoss`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
optimizer `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`optim`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`SGD`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# we want to optimize ("fit") **all** of our model's parameters.`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# in the case of logistic regression, that's \\mathbf{w} and b`}</span>{`
    model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`parameters`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# our learning rate`}</span>{`
    lr `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`0.01`}</span>{`
`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`Since our training data is so small (only 80 datapoints) and we have so few parameters, we'll just make a single update per complete pass over the training data.  Normally, we'd split things into mini-batches.  In a later tutorial, we'll look at how to do this easily by defining a PyTorch `}<a parentName="p" {...{
        "href": "https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader",
        "target": "_self",
        "rel": "nofollow"
      }}>{`DataLoader`}</a>{`.  We'll also keep track of our loss for each epoch, so that we can visualize the network's learning process:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token comment"
          }}>{`# put our model into training mode (i.e., tell it to calculate gradients)`}</span>{`
model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
loss_tracker `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`list`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`parameters`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`for`}</span>{` i `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`in`}</span>{` `}<span parentName="code" {...{
            "className": "token builtin"
          }}>{`range`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`1`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` EPOCHS `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`+`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`1`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# IMPORTANT: clear out our gradients from the last epoch.`}</span>{`
    model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`zero_grad`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    y_pred `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X_train`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    y_true `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` y_train
    loss `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` loss_fn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`squeeze`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y_true`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token string-interpolation"
          }}><span parentName="span" {...{
              "className": "token string"
            }}>{`f"Epoch `}</span><span parentName="span" {...{
              "className": "token interpolation"
            }}><span parentName="span" {...{
                "className": "token punctuation"
              }}>{`{`}</span>{`i`}<span parentName="span" {...{
                "className": "token punctuation"
              }}>{`}`}</span></span><span parentName="span" {...{
              "className": "token string"
            }}>{` loss: `}</span><span parentName="span" {...{
              "className": "token interpolation"
            }}><span parentName="span" {...{
                "className": "token punctuation"
              }}>{`{`}</span>{`loss`}<span parentName="span" {...{
                "className": "token punctuation"
              }}>{`}`}</span></span><span parentName="span" {...{
              "className": "token string"
            }}>{`"`}</span></span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# backward pass`}</span>{`
    loss`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`backward`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# actually update our parameters`}</span>{`
    optimizer`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`step`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    loss_tracker`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`append`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`i`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` loss`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`item`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`Quite a bit is happening in that block above.  Let's review the crucial pieces:`}</p>
    <ol>
      <li parentName="ol">{`For each epoch, we ensure our gradient is reset to zero, rather than accumulating over many updates.  `}</li>
      <li parentName="ol">{`We then make our predictions with `}<code parentName="li" {...{
          "className": "language-text"
        }}>{`model(X_train)`}</code>{` and calculate our loss (`}<code parentName="li" {...{
          "className": "language-text"
        }}>{`loss = loss_fn(y_pred.squeeze(), y_true)`}</code>{`).`}</li>
      <li parentName="ol">{`Next, we calculate our gradients (`}<code parentName="li" {...{
          "className": "language-text"
        }}>{`loss.backward()`}</code>{`).  `}</li>
      <li parentName="ol">{`Finally, we update our parameters (`}<code parentName="li" {...{
          "className": "language-text"
        }}>{`optimizer.step()`}</code>{`).`}</li>
    </ol>
    <h3 {...{
      "id": "plotting-our-loss",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h3" {...{
        "href": "#plotting-our-loss",
        "aria-label": "plotting our loss permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Plotting our loss`}</h3>
    <p>{`Let's take a quick look at the learning curve of our model:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`loss_df `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` pd`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`DataFrame`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`loss_tracker`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` columns`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"epoch"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` `}<span parentName="code" {...{
            "className": "token string"
          }}>{`"loss"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
loss_df`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`plot`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`line`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`x`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"epoch"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"loss"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p><figure parentName="p" {...{
        "className": "gatsby-resp-image-figure",
        "style": {}
      }}>{`
    `}<span parentName="figure" {...{
          "className": "gatsby-resp-image-wrapper",
          "style": {
            "position": "relative",
            "display": "block",
            "marginLeft": "auto",
            "marginRight": "auto",
            "maxWidth": "378px"
          }
        }}>{`
      `}<a parentName="span" {...{
            "className": "gatsby-resp-image-link",
            "href": "/static/e668d1e240463bf99a56989e35106f18/f0991/learning-curve.png",
            "style": {
              "display": "block"
            },
            "target": "_blank",
            "rel": "noopener"
          }}>{`
    `}<span parentName="a" {...{
              "className": "gatsby-resp-image-background-image",
              "style": {
                "paddingBottom": "69.5945945945946%",
                "position": "relative",
                "bottom": "0",
                "left": "0",
                "backgroundImage": "url('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAOCAYAAAAvxDzwAAAACXBIWXMAAAsTAAALEwEAmpwYAAACA0lEQVQ4y52Uz2sTQRTHB1sPehCK7aGoFCko5JgQ/4b8M1YPEb2I4EGj0IMHb0ZrCTG/mpJLFSleWigIgQSKEBGyFZNNttnddDe7Ozuz+57MEGNTAm598GXfzL758J2ZxxBEJJVyeeXoV+eDaRpFXdcLUWUYekHta4Wf3V5eUZRaq9V6TAghczu16p1jy0U/xHMH5QGaLkNdH6CqqgcS+KlWjQ8sxx96PEBEDgCRhQgcAaiAj0ajzxK4Xa3GPUoDzWGICCL+6exPjfgCgDCCtm3vCuB8ZWsrznw/0F2OPg/hVOF/Aec2NjeTnPnBiIU49Fgk4Bn4NLBYLCZ83w/E8r5N4ayD8wIvVCoVeYZicuD44LBg1rYiA+dzuVySjoE8CEG16NRZRdC0w1KplKCUcrEeEUKLcujZFIJQ2jolmC3ZPn+BF/P5fJIxNrUnjwOKNjI9LiXG4v6D2YJxH36RwGw2mzBN03Zd13XG4bmuI0b6iS3VNSxHOR46R0La0FE0U+ZiTtFMu2NY2OtrOwJ4KRaLXclkMiuNRmOp3W5fTafTt96+27jGQ1hYf/Z09fX6i5v9H9+Wyu/f3Hjy6MFtpncXHVVZfHj/buxg9+MyIi68evl89d7a2nUBJKlUiohHotlsknq9LnOh/b29SU4IWSaEXBb5SbdNOt8PJ/8Ov+5P6n4DZ4e3o2Yxt1gAAAAASUVORK5CYII=')",
                "backgroundSize": "cover",
                "display": "block"
              }
            }}></span>{`
  `}<img parentName="a" {...{
              "className": "gatsby-resp-image-image",
              "alt": "\"Our learning curve over 50 epochs.\"",
              "title": "Our learning curve over 50 epochs. Would fewer epochs have been sufficient for this dataset?",
              "src": "/static/e668d1e240463bf99a56989e35106f18/f0991/learning-curve.png",
              "srcSet": ["/static/e668d1e240463bf99a56989e35106f18/12f09/learning-curve.png 148w", "/static/e668d1e240463bf99a56989e35106f18/e4a3f/learning-curve.png 295w", "/static/e668d1e240463bf99a56989e35106f18/f0991/learning-curve.png 378w"],
              "sizes": "(max-width: 378px) 100vw, 378px",
              "style": {
                "width": "100%",
                "height": "100%",
                "margin": "0",
                "verticalAlign": "middle",
                "position": "absolute",
                "top": "0",
                "left": "0"
              },
              "loading": "lazy",
              "decoding": "async"
            }}></img>{`
  `}</a>{`
    `}</span>{`
    `}<figcaption parentName="figure" {...{
          "className": "gatsby-resp-image-figcaption"
        }}>{`Our learning curve over 50 epochs. Would fewer epochs have been sufficient for this dataset?`}</figcaption>{`
  `}</figure></p>
    <p>{`Not too shabby. Now, let's evaluate our model on the held-out test data...`}</p>
    <h2 {...{
      "id": "evaluate",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h2" {...{
        "href": "#evaluate",
        "aria-label": "evaluate permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Evaluate`}</h2>
    <p>{`Let's collect our predictions on the test data into a `}<a parentName="p" {...{
        "href": "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html",
        "target": "_self",
        "rel": "nofollow"
      }}>{`Pandas `}<code parentName="a" {...{
          "className": "language-text"
        }}>{`DataFrame`}</code></a>{` and then feed that data to scikit-learn's `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`classification_report`}</code>{`:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token comment"
          }}>{`# convert continous-valued predictions into a hard 1 or 0.`}</span>{`
binarize `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`lambda`}</span>{` preds`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{` `}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`preds `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`>`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`0.5`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`long`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`

`}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# don't calculate gradients when evaluating`}</span>{`
model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`eval`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`with`}</span>{` torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`no_grad`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
    y_pred `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`X_test`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    y_true `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` y_test`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span><span parentName="code" {...{
            "className": "token builtin"
          }}>{`long`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    pred_df `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` pd`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`DataFrame`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`{`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`'y_true'`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`y_true`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`'y_pred'`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`binarize`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`flatten`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`}`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`pred_df`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`from`}</span>{` sklearn`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`metrics `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`import`}</span>{` classification_report

y_hat `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` binarize`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`flatten`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
report `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` classification_report`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_true`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{`y_true`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`,`}</span>{` y_pred`}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{`y_hat`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`

`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`report`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <h3 {...{
      "id": "whats-up-with-eval-and-with-torchno_grad",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h3" {...{
        "href": "#whats-up-with-eval-and-with-torchno_grad",
        "aria-label": "whats up with eval and with torchno_grad permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`what's up with `}<code parentName="h3" {...{
        "className": "language-text"
      }}>{`.eval()`}</code>{` and `}<code parentName="h3" {...{
        "className": "language-text"
      }}>{`with torch.no_grad()`}</code>{`?`}</h3>
    <p>{`We've trained our model and fit our parameters.  We no longer need to make any adjustments to them. the `}<a parentName="p" {...{
        "href": "https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad",
        "target": "_self",
        "rel": "nofollow"
      }}><code parentName="a" {...{
          "className": "language-text"
        }}>{`with torch.no_grad()`}</code>{` context manager`}</a>{` tells PyTorch that no gradients need to be computed for the block that follows.  This is a convenient way of temporarily setting all of our parameters to `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`requires_grad=False`}</code>{`.  Let's take a look...`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"with gradient calculations off:"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`with`}</span>{` torch`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`no_grad`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span>{`
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`bias`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
    `}<span parentName="code" {...{
            "className": "token comment"
          }}>{`# let's add two ...`}</span>{`
    res `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`bias `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`+`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`2`}</span>{`
    `}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`res`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}><span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token string"
          }}>{`"with gradient calculations on:"`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`bias`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span>{`
res `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`=`}</span>{` model`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`linear`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`bias `}<span parentName="code" {...{
            "className": "token operator"
          }}>{`+`}</span>{` `}<span parentName="code" {...{
            "className": "token number"
          }}>{`2`}</span>{`
`}<span parentName="code" {...{
            "className": "token keyword"
          }}>{`print`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`res`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`What do you notice?`}</p>
    <div {...{
      "className": "admonition admonition-info alert alert--info"
    }}><div parentName="div" {...{
        "className": "admonition-heading"
      }}><h5 parentName="div"><span parentName="h5" {...{
            "className": "admonition-icon"
          }}><svg parentName="span" {...{
              "xmlns": "http://www.w3.org/2000/svg",
              "width": "14",
              "height": "16",
              "viewBox": "0 0 14 16"
            }}><path parentName="svg" {...{
                "fillRule": "evenodd",
                "d": "M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z"
              }}></path></svg></span>{`NOTE`}</h5></div><div parentName="div" {...{
        "className": "admonition-content"
      }}><p parentName="div"><a parentName="p" {...{
            "href": "https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval",
            "target": "_self",
            "rel": "nofollow"
          }}><code parentName="a" {...{
              "className": "language-text"
            }}>{`.eval()`}</code></a>{` can alter the the behavior of model components like regularizers (ex. Dropout) at inference time.  In our particular model, it has no effect, but it's good practice to use it whenever you're `}<strong parentName="p">{`not`}</strong>{` training/fitting parameters.`}</p></div></div>
    <h3 {...{
      "id": "why-do-we-call-binarize-here",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h3" {...{
        "href": "#why-do-we-call-binarize-here",
        "aria-label": "why do we call binarize here permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`why do we call `}<code parentName="h3" {...{
        "className": "language-text"
      }}>{`binarize()`}</code>{` here?`}</h3>
    <p><code parentName="p" {...{
        "className": "language-text"
      }}>{`y_pred`}</code>{`'s values range between 0 and 1 (recall this is the output of the sigmoid). Let's take a look at the first few items in `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`y_pred`}</code>{` ...`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`4`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span></code></pre></div>
    <p>{`If we want hard predictions, we need to map these probabilities to 1 or 0.  We'll treat anything .5 as class 1 as and anything .5 or less as class 0.`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`binarize`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`[`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`:`}</span><span parentName="code" {...{
            "className": "token number"
          }}>{`4`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`]`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <h3 {...{
      "id": "why-do-we-need-flatten",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h3" {...{
        "href": "#why-do-we-need-flatten",
        "aria-label": "why do we need flatten permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`why do we need `}<code parentName="h3" {...{
        "className": "language-text"
      }}>{`.flatten()`}</code>{`?`}</h3>
    <p>{`Let's take a look at the shape of `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`y_pred`}</code>{` ...`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`size`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <p>{`20 rows x 1 column.  Pandas is expecting a vector (`}<a parentName="p" {...{
        "href": "https://en.wikipedia.org/wiki/Tensor#Examples",
        "target": "_self",
        "rel": "nofollow"
      }}>{`rank-1 tensor`}</a>{`) for its column.  `}<code parentName="p" {...{
        "className": "language-text"
      }}>{`.flatten()`}</code>{` does exactly this:`}</p>
    <div {...{
      "className": "gatsby-highlight",
      "data-language": "python"
    }}><pre parentName="div" {...{
        "className": "language-python"
      }}><code parentName="pre" {...{
          "className": "language-python"
        }}>{`y_pred`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`flatten`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`.`}</span>{`size`}<span parentName="code" {...{
            "className": "token punctuation"
          }}>{`(`}</span><span parentName="code" {...{
            "className": "token punctuation"
          }}>{`)`}</span></code></pre></div>
    <h1 {...{
      "id": "practice",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#practice",
        "aria-label": "practice permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Practice`}</h1>
    <p>{`You now know the basics of training a defining, training, and evaluating a neural network using PyTorch.  Apply what you've learned by attempting the following challenges:`}</p>
    <ul>
      <li parentName="ul">
        <p parentName="li">{`PyTorch provides several ways of defining neural networks. Rewrite `}<code parentName="p" {...{
            "className": "language-text"
          }}>{`LogisticRegression`}</code>{` using `}<a parentName="p" {...{
            "href": "https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html",
            "target": "_self",
            "rel": "nofollow"
          }}><code parentName="a" {...{
              "className": "language-text"
            }}>{`torch.nn.Sequential`}</code></a>{` and retrain your network.   `}</p>
      </li>
      <li parentName="ul">
        <p parentName="li">{`Define an MLP two hidden layers each of 4 units and a final sigmoid activation.  Train this MLP on your data.  Is the learning curve any different?`}</p>
      </li>
      <li parentName="ul">
        <p parentName="li">{`Adjust the data generation procedure to produce 500 samples.  Partition this data into 3 slices: train (80%), dev (10%), and test (10%).`}</p>
        <ul parentName="li">
          <li parentName="ul">{`Retrain your network on this data, but record your performance on the validation data at the end of each epoch. Plot this curve.`}
            <div parentName="li" {...{
              "className": "admonition admonition-note alert alert--secondary"
            }}><div parentName="div" {...{
                "className": "admonition-heading"
              }}><h5 parentName="div"><span parentName="h5" {...{
                    "className": "admonition-icon"
                  }}><svg parentName="span" {...{
                      "xmlns": "http://www.w3.org/2000/svg",
                      "width": "14",
                      "height": "16",
                      "viewBox": "0 0 14 16"
                    }}><path parentName="svg" {...{
                        "fillRule": "evenodd",
                        "d": "M6.3 5.69a.942.942 0 0 1-.28-.7c0-.28.09-.52.28-.7.19-.18.42-.28.7-.28.28 0 .52.09.7.28.18.19.28.42.28.7 0 .28-.09.52-.28.7a1 1 0 0 1-.7.3c-.28 0-.52-.11-.7-.3zM8 7.99c-.02-.25-.11-.48-.31-.69-.2-.19-.42-.3-.69-.31H6c-.27.02-.48.13-.69.31-.2.2-.3.44-.31.69h1v3c.02.27.11.5.31.69.2.2.42.31.69.31h1c.27 0 .48-.11.69-.31.2-.19.3-.42.31-.69H8V7.98v.01zM7 2.3c-3.14 0-5.7 2.54-5.7 5.68 0 3.14 2.56 5.7 5.7 5.7s5.7-2.55 5.7-5.7c0-3.15-2.56-5.69-5.7-5.69v.01zM7 .98c3.86 0 7 3.14 7 7s-3.14 7-7 7-7-3.12-7-7 3.14-7 7-7z"
                      }}></path></svg></span>{`NOTE `}</h5></div><div parentName="div" {...{
                "className": "admonition-content"
              }}><p parentName="div">{`You'll want to avoid calculating the gradients when processing the validation data.`}</p></div></div>
          </li>
        </ul>
      </li>
    </ul>
    <h1 {...{
      "id": "next-steps",
      "style": {
        "position": "relative"
      }
    }}><a parentName="h1" {...{
        "href": "#next-steps",
        "aria-label": "next steps permalink",
        "className": "md-header before"
      }}><svg parentName="a" {...{
          "aria-hidden": "true",
          "height": "20",
          "version": "1.1",
          "viewBox": "0 0 16 16",
          "width": "20"
        }}><path parentName="svg" {...{
            "fillRule": "evenodd",
            "d": "M4 9h1v1H4c-1.5 0-3-1.69-3-3.5S2.55 3 4 3h4c1.45 0 3 1.69 3 3.5 0 1.41-.91 2.72-2 3.25V8.59c.58-.45 1-1.27 1-2.09C10 5.22 8.98 4 8 4H4c-.98 0-2 1.22-2 2.5S3 9 4 9zm9-3h-1v1h1c1 0 2 1.22 2 2.5S13.98 12 13 12H9c-.98 0-2-1.22-2-2.5 0-.83.42-1.64 1-2.09V6.25c-1.09.53-2 1.84-2 3.25C6 11.31 7.55 13 9 13h4c1.45 0 3-1.69 3-3.5S14.5 6 13 6z"
          }}></path></svg></a>{`Next steps`}</h1>
    <p>{`If you would like to deepen your understanding of PyTorch, try working through the `}<a parentName="p" {...{
        "href": "https://pytorch.org/tutorials/beginner/introyt.html",
        "target": "_self",
        "rel": "nofollow"
      }}>{`PyTorch intro tutorial series (official)`}</a></p>
    <div {...{
      "className": "admonition admonition-note alert alert--secondary"
    }}><div parentName="div" {...{
        "className": "admonition-heading"
      }}><h5 parentName="div"><span parentName="h5" {...{
            "className": "admonition-icon"
          }}><svg parentName="span" {...{
              "xmlns": "http://www.w3.org/2000/svg",
              "width": "14",
              "height": "16",
              "viewBox": "0 0 14 16"
            }}><path parentName="svg" {...{
                "fillRule": "evenodd",
                "d": "M6.3 5.69a.942.942 0 0 1-.28-.7c0-.28.09-.52.28-.7.19-.18.42-.28.7-.28.28 0 .52.09.7.28.18.19.28.42.28.7 0 .28-.09.52-.28.7a1 1 0 0 1-.7.3c-.28 0-.52-.11-.7-.3zM8 7.99c-.02-.25-.11-.48-.31-.69-.2-.19-.42-.3-.69-.31H6c-.27.02-.48.13-.69.31-.2.2-.3.44-.31.69h1v3c.02.27.11.5.31.69.2.2.42.31.69.31h1c.27 0 .48-.11.69-.31.2-.19.3-.42.31-.69H8V7.98v.01zM7 2.3c-3.14 0-5.7 2.54-5.7 5.68 0 3.14 2.56 5.7 5.7 5.7s5.7-2.55 5.7-5.7c0-3.15-2.56-5.69-5.7-5.69v.01zM7 .98c3.86 0 7 3.14 7 7s-3.14 7-7 7-7-3.12-7-7 3.14-7 7-7z"
              }}></path></svg></span>{`NOTE `}</h5></div><div parentName="div" {...{
        "className": "admonition-content"
      }}><p parentName="div">{`You can use the `}<code parentName="p" {...{
            "className": "language-text"
          }}>{`uazhlt/ling-582-playground:latest`}</code>{` docker image to follow along.`}</p></div></div>
    {
      /* ## References */
    }
    {
      /* ## Footnotes */
    }

    </MDXLayout>;
}
;
MDXContent.isMDXComponent = true;
      