N-Queens / C++ & PARDS版を高速化
Erlang版とかに比べて遅すぎると思ったので、高速化にトライ。
元々、Fleng/KLIC/Erlang版をストレートに移植したので、リストに頼りすぎたデータ構造になっていたが、普通C/C++ではもう少し配列を使うよなぁ、ということで、ワーキングエリアとして、配列を含むClassを定義して、リストの代わりにすることに。答えを返すデータ構造は長さが分からないのでリストのままだけど。
それから、newを使わずスタック上にデータを積むようにしてみた。
すると、サイズが13で、
- 逐次版:2.1秒
- 並列版:1.1秒
まで高速化 (^^
配列サイズは#defineで決めているんだけど、このサイズの大小で結構時間が変わる。キャッシュに乗るかどうかなどで変わってるのかなぁ?
本気でN-Queensの高速化をやると、まだまだ追いつめられそうなんだけど、アルゴリズムとデータ構造の見直しが必要そうだし、他の言語との比較に意味がなくなりそうなので、とりあえずこの辺でストップ。
長いけど、ソースコードは下記の通り:
#include "Sync.h" #define MAX_SIZE 13 class Work { public: int cand[MAX_SIZE]; int pos; Work(){pos = 0;} int pop(){pos--; return cand[pos];} void push(int val){cand[pos] = val; pos++;} }; class AnsList{ public: int ans[MAX_SIZE]; AnsList* next; static void* operator new(size_t size){ void* ptr = pards_shmalloc(size); return ptr; } static void operator delete(void* obj){ pards_shmfree(obj); return; } }; AnsList* appendAnsList(AnsList* a, AnsList* b){ if(a == 0) return b; else { AnsList* tmp = a; while(tmp->next != 0) tmp = tmp->next; tmp->next = b; return a; } } void WorkToAnsList(Work* w, AnsList* a){ for(int i = 0; i < w->pos; i++) a->ans[i] = w->cand[i]; } void append(Work* a1, Work* a2, Work* a3){ int i, j; for(i = 0; i < a1->pos; i++) a3->cand[i] = a1-> cand[i]; for(j = 0; j < a2->pos; j++) a3->cand[i+j] = a2->cand[j]; a3->pos = i+j; } void generate(Work* a, int n){ int i; for(i = 0; i < n; i++) a->cand[i] = i+1; a->pos = i; } void qudist(Work* plu, Work* ls, Work* lp, Sync<AnsList*> ans){ if(plu->pos == 0) { if(ls->pos == 0) { AnsList* res = new AnsList; WorkToAnsList(lp,res); res->next = 0; ans.write(res); return; } else { ans.write(0); return; } } else { AnsList* check(int, int, Work*, Work*, Work*); int p = plu->pop(); Work* lu = plu; Work lr; append(lu,ls,&lr); Work lp2 = *lp; ls->push(p); Sync<AnsList*> syncans2; SPAWN(qudist(lu, ls, lp, syncans2)); AnsList* ans1 = check(p, 1, &lr, &lp2, lp); AnsList* ans2 = syncans2.read(); ans.write(appendAnsList(ans1, ans2)); } } AnsList* qu(Work* plu, Work* ls, Work* lp){ if(plu->pos == 0) { if(ls->pos == 0) { AnsList* res = new AnsList; WorkToAnsList(lp,res); res->next = 0; return res; } else { return 0; } } else { AnsList* check(int, int, Work*, Work*, Work*); int p = plu->pop(); Work* lu = plu; Work lr; append(lu,ls,&lr); Work lp2 = *lp; AnsList* ans1 = check(p, 1, &lr, &lp2, lp); ls->push(p); AnsList* ans2 = qu(lu, ls, lp); return appendAnsList(ans1, ans2); } } AnsList* check(int p, int d, Work* l, Work* qlp0, Work* lp){ while(1){ if(qlp0->pos == 0) { Work plp = *lp; plp.push(p); Work w; return qu(l,&w,&plp); } else { int q = qlp0->pop(); if(q + d == p || q - d == p) return 0; else{ d = d+1; continue; } } } } int main(int argc, char* argv[]) { if(argc < 2) { printf("qu SIZE\n"); exit(0); } int num = atoi(argv[1]); if(num > MAX_SIZE) { printf("SIZE is greater than MAX_SIZE (%d).\n", MAX_SIZE); exit(0); } pards_init(); struct timeval t1, t2; struct timezone tz; int i; AnsList* res; printf("serial version\n"); gettimeofday(&t1,&tz); Work gen1; generate(&gen1, num); Work w1, w2; res = qu(&gen1,&w1,&w2); gettimeofday(&t2,&tz); for(i = 0; res != 0; res = res->next, i++){ #ifdef DEBUG for(int j = 0; j < num; j++) printf("%d ", res->ans[j]); printf("\n"); #endif } printf("num = %d\n",i); printf("elapsed time = %f sec\n", t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0); printf("parallel version\n"); gettimeofday(&t1,&tz); Work gen2; generate(&gen2, num); Work w3, w4; Sync<AnsList*> ans; qudist(&gen2,&w3,&w4,ans); res = ans.read(); gettimeofday(&t2,&tz); for(i = 0; res != 0; res = res->next, i++){ #ifdef DEBUG for(int j = 0; j < num; j++) printf("%d ", res->ans[j]); printf("\n"); #endif } printf("num = %d\n",i); printf("elapsed time = %f sec\n", t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0); pards_finalize(); }