tuffy.infer
Class MRF

java.lang.Object
  extended by tuffy.infer.MRF

public class MRF
extends java.lang.Object

In-memory data structure representing an MRF.


Nested Class Summary
static class MRF.INIT_STRATEGY
           
private  class MRF.myInt
           
 
Field Summary
protected  java.util.HashMap<java.lang.Integer,java.util.ArrayList<GClause>> adj
          Index from GAtom ID to GClause.
 java.util.HashMap<java.lang.Integer,GAtom> atoms
          Map from GAtom ID to GAtom object.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseNiNjViolationTallies
          This map records the tallies for calculating E(v_i*v_j).
 java.util.ArrayList<GClause> clauses
          Array of all GClause objects in this MRF.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseSatTallies
          This array records total number of satisfaction for a clause.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseSquareVioTallies
          This array records total number of square violation for a clause.
 java.util.HashMap<java.lang.String,java.lang.Long> clauseVioTallies
          This array records total number of violation for a clause.
private  java.util.HashSet<java.lang.Integer> coreAtoms
           
protected  java.util.HashSet<java.lang.Integer> dirtyAtoms
          Atoms that have been flipped since last saving to low.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfNiNjViolation
          This map records the expectation of E(v_i*v_j).
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSatisfication
          This array records the expection of #satisfaction for each clause.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSquareViolation
          This map records the expectation of square #violation for each clause.
 java.util.HashMap<java.lang.String,java.lang.Double> expectationOfViolation
          This map records the expectation of #violation for each clause.
 long inferOps
           
protected  MRF.INIT_STRATEGY initStrategy
           
 KeyBlock keyBlock
           
 double lowCost
          Lowest cost ever seen.
private  MarkovLogicNetwork mln
          The MLN object.
private  int nClauseVioTallies
          Number of iterations of tallies.
(package private)  int numClausesInCut
           
(package private)  int numCriticalNodesLocal
           
 boolean ownsAllAtoms
           
(package private)  int partID
           
private  java.util.Random rand
           
protected  boolean sampleSatMode
          The flag indicating whether MCSAT is running WalkSAT or SampleSAT.
protected  int totalAlive
          Number of GClauses that is selected, and therefore must be satisfied by next SampleSAT invocation of MCSAT.
protected  double totalCost
          The total cost of this MRF under current atoms' truth setting.
protected  HashArray<GClause> unsat
          Array of unsatisfied GClauses under current atoms' truth setting.
private  boolean usingBlocks
           
(package private)  double weightClausesInCut
           
 
Constructor Summary
MRF(MarkovLogicNetwork mln)
          Default constructor.
MRF(MarkovLogicNetwork mln, int partID, java.util.HashMap<java.lang.Integer,GAtom> gatoms)
           
 
Method Summary
 void addAtom(int aid)
          Add an atom into this MRF.
private  void adjustAtomClauseRelation(java.util.ArrayList<GClause> tlfac, java.util.ArrayList<GClause> flfac, int picked)
           
private  void assignAllFalseTruthValues()
          Set all atoms to false.
private  void assignGreedyTruthValues()
          Assign inital truth values according to some ad hoc and heuristic stats.
private  void assignRandomTruthValues()
          Set random atom truth values.
 void auditClauseViolations()
          Track ground clause violations to fo-clauses.
protected  void buildIndices()
          Build literal-->clauses index.
protected  double calcCosts()
          Compute total cost and per-atom delta cost.
private  double calcCostsForWalkSAT(java.util.HashSet<java.lang.Integer> needToBeReset)
           
 void calcExpViolation()
          Calculating the different expectations by filling the HashMaps related to expectations in this class.
private  void calcNSAT(GClause f)
          Calculate the number of true literals in a clause.
 void discard()
          Discard all data structures, in hope of facilitating faster GC.
protected  void enableAllClauses()
          Reset all clauses to be alive.
protected  void fixAtom(int aid, boolean t)
          Fix the truth value of an atom.
private  java.util.HashSet<java.lang.Integer> getAtomNeighbors(int aid)
          For research experiments! Get all neighboring atoms of one atom.
 java.util.HashSet<java.lang.Integer> getCoreAtoms()
           
 double getCost()
           
