Come utilizzare la propagazione dell'etichetta Scikit-Learn su dati strutturati grafici?


11

Come parte della mia ricerca, sono interessato a eseguire la propagazione delle etichette su un grafico. Sono particolarmente interessato a questi due metodi:

Ho visto che scikit-learn offre un modello per farlo. Tuttavia, questo modello dovrebbe essere applicato su dati strutturati vettoriali ( cioè punti dati).

Il modello crea una matrice di affinità dai punti dati usando un kernel, quindi esegue l'algoritmo sulla matrice costruita. Vorrei poter inserire direttamente la matrice di adiacenza del mio grafico al posto della matrice di somiglianza.

Qualche idea su come raggiungerlo? Oppure conosci qualche libreria Python che ti permetterà di eseguire la propagazione delle etichette direttamente su dati strutturati graficamente per i due suddetti metodi?

Grazie in anticipo per il vostro aiuto!


Hai controllato il codice sorgente di Scikit-learn per vedere cosa fa dopo aver calcolato la matrice di affinità? Forse potrebbe "copiare" il codice dopo quella parte per applicarlo direttamente alla matrice di adiacenza.
Tasos,

Grazie per il tuo commento! Quindi, in realtà, questo è quello che sto facendo attualmente, ma alcune parti del codice che devo modificare per soddisfare le mie esigenze sono in qualche modo criptiche. Temo che riscrivere quelle parti causerà errori. Speravo che esistesse un metodo più semplice.
Thibaud Martinez,

1
Il codice sorgente su github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… - afferma che le implementazioni dovrebbero sovrascrivere il metodo _build_graph. Quindi, nativamente, dovresti provare a creare una classe derivata che accetta una matrice pre-calcolata.
Mikalai,

Risposte:


2

Rispondere alla mia domanda qui, poiché spero che possa essere utile ad alcuni lettori.

Scikit-learn è progettato principalmente per gestire dati strutturati vettoriali. Quindi, se si desidera eseguire la propagazione dell'etichetta / la diffusione dell'etichetta su dati strutturati in grafici, probabilmente è meglio reimplementare il metodo da soli piuttosto che utilizzare l'interfaccia di Scikit.

Ecco un'implementazione di Label Propagation e Label Spreading in PyTorch.

I due metodi in generale seguono gli stessi passaggi algoritmici, con variazioni su come viene normalizzata la matrice di adiacenza e su come vengono propagate le etichette ad ogni passaggio. Creiamo quindi una classe di base per i nostri due modelli.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Il modello accetta come input la matrice di adiacenza del grafico e le etichette dei nodi. Le etichette hanno la forma di un vettore di un numero intero che indica il numero di classe di ciascun nodo con un -1 nella posizione dei nodi senza etichetta.

L'algoritmo di propagazione dell'etichetta è presentato di seguito.

W: matrice di adiacenza del grafico Calcola la matrice dei gradi diagonale D di DioioΣjWioj Inizializzare Y^(0)(y1,...,yl,0,0,...,0) iterate  1. Y^(t+1)D-1WY^(t) 2. Y^l(t+1)Yl fino alla convergenza verso Y^() Punto etichetta Xio dal segno di y^io()

Da Xiaojin Zhu e Zoubin Ghahramani. Imparare da dati etichettati e senza etichetta con la propagazione delle etichette. Rapporto tecnico CMU-CALD-02-107, Carnegie Mellon University, 2002

Otteniamo la seguente implementazione.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

L'algoritmo di diffusione dell'etichetta è:

W: matrice di adiacenza del grafico Calcola la matrice dei gradi diagonale D di DioioΣjWioj Calcola il grafico normalizzato Laplacian LD-1/2WD-1/2 Inizializzare Y^(0)(y1,...,yl,0,0,...,0) Scegli un parametro α[0,1) iterate Y^(t+1)αLY^(t)+(1-α)Y^(0) fino alla convergenza verso Y^() Punto etichetta Xio dal segno di y^io()

Da Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Imparare con coerenza locale e globale (2004)

L'implementazione è, quindi, la seguente.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Proviamo ora i nostri modelli di propagazione su dati sintetici. Per fare ciò, scegliamo di usare un grafico cavernicolo .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

I modelli implementati funzionano correttamente e consentono di rilevare le comunità nel grafico.

Propagazione delle etichette e implementazione delle etichette testate su un grafico cavernicolo

Nota: i metodi di propagazione presentati sono pensati per essere utilizzati su grafici non indirizzati.

Il codice è disponibile come taccuino interattivo Jupyter qui .

Raccoglitore

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.