# Direct Sampler by SMLE for 2-way Tables
# By Shuhei Mano, June 2, 2026
# Used in Alg. Stat. 16: 175-199 (2025)

# define table

nr=4 # number of rows
nc=5 # number of columns
sc=2 # counts in each cell (sample size is nr*nc*sc)

# define odds ratios

od<-matrix(c(rep(1,nr*nc)),nrow=nr,ncol=nc)
od[1,1]<-3
od[1,2]<-2
od[2,1]<-2
od[2,2]<-1
sod<-sum(od)

# construct A matrix (constraint matrix for margins)

a<-matrix(c(rep(0,(nr+nc)*nr*nc)),nrow=nr+nc,ncol=nr*nc)
for(i in 1:nr) for(j in (nc*(i-1)+1):(nc*i)) a[i,j]=1
for(i in (nr+1):(nr+nc)) {
  for(j in 1:(nr*nc))
    if((j-1)%%nc==(i-nr-1)) a[i,j]=1
}

# chi2

chi2<-function(x) {
  val<-0
  for(i in 1:nr)
    for(j in 1:nc) val<-val+(x[i,j]-sc)^2/sc
  return (val)
}		     

# memoization cache for ipsDR

ipsDR_cache<-new.env(hash=TRUE,parent=emptyenv())

# IPS by DR72

ipsDR<-function(ip,r,c) {

  # tolerance parameters
  conv=0.1 # difference per marginal counts
  itmax=20 # max number of iterations
  
  be<-c(r,c) # marginal counter
  ns<-sum(r) # current total counts

  lb<-length(be)
  cb<-0
  for(i in 1:lb) {
    if(be[i]>0) cb<-cb+1
  }  
  
  for(l in 1:itmax) {
    av<-ns*colSums(t(a)*ip)

    for(j in 1:(nr*nc)) {
      val<-1
      for(i in 1:(nr+nc)) {
	if(av[i]>0&&a[i,j]>0)
          val<-val*(be[i]/av[i])^(a[i,j])
      }
      ip[j]<-ip[j]*sqrt(val)
      #ip[j]<-ip[j]*val # DS40 for log-linear model

    }   
    if(sum(abs(av-be))<cb*conv) break
  }

  if(l==itmax) ip[1]<-999 # does not converge

  return(ip)
  
}

# memoized wrapper for ipsDR

ipsDR_memo<-function(ip,r,c) {
  
  key<-paste(paste(r,collapse=","),paste(c,collapse=","),sep="|")
  
  if(exists(key,envir=ipsDR_cache)) {
    return(get(key,envir=ipsDR_cache))
  }
  
  res<-ipsDR(ip, r, c)
  assign(key,res,envir=ipsDR_cache)
  
  return(res)
  
}

# clear cache

clear_ipsDR_cache<-function() {
  rm(list=ls(envir=ipsDR_cache),envir=ipsDR_cache)
}

# main

ds<-function(trial,seed) {

  cmax<-0         # failure counter 
  maxat<-trial*10 # max number of attempts
  
  set.seed(seed)

  rs<-c(rep(sc*nc,nr)) # rowsum
  cs<-c(rep(sc*nr,nc)) # colsum

  ch<-c(rep(0,trial))  # chi2

  start<-proc.time()

  at<-0

  j<-0  
  while(j<trial) {

    at<-at+1
    if(at>maxat) {
      print("WARNING: max number of attempts reached.")
      break
    }
    
    brs<-rs
    bcs<-cs
    u<-matrix(c(rep(0,nr*nc)),nrow=nr,ncol=nc)
    
    for(i in 1:(nr*nc*sc)) {

      #ns<-sum(brs)
      #p<-as.vector(t((brs/ns)%o%(bcs/ns))) # independent model

      p<-as.vector(t(od/sod))  # entries aline column-wise! 
      p<-ipsDR(p,brs,bcs)      # IPS (add _memo for memo) 
      if(p[1]==999) break      # does not converge
      
      r<-sample(1:(nr*nc),1,prob=p)

      if(r%%nc==0) {
        u[r%/%nc,nc]<-u[r%/%nc,nc]+1
        brs[r%/%nc]<-brs[r%/%nc]-1
        bcs[nc]<-bcs[nc]-1
      } else {
        u[r%/%nc+1,r%%nc]<-u[r%/%nc+1,r%%nc]+1
        brs[r%/%nc+1]<-brs[r%/%nc+1]-1
        bcs[r%%nc]<-bcs[r%%nc]-1
      } # pick
    }

    if(p[1]==999) {
      cmax<-cmax+1
    } else {  
      ch[j]<-chi2(u)
      j<-j+1
    }
    
  }

  end<-proc.time()

  if(at<=maxat) {
    hist(ch,breaks=seq(0,70,length.out=19),main="Histogram of chi2",xlab="chi2")
  }
  print(sprintf("number of fails %d",cmax))
  print(sprintf("time elapsed %f",as.numeric(end[3]-start[3])))

}