private  java.util.ArrayList<GAtom> getFlipSequence(GAtom a)
           
 MRF.INIT_STRATEGY getInitStrategy()
           
 MarkovLogicNetwork getMLN()
           
 void inferSweepSAT(int nTries, int nSteps)
          Deprecated.  
 void inferWalkSAT(int nTries, int nSteps)
          Run WalkSAT.
private  void inferWalkSATwithBlocks(int nTries, int nSteps)
          Run WalkSAT with blocks.
private  void inferWalkSATwithoutBlocking(int nTries, int nSteps)
          Run WalkSAT.
 void initMRF()
          Initialize the state of the MRF.
 void invalidateLowCost()
          Reset low-cost to infinity.
protected  boolean isAlwaysTrue(GClause gc)
          Test if a clause is always true no matter how we flip flippable atoms.
protected  boolean isTrueLit(int lit)
          Check if a given literal is true under current truth assignment.
private  void maintainKeyConstraints()
           
 void mcsat(int numSamples, int numFlips)
          Execute the MC-SAT algorithm.
protected  boolean ownsAtom(int aid)
          Test if a given atom is "owned" by this MRF.
private  void performMCSatStep(int numFlips)
          Perform one sample of MC-SAT
 double recalcCost()
          Recalculate total cost.
 void restoreLowTruth()
          Assign the recorded low-cost truth values to current truth values.
protected  int retainOnlyHardClauses()
          Kill soft clauses.
protected  int retainSomeGoodClauses()
          Retain a subset of currently satisfied clauses, according to the sampling method of MC-SAT.
protected  boolean sampleSAT(int nSteps)
          SampleSAT (with WalkSAT inside), used to uniformly sample a zero-cost world.
protected  void saveLowTruth(double cost)
          If current truths have the lowest cost, save them.
private  void saveTruthAsLow()
           
 void setInitStrategy(MRF.INIT_STRATEGY strategy)
           
private  java.util.ArrayList<MRF> split(int np)
          For research experiments! Split the MRF into multiple pieces by agglomerative clustering.
protected  boolean testChance(double p)
          Coin flipping.
private  void testKeyConstraints()
           
protected  void unfixAllAtoms()
          Unfix all atoms.
private  void unitPropagation()
          Try to satisfy as many clauses as possible with unit propagation.
 void updateAtomMarginalProbs(int numSamples)
           
private  void updateAtomTruthTallies()
          For each atom, increment its truth tally by one if it's currently true.
 void updateClauseVoiTallies()
          Update the number of violations of a clause.
 void updateClauseWeights(java.util.HashMap<java.lang.String,java.lang.Double> currentWeight)
          Change the weight of GClause based on updated weight of Clause.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

initStrategy

protected MRF.INIT_STRATEGY initStrategy

keyBlock

public KeyBlock keyBlock

atoms

public java.util.HashMap<java.lang.Integer,GAtom> atoms
Map from GAtom ID to GAtom object.


ownsAllAtoms

public boolean ownsAllAtoms

usingBlocks

private boolean usingBlocks

coreAtoms

private java.util.HashSet<java.lang.Integer> coreAtoms

clauses

public java.util.ArrayList<GClause> clauses
Array of all GClause objects in this MRF.


unsat

protected HashArray<GClause> unsat
Array of unsatisfied GClauses under current atoms' truth setting.


adj

protected java.util.HashMap<java.lang.Integer,java.util.ArrayList<GClause>> adj
Index from GAtom ID to GClause.


numCriticalNodesLocal

int numCriticalNodesLocal

numClausesInCut

int numClausesInCut

weightClausesInCut

double weightClausesInCut

dirtyAtoms

protected java.util.HashSet<java.lang.Integer> dirtyAtoms
Atoms that have been flipped since last saving to low.


totalCost

protected double totalCost
The total cost of this MRF under current atoms' truth setting.


lowCost

public double lowCost
Lowest cost ever seen.


totalAlive

protected int totalAlive
Number of GClauses that is selected, and therefore must be satisfied by next SampleSAT invocation of MCSAT.


sampleSatMode

protected boolean sampleSatMode
The flag indicating whether MCSAT is running WalkSAT or SampleSAT.


partID

int partID

inferOps

public long inferOps

mln

private MarkovLogicNetwork mln
The MLN object.


rand

private java.util.Random rand

expectationOfViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfViolation
This map records the expectation of #violation for each clause. This is filled by MCSAT#calcExpViolation().


expectationOfSquareViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSquareViolation
This map records the expectation of square #violation for each clause. This is filled by MCSAT#calcExpViolation().


clauseNiNjViolationTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseNiNjViolationTallies
This map records the tallies for calculating E(v_i*v_j).


expectationOfNiNjViolation

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfNiNjViolation
This map records the expectation of E(v_i*v_j). This is filled by MCSAT#calcExpViolation().


expectationOfSatisfication

public java.util.HashMap<java.lang.String,java.lang.Double> expectationOfSatisfication
This array records the expection of #satisfaction for each clause. This is filled by MCSAT#calcExpViolation().


clauseVioTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseVioTallies
This array records total number of violation for a clause. Dividing this number by MCSAT#nClauseVioTallies will give the estimated expectation of #violation.


clauseSquareVioTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseSquareVioTallies
This array records total number of square violation for a clause. Dividing this number by MCSAT#nClauseVioTallies will give the estimated expectation of #violation.


clauseSatTallies

public java.util.HashMap<java.lang.String,java.lang.Long> clauseSatTallies
This array records total number of satisfaction for a clause.


nClauseVioTallies

private int nClauseVioTallies
Number of iterations of tallies.

Constructor Detail

MRF

public MRF(MarkovLogicNetwork mln)
Default constructor. Does not really do anything.


MRF

public MRF(MarkovLogicNetwork mln,
           int partID,
           java.util.HashMap<java.lang.Integer,GAtom> gatoms)
Parameters:
partID - id of this MRF
gatoms - ground atoms
Method Detail

getInitStrategy

public MRF.INIT_STRATEGY getInitStrategy()

setInitStrategy

public void setInitStrategy(MRF.INIT_STRATEGY strategy)

getCost

public double getCost()

getMLN

public MarkovLogicNetwork getMLN()

invalidateLowCost

public void invalidateLowCost()
Reset low-cost to infinity.


split

private java.util.ArrayList<MRF> split(int np)
For research experiments! Split the MRF into multiple pieces by agglomerative clustering. Each piece contains up to 2/np of the total atoms.

Parameters:
np - number of pieces
Returns:
the pieces, each as an individual MRF

getAtomNeighbors

private java.util.HashSet<java.lang.Integer> getAtomNeighbors(int aid)
For research experiments! Get all neighboring atoms of one atom.

Parameters:
aid - id of the core atom
Returns:
id of neighbors

discard

public void discard()
Discard all data structures, in hope of facilitating faster GC.


addAtom

public void addAtom(int aid)
Add an atom into this MRF.

Parameters:
aid - id of the atom

ownsAtom

protected boolean ownsAtom(int aid)
Test if a given atom is "owned" by this MRF. An atom may not belong to this MRF if this MRF represents a partition of a component that has multiple partitions.

Parameters:
aid - id of the atom

saveLowTruth

protected void saveLowTruth(double cost)
If current truths have the lowest cost, save them.

Parameters:
cost - the current cost

saveTruthAsLow

private void saveTruthAsLow()

isTrueLit

protected boolean isTrueLit(int lit)
Check if a given literal is true under current truth assignment.

Parameters:
lit - the literal represented as an integer

isAlwaysTrue

protected boolean isAlwaysTrue(GClause gc)
Test if a clause is always true no matter how we flip flippable atoms.

Parameters:
gc - the clause

fixAtom

protected void fixAtom(int aid,
                       boolean t)
Fix the truth value of an atom.

Parameters:
aid - id of the atom
t - truth value to be fixed

retainSomeGoodClauses

protected int retainSomeGoodClauses()
Retain a subset of currently satisfied clauses, according to the sampling method of MC-SAT.

Returns:
the number of retained clauses

unfixAllAtoms

protected void unfixAllAtoms()
Unfix all atoms.


restoreLowTruth

public void restoreLowTruth()
Assign the recorded low-cost truth values to current truth values.


enableAllClauses

protected void enableAllClauses()
Reset all clauses to be alive.


buildIndices

protected void buildIndices()
Build literal-->clauses index. Used by WalkSAT.


testChance

protected boolean testChance(double p)
Coin flipping.

Parameters:
p - probability of returning true

testKeyConstraints

private void testKeyConstraints()

inferWalkSAT

public void inferWalkSAT(int nTries,
                         int nSteps)
Run WalkSAT.

