Log in Help
Print
Homereleasesgate-5.1-beta2-build3402-ALLpluginsMachine_Learningsrcgatecreolemlweka 〉 Wrapper.java
 
/*
 *  Copyright (c) 1998-2005, The University of Sheffield.
 *
 *  This file is part of GATE (see http://gate.ac.uk/), and is free
 *  software, licenced under the GNU Library General Public License,
 *  Version 2, June 1991 (in the distribution as file licence.html,
 *  and also available at http://gate.ac.uk/gate/licence.html).
 *
 *  Valentin Tablan 21/11/2002
 *  
 *  Modified by: Mahesh Joshi
 *  The changes include uncommenting of the SaveDatasetAsArff function
 *  and related UI action.
 *  
 *  Also, the module now supports pure dataset accumulation,
 *  without mandating the presence of a classifier or an output
 *  dataset file.
 *
 *  $Id: Wrapper.java 7030 2005-11-12 14:17:39 +0000 (Sat, 12 Nov 2005) julien_nioche $
 *
 */
package gate.creole.ml.weka;

import java.io.*;
import java.util.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import javax.swing.JFileChooser;
import javax.swing.JOptionPane;
import org.jdom.Element;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.*;
import weka.filters.Filter;
import gate.ProcessingResource;
import gate.creole.ExecutionException;
import gate.creole.ResourceInstantiationException;
import gate.creole.ml.DatasetDefintion;
import gate.creole.ml.AdvancedMLEngine;
import gate.event.StatusListener;
import gate.gui.ActionsPublisher;
import gate.gui.MainFrame;
import gate.util.*;

/**
 * Wrapper class for the WEKA Machine Learning Engine.
 * 
 * @see <a href="http://www.cs.waikato.ac.nz/ml/weka/">WEKA homepage</a>
 */
