Giter Club home page Giter Club logo

hw7-regression's Introduction

HW 7: logistic regression project7

In this assignment, you'll implement a classifier using logistic regression, optimized with gradient descent.

Overview

In class, we went over an implementation of linear regression using gradient descent. For this homework, you will be implementing a logistic regression model using the same framework. Logistic regression is useful for binary classification because the sigmoid function outputs a value between 0 and 1.

In this repository, you are given a set of simulated medical record data from patients with small cell and non-small cell lung cancers. Your goal to apply a logistic regression classifier to this dataset, predicting whether a patient has small cell or non-small cell lung cancer based on features of their medical record prior to diagnosis.

Logistic regression

As stated above, logistic regression involves using a sigmoid function to model the data. Just like in linear regression, we will define a loss function to keep track of how well the model performs. But instead of mean-squared error, you will implement the binary cross entropy loss function. This function minimizes the error when the predicted y is close to an expected value of 1 or 0. Here are some resources to get you started: [1], [2], [3].

Dataset

You will find the full dataset in data/nsclc.csv. Class labels are encoded in the NSCLC column of the dataset, with 1 = NSCLC and 0 = small cell. A set of features has been pre-selected for you to use in your model during testing (see main.py), but you are encouraged to submit unit tests that look at different features. The full list of features can be found in logreg/utils.py.

Tasks + Grading

  • Algorithm Implementation (6):
    • Complete the make_prediction method (2)
    • Complete the loss_function method (2)
    • Complete the calculate_gradient method (2)
  • Unit tests (3):
    • Unit test for test_prediction
    • Unit test for test_loss_function
    • Unit test for test_gradient
    • Unit test for test_training
  • Code Readability (1)
  • Extra credit (0.5)
    • Github actions and workflows (up to 0.5)

Getting started

Fork this repository to your own GitHub account. Work on the codebase locally and commit changes to your forked repository.

You will need following packages:

Additional notes

Try tuning the hyperparameters if you find that your model doesn't converge. Too high of a learning rate or too large of a batch size can sometimes cause the model to be unstable (e.g. loss function goes to infinity). If you're interested, scikit-learn also has some built-in toy datasets that you can use for testing.

We're applying a pretty simple model to a relatively complex problem here, so you should expect your classifier to perform decently but not amazingly. It's also possible for a given optimization run to get stuck in a local minima depending on the initialization. With that said, if your implementation is correct and you found reasonable hyperparameters, you should almost always at least do better than chance.

hw7-regression's People

Contributors

zackmawaldi avatar khanu263 avatar hchen725 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.