Log in Help
Print
HomegatepluginsTagger_PennBiosrceduupenncistaggersvariation 〉 VariationTrainer.java
 
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package edu.upenn.cis.taggers.variation;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.util.regex.Pattern;

import edu.umass.cs.mallet.base.fst.CRF4;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.TokenSequence2FeatureVectorSequence;
import edu.umass.cs.mallet.base.pipe.TokenSequenceLowercase;
import edu.umass.cs.mallet.base.pipe.iterator.LineGroupIterator;
import edu.umass.cs.mallet.base.pipe.tsf.FeaturesInWindow;
import edu.umass.cs.mallet.base.pipe.tsf.RegexMatches;
import edu.umass.cs.mallet.base.pipe.tsf.TokenText;
import edu.umass.cs.mallet.base.pipe.tsf.TokenTextCharNGrams;
import edu.umass.cs.mallet.base.pipe.tsf.TrieLexiconMembership;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.InstanceList;

import edu.upenn.cis.taggers.Model;

public class VariationTrainer
{
    int numEvaluations = 0;
    static int iterationsBetweenEvals = 16;

    private static String CAPS = "[A-Z]";
    private static String LOW = "[a-z]";
    private static String CAPSNUM = "[A-Z0-9]";
    private static String ALPHA = "[A-Za-z]";
    private static String ALPHANUM = "[A-Za-z0-9]";
    private static String PUNT = "[,\\.;:?!()]";
    private static String QUOTE = "[\"`']";
 
    public static void main (String[] args) throws FileNotFoundException, Exception
    {

	// train a new model
	if(args[0].equals("train")) {
  
	    Pipe p = new SerialPipes (new Pipe[] {
		new VariationSentence2TokenSequence (),
		new RegexMatches ("INITCAP", Pattern.compile (CAPS+".*")),
		new RegexMatches ("CAPITALIZED", Pattern.compile (CAPS+LOW+"*")),
		new RegexMatches ("ALLCAPS", Pattern.compile (CAPS+"+")),
		new RegexMatches ("MIXEDCAPS", Pattern.compile ("[A-Z][a-z]+[A-Z][A-Za-z]*")),
		new RegexMatches ("CONTAINSDIGITS", Pattern.compile (".*[0-9].*")),
		new RegexMatches ("SINGLEDIGITS", Pattern.compile ("[0-9]")),
		new RegexMatches ("DOUBLEDIGITS", Pattern.compile ("[0-9][0-9]")),
		new RegexMatches ("ALLDIGITS", Pattern.compile ("[0-9]+")),
		new RegexMatches ("NUMERICAL", Pattern.compile ("[-0-9]+[\\.,]+[0-9\\.,]+")),
		new RegexMatches ("ALPHNUMERIC", Pattern.compile ("[A-Za-z0-9]+")),
		new RegexMatches ("ROMAN", Pattern.compile ("[ivxdlcm]+|[IVXDLCM]+")),
		new RegexMatches ("MULTIDOTS", Pattern.compile ("\\.\\.+")),
		new RegexMatches ("ENDSINDOT", Pattern.compile ("[^\\.]+.*\\.")),
		new RegexMatches ("CONTAINSDASH", Pattern.compile (ALPHANUM+"+-"+ALPHANUM+"*")),
		new RegexMatches ("ACRO", Pattern.compile ("[A-Z][A-Z\\.]*\\.[A-Z\\.]*")),
		new RegexMatches ("LONELYINITIAL", Pattern.compile (CAPS+"\\.")),
		new RegexMatches ("SINGLECHAR", Pattern.compile (ALPHA)),
		new RegexMatches("CAPLETTER", Pattern.compile ("[A-Z]")),
		new RegexMatches ("PUNC", Pattern.compile (PUNT)),
		new RegexMatches ("QUOTE", Pattern.compile (QUOTE)),
		new TokenSequenceLowercase(),
		new RegexMatches ("ISTYPE1", Pattern.compile (".*ploidy|.*somy")),
    
		new TokenText ("WORD="),

		new TokenTextCharNGrams ("CHARNGRAM=", new int[] {2,3,4}),
    
		// List membership
		new TrieLexiconMembership("ISTYPE",new File("data/training/variation/lists/VARIATION.TYPES"),true),
		new VariationRegexTrieLexiconMembership("ISLOC",new File("data/training/variation/lists/LOCATION.PATTERNS"),true),
		new VariationRegexTrieLexiconMembership("ISSTATE",new File("data/training/variation/lists/STATE.PATTERNS"),true),
    
		new FeaturesInWindow("WINDOW=",-1,1,Pattern.compile("IS.*|WORD=.*|[A-Z]+"),true),
    
		new TokenSequence2FeatureVectorSequence (true, true)
	    });
   
	    InstanceList allData = new InstanceList (p);
	    allData.add (new LineGroupIterator
			 (new FileReader (new File (args[1])), Pattern.compile("^$"), true));
	    System.out.println ("Number of predicates in training data: "+p.getDataAlphabet().size());
   
	    Alphabet data = p.getDataAlphabet();   
	    data.stopGrowth();
	    // Print out all the target names  
	    Alphabet targets = p.getTargetAlphabet();

	    InstanceList testData = null;
	    if(!args[2].equals("null")) {
		testData = new InstanceList (p);
		testData.add (new LineGroupIterator (new FileReader
						     (new File (args[2])), Pattern.compile("^$"), true));
	    }

	    // Print out some feature information
	    System.out.println ("Number of features = "+p.getDataAlphabet().size());

	    VariationSegmentationEvaluator eval =
		new VariationSegmentationEvaluator (new String[] {"B-type", "B-location",
								  "B-state-original","B-state-altered"},
						    new String[] {"I-type", "I-location",
								  "I-state-original","I-state-altered"});
   
	    CRF4 crf = new CRF4(p,null);
   
	    crf.addFullyConnectedStatesForLabels();
	    crf.setGaussianPriorVariance (5.0);
   
	    crf.train (allData, null, testData, eval, 200);

	    crf.write(new File(args[3]));
	}
	else if(args[0].equals("test")) {
	    VariationSegmentationEvaluator eval =
		new VariationSegmentationEvaluator (new String[] {"B-type", "B-location",
								  "B-state-original","B-state-altered"},
						    new String[] {"I-type", "I-location",
								  "I-state-original","I-state-altered"});
   
	    //ObjectInputStream ois;
	    //ois = new ObjectInputStream(new FileInputStream(args[2]));
   
	    CRF4 crf = Model.loadAndRetrieveModel(args[2]);//(CRF4)ois.readObject();
	    crf.getInputAlphabet().stopGrowth();

	    SerialPipes p = (SerialPipes)crf.getInputPipe();
   
	    InstanceList allData = new InstanceList (p);
	    System.out.println(p.getDataAlphabet().size());
	    allData.add (new LineGroupIterator (new FileReader (new File (args[1])), Pattern.compile("^$"), true));
	    
	    crf.evaluate(eval,allData);
	    //eval.test(crf,allData,"Testing",null);
   
	}
	else {
	    System.out.println("Bad arguments.");
	    System.exit(0);
	}
  
    }
 
}