Perzeptron Algorithmus (Duale Form)

Der Perzeptron Algorithmus wird verwendet um Datenwolken voneinander zu trennen. Diese Datenwolken können z.B. Messwerte im Vektorraum sein. In der dualen Form ist es zusätzlich auch einfach möglich, die Daten in einen sogenannten “feature space” abzubilden. Dadurch können auch Datenwolken die vormals nicht linear trennbar waren, da sie sich zum Beispiel überlagert haben, voneinander getrennt werden.

Ich habe den Perzeptron Algorithmus wie er im Buch “An Introduction to Support Vector Machines and other Kernel-based Learning Methods” beschrieben wird, hier in Java umgesetzt. Das folgende Java Applet kann mit der linken / rechten Maustaste bedient werden. Mit der linken Maustaste setzt man rote Datenpunkte und mit der rechten Maustaste blaue Datenpunkte. Sobald man auf den Button “Start Perzeptron” klickt wird der Algorithmus ausgeführt und fittet eine Hyperebene zwischen die Datenwolken (Aber Achtung: der Perzeptron Algorithmus funktioniert hier nur korrekt wenn die Datenpunkte sich auch wirklich trennen lassen, ansonsten kommt eine Fehlermeldung). Diese Hyperebene wird hier jedoch zufällig zwischen den Datenwolken platziert. Damit diese Hyperebene optimal platziert werden kann, muss statt einem Perzeptron eine Support Vektor Machine angewandt werden. Ich habe für diesen Perzeptron Algorithmus 3 Kerne als Beispiel implementiert (linearer, polynomialer und gauss kern).

Der folgende Sourcecode ist nur als Beispiel für eine ganze einfache Implementierung eines Perzeptrons gedacht. Wer den kompletten Code für das Applet möchte, kann ihn hier als Eclipse Projekt herunterladen:

Perzeptron-Projekt

import java.applet.Applet;  
import java.awt.Color;  
import java.awt.Graphics;  
import java.awt.Graphics2D;  
import java.awt.event.MouseEvent;  
import java.awt.geom.Line2D;  
import java.awt.geom.Rectangle2D;  
import java.util.ArrayList;  
import java.util.List;

import javax.swing.JButton;  
import javax.swing.event.MouseInputAdapter;

public class DualPerzeptron2DApplet extends Applet {  
  public static class Label {
    public final static Double KLASSE_A = -1.0;
    public final static Double KLASSE_B = 1.0;
    private Double klasse;

    public Label(Double klasse) {
      this.klasse = klasse;
    }

    public void setKlasse(Double klasse) {
      this.klasse = klasse;
    }

    public Double getKlasse() {
      return this.klasse;
    }
  }

  public class Vektor2D {
    private Double x;
    private Double y;

    public Vektor2D(Double x, Double y) {
      this.x = x;
      this.y = y;
    }

    public Double getX() {
      return x;
    }

    public Double getY() {
      return y;
    }
  }

  private static final long serialVersionUID = -828153296140531761L;

  private List<Vektor2D> xList = new ArrayList<Vektor2D>();
  private List<Label> yList = new ArrayList<Label>();
  private List<Double> alphas = new ArrayList<Double>();
  private Double bias;
  private int l;
  private Double R;

  // hyperPlane besteht aus zwei Vektoren um die Hyperebene als Linie
  // zeichnen zu können
  private List<Vektor2D> hyperPlane = new ArrayList<Vektor2D>();

  public void init() {
    MouseInputAdapter mouseListener = new MouseInputAdapter() {
      public synchronized void mouseClicked(MouseEvent event) {
        if (event.getSource() instanceof JButton) {
          learn();
          createHyperplane();
          repaint();
          return;
        }

        Vektor2D xNew = new Vektor2D((double) event.getX(),
            (double) event.getY());
        Label yNew = null;

        if (event.getButton() == MouseEvent.BUTTON1)
          yNew = new Label(Label.KLASSE_A);
        else if (event.getButton() == MouseEvent.BUTTON3)
          yNew = new Label(Label.KLASSE_B);
        else
          return;

        xList.add(xNew);
        yList.add(yNew);
        alphas.add(0.0);
        repaint();
      }
    };

    setSize(400, 300);
    this.addMouseListener(mouseListener);

    JButton button = new JButton();
    button.setText("Start Perzeptron");
    button.addMouseListener(mouseListener);
    this.add(button);
  }

