#ifdef MATLAB_MEX_FILE #include #endif #include "GCoptimization.h" #include "LinkedBlockList.h" #include #include #include #include // will leave this one just for the laughs :) //#define olga_assert(expr) assert(!(expr)) // Choose reasonably high-precision timer (sub-millisec resolution if possible). #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #define VC_EXTRALEAN #define NOMINMAX #include extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = 0; extern "C" inline gcoclock_t gcoclock() // TODO: not thread safe; separate begin/end so that end doesn't have to check for query frequency { gcoclock_t result = 0; if (GCO_CLOCKS_PER_SEC == 0) QueryPerformanceFrequency((LARGE_INTEGER*)&GCO_CLOCKS_PER_SEC); QueryPerformanceCounter((LARGE_INTEGER*)&result); return result; } #else extern "C" gcoclock_t GCO_CLOCKS_PER_SEC = CLOCKS_PER_SEC; extern "C" gcoclock_t gcoclock() { return clock(); } #endif #ifdef MATLAB_MEX_FILE extern "C" bool utIsInterruptPending(); static void flushnow() { // Don't flush to frequently, for overall speed. static gcoclock_t prevclock = 0; gcoclock_t now = gcoclock(); if (now - prevclock > GCO_CLOCKS_PER_SEC/5) { prevclock = now; mexEvalString("drawnow;"); } } #define INDEX0 1 // print 1-based label and site indices for MATLAB #else inline static bool utIsInterruptPending() { return false; } static void flushnow() { } #define INDEX0 0 // print 0-based label and site indices #endif // Singly-linked list helper functions; works on any struct with a 'next' member. template void slist_clear(T*& head) { while (head) { T* temp = head; head = head->next; delete temp; } } template void slist_prepend(T*& head, T* val) { val->next = head; head = val; } void GCException::Report() { printf("\n%s\n",message); exit(0); } ///////////////////////////////////////////////////////////////////////////////////////////////// // First we have functions for the base class ///////////////////////////////////////////////////////////////////////////////////////////////// // Constructor for base class GCoptimization::GCoptimization(SiteID nSites, LabelID nLabels) : m_num_labels(nLabels) , m_num_sites(nSites) , m_datacostIndividual(0) , m_smoothcostIndividual(0) , m_labelcostsAll(0) , m_labelcostsByLabel(0) , m_labelcostCount(0) , m_smoothcostFn(0) , m_datacostFn(0) , m_numNeighborsTotal(0) , m_queryActiveSitesExpansion(&GCoptimization::queryActiveSitesExpansion) , m_setupDataCostsSwap(0) , m_setupDataCostsExpansion(0) , m_setupSmoothCostsSwap(0) , m_setupSmoothCostsExpansion(0) , m_applyNewLabeling(0) , m_updateLabelingDataCosts(0) , m_giveSmoothEnergyInternal(0) , m_solveSpecialCases(&GCoptimization::solveSpecialCases) , m_datacostFnDelete(0) , m_smoothcostFnDelete(0) , m_random_label_order(false) , m_verbosity(0) , m_labelingInfoDirty(true) , m_lookupSiteVar(new SiteID[nSites]) , m_labeling(new LabelID[nSites]) , m_labelTable(new LabelID[nLabels]) , m_labelingDataCosts(new EnergyTermType[nSites]) , m_labelCounts(new SiteID[nLabels]) , m_activeLabelCounts(new SiteID[m_num_labels]) , m_stepsThisCycle(0) , m_stepsThisCycleTotal(0) { if ( nLabels <= 1 ) handleError("Number of labels must be >= 2"); if ( nSites <= 0 ) handleError("Number of sites must be >= 1"); if ( !m_lookupSiteVar || !m_labelTable || !m_labeling ){ if (m_lookupSiteVar) delete [] m_lookupSiteVar; if (m_labelTable) delete [] m_labelTable; if (m_labeling) delete [] m_labeling; if (m_labelingDataCosts) delete [] m_labelingDataCosts; if (m_labelCounts) delete [] m_labelCounts; handleError("Not enough memory."); } memset(m_labeling, 0, m_num_sites*sizeof(LabelID)); memset(m_lookupSiteVar,-1,m_num_sites*sizeof(SiteID)); setLabelOrder(false); specializeSmoothCostFunctor(SmoothCostFnPotts()); } //------------------------------------------------------------------- GCoptimization::~GCoptimization() { delete [] m_labelTable; delete [] m_lookupSiteVar; delete [] m_labeling; delete [] m_labelingDataCosts; delete [] m_labelCounts; delete [] m_activeLabelCounts; if (m_datacostFnDelete) m_datacostFnDelete(m_datacostFn); if (m_smoothcostFnDelete) m_smoothcostFnDelete(m_smoothcostFn); if (m_datacostIndividual) delete [] m_datacostIndividual; if (m_smoothcostIndividual) delete [] m_smoothcostIndividual; // Delete label cost bookkeeping structures // slist_clear(m_labelcostsAll); if (m_labelcostsByLabel) { for ( LabelID i = 0; i < m_num_labels; ++i ) slist_clear(m_labelcostsByLabel[i]); delete [] m_labelcostsByLabel; } } //------------------------------------------------------------------- template <> GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion(LabelID alpha_label,SiteID *activeSites) { return ((DataCostFnSparse*)m_datacostFn)->queryActiveSitesExpansion(alpha_label,m_labeling,activeSites); } //------------------------------------------------------------------- template <> void GCoptimization::setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites) { DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn; DataCostFnSparse::iterator dciter = dc->begin(alpha_label); for ( SiteID i = 0; i < size; ++i ) { SiteID site = activeSites[i]; while ( dciter.site() != site ) ++dciter; addterm1_checked(e,i,dciter.cost(),m_labelingDataCosts[site]); } } //------------------------------------------------------------------- template <> void GCoptimization::applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label) { DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn; DataCostFnSparse::iterator dciter = dc->begin(alpha_label); for ( SiteID i = 0; i < size; i++ ) { if ( e->get_var(i) == 0 ) { SiteID site = activeSites[i]; LabelID prev = m_labeling[site]; m_labeling[site] = alpha_label; m_labelCounts[alpha_label]++; m_labelCounts[prev]--; while ( dciter.site() != site ) ++dciter; m_labelingDataCosts[site] = dciter.cost(); } } m_labelingInfoDirty = true; updateLabelingInfo(false,true,false); // labels have changed, so update necessary labeling info } //------------------------------------------------------------------- template void GCoptimization::specializeDataCostFunctor(const UserFunctor f) { if ( m_datacostFnDelete ) m_datacostFnDelete(m_datacostFn); if ( m_datacostIndividual ) { delete [] m_datacostIndividual; m_datacostIndividual = 0; } m_datacostFn = new UserFunctor(f); m_datacostFnDelete = &GCoptimization::deleteFunctor; m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion; m_setupDataCostsExpansion = &GCoptimization::setupDataCostsExpansion; m_setupDataCostsSwap = &GCoptimization::setupDataCostsSwap; m_applyNewLabeling = &GCoptimization::applyNewLabeling; m_updateLabelingDataCosts = &GCoptimization::updateLabelingDataCosts; m_solveSpecialCases = &GCoptimization::solveSpecialCases; } template void GCoptimization::specializeSmoothCostFunctor(const UserFunctor f) { if ( m_smoothcostFnDelete ) m_smoothcostFnDelete(m_smoothcostFn); if ( m_smoothcostIndividual ) { delete [] m_smoothcostIndividual; m_smoothcostIndividual = 0; } m_smoothcostFn = new UserFunctor(f); m_smoothcostFnDelete = &GCoptimization::deleteFunctor; m_giveSmoothEnergyInternal = &GCoptimization::giveSmoothEnergyInternal; m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion; m_setupSmoothCostsSwap = &GCoptimization::setupSmoothCostsSwap; } //------------------------------------------------------------------- template GCoptimization::EnergyType GCoptimization::giveSmoothEnergyInternal() { EnergyType eng = (EnergyType) 0; SiteID i,numN,*nPointer,nSite,n; EnergyTermType *weights; SmoothCostT* sc = (SmoothCostT*) m_smoothcostFn; for ( i = 0; i < m_num_sites; i++ ) { giveNeighborInfo(i,&numN,&nPointer,&weights); for ( n = 0; n < numN; n++ ) { nSite = nPointer[n]; if ( nSite < i ) eng += weights[n]*(sc->compute(i,nSite,m_labeling[i],m_labeling[nSite])); } } return eng; } //------------------------------------------------------------------- OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1) { if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM ) handleError("Data cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); m_beforeExpansionEnergy += e1; e->add_term1(i,e0,e1); } OLGA_INLINE void GCoptimization::addterm1_checked(EnergyT* e, VarID i, EnergyTermType e0, EnergyTermType e1, EnergyTermType w) { if ( e0 > GCO_MAX_ENERGYTERM || e1 > GCO_MAX_ENERGYTERM ) handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); if ( w > GCO_MAX_ENERGYTERM ) handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); m_beforeExpansionEnergy += e1*w; e->add_term1(i,e0*w,e1*w); } OLGA_INLINE void GCoptimization::addterm2_checked(EnergyT* e, VarID i, VarID j, EnergyTermType e00, EnergyTermType e01, EnergyTermType e10, EnergyTermType e11, EnergyTermType w) { if ( e00 > GCO_MAX_ENERGYTERM || e11 > GCO_MAX_ENERGYTERM || e01 > GCO_MAX_ENERGYTERM || e10 > GCO_MAX_ENERGYTERM ) handleError("Smooth cost term was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); if ( w > GCO_MAX_ENERGYTERM ) handleError("Smoothness weight was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); // Inside energy/maxflow code the submodularity check is performed as an assertion, // but is optimized out. We check it in release builds as well. if ( e00+e11 > e01+e10 ) handleError("Non-submodular expansion term detected; smooth costs must be a metric for expansion"); m_beforeExpansionEnergy += e11*w; e->add_term2(i,j,e00*w,e01*w,e10*w,e11*w); } //------------------------------------------------------------------ template GCoptimization::SiteID GCoptimization::queryActiveSitesExpansion(LabelID alpha_label,SiteID *activeSites) { SiteID size = 0; for ( SiteID i = 0; i < m_num_sites; i++ ) if ( m_labeling[i] != alpha_label ) activeSites[size++] = i; return size; } //------------------------------------------------------------------- template void GCoptimization::setupDataCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites) { DataCostT* dc = (DataCostT*)m_datacostFn; for ( SiteID i = 0; i < size; ++i ) addterm1_checked(e,i,dc->compute(activeSites[i],alpha_label),m_labelingDataCosts[activeSites[i]]); } //------------------------------------------------------------------- template void GCoptimization::setupSmoothCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites) { SiteID i,nSite,site,n,nNum,*nPointer; EnergyTermType *weights; SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn; for ( i = size - 1; i >= 0; i-- ) { site = activeSites[i]; giveNeighborInfo(site,&nNum,&nPointer,&weights); for ( n = 0; n < nNum; n++ ) { nSite = nPointer[n]; if ( m_lookupSiteVar[nSite] == -1 ) addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]), sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]); else if ( nSite < site ) { addterm2_checked(e,i,m_lookupSiteVar[nSite], sc->compute(site,nSite,alpha_label,alpha_label), sc->compute(site,nSite,alpha_label,m_labeling[nSite]), sc->compute(site,nSite,m_labeling[site],alpha_label), sc->compute(site,nSite,m_labeling[site],m_labeling[nSite]),weights[n]); } } } } //----------------------------------------------------------------------------------- template void GCoptimization::setupDataCostsSwap(SiteID size, LabelID alpha_label, LabelID beta_label, EnergyT *e,SiteID *activeSites ) { DataCostT* dc = (DataCostT*)m_datacostFn; for ( SiteID i = 0; i < size; i++ ) { e->add_term1(i,dc->compute(activeSites[i],alpha_label), dc->compute(activeSites[i],beta_label) ); } } //------------------------------------------------------------------- template void GCoptimization::setupSmoothCostsSwap(SiteID size, LabelID alpha_label,LabelID beta_label, EnergyT *e,SiteID *activeSites ) { SiteID i,nSite,site,n,nNum,*nPointer; EnergyTermType *weights; SmoothCostT* sc = (SmoothCostT*)m_smoothcostFn; for ( i = size - 1; i >= 0; i-- ) { site = activeSites[i]; giveNeighborInfo(site,&nNum,&nPointer,&weights); for ( n = 0; n < nNum; n++ ) { nSite = nPointer[n]; if ( m_lookupSiteVar[nSite] == -1 ) addterm1_checked(e,i,sc->compute(site,nSite,alpha_label,m_labeling[nSite]), sc->compute(site,nSite,beta_label, m_labeling[nSite]),weights[n]); else if ( nSite < site ) { addterm2_checked(e,i,m_lookupSiteVar[nSite], sc->compute(site,nSite,alpha_label,alpha_label), sc->compute(site,nSite,alpha_label,beta_label), sc->compute(site,nSite,beta_label,alpha_label), sc->compute(site,nSite,beta_label,beta_label),weights[n]); } } } } //----------------------------------------------------------------------------------- template void GCoptimization::applyNewLabeling(EnergyT *e,SiteID *activeSites,SiteID size,LabelID alpha_label) { DataCostT* dc = (DataCostT*)m_datacostFn; for ( SiteID i = 0; i < size; i++ ) { if ( e->get_var(i) == 0 ) { SiteID site = activeSites[i]; LabelID prev = m_labeling[site]; m_labeling[site] = alpha_label; m_labelCounts[alpha_label]++; m_labelCounts[prev]--; m_labelingDataCosts[site] = dc->compute(site,alpha_label); } } m_labelingInfoDirty = true; updateLabelingInfo(false,true,false); // labels have changed, so update necessary labeling info } //----------------------------------------------------------------------------------- template void GCoptimization::updateLabelingDataCosts() { DataCostT* dc = (DataCostT*)m_datacostFn; for (int i = 0; i < m_num_sites; ++i) m_labelingDataCosts[i] = dc->compute(i,m_labeling[i]); } //----------------------------------------------------------------------------------- template bool GCoptimization::solveSpecialCases(EnergyType& energy) { finalizeNeighbors(); DataCostT* dc = (DataCostT*)m_datacostFn; bool sc = m_numNeighborsTotal != 0; bool lc = m_labelcostsAll != 0; if ( !dc && !sc && !lc ) { energy = 0; return true; } if ( dc && !sc && !lc ) { // Special case: No label costs, so return trivial solution energy = 0; for ( SiteID i = 0; i < m_num_sites; ++i ) { LabelID minCostLabel = 0; EnergyTermType minCost = dc->compute(i, 0); for ( LabelID l = 1; l < m_num_labels; ++l ) { EnergyTermType lcost = dc->compute(i, l); if ( lcost < minCost ) { minCostLabel = l; minCost = lcost; } } if ( minCostLabel > GCO_MAX_ENERGYTERM ) handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); m_labeling[i] = minCostLabel; energy += minCost; } m_labelingInfoDirty = true; updateLabelingInfo(); return true; } if ( !dc && !sc && lc ) { // Special case: No data costs, so return trivial solution LabelID minLabel = 0; EnergyType minLabelCost = GCO_MAX_ENERGYTERM*(EnergyType)m_num_labels; for ( LabelID l = 0; l < m_num_labels; ++l ) { EnergyType lcsum = 0; for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next ) lcsum += lci->node->cost; if ( lcsum < minLabelCost ) { minLabel = l; minLabelCost = lcsum; } } for ( SiteID i = 0; i < m_num_sites; ++i ) m_labeling[i] = minLabel; energy = minLabelCost; m_labelingInfoDirty = true; updateLabelingInfo(); return true; } if ( dc && !sc && lc ) { LabelCost* lc; for ( lc = m_labelcostsAll; lc; lc = lc->next ) if ( lc->numLabels > 1) break; if ( !lc ) { // Special case: Data costs and per-label costs energy = solveGreedy(); return true; } } // Otherwise, use full-blown expansion/swap return false; } template <> class GCoptimization::GreedyIter { public: GreedyIter(DataCostFnSparse& dc, SiteID) : m_dc(dc), m_label(0), m_labelend(0) { } OLGA_INLINE void start(const LabelID* labels, LabelID labelCount=1) { m_label = labels; m_labelend = labels + labelCount; if (labelCount > 0) { m_site = m_dc.begin(*labels); m_siteend = m_dc.end(*labels); while (m_site == m_siteend) { if (++m_label == m_labelend) break; m_site = m_dc.begin(*m_label); m_siteend = m_dc.end(*m_label); } } } OLGA_INLINE SiteID site() const { return m_site.site(); } OLGA_INLINE SiteID label() const { return *m_label; } OLGA_INLINE bool done() const { return m_label >= m_labelend; } OLGA_INLINE GreedyIter& operator++() { // The inner loop is over sites, not labels, because sparse data costs // are stored as consecutive [sparse] SiteIDs with respect to each label. if (++m_site == m_siteend) { while (++m_label < m_labelend) { m_site = m_dc.begin(*m_label); m_siteend = m_dc.end(*m_label); if (m_site != m_siteend) break; } } return *this; } OLGA_INLINE EnergyTermType compute() const { return m_site.cost(); } OLGA_INLINE SiteID feasibleSites() const { return (SiteID)(m_siteend - m_site); } private: DataCostFnSparse::iterator m_site; DataCostFnSparse::iterator m_siteend; DataCostFnSparse& m_dc; const LabelID* m_label; const LabelID* m_labelend; }; template GCoptimization::EnergyType GCoptimization::solveGreedy() { printStatus1("starting greedy algorithm (1 cycle only)"); m_stepsThisCycle = m_stepsThisCycleTotal = 0; EnergyType estart = compute_energy(); EnergyType efinal = 0; LabelID* oldLabeling = m_labeling; m_labeling = new LabelID[m_num_sites]; EnergyType* e = new EnergyType[m_num_labels]; LabelID* order = new LabelID[m_num_labels]; // order[0..activeCount-1] contains the activated labels so far try { gcoclock_t ticks0all = gcoclock(); gcoclock_t ticks0 = gcoclock(); // clear active flags for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next) lc->active = false; DataCostT* dc = (DataCostT*)m_datacostFn; GreedyIter iter(*dc,m_num_sites); LabelID alpha = 0; // Treat first iteration as special case. // Ignore current labeling and just find the greedy initial label. for ( LabelID l = 0; l < m_num_labels; ++l ) { e[l] = 0; for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next ) e[l] += lci->node->cost; iter.start(&l); e[l] += (EnergyType)(m_num_sites - iter.feasibleSites()) * GCO_MAX_ENERGYTERM; // pre-add GCO_MAX_ENERGYTERM for all infeasible sites for (; !iter.done(); ++iter) { EnergyTermType dataCost = iter.compute(); if ( dataCost > GCO_MAX_ENERGYTERM ) handleError("Data cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); e[l] += dataCost; if ( e[l] > e[alpha] ) // break out early if this will definitely break; // not be a good label to start from } if ( e[l] < e[alpha] ) // choose alpha with minimum energy e[alpha] alpha = l; } for ( SiteID i = 0; i < m_num_sites; ++i ) { m_labeling[i] = alpha; m_labelingDataCosts[i] = dc->compute(i,alpha); } for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next ) lci->node->active = true; // List of labels in the order that they were expanded upon (order[0] first, order[1] second, ...) for ( LabelID l = 0; l < m_num_labels; ++l ) order[l] = l; order[alpha] = 0; order[0] = alpha; printStatus2(alpha,-1,m_num_sites,ticks0); // Greedily expand remaining labels for ( LabelID alpha_count = 1; alpha_count <= m_num_labels; ++alpha_count) { checkInterrupt(); ticks0 = gcoclock(); // Energy e[l] for expanding on label 'l' starts at e[alpha] + new labelcosts for introducing l LabelID alpha_prev = alpha; for ( LabelID li = alpha_count; li < m_num_labels; ++li ) { LabelID l = order[li]; e[l] = e[alpha_prev]; for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next ) if ( !lci->node->active ) e[l] += lci->node->cost; } // Loop over all sites and all remaining labels to calculate energy drop. for ( iter.start(&order[alpha_count],m_num_labels-alpha_count); !iter.done(); ++iter ) { EnergyTermType dc_l = iter.compute(); EnergyTermType dc_i = m_labelingDataCosts[iter.site()]; EnergyTermType delta_i = dc_l - dc_i; if ( delta_i < 0 ) e[iter.label()] += delta_i; } // Choose the next alpha based on lowest resulting energy LabelID alpha_index = alpha_count-1; for ( LabelID li = alpha_count; li < m_num_labels; ++li ) { LabelID l = order[li]; if ( e[l] < e[alpha] ) { alpha = l; alpha_index = li; } } if ( alpha == alpha_prev ) break; // Append alpha to the list of activated labels LabelID temp = order[alpha_count]; order[alpha_count] = order[alpha_index]; order[alpha_index] = temp; // Apply the new labeling, updating m_labelingDataCosts and active labelcosts as necessary iter.start(&alpha); SiteID size = iter.feasibleSites(); for ( ; !iter.done(); ++iter ) { EnergyTermType dc_l = iter.compute(); EnergyTermType dc_i = m_labelingDataCosts[iter.site()]; EnergyTermType delta_i = dc_l - dc_i; if ( delta_i < 0 ) { m_labeling[iter.site()] = alpha; m_labelingDataCosts[iter.site()] = dc_l; } } for ( LabelCostIter* lci = m_labelcostsByLabel[alpha]; lci; lci = lci->next ) lci->node->active = true; printStatus2(alpha,-1,size,ticks0); } efinal = e[alpha]; if ( efinal < estart ) { // Greedy succeeded in lowering energy compared to initial labeling delete [] oldLabeling; m_labelingInfoDirty = true; updateLabelingInfo(true,false,false); // update m_labelCounts only; m_labelingDataCosts and active labelcosts should be up to date printStatus1(1,false,ticks0all); } else { // Greedy failed to find a lower energy, so revert everything efinal = estart; delete [] m_labeling; m_labeling = oldLabeling; m_labelingInfoDirty = true; updateLabelingInfo(); // put all labeling info back the way it was printStatus1(1,false,ticks0all); } delete [] order; delete [] e; } catch (...) { delete [] order; delete [] e; throw; } return efinal; } //------------------------------------------------------------------ void GCoptimization::setDataCost(DataCostFn fn) { specializeDataCostFunctor(DataCostFnFromFunction(fn)); m_labelingInfoDirty = true; } //------------------------------------------------------------------ void GCoptimization::setDataCost(DataCostFnExtra fn, void *extraData) { specializeDataCostFunctor(DataCostFnFromFunctionExtra(fn, extraData)); m_labelingInfoDirty = true; } //------------------------------------------------------------------- void GCoptimization::setDataCost(EnergyTermType *dataArray) { specializeDataCostFunctor(DataCostFnFromArray(dataArray, m_num_labels)); m_labelingInfoDirty = true; } //------------------------------------------------------------------- void GCoptimization::setDataCost(SiteID s, LabelID l, EnergyTermType e) { if ( !m_datacostIndividual ) { EnergyTermType* table = new EnergyTermType[m_num_sites*m_num_labels]; memset(table, 0, m_num_sites*m_num_labels*sizeof(EnergyTermType)); specializeDataCostFunctor(DataCostFnFromArray(table, m_num_labels)); m_datacostIndividual = table; m_labelingInfoDirty = true; } m_datacostIndividual[s*m_num_labels + l] = e; if ( m_labeling[s] == l ) m_labelingInfoDirty = true; // m_labelingDataCosts is dirty } //------------------------------------------------------------------- void GCoptimization::setDataCostFunctor(DataCostFunctor* f) { if ( m_datacostFnDelete ) m_datacostFnDelete(m_datacostFn); if ( m_datacostIndividual ) { delete [] m_datacostIndividual; m_datacostIndividual = 0; } m_datacostFn = f; m_datacostFnDelete = 0; m_queryActiveSitesExpansion = &GCoptimization::queryActiveSitesExpansion; m_setupDataCostsExpansion = &GCoptimization::setupDataCostsExpansion; m_setupDataCostsSwap = &GCoptimization::setupDataCostsSwap; m_applyNewLabeling = &GCoptimization::applyNewLabeling; m_updateLabelingDataCosts = &GCoptimization::updateLabelingDataCosts; m_solveSpecialCases = &GCoptimization::solveSpecialCases; m_labelingInfoDirty = true; } //------------------------------------------------------------------- void GCoptimization::setDataCost(LabelID l, SparseDataCost *costs, SiteID count) { if ( !m_datacostFn ) specializeDataCostFunctor(DataCostFnSparse(numSites(),numLabels())); else if ( m_queryActiveSitesExpansion != (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion ) handleError("Cannot apply sparse data costs after dense data costs have been used."); m_labelingInfoDirty = true; DataCostFnSparse* dc = (DataCostFnSparse*)m_datacostFn; dc->set(l,costs,count); } //------------------------------------------------------------------- void GCoptimization::setSmoothCost(SmoothCostFn fn) { specializeSmoothCostFunctor(SmoothCostFnFromFunction(fn)); } //------------------------------------------------------------------- void GCoptimization::setSmoothCost(SmoothCostFnExtra fn, void* extraData) { specializeSmoothCostFunctor(SmoothCostFnFromFunctionExtra(fn, extraData)); } //------------------------------------------------------------------- void GCoptimization::setSmoothCost(EnergyTermType *smoothArray) { specializeSmoothCostFunctor(SmoothCostFnFromArray(smoothArray, m_num_labels)); } //------------------------------------------------------------------- void GCoptimization::setSmoothCost(LabelID l1, LabelID l2, EnergyTermType e){ if ( !m_smoothcostIndividual ) { EnergyTermType* table = new EnergyTermType[m_num_labels*m_num_labels]; memset(table, 0, m_num_labels*m_num_labels*sizeof(EnergyTermType)); specializeSmoothCostFunctor(SmoothCostFnFromArray(table, m_num_labels)); m_smoothcostIndividual = table; } m_smoothcostIndividual[l1*m_num_labels + l2] = e; } //------------------------------------------------------------------- void GCoptimization::setSmoothCostFunctor(SmoothCostFunctor* f) { if ( m_smoothcostFnDelete ) m_smoothcostFnDelete(m_smoothcostFn); if ( m_smoothcostIndividual ) { delete [] m_smoothcostIndividual; m_smoothcostIndividual = 0; } m_smoothcostFn = f; m_smoothcostFnDelete = 0; m_giveSmoothEnergyInternal = &GCoptimization::giveSmoothEnergyInternal; m_setupSmoothCostsExpansion = &GCoptimization::setupSmoothCostsExpansion; m_setupSmoothCostsSwap = &GCoptimization::setupSmoothCostsSwap; } //------------------------------------------------------------------- void GCoptimization::setLabelCost(EnergyTermType cost) { EnergyTermType* lc = new EnergyTermType[m_num_labels]; for ( LabelID i = 0; i < m_num_labels; ++i ) lc[i] = cost; setLabelCost(lc); delete [] lc; } //------------------------------------------------------------------- void GCoptimization::setLabelCost(EnergyTermType *costArray) { for ( LabelID i = 0; i < m_num_labels; ++i ) setLabelSubsetCost(&i, 1, costArray[i]); } //------------------------------------------------------------------- void GCoptimization::setLabelSubsetCost(LabelID* labels, LabelID numLabels, EnergyTermType cost) { if ( cost < 0 ) handleError("Label costs must be non-negative."); if ( cost > GCO_MAX_ENERGYTERM ) handleError("Label cost was larger than GCO_MAX_ENERGYTERM; danger of integer overflow."); for ( LabelID i = 0; i < numLabels; ++i) if ( labels[i] < 0 || labels[i] >= m_num_labels ) handleError("Invalid label id was found in label subset list."); if ( !m_labelcostsByLabel ) { m_labelcostsByLabel = new LabelCostIter*[m_num_labels]; memset(m_labelcostsByLabel, 0, m_num_labels*sizeof(void*)); } // If this particular subset already has a cost, simply replace it. for ( LabelCostIter* lci = m_labelcostsByLabel[labels[0]]; lci; lci = lci->next ) { if ( numLabels == lci->node->numLabels ) { if ( !memcmp(labels, lci->node->labels, numLabels*sizeof(LabelID)) ) { // This label subset already exists, so just update the cost and return lci->node->cost = cost; return; } } } if (cost == 0) return; // Create a new LabelCost entry and add it to the appropriate lists m_labelcostCount++; LabelCost* lc = new LabelCost; lc->cost = cost; lc->active = false; lc->aux = -1; lc->numLabels = numLabels; lc->labels = new LabelID[numLabels]; memcpy(lc->labels, labels, numLabels*sizeof(LabelID)); slist_prepend(m_labelcostsAll, lc); for ( LabelID i = 0; i < numLabels; ++i ) { LabelCostIter* lci = new LabelCostIter; lci->node = lc; slist_prepend(m_labelcostsByLabel[labels[i]], lci); } } //------------------------------------------------------------------- void GCoptimization::whatLabel(SiteID start, SiteID count, LabelID* labeling) { assert(start >= 0 && start+count <= m_num_sites); memcpy(labeling, m_labeling+start, count*sizeof(LabelID)); } //------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::giveSmoothEnergy() { finalizeNeighbors(); if ( m_giveSmoothEnergyInternal ) return( (this->*m_giveSmoothEnergyInternal)()); return 0; } //------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::giveDataEnergy() { updateLabelingInfo(); EnergyType energy = 0; for ( SiteID i = 0; i < m_num_sites; i++ ) energy += m_labelingDataCosts[i]; return energy; } GCoptimization::EnergyType GCoptimization::giveLabelEnergy() { updateLabelingInfo(); EnergyType energy = 0; for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next) if ( lc->active ) energy += lc->cost; return energy; } //------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::compute_energy() { return giveDataEnergy() + giveSmoothEnergy() + giveLabelEnergy(); } //------------------------------------------------------------------- void GCoptimization::permuteLabelTable() { if ( !m_random_label_order ) return; for ( LabelID i = 0; i < m_num_labels; i++ ) { LabelID j = i + (rand() % (m_num_labels-i)); LabelID temp = m_labelTable[i]; m_labelTable[i] = m_labelTable[j]; m_labelTable[j] = temp; } } //------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::expansion(int max_num_iterations) { EnergyType new_energy, old_energy; if ( (this->*m_solveSpecialCases)(new_energy) ) return new_energy; permuteLabelTable(); updateLabelingInfo(); try { if ( max_num_iterations == -1 ) { // Strategic expansion loop focuses on labels that successfuly reduced the energy printStatus1("starting alpha-expansion w/ adaptive cycles"); std::vector queueSizes; queueSizes.push_back(m_num_labels); int cycle = 1; LabelID next = 0; do { gcoclock_t ticks0 = gcoclock(); m_stepsThisCycle = 0; // Make a pass over the unchecked labels in the current queue, i.e. m_labelTable[next..queueSize-1] LabelID queueSize = queueSizes.back(); LabelID start = next; m_stepsThisCycleTotal = queueSize - start; do { if ( !alpha_expansion(m_labelTable[next]) ) std::swap(m_labelTable[next],m_labelTable[--queueSize]); // don't put this label in a new queue else ++next; // keep this label for the next (smaller) queue m_stepsThisCycle++; } while ( next < queueSize ); if ( next == start ) // No expansion was successful, so try more labels from the previous queue { next = queueSizes.back(); queueSizes.pop_back(); } else if ( queueSize < queueSizes.back()/2 ) // Some expansions were successful, so focus on them in a new queue { next = 0; queueSizes.push_back(queueSize); } else next = 0; // All expansions were successful, so do another complete sweep printStatus1(cycle++,false,ticks0); } while ( !queueSizes.empty() ); new_energy = compute_energy(); } else { // Standard expansion loop sweeps over all labels each cycle printStatus1("starting alpha-expansion w/ standard cycles"); new_energy = compute_energy(); old_energy = new_energy+1; for ( int cycle = 1; cycle <= max_num_iterations; cycle++ ) { gcoclock_t ticks0 = gcoclock(); old_energy = new_energy; new_energy = oneExpansionIteration(); printStatus1(cycle,false,ticks0); if ( new_energy == old_energy ) break; permuteLabelTable(); } } } catch (...) { m_stepsThisCycle = m_stepsThisCycleTotal = 0; throw; } m_stepsThisCycle = m_stepsThisCycleTotal = 0; // set so that alpha_expansion() knows it's no inside expansion() if called externally return new_energy; } //------------------------------------------------------------------- void GCoptimization::setLabelOrder(bool isRandom) { m_random_label_order = isRandom; for ( LabelID i = 0; i < m_num_labels; i++ ) m_labelTable[i] = i; } //------------------------------------------------------------------- void GCoptimization::setLabelOrder(const LabelID* order, LabelID size) { if ( size > m_num_labels ) handleError("setLabelOrder receieved too many labels"); for ( LabelID i = 0; i < size; ++i ) if ( order[i] < 0 || order[i] >= m_num_labels ) handleError("Invalid label id in setLabelOrder"); m_random_label_order = false; memcpy(m_labelTable,order,size*sizeof(LabelID)); memset(m_labelTable+size,-1,(m_num_labels-size)*sizeof(LabelID)); } //------------------------------------------------------------------ void GCoptimization::handleError(const char *message) { throw GCException(message); } //------------------------------------------------------------------ void GCoptimization::checkInterrupt() { if ( utIsInterruptPending() ) throw GCException("Interrupted."); } //-------------------------------------------------------------------// // METHODS for EXPANSION MOVES // //-------------------------------------------------------------------// GCoptimization::EnergyType GCoptimization::setupLabelCostsExpansion(SiteID size,LabelID alpha_label,EnergyT *e,SiteID *activeSites) { EnergyType alphaCostCorrection = 0; if ( !m_labelcostsAll ) return alphaCostCorrection; const SiteID DISABLE = -2; const SiteID UNINIT = -1; for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next ) lc->aux = UNINIT; // Skip higher-order costs that include alpha_label or any label used // outside the activeSites, since they cannot be eliminated by the expansion. if ( m_queryActiveSitesExpansion == (SiteID (GCoptimization::*)(LabelID,SiteID*))&GCoptimization::queryActiveSitesExpansion ) { // For sparse data costs, things are more complicated, because we must ensure that // no label cost for a fixed (non-active) non-alpha label is encoded in the graph. memset(m_activeLabelCounts,0,m_num_labels*sizeof(SiteID)); for ( SiteID i = 0; i < size; ++i ) m_activeLabelCounts[m_labeling[activeSites[i]]]++; for ( LabelID l = 0; l < m_num_labels; ++l ) { if ( m_activeLabelCounts[l] != m_labelCounts[l] ) { for ( LabelCostIter* lcj = m_labelcostsByLabel[l]; lcj; lcj = lcj->next ) lcj->node->aux = DISABLE; } } } for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next ) lci->node->aux = DISABLE; // Since we're explicitly omitting the alpha_label label costs from the binary energy, // calculate what it would have been, so that we can potentially reject the expansion afterwards. if ( !m_labelCounts[alpha_label] ) { for ( LabelCostIter* lci = m_labelcostsByLabel[alpha_label]; lci; lci = lci->next ) if ( !lci->node->active ) alphaCostCorrection += lci->node->cost; } // Add edges to the graph, including auxiliary vertices as needed for ( SiteID i = 0; i < size; i++ ) { LabelID label_i = m_labeling[activeSites[i]]; for ( LabelCostIter* lci = m_labelcostsByLabel[label_i]; lci; lci = lci->next ) { LabelCost* lc = lci->node; if ( lc->aux == DISABLE ) continue; // Add auxiliary variable if necessary, and add pairwise potential if ( lc->aux == UNINIT ) { lc->aux = e->add_variable(); e->add_term1(lc->aux,0,lc->cost); m_beforeExpansionEnergy += lc->cost; } e->add_term2(i,lc->aux,0,0,lc->cost,0); } } return alphaCostCorrection; } //------------------------------------------------------------------- void GCoptimization::updateLabelingInfo(bool updateCounts, bool updateActive, bool updateCosts) { if ( !m_labelingInfoDirty ) return; m_labelingInfoDirty = false; if ( m_labelcostsAll ) { if ( updateCounts ) { memset(m_labelCounts,0,m_num_labels*sizeof(SiteID)); for ( SiteID i = 0; i < m_num_sites; ++i ) m_labelCounts[m_labeling[i]]++; } if ( updateActive ) { for ( LabelCost* lc = m_labelcostsAll; lc; lc = lc->next ) lc->active = false; EnergyType energy = 0; for ( LabelID l = 0; l < m_num_labels; ++l ) if ( m_labelCounts[l] ) for ( LabelCostIter* lci = m_labelcostsByLabel[l]; lci; lci = lci->next ) lci->node->active = true; } } if ( updateCosts ) { if (m_updateLabelingDataCosts) (this->*m_updateLabelingDataCosts)(); else memset(m_labelingDataCosts,0,m_num_sites*sizeof(EnergyTermType)); } } //------------------------------------------------------------------- // Sets up the binary expansion energy, optimizes it, and updates the current labeling. // bool GCoptimization::alpha_expansion(LabelID alpha_label) { if (alpha_label < 0) return false; // label was disabled due to setLabelOrder on subset of labels finalizeNeighbors(); gcoclock_t ticks0 = gcoclock(); if ( m_stepsThisCycleTotal == 0 ) m_labelingInfoDirty = true; // if not inside expansion(), assume data cost function could have changed since last expansion updateLabelingInfo(); // Determine list of active sites for this expansion move SiteID size = 0; SiteID *activeSites = new SiteID[m_num_sites]; EnergyType afterExpansionEnergy = 0; try { // Get list of active sites based on alpha and current labeling if ( m_queryActiveSitesExpansion ) size = (this->*m_queryActiveSitesExpansion)(alpha_label,activeSites); if ( size == 0 ) // Nothing to do { delete [] activeSites; printStatus2(alpha_label,-1,size,ticks0); return false; } // Initialise reverse-lookup so that non-active neighbours can be identified // while constructing the graph for ( SiteID i = 0; i < size; i++ ) m_lookupSiteVar[activeSites[i]] = i; // Create binary variables for each remaining site, add the data costs, // and compute the smooth costs between variables. EnergyT e(size+m_labelcostCount, // poor guess at number of pairwise terms needed :( m_numNeighborsTotal+(m_labelcostCount?size+m_labelcostCount : 0), (void(*)(char*))handleError); e.add_variable(size); m_beforeExpansionEnergy = 0; if ( m_setupDataCostsExpansion ) (this->*m_setupDataCostsExpansion )(size,alpha_label,&e,activeSites); if ( m_setupSmoothCostsExpansion ) (this->*m_setupSmoothCostsExpansion)(size,alpha_label,&e,activeSites); EnergyType alphaCorrection = setupLabelCostsExpansion(size,alpha_label,&e,activeSites); checkInterrupt(); afterExpansionEnergy = e.minimize() + alphaCorrection; checkInterrupt(); if ( afterExpansionEnergy < m_beforeExpansionEnergy ) (this->*m_applyNewLabeling)(&e,activeSites,size,alpha_label); for ( SiteID i = 0; i < size; i++ ) m_lookupSiteVar[activeSites[i]] = -1; // restore m_lookupSite to all -1s printStatus2(alpha_label,-1,size,ticks0); } catch (...) { delete [] activeSites; throw; } delete [] activeSites; return afterExpansionEnergy < m_beforeExpansionEnergy; } //------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::oneExpansionIteration() { permuteLabelTable(); m_stepsThisCycle = 0; m_stepsThisCycleTotal = m_num_labels; // Each cycle is exactly one pass over the labels for (LabelID next = 0; next < m_num_labels; next++, m_stepsThisCycle++ ) alpha_expansion(m_labelTable[next]); return compute_energy(); } //-------------------------------------------------------------------// // METHODS for SWAP MOVES // //-------------------------------------------------------------------// GCoptimization::EnergyType GCoptimization::swap(int max_num_iterations) { EnergyType new_energy,old_energy; if ( (this->*m_solveSpecialCases)(new_energy) ) return new_energy; new_energy = compute_energy(); old_energy = new_energy+1; printStatus1("starting alpha/beta-swap"); if ( max_num_iterations == -1 ) max_num_iterations = 10000000; int curr_cycle = 1; m_stepsThisCycleTotal = (m_num_labels*(m_num_labels-1))/2; //try //{ while ( old_energy > new_energy && curr_cycle <= max_num_iterations) { gcoclock_t ticks0 = gcoclock(); old_energy = new_energy; new_energy = oneSwapIteration(); printStatus1(curr_cycle,true,ticks0); curr_cycle++; } //} //catch (...) //{ // m_stepsThisCycle = m_stepsThisCycleTotal = 0; // throw; //} m_stepsThisCycle = m_stepsThisCycleTotal = 0; return(new_energy); } //-------------------------------------------------------------------------------- GCoptimization::EnergyType GCoptimization::oneSwapIteration() { LabelID next,next1; permuteLabelTable(); m_stepsThisCycle = 0; for (next = 0; next < m_num_labels; next++ ) for (next1 = m_num_labels - 1; next1 >= 0; next1-- ) if ( m_labelTable[next] < m_labelTable[next1] ) { alpha_beta_swap(m_labelTable[next],m_labelTable[next1]); m_stepsThisCycle++; } return(compute_energy()); } //--------------------------------------------------------------------------------- void GCoptimization::alpha_beta_swap(LabelID alpha_label, LabelID beta_label) { assert( alpha_label >= 0 && alpha_label < m_num_labels && beta_label >= 0 && beta_label < m_num_labels); if ( m_labelcostsAll ) handleError("Label costs only implemented for alpha-expansion."); finalizeNeighbors(); gcoclock_t ticks0 = gcoclock(); // Determine the list of active sites for this swap move SiteID size = 0; SiteID *activeSites = new SiteID[m_num_sites]; //try //{ for ( SiteID i = 0; i < m_num_sites; i++ ) { if ( m_labeling[i] == alpha_label || m_labeling[i] == beta_label ) { activeSites[size] = i; m_lookupSiteVar[i] = size; size++; } } if ( size == 0 ) { delete [] activeSites; printStatus2(alpha_label,beta_label,size,ticks0); return; } // Create binary variables for each remaining site, add the data costs, // and compute the smooth costs between variables. EnergyT e(size,m_numNeighborsTotal,(void(*)(char*))handleError); e.add_variable(size); if ( m_setupDataCostsSwap ) (this->*m_setupDataCostsSwap )(size,alpha_label,beta_label,&e,activeSites); if ( m_setupSmoothCostsSwap ) (this->*m_setupSmoothCostsSwap)(size,alpha_label,beta_label,&e,activeSites); checkInterrupt(); e.minimize(); checkInterrupt(); // Apply the new labeling for ( SiteID i = 0; i < size; i++ ) { m_labeling[activeSites[i]] = (e.get_var(i) == 0) ? alpha_label : beta_label; m_lookupSiteVar[activeSites[i]] = -1; // restore lookupSiteVar to all -1s } m_labelingInfoDirty = true; //} //catch (...) //{ // delete [] activeSites; // throw; //} delete [] activeSites; printStatus2(alpha_label,beta_label,size,ticks0); } ////////////////////////////////////////////////////////////////////////////////////////////////// // Functions for the GCoptimizationGridGraph, derived from GCoptimization //////////////////////////////////////////////////////////////////////////////////////////////////// GCoptimizationGridGraph::GCoptimizationGridGraph(SiteID width, SiteID height,LabelID num_labels) :GCoptimization(width*height,num_labels) { assert( (width > 1) && (height > 1) && (num_labels > 1 )); m_weightedGraph = 0; for (int i = 0; i < 4; i ++ ) m_unityWeights[i] = 1; m_width = width; m_height = height; m_numNeighbors = new SiteID[m_num_sites]; m_neighbors = new SiteID[4*m_num_sites]; SiteID indexes[4] = {-1,1,-m_width,m_width}; SiteID indexesL[3] = {1,-m_width,m_width}; SiteID indexesR[3] = {-1,-m_width,m_width}; SiteID indexesU[3] = {1,-1,m_width}; SiteID indexesD[3] = {1,-1,-m_width}; SiteID indexesUL[2] = {1,m_width}; SiteID indexesUR[2] = {-1,m_width}; SiteID indexesDL[2] = {1,-m_width}; SiteID indexesDR[2] = {-1,-m_width}; setupNeighbData(1,m_height-1,1,m_width-1,4,indexes); setupNeighbData(1,m_height-1,0,1,3,indexesL); setupNeighbData(1,m_height-1,m_width-1,m_width,3,indexesR); setupNeighbData(0,1,1,width-1,3,indexesU); setupNeighbData(m_height-1,m_height,1,m_width-1,3,indexesD); setupNeighbData(0,1,0,1,2,indexesUL); setupNeighbData(0,1,m_width-1,m_width,2,indexesUR); setupNeighbData(m_height-1,m_height,0,1,2,indexesDL); setupNeighbData(m_height-1,m_height,m_width-1,m_width,2,indexesDR); } //------------------------------------------------------------------- GCoptimizationGridGraph::~GCoptimizationGridGraph() { delete [] m_numNeighbors; if ( m_neighbors ) delete [] m_neighbors; if (m_weightedGraph) delete [] m_neighborsWeights; } //------------------------------------------------------------------- void GCoptimizationGridGraph::setupNeighbData(SiteID startY,SiteID endY,SiteID startX, SiteID endX,SiteID maxInd,SiteID *indexes) { SiteID x,y,pix; SiteID n; for ( y = startY; y < endY; y++ ) for ( x = startX; x < endX; x++ ) { pix = x+y*m_width; m_numNeighbors[pix] = maxInd; m_numNeighborsTotal += maxInd; for (n = 0; n < maxInd; n++ ) m_neighbors[pix*4+n] = pix+indexes[n]; } } //------------------------------------------------------------------- void GCoptimizationGridGraph::finalizeNeighbors() { } //------------------------------------------------------------------- void GCoptimizationGridGraph::setSmoothCostVH(EnergyTermType *smoothArray, EnergyTermType *vCosts, EnergyTermType *hCosts) { setSmoothCost(smoothArray); m_weightedGraph = 1; computeNeighborWeights(vCosts,hCosts); } //------------------------------------------------------------------- void GCoptimizationGridGraph::giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights) { *numSites = m_numNeighbors[site]; *neighbors = &m_neighbors[site*4]; if (m_weightedGraph) *weights = &m_neighborsWeights[site*4]; else *weights = m_unityWeights; } //------------------------------------------------------------------- void GCoptimizationGridGraph::computeNeighborWeights(EnergyTermType *vCosts,EnergyTermType *hCosts) { SiteID i,n,nSite; GCoptimization::EnergyTermType weight; m_neighborsWeights = new EnergyTermType[m_num_sites*4]; for ( i = 0; i < m_num_sites; i++ ) { for ( n = 0; n < m_numNeighbors[i]; n++ ) { nSite = m_neighbors[4*i+n]; if ( i-nSite == 1 ) weight = hCosts[nSite]; else if (i-nSite == -1 ) weight = hCosts[i]; else if ( i-nSite == m_width ) weight = vCosts[nSite]; else if (i-nSite == -m_width ) weight = vCosts[i]; m_neighborsWeights[i*4+n] = weight; } } } //////////////////////////////////////////////////////////////////////////////////////////////// // Functions for the GCoptimizationGeneralGraph, derived from GCoptimization //////////////////////////////////////////////////////////////////////////////////////////////////// GCoptimizationGeneralGraph::GCoptimizationGeneralGraph(SiteID num_sites,LabelID num_labels):GCoptimization(num_sites,num_labels) { assert( num_sites > 1 && num_labels > 1 ); m_neighborsIndexes = 0; m_neighborsWeights = 0; m_numNeighbors = 0; m_neighbors = 0; m_needTodeleteNeighbors = true; m_needToFinishSettingNeighbors = true; } //------------------------------------------------------------------ GCoptimizationGeneralGraph::~GCoptimizationGeneralGraph() { if ( m_neighbors ) delete [] m_neighbors; if ( m_numNeighbors && m_needTodeleteNeighbors ) { for ( SiteID i = 0; i < m_num_sites; i++ ) { if (m_numNeighbors[i] != 0 ) { delete [] m_neighborsIndexes[i]; delete [] m_neighborsWeights[i]; } } delete [] m_numNeighbors; delete [] m_neighborsIndexes; delete [] m_neighborsWeights; } } //------------------------------------------------------------------ void GCoptimizationGeneralGraph::finalizeNeighbors() { if ( !m_needToFinishSettingNeighbors ) return; m_needToFinishSettingNeighbors = false; Neighbor *tmp; SiteID i,site,count; EnergyTermType *tempWeights = new EnergyTermType[m_num_sites]; SiteID *tempIndexes = new SiteID[m_num_sites]; if ( !tempWeights || !tempIndexes ) handleError("Not enough memory"); m_numNeighbors = new SiteID[m_num_sites]; m_neighborsIndexes = new SiteID*[m_num_sites]; m_neighborsWeights = new EnergyTermType*[m_num_sites]; if ( !m_numNeighbors || !m_neighborsIndexes || !m_neighborsWeights ) handleError("Not enough memory."); for ( site = 0; site < m_num_sites; site++ ) { if ( m_neighbors && !m_neighbors[site].isEmpty() ) { m_neighbors[site].setCursorFront(); count = 0; while ( m_neighbors[site].hasNext() ) { tmp = (Neighbor *) (m_neighbors[site].next()); tempIndexes[count] = tmp->to_node; tempWeights[count] = tmp->weight; delete tmp; count++; } m_numNeighbors[site] = count; m_numNeighborsTotal += count; m_neighborsIndexes[site] = new SiteID[count]; m_neighborsWeights[site] = new EnergyTermType[count]; if ( !m_neighborsIndexes[site] || !m_neighborsWeights[site] ) handleError("Not enough memory."); for ( i = 0; i < count; i++ ) { m_neighborsIndexes[site][i] = tempIndexes[i]; m_neighborsWeights[site][i] = tempWeights[i]; } } else m_numNeighbors[site] = 0; } delete [] tempIndexes; delete [] tempWeights; if (m_neighbors) { delete [] m_neighbors; m_neighbors = 0; } } //------------------------------------------------------------------------------ void GCoptimizationGeneralGraph::giveNeighborInfo(SiteID site, SiteID *numSites, SiteID **neighbors, EnergyTermType **weights) { if (m_numNeighbors) { (*numSites) = m_numNeighbors[site]; (*neighbors) = m_neighborsIndexes[site]; (*weights) = m_neighborsWeights[site]; } else { *numSites = 0; *neighbors = 0; *weights = 0; } } //------------------------------------------------------------------ void GCoptimizationGeneralGraph::setNeighbors(SiteID site1, SiteID site2, EnergyTermType weight) { assert( site1 < m_num_sites && site1 >= 0 && site2 < m_num_sites && site2 >= 0); if ( m_needToFinishSettingNeighbors == false ) handleError("Already set up neighborhood system."); if ( !m_neighbors ) { m_neighbors = new LinkedBlockList[m_num_sites]; if ( !m_neighbors ) handleError("Not enough memory."); } Neighbor *temp1 = (Neighbor *) new Neighbor; Neighbor *temp2 = (Neighbor *) new Neighbor; temp1->weight = weight; temp1->to_node = site2; temp2->weight = weight; temp2->to_node = site1; m_neighbors[site1].addFront(temp1); m_neighbors[site2].addFront(temp2); } //------------------------------------------------------------------ void GCoptimizationGeneralGraph::setAllNeighbors(SiteID *numNeighbors,SiteID **neighborsIndexes, EnergyTermType **neighborsWeights) { m_needTodeleteNeighbors = false; m_needToFinishSettingNeighbors = false; if ( m_numNeighborsTotal > 0 ) handleError("Already set up neighborhood system."); m_numNeighbors = numNeighbors; m_numNeighborsTotal = 0; for (int site = 0; site < m_num_sites; site++ ) m_numNeighborsTotal += m_numNeighbors[site]; m_neighborsIndexes = neighborsIndexes; m_neighborsWeights = neighborsWeights; } //------------------------------------------------------------------ // boring status messages void GCoptimization::printStatus1(const char* extraMsg) { if ( m_verbosity < 1 ) return; if ( extraMsg ) printf("gco>> %s\n",extraMsg); printf("gco>> initial energy: \tE=%lld (E=%lld+%lld+%lld)\n",(long long)compute_energy(), (long long)giveDataEnergy(), (long long)giveSmoothEnergy(), (long long)giveLabelEnergy()); flushnow(); } void GCoptimization::printStatus1(int cycle, bool isSwap, gcoclock_t ticks0) { if ( m_verbosity < 1 ) return; gcoclock_t ticks1 = gcoclock(); printf("gco>> after cycle %2d: \tE=%lld (E=%lld+%lld+%lld);",cycle,(long long)compute_energy(), (long long)giveDataEnergy(),(long long)giveSmoothEnergy(),(long long)giveLabelEnergy()); if ( m_stepsThisCycleTotal > 0 ) printf(isSwap ? " \t%d swaps(s);" : " \t%d expansions(s);",m_stepsThisCycleTotal); if ( m_verbosity == 1 ) { // Don't print time if time is already printed at finer scale, since printing // itself takes time (esp in MATLAB) and makes time useless at this level int ms = (int)(1000*(ticks1 - ticks0) / GCO_CLOCKS_PER_SEC); printf(" \t%d ms",ms); } printf("\n"); flushnow(); } void GCoptimization::printStatus2(int alpha, int beta, int numVars, gcoclock_t ticks0) { if ( m_verbosity < 2 ) return; int microsec = (int)(1000000*(gcoclock() - ticks0) / GCO_CLOCKS_PER_SEC); if ( beta >= 0 ) printf("gco>> after swap(%d,%d):",alpha+INDEX0,beta+INDEX0); else printf("gco>> after expansion(%d):",alpha+INDEX0); printf(" \tE=%lld (E=%lld+%lld+%lld);\t %lld vars;", (long long)compute_energy(),(long long)giveDataEnergy(), (long long)giveSmoothEnergy(),(long long)giveLabelEnergy(),(long long)numVars); if ( m_stepsThisCycleTotal > 0 ) printf(" \t(%d of %d);",m_stepsThisCycle+1,m_stepsThisCycleTotal); printf(microsec > 100 ? "\t %.2f ms\n" : "\t %.3f ms\n",(double)microsec/1000.0); flushnow(); } //------------------------------------------------------------------- // DataCostFnSparse methods //------------------------------------------------------------------- GCoptimization::DataCostFnSparse::DataCostFnSparse(SiteID num_sites, LabelID num_labels) : m_num_sites(num_sites) , m_num_labels(num_labels) , m_buckets_per_label((m_num_sites + cSitesPerBucket-1)/cSitesPerBucket) , m_buckets(0) { } GCoptimization::DataCostFnSparse::DataCostFnSparse(const DataCostFnSparse& src) : m_num_sites(src.m_num_sites) , m_num_labels(src.m_num_labels) , m_buckets_per_label(src.m_buckets_per_label) , m_buckets(0) { assert(!src.m_buckets); // not implemented } GCoptimization::DataCostFnSparse::~DataCostFnSparse() { if (m_buckets) { for (LabelID l = 0; l < m_num_labels; ++l) if (m_buckets[l*m_buckets_per_label].begin) delete [] m_buckets[l*m_buckets_per_label].begin; delete [] m_buckets; } } void GCoptimization::DataCostFnSparse::set(LabelID l, const SparseDataCost* costs, SiteID count) { // Create the bucket if necessary, and copy all the costs // if (!m_buckets) { m_buckets = new DataCostBucket[m_num_labels*m_buckets_per_label]; memset(m_buckets, 0, m_num_labels*m_buckets_per_label*sizeof(DataCostBucket)); } DataCostBucket* b = &m_buckets[l*m_buckets_per_label]; if (b->begin) delete [] b->begin; SparseDataCost* next = new SparseDataCost[count]; memcpy(next,costs,count*sizeof(SparseDataCost)); // // Scan the list of costs and remember pointers to delimit the 'buckets', i.e. where // ranges of SiteIDs lie along the array. Buckets can be empty (begin == end). // const SparseDataCost* end = next+count; SiteID prev_site = -1; for (int i = 0; i < m_buckets_per_label; ++i) { b[i].begin = b[i].predict = next; SiteID end_site = (i+1)*cSitesPerBucket; while (next < end && next->site < end_site) { if (next->site < 0 || next->site >= m_num_sites) throw GCException("Invalid site id given for sparse data cost; must be within range."); if (next->site <= prev_site) throw GCException("Sparse data costs must be sorted in increasing order of SiteID"); prev_site = next->site; ++next; } b[i].end = next; } } GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::search(DataCostBucket& b, SiteID s) { // Perform binary search for requested SiteID // const SparseDataCost* L = b.begin; const SparseDataCost* R = b.end-1; if ( R - L == m_num_sites ) return b.begin[s].cost; // special case: this particular label is actually dense do { const SparseDataCost* mid = (const SparseDataCost*)((((size_t)L+(size_t)R) >> 1) & cDataCostPtrMask); if (s < mid->site) R = mid-1; // eliminate upper range else if (mid->site < s) L = mid+1; // eliminate lower range else { b.predict = mid+1; return mid->cost; // found it! } } while (R - L > cLinearSearchSize); // Finish off with linear search over the remaining elements // do { if (L->site >= s) { if (L->site == s) { b.predict = L+1; return L->cost; } break; } } while (++L <= R); b.predict = L; return GCO_MAX_ENERGYTERM; // the site belongs to this bucket but with no cost specified } OLGA_INLINE GCoptimization::EnergyTermType GCoptimization::DataCostFnSparse::compute(SiteID s, LabelID l) { DataCostBucket& b = m_buckets[l*m_buckets_per_label + (s >> cLogSitesPerBucket)]; if (b.begin == b.end) return GCO_MAX_ENERGYTERM; if (b.predict < b.end) { // Check for correct prediction if (b.predict->site == s) return (b.predict++)->cost; // predict++ for next time // If the requested site is missing from the site ids near 'predict' // then we know it doesn't exist in the bucket, so return INF if (b.predict->site > s && b.predict > b.begin && (b.predict-1)->site < s) return GCO_MAX_ENERGYTERM; } if ( (size_t)b.end - (size_t)b.begin == cSitesPerBucket*sizeof(SparseDataCost) ) return b.begin[s-b.begin->site].cost; // special case: this particular bucket is actually dense! return search(b,s); } GCoptimization::SiteID GCoptimization::DataCostFnSparse::queryActiveSitesExpansion(LabelID alpha_label, const LabelID* labeling, SiteID* activeSites) { const SparseDataCost* next = m_buckets[alpha_label*m_buckets_per_label].begin; const SparseDataCost* end = m_buckets[alpha_label*m_buckets_per_label + m_buckets_per_label-1].end; SiteID count = 0; for (; next < end; ++next) { if ( labeling[next->site] != alpha_label ) activeSites[count++] = next->site; } return count; }