Ocena modeli klasyfikacyjnych: miary i krzywa ROC

W jaki sposób możemy ocenić, że nasz model do klasyfikacji działa dobrze, albo lepiej niż inny model? Potrzebujemy go w pewien sposób ‘zmierzyć’, sprawdzić, jak dobrze radzi sobie z klasyfikacją danych. Zacznijmy od wyjaśnienia czym jest macierz błędów. Jest to tabela o wymiarach NxN, gdzie N reprezentuje liczbę klas w problemie klasyfikacji. Dla uproszczenia przyjmijmy klasyfikację binarną, czyli macierz 2×2. Macierz błędów jest bardzo przydatna do oceny jakości działania modelu klasyfikacyjnego, umożliwiając zrozumienie, które klasy są poprawnie rozpoznawane przez model, a które klasy są mylone.

Macierz błędów składa się z czterech głównych komponentów:

  1. True Positives (TP): Liczba przypadków, w których model poprawnie przewidział pozytywną klasę.
  2. True Negatives (TN): Liczba przypadków, w których model poprawnie przewidział negatywną klasę.
  3. False Positives (FP): Liczba przypadków, w których model błędnie przewidział pozytywną klasę. Oznacza to, że model mylnie stwierdził obecność obiektów danej klasy, gdy w rzeczywistości ich nie było. Często nazywane błędem pierwszego rodzaju.
  4. False Negatives (FN): Liczba przypadków, w których model błędnie przewidział negatywną klasę. Oznacza to, że model pomija obiekty danej klasy, które były obecne. Często nazywane błędem drugiego rodzaju.

Macierz błędów jest używana do obliczenia wielu metryk oceny jakości modelu, w tym:

1. Accuracy

Accuracy (dokładność) to jedna z najprostszych miar oceny wydajności modeli klasyfikacyjnych. Określa ona stosunek liczby poprawnych przewidywań do ogólnej liczby próbek, czyli procent poprawnie sklasyfikowanych przypadków w stosunku do ogólnej liczby przypadków. Wartość accuracy jest przydatna, gdy klasy w zbiorze danych są zrównoważone, jest często używana do oceny jakości modelu, ponieważ dostarcza jasny i zrozumiały sposób pomiaru jego skuteczności.

Accuracy = (TP + TN) / (TP + TN + FP + FN)​

2. Precission

Precission (precyzja) mierzy stosunek poprawnie sklasyfikowanych pozytywnych przypadków do wszystkich przypadków sklasyfikowanych jako pozytywne. Wysoka precyzja jest pożądana w sytuacjach, gdzie błędne pozytywne wyniki są kosztowne lub niebezpieczne. W dziedzinie bezpieczeństwa publicznego, na przykład w systemach antyterrorystycznych lub przeciwdziałania przestępczości, błędne pozytywne wyniki mogą prowadzić do fałszywych oskarżeń lub nadmiernego interweniowania. Jednak czasami może być konieczne znalezienie równowagi między precyzją a czułością (recall), która mierzy zdolność modelu do wykrywania wszystkich pozytywnych przypadków.

Precission = TP / (TP + FP)

3. Recall

Recall (czułość), zwany też True Positive Rate, to miara, która ocenia, ile spośród wszystkich rzeczywiście pozytywnych przypadków model jest w stanie wykryć poprawnie. Oznacza to, że recall mierzy zdolność modelu do unikania błędów typu II (False Negatives), czyli przypadków, które zostały błędnie sklasyfikowane jako negatywne, chociaż są pozytywne. Recall równy 1 oznacza, że model wykrywa wszystkie pozytywne przypadki w zbiorze danych bez błędów typu II.

Recall = TP / (TP + FN)

4. F1

Wskaźnik F1 jest średnią harmoniczną precission i recall i służy do oceny jakości klasyfikacji binarnej lub wieloklasowej. Pomaga w znalezieniu równowagi między dokładnością identyfikacji pozytywnych przypadków (precyzją) a zdolnością do wykrywania wszystkich rzeczywiście pozytywnych przypadków (recall). W systemach detekcji oszustw finansowych, wskaźnik F1 pomaga w ocenie jakości algorytmów, które muszą balansować między precyzją (unikanie fałszywych oskarżeń) a recall (wykrywanie wszystkich przypadków oszustw).

F1 = 2 * (Precission * Recall) / (Precission + Recall)

5. Krzywa ROC

Krzywa ROC to wykres przedstawiający relację między dwiema ważnymi miarami:

  1. True Positive Rate (TPR) (czułość): Oznacza zdolność modelu do poprawnej identyfikacji pozytywnych przypadków. Jest to to samo co recall, czyli
    TPR = TP / (TP + FN)
  2. False Positive Rate (FPR) (specyficzność): Oznacza odsetek błędnie sklasyfikowanych negatywnych przypadków do wszystkich rzeczywiście negatywnych przypadków.
    FPR = FP/(TN + FP)

Należy jeszcze przypomnieć, że modele klasyfikacyjne przypisują obserwacjom prawdopodobieństwa przynależności do klasy pozytywnej (klasa 1). Wyniki te mogą być przekształcane w klasy binarne na podstawie ustalonego progu odcięcia (domyślnie 0.5). Jeśli prawdopodobieństwo przekroczy próg, obserwacja zostanie sklasyfikowana jako pozytywna (1); w przeciwnym razie jako negatywna (0).