public class Wrapper implements AdvancedMLEngine, ActionsPublisher {
  public Wrapper() {
    actionsList = new ArrayList();
    actionsList.add(new LoadModelAction());
    actionsList.add(new SaveModelAction());
    actionsList.add(null);
    actionsList.add(new LoadDatasetFromArffAction());
    actionsList.add(new SaveDatasetAsArffAction());
  }
  /**
   * No clean up is needed for this wrapper, so this is just added because its
   * in the interface.
   */
  public void cleanUp() {
  }
  public void setOptions(Element optionsElem) {
    this.optionsElement = optionsElem;
  }
  /**
   * Some wrappers allow batch classification, but this one doesn't, so if it's
   * ever called just inform the user about this by throwing an exception.
   * 
   * @param instances
   *          This parameter is not used.
   * @return Nothing is ever returned - an exception is always thrown.
   * @throws ExecutionException
   */
  public List batchClassifyInstances(java.util.List instances) throws ExecutionException {
    throw new ExecutionException("The Weka wrapper does not support " + "batch classification. Remove the "
            + "<BATCH-MODE-CLASSIFICATION/> entry " + "from the XML configuration file and " + "try again.");
  }
  public void addTrainingInstance(List attributeValues) throws ExecutionException {
    Instance instance = buildInstance(attributeValues);
    addTrainingInstance(instance);
  }
  protected void addTrainingInstance(Instance instance) throws ExecutionException {
    if(classifier != null){
      if(classifier instanceof UpdateableClassifier){
        // the classifier can learn on the fly; we need to update it
        try{
          ((UpdateableClassifier)classifier).updateClassifier(instance);
        }catch(Exception e){
          throw new GateRuntimeException("Could not update updateable classifier! Problem was:\n" + e.toString());
        }
      }else{
        // the classifier is not updatebale; we need to mark the dataset as
        // changed
        dataset.add(instance);
        datasetChanged = true;
      }
    }
    if(datasetFile != null){
      // write the new instance to the file
      try{
        FileWriter fw = new FileWriter(datasetFile, true);
        fw.write(instance.toString() + "\n");
        fw.flush();
        fw.close();
      }catch(IOException ioe){
        throw new ExecutionException(ioe);
      }
    }
    // if only accumulating dataset,
    // create a cumulative dataset definition, unless
    // cleared by user
    if(onlyAccumulateDataset == true){
      dataset.add(instance);
    }
  }
  /**
   * Constructs an instance valid for the current dataset from a list of
   * attribute values.
   * 
   * @param attributeValues
   *          the values for the attributes.
   * @return an {@link weka.core.Instance} value.
   */
  protected Instance buildInstance(List attributeValues) throws ExecutionException {
    // sanity check
    if(attributeValues.size() != datasetDefinition.getAttributes().size()){ throw new ExecutionException(
            "The number of attributes provided is wrong for this dataset!"); }
    double[] values = new double[datasetDefinition.getAttributes().size()];
    int index = 0;
    Iterator attrIter = datasetDefinition.getAttributes().iterator();
    Iterator valuesIter = attributeValues.iterator();
    Instance instance = new Instance(attributeValues.size());
    instance.setDataset(dataset);
    while(attrIter.hasNext()){
      gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next();
      String value = (String)valuesIter.next();
      if(value == null){
        instance.setMissing(index);
      }else{
        if(attr.getFeature() == null){
          // boolean attribute ->the value should already be true/false
          instance.setValue(index, value);
        }else{
          // nominal, numeric or string attribute
          if(attr.getValues() != null){
            // nominal or string
            if(attr.getValues().isEmpty()){
              // string attribute
              instance.setValue(index, value);
            }else{
              // nominal attribute
              if(attr.getValues().contains(value)){
                instance.setValue(index, value);
              }else{
                Out.prln("Warning: invalid value: \"" + value + "\" for attribute " + attr.getName() + " was ignored!");
                instance.setMissing(index);
              }
            }
          }else{
            // numeric attribute
            try{
              double db = Double.parseDouble(value);
              instance.setValue(index, db);
            }catch(Exception e){
              Out.prln("Warning: invalid numeric value: \"" + value + "\" for attribute " + attr.getName()
                      + " was ignored!");
              instance.setMissing(index);
            }
          }
        }
      }
      index++;
    }
    return instance;
  }
  public void setDatasetDefinition(DatasetDefintion definition) {
    this.datasetDefinition = definition;
  }
  public Object classifyInstance(List attributeValues) throws ExecutionException {
    Instance instance = buildInstance(attributeValues);
    // double result;
    try{
      if(classifier instanceof UpdateableClassifier){
        return convertAttributeValue(classifier.classifyInstance(instance));
      }else{
        if(datasetChanged){
          if(sListener != null) sListener.statusChanged("[Re]building model...");
          classifier.buildClassifier(dataset);
          datasetChanged = false;
          if(sListener != null) sListener.statusChanged("");
        }
        if(confidenceThreshold > 0 && dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){
          // confidence set; use probability distribution
          double[] distribution = null;
          try{
            distribution = classifier.distributionForInstance(instance);
          }catch(Exception e){
            // if the classifier cannot return a distribution it will throw
            // a java.lang.Exception
            throw new ExecutionException(e);
          }
          List res = new ArrayList();
          for(int i = 0; i < distribution.length; i++){
            if(distribution[i] >= confidenceThreshold){
              res.add(dataset.classAttribute().value(i));
            }
          }
          return res;
        }else{
          // confidence not set; use simple classification
          return convertAttributeValue(classifier.classifyInstance(instance));
        }
      }
    }catch(Exception e){
      throw new ExecutionException(e);
    }
  }
  protected Object convertAttributeValue(double value) {
    gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute();
    List classValues = classAttr.getValues();
    if(classValues != null && !classValues.isEmpty()){
      // nominal attribute
      return dataset.attribute(datasetDefinition.getClassIndex()).value((int)value);
    }else{
      if(classAttr.getFeature() == null){
        // boolean attribute
        return dataset.attribute(datasetDefinition.getClassIndex()).value((int)value);
      }else{
        // numeric attribute
        return new Double(value);
      }
    }
  }
  /**
   * Initialises the classifier and prepares for running.
   * 
   * @throws GateException
   */
  public void init() throws GateException {
    // onlyAccumulateDataset is false by default
    onlyAccumulateDataset = false;
    // see if we can shout about what we're doing
    sListener = null;
    Map listeners = MainFrame.getListeners();
    if(listeners != null){
      sListener = (StatusListener)listeners.get("gate.event.StatusListener");
    }
    // find the classifier to be used
    if(sListener != null) sListener.statusChanged("Initialising classifier...");
    Element classifierElem = optionsElement.getChild("CLASSIFIER");
    if(classifierElem == null){
      Out.prln("Warning (WEKA ML engine): no classifier selected;" + " dataset collection only!");
      classifier = null;
    }else{
      String classifierClassName = classifierElem.getTextTrim();
      // get the options for the classiffier
      String optionsString = null;
      if(sListener != null) sListener.statusChanged("Setting classifier options...");
      Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS");
      if(classifierOptionsElem != null){
        optionsString = classifierOptionsElem.getTextTrim();
      }
      // new style overrides old style
      org.jdom.Attribute optionsAttribute = classifierElem.getAttribute("OPTIONS");
      if(optionsAttribute != null){
        optionsString = optionsAttribute.getValue().trim();
      }
      String[] options = parseOptions(optionsString);
      try{
        classifier = Classifier.forName(classifierClassName, options);
      }catch(Exception e){
        throw new GateException(e);
      }
      // if we have any filters apply them to the classifer
      List filterElems = optionsElement.getChildren("FILTER");
      if(filterElems != null && filterElems.size() > 0){
        Iterator elemIter = filterElems.iterator();
        while(elemIter.hasNext()){
          Element filterElem = (Element)elemIter.next();
          String filterClassName = filterElem.getTextTrim();
          String filterOptionsString = "";
          org.jdom.Attribute optionsAttr = filterElem.getAttribute("OPTIONS");
          if(optionsAttr != null){
            filterOptionsString = optionsAttr.getValue().trim();
          }
          // create the new filter
          try{
            Class filterClass = Class.forName(filterClassName);
            if(!Filter.class.isAssignableFrom(filterClass)){ throw new ResourceInstantiationException(filterClassName
                    + " is not a " + Filter.class.getName() + "!"); }
            Filter aFilter = (Filter)filterClass.newInstance();
            // apply the options to the filter
            if(filterOptionsString != null && filterOptionsString.length() > 0){
              if(!(aFilter instanceof OptionHandler)){ throw new ResourceInstantiationException(filterClassName
                      + " cannot handle options!"); }
              options = parseOptions(filterOptionsString);
              ((OptionHandler)aFilter).setOptions(options);
            }
            // apply the filter to the classifier
            FilteredClassifier Fclassifier = new FilteredClassifier();
            Fclassifier.setClassifier(classifier);
            Fclassifier.setFilter(aFilter);
            classifier = Fclassifier;
          }catch(Exception e){
            throw new ResourceInstantiationException(e);
          }
        }
      }
      Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD");
      if(anElement != null){
        try{
          confidenceThreshold = Double.parseDouble(anElement.getTextTrim());
        }catch(Exception e){
          throw new GateException("Could not parse confidence threshold value: " + anElement.getTextTrim() + "!");
        }
        // in the new version of WEKA all classifiers might be distribution
        // classifiers - there is no way to distinguish between them
        // if(!(classifier instanceof DistributionClassifier)){
        // throw new GateException(
        // "Cannot use confidence threshold with classifier: " +
        // classifier.getClass().getName() + "!");
        // }
      }
    }
    // find the file to be used for dataset
    Element datafileElem = optionsElement.getChild("DATASET-FILE");
    if(datafileElem != null){
      datasetFile = new File(datafileElem.getTextTrim());
      try{
        Out.prln("Warning (WEKA ML engine): writing dataset as ARFF to " + datasetFile.getCanonicalPath());
      }catch(IOException ioe){
        throw new ResourceInstantiationException(ioe);
      }
    }else{
      if(classifier == null){
        // both classifier and datasetFile are null
        // default to only accumulating dataset.
        onlyAccumulateDataset = true;
        Out.prln("Warning: Neither classifier or dataset file are specified in the "
                + "definition!\nThis only accumulates dataset "
                + "elements internally.\nMake sure to use the SaveDatasetAsArff"
                + "action or function to save the dataset to a file.");
      }
    }
    // initialise the dataset
    if(sListener != null) sListener.statusChanged("Initialising dataset...");
    FastVector attributes = new FastVector();
    weka.core.Attribute classAttribute;
    Iterator attIter = datasetDefinition.getAttributes().iterator();
    while(attIter.hasNext()){
      gate.creole.ml.Attribute aGateAttr = (gate.creole.ml.Attribute)attIter.next();
      weka.core.Attribute aWekaAttribute = null;
      if(aGateAttr.getValues() != null){
        // nominal or String attribute
        if(!aGateAttr.getValues().isEmpty()){
          // nominal attribute
          FastVector attrValues = new FastVector(aGateAttr.getValues().size());
          Iterator valIter = aGateAttr.getValues().iterator();
          while(valIter.hasNext()){
            attrValues.addElement(valIter.next());
          }
          aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), attrValues);
        }else{
          // VALUES element present but no values defined -> String attribute
          aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), (FastVector)null);
        }
      }else{
        if(aGateAttr.getFeature() == null){
          // boolean attribute ([lack of] presence of an annotation)
          FastVector attrValues = new FastVector(2);
          attrValues.addElement("true");
          attrValues.addElement("false");
          aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), attrValues);
        }else{
          // feature is not null but no values provided -> numeric attribute
          aWekaAttribute = new weka.core.Attribute(aGateAttr.getName());
        }
      }
      if(aGateAttr.isClass()) classAttribute = aWekaAttribute;
      attributes.addElement(aWekaAttribute);
    }
    dataset = new Instances("Weka ML Engine Dataset", attributes, 0);
    dataset.setClassIndex(datasetDefinition.getClassIndex());
    // write the head of the datafile
    if(datasetFile != null){
      try{
        FileWriter fw = new FileWriter(datasetFile);
        fw.write(dataset.toString());
        fw.flush();
        fw.close();
      }catch(IOException ioe){
        throw new ResourceInstantiationException(ioe);
      }
    }
    if(classifier != null && classifier instanceof UpdateableClassifier){
      try{
        classifier.buildClassifier(dataset);
      }catch(Exception e){
        throw new ResourceInstantiationException(e);
      }
    }
    if(sListener != null) sListener.statusChanged("");
  }
  protected String[] parseOptions(String optionsString) {
    String[] options = null;
    if(optionsString == null || optionsString.length() == 0){
      options = new String[]{};
    }else{
      List optionsList = new ArrayList();
      StringTokenizer strTok = new StringTokenizer(optionsString, " ", false);
      while(strTok.hasMoreTokens()){
        optionsList.add(strTok.nextToken());
      }
      options = (String[])optionsList.toArray(new String[optionsList.size()]);
    }
    return options;
  }
  /**
   * Loads the state of this engine from previously saved data.
   * 
   * @param is
   */
  public void load(InputStream is) throws IOException {
    if(sListener != null) sListener.statusChanged("Loading model...");
    ObjectInputStream ois = new ObjectInputStream(is);
    try{
      classifier = (Classifier)ois.readObject();
      dataset = (Instances)ois.readObject();
      datasetDefinition = (DatasetDefintion)ois.readObject();
      datasetChanged = ois.readBoolean();
      confidenceThreshold = ois.readDouble();
    }catch(ClassNotFoundException cnfe){
      throw new GateRuntimeException(cnfe.toString());
    }
    ois.close();
    if(sListener != null) sListener.statusChanged("");
  }
  /**
   * Saves the state of the engine for reuse at a later time.
   * 
   * @param os
   */
  public void save(OutputStream os) throws IOException {
    if(sListener != null) sListener.statusChanged("Saving model...");
    ObjectOutputStream oos = new ObjectOutputStream(os);
    oos.writeObject(classifier);
    oos.writeObject(dataset);
    oos.writeObject(datasetDefinition);
    oos.writeBoolean(datasetChanged);
    oos.writeDouble(confidenceThreshold);
    oos.flush();
    oos.close();
    if(sListener != null) sListener.statusChanged("");
  }
  /**
   * Gets the list of actions that can be performed on this resource.
   * 
   * @return a List of Action objects (or null values)
   */
  public List getActions() {
    return actionsList;
  }
  /**
   * Registers the PR using the engine with the engine itself.
   * 
   * @param pr
   *          the processing resource that owns this engine.
   */
  public void setOwnerPR(ProcessingResource pr) {
    this.owner = pr;
  }
  public DatasetDefintion getDatasetDefinition() {
    return datasetDefinition;
  }
  public void saveDatasetAsARFF(FileWriter writer) {
    try{
      writer.write(dataset.toString());
      writer.flush();
    }catch(IOException ioe){
      throw new GateRuntimeException(ioe.getMessage());
    }
  }
  public void loadDatasetFromArff(FileReader reader) throws IOException, ExecutionException, Exception {
    Instances newDataset = new Instances(reader);
    if(!dataset.equalHeaders(newDataset))
      throw new ExecutionException("Loaded dataset incompatible with the one " + " in the definition!");
    Enumeration instEnum = newDataset.enumerateInstances();
    while(instEnum.hasMoreElements()){
      addTrainingInstance((Instance)instEnum.nextElement());
    }
  }
  public boolean supportsBatchMode() {
    return false;
  }

  protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction {
    public SaveDatasetAsArffAction() {
      super("Save dataset as ARFF");
      putValue(SHORT_DESCRIPTION, "Saves the dataset to a file in ARFF format");
    }
    public void actionPerformed(java.awt.event.ActionEvent evt) {
      Runnable runnable = new Runnable() {
        public void run() {
          JFileChooser fileChooser = MainFrame.getFileChooser();
          fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
          fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
          fileChooser.setMultiSelectionEnabled(false);
          if(fileChooser.showSaveDialog(null) == JFileChooser.APPROVE_OPTION){
            File file = fileChooser.getSelectedFile();
            try{
              MainFrame.lockGUI("Saving dataset...");
              FileWriter fw = new FileWriter(file.getCanonicalPath(), false);
              saveDatasetAsARFF(fw);
              fw.close();
            }catch(IOException ioe){
              JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "Gate", JOptionPane.ERROR_MESSAGE);
              ioe.printStackTrace(Err.getPrintWriter());
            }finally{
              MainFrame.unlockGUI();
            }
          }
        }
      };
      Thread thread = new Thread(runnable, "DatasetSaver(ARFF)");
      thread.setPriority(Thread.MIN_PRIORITY);
      thread.start();
    }
  }
  protected class LoadDatasetFromArffAction extends javax.swing.AbstractAction {
    public LoadDatasetFromArffAction() {
      super("Load data from ARFF");
      putValue(SHORT_DESCRIPTION, "Loads training data from a file in ARFF format and "
              + "appends it to the current dataset.");
    }
    public void actionPerformed(java.awt.event.ActionEvent evt) {
      Runnable runnable = new Runnable() {
        public void run() {
          JFileChooser fileChooser = MainFrame.getFileChooser();
          fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
          fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
          fileChooser.setMultiSelectionEnabled(false);
          if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){
            File file = fileChooser.getSelectedFile();
            try{
              MainFrame.lockGUI("Loading dataset...");
              FileReader reader = new FileReader(file.getCanonicalPath());
              loadDatasetFromArff(reader);
              reader.close();
            }catch(Exception e){
              JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + e.toString(), "GATE", JOptionPane.ERROR_MESSAGE);
              e.printStackTrace(Err.getPrintWriter());
            }finally{
              MainFrame.unlockGUI();
            }
          }
        }
      };
      Thread thread = new Thread(runnable, "DatasetSaver(ARFF)");
      thread.setPriority(Thread.MIN_PRIORITY);
      thread.start();
    }
  }
  protected class SaveModelAction extends javax.swing.AbstractAction {
    public SaveModelAction() {
      super("Save model");
      putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
    }
    public void actionPerformed(java.awt.event.ActionEvent evt) {
      Runnable runnable = new Runnable() {
        public void run() {
          JFileChooser fileChooser = MainFrame.getFileChooser();
          fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
          fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
          fileChooser.setMultiSelectionEnabled(false);
          if(fileChooser.showSaveDialog(null) == JFileChooser.APPROVE_OPTION){
            File file = fileChooser.getSelectedFile();
            try{
              MainFrame.lockGUI("Saving ML model...");
              save(new GZIPOutputStream(new FileOutputStream(file.getCanonicalPath(), false)));
            }catch(IOException ioe){
              JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "GATE", JOptionPane.ERROR_MESSAGE);
              ioe.printStackTrace(Err.getPrintWriter());
            }finally{
              MainFrame.unlockGUI();
            }
          }
        }
      };
      Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
      thread.setPriority(Thread.MIN_PRIORITY);
      thread.start();
    }
  }
  protected class LoadModelAction extends javax.swing.AbstractAction {
    public LoadModelAction() {
      super("Load model");
      putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
    }
    public void actionPerformed(java.awt.event.ActionEvent evt) {
      Runnable runnable = new Runnable() {
        public void run() {
          JFileChooser fileChooser = MainFrame.getFileChooser();
          fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
          fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
          fileChooser.setMultiSelectionEnabled(false);
          if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){
            File file = fileChooser.getSelectedFile();
            try{
              MainFrame.lockGUI("Loading model...");
              load(new GZIPInputStream(new FileInputStream(file)));
            }catch(IOException ioe){
              JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "GATE", JOptionPane.ERROR_MESSAGE);
              ioe.printStackTrace(Err.getPrintWriter());
            }finally{
              MainFrame.unlockGUI();
            }
          }
        }
      };
      Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
      thread.setPriority(Thread.MIN_PRIORITY);
      thread.start();
    }
  }

  protected DatasetDefintion datasetDefinition;
  double confidenceThreshold = 0;
  /**
   * The WEKA classifier used by this wrapper
   */
  protected Classifier classifier;
  /**
   * The dataset used for training
   */
  protected Instances dataset;
  /**
   * The JDom element contaning the options fro this wrapper.
   */
  protected Element optionsElement;
  /**
   * Marks whether the dataset was changed since the last time the classifier
   * was built.
   */
  protected boolean datasetChanged = false;
  protected File datasetFile;
  protected List actionsList;
  protected ProcessingResource owner;
  protected StatusListener sListener;
  /**
   * This variable is set when the ML configuration file has neither the
   * classifier nor the output dataset file option. User is responsible for
   * explicitly saving the dataset using the now uncommented SaveDatasetAsArff
   * function / action.
   */
  protected boolean onlyAccumulateDataset;
}