1
14 package gate.creole.ml;
15
16 import java.util.*;
17
18 import org.jdom.Element;
19 import org.jdom.JDOMException;
20 import org.jdom.input.SAXBuilder;
21
22 import gate.*;
23 import gate.creole.*;
24 import gate.gui.ActionsPublisher;
25 import gate.util.*;
26
27
31
32 public class MachineLearningPR extends AbstractLanguageAnalyser
33 implements gate.gui.ActionsPublisher{
34
35 public MachineLearningPR(){
36 actionList = new ArrayList();
37 actionList.add(null);
38 }
39
40
46 public void cleanup() {
47 super.cleanup();
50
51 if (engine!=null) {
54 engine.cleanUp();
55 }
56 }
57
58
59 public Resource init() throws ResourceInstantiationException {
60 if(configFileURL == null){
61 throw new ResourceInstantiationException(
62 "No configuration file provided!");
63 }
64
65 org.jdom.Document jdomDoc;
66 SAXBuilder saxBuilder = new SAXBuilder(false);
67 try {
68 try{
69 jdomDoc = saxBuilder.build(configFileURL);
70 }catch(JDOMException jde){
71 throw new ResourceInstantiationException(jde);
72 }
73 } catch (java.io.IOException ex) {
74 throw new ResourceInstantiationException(ex);
75 }
76
77 Element rootElement = jdomDoc.getRootElement();
79 if(!rootElement.getName().equals("ML-CONFIG"))
80 throw new ResourceInstantiationException(
81 "Root element of dataset defintion file is \"" + rootElement.getName() +
82 "\" instead of \"ML-CONFIG\"!");
83
84 Element datasetElement = rootElement.getChild("DATASET");
86 if(datasetElement == null) throw new ResourceInstantiationException(
87 "No dataset definition provided in the configuration file!");
88 try{
89 datasetDefinition = new DatasetDefintion(datasetElement);
90 }catch(GateException ge){
91 throw new ResourceInstantiationException(ge);
92 }
93
94 Element engineElement = rootElement.getChild("ENGINE");
96 if(engineElement == null) throw new ResourceInstantiationException(
97 "No engine option provided in the configuration file!");
98 Element engineClassElement = engineElement.getChild("WRAPPER");
99 if(engineClassElement == null) throw new ResourceInstantiationException(
100 "No ML engine class provided!");
101 String engineClassName = engineClassElement.getTextTrim();
102 try{
103 Class engineClass = Class.forName(engineClassName);
104 engine = (MLEngine)engineClass.newInstance();
105 }catch(ClassNotFoundException cnfe){
106 throw new ResourceInstantiationException(
107 "ML engine class:" + engineClassName + "not found!");
108 }catch(IllegalAccessException iae){
109 throw new ResourceInstantiationException(iae);
110 }catch(InstantiationException ie){
111 throw new ResourceInstantiationException(ie);
112 }
113
114 if (engineElement.getChild("BATCH-MODE-CLASSIFICATION") == null) {
116 batchModeClassification = false;
117 } else {
118 batchModeClassification = true;
119 }
120
121 engine.setDatasetDefinition(datasetDefinition);
122 engine.setOptions(engineElement.getChild("OPTIONS"));
123 engine.setOwnerPR(this);
124 try{
125 engine.init();
126 }catch(GateException ge){
127 throw new ResourceInstantiationException(ge);
128 }
129
130 return this;
131 }
133
134
137 public void execute() throws ExecutionException {
138 interrupted = false;
139 if (document == null) {
141 throw new ExecutionException(
142 "No document provided!"
143 );
144 }
145
146 if (inputASName == null ||
147 inputASName.equals(""))
148 annotationSet = document.getAnnotations();
149 else
150 annotationSet = document.getAnnotations(inputASName);
151
152 if (training.booleanValue()) {
153 fireStatusChanged(
154 "Collecting training data from " + document.getName() + "...");
155 }
156 else {
157 fireStatusChanged(
158 "Applying ML model to " + document.getName() + "...");
159 }
160 fireProgressChanged(0);
161 AnnotationSet anns = annotationSet.
162 get(datasetDefinition.getInstanceType());
163 annotations = (anns == null || anns.isEmpty()) ?
164 new ArrayList() : new ArrayList(anns);
165 Collections.sort(annotations, new OffsetComparator());
166 Iterator annotationIter = annotations.iterator();
167 int index = 0;
168 int size = annotations.size();
169
170 cache = new Cache();
172
173 if (!batchModeClassification || training.booleanValue()) {
174 while (annotationIter.hasNext()) {
178 Annotation instanceAnn = (Annotation) annotationIter.next();
179 List attributeValues = new ArrayList(datasetDefinition.
180 getAttributes().size());
181 Iterator attrIter = datasetDefinition.getAttributes().iterator();
183 while (attrIter.hasNext()) {
184 Attribute attr = (Attribute) attrIter.next();
185 if (attr.isClass && !training.booleanValue()) {
186 attributeValues.add(null);
188 }
189 else {
190 attributeValues.add(cache.getAttributeValue(index, attr));
191 }
192 }
193
194 if (training.booleanValue()) {
195 engine.addTrainingInstance(attributeValues);
196 }
197 else {
198 Object result = engine.classifyInstance(attributeValues);
199 if (result instanceof Collection) {
200 Iterator resIter = ( (Collection) result).iterator();
201 while (resIter.hasNext())
202 updateDocument(resIter.next(), index);
203 }
204 else {
205 updateDocument(result, index);
206 }
207 }
208
209 cache.shift();
210 if (index % 10 == 0) {
212 fireProgressChanged(index * 100 / size);
213 if (isInterrupted())
214 throw new ExecutionInterruptedException();
215 }
216 index++;
217 }
218
219 }
220 else {
221
226 List instancesToBeClassified = new ArrayList();
228
229 while (annotationIter.hasNext()) {
230 Annotation instanceAnn = (Annotation) annotationIter.next();
231 List attributeValues = new ArrayList(datasetDefinition.
232 getAttributes().size());
233 Iterator attrIter = datasetDefinition.getAttributes().iterator();
235 while (attrIter.hasNext()) {
236 Attribute attr = (Attribute) attrIter.next();
237 if (attr.isClass) {
238 attributeValues.add(null);
240 }
241 else {
242 attributeValues.add(cache.getAttributeValue(index, attr));
243 }
244 }
245
246 instancesToBeClassified.add(attributeValues);
249
250 cache.shift();
251
252 index++;
253 }
254
255 List classificationResults = engine.batchClassifyInstances(
258 instancesToBeClassified);
259
260
263 index = 0;
265 Iterator resultsIterator = classificationResults.iterator();
266 while (resultsIterator.hasNext()) {
267
268 Object result = resultsIterator.next();
269 if (result instanceof Collection) {
270 Iterator resIter = ( (Collection) result).iterator();
271 while (resIter.hasNext())
272 updateDocument(resIter.next(), index);
273 }
274 else {
275 updateDocument(result, index);
276 }
277
278 index++;
280 }
281 }
282 annotations = null;
283 }
285
286 protected void updateDocument(Object classificationResult, int instanceIndex){
287 Attribute classAttr = datasetDefinition.getClassAttribute();
289 String type = classAttr.getType();
290 String feature = classAttr.getFeature();
291 List classValues = classAttr.getValues();
292 FeatureMap features = Factory.newFeatureMap();
293 boolean shouldCreateAnnotation = true;
294 if(classValues != null && !classValues.isEmpty()){
295 String featureValue = (String)classificationResult;
298 features.put(feature, featureValue);
299 }else{
300 if(feature == null){
301 shouldCreateAnnotation = classificationResult.equals("true");
303 }else{
304 String featureValue = classificationResult.toString();
306 features.put(feature, featureValue);
307 }
308 }
309
310 if(shouldCreateAnnotation){
311 int coveredInstanceIndex = instanceIndex + classAttr.getPosition();
313 if(coveredInstanceIndex >= 0 &&
314 coveredInstanceIndex < annotations.size()){
315 Annotation coveredInstance = (Annotation)annotations.
316 get(coveredInstanceIndex);
317 annotationSet.add(coveredInstance.getStartNode(),
318 coveredInstance.getEndNode(),
319 type, features);
320 }
321 }
322 }
323
324
325
329 public List getActions(){
330 List result = new ArrayList();
331 result.addAll(actionList);
332 if(engine instanceof ActionsPublisher){
333 result.addAll(((ActionsPublisher)engine).getActions());
334 }
335 return result;
336 }
337
338 protected class Cache{
339 public Cache(){
340 int forwardCacheSize = 0;
342 int backwardCacheSize = 0;
343 Iterator attrIter = datasetDefinition.getAttributes().iterator();
344 while(attrIter.hasNext()){
345 Attribute anAttribute = (Attribute)attrIter.next();
346 if(anAttribute.getPosition() > 0){
347 if(anAttribute.getPosition() > forwardCacheSize){
349 forwardCacheSize = anAttribute.getPosition();
350 }
351 }else if(anAttribute.getPosition() < 0){
352 if(-anAttribute.getPosition() > backwardCacheSize){
354 backwardCacheSize = -anAttribute.getPosition();
355 }
356 }
357 }
358 forwardCache = new ArrayList(forwardCacheSize);
360 for(int i =0; i < forwardCacheSize; i++) forwardCache.add(null);
361 backwardCache = new ArrayList(backwardCacheSize);
362 for(int i =0; i < backwardCacheSize; i++) backwardCache.add(null);
363 }
364
365
372 public String getAttributeValue(int instanceIndex, Attribute attribute){
373 int actualPosition = instanceIndex + attribute.getPosition();
375 if(actualPosition < 0 || actualPosition >= annotations.size()) return null;
376
377 if(attribute.getPosition() == 0){
379 if(currentAttributes == null) currentAttributes = new HashMap();
381 return getValue(attribute, instanceIndex, currentAttributes);
382 }else if(attribute.getPosition() > 0){
383 Map attributesMap = (Map)forwardCache.get(attribute.getPosition() - 1);
385 if(attributesMap == null){
386 attributesMap = new HashMap();
387 forwardCache.set(attribute.getPosition() - 1, attributesMap);
388 }
389 return getValue(attribute, actualPosition, attributesMap);
390 }else if(attribute.getPosition() < 0){
391 Map attributesMap = (Map)backwardCache.get(-attribute.getPosition() - 1);
393 if(attributesMap == null){
394 attributesMap = new HashMap();
395 backwardCache.set(-attribute.getPosition() - 1, attributesMap);
396 }
397 return getValue(attribute, actualPosition, attributesMap);
398 }
399 throw new LuckyException(
401 "Attribute position is neither 0, nor negative nor positive!");
402 }
403
404
408 public void shift(){
409 if(backwardCache.isEmpty()){
410 }else{
413 backwardCache.remove(backwardCache.size() - 1);
414 backwardCache.add(0, currentAttributes);
415 }
416 if(forwardCache.isEmpty()){
417 if(currentAttributes != null) currentAttributes.clear();
419 }else{
420 currentAttributes = (Map) forwardCache.remove(0);
421 forwardCache.add(null);
422 }
423 }
424
425
436 protected String getValue(Attribute attribute,
437 int instanceIndex,
438 Map cache){
439 String value = null;
440 String annType = attribute.getType();
441 String featureName = attribute.getFeature();
442 Map typeData = (Map)cache.get(annType);
443 if(typeData != null){
444 if(featureName == null){
445 value = (String)typeData.get(null);
447 }else{
448 value = (String)typeData.get(featureName);
449 }
450 }else{
451 Annotation instanceAnnot = (Annotation)annotations.get(instanceIndex);
454 AnnotationSet coverSubset = annotationSet.get(
455 annType,
456 instanceAnnot.getStartNode().getOffset(),
457 instanceAnnot.getEndNode().getOffset());
458 typeData = new HashMap();
459 cache.put(annType, typeData);
460 if(coverSubset == null || coverSubset.isEmpty()){
461 typeData.put(null, "false");
463 if(featureName == null) value = "false";
464 else value = null;
465 }else{
466 typeData.putAll(((Annotation)coverSubset.iterator().next()).
467 getFeatures());
468 typeData.put(null, "true");
469 if(featureName == null) value = "true";
470 else value = (String)typeData.get(featureName);
471 }
472 }
473 return value;
474 }
475
476
489 protected List forwardCache;
490
491
504 protected List backwardCache;
505
506
518 protected Map currentAttributes;
519
520 }
521
522
523 public void setInputASName(String inputASName) {
524 this.inputASName = inputASName;
525 }
526 public String getInputASName() {
527 return inputASName;
528 }
529 public java.net.URL getConfigFileURL() {
530 return configFileURL;
531 }
532 public void setConfigFileURL(java.net.URL configFileURL) {
533 this.configFileURL = configFileURL;
534 }
535 public void setTraining(Boolean training) {
536 this.training = training;
537 }
538 public Boolean getTraining() {
539 return training;
540 }
541 public MLEngine getEngine() {
542 return engine;
543 }
544 public void setEngine(MLEngine engine) {
545 this.engine = engine;
546 }
547
548 private java.net.URL configFileURL;
549 protected DatasetDefintion datasetDefinition;
550
551 protected MLEngine engine;
552
553 protected String inputASName;
554
555 protected AnnotationSet annotationSet;
556
557 protected List annotations;
558
559 protected List actionList;
560
561 protected Cache cache;
562 private Boolean training;
563
564
568 protected boolean batchModeClassification;
569 }