Training my Relationship Extraction Model

In this tutorial, we will show how to create and train your own Relationship Extraction using REEL. The running example will consist on the implementation of a kernel-based technique described in:

A Shortest Path Dependency Kernel for Relation Extraction. Razvan Bunescu and Raymond J. Mooney. In Proceedings of the Joint Conference on Human Language Technology / Empirical Methods in Natural Language Processing (HLT/EMNLP), 2005.

The first step for creating your Relationship Extraction program is to create a Core. In general, a core defines the necessary representation of the data as well as how it will be processed. Since our example consists on a kernel-based technique, we will implement a specific type of core which is a kernel:

package edu.columbia.cs.ref.model.core.impl;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import opennlp.tools.util.InvalidFormatException;
import edu.columbia.cs.ref.algorithm.feature.generation.FeatureGenerator;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.EntityBasedChunkingFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.OpenNLPTokenizationFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.StanfordNLPDependencyGraphFG;
import edu.columbia.cs.ref.model.CandidateSentence;
import edu.columbia.cs.ref.model.core.Kernel;
import edu.columbia.cs.ref.model.core.structure.OperableStructure;
import edu.columbia.cs.ref.model.core.structure.impl.TaggedGraph;
import edu.columbia.cs.utils.Pair;
import edu.columbia.cs.utils.SimpleGraphNode;
import edu.columbia.cs.utils.TokenInformation;

//In order to implement a kernel-based technique your core must extend from the
//class Kernel
public class ShortestPathKernel extends Kernel {
    //This method is used to produce an empty structure that is compatible with this kernel
    //In the case of the ShortestPathKernel we must use a TaggedGraph (You may need to create
    //your own structure for your kernel)
    @Override
    public OperableStructure createOperableStructure(CandidateSentence sent) {
        return new TaggedGraph(sent);
    }

