/* * Copyright (c) 1998-2005, The University of Sheffield. * * This file is part of GATE (see http://gate.ac.uk/), and is free * software, licenced under the GNU Library General Public License, * Version 2, June 1991 (in the distribution as file licence.html, * and also available at http://gate.ac.uk/gate/licence.html). * * Valentin Tablan 21/11/2002 * * Modified by: Mahesh Joshi * The changes include uncommenting of the SaveDatasetAsArff function * and related UI action. * * Also, the module now supports pure dataset accumulation, * without mandating the presence of a classifier or an output * dataset file. * * $Id: Wrapper.java 7030 2005-11-12 14:17:39 +0000 (Sat, 12 Nov 2005) julien_nioche $ * */ package gate.creole.ml.weka; import java.io.*; import java.util.*; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import javax.swing.JFileChooser; import javax.swing.JOptionPane; import org.jdom.Element; import weka.classifiers.Classifier; import weka.classifiers.UpdateableClassifier; import weka.classifiers.meta.FilteredClassifier; import weka.core.*; import weka.filters.Filter; import gate.ProcessingResource; import gate.creole.ExecutionException; import gate.creole.ResourceInstantiationException; import gate.creole.ml.DatasetDefintion; import gate.creole.ml.AdvancedMLEngine; import gate.event.StatusListener; import gate.gui.ActionsPublisher; import gate.gui.MainFrame; import gate.util.*; /** * Wrapper class for the WEKA Machine Learning Engine. * * @see <a href="http://www.cs.waikato.ac.nz/ml/weka/">WEKA homepage</a> */ public class Wrapper implements AdvancedMLEngine, ActionsPublisher { public Wrapper() { actionsList = new ArrayList(); actionsList.add(new LoadModelAction()); actionsList.add(new SaveModelAction()); actionsList.add(null); actionsList.add(new LoadDatasetFromArffAction()); actionsList.add(new SaveDatasetAsArffAction()); } /** * No clean up is needed for this wrapper, so this is just added because its * in the interface. */ public void cleanUp() { } public void setOptions(Element optionsElem) { this.optionsElement = optionsElem; } /** * Some wrappers allow batch classification, but this one doesn't, so if it's * ever called just inform the user about this by throwing an exception. * * @param instances * This parameter is not used. * @return Nothing is ever returned - an exception is always thrown. * @throws ExecutionException */ public List batchClassifyInstances(java.util.List instances) throws ExecutionException { throw new ExecutionException("The Weka wrapper does not support " + "batch classification. Remove the " + "<BATCH-MODE-CLASSIFICATION/> entry " + "from the XML configuration file and " + "try again."); } public void addTrainingInstance(List attributeValues) throws ExecutionException { Instance instance = buildInstance(attributeValues); addTrainingInstance(instance); } protected void addTrainingInstance(Instance instance) throws ExecutionException { if(classifier != null){ if(classifier instanceof UpdateableClassifier){ // the classifier can learn on the fly; we need to update it try{ ((UpdateableClassifier)classifier).updateClassifier(instance); }catch(Exception e){ throw new GateRuntimeException("Could not update updateable classifier! Problem was:\n" + e.toString()); } }else{ // the classifier is not updatebale; we need to mark the dataset as // changed dataset.add(instance); datasetChanged = true; } } if(datasetFile != null){ // write the new instance to the file try{ FileWriter fw = new FileWriter(datasetFile, true); fw.write(instance.toString() + "\n"); fw.flush(); fw.close(); }catch(IOException ioe){ throw new ExecutionException(ioe); } } // if only accumulating dataset, // create a cumulative dataset definition, unless // cleared by user if(onlyAccumulateDataset == true){ dataset.add(instance); } } /** * Constructs an instance valid for the current dataset from a list of * attribute values. * * @param attributeValues * the values for the attributes. * @return an {@link weka.core.Instance} value. */ protected Instance buildInstance(List attributeValues) throws ExecutionException { // sanity check if(attributeValues.size() != datasetDefinition.getAttributes().size()){ throw new ExecutionException( "The number of attributes provided is wrong for this dataset!"); } double[] values = new double[datasetDefinition.getAttributes().size()]; int index = 0; Iterator attrIter = datasetDefinition.getAttributes().iterator(); Iterator valuesIter = attributeValues.iterator(); Instance instance = new Instance(attributeValues.size()); instance.setDataset(dataset); while(attrIter.hasNext()){ gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next(); String value = (String)valuesIter.next(); if(value == null){ instance.setMissing(index); }else{ if(attr.getFeature() == null){ // boolean attribute ->the value should already be true/false instance.setValue(index, value); }else{ // nominal, numeric or string attribute if(attr.getValues() != null){ // nominal or string if(attr.getValues().isEmpty()){ // string attribute instance.setValue(index, value); }else{ // nominal attribute if(attr.getValues().contains(value)){ instance.setValue(index, value); }else{ Out.prln("Warning: invalid value: \"" + value + "\" for attribute " + attr.getName() + " was ignored!"); instance.setMissing(index); } } }else{ // numeric attribute try{ double db = Double.parseDouble(value); instance.setValue(index, db); }catch(Exception e){ Out.prln("Warning: invalid numeric value: \"" + value + "\" for attribute " + attr.getName() + " was ignored!"); instance.setMissing(index); } } } } index++; } return instance; } public void setDatasetDefinition(DatasetDefintion definition) { this.datasetDefinition = definition; } public Object classifyInstance(List attributeValues) throws ExecutionException { Instance instance = buildInstance(attributeValues); // double result; try{ if(classifier instanceof UpdateableClassifier){ return convertAttributeValue(classifier.classifyInstance(instance)); }else{ if(datasetChanged){ if(sListener != null) sListener.statusChanged("[Re]building model..."); classifier.buildClassifier(dataset); datasetChanged = false; if(sListener != null) sListener.statusChanged(""); } if(confidenceThreshold > 0 && dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){ // confidence set; use probability distribution double[] distribution = null; try{ distribution = classifier.distributionForInstance(instance); }catch(Exception e){ // if the classifier cannot return a distribution it will throw // a java.lang.Exception throw new ExecutionException(e); } List res = new ArrayList(); for(int i = 0; i < distribution.length; i++){ if(distribution[i] >= confidenceThreshold){ res.add(dataset.classAttribute().value(i)); } } return res; }else{ // confidence not set; use simple classification return convertAttributeValue(classifier.classifyInstance(instance)); } } }catch(Exception e){ throw new ExecutionException(e); } } protected Object convertAttributeValue(double value) { gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute(); List classValues = classAttr.getValues(); if(classValues != null && !classValues.isEmpty()){ // nominal attribute return dataset.attribute(datasetDefinition.getClassIndex()).value((int)value); }else{ if(classAttr.getFeature() == null){ // boolean attribute return dataset.attribute(datasetDefinition.getClassIndex()).value((int)value); }else{ // numeric attribute return new Double(value); } } } /** * Initialises the classifier and prepares for running. * * @throws GateException */ public void init() throws GateException { // onlyAccumulateDataset is false by default onlyAccumulateDataset = false; // see if we can shout about what we're doing sListener = null; Map listeners = MainFrame.getListeners(); if(listeners != null){ sListener = (StatusListener)listeners.get("gate.event.StatusListener"); } // find the classifier to be used if(sListener != null) sListener.statusChanged("Initialising classifier..."); Element classifierElem = optionsElement.getChild("CLASSIFIER"); if(classifierElem == null){ Out.prln("Warning (WEKA ML engine): no classifier selected;" + " dataset collection only!"); classifier = null; }else{ String classifierClassName = classifierElem.getTextTrim(); // get the options for the classiffier String optionsString = null; if(sListener != null) sListener.statusChanged("Setting classifier options..."); Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS"); if(classifierOptionsElem != null){ optionsString = classifierOptionsElem.getTextTrim(); } // new style overrides old style org.jdom.Attribute optionsAttribute = classifierElem.getAttribute("OPTIONS"); if(optionsAttribute != null){ optionsString = optionsAttribute.getValue().trim(); } String[] options = parseOptions(optionsString); try{ classifier = Classifier.forName(classifierClassName, options); }catch(Exception e){ throw new GateException(e); } // if we have any filters apply them to the classifer List filterElems = optionsElement.getChildren("FILTER"); if(filterElems != null && filterElems.size() > 0){ Iterator elemIter = filterElems.iterator(); while(elemIter.hasNext()){ Element filterElem = (Element)elemIter.next(); String filterClassName = filterElem.getTextTrim(); String filterOptionsString = ""; org.jdom.Attribute optionsAttr = filterElem.getAttribute("OPTIONS"); if(optionsAttr != null){ filterOptionsString = optionsAttr.getValue().trim(); } // create the new filter try{ Class filterClass = Class.forName(filterClassName); if(!Filter.class.isAssignableFrom(filterClass)){ throw new ResourceInstantiationException(filterClassName + " is not a " + Filter.class.getName() + "!"); } Filter aFilter = (Filter)filterClass.newInstance(); // apply the options to the filter if(filterOptionsString != null && filterOptionsString.length() > 0){ if(!(aFilter instanceof OptionHandler)){ throw new ResourceInstantiationException(filterClassName + " cannot handle options!"); } options = parseOptions(filterOptionsString); ((OptionHandler)aFilter).setOptions(options); } // apply the filter to the classifier FilteredClassifier Fclassifier = new FilteredClassifier(); Fclassifier.setClassifier(classifier); Fclassifier.setFilter(aFilter); classifier = Fclassifier; }catch(Exception e){ throw new ResourceInstantiationException(e); } } } Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD"); if(anElement != null){ try{ confidenceThreshold = Double.parseDouble(anElement.getTextTrim()); }catch(Exception e){ throw new GateException("Could not parse confidence threshold value: " + anElement.getTextTrim() + "!"); } // in the new version of WEKA all classifiers might be distribution // classifiers - there is no way to distinguish between them // if(!(classifier instanceof DistributionClassifier)){ // throw new GateException( // "Cannot use confidence threshold with classifier: " + // classifier.getClass().getName() + "!"); // } } } // find the file to be used for dataset Element datafileElem = optionsElement.getChild("DATASET-FILE"); if(datafileElem != null){ datasetFile = new File(datafileElem.getTextTrim()); try{ Out.prln("Warning (WEKA ML engine): writing dataset as ARFF to " + datasetFile.getCanonicalPath()); }catch(IOException ioe){ throw new ResourceInstantiationException(ioe); } }else{ if(classifier == null){ // both classifier and datasetFile are null // default to only accumulating dataset. onlyAccumulateDataset = true; Out.prln("Warning: Neither classifier or dataset file are specified in the " + "definition!\nThis only accumulates dataset " + "elements internally.\nMake sure to use the SaveDatasetAsArff" + "action or function to save the dataset to a file."); } } // initialise the dataset if(sListener != null) sListener.statusChanged("Initialising dataset..."); FastVector attributes = new FastVector(); weka.core.Attribute classAttribute; Iterator attIter = datasetDefinition.getAttributes().iterator(); while(attIter.hasNext()){ gate.creole.ml.Attribute aGateAttr = (gate.creole.ml.Attribute)attIter.next(); weka.core.Attribute aWekaAttribute = null; if(aGateAttr.getValues() != null){ // nominal or String attribute if(!aGateAttr.getValues().isEmpty()){ // nominal attribute FastVector attrValues = new FastVector(aGateAttr.getValues().size()); Iterator valIter = aGateAttr.getValues().iterator(); while(valIter.hasNext()){ attrValues.addElement(valIter.next()); } aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), attrValues); }else{ // VALUES element present but no values defined -> String attribute aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), (FastVector)null); } }else{ if(aGateAttr.getFeature() == null){ // boolean attribute ([lack of] presence of an annotation) FastVector attrValues = new FastVector(2); attrValues.addElement("true"); attrValues.addElement("false"); aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), attrValues); }else{ // feature is not null but no values provided -> numeric attribute aWekaAttribute = new weka.core.Attribute(aGateAttr.getName()); } } if(aGateAttr.isClass()) classAttribute = aWekaAttribute; attributes.addElement(aWekaAttribute); } dataset = new Instances("Weka ML Engine Dataset", attributes, 0); dataset.setClassIndex(datasetDefinition.getClassIndex()); // write the head of the datafile if(datasetFile != null){ try{ FileWriter fw = new FileWriter(datasetFile); fw.write(dataset.toString()); fw.flush(); fw.close(); }catch(IOException ioe){ throw new ResourceInstantiationException(ioe); } } if(classifier != null && classifier instanceof UpdateableClassifier){ try{ classifier.buildClassifier(dataset); }catch(Exception e){ throw new ResourceInstantiationException(e); } } if(sListener != null) sListener.statusChanged(""); } protected String[] parseOptions(String optionsString) { String[] options = null; if(optionsString == null || optionsString.length() == 0){ options = new String[]{}; }else{ List optionsList = new ArrayList(); StringTokenizer strTok = new StringTokenizer(optionsString, " ", false); while(strTok.hasMoreTokens()){ optionsList.add(strTok.nextToken()); } options = (String[])optionsList.toArray(new String[optionsList.size()]); } return options; } /** * Loads the state of this engine from previously saved data. * * @param is */ public void load(InputStream is) throws IOException { if(sListener != null) sListener.statusChanged("Loading model..."); ObjectInputStream ois = new ObjectInputStream(is); try{ classifier = (Classifier)ois.readObject(); dataset = (Instances)ois.readObject(); datasetDefinition = (DatasetDefintion)ois.readObject(); datasetChanged = ois.readBoolean(); confidenceThreshold = ois.readDouble(); }catch(ClassNotFoundException cnfe){ throw new GateRuntimeException(cnfe.toString()); } ois.close(); if(sListener != null) sListener.statusChanged(""); } /** * Saves the state of the engine for reuse at a later time. * * @param os */ public void save(OutputStream os) throws IOException { if(sListener != null) sListener.statusChanged("Saving model..."); ObjectOutputStream oos = new ObjectOutputStream(os); oos.writeObject(classifier); oos.writeObject(dataset); oos.writeObject(datasetDefinition); oos.writeBoolean(datasetChanged); oos.writeDouble(confidenceThreshold); oos.flush(); oos.close(); if(sListener != null) sListener.statusChanged(""); } /** * Gets the list of actions that can be performed on this resource. * * @return a List of Action objects (or null values) */ public List getActions() { return actionsList; } /** * Registers the PR using the engine with the engine itself. * * @param pr * the processing resource that owns this engine. */ public void setOwnerPR(ProcessingResource pr) { this.owner = pr; } public DatasetDefintion getDatasetDefinition() { return datasetDefinition; } public void saveDatasetAsARFF(FileWriter writer) { try{ writer.write(dataset.toString()); writer.flush(); }catch(IOException ioe){ throw new GateRuntimeException(ioe.getMessage()); } } public void loadDatasetFromArff(FileReader reader) throws IOException, ExecutionException, Exception { Instances newDataset = new Instances(reader); if(!dataset.equalHeaders(newDataset)) throw new ExecutionException("Loaded dataset incompatible with the one " + " in the definition!"); Enumeration instEnum = newDataset.enumerateInstances(); while(instEnum.hasMoreElements()){ addTrainingInstance((Instance)instEnum.nextElement()); } } public boolean supportsBatchMode() { return false; } protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction { public SaveDatasetAsArffAction() { super("Save dataset as ARFF"); putValue(SHORT_DESCRIPTION, "Saves the dataset to a file in ARFF format"); } public void actionPerformed(java.awt.event.ActionEvent evt) { Runnable runnable = new Runnable() { public void run() { JFileChooser fileChooser = MainFrame.getFileChooser(); fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(false); if(fileChooser.showSaveDialog(null) == JFileChooser.APPROVE_OPTION){ File file = fileChooser.getSelectedFile(); try{ MainFrame.lockGUI("Saving dataset..."); FileWriter fw = new FileWriter(file.getCanonicalPath(), false); saveDatasetAsARFF(fw); fw.close(); }catch(IOException ioe){ JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "Gate", JOptionPane.ERROR_MESSAGE); ioe.printStackTrace(Err.getPrintWriter()); }finally{ MainFrame.unlockGUI(); } } } }; Thread thread = new Thread(runnable, "DatasetSaver(ARFF)"); thread.setPriority(Thread.MIN_PRIORITY); thread.start(); } } protected class LoadDatasetFromArffAction extends javax.swing.AbstractAction { public LoadDatasetFromArffAction() { super("Load data from ARFF"); putValue(SHORT_DESCRIPTION, "Loads training data from a file in ARFF format and " + "appends it to the current dataset."); } public void actionPerformed(java.awt.event.ActionEvent evt) { Runnable runnable = new Runnable() { public void run() { JFileChooser fileChooser = MainFrame.getFileChooser(); fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(false); if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){ File file = fileChooser.getSelectedFile(); try{ MainFrame.lockGUI("Loading dataset..."); FileReader reader = new FileReader(file.getCanonicalPath()); loadDatasetFromArff(reader); reader.close(); }catch(Exception e){ JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + e.toString(), "GATE", JOptionPane.ERROR_MESSAGE); e.printStackTrace(Err.getPrintWriter()); }finally{ MainFrame.unlockGUI(); } } } }; Thread thread = new Thread(runnable, "DatasetSaver(ARFF)"); thread.setPriority(Thread.MIN_PRIORITY); thread.start(); } } protected class SaveModelAction extends javax.swing.AbstractAction { public SaveModelAction() { super("Save model"); putValue(SHORT_DESCRIPTION, "Saves the ML model to a file"); } public void actionPerformed(java.awt.event.ActionEvent evt) { Runnable runnable = new Runnable() { public void run() { JFileChooser fileChooser = MainFrame.getFileChooser(); fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(false); if(fileChooser.showSaveDialog(null) == JFileChooser.APPROVE_OPTION){ File file = fileChooser.getSelectedFile(); try{ MainFrame.lockGUI("Saving ML model..."); save(new GZIPOutputStream(new FileOutputStream(file.getCanonicalPath(), false))); }catch(IOException ioe){ JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "GATE", JOptionPane.ERROR_MESSAGE); ioe.printStackTrace(Err.getPrintWriter()); }finally{ MainFrame.unlockGUI(); } } } }; Thread thread = new Thread(runnable, "ModelSaver(serialisation)"); thread.setPriority(Thread.MIN_PRIORITY); thread.start(); } } protected class LoadModelAction extends javax.swing.AbstractAction { public LoadModelAction() { super("Load model"); putValue(SHORT_DESCRIPTION, "Loads a ML model from a file"); } public void actionPerformed(java.awt.event.ActionEvent evt) { Runnable runnable = new Runnable() { public void run() { JFileChooser fileChooser = MainFrame.getFileChooser(); fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY); fileChooser.setMultiSelectionEnabled(false); if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){ File file = fileChooser.getSelectedFile(); try{ MainFrame.lockGUI("Loading model..."); load(new GZIPInputStream(new FileInputStream(file))); }catch(IOException ioe){ JOptionPane.showMessageDialog(MainFrame.getInstance(), "Error!\n" + ioe.toString(), "GATE", JOptionPane.ERROR_MESSAGE); ioe.printStackTrace(Err.getPrintWriter()); }finally{ MainFrame.unlockGUI(); } } } }; Thread thread = new Thread(runnable, "ModelLoader(serialisation)"); thread.setPriority(Thread.MIN_PRIORITY); thread.start(); } } protected DatasetDefintion datasetDefinition; double confidenceThreshold = 0; /** * The WEKA classifier used by this wrapper */ protected Classifier classifier; /** * The dataset used for training */ protected Instances dataset; /** * The JDom element contaning the options fro this wrapper. */ protected Element optionsElement; /** * Marks whether the dataset was changed since the last time the classifier * was built. */ protected boolean datasetChanged = false; protected File datasetFile; protected List actionsList; protected ProcessingResource owner; protected StatusListener sListener; /** * This variable is set when the ML configuration file has neither the * classifier nor the output dataset file option. User is responsible for * explicitly saving the dataset using the now uncommented SaveDatasetAsArff * function / action. */ protected boolean onlyAccumulateDataset; }