N-Queens / C++ & PARDS版を高速化

Erlang版とかに比べて遅すぎると思ったので、高速化にトライ。
元々、Fleng/KLIC/Erlang版をストレートに移植したので、リストに頼りすぎたデータ構造になっていたが、普通C/C++ではもう少し配列を使うよなぁ、ということで、ワーキングエリアとして、配列を含むClassを定義して、リストの代わりにすることに。答えを返すデータ構造は長さが分からないのでリストのままだけど。
それから、newを使わずスタック上にデータを積むようにしてみた。
すると、サイズが13で、

  • 逐次版:2.1秒
  • 並列版:1.1秒

まで高速化 (^^

配列サイズは#defineで決めているんだけど、このサイズの大小で結構時間が変わる。キャッシュに乗るかどうかなどで変わってるのかなぁ?
本気でN-Queensの高速化をやると、まだまだ追いつめられそうなんだけど、アルゴリズムとデータ構造の見直しが必要そうだし、他の言語との比較に意味がなくなりそうなので、とりあえずこの辺でストップ。

長いけど、ソースコードは下記の通り:

#include "Sync.h"

#define MAX_SIZE 13

class Work {
public:
  int cand[MAX_SIZE];
  int pos;
  Work(){pos = 0;}
  int pop(){pos--; return cand[pos];}
  void push(int val){cand[pos] = val; pos++;}
};

class AnsList{
public:
  int ans[MAX_SIZE];
  AnsList* next;
  static void* operator new(size_t size){
    void* ptr = pards_shmalloc(size);
    return ptr;
  }
  static void operator delete(void* obj){
    pards_shmfree(obj);
    return;
  }
};

AnsList* appendAnsList(AnsList* a, AnsList* b){
  if(a == 0) return b;
  else {
    AnsList* tmp = a;
    while(tmp->next != 0) tmp = tmp->next;
    tmp->next = b;
    return a;
  }
}

void WorkToAnsList(Work* w, AnsList* a){
  for(int i = 0; i < w->pos; i++)
    a->ans[i] = w->cand[i];
}

void append(Work* a1, Work* a2, Work* a3){
  int i, j;
  for(i = 0; i < a1->pos; i++)
    a3->cand[i] = a1-> cand[i];
  for(j = 0; j < a2->pos; j++)
    a3->cand[i+j] = a2->cand[j];
  a3->pos = i+j;
}

void generate(Work* a, int n){
  int i;
  for(i = 0; i < n; i++)
    a->cand[i] = i+1;
  a->pos = i;
}

void qudist(Work* plu, Work* ls, Work* lp, Sync<AnsList*> ans){
  if(plu->pos == 0) {
    if(ls->pos == 0) {
      AnsList* res = new AnsList;
      WorkToAnsList(lp,res);
      res->next = 0;
      ans.write(res);
      return;
    } else {
      ans.write(0);
      return;
    }
  } else {
    AnsList* check(int, int, Work*, Work*, Work*);

    int p = plu->pop();
    Work* lu = plu;
    Work lr;
    append(lu,ls,&lr);
    Work lp2 = *lp;

    ls->push(p);

    Sync<AnsList*> syncans2;
    SPAWN(qudist(lu, ls, lp, syncans2));
   
    AnsList* ans1 = check(p, 1, &lr, &lp2, lp);
    AnsList* ans2 = syncans2.read();

    ans.write(appendAnsList(ans1, ans2));
  }
}

AnsList* qu(Work* plu, Work* ls, Work* lp){
  if(plu->pos == 0) {
    if(ls->pos == 0) {
      AnsList* res = new AnsList;
      WorkToAnsList(lp,res);
      res->next = 0;
      return res;
    } else {
      return 0;
    }
  } else {
    AnsList* check(int, int, Work*, Work*, Work*);

    int p = plu->pop();
    Work* lu = plu;
    Work lr;
    append(lu,ls,&lr);
    Work lp2 = *lp;
    
    AnsList* ans1 = check(p, 1, &lr, &lp2, lp);
    ls->push(p);
    AnsList* ans2 = qu(lu, ls, lp);
    return appendAnsList(ans1, ans2);
  }
}

AnsList* check(int p, int d, Work* l, Work* qlp0, Work* lp){
  while(1){
    if(qlp0->pos == 0) {
      Work plp = *lp;
      plp.push(p);
      Work w;
      return qu(l,&w,&plp);
    } else {
      int q = qlp0->pop();
      if(q + d == p || q - d == p)
	return 0;
      else{
	d = d+1;
	continue;
      }
    }
  }
}

int main(int argc, char* argv[])
{
  if(argc < 2) {
    printf("qu SIZE\n");
    exit(0);
  }

  int num = atoi(argv[1]);

  if(num > MAX_SIZE) {
    printf("SIZE is greater than MAX_SIZE (%d).\n", MAX_SIZE);
    exit(0);
  }

  pards_init();

  struct timeval t1, t2;
  struct timezone tz;
  int i;

  AnsList* res;

  printf("serial version\n");
  gettimeofday(&t1,&tz);
  Work gen1;
  generate(&gen1, num);
  Work w1, w2;
  res = qu(&gen1,&w1,&w2);  
  gettimeofday(&t2,&tz);

  for(i = 0; res != 0; res = res->next, i++){
#ifdef DEBUG
    for(int j = 0; j < num; j++)
      printf("%d ", res->ans[j]);
    printf("\n");
#endif
  }
  printf("num = %d\n",i);
  printf("elapsed time = %f sec\n",
	 t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0);

  printf("parallel version\n");
  gettimeofday(&t1,&tz);
  Work gen2;
  generate(&gen2, num);
  Work w3, w4;
  Sync<AnsList*> ans;
  qudist(&gen2,&w3,&w4,ans);
  res = ans.read();
  gettimeofday(&t2,&tz);
  
  for(i = 0; res != 0; res = res->next, i++){
#ifdef DEBUG
    for(int j = 0; j < num; j++)
      printf("%d ", res->ans[j]);
    printf("\n");
#endif
  }
  printf("num = %d\n",i);
  printf("elapsed time = %f sec\n",
	 t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0);

  pards_finalize();
}