1
14
15 package gate.creole.ml.maxent;
16
17 import gate.creole.ml.*;
18 import gate.util.GateException;
19 import gate.creole.ExecutionException;
20 import java.util.List;
21 import java.util.Iterator;
22
23
27 public class MaxentWrapper
28 implements MLEngine, gate.gui.ActionsPublisher {
29
30 boolean DEBUG=false;
31
32
41 public MaxentWrapper() {
42 actionsList = new java.util.ArrayList();
43 actionsList.add(new LoadModelAction());
44 actionsList.add(new SaveModelAction());
45 actionsList.add(null);
46 }
47
48
52 public void cleanUp() {
53 }
54
55
63 public List batchClassifyInstances(java.util.List instances)
64 throws ExecutionException {
65 throw new ExecutionException("The Maxent wrapper does not support "+
66 "batch classification. Remove the "+
67 "<BATCH-MODE-CLASSIFICATION/> entry "+
68 "from the XML configuration file and "+
69 "try again.");
70 }
71
72
78 public void setOptions(org.jdom.Element optionsElem) {
79 this.optionsElement = optionsElem;
80 }
81
82
88 private void extractAndCheckOptions() throws gate.creole.
89 ResourceInstantiationException {
90 setCutoff(optionsElement);
91 setConfidenceThreshold(optionsElement);
92 setVerbose(optionsElement);
93 setIterations(optionsElement);
94 setSmoothing(optionsElement);
95 setSmoothingObservation(optionsElement);
96 }
97
98
102 private void setVerbose(org.jdom.Element optionsElem) {
103 if (optionsElem.getChild("VERBOSE") == null) {
104 verbose = false;
105 }
106 else {
107 verbose = true;
108 }
109 }
110
111
115 private void setSmoothing(org.jdom.Element optionsElem) {
116 if (optionsElem.getChild("SMOOTHING") == null) {
117 smoothing = false;
118 }
119 else {
120 smoothing = true;
121 }
122 }
123
124
128 private void setSmoothingObservation(org.jdom.Element optionsElem) throws
129 gate.creole.ResourceInstantiationException {
130 String smoothingObservationString
131 = optionsElem.getChildTextTrim("SMOOTHING-OBSERVATION");
132 if (smoothingObservationString != null) {
133 try {
134 smoothingObservation = Double.parseDouble(smoothingObservationString);
135 }
136 catch (NumberFormatException e) {
137 throw new gate.creole.ResourceInstantiationException("Unable to parse " +
138 "<SMOOTHING-OBSERVATION> value in maxent configuration file.");
139 }
140 }
141 else {
142 smoothingObservation = 0.0;
143 }
144 }
145
146
150 private void setConfidenceThreshold(org.jdom.Element optionsElem) throws gate.
151 creole.ResourceInstantiationException {
152 String confidenceThresholdString
153 = optionsElem.getChildTextTrim("CONFIDENCE-THRESHOLD");
154 if (confidenceThresholdString != null) {
155 try {
156 confidenceThreshold = Double.parseDouble(confidenceThresholdString);
157 }
158 catch (NumberFormatException e) {
159 throw new gate.creole.ResourceInstantiationException("Unable to parse " +
160 "<CONFIDENCE-THRESHOLD> value in maxent configuration file.");
161 }
162 if (confidenceThreshold < 0.0 || confidenceThreshold > 1) {
163 throw new gate.creole.ResourceInstantiationException(
164 "<CONFIDENCE-THRESHOLD> in maxent configuration"
165 + " file must be set to a value between 0 and 1."
166 + " (It is a probability.)");
167 }
168 }
169 else {
170 confidenceThreshold = 0.0;
171 }
172 }
173
174
178 private void setCutoff(org.jdom.Element optionsElem) throws gate.creole.
179 ResourceInstantiationException {
180 String cutoffString = optionsElem.getChildTextTrim("CUT-OFF");
181 if (cutoffString != null) {
182 try {
183 cutoff = Integer.parseInt(cutoffString);
184 }
185 catch (NumberFormatException e) {
186 throw new gate.creole.ResourceInstantiationException(
187 "Unable to parse <CUT-OFF> value in maxent " +
188 "configuration file. It must be an integer.");
189 }
190 }
191 else {
192 cutoff = 0;
193 }
194 }
195
196
201 private void setIterations(org.jdom.Element optionsElem) throws gate.creole.
202 ResourceInstantiationException {
203 String iterationsString = optionsElem.getChildTextTrim("ITERATIONS");
204 if (iterationsString != null) {
205 try {
206 iterations = Integer.parseInt(iterationsString);
207 }
208 catch (NumberFormatException e) {
209 throw new gate.creole.ResourceInstantiationException(
210 "Unable to parse <ITERATIONS> value in maxent " +
211 "configuration file. It must be an integer.");
212 }
213 }
214 else {
215 iterations = 0;
216 }
217 }
218
219
227 public void addTrainingInstance(List attributeValues) {
228 markIndicesOnFeatures(attributeValues);
229 trainingData.add(attributeValues);
230 datasetChanged = true;
231 }
232
233
243 void markIndicesOnFeatures(List attributeValues) {
244 for (int i=0; i<attributeValues.size(); ++i) {
245 if (i != datasetDefinition.getClassIndex())
247 attributeValues.set(i, i+":"+(String)attributeValues.get(i));
248 }
249 }
250
251
258 public void setDatasetDefinition(DatasetDefintion definition) {
259 this.datasetDefinition = definition;
260 }
261
262
268 private void checkDatasetDefinition() throws gate.creole.
269 ResourceInstantiationException {
270 List attributes = datasetDefinition.getAttributes();
273 Iterator attributeIterator = attributes.iterator();
274 while (attributeIterator.hasNext()) {
275 gate.creole.ml.Attribute currentAttribute
276 = (gate.creole.ml.Attribute) attributeIterator.next();
277 if (currentAttribute.semanticType() != gate.creole.ml.Attribute.BOOLEAN) {
278 if (currentAttribute.semanticType() != gate.creole.ml.Attribute.NOMINAL
279 || !currentAttribute.isClass()) {
280 throw new gate.creole.ResourceInstantiationException(
281 "Error in maxent configuration file. All " +
282 "attributes except the <CLASS/> attribute " +
283 "must be boolean, and the <CLASS/> attribute" +
284 " must be boolean or nominal");
285 }
286 }
287 }
288 }
289
290
295 private void initialiseAndTrainClassifier() {
296 opennlp.maxent.GIS.PRINT_MESSAGES = verbose;
297 opennlp.maxent.GIS.SMOOTHING = smoothing;
298 opennlp.maxent.GIS.SMOOTHING_OBSERVATION = smoothingObservation;
299
300 if (DEBUG) {
302 System.out.println("Number of training instances: "+trainingData.size());
303 System.out.println("Class index: "+datasetDefinition.getClassIndex());
304 System.out.println("Iterations: "+iterations);
305 System.out.println("Cutoff: "+cutoff);
306 System.out.println("Confidence threshold: "+confidenceThreshold);
307 System.out.println("Verbose: "+verbose);
308 System.out.println("Smoothing: "+smoothing);
309 System.out.println("Smoothing observation: "+smoothingObservation);
310
311 System.out.println("");
312 System.out.println("TRAINING DATA\n");
313 System.out.println(trainingData);
314 }
315 maxentClassifier = opennlp.maxent.GIS.trainModel(
316 new GateEventStream(trainingData, datasetDefinition.getClassIndex()),
317 iterations, cutoff);
318 }
319
320
337 public Object classifyInstance(List attributeValues) throws
338 ExecutionException {
339 if (maxentClassifier == null || datasetChanged)
343 initialiseAndTrainClassifier();
344 datasetChanged=false;
347
348 markIndicesOnFeatures(attributeValues);
351
352 attributeValues.remove(datasetDefinition.getClassIndex());
356
357 if (confidenceThreshold == 0) { return maxentClassifier.
361 getBestOutcome(maxentClassifier.eval(
362 (String[])attributeValues.toArray(new String[0])));
363 }
364 else { double[] outcomeProbabilities = maxentClassifier.eval(
366 (String[]) attributeValues.toArray(new String[0]));
367
368 List allOutcomesOverThreshold = new java.util.ArrayList();
369 for (int i = 0; i < outcomeProbabilities.length; i++) {
370 if (outcomeProbabilities[i] >= confidenceThreshold) {
371 allOutcomesOverThreshold.add(maxentClassifier.getOutcome(i));
372 }
373 }
374 return allOutcomesOverThreshold;
375 }
376 }
378
385 public void init() throws GateException {
386 sListener = null;
388 java.util.Map listeners = gate.gui.MainFrame.getListeners();
389 if (listeners != null) {
390 sListener = (gate.event.StatusListener)
391 listeners.get("gate.event.StatusListener");
392 }
393
394 if (sListener != null) {
395 sListener.statusChanged("Setting classifier options...");
396 }
397 extractAndCheckOptions();
398
399 if (sListener != null) {
400 sListener.statusChanged("Checking dataset definition...");
401 }
402 checkDatasetDefinition();
403
404
408 if (sListener != null) {
410 sListener.statusChanged("Initialising dataset...");
411
412 }
413 trainingData = new java.util.ArrayList();
414
415 if (sListener != null) {
416 sListener.statusChanged("");
417 }
418 }
420
424 public void load(java.io.InputStream is) throws java.io.IOException {
425 if (sListener != null) {
426 sListener.statusChanged("Loading model...");
427
428 }
429 java.io.ObjectInputStream ois = new java.io.ObjectInputStream(is);
430
431 try {
432 maxentClassifier = (opennlp.maxent.MaxentModel) ois.readObject();
433 trainingData = (java.util.List) ois.readObject();
434 datasetDefinition = (DatasetDefintion) ois.readObject();
435 datasetChanged = ois.readBoolean();
436
437 cutoff = ois.readInt();
438 confidenceThreshold = ois.readDouble();
439 iterations = ois.readInt();
440 verbose = ois.readBoolean();
441 smoothing = ois.readBoolean();
442 smoothingObservation = ois.readDouble();
443 }
444 catch (ClassNotFoundException cnfe) {
445 throw new gate.util.GateRuntimeException(cnfe.toString());
446 }
447 ois.close();
448
449 if (sListener != null) {
450 sListener.statusChanged("");
451 }
452 }
453
454
458 public void save(java.io.OutputStream os) throws java.io.IOException {
459 if (sListener != null) {
460 sListener.statusChanged("Saving model...");
461
462 }
463 java.io.ObjectOutputStream oos = new java.io.ObjectOutputStream(os);
464
465 oos.writeObject(maxentClassifier);
466 oos.writeObject(trainingData);
467 oos.writeObject(datasetDefinition);
468 oos.writeBoolean(datasetChanged);
469
470 oos.writeInt(cutoff);
471 oos.writeDouble(confidenceThreshold);
472 oos.writeInt(iterations);
473 oos.writeBoolean(verbose);
474 oos.writeBoolean(smoothing);
475 oos.writeDouble(smoothingObservation);
476
477 oos.flush();
478 oos.close();
479
480 if (sListener != null) {
481 sListener.statusChanged("");
482 }
483 }
484
485
489 public java.util.List getActions() {
490 return actionsList;
491 }
492
493
497 public void setOwnerPR(gate.ProcessingResource pr) {
498 this.owner = pr;
499 }
500
501 public DatasetDefintion getDatasetDefinition() {
502 return datasetDefinition;
503 }
504
505
508 protected class SaveModelAction
509 extends javax.swing.AbstractAction {
510 public SaveModelAction() {
511 super("Save model");
512 putValue(SHORT_DESCRIPTION, "Saves the ML model to a file");
513 }
514
515
521 public void actionPerformed(java.awt.event.ActionEvent evt) {
522 Runnable runnable = new Runnable() {
523 public void run() {
524 javax.swing.JFileChooser fileChooser
525 = gate.gui.MainFrame.getFileChooser();
526 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
527 fileChooser.setFileSelectionMode(javax.swing.JFileChooser.FILES_ONLY);
528 fileChooser.setMultiSelectionEnabled(false);
529 if (fileChooser.showSaveDialog(null)
530 == javax.swing.JFileChooser.APPROVE_OPTION) {
531 java.io.File file = fileChooser.getSelectedFile();
532 try {
533 gate.gui.MainFrame.lockGUI("Saving ML model...");
534 save(new java.util.zip.GZIPOutputStream(
535 new java.io.FileOutputStream(
536 file.getCanonicalPath(), false)));
537 }
538 catch (java.io.IOException ioe) {
539 javax.swing.JOptionPane.showMessageDialog(null,
540 "Error!\n" +
541 ioe.toString(),
542 "GATE", javax.swing.JOptionPane.ERROR_MESSAGE);
543 ioe.printStackTrace(gate.util.Err.getPrintWriter());
544 }
545 finally {
546 gate.gui.MainFrame.unlockGUI();
547 }
548 }
549 }
550 };
551 Thread thread = new Thread(runnable, "ModelSaver(serialisation)");
552 thread.setPriority(Thread.MIN_PRIORITY);
553 thread.start();
554 }
555 }
556
557
562 protected class LoadModelAction
563 extends javax.swing.AbstractAction {
564 public LoadModelAction() {
565 super("Load model");
566 putValue(SHORT_DESCRIPTION, "Loads a ML model from a file");
567 }
568
569
575 public void actionPerformed(java.awt.event.ActionEvent evt) {
576 Runnable runnable = new Runnable() {
577 public void run() {
578 javax.swing.JFileChooser fileChooser
579 = gate.gui.MainFrame.getFileChooser();
580 fileChooser.setFileFilter(fileChooser.getAcceptAllFileFilter());
581 fileChooser.setFileSelectionMode(javax.swing.JFileChooser.FILES_ONLY);
582 fileChooser.setMultiSelectionEnabled(false);
583 if (fileChooser.showOpenDialog(null)
584 == javax.swing.JFileChooser.APPROVE_OPTION) {
585 java.io.File file = fileChooser.getSelectedFile();
586 try {
587 gate.gui.MainFrame.lockGUI("Loading model...");
588 load(new java.util.zip.GZIPInputStream(
589 new java.io.FileInputStream(file)));
590 }
591 catch (java.io.IOException ioe) {
592 javax.swing.JOptionPane.showMessageDialog(null,
593 "Error!\n" +
594 ioe.toString(),
595 "GATE", javax.swing.JOptionPane.ERROR_MESSAGE);
596 ioe.printStackTrace(gate.util.Err.getPrintWriter());
597 }
598 finally {
599 gate.gui.MainFrame.unlockGUI();
600 }
601 }
602 }
603 };
604 Thread thread = new Thread(runnable, "ModelLoader(serialisation)");
605 thread.setPriority(Thread.MIN_PRIORITY);
606 thread.start();
607 }
608 }
609
610 protected gate.creole.ml.DatasetDefintion datasetDefinition;
611
612
615 protected opennlp.maxent.MaxentModel maxentClassifier;
616
617
623 protected List trainingData;
624
625
628 protected org.jdom.Element optionsElement;
629
630
634 protected boolean datasetChanged = false;
635
636
640 protected List actionsList;
641
642 protected gate.ProcessingResource owner;
643
644 protected gate.event.StatusListener sListener;
645
646
652 protected int cutoff = 0;
653 protected double confidenceThreshold = 0;
654 protected int iterations = 10;
655 protected boolean verbose = false;
656 protected boolean smoothing = false;
657 protected double smoothingObservation = 0.1;
658
659 }