Krzywa ROC jest wykresem TPR na osi Y i FPR na osi X w zakresie różnych wartości progów odcięcia (tzw. treshold). Każdy próg odcięcia w zakresie [0, 1] generuje jedną parę (FPR, TPR), która reprezentuje punkt na krzywej ROC.

Przekątna na krzywej ROC (niebieska przerywana linia) to linia odniesienia, która reprezentuje wyniki losowego klasyfikatora, jakby klasyfikator działał na zasadzie czystego zgadywania lub losowego klasyfikowania próbek. Jeśli klasyfikator jest skuteczny, to jego krzywa ROC powinna znajdować się jak najdalej od linii odniesienia, czyli w górnym lewym rogu wykresu.

Na podstawie wykresu można obliczyć AUC-ROC (Area Under the ROC Curve). Jest to miara jakości klasyfikatora i jak sama nazwa wskazuje mierzy obszar pod krzywą ROC. Im większy obszar pod krzywą, tym lepsza ogólna jakość klasyfikatora. AUC-ROC jest obliczane jako całka numeryczna pod krzywą ROC.

Dla narysowanej krzywej ROC można też znaleźć punkt odcięcia (treshold). Idealnie byłoby znajdować się w lewym górnym rogu, gdzie TPR = 1 i FPR = 0. Jednak przesuwając się po krzywej ROC widzimy, że wraz ze wzrostem TPR wzrasta też niepożądany FPR, dlatego ważne jest zastanowienie się i znalezienie kompromisu. Jednym z przykładem, w jaki sposób odczytać punkt odcięcia jest znalezienie takiego punktu, dla którego różnica TPR – FPR będzie największa.

Poniżej przykładowy kod tworzący krzywą ROC i znajdujący punkt odcięcia.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

# Przykładowe dane 
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)

# Podział danych na zbiór treningowy i testowy
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Utworzenie i dopasowanie modelu (tu używamy regresji logistycznej jako przykład)
model = LogisticRegression()
model.fit(X_train, y_train)

# Pobranie prawdopodobieństw przewidywanych przez model
y_scores = model.predict_proba(X_test)[:, 1]

# Obliczenie krzywej ROC
fpr, tpr, thresholds = roc_curve(y_test, y_scores)
roc_auc = auc(fpr, tpr)

# Narysowanie krzywej ROC
plt.figure(figsize=(4, 4))

plt.plot(fpr, tpr, color='darkorange', lw=2, label='Krzywa ROC')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])

# Dodanie etykiet i tytułu
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Krzywa ROC')

# Dodanie legendy w innym miejscu (lewy górny róg)
plt.legend(loc='upper left')

# Wyświetlenie wykresu
plt.show()
fpr, tpr, thresholds = roc_curve(y_test, y_scores)
m = np.argmax(tpr - fpr)
cut_off = thresholds[m]
print(cut_off)

Zobacz także:

  • Piotr Szymański

    Kategoria:

    Hejka! Zapraszam na skrót z minionych dwóch tygodni, który przyswoić możecie przy ciepłej herbatce w te mroczne, szare dni. W opublikowanym przez Google 14 listopada ostrzeżeniu wskazano kilka najważniejszych rodzajów oszustw internetowych. Uwagę zwrócono między na niebezpieczne techniki ataków typu cloaking, które nabierają nowego wymiaru dzięki wykorzystaniu sztucznej inteligencji. Cloaking polega na ukrywaniu przed użytkownikiem […]
  • Piotr Szymański

    Kategoria:

    Hejka po dłuższej przerwie! Zaczynamy świeżym tematem. Raptem kilkanaście godzin temu do użytkowników trafiła, zapowiedziana 25 lipca, funkcja SearchGPT od OpenAI, umożliwiająca, w przeciwieństwie do tradycyjnych modeli językowych, na integrację z internetem w czasie rzeczywistym. SearchGPT ma dostęp do aktualnych informacji z sieci, co pozwala na udzielanie odpowiedzi opartych na najnowszych danych. Ponadto SearchGPT dostarcza […]
  • Piotr Szymański

    Kategoria:

    Hejson! Dzisiejsza konsumpcja mediów ma to do siebie, że odbywa się na 5-6 calowym ekranie telefonu. Ma też to do siebie, że zanim zdjęcie dotrze do Ciebie, to przejdzie przez 6 konwersacji na jedynym słusznym messengerze, zatem zostanie 6-cio krotnie skompresowane. W międzyczasie, jak będziecie mieli pecha, to jakiś wujek zrobi screena, zamiast zapisać zdjęcie […]
  • Piotr Szymański

    Kategoria:

    Hej! Robimy bardzo dużo zdjęć, a co za tym idzie – wiele z nich jest niechlujnych, z zabałagnionym tłem. Możemy jednak chcieć wykorzystać je do pochwalenia się naszym ryjkiem na jakimś publicznym profilu, gdyż np. naturalne, miękkie światło korzystnie eksponuje naszą facjatę. Podejścia mogą być dwa – albo zdecydujemy się na blur bądź zupełne usunięcie […]
  • Piotr Szymański

    Kategoria:

    Strzałeczka. Nvidia przejęła OctoAI, startup specjalizujący się w optymalizacji modeli uczenia maszynowego. To już piąta akwizycja Nvidii w 2024 roku, co czyni aktualnie nam panujący rok rekordowym pod względem liczby przejęć. OctoAI, założone w 2019 roku przez Luisa Ceze, skupiło się na tworzeniu oprogramowania zwiększającego wydajność modeli uczenia maszynowego na różnych platformach sprzętowych. Oprogramowanie OctoAI […]