-
Notifications
You must be signed in to change notification settings - Fork 85
/
test.cpp
133 lines (116 loc) · 3.91 KB
/
test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
Copyright (C) 2010,2011 Wei Dong <[email protected]>. All Rights Reserved.
DISTRIBUTION OF THIS PROGRAM IN EITHER BINARY OR SOURCE CODE FORM MUST BE
PERMITTED BY THE AUTHOR.
*/
#ifndef KGRAPH_VALUE_TYPE
#define KGRAPH_VALUE_TYPE float
#endif
#include <cctype>
#include <type_traits>
#include <iostream>
#include <boost/timer/timer.hpp>
#include <boost/program_options.hpp>
#include <kgraph.h>
#include <kgraph-data.h>
using namespace std;
using namespace boost;
using namespace kgraph;
namespace po = boost::program_options;
typedef KGRAPH_VALUE_TYPE value_type;
int main (int argc, char *argv[]) {
string input_path;
string query_path;
string output_path;
string eval_path;
unsigned K, P;
po::options_description desc_visible("General options");
desc_visible.add_options()
("help,h", "produce help message.")
("data", po::value(&input_path), "input path")
("query", po::value(&query_path), "query path")
("eval", po::value(&eval_path), "eval path")
(",K", po::value(&K)->default_value(default_K), "")
(",P", po::value(&P)->default_value(default_P), "")
;
po::options_description desc("Allowed options");
desc.add(desc_visible);
po::positional_options_description p;
p.add("data", 1);
p.add("query", 1);
po::variables_map vm;
po::store(po::command_line_parser(argc, argv).options(desc).positional(p).run(), vm);
po::notify(vm);
if (vm.count("help") || vm.count("data") == 0 || vm.count("query") == 0) {
cout << "search <data> <index> <query> [output]" << endl;
cout << desc_visible << endl;
return 0;
}
if (P < K) {
P = K;
}
Matrix<value_type> data;
Matrix<value_type> query;
Matrix<unsigned> result;
data.load_lshkit(input_path);
query.load_lshkit(query_path);
unsigned dim = data.dim();
VectorOracle<Matrix<value_type>, value_type const*> oracle(data,
[dim](value_type const *a, value_type const *b)
{
float r = 0;
for (unsigned i = 0; i < dim; ++i) {
float v = float(a[i]) - (b[i]);
r += v * v;
}
return r;
});
float recall = 0;
float cost = 0;
float time = 0;
result.resize(query.size(), K);
KGraph::SearchParams params;
params.K = K;
params.P = P;
KGraph *kgraph = KGraph::create();
{
KGraph::IndexParams params;
kgraph->build(oracle, params, NULL);
}
boost::timer::auto_cpu_timer timer;
cerr << "Searching..." << endl;
#pragma omp parallel for reduction(+:cost)
for (unsigned i = 0; i < query.size(); ++i) {
KGraph::SearchInfo info;
kgraph->search(oracle.query(query[i]), params, result[i], &info);
cost += info.cost;
}
cost /= query.size();
time = timer.elapsed().wall / 1e9;
delete kgraph;
if (eval_path.size()) {
Matrix<unsigned> gs;
gs.load_lshkit(eval_path);
BOOST_VERIFY(gs.dim() >= K);
BOOST_VERIFY(gs.size() >= query.size());
kgraph::Matrix<float> gs_dist(query.size(), K);
kgraph::Matrix<float> result_dist(query.size(), K);
#pragma omp parallel for
for (unsigned i = 0; i < query.size(); ++i) {
auto const Q = oracle.query(query[i]);
float *gs_dist_row = gs_dist[i];
float *result_dist_row = result_dist[i];
unsigned const *gs_row = gs[i];
unsigned const *result_row = result[i];
for (unsigned k = 0; k < K; ++k) {
gs_dist_row[k] = Q(gs_row[k]);
result_dist_row[k] = Q(result_row[k]);
}
sort(gs_dist_row, gs_dist_row + K);
sort(result_dist_row, result_dist_row + K);
}
recall = AverageRecall(gs_dist, result_dist, K);
}
cout << "Time: " << time << " Recall: " << recall << " Cost: " << cost << endl;
return 0;
}