    //Some kernels need mandatory features. This method is responsible for indicating them.
    //In the case of the ShortestPathKernel the mandatory feature is a dependency parser
    @Override
    protected List createMandatoryFeatureGenerators() {
        List fg = new ArrayList();

        try {
            fg.add(new StanfordNLPDependencyGraphFG(\"/path/to/dependency/model.bin\",
                    new EntityBasedChunkingFG(new OpenNLPTokenizationFG(\"/path/to/tokenizer/model.bin\"))));
        } catch (InvalidFormatException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (IOException e) {
            e.printStackTrace();
            System.exit(1);
        }
        
        return fg;
    }
    
    //Main method that computes the kernel between the two structures. For more
    //details about the shortest path kernel please refer to our javadoc for this class
    //and the paper \"A Shortest Path Dependency Kernel for Relation Extraction\" by
    //Bunescu and Mooney.
    @Override
    public double evaluate(OperableStructure s1, OperableStructure s2) {
        return normalizedKernel((TaggedGraph)s1, (TaggedGraph)s2);
    }

    public double normalizedKernel(TaggedGraph s1, TaggedGraph s2) 
    {
        double k1 = kernel(s1, s2);
        double k2 = kernel(s2, s2);
        
        double k = kernel(s1, s2);
        if (k == 0)
            return 0;
        
        // normalize
        return k / Math.sqrt (k1 * k2);                
    }
    
    public double kernel(TaggedGraph s1, TaggedGraph s2){
        List, String>> path1 = s1.getGraph().getShortestPathEdges();
        List, String>> path2 = s2.getGraph().getShortestPathEdges();
        
        if(path1.size()!=path2.size() || path1.size()==0 || path1.size()>10){
            return 0;
        }
        
        List spath1 = getCorrectSequence(s1, path1);
        List spath2 = getCorrectSequence(s2, path2);
        
        int size=spath1.size();
        double result=1;
        for(int i=0;i            result*=count(spath1.get(i),spath2.get(i));
        }
        
        return result;
    }
    
    public List getCorrectSequence(TaggedGraph s, List, String>> path){
        List sPath = new ArrayList();
        SimpleGraphNode[] nodes=s.getGraph().getNodes();
        int entity1=-1;
        int entity2=-1;
        
        for(Pair, String> edge : path){
            Pair p = edge.a();
            
            TokenInformation origin = nodes[p.a()].getLabel();
            if(origin.isEntity1()){
                entity1=p.a();
            }else if(origin.isEntity2()){
                entity2=p.a();
            }
            
            TokenInformation destiny = nodes[p.b()].getLabel();
            if(destiny.isEntity1()){
                entity1=p.b();
            }else if(destiny.isEntity2()){
                entity2=p.b();
            }
        }
        
        int currentToken=entity1;
        sPath.add(createTokenFeatureVector(nodes[currentToken].getLabel()));
        Set, String>> processedEdges = new HashSet, String>>();
        while(currentToken!=entity2){
            Pair p=null;
            for(Pair, String> edge : path){
                if(!processedEdges.contains(edge)){
                    p = edge.a();
                    if(p.a()==currentToken || p.b()==currentToken){
                        processedEdges.add(edge);
                        break;
                    }
                }
            }
            if(p.a()==currentToken){
                sPath.add(createEdgeFeatureVector(true));
                currentToken=p.b();
            }else{
                sPath.add(createEdgeFeatureVector(false));
                currentToken=p.a();
            }
            sPath.add(createTokenFeatureVector(nodes[currentToken].getLabel()));
        }
        
        return sPath;
    }
    
    private String[] createTokenFeatureVector(TokenInformation t){
        List feats=t.getOwnFeatures();
        String[] result=new String[feats.size()];

        result = feats.toArray(result);
        
        return result;
    }
    
    private String[] createEdgeFeatureVector(boolean isPos){
        if(isPos){
            return new String[]{\"->\"};
        }else{
            return new String[]{\"<-\"};
        }
    }
    
    public int count(String[] s1, String[] s2){
        int len = s1.length;
        int result=0;
        for(int i=0; i            if(s1[i].equals(s2[i])){
                result++;
            }
        }
        return result;
    }

}


The next step to create your own Relationship Extraction program is to train the model using the previously created Core. The following code shows how to create a model to extract a relationship between organization and a person (ORG-AFF) using the core shortest-path kernel. Notice that we assume that the training data is the ACE 2005 dataset. The ACE 2005 dataset is available in the LDC Catalog. If you want to use your own data you may need to create your own loader (you can see how in the tutorial for that task here).

import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

import edu.columbia.cs.ref.algorithm.CandidatesGenerator;
import edu.columbia.cs.ref.algorithm.StructureGenerator;
import edu.columbia.cs.ref.algorithm.feature.generation.FeatureGenerator;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.EntityBasedChunkingFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.GenericPartOfSpeechFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.OpenNLPPartOfSpeechFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.OpenNLPTokenizationFG;
import edu.columbia.cs.ref.algorithm.feature.generation.impl.SpansToStringsConvertionFG;
import edu.columbia.cs.ref.engine.Engine;
import edu.columbia.cs.ref.engine.impl.JLibSVMBinaryEngine;
import edu.columbia.cs.ref.model.CandidateSentence;
import edu.columbia.cs.ref.model.Dataset;
import edu.columbia.cs.ref.model.Document;
import edu.columbia.cs.ref.model.StructureConfiguration;
import edu.columbia.cs.ref.model.constraint.role.impl.EntityTypeConstraint;
import edu.columbia.cs.ref.model.core.impl.ShortestPathKernel;
import edu.columbia.cs.ref.model.core.structure.OperableStructure;
import edu.columbia.cs.ref.model.feature.impl.SequenceFS;
import edu.columbia.cs.ref.model.re.Model;
import edu.columbia.cs.ref.model.relationship.RelationshipType;
import edu.columbia.cs.ref.tool.document.splitter.impl.OpenNLPMESplitter;
import edu.columbia.cs.ref.tool.loader.document.impl.ace2005.ACE2005Loader;
import edu.columbia.cs.utils.Span;


public class TrainREModel {
    public static void main(String[] args) throws IOException{
        RelationshipType relationshipType = new RelationshipType(\"ORG-AFF\",\"Arg-1\",\"Arg-2\");
        relationshipType.setConstraints(new EntityTypeConstraint(\"ORG\"), \"Arg-2\");
        relationshipType.setConstraints(new EntityTypeConstraint(\"PER\"), \"Arg-1\");
        
        Set relationshipTypes = new HashSet();
        relationshipTypes.add(relationshipType);
        
        OpenNLPMESplitter splitter = new OpenNLPMESplitter(\"/path/to/sentence/splitter/model.bin\");
        CandidatesGenerator generator = new CandidatesGenerator(splitter);
        
        ACE2005Loader l = new ACE2005Loader(relationshipTypes);
        File ACEDir = new File(\"/path/to/training/data/trainingData/\");
        Dataset ace2005 = new Dataset(l,ACEDir,false);
                
        Set candidates = new HashSet();
        for(Document d : ace2005){
            candidates.addAll(generator.generateCandidates(d, relationshipTypes));    
        }
        
        StructureConfiguration conf = new StructureConfiguration(new ShortestPathKernel());
        FeatureGenerator> tokenizer = new OpenNLPTokenizationFG(\"/path/to/tokenizer/model.bin\");
        FeatureGenerator> fgChunk = new EntityBasedChunkingFG(tokenizer);
        FeatureGenerator> fgChuckString = new SpansToStringsConvertionFG(fgChunk);
        FeatureGenerator> fgPOS = new OpenNLPPartOfSpeechFG(\"/path/to/pos/model.bin\",fgChuckString);
        FeatureGenerator> fgGPOS = new GenericPartOfSpeechFG(fgPOS);
        conf.addFeatureGenerator(fgPOS);
        conf.addFeatureGenerator(fgGPOS);
        
        Set trainingData = StructureGenerator.generateStructures(candidates, conf);
        
        Engine classificationEngine = new JLibSVMBinaryEngine(conf, relationshipTypes);
        Model svmModel = classificationEngine.train(trainingData);
        
        edu.columbia.cs.ref.tool.io.SerializationHelper.write(\"/path/to/model/ORG-AFFModel.svm\", svmModel);
    }
}