1   /*
2    *  Copyright (c) 1998-2001, The University of Sheffield.
3    *
4    *  This file is part of GATE (see http://gate.ac.uk/), and is free
5    *  software, licenced under the GNU Library General Public License,
6    *  Version 2, June 1991 (in the distribution as file licence.html,
7    *  and also available at http://gate.ac.uk/gate/licence.html).
8    *
9    *  Valentin Tablan 21/11/2002
10   *
11   *  $Id: Wrapper.java,v 1.6 2003/05/23 09:52:09 valyt Exp $
12   *
13   */
14  package gate.creole.ml.weka;
15  
16  import java.util.*;
17  import java.io.*;
18  import javax.swing.*;
19  import java.util.zip.*;
20  
21  import org.jdom.Element;
22  
23  import weka.core.*;
24  import weka.classifiers.*;
25  import weka.filters.*;
26  
27  import gate.creole.ml.*;
28  import gate.*;
29  import gate.creole.*;
30  import gate.util.*;
31  import gate.event.*;
32  import gate.gui.*;
33  
34  /**
35   * Wrapper class for the WEKA Machine Learning Engine.
36   * {@ see http://www.cs.waikato.ac.nz/ml/weka/}
37   */
38  
39  public class Wrapper implements MLEngine, ActionsPublisher {
40  
41    public Wrapper() {
42      actionsList = new ArrayList();
43      actionsList.add(new LoadModelAction());
44      actionsList.add(new SaveModelAction());
45      actionsList.add(new SaveDatasetAsArffAction());
46    }
47  
48    public void setOptions(Element optionsElem) {
49      this.optionsElement = optionsElem;
50    }
51  
52    public void addTrainingInstance(List attributeValues)
53                throws ExecutionException{
54      Instance instance = buildInstance(attributeValues);
55      dataset.add(instance);
56      if(classifier != null){
57        if(classifier instanceof UpdateableClassifier){
58          //the classifier can learn on the fly; we need to update it
59          try{
60            ((UpdateableClassifier)classifier).updateClassifier(instance);
61          }catch(Exception e){
62            throw new GateRuntimeException(
63              "Could not update updateable classifier! Problem was:\n" +
64              e.toString());
65          }
66        }else{
67          //the classifier is not updatebale; we need to mark the dataset as changed
68          datasetChanged = true;
69        }
70      }
71    }
72  
73    /**
74     * Constructs an instance valid for the current dataset from a list of
75     * attribute values.
76     * @param attributeValues the values for the attributes.
77     * @return an {@link weka.core.Instance} value.
78     */
79    protected Instance buildInstance(List attributeValues)
80              throws ExecutionException{
81      //sanity check
82      if(attributeValues.size() != datasetDefinition.getAttributes().size()){
83        throw new ExecutionException(
84          "The number of attributes provided is wrong for this dataset!");
85      }
86  
87      double[] values = new double[datasetDefinition.getAttributes().size()];
88      int index = 0;
89      Iterator attrIter = datasetDefinition.getAttributes().iterator();
90      Iterator valuesIter = attributeValues.iterator();
91  
92      Instance instance = new Instance(attributeValues.size());
93      instance.setDataset(dataset);
94  
95      while(attrIter.hasNext()){
96        gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next();
97        String value = (String)valuesIter.next();
98        if(value == null){
99          instance.setMissing(index);
100       }else{
101         if(attr.getFeature() == null){
102           //boolean attribute ->the value should already be true/false
103           instance.setValue(index, value);
104         }else{
105           //nominal, numeric or string attribute
106           if(attr.getValues() != null){
107             //nominal or string
108             if(attr.getValues().isEmpty()){
109               //string attribute
110               instance.setValue(index, value);
111             }else{
112               //nominal attribute
113               if(attr.getValues().contains(value)){
114                 instance.setValue(index, value);
115               }else{
116                 Out.prln("Warning: invalid value: \"" + value +
117                          "\" for attribute " + attr.getName() + " was ignored!");
118                 instance.setMissing(index);
119               }
120             }
121           }else{
122             //numeric attribute
123             try{
124               double db = Double.parseDouble(value);
125               instance.setValue(index, db);
126             }catch(Exception e){
127               Out.prln("Warning: invalid numeric value: \"" + value +
128                        "\" for attribute " + attr.getName() + " was ignored!");
129               instance.setMissing(index);
130             }
131           }
132         }
133       }
134       index ++;
135     }
136     return instance;
137   }
138 
139   public void setDatasetDefinition(DatasetDefintion definition) {
140     this.datasetDefinition = definition;
141   }
142 
143   public Object classifyInstance(List attributeValues)
144          throws ExecutionException {
145     Instance instance = buildInstance(attributeValues);
146 //    double result;
147 
148     try{
149       if(classifier instanceof UpdateableClassifier){
150         return convertAttributeValue(classifier.classifyInstance(instance));
151       }else{
152         if(datasetChanged){
153           if(sListener != null) sListener.statusChanged("[Re]building model...");
154           classifier.buildClassifier(dataset);
155           datasetChanged = false;
156           if(sListener != null) sListener.statusChanged("");
157         }
158 
159         if(confidenceThreshold > 0 &&
160            dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){
161           //confidence set; use probability distribution
162 
163           double[] distribution = null;
164           distribution = ((DistributionClassifier)classifier).
165                                   distributionForInstance(instance);
166 
167           List res = new ArrayList();
168           for(int i = 0; i < distribution.length; i++){
169             if(distribution[i] >= confidenceThreshold){
170               res.add(dataset.classAttribute().value(i));
171             }
172           }
173           return res;
174 
175         }else{
176           //confidence not set; use simple classification
177           return convertAttributeValue(classifier.classifyInstance(instance));
178         }
179       }
180     }catch(Exception e){
181       throw new ExecutionException(e);
182     }
183   }
184 
185   protected Object convertAttributeValue(double value){
186     gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute();
187     List classValues = classAttr.getValues();
188     if(classValues != null && !classValues.isEmpty()){
189       //nominal attribute
190       return dataset.attribute(datasetDefinition.getClassIndex()).
191                      value((int)value);
192     }else{
193       if(classAttr.getFeature() == null){
194         //boolean attribute
195         return dataset.attribute(datasetDefinition.getClassIndex()).
196                        value((int)value);
197       }else{
198         //numeric attribute
199         return new Double(value);
200       }
201     }
202   }
203   /**
204    * Initialises the classifier and prepares for running.
205    * @throws GateException
206    */
207   public void init() throws GateException{
208     //see if we can shout about what we're doing
209     sListener = null;
210     Map listeners = MainFrame.getListeners();
211     if(listeners != null){
212       sListener = (StatusListener)listeners.get("gate.event.StatusListener");
213     }
214 
215     //find the classifier to be used
216     if(sListener != null) sListener.statusChanged("Initialising classifier...");
217     Element classifierElem = optionsElement.getChild("CLASSIFIER");
218     if(classifierElem == null){
219       Out.prln("Warning (WEKA ML engine): no classifier selected;" +
220                " dataset collection only!");
221       classifier = null;
222     }else{
223       String classifierClassName = classifierElem.getTextTrim();
224 
225 
226       //get the options for the classiffier
227       String optionsString = null;
228       if(sListener != null) sListener.statusChanged("Setting classifier options...");
229       Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS");
230       if(classifierOptionsElem != null){
231         optionsString = classifierOptionsElem.getTextTrim();
232       }
233 
234       //new style overrides old style
235       org.jdom.Attribute optionsAttribute = classifierElem.getAttribute("OPTIONS");
236       if(optionsAttribute != null){
237         optionsString = optionsAttribute.getValue().trim();
238       }
239       String[] options = parseOptions(optionsString);
240 
241       try{
242         classifier = Classifier.forName(classifierClassName, options);
243       }catch(Exception e){
244         throw new GateException(e);
245       }
246 
247       //if we have any filters apply them to the classifer
248       List filterElems = optionsElement.getChildren("FILTER");
249       if(filterElems != null && filterElems.size() > 0){
250         Iterator elemIter = filterElems.iterator();
251         while(elemIter.hasNext()){
252           Element filterElem = (Element)elemIter.next();
253           String filterClassName = filterElem.getTextTrim();
254           String filterOptionsString = "";
255           org.jdom.Attribute optionsAttr = filterElem.getAttribute("OPTIONS");
256           if(optionsAttr != null){
257             filterOptionsString = optionsAttr.getValue().trim();
258           }
259           //create the new filter
260           try{
261             Class filterClass = Class.forName(filterClassName);
262             if(!Filter.class.isAssignableFrom(filterClass)){
263               throw new ResourceInstantiationException(
264                 filterClassName + " is not a " + Filter.class.getName() + "!");
265             }
266             Filter aFilter = (Filter)filterClass.newInstance();
267             //apply the options to the filter
268             if(filterOptionsString != null && filterOptionsString.length() > 0){
269               if(!(aFilter instanceof OptionHandler)){
270                 throw new ResourceInstantiationException(
271                   filterClassName + " cannot handle options!");
272               }
273               options = parseOptions(filterOptionsString);
274               ((OptionHandler)aFilter).setOptions(options);
275             }
276             //apply the filter to the classifier
277             classifier = new FilteredClassifier(classifier, aFilter);
278           }catch(Exception e){
279             throw new ResourceInstantiationException(e);
280           }
281         }
282       }
283 
284       Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD");
285       if(anElement != null){
286         try{
287           confidenceThreshold = Double.parseDouble(anElement.getTextTrim());
288         }catch(Exception e){
289           throw new GateException(
290             "Could not parse confidence threshold value: " +
291             anElement.getTextTrim() + "!");
292         }
293         if(!(classifier instanceof DistributionClassifier)){
294           throw new GateException(
295             "Cannot use confidence threshold with classifier: " +
296             classifier.getClass().getName() + "!");
297         }
298       }
299 
300     }
301 
302     //initialise the dataset
303     if(sListener != null) sListener.statusChanged("Initialising dataset...");
304     FastVector attributes = new FastVector();
305     weka.core.Attribute classAttribute;
306     Iterator attIter = datasetDefinition.getAttributes().iterator();
307     while(attIter.hasNext()){
308       gate.creole.ml.Attribute aGateAttr =
309         (gate.creole.ml.Attribute)attIter.next();
310       weka.core.Attribute aWekaAttribute = null;
311       if(aGateAttr.getValues() != null){
312         //nominal or String attribute
313         if(!aGateAttr.getValues().isEmpty()){
314           //nominal attribute
315           FastVector attrValues = new FastVector(aGateAttr.getValues().size());
316           Iterator valIter = aGateAttr.getValues().iterator();
317           while(valIter.hasNext()){
318             attrValues.addElement(valIter.next());
319           }
320           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
321                                                    attrValues);
322         }else{
323           //VALUES element present but no values defined -> String attribute
324           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
325                                                    null);
326         }
327       }else{
328         if(aGateAttr.getFeature() == null){
329           //boolean attribute ([lack of] presence of an annotation)
330           FastVector attrValues = new FastVector(2);
331           attrValues.addElement("true");
332           attrValues.addElement("false");
333           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
334                                                    attrValues);
335         }else{
336           //feature is not null but no values provided -> numeric attribute
337           aWekaAttribute = new weka.core.Attribute(aGateAttr.getName());
338         }
339       }
340       if(aGateAttr.isClass()) classAttribute = aWekaAttribute;
341       attributes.addElement(aWekaAttribute);
342     }
343 
344     dataset = new Instances("Weka ML Engine Dataset", attributes, 0);
345     dataset.setClassIndex(datasetDefinition.getClassIndex());
346 
347     if(classifier != null && classifier instanceof UpdateableClassifier){
348       try{
349         classifier.buildClassifier(dataset);
350       }catch(Exception e){
351         throw new ResourceInstantiationException(e);
352       }
353     }
354     if(sListener != null) sListener.statusChanged("");
355   }
356 
357   protected String[] parseOptions(String optionsString){
358     String[] options = null;
359     if(optionsString == null || optionsString.length() == 0){
360       options = new String[]{};
361     }else{
362       List optionsList = new ArrayList();
363       StringTokenizer strTok =
364         new StringTokenizer(optionsString , " ", false);
365       while(strTok.hasMoreTokens()){
366         optionsList.add(strTok.nextToken());
367       }
368       options = (String[])optionsList.toArray(new String[optionsList.size()]);
369     }
370     return options;
371   }
372 
373   /**
374    * Loads the state of this engine from previously saved data.
375    * @param is
376    */
377   public void load(InputStream is) throws IOException{
378     if(sListener != null) sListener.statusChanged("Loading model...");
379     ObjectInputStream ois = new ObjectInputStream(is);
380     try{
381       classifier = (Classifier)ois.readObject();
382       dataset = (Instances)ois.readObject();
383       datasetDefinition = (DatasetDefintion)ois.readObject();
384       datasetChanged = ois.readBoolean();
385       confidenceThreshold = ois.readDouble();
386     }catch(ClassNotFoundException cnfe){
387       throw new GateRuntimeException(cnfe.toString());
388     }
389     ois.close();
390     if(sListener != null) sListener.statusChanged("");
391   }
392 
393   /**
394    * Saves the state of the engine for reuse at a later time.
395    * @param os
396    */
397   public void save(OutputStream os) throws IOException{
398     if(sListener != null) sListener.statusChanged("Saving model...");
399     ObjectOutputStream oos = new ObjectOutputStream(os);
400     oos.writeObject(classifier);
401     oos.writeObject(dataset);
402     oos.writeObject(datasetDefinition);
403     oos.writeBoolean(datasetChanged);
404     oos.writeDouble(confidenceThreshold);
405     oos.flush();
406     oos.close();
407     if(sListener != null) sListener.statusChanged("");
408   }
409 
410   /**
411    * Gets the list of actions that can be performed on this resource.
412    * @return a List of Action objects (or null values)
413    */
414   public List getActions(){
415     return actionsList;
416   }
417 
418   /**
419    * Registers the PR using the engine with the engine itself.
420    * @param pr the processing resource that owns this engine.
421    */
422   public void setOwnerPR(ProcessingResource pr){
423     this.owner = pr;
424   }
425   public DatasetDefintion getDatasetDefinition() {
426     return datasetDefinition;
427   }
428 
429 
430   protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction{
431     public SaveDatasetAsArffAction(){
432       super("Save dataset as ARFF");
433       putValue(SHORT_DESCRIPTION, "Saves the dataset to a file in ARFF format");
434     }
435 
436     public void actionPerformed(java.awt.event.ActionEvent evt){
437       Runnable runnable = new Runnable(){
438         public void run(){
439           JFileChooser fileChooser = MainFrame.getFileChooser();
440           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
441           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
442           fileChooser.setMultiSelectionEnabled(false);
443           if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){
444             File file = fileChooser.getSelectedFile();
445             try{
446               MainFrame.lockGUI("Saving dataset...");
447               FileWriter fw = new FileWriter(file.getCanonicalPath(), false);
448               fw.write(dataset.toString());
449               fw.flush();
450               fw.close();
451             }catch(IOException ioe){
452               JOptionPane.showMessageDialog(null,
453                               "Error!\n"+
454                                ioe.toString(),
455                                "Gate", JOptionPane.ERROR_MESSAGE);
456               ioe.printStackTrace(Err.getPrintWriter());
457             }finally{
458               MainFrame.unlockGUI();
459             }
460           }
461         }
462       };
463 
464       Thread thread = new Thread(runnable, "DatasetSaver(ARFF)");
465       thread.setPriority(Thread.MIN_PRIORITY);
466       thread.start();
467     }
468   }
469 
470 
471   protected class SaveModelAction extends javax.swing.AbstractAction{
472     public SaveModelAction(){
473       super("Save model");
474       putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
475     }
476 
477     public void actionPerformed(java.awt.event.ActionEvent evt){
478       Runnable runnable = new Runnable(){
479         public void run(){
480           JFileChooser fileChooser = MainFrame.getFileChooser();
481           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
482           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
483           fileChooser.setMultiSelectionEnabled(false);
484           if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){
485             File file = fileChooser.getSelectedFile();
486             try{
487               MainFrame.lockGUI("Saving ML model...");
488               save(new GZIPOutputStream(
489                    new FileOutputStream(file.getCanonicalPath(), false)));
490             }catch(IOException ioe){
491               JOptionPane.showMessageDialog(null,
492                               "Error!\n"+
493                                ioe.toString(),
494                                "Gate", JOptionPane.ERROR_MESSAGE);
495               ioe.printStackTrace(Err.getPrintWriter());
496             }finally{
497               MainFrame.unlockGUI();
498             }
499           }
500         }
501       };
502       Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
503       thread.setPriority(Thread.MIN_PRIORITY);
504       thread.start();
505     }
506   }
507 
508   protected class LoadModelAction extends javax.swing.AbstractAction{
509     public LoadModelAction(){
510       super("Load model");
511       putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
512     }
513 
514     public void actionPerformed(java.awt.event.ActionEvent evt){
515       Runnable runnable = new Runnable(){
516         public void run(){
517           JFileChooser fileChooser = MainFrame.getFileChooser();
518           fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
519           fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY);
520           fileChooser.setMultiSelectionEnabled(false);
521           if(fileChooser.showOpenDialog(null) == fileChooser.APPROVE_OPTION){
522             File file = fileChooser.getSelectedFile();
523             try{
524               MainFrame.lockGUI("Loading model...");
525               load(new GZIPInputStream(new FileInputStream(file)));
526             }catch(IOException ioe){
527               JOptionPane.showMessageDialog(null,
528                               "Error!\n"+
529                                ioe.toString(),
530                                "Gate", JOptionPane.ERROR_MESSAGE);
531               ioe.printStackTrace(Err.getPrintWriter());
532             }finally{
533               MainFrame.unlockGUI();
534             }
535           }
536         }
537       };
538       Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
539       thread.setPriority(Thread.MIN_PRIORITY);
540       thread.start();
541     }
542   }
543 
544 
545 
546   protected DatasetDefintion datasetDefinition;
547 
548   double confidenceThreshold = 0;
549 
550   /**
551    * The WEKA classifier used by this wrapper
552    */
553   protected Classifier classifier;
554 
555   /**
556    * The dataset used for training
557    */
558   protected Instances dataset;
559 
560   /**
561    * The JDom element contaning the options fro this wrapper.
562    */
563   protected Element optionsElement;
564 
565   /**
566    * Marks whether the dataset was changed since the last time the classifier
567    * was built.
568    */
569   protected boolean datasetChanged = false;
570 
571   protected List actionsList;
572 
573   protected ProcessingResource owner;
574 
575   protected StatusListener sListener;
576 }