|
Wrapper |
|
1 /* 2 * Copyright (c) 1998-2001, The University of Sheffield. 3 * 4 * This file is part of GATE (see http://gate.ac.uk/), and is free 5 * software, licenced under the GNU Library General Public License, 6 * Version 2, June 1991 (in the distribution as file licence.html, 7 * and also available at http://gate.ac.uk/gate/licence.html). 8 * 9 * Valentin Tablan 21/11/2002 10 * 11 * $Id: Wrapper.java,v 1.6 2003/05/23 09:52:09 valyt Exp $ 12 * 13 */ 14 package gate.creole.ml.weka; 15 16 import java.util.*; 17 import java.io.*; 18 import javax.swing.*; 19 import java.util.zip.*; 20 21 import org.jdom.Element; 22 23 import weka.core.*; 24 import weka.classifiers.*; 25 import weka.filters.*; 26 27 import gate.creole.ml.*; 28 import gate.*; 29 import gate.creole.*; 30 import gate.util.*; 31 import gate.event.*; 32 import gate.gui.*; 33 34 /** 35 * Wrapper class for the WEKA Machine Learning Engine. 36 * {@ see http://www.cs.waikato.ac.nz/ml/weka/} 37 */ 38 39 public class Wrapper implements MLEngine, ActionsPublisher { 40 41 public Wrapper() { 42 actionsList = new ArrayList(); 43 actionsList.add(new LoadModelAction()); 44 actionsList.add(new SaveModelAction()); 45 actionsList.add(new SaveDatasetAsArffAction()); 46 } 47 48 public void setOptions(Element optionsElem) { 49 this.optionsElement = optionsElem; 50 } 51 52 public void addTrainingInstance(List attributeValues) 53 throws ExecutionException{ 54 Instance instance = buildInstance(attributeValues); 55 dataset.add(instance); 56 if(classifier != null){ 57 if(classifier instanceof UpdateableClassifier){ 58 //the classifier can learn on the fly; we need to update it 59 try{ 60 ((UpdateableClassifier)classifier).updateClassifier(instance); 61 }catch(Exception e){ 62 throw new GateRuntimeException( 63 "Could not update updateable classifier! Problem was:\n" + 64 e.toString()); 65 } 66 }else{ 67 //the classifier is not updatebale; we need to mark the dataset as changed 68 datasetChanged = true; 69 } 70 } 71 } 72 73 /** 74 * Constructs an instance valid for the current dataset from a list of 75 * attribute values. 76 * @param attributeValues the values for the attributes. 77 * @return an {@link weka.core.Instance} value. 78 */ 79 protected Instance buildInstance(List attributeValues) 80 throws ExecutionException{ 81 //sanity check 82 if(attributeValues.size() != datasetDefinition.getAttributes().size()){ 83 throw new ExecutionException( 84 "The number of attributes provided is wrong for this dataset!"); 85 } 86 87 double[] values = new double[datasetDefinition.getAttributes().size()]; 88 int index = 0; 89 Iterator attrIter = datasetDefinition.getAttributes().iterator(); 90 Iterator valuesIter = attributeValues.iterator(); 91 92 Instance instance = new Instance(attributeValues.size()); 93 instance.setDataset(dataset); 94 95 while(attrIter.hasNext()){ 96 gate.creole.ml.Attribute attr = (gate.creole.ml.Attribute)attrIter.next(); 97 String value = (String)valuesIter.next(); 98 if(value == null){ 99 instance.setMissing(index); 100 }else{ 101 if(attr.getFeature() == null){ 102 //boolean attribute ->the value should already be true/false 103 instance.setValue(index, value); 104 }else{ 105 //nominal, numeric or string attribute 106 if(attr.getValues() != null){ 107 //nominal or string 108 if(attr.getValues().isEmpty()){ 109 //string attribute 110 instance.setValue(index, value); 111 }else{ 112 //nominal attribute 113 if(attr.getValues().contains(value)){ 114 instance.setValue(index, value); 115 }else{ 116 Out.prln("Warning: invalid value: \"" + value + 117 "\" for attribute " + attr.getName() + " was ignored!"); 118 instance.setMissing(index); 119 } 120 } 121 }else{ 122 //numeric attribute 123 try{ 124 double db = Double.parseDouble(value); 125 instance.setValue(index, db); 126 }catch(Exception e){ 127 Out.prln("Warning: invalid numeric value: \"" + value + 128 "\" for attribute " + attr.getName() + " was ignored!"); 129 instance.setMissing(index); 130 } 131 } 132 } 133 } 134 index ++; 135 } 136 return instance; 137 } 138 139 public void setDatasetDefinition(DatasetDefintion definition) { 140 this.datasetDefinition = definition; 141 } 142 143 public Object classifyInstance(List attributeValues) 144 throws ExecutionException { 145 Instance instance = buildInstance(attributeValues); 146 // double result; 147 148 try{ 149 if(classifier instanceof UpdateableClassifier){ 150 return convertAttributeValue(classifier.classifyInstance(instance)); 151 }else{ 152 if(datasetChanged){ 153 if(sListener != null) sListener.statusChanged("[Re]building model..."); 154 classifier.buildClassifier(dataset); 155 datasetChanged = false; 156 if(sListener != null) sListener.statusChanged(""); 157 } 158 159 if(confidenceThreshold > 0 && 160 dataset.classAttribute().type() == weka.core.Attribute.NOMINAL){ 161 //confidence set; use probability distribution 162 163 double[] distribution = null; 164 distribution = ((DistributionClassifier)classifier). 165 distributionForInstance(instance); 166 167 List res = new ArrayList(); 168 for(int i = 0; i < distribution.length; i++){ 169 if(distribution[i] >= confidenceThreshold){ 170 res.add(dataset.classAttribute().value(i)); 171 } 172 } 173 return res; 174 175 }else{ 176 //confidence not set; use simple classification 177 return convertAttributeValue(classifier.classifyInstance(instance)); 178 } 179 } 180 }catch(Exception e){ 181 throw new ExecutionException(e); 182 } 183 } 184 185 protected Object convertAttributeValue(double value){ 186 gate.creole.ml.Attribute classAttr = datasetDefinition.getClassAttribute(); 187 List classValues = classAttr.getValues(); 188 if(classValues != null && !classValues.isEmpty()){ 189 //nominal attribute 190 return dataset.attribute(datasetDefinition.getClassIndex()). 191 value((int)value); 192 }else{ 193 if(classAttr.getFeature() == null){ 194 //boolean attribute 195 return dataset.attribute(datasetDefinition.getClassIndex()). 196 value((int)value); 197 }else{ 198 //numeric attribute 199 return new Double(value); 200 } 201 } 202 } 203 /** 204 * Initialises the classifier and prepares for running. 205 * @throws GateException 206 */ 207 public void init() throws GateException{ 208 //see if we can shout about what we're doing 209 sListener = null; 210 Map listeners = MainFrame.getListeners(); 211 if(listeners != null){ 212 sListener = (StatusListener)listeners.get("gate.event.StatusListener"); 213 } 214 215 //find the classifier to be used 216 if(sListener != null) sListener.statusChanged("Initialising classifier..."); 217 Element classifierElem = optionsElement.getChild("CLASSIFIER"); 218 if(classifierElem == null){ 219 Out.prln("Warning (WEKA ML engine): no classifier selected;" + 220 " dataset collection only!"); 221 classifier = null; 222 }else{ 223 String classifierClassName = classifierElem.getTextTrim(); 224 225 226 //get the options for the classiffier 227 String optionsString = null; 228 if(sListener != null) sListener.statusChanged("Setting classifier options..."); 229 Element classifierOptionsElem = optionsElement.getChild("CLASSIFIER-OPTIONS"); 230 if(classifierOptionsElem != null){ 231 optionsString = classifierOptionsElem.getTextTrim(); 232 } 233 234 //new style overrides old style 235 org.jdom.Attribute optionsAttribute = classifierElem.getAttribute("OPTIONS"); 236 if(optionsAttribute != null){ 237 optionsString = optionsAttribute.getValue().trim(); 238 } 239 String[] options = parseOptions(optionsString); 240 241 try{ 242 classifier = Classifier.forName(classifierClassName, options); 243 }catch(Exception e){ 244 throw new GateException(e); 245 } 246 247 //if we have any filters apply them to the classifer 248 List filterElems = optionsElement.getChildren("FILTER"); 249 if(filterElems != null && filterElems.size() > 0){ 250 Iterator elemIter = filterElems.iterator(); 251 while(elemIter.hasNext()){ 252 Element filterElem = (Element)elemIter.next(); 253 String filterClassName = filterElem.getTextTrim(); 254 String filterOptionsString = ""; 255 org.jdom.Attribute optionsAttr = filterElem.getAttribute("OPTIONS"); 256 if(optionsAttr != null){ 257 filterOptionsString = optionsAttr.getValue().trim(); 258 } 259 //create the new filter 260 try{ 261 Class filterClass = Class.forName(filterClassName); 262 if(!Filter.class.isAssignableFrom(filterClass)){ 263 throw new ResourceInstantiationException( 264 filterClassName + " is not a " + Filter.class.getName() + "!"); 265 } 266 Filter aFilter = (Filter)filterClass.newInstance(); 267 //apply the options to the filter 268 if(filterOptionsString != null && filterOptionsString.length() > 0){ 269 if(!(aFilter instanceof OptionHandler)){ 270 throw new ResourceInstantiationException( 271 filterClassName + " cannot handle options!"); 272 } 273 options = parseOptions(filterOptionsString); 274 ((OptionHandler)aFilter).setOptions(options); 275 } 276 //apply the filter to the classifier 277 classifier = new FilteredClassifier(classifier, aFilter); 278 }catch(Exception e){ 279 throw new ResourceInstantiationException(e); 280 } 281 } 282 } 283 284 Element anElement = optionsElement.getChild("CONFIDENCE-THRESHOLD"); 285 if(anElement != null){ 286 try{ 287 confidenceThreshold = Double.parseDouble(anElement.getTextTrim()); 288 }catch(Exception e){ 289 throw new GateException( 290 "Could not parse confidence threshold value: " + 291 anElement.getTextTrim() + "!"); 292 } 293 if(!(classifier instanceof DistributionClassifier)){ 294 throw new GateException( 295 "Cannot use confidence threshold with classifier: " + 296 classifier.getClass().getName() + "!"); 297 } 298 } 299 300 } 301 302 //initialise the dataset 303 if(sListener != null) sListener.statusChanged("Initialising dataset..."); 304 FastVector attributes = new FastVector(); 305 weka.core.Attribute classAttribute; 306 Iterator attIter = datasetDefinition.getAttributes().iterator(); 307 while(attIter.hasNext()){ 308 gate.creole.ml.Attribute aGateAttr = 309 (gate.creole.ml.Attribute)attIter.next(); 310 weka.core.Attribute aWekaAttribute = null; 311 if(aGateAttr.getValues() != null){ 312 //nominal or String attribute 313 if(!aGateAttr.getValues().isEmpty()){ 314 //nominal attribute 315 FastVector attrValues = new FastVector(aGateAttr.getValues().size()); 316 Iterator valIter = aGateAttr.getValues().iterator(); 317 while(valIter.hasNext()){ 318 attrValues.addElement(valIter.next()); 319 } 320 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), 321 attrValues); 322 }else{ 323 //VALUES element present but no values defined -> String attribute 324 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), 325 null); 326 } 327 }else{ 328 if(aGateAttr.getFeature() == null){ 329 //boolean attribute ([lack of] presence of an annotation) 330 FastVector attrValues = new FastVector(2); 331 attrValues.addElement("true"); 332 attrValues.addElement("false"); 333 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName(), 334 attrValues); 335 }else{ 336 //feature is not null but no values provided -> numeric attribute 337 aWekaAttribute = new weka.core.Attribute(aGateAttr.getName()); 338 } 339 } 340 if(aGateAttr.isClass()) classAttribute = aWekaAttribute; 341 attributes.addElement(aWekaAttribute); 342 } 343 344 dataset = new Instances("Weka ML Engine Dataset", attributes, 0); 345 dataset.setClassIndex(datasetDefinition.getClassIndex()); 346 347 if(classifier != null && classifier instanceof UpdateableClassifier){ 348 try{ 349 classifier.buildClassifier(dataset); 350 }catch(Exception e){ 351 throw new ResourceInstantiationException(e); 352 } 353 } 354 if(sListener != null) sListener.statusChanged(""); 355 } 356 357 protected String[] parseOptions(String optionsString){ 358 String[] options = null; 359 if(optionsString == null || optionsString.length() == 0){ 360 options = new String[]{}; 361 }else{ 362 List optionsList = new ArrayList(); 363 StringTokenizer strTok = 364 new StringTokenizer(optionsString , " ", false); 365 while(strTok.hasMoreTokens()){ 366 optionsList.add(strTok.nextToken()); 367 } 368 options = (String[])optionsList.toArray(new String[optionsList.size()]); 369 } 370 return options; 371 } 372 373 /** 374 * Loads the state of this engine from previously saved data. 375 * @param is 376 */ 377 public void load(InputStream is) throws IOException{ 378 if(sListener != null) sListener.statusChanged("Loading model..."); 379 ObjectInputStream ois = new ObjectInputStream(is); 380 try{ 381 classifier = (Classifier)ois.readObject(); 382 dataset = (Instances)ois.readObject(); 383 datasetDefinition = (DatasetDefintion)ois.readObject(); 384 datasetChanged = ois.readBoolean(); 385 confidenceThreshold = ois.readDouble(); 386 }catch(ClassNotFoundException cnfe){ 387 throw new GateRuntimeException(cnfe.toString()); 388 } 389 ois.close(); 390 if(sListener != null) sListener.statusChanged(""); 391 } 392 393 /** 394 * Saves the state of the engine for reuse at a later time. 395 * @param os 396 */ 397 public void save(OutputStream os) throws IOException{ 398 if(sListener != null) sListener.statusChanged("Saving model..."); 399 ObjectOutputStream oos = new ObjectOutputStream(os); 400 oos.writeObject(classifier); 401 oos.writeObject(dataset); 402 oos.writeObject(datasetDefinition); 403 oos.writeBoolean(datasetChanged); 404 oos.writeDouble(confidenceThreshold); 405 oos.flush(); 406 oos.close(); 407 if(sListener != null) sListener.statusChanged(""); 408 } 409 410 /** 411 * Gets the list of actions that can be performed on this resource. 412 * @return a List of Action objects (or null values) 413 */ 414 public List getActions(){ 415 return actionsList; 416 } 417 418 /** 419 * Registers the PR using the engine with the engine itself. 420 * @param pr the processing resource that owns this engine. 421 */ 422 public void setOwnerPR(ProcessingResource pr){ 423 this.owner = pr; 424 } 425 public DatasetDefintion getDatasetDefinition() { 426 return datasetDefinition; 427 } 428 429 430 protected class SaveDatasetAsArffAction extends javax.swing.AbstractAction{ 431 public SaveDatasetAsArffAction(){ 432 super("Save dataset as ARFF"); 433 putValue(SHORT_DESCRIPTION, "Saves the dataset to a file in ARFF format"); 434 } 435 436 public void actionPerformed(java.awt.event.ActionEvent evt){ 437 Runnable runnable = new Runnable(){ 438 public void run(){ 439 JFileChooser fileChooser = MainFrame.getFileChooser(); 440 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 441 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 442 fileChooser.setMultiSelectionEnabled(false); 443 if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){ 444 File file = fileChooser.getSelectedFile(); 445 try{ 446 MainFrame.lockGUI("Saving dataset..."); 447 FileWriter fw = new FileWriter(file.getCanonicalPath(), false); 448 fw.write(dataset.toString()); 449 fw.flush(); 450 fw.close(); 451 }catch(IOException ioe){ 452 JOptionPane.showMessageDialog(null, 453 "Error!\n"+ 454 ioe.toString(), 455 "Gate", JOptionPane.ERROR_MESSAGE); 456 ioe.printStackTrace(Err.getPrintWriter()); 457 }finally{ 458 MainFrame.unlockGUI(); 459 } 460 } 461 } 462 }; 463 464 Thread thread = new Thread(runnable, "DatasetSaver(ARFF)"); 465 thread.setPriority(Thread.MIN_PRIORITY); 466 thread.start(); 467 } 468 } 469 470 471 protected class SaveModelAction extends javax.swing.AbstractAction{ 472 public SaveModelAction(){ 473 super("Save model"); 474 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file"); 475 } 476 477 public void actionPerformed(java.awt.event.ActionEvent evt){ 478 Runnable runnable = new Runnable(){ 479 public void run(){ 480 JFileChooser fileChooser = MainFrame.getFileChooser(); 481 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 482 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 483 fileChooser.setMultiSelectionEnabled(false); 484 if(fileChooser.showSaveDialog(null) == fileChooser.APPROVE_OPTION){ 485 File file = fileChooser.getSelectedFile(); 486 try{ 487 MainFrame.lockGUI("Saving ML model..."); 488 save(new GZIPOutputStream( 489 new FileOutputStream(file.getCanonicalPath(), false))); 490 }catch(IOException ioe){ 491 JOptionPane.showMessageDialog(null, 492 "Error!\n"+ 493 ioe.toString(), 494 "Gate", JOptionPane.ERROR_MESSAGE); 495 ioe.printStackTrace(Err.getPrintWriter()); 496 }finally{ 497 MainFrame.unlockGUI(); 498 } 499 } 500 } 501 }; 502 Thread thread = new Thread(runnable, "ModelSaver(serialisation)"); 503 thread.setPriority(Thread.MIN_PRIORITY); 504 thread.start(); 505 } 506 } 507 508 protected class LoadModelAction extends javax.swing.AbstractAction{ 509 public LoadModelAction(){ 510 super("Load model"); 511 putValue(SHORT_DESCRIPTION, "Loads a ML model from a file"); 512 } 513 514 public void actionPerformed(java.awt.event.ActionEvent evt){ 515 Runnable runnable = new Runnable(){ 516 public void run(){ 517 JFileChooser fileChooser = MainFrame.getFileChooser(); 518 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter()); 519 fileChooser.setFileSelectionMode(fileChooser.FILES_ONLY); 520 fileChooser.setMultiSelectionEnabled(false); 521 if(fileChooser.showOpenDialog(null) == fileChooser.APPROVE_OPTION){ 522 File file = fileChooser.getSelectedFile(); 523 try{ 524 MainFrame.lockGUI("Loading model..."); 525 load(new GZIPInputStream(new FileInputStream(file))); 526 }catch(IOException ioe){ 527 JOptionPane.showMessageDialog(null, 528 "Error!\n"+ 529 ioe.toString(), 530 "Gate", JOptionPane.ERROR_MESSAGE); 531 ioe.printStackTrace(Err.getPrintWriter()); 532 }finally{ 533 MainFrame.unlockGUI(); 534 } 535 } 536 } 537 }; 538 Thread thread = new Thread(runnable, "ModelLoader(serialisation)"); 539 thread.setPriority(Thread.MIN_PRIORITY); 540 thread.start(); 541 } 542 } 543 544 545 546 protected DatasetDefintion datasetDefinition; 547 548 double confidenceThreshold = 0; 549 550 /** 551 * The WEKA classifier used by this wrapper 552 */ 553 protected Classifier classifier; 554 555 /** 556 * The dataset used for training 557 */ 558 protected Instances dataset; 559 560 /** 561 * The JDom element contaning the options fro this wrapper. 562 */ 563 protected Element optionsElement; 564 565 /** 566 * Marks whether the dataset was changed since the last time the classifier 567 * was built. 568 */ 569 protected boolean datasetChanged = false; 570 571 protected List actionsList; 572 573 protected ProcessingResource owner; 574 575 protected StatusListener sListener; 576 }
|
Wrapper |
|