1
14 package gate.creole.ml.weka;
15
16 import java.io.*;
17 import java.util.*;
18 import java.util.zip.GZIPInputStream;
19 import java.util.zip.GZIPOutputStream;
20
21 import javax.swing.JFileChooser;
22 import javax.swing.JOptionPane;
23
24 import org.jdom.Element;
25
26 import weka.classifiers.Classifier;
27 import weka.classifiers.UpdateableClassifier;
28 import weka.classifiers.meta.FilteredClassifier;
29 import weka.core.*;
30 import weka.filters.Filter;
31
32 import gate.ProcessingResource;
33 import gate.creole.ExecutionException;
34 import gate.creole.ResourceInstantiationException;
35 import gate.creole.ml.DatasetDefintion;
36 import gate.creole.ml.MLEngine;
37 import gate.event.StatusListener;
38 import gate.gui.ActionsPublisher;
39 import gate.gui.MainFrame;
40 import gate.util.*;
41
42
46
47 public class Wrapper implements MLEngine, ActionsPublisher {
48
49 public Wrapper() {
50 actionsList = new ArrayList();
51 actionsList.add(new LoadModelAction());
52 actionsList.add(new SaveModelAction());
53 actionsList.add(null);
54 actionsList.add(new LoadDatasetFromArffAction());
55 }
57
58
62 public void cleanUp() {
63 }
64
65 public void setOptions(Element optionsElem) {
66 this.optionsElement = optionsElem;
67 }
68
69
70
78 public List batchClassifyInstances(java.util.List instances)
79 throws ExecutionException {
80 throw new ExecutionException("The Weka wrapper does not support "+
81 "batch classification. Remove the "+
82 "<BATCH-MODE-CLASSIFICATION/> entry "+
83 "from the XML configuration file and "+
84 "try again.");
85 }
86
87 public void addTrainingInstance(List attributeValues)
88 throws ExecutionException{
89 Instance instance = buildInstance(attributeValues);
90 addTrainingInstance(instance);
91 }
92
93 protected void addTrainingInstance(Instance instance)
94 throws ExecutionException{
95
96 if(classifier != null){
97 if(classifier instanceof UpdateableClassifier){
98 try{
100 ((UpdateableClassifier)classifier).updateClassifier(instance);
101 }catch(Exception e){
102 throw new GateRuntimeException(
103 "Could not update updateable classifier! Problem was:\n" +
104 e.toString());
105 }
106 }else{
107 dataset.add(instance);
109 datasetChanged = true;
110 }
111 }
112 if(datasetFile != null){
113 try{
115 FileWriter fw = new FileWriter(datasetFile, true);
116 fw.write(instance.toString() + "\n");
117 fw.flush();
118 fw.close();
119 }catch(IOException ioe){
120 throw new ExecutionException(ioe);
121 }
122 }
123 }
124
125
131 protected Instance buildInstance(List attributeValues)
132 throws ExecutionException{
133 if(attributeValues.size() != datasetDefinition.getAttributes().size()){
135 throw new ExecutionException(
136 "The number of attributes provided is wrong for this dataset!");
137 }
138
139 double[] values = new double[datasetDefinition.getAttributes().size()];
140 int index = 0;
141 Iterator attrIter = datasetDefinition.getAttributes().iterator();
142 Iterator valuesIter = attributeValues.iterator();
143
144 Instance instance = new Instance(attributeValues.size());
145 instance.setDataset(dataset);
146
147 while(attrIter.hasNext()){
148 gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next();
149 String value = (String)valuesIter.next();
150 if(value == null){
151 instance.setMissing(index);
152 }else{
153 if(attr.getFeature() == null){
154 instance.setValue(index, value);
156 }else{
157 if(attr.getValues() != null){
159 if(attr.getValues().isEmpty()){
161 instance.setValue(index, value);
163 }else{
164 if(attr.getValues().contains(value)){
166 instance.setValue(index, value);
167 }else{
168 Out.prln("Warning: invalid value: \"" + value +
169 "\" for attribute " + attr.getName() + " was ignored!");
170 instance.setMissing(index);
171 }
172 }
173 }else{
174 try{
176 double db = Double.parseDouble(value);
177 instance.setValue(index, db);
178 }catch(Exception e){
179 Out.prln("Warning: invalid numeric value: \"" + value +
180 "\" for attribute " + attr.getName() + " was ignored!");
181 instance.setMissing(index);
182 }
183 }
184 }
185 }
186 index ++;
187 }
188 return instance;
189 }
190
191 public void setDatasetDefinition(DatasetDefintion definition) {
192 this.datasetDefinition = definition;
193 }
194
195 public Object classifyInstance(List attributeValues)
196 throws ExecutionException {
197 Instance instance = buildInstance(attributeValues);
198
200 try{
201 if(classifier instanceof UpdateableClassifier){
202 return convertAttributeValue(classifier.classifyInstance(instance));
203 }else{
204 if(datasetChanged){
205 if(sListener != null) sListener.statusChanged("[Re]building model...");
206 classifier.buildClassifier(dataset);
207 datasetChanged = false;
208 if(sListener != null) sListener.statusChanged("");
209 }
210
211 if(confidenceThreshold > 0 &&
212 dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){
213
215 double[] distribution = null;
216 try{
217 distribution = classifier.distributionForInstance(instance);
218 }catch(Exception e){
219 throw new ExecutionException(e);
222 }
223
224 List res = new ArrayList();
225 for(int i = 0; i < distribution.length; i++){
226 if(distribution[i] >= confidenceThreshold){
227 res.add(dataset.classAttribute().value(i));
228 }
229 }
230 return res;
231
232 }else{
233 return convertAttributeValue(classifier.classifyInstance(instance));
235 }
236 }
237 }catch(Exception e){
238 throw new ExecutionException(e);
239 }
240 }
241
242 protected Object convertAttributeValue(double value){
243 gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute();
244 List classValues = classAttr.getValues();
245 if(classValues != null && !classValues.isEmpty()){
246 return dataset.attribute(datasetDefinition.getClassIndex()).
248 value((int)value);
249 }else{
250 if(classAttr.getFeature() == null){
251 return dataset.attribute(datasetDefinition.getClassIndex()).
253 value((int)value);
254 }else{
255 return new Double(value);
257 }
258 }
259 }
260
264 public void init() throws GateException{
265 sListener = null;
267 Map listeners = MainFrame.getListeners();
268 if(listeners != null){
269 sListener = (StatusListener)listeners.get("gate.event.StatusListener");
270 }
271
272 if(sListener != null) sListener.statusChanged("Initialising classifier...");
274 Element classifierElem = optionsElement.getChild("CLASSIFIER");
275 if(classifierElem == null){
276 Out.prln("Warning (WEKA ML engine): no classifier selected;" +
277 " dataset collection only!");
278 classifier = null;
279 }else{
280 String classifierClassName = classifierElem.getTextTrim();
281
282
283 String optionsString = null;
285 if(sListener != null) sListener.statusChanged("Setting classifier options...");
286 Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS");
287 if(classifierOptionsElem != null){
288 optionsString = classifierOptionsElem.getTextTrim();
289 }
290
291 org.jdom.Attribute optionsAttribute = classifierElem.getAttribute("OPTIONS");
293 if(optionsAttribute != null){
294 optionsString = optionsAttribute.getValue().trim();
295 }
296 String[] options = parseOptions(optionsString);
297
298 try{
299 classifier = Classifier.forName(classifierClassName, options);
300 }catch(Exception e){
301 throw new GateException(e);
302 }
303
304 List filterElems = optionsElement.getChildren("FILTER");
306 if(filterElems != null && filterElems.size() > 0){
307 Iterator elemIter = filterElems.iterator();
308 while(elemIter.hasNext()){
309 Element filterElem = (Element)elemIter.next();
310 String filterClassName = filterElem.getTextTrim();
311 String filterOptionsString = "";
312 org.jdom.Attribute optionsAttr = filterElem.getAttribute("OPTIONS");
313 if(optionsAttr != null){
314 filterOptionsString = optionsAttr.getValue().trim();
315 }
316 try{
318 Class filterClass = Class.forName(filterClassName);
319 if(!Filter.class.isAssignableFrom(filterClass)){
320 throw new ResourceInstantiationException(
321 filterClassName + " is not a " + Filter.class.getName() + "!");
322 }
323 Filter aFilter = (Filter)filterClass.newInstance();
324 if(filterOptionsString != null && filterOptionsString.length() > 0){
326 if(!(aFilter instanceof OptionHandler)){
327 throw new ResourceInstantiationException(
328 filterClassName + " cannot handle options!");
329 }
330 options = parseOptions(filterOptionsString);
331 ((OptionHandler)aFilter).setOptions(options);
332 }
333 classifier = new FilteredClassifier(classifier, aFilter);
335 }catch(Exception e){
336 throw new ResourceInstantiationException(e);
337 }
338 }
339 }
340
341 Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD");
342 if(anElement != null){
343 try{
344 confidenceThreshold = Double.parseDouble(anElement.getTextTrim());
345 }catch(Exception e){
346 throw new GateException(
347 "Could not parse confidence threshold value: " +
348 anElement.getTextTrim() + "!");
349 }
350 }
358
359 }
360
361 Element datafileElem = optionsElement.getChild("DATASET-FILE");
363 if(datafileElem != null){
364 datasetFile = new File(datafileElem.getTextTrim());
365 try{
366 Out.prln("Warning (WEKA ML engine): writing dataset as ARFF to " +
367 datasetFile.getCanonicalPath());
368 }catch(IOException ioe){
369 throw new ResourceInstantiationException(ioe);
370 }
371
372 }else{
373 if(classifier == null){
374 throw new ResourceInstantiationException(
376 "Neither classifier or dataset file are specified in the " +
377 "definition!\nRunning this PR this way would be pointless!");
378 }
379 }
380
381 if(sListener != null) sListener.statusChanged("Initialising dataset...");
383 FastVector attributes = new FastVector();
384 weka.core.Attribute classAttribute;
385 Iterator attIter = datasetDefinition.getAttributes().iterator();
386 while(attIter.hasNext()){
387 gate.creole.ml.Attribute aGateAttr =
388 (gate.creole.ml.Attribute)attIter.next();
389 weka.core.Attribute aWekaAttribute = null;
390 if(aGateAttr.getValues() != null){
391 if(!aGateAttr.getValues().isEmpty()){
393 FastVector attrValues = new FastVector(aGateAttr.getValues().size());
395 Iterator valIter = aGateAttr.getValues().iterator();
396 while(valIter.hasNext()){
397 attrValues.addElement(valIter.next());
398 }
399 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
400 attrValues);
401 }else{
402 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
404 (FastVector)null);
405 }
406 }else{
407 if(aGateAttr.getFeature() == null){
408 FastVector attrValues = new FastVector(2);
410 attrValues.addElement("true");
411 attrValues.addElement("false");
412 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(),
413 attrValues);
414 }else{
415 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName());
417 }
418 }
419 if(aGateAttr.isClass()) classAttribute = aWekaAttribute;
420 attributes.addElement(aWekaAttribute);
421 }
422
423 dataset = new Instances("Weka ML Engine Dataset", attributes, 0);
424 dataset.setClassIndex(datasetDefinition.getClassIndex());
425
426 if(datasetFile != null){
428 try{
429 FileWriter fw = new FileWriter(datasetFile);
430 fw.write(dataset.toString());
431 fw.flush();
432 fw.close();
433 }catch(IOException ioe){
434 throw new ResourceInstantiationException(ioe);
435 }
436 }
437
438 if(classifier != null && classifier instanceof UpdateableClassifier){
439 try{
440 classifier.buildClassifier(dataset);
441 }catch(Exception e){
442 throw new ResourceInstantiationException(e);
443 }
444 }
445 if(sListener != null) sListener.statusChanged("");
446 }
447
448 protected String[] parseOptions(String optionsString){
449 String[] options = null;
450 if(optionsString == null || optionsString.length() == 0){
451 options = new String[]{};
452 }else{
453 List optionsList = new ArrayList();
454 StringTokenizer strTok =
455 new StringTokenizer(optionsString , " ", false);
456 while(strTok.hasMoreTokens()){
457 optionsList.add(strTok.nextToken());
458 }
459 options = (String[])optionsList.toArray(new String[optionsList.size()]);
460 }
461 return options;
462 }
463
464
468 public void load(InputStream is) throws IOException{
469 if(sListener != null) sListener.statusChanged("Loading model...");
470 ObjectInputStream ois = new ObjectInputStream(is);
471 try{
472 classifier = (Classifier)ois.readObject();
473 dataset = (Instances)ois.readObject();
474 datasetDefinition = (DatasetDefintion)ois.readObject();
475 datasetChanged = ois.readBoolean();
476 confidenceThreshold = ois.readDouble();
477 }catch(ClassNotFoundException cnfe){
478 throw new GateRuntimeException(cnfe.toString());
479 }
480 ois.close();
481 if(sListener != null) sListener.statusChanged("");
482 }
483
484
488 public void save(OutputStream os) throws IOException{
489 if(sListener != null) sListener.statusChanged("Saving model...");
490 ObjectOutputStream oos = new ObjectOutputStream(os);
491 oos.writeObject(classifier);
492 oos.writeObject(dataset);
493 oos.writeObject(datasetDefinition);
494 oos.writeBoolean(datasetChanged);
495 oos.writeDouble(confidenceThreshold);
496 oos.flush();
497 oos.close();
498 if(sListener != null) sListener.statusChanged("");
499 }
500
501
505 public List getActions(){
506 return actionsList;
507 }
508
509
513 public void setOwnerPR(ProcessingResource pr){
514 this.owner = pr;
515 }
516 public DatasetDefintion getDatasetDefinition() {
517 return datasetDefinition;
518 }
519
520
529 public void loadDatasetFromArff(FileReader reader) throws IOException,
530 ExecutionException,
531 Exception{
532 Instances newDataset = new Instances(reader);
533 if(!dataset.equalHeaders(newDataset))
534 throw new ExecutionException("Loaded dataset incompatible with the one " +
535 " in the definition!");
536 Enumeration instEnum = newDataset.enumerateInstances();
537 while(instEnum.hasMoreElements()){
538 addTrainingInstance((Instance)instEnum.nextElement());
539 }
540 }
541
542
581
582 protected class LoadDatasetFromArffAction extends javax.swing.AbstractAction{
583 public LoadDatasetFromArffAction(){
584 super("Load data from ARFF");
585 putValue(SHORT_DESCRIPTION,
586 "Loads training data from a file in ARFF format and " +
587 "appends it to the current dataset.");
588 }
589
590 public void actionPerformed(java.awt.event.ActionEvent evt){
591 Runnable runnable = new Runnable(){
592 public void run(){
593 JFileChooser fileChooser = MainFrame.getFileChooser();
594 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
595 fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
596 fileChooser.setMultiSelectionEnabled(false);
597 if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){
598 File file = fileChooser.getSelectedFile();
599 try{
600 MainFrame.lockGUI("Loading dataset...");
601 FileReader reader = new FileReader(file.getCanonicalPath());
602 loadDatasetFromArff(reader);
603 reader.close();
604 }catch(Exception e){
605 JOptionPane.showMessageDialog(null,
606 "Error!\n"+
607 e.toString(),
608 "GATE", JOptionPane.ERROR_MESSAGE);
609 e.printStackTrace(Err.getPrintWriter());
610 }finally{
611 MainFrame.unlockGUI();
612 }
613 }
614 }
615 };
616
617 Thread thread = new Thread(runnable, "DatasetSaver(ARFF)");
618 thread.setPriority(Thread.MIN_PRIORITY);
619 thread.start();
620 }
621 }
622
623
624 protected class SaveModelAction extends javax.swing.AbstractAction{
625 public SaveModelAction(){
626 super("Save model");
627 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
628 }
629
630 public void actionPerformed(java.awt.event.ActionEvent evt){
631 Runnable runnable = new Runnable(){
632 public void run(){
633 JFileChooser fileChooser = MainFrame.getFileChooser();
634 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
635 fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
636 fileChooser.setMultiSelectionEnabled(false);
637 if(fileChooser.showSaveDialog(null) == JFileChooser.APPROVE_OPTION){
638 File file = fileChooser.getSelectedFile();
639 try{
640 MainFrame.lockGUI("Saving ML model...");
641 save(new GZIPOutputStream(
642 new FileOutputStream(file.getCanonicalPath(), false)));
643 }catch(IOException ioe){
644 JOptionPane.showMessageDialog(null,
645 "Error!\n"+
646 ioe.toString(),
647 "GATE", JOptionPane.ERROR_MESSAGE);
648 ioe.printStackTrace(Err.getPrintWriter());
649 }finally{
650 MainFrame.unlockGUI();
651 }
652 }
653 }
654 };
655 Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
656 thread.setPriority(Thread.MIN_PRIORITY);
657 thread.start();
658 }
659 }
660
661 protected class LoadModelAction extends javax.swing.AbstractAction{
662 public LoadModelAction(){
663 super("Load model");
664 putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
665 }
666
667 public void actionPerformed(java.awt.event.ActionEvent evt){
668 Runnable runnable = new Runnable(){
669 public void run(){
670 JFileChooser fileChooser = MainFrame.getFileChooser();
671 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
672 fileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
673 fileChooser.setMultiSelectionEnabled(false);
674 if(fileChooser.showOpenDialog(null) == JFileChooser.APPROVE_OPTION){
675 File file = fileChooser.getSelectedFile();
676 try{
677 MainFrame.lockGUI("Loading model...");
678 load(new GZIPInputStream(new FileInputStream(file)));
679 }catch(IOException ioe){
680 JOptionPane.showMessageDialog(null,
681 "Error!\n"+
682 ioe.toString(),
683 "GATE", JOptionPane.ERROR_MESSAGE);
684 ioe.printStackTrace(Err.getPrintWriter());
685 }finally{
686 MainFrame.unlockGUI();
687 }
688 }
689 }
690 };
691 Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
692 thread.setPriority(Thread.MIN_PRIORITY);
693 thread.start();
694 }
695 }
696
697
698
699 protected DatasetDefintion datasetDefinition;
700
701 double confidenceThreshold = 0;
702
703
706 protected Classifier classifier;
707
708
711 protected Instances dataset;
712
713
716 protected Element optionsElement;
717
718
722 protected boolean datasetChanged = false;
723
724 protected File datasetFile;
725
726 protected List actionsList;
727
728 protected ProcessingResource owner;
729
730 protected StatusListener sListener;
731 }