Introduction to Logistic Regression
Logistic Regression is a supervised learning algorithm used for classification problems. Unlike Linear Regression which predicts continuous values, Logistic Regression predicts discrete outcomes — often binary (0 or 1).
Think of it like answering a Yes/No question with probability:
- Will the email be spam or not?
- Will the customer buy the product or not?
- Is this tumor malignant or benign?
Real-Life Example: Email Spam Classification
Suppose you're building a model to classify emails as "Spam" (1) or "Not Spam" (0). You collect features like:
- Number of capital words in subject
- Presence of suspicious keywords (e.g., "lottery", "offer")
- Email length
If you apply Linear Regression, it may predict values like 1.2 or -0.4 — but we need a clear binary class! That’s where Logistic Regression shines by converting outputs into probabilities between 0 and 1 using the Sigmoid Function.
The Sigmoid (Logistic) Function
The logistic regression model computes:
y = 1 / (1 + e^(-z))
Where:
z = b0 + b1*x1 + b2*x2 + ... + bn*xn
y
= predicted probability (0 ≤ y ≤ 1)
We then use a threshold (typically 0.5) to decide:
- If
y ≥ 0.5
→ class 1 (e.g., Spam) - If
y < 0.5
→ class 0 (e.g., Not Spam)
🧠 Question: Why not just use Linear Regression and round the value?
Answer: Linear Regression doesn't output probabilities and is sensitive to outliers. Logistic Regression, through the sigmoid function, maps output between 0 and 1, which is perfect for classification tasks.
Logistic Regression in Python (Step-by-Step)
Problem: Predict if a student will pass an exam
Let’s say you have data of students with their "Hours Studied" and whether they passed (1) or failed (0). We'll build a logistic regression model to predict the outcome.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
# Step 1: Create Dataset
data = {
"Hours_Studied": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"Passed": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
}
df = pd.DataFrame(data)
# Step 2: Split features and target
X = df[["Hours_Studied"]] # input feature
y = df["Passed"] # target label
# Step 3: Train/Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Step 4: Train the Logistic Regression model
model = LogisticRegression()
model.fit(X_train, y_train)
# Step 5: Predict on test data
y_pred = model.predict(X_test)
# Step 6: Print performance metrics
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))
# Step 7: Plot the sigmoid curve
x_range = np.linspace(0, 11, 100)
y_prob = model.predict_proba(x_range.reshape(-1, 1))[:, 1]
plt.scatter(df["Hours_Studied"], df["Passed"], color='red', label='Actual')
plt.plot(x_range, y_prob, color='blue', label='Sigmoid Curve')
plt.xlabel("Hours Studied")
plt.ylabel("Probability of Passing")
plt.title("Logistic Regression - Exam Pass Prediction")
plt.legend()
plt.grid(True)
plt.show()
Code Explanation
LogisticRegression()
: This is the sklearn model for binary classification.predict()
: Predicts whether a student passed (0 or 1).predict_proba()
: Gives the probability of the student passing.confusion_matrix
andclassification_report
: Help us understand how well the model performs.- The final plot shows actual data vs. the predicted sigmoid curve.
🧠 Question: What if my data has multiple features like 'Hours Studied', 'Attendance', 'Sleep Hours'?
Answer: Logistic Regression handles multiple features easily. Just use more columns in X = df[[...]])
. Scikit-learn automatically handles it internally!
Summary
- Use Logistic Regression for binary classification problems.
- It outputs probabilities using the sigmoid function.
- Apply a threshold to convert probabilities into class predictions.
- Scikit-learn makes it very easy to implement!
What’s Next?
In the next lesson, we’ll explore another popular classification algorithm — K-Nearest Neighbors (KNN).