Parameters:
nTries - number of tries
nSteps - number of steps per try

inferWalkSATwithoutBlocking

private void inferWalkSATwithoutBlocking(int nTries,
                                         int nSteps)
Run WalkSAT.

Parameters:
nTries - number of tries
nSteps - number of steps per try

inferWalkSATwithBlocks

private void inferWalkSATwithBlocks(int nTries,
                                    int nSteps)
Run WalkSAT with blocks.

Parameters:
nTries - number of tries
nSteps - number of steps per try

adjustAtomClauseRelation

private void adjustAtomClauseRelation(java.util.ArrayList<GClause> tlfac,
                                      java.util.ArrayList<GClause> flfac,
                                      int picked)

inferSweepSAT

public void inferSweepSAT(int nTries,
                          int nSteps)
Deprecated. 

Run SweepSAT for MAP inference.

Parameters:
nTries - number of tries
nSteps - number of steps per try

initMRF

public void initMRF()
Initialize the state of the MRF.


maintainKeyConstraints

private void maintainKeyConstraints()

assignAllFalseTruthValues

private void assignAllFalseTruthValues()
Set all atoms to false.


assignRandomTruthValues

private void assignRandomTruthValues()
Set random atom truth values.


assignGreedyTruthValues

private void assignGreedyTruthValues()
Assign inital truth values according to some ad hoc and heuristic stats.


calcNSAT

private void calcNSAT(GClause f)
Calculate the number of true literals in a clause.

Parameters:
f -

getFlipSequence

private java.util.ArrayList<GAtom> getFlipSequence(GAtom a)

calcCosts

protected double calcCosts()
Compute total cost and per-atom delta cost. The delta cost of an atom is the change in the total cost if this atom is flipped.

Returns:
total cost

calcCostsForWalkSAT

private double calcCostsForWalkSAT(java.util.HashSet<java.lang.Integer> needToBeReset)

auditClauseViolations

public void auditClauseViolations()
Track ground clause violations to fo-clauses. Stats are records on a per fo-clause basis.

See Also:
Stats.reportMostViolatedClauses(tuffy.infer.MRF, int)

recalcCost

public double recalcCost()
Recalculate total cost.

Returns:
updated total cost

getCoreAtoms

public java.util.HashSet<java.lang.Integer> getCoreAtoms()

retainOnlyHardClauses

protected int retainOnlyHardClauses()
Kill soft clauses.

Returns:
the number of hard clauses

sampleSAT

protected boolean sampleSAT(int nSteps)
SampleSAT (with WalkSAT inside), used to uniformly sample a zero-cost world. WalkSAT is used as a SAT solver to find the first (quasi-)zero-cost world. Simulated annealing (SA) is stochastically performed to wander around.

Parameters:
nSteps -
Returns:
true iff a zero-cost world was reached

updateAtomMarginalProbs

public void updateAtomMarginalProbs(int numSamples)

updateAtomTruthTallies

private void updateAtomTruthTallies()
For each atom, increment its truth tally by one if it's currently true.


mcsat

public void mcsat(int numSamples,
                  int numFlips)
Execute the MC-SAT algorithm.

Parameters:
numSamples - number of MC-SAT samples
numFlips - number of SampleSAT steps in each iteration

updateClauseVoiTallies

public void updateClauseVoiTallies()
Update the number of violations of a clause. For each GClause, their value can increase at most 1 for each MCSAT iteration. For Clause, their value can increase more, because there may be more than one GClauses associated with it.


calcExpViolation

public void calcExpViolation()
Calculating the different expectations by filling the HashMaps related to expectations in this class.


updateClauseWeights

public void updateClauseWeights(java.util.HashMap<java.lang.String,java.lang.Double> currentWeight)
Change the weight of GClause based on updated weight of Clause. This new weight will be aware by MCSAT. The cost of flipping atom and the unsat set for GClause will be calculated automatically by this function.

Parameters:
currentWeight - The weight of clauses to be flushed in this MCSAT instance.

performMCSatStep

private void performMCSatStep(int numFlips)
Perform one sample of MC-SAT

Parameters:
numFlips - number of sampleSAT flips

unitPropagation

private void unitPropagation()
Try to satisfy as many clauses as possible with unit propagation. Used as a preprocessing step of SampleSAT, which tries to uniformly sample among all zero-cost worlds.