00001 #include "solver.hpp"
00002 #include "math.h"
00003
00004
00005
00006
00007 SourceFunction::SourceFunction(unsigned _N, double _H,
00008 unsigned _M, double _T, unsigned _m1, unsigned _m2)
00009 : N(_N), M(_M), m1(_m1), m2(_m2), h(_H / double(_N)), tau(_T / double(_M))
00010 { }
00011
00012 SourceFunctionGeneral::SourceFunctionGeneral(unsigned _N, double _H,
00013 unsigned _M, double _T, unsigned _m1, unsigned _m2, double *_v)
00014 : SourceFunction(_N, _H, _M, _T, _m1, _m2)
00015 {
00016 unsigned i, k;
00017 v = new double[(N + 1) * (M + 1)];
00018 for(k = 0; k <= M; ++k) {
00019 for(i = 0; i <= N; ++i) {
00020 if((k < m1) || (m2 < k)) {
00021 v[(N + 1) * k + i] = 0;
00022 } else {
00023 v[(N + 1) * k + i] = _v[(N + 1) * k + i];
00024 }
00025 }
00026 }
00027 }
00028
00029 SourceFunctionGeneral::~SourceFunctionGeneral()
00030 {
00031 delete [] v;
00032 }
00033
00034 double SourceFunctionGeneral::operator()(unsigned xi, unsigned ti) const
00035 {
00036 #if HANDLEXCEPTION
00037 if((xi > N) || (ti > M)) {
00038 throw "assume xi <= N and ti <= M";
00039 }
00040 #endif
00041 return v[(N + 1) * ti + xi];
00042 }
00043
00044 double SourceFunctionGeneral::SquareNorm() const
00045 {
00046 unsigned i, k;
00047 double res, tmp;
00048 res = 0;
00049 for(k = m1; k <= m2; ++k) {
00050 tmp = 0;
00051 for(i = 0; i <= N; ++i) {
00052 tmp += v[(N + 1) * k + i] * v[(N + 1) * k + i];
00053 }
00054 tmp -= 0.5 * (v[(N + 1) * k] + v[(N + 1) * k + N]);
00055 if((k == m1) || (k == m2)) {
00056 tmp *= 0.5;
00057 }
00058 res += tmp;
00059 }
00060 return tau * h * res;
00061 }
00062
00063 void SourceFunctionGeneral::Update(double *q, double alpha, double dzeta)
00064 {
00065 unsigned i, k;
00066 for(k = m1; k <= m2; ++k) {
00067 for(i = 0; i <= N; ++i) {
00068 v[(N + 1) * k + i] -= dzeta * (alpha *
00069 v[(N + 1) * k + i] + q[(N + 1) * k + i]);
00070 }
00071 }
00072 }
00073
00074 SourceFunctionFixTime::SourceFunctionFixTime(unsigned _N, double _H,
00075 unsigned _M, double _T, unsigned _m1, unsigned _m2,
00076 double *_vt, double *_vx)
00077 : SourceFunction(_N, _H, _M, _T, _m1, _m2)
00078 {
00079 unsigned i;
00080 vx = new double[N + 1];
00081 vt = new double[M + 1];
00082 vt_scale = new double[M + 1];
00083 for(i = 0; i <= N; ++i) {
00084 vx[i] = _vx[i];
00085 }
00086 square_norm_vt = 0;
00087 for(i = m1; i <= m2; ++i) {
00088 square_norm_vt += _vt[i] * _vt[i];
00089 }
00090 square_norm_vt = tau * (square_norm_vt -
00091 0.5 * (_vt[m1] * _vt[m1] + _vt[m2] * _vt[m2]));
00092 for(i = 0; i <= M; ++i) {
00093 if((i < m1) || (m2 < i)) {
00094 vt[i] = 0;
00095 vt_scale[i] = 0;
00096 } else {
00097 vt[i] = _vt[i];
00098 vt_scale[i] = _vt[i] / square_norm_vt;
00099 }
00100 }
00101 }
00102
00103 SourceFunctionFixTime::~SourceFunctionFixTime()
00104 {
00105 delete [] vx;
00106 delete [] vt;
00107 delete [] vt_scale;
00108 }
00109
00110 double SourceFunctionFixTime::operator()(unsigned xi, unsigned ti) const
00111 {
00112 #if HANDLEXCEPTION
00113 if((xi > N) || (ti > M)) {
00114 throw "assume xi <= N and ti <= M";
00115 }
00116 #endif
00117 return vx[xi] * vt[ti];
00118 }
00119
00120 double SourceFunctionFixTime::SquareNorm() const
00121 {
00122 unsigned i;
00123 double square_norm_vx;
00124 square_norm_vx = 0;
00125 for(i = 0; i <= N; ++i) {
00126 square_norm_vx += vx[i] * vx[i];
00127 }
00128 square_norm_vx = h * (square_norm_vx -
00129 0.5 * (vx[0] * vx[0] + vx[N] * vx[N]));
00130 return square_norm_vx * square_norm_vt;
00131 }
00132
00133 void SourceFunctionFixTime::Update(double *q, double alpha, double dzeta)
00134 {
00135 double tmp;
00136 unsigned i, k;
00137 for(i = 0; i <= N; ++i) {
00138 tmp = 0;
00139 for(k = m1 + 1; k < m2; ++k) {
00140 tmp += vt_scale[k] * q[(N + 1) * k + i];
00141 }
00142 tmp += 0.5 * (vt_scale[m1] * q[(N + 1) * m1 + i] +
00143 vt_scale[m2] * q[(N + 1) * m2 + i]);
00144 vx[i] -= dzeta * (alpha * vx[i] + tau * tmp);
00145 }
00146 }
00147
00148
00149
00150 RightSideNoTime::RightSideNoTime(unsigned _N, const double *_arr)
00151 : N(_N), arr(_arr)
00152 { }
00153
00154 void RightSideNoTime::Set(unsigned _N, const double *_arr)
00155 {
00156 N = _N;
00157 arr = _arr;
00158 }
00159
00160 double RightSideNoTime::operator()(unsigned xi, unsigned ti) const
00161 {
00162 #if HANDLEXCEPTION
00163 if(xi > N) {
00164 throw "assume xi <= N";
00165 }
00166 #endif
00167 return arr[xi];
00168 }
00169
00170 RightSideGeneral::RightSideGeneral(unsigned _N, unsigned _M, const double *_arr)
00171 : N(_N), M(_M), arr(_arr)
00172 { }
00173
00174 void RightSideGeneral::Set(unsigned _N, unsigned _M, const double *_arr)
00175 {
00176 N = _N;
00177 M = _M;
00178 arr = _arr;
00179 }
00180
00181 double RightSideGeneral::operator()(unsigned xi, unsigned ti) const
00182 {
00183 #if HANDLEXCEPTION
00184 if((xi > N) || (ti > M)) {
00185 throw "assume xi <= N and ti <= M";
00186 }
00187 #endif
00188 return arr[(N + 1) * ti + xi];
00189 }
00190
00191 RightSideSourceFunction::RightSideSourceFunction(const SourceFunction& _v)
00192 : v(_v)
00193 { }
00194
00195 double RightSideSourceFunction::operator()(unsigned xi, unsigned ti) const
00196 {
00197 return v(xi, ti);
00198 }
00199
00200
00201
00202
00203
00204 StopCriteria::StopCriteria(unsigned max_iter, bool calc_func)
00205 : type(calc_func ? stop_max_iters_calc_func : stop_max_iters),
00206 cur_iter(0), func_old(0), func_new(1)
00207 {
00208 param.max_iter = max_iter;
00209 }
00210
00211 StopCriteria::StopCriteria(double tol)
00212 : type(stop_func), cur_iter(0), func_old(0), func_new(1)
00213 {
00214 param.tol = tol;
00215 }
00216
00217 void StopCriteria::Reset()
00218 {
00219 cur_iter = 0;
00220 func_old = 0;
00221 func_new = 1;
00222 }
00223
00224 StopCriteria::stop_type StopCriteria::GetType() const
00225 {
00226 switch(type) {
00227 case stop_max_iters:
00228 return stop_max_iters;
00229 case stop_max_iters_calc_func:
00230 return stop_max_iters_calc_func;
00231 case stop_func:
00232 return stop_func;
00233 default:
00234
00235 return stop_func;
00236 }
00237 }
00238
00239 bool StopCriteria::NeedCalcFunc() const
00240 {
00241 return type != stop_max_iters;
00242 }
00243
00244 bool StopCriteria::IsStop() const
00245 {
00246 if(type == stop_func) {
00247 return fabs(func_old - func_new) < param.tol * func_old;
00248 } else {
00249 return cur_iter >= param.max_iter;
00250 }
00251 }
00252
00253 void StopCriteria::CalcFunc(unsigned N, double h, const double *qT,
00254 double alpha, double square_norm_v)
00255 {
00256 unsigned i;
00257 double square_norm_qT = 0;
00258 for(i = 0; i <= N; ++i) {
00259 square_norm_qT += qT[i] * qT[i];
00260 }
00261 square_norm_qT = h * (square_norm_qT - 0.5 * (qT[0] + qT[N]));
00262 func_old = func_new;
00263 func_new = alpha * square_norm_v + square_norm_qT;
00264 }
00265
00266
00267
00268
00269
00270
00271
00272 void SmoluchowskiSolver::forward_problem(SmoluchowskiOperator& A,
00273 double T, unsigned M, const RightSide& rightside, double *res)
00274 {
00275 unsigned i, k, N;
00276 double tau = T / double(M);
00277 N = A.GetSmolCalc().GetN();
00278 for(k = 0; k < M; ++k) {
00279 A.Apply(res + (N + 1) * k, res + (N + 1) * (k + 1));
00280 for(i = 0; i <= N; ++i) {
00281 res[(N + 1) * (k + 1) + i] = res[(N + 1) * k + i] +
00282 tau * (rightside(i, k) -
00283 res[(N + 1) * (k + 1) + i]);
00284 }
00285 }
00286 }
00287
00288
00289
00290
00291 void SmoluchowskiSolver::adjoint_problem(SmoluchowskiLinearOperator& A,
00292 double T, unsigned M, double *res)
00293 {
00294 unsigned i, k, N;
00295 double tau = T / double(M);
00296 N = A.GetSmolCalc().GetN();
00297 for(k = M; k > 0; --k) {
00298 A.ApplyAdjoint(res + (N + 1) * k, res + (N + 1) * (k - 1));
00299 for(i = 0; i <= N; ++i) {
00300 res[(N + 1) * (k - 1) + i] = res[(N + 1) * k + i] -
00301 tau * res[(N + 1) * (k - 1) + i];
00302 }
00303 }
00304 }
00305
00306
00307 SmoluchowskiSolver::SmoluchowskiSolver(SmoluchowskiLinearOperator& _A)
00308 : A(_A)
00309 {
00310 unsigned N = A.GetSmolCalc().GetN();
00311 h_obs = new double[N + 1];
00312 calc_s();
00313 }
00314
00315 SmoluchowskiSolver::~SmoluchowskiSolver()
00316 {
00317 delete [] s;
00318 delete [] h_obs;
00319 }
00320
00321 void SmoluchowskiSolver::calc_s()
00322 {
00323 unsigned i, N;
00324 const double *a, *c0;
00325 SmoluchowskiCalc& smol_calc = A.GetSmolCalc();
00326 a = A.Get_a();
00327 c0 = A.Get_c0();
00328 N = smol_calc.GetN();
00329 s = new double[N + 1];
00330 smol_calc.calc_L4(c0, s);
00331 smol_calc.calc_L1_fix_f(c0, h_obs);
00332 for(i = 0; i <= N; ++i) {
00333 s[i] += 0.5 * h_obs[i] - a[i] * c0[i];
00334 }
00335 }
00336
00337 const double *SmoluchowskiSolver::Get_s() const
00338 {
00339 return s;
00340 }
00341
00342 unsigned SmoluchowskiSolver::GetN() const
00343 {
00344 return A.GetSmolCalc().GetN();
00345 }
00346
00347 void SmoluchowskiSolver::inverse_problem(SourceFunction& v,
00348 const double *g_obs, double alpha, double dzeta,
00349 StopCriteria& stop_criteria, Callback* callback)
00350 {
00351 unsigned i, N, M;
00352 double T, h, *tmp_res;
00353 N = v.GetN();
00354 M = v.GetM();
00355 T = v.Get_tau() * M;
00356 h = A.Get_h();
00357 tmp_res = new double[(N + 1) * (M + 1)];
00358 for(i = 0; i <= N; ++i) {
00359 tmp_res[i] = 0;
00360 }
00361
00362 forward_problem(A, T, M, RightSideNoTime(N, s), tmp_res);
00363
00364 for(i = 0; i <= N; ++i) {
00365 h_obs[i] = g_obs[i] - tmp_res[(N + 1) * M + i];
00366 }
00367 do {
00368
00369 for(i = 0; i <= N; ++i) {
00370 tmp_res[i] = 0;
00371 }
00372 forward_problem(A, T, M, RightSideSourceFunction(v), tmp_res);
00373
00374 for(i = 0; i <= N; ++i) {
00375 tmp_res[(N + 1) * M + i] -= h_obs[i];
00376 }
00377 if(stop_criteria.NeedCalcFunc()) {
00378 stop_criteria.CalcFunc(N, h, tmp_res + (N + 1) * M,
00379 alpha, v.SquareNorm());
00380 }
00381 adjoint_problem(A, T, M, tmp_res);
00382
00383 v.Update(tmp_res, alpha, dzeta);
00384 if(callback) {
00385 callback->InverseProblemIter(stop_criteria);
00386 }
00387 stop_criteria.UpdateCurIter();
00388 } while (!stop_criteria.IsStop());
00389 delete [] tmp_res;
00390 }