  private void createHyperplane() {
    hyperPlane = new ArrayList<Vektor2D>();

    // x muss fix eingesetzt werden, wir wollen die 0 Stellen
    // nur für y=0 und y=400 bestimmen --> gesucht: f(y) = 0

    // der X Wert vom Gewichtsvektor und das Bias wird auf die rechte Seite
    // der Gleichung geholt (durch Multiplikation mit -1) und durch den Y Wert
    // des Gewichtsvektors geteilt. Das Ergebnis ist der Y-Wert der Hyperebene
    // für den gegebenen X-Wert (0,400).

    double xFirst = 0;
    Double result1 = ((getWeightVektor().getX() * xFirst + bias) * (-1))
        / getWeightVektor().getY();
    Vektor2D vektor1 = new Vektor2D((double) xFirst, result1);
    hyperPlane.add(vektor1);

    double xLast = 400;
    Double result2 = ((getWeightVektor().getX() * xLast + bias) * (-1))
        / getWeightVektor().getY();
    Vektor2D vektor2 = new Vektor2D((double) xLast, result2);
    hyperPlane.add(vektor2);
  }

  @Override
  public void paint(Graphics g) {
    Graphics2D g2 = (Graphics2D) g;
    super.paint(g2);

    // sämtliche Punkte für beide Klassen zeichnen

    for (int i = 0; i < xList.size(); i++) {
      Vektor2D xValue = xList.get(i);
      Label yLabel = yList.get(i);

      if (yLabel.getKlasse() == Label.KLASSE_A)
        g2.setColor(Color.red);
      else if (yLabel.getKlasse() == Label.KLASSE_B)
        g2.setColor(Color.green);

      g2.draw(new Rectangle2D.Double(xValue.getX(), xValue.getY(), 3, 3));
      g2.drawString(xValue.getX() + " / " + xValue.getY(),
          Math.round(xValue.getX()), Math.round(xValue.getY()));
    }

    // wenn noch keine Hyperebene bestimmt wurde, nichts tun. Ansonsten den
    // ersten
    // und den letzten Punkt der Hyperebene verwenden um diese zu zeichnen.
    if (hyperPlane.size() == 0)
      return;

    g2.setColor(Color.blue);

    Vektor2D first = hyperPlane.get(0);
    Vektor2D last = hyperPlane.get(1);

    g2.draw(new Line2D.Double(first.getX(), first.getY(), last.getX(), last
        .getY()));
  }

  private void learn() {
    // Perzeptron Lernphase
    bias = 0.0;
    R = getMaximumLengthOfAllVectors();
    l = xList.size();

    Boolean mistake = true;
    while (mistake == true) {
      mistake = false;

      for (int i = 0; i < l; i++) {
        Double yi = yList.get(i).getKlasse();
        Vektor2D xi = xList.get(i);
        Double alphai = alphas.get(i);

        if ((yi * (vektorSkalarProdukt(getWeightVektor(), xi) + bias)) <= 0) {
          alphas.set(i, (alphai + 1.0));
          bias = bias + yi * R * R;

          mistake = true;
        }
      }
    }
  }

  private Vektor2D getWeightVektor() {
    // gibt einen Gewichtsvektor als Linearkombination der Trainingspunkte
    // zurück (duale Form)

    Vektor2D weight = new Vektor2D(0.0, 0.0);
    for (int j = 0; j < l; j++) {
      Double alphaj = alphas.get(j);
      Vektor2D xj = xList.get(j);
      Double yj = yList.get(j).getKlasse();

      Vektor2D newWeight = vektorSkalarMultiplikation(alphaj,
          vektorSkalarMultiplikation(yj, xj));
      weight = vektorAddition(weight, newWeight);
    }

    return weight;
  }

  private Double vektorSkalarProdukt(Vektor2D v1, Vektor2D v2) {
    // das Skalarprodukt in einem Vektorraum
    return (v1.getX() * v2.getX()) + (v1.getY() * v2.getY());
  }

  private Vektor2D vektorAddition(Vektor2D v1, Vektor2D v2) {
    // die Vektoraddition in einem Vektorraum
    Double x = v1.getX() + v2.getX();
    Double y = v1.getY() + v2.getY();
    return new Vektor2D(x, y);
  }

  private Vektor2D vektorSkalarMultiplikation(Double skalar, Vektor2D v1) {
    // die Vektor-Skalar-Multiplikation in einem Vektorraum
    Double x = v1.getX() * skalar;
    Double y = v1.getY() * skalar;
    return new Vektor2D(x, y);
  }

  private Double getVektorLength(Vektor2D vektor) {
    // die Länge eines Vektors in einem Vektorraum (Satz des Pythagoras)
    return Math.sqrt(vektor.getX() * vektor.getX() + vektor.getY()
        * vektor.getY());
  }

  private Double getMaximumLengthOfAllVectors() {
    // ermittle den Vektor mit der größten Länge
    Double max = 0.0;
    for (Vektor2D vektor : xList) {

      Double vektorLength = getVektorLength(vektor);
      if (vektorLength > max)
        max = vektorLength;
    }

    return max;
  }
}
comments powered by Disqus