Main Page | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Class Members | File Members

solver.cpp

Go to the documentation of this file.
00001 #include "solver.hpp"
00002 #include "math.h"
00003 
00004 /*===========SourceFunction===============*/
00005 
00006 
00007 SourceFunction::SourceFunction(unsigned _N, double _H, 
00008         unsigned _M, double _T, unsigned _m1, unsigned _m2)
00009     : N(_N), M(_M), m1(_m1), m2(_m2), h(_H / double(_N)), tau(_T / double(_M))
00010 { }
00011 
00012 SourceFunctionGeneral::SourceFunctionGeneral(unsigned _N, double _H, 
00013     unsigned _M, double _T, unsigned _m1, unsigned _m2, double *_v)
00014     : SourceFunction(_N, _H, _M, _T, _m1, _m2)
00015 {
00016     unsigned i, k;
00017     v = new double[(N + 1) * (M + 1)];
00018     for(k = 0; k <= M; ++k) {
00019         for(i = 0; i <= N; ++i) {
00020             if((k < m1) || (m2 < k)) {
00021                 v[(N + 1) * k + i] = 0;
00022             } else {
00023                 v[(N + 1) * k + i] = _v[(N + 1) * k + i];
00024             }
00025         }
00026     }
00027 }
00028 
00029 SourceFunctionGeneral::~SourceFunctionGeneral()
00030 {
00031     delete [] v;
00032 }
00033 
00034 double SourceFunctionGeneral::operator()(unsigned xi, unsigned ti) const
00035 {
00036 #if HANDLEXCEPTION
00037     if((xi > N) || (ti > M)) {
00038         throw "assume xi <= N and ti <= M";
00039     }
00040 #endif
00041     return v[(N + 1) * ti + xi];
00042 }
00043 
00044 double SourceFunctionGeneral::SquareNorm() const
00045 {
00046     unsigned i, k;
00047     double res, tmp;
00048     res = 0;
00049     for(k = m1; k <= m2; ++k) {
00050         tmp = 0;
00051         for(i = 0; i <= N; ++i) {
00052             tmp += v[(N + 1) * k + i] * v[(N + 1) * k + i];
00053         }
00054         tmp -= 0.5 * (v[(N + 1) * k] + v[(N + 1) * k + N]);
00055         if((k == m1) || (k == m2)) {
00056             tmp *= 0.5;
00057         }
00058         res += tmp;
00059     }
00060     return tau * h * res;
00061 }
00062 
00063 void SourceFunctionGeneral::Update(double *q, double alpha, double dzeta)
00064 {
00065     unsigned i, k;
00066     for(k = m1; k <= m2; ++k) {
00067         for(i = 0; i <= N; ++i) {
00068             v[(N + 1) * k + i] -= dzeta * (alpha * 
00069                 v[(N + 1) * k + i] + q[(N + 1) * k + i]);
00070         }
00071     }
00072 }
00073 
00074 SourceFunctionFixTime::SourceFunctionFixTime(unsigned _N, double _H, 
00075         unsigned _M, double _T, unsigned _m1, unsigned _m2,
00076         double *_vt, double *_vx)
00077     : SourceFunction(_N, _H, _M, _T, _m1, _m2)
00078 {
00079     unsigned i;
00080     vx = new double[N + 1];
00081     vt = new double[M + 1];
00082     vt_scale = new double[M + 1];
00083     for(i = 0; i <= N; ++i) {
00084         vx[i] = _vx[i];
00085     }
00086     square_norm_vt = 0;
00087     for(i = m1; i <= m2; ++i) {
00088         square_norm_vt += _vt[i] * _vt[i];
00089     }
00090     square_norm_vt = tau * (square_norm_vt -
00091             0.5 * (_vt[m1] * _vt[m1] + _vt[m2] * _vt[m2]));
00092     for(i = 0; i <= M; ++i) {
00093         if((i < m1) || (m2 < i)) {
00094             vt[i] = 0;
00095             vt_scale[i] = 0;
00096         } else {
00097             vt[i] = _vt[i];
00098             vt_scale[i] = _vt[i] / square_norm_vt;
00099         }
00100     }
00101 }
00102 
00103 SourceFunctionFixTime::~SourceFunctionFixTime()
00104 {
00105     delete [] vx;
00106     delete [] vt;
00107     delete [] vt_scale;
00108 }
00109 
00110 double SourceFunctionFixTime::operator()(unsigned xi, unsigned ti) const
00111 {
00112 #if HANDLEXCEPTION
00113     if((xi > N) || (ti > M)) {
00114         throw "assume xi <= N and ti <= M";
00115     }
00116 #endif
00117     return vx[xi] * vt[ti];
00118 }
00119 
00120 double SourceFunctionFixTime::SquareNorm() const
00121 {
00122     unsigned i;
00123     double square_norm_vx;
00124     square_norm_vx = 0;
00125     for(i = 0; i <= N; ++i) {
00126         square_norm_vx += vx[i] * vx[i];
00127     }
00128     square_norm_vx = h * (square_norm_vx -
00129             0.5 * (vx[0] * vx[0] + vx[N] * vx[N]));
00130     return square_norm_vx * square_norm_vt;
00131 }
00132 
00133 void SourceFunctionFixTime::Update(double *q, double alpha, double dzeta)
00134 {
00135     double tmp;
00136     unsigned i, k;
00137     for(i = 0; i <= N; ++i) {
00138         tmp = 0;
00139         for(k = m1 + 1; k < m2; ++k) {
00140             tmp += vt_scale[k] * q[(N + 1) * k + i];
00141         }
00142         tmp += 0.5 * (vt_scale[m1] * q[(N + 1) * m1 + i] +
00143                 vt_scale[m2] * q[(N + 1) * m2 + i]);
00144         vx[i] -= dzeta * (alpha * vx[i] + tau * tmp);
00145     }
00146 }
00147 
00148 /*===========RightSide===============*/
00149 
00150 RightSideNoTime::RightSideNoTime(unsigned _N, const double *_arr)
00151     : N(_N), arr(_arr)
00152 { }
00153 
00154 void RightSideNoTime::Set(unsigned _N, const double *_arr)
00155 {
00156     N = _N;
00157     arr = _arr;
00158 }
00159 
00160 double RightSideNoTime::operator()(unsigned xi, unsigned ti) const
00161 {
00162 #if HANDLEXCEPTION
00163     if(xi > N) {
00164         throw "assume xi <= N";
00165     }
00166 #endif
00167     return arr[xi];
00168 }
00169 
00170 RightSideGeneral::RightSideGeneral(unsigned _N, unsigned _M, const double *_arr)
00171     : N(_N), M(_M), arr(_arr)
00172 { }
00173 
00174 void RightSideGeneral::Set(unsigned _N, unsigned _M, const double *_arr)
00175 {
00176     N = _N;
00177     M = _M;
00178     arr = _arr;
00179 }
00180 
00181 double RightSideGeneral::operator()(unsigned xi, unsigned ti) const
00182 {
00183 #if HANDLEXCEPTION
00184     if((xi > N) || (ti > M)) {
00185         throw "assume xi <= N and ti <= M";
00186     }
00187 #endif
00188     return arr[(N + 1) * ti + xi];
00189 }
00190 
00191 RightSideSourceFunction::RightSideSourceFunction(const SourceFunction& _v)
00192     : v(_v)
00193 { }
00194 
00195 double RightSideSourceFunction::operator()(unsigned xi, unsigned ti) const
00196 {
00197     return v(xi, ti);
00198 }
00199 
00200 /*===========StopCriteria===============*/
00201 
00202 // TODO: func_new not init by 1
00203 
00204 StopCriteria::StopCriteria(unsigned max_iter, bool calc_func)
00205     : type(calc_func ? stop_max_iters_calc_func : stop_max_iters),
00206     cur_iter(0), func_old(0), func_new(1)
00207 {
00208     param.max_iter = max_iter;
00209 }
00210 
00211 StopCriteria::StopCriteria(double tol)
00212     : type(stop_func), cur_iter(0), func_old(0), func_new(1)
00213 {
00214     param.tol = tol;
00215 }
00216 
00217 void StopCriteria::Reset()
00218 {
00219     cur_iter = 0;
00220     func_old = 0;
00221     func_new = 1;
00222 }
00223 
00224 StopCriteria::stop_type StopCriteria::GetType() const
00225 {
00226     switch(type) {
00227     case stop_max_iters:
00228         return stop_max_iters;
00229     case stop_max_iters_calc_func:
00230         return stop_max_iters_calc_func;
00231     case stop_func:
00232         return stop_func;
00233     default:
00234         /* imposible */
00235         return stop_func;
00236     }
00237 }
00238 
00239 bool StopCriteria::NeedCalcFunc() const
00240 {
00241     return type != stop_max_iters;
00242 }
00243 
00244 bool StopCriteria::IsStop() const
00245 {
00246     if(type == stop_func) {
00247         return fabs(func_old - func_new) < param.tol * func_old;
00248     } else {
00249         return cur_iter >= param.max_iter;
00250     }
00251 }
00252 
00253 void StopCriteria::CalcFunc(unsigned N, double h, const double *qT,
00254         double alpha, double square_norm_v)
00255 {
00256     unsigned i;
00257     double square_norm_qT = 0;
00258     for(i = 0; i <= N; ++i) {
00259         square_norm_qT += qT[i] * qT[i];
00260     }
00261     square_norm_qT = h * (square_norm_qT - 0.5 * (qT[0] + qT[N]));
00262     func_old = func_new;
00263     func_new = alpha * square_norm_v + square_norm_qT;
00264 }
00265 
00266 
00267 /*===========SmoluchowskiSolver===============*/
00268 
00269 /* solve eq: dg/dt + Ag = rightside, g(t=0) = res(0 : N + 1) by Euler
00270  * g(t = k) = res((N + 1) * k : (N + 1) * (k + 1))
00271  * res is array with ((N + 1) * (M + 1)) elements  */
00272 void SmoluchowskiSolver::forward_problem(SmoluchowskiOperator& A, 
00273         double T, unsigned M, const RightSide& rightside, double *res)
00274 {
00275     unsigned i, k, N;
00276     double tau = T / double(M);
00277     N = A.GetSmolCalc().GetN();
00278     for(k = 0; k < M; ++k) {
00279         A.Apply(res + (N + 1) * k, res + (N + 1) * (k + 1));
00280         for(i = 0; i <= N; ++i) {
00281             res[(N + 1) * (k + 1) + i] = res[(N + 1) * k + i] + 
00282                 tau * (rightside(i, k) - 
00283                         res[(N + 1) * (k + 1) + i]);
00284         }
00285     }
00286 }
00287 
00288 /* solve eq: -dq/dt + A*q = 0, g(t=T) = res((N + 1) * M : (N + 1) * (M + 1))
00289  * by Euler; q(t = k) = res((N + 1) * k : (N + 1) * (k + 1))
00290  * res is array with ((N + 1) * (M + 1)) elements  */
00291 void SmoluchowskiSolver::adjoint_problem(SmoluchowskiLinearOperator& A, 
00292         double T, unsigned M, double *res)
00293 {
00294     unsigned i, k, N;
00295     double tau = T / double(M);
00296     N = A.GetSmolCalc().GetN();
00297     for(k = M; k > 0; --k) {
00298         A.ApplyAdjoint(res + (N + 1) * k, res + (N + 1) * (k - 1));
00299         for(i = 0; i <= N; ++i) {
00300             res[(N + 1) * (k - 1) + i] = res[(N + 1) * k + i] -
00301                 tau * res[(N + 1) * (k - 1) + i];
00302         }
00303     }
00304 }
00305 
00306 
00307 SmoluchowskiSolver::SmoluchowskiSolver(SmoluchowskiLinearOperator& _A)
00308     : A(_A)
00309 {
00310     unsigned N = A.GetSmolCalc().GetN();
00311     h_obs = new double[N + 1];
00312     calc_s();
00313 }
00314 
00315 SmoluchowskiSolver::~SmoluchowskiSolver()
00316 {
00317     delete [] s;
00318     delete [] h_obs;
00319 }
00320 
00321 void SmoluchowskiSolver::calc_s()
00322 {
00323     unsigned i, N;
00324     const double *a, *c0;
00325     SmoluchowskiCalc& smol_calc = A.GetSmolCalc();
00326     a = A.Get_a();
00327     c0 = A.Get_c0();
00328     N = smol_calc.GetN();
00329     s = new double[N + 1];
00330     smol_calc.calc_L4(c0, s);
00331     smol_calc.calc_L1_fix_f(c0, h_obs);// A already fix c0; h_obs is tmp
00332     for(i = 0; i <= N; ++i) {
00333         s[i] += 0.5 * h_obs[i] - a[i] * c0[i];
00334     }
00335 }
00336 
00337 const double *SmoluchowskiSolver::Get_s() const
00338 {
00339     return s;
00340 }
00341 
00342 unsigned SmoluchowskiSolver::GetN() const
00343 {
00344     return A.GetSmolCalc().GetN();
00345 }
00346 
00347 void SmoluchowskiSolver::inverse_problem(SourceFunction& v,
00348         const double *g_obs, double alpha, double dzeta,
00349         StopCriteria& stop_criteria, Callback* callback)
00350 {
00351     unsigned i, N, M;
00352     double T, h, *tmp_res;
00353     N = v.GetN();
00354     M = v.GetM();
00355     T = v.Get_tau() * M;
00356     h = A.Get_h();
00357     tmp_res = new double[(N + 1) * (M + 1)];
00358     for(i = 0; i <= N; ++i) {
00359         tmp_res[i] = 0;
00360     }
00361     /* calc g_1 */
00362     forward_problem(A, T, M, RightSideNoTime(N, s), tmp_res);
00363     /* calc h_obs = g_obs - g1(t = T) */
00364     for(i = 0; i <= N; ++i) {
00365         h_obs[i] = g_obs[i] - tmp_res[(N + 1) * M + i];
00366     }
00367     do {
00368         /* calc h */
00369         for(i = 0; i <= N; ++i) {
00370             tmp_res[i] = 0;
00371         }
00372         forward_problem(A, T, M, RightSideSourceFunction(v), tmp_res);
00373         /* calc q */
00374         for(i = 0; i <= N; ++i) {
00375             tmp_res[(N + 1) * M + i] -= h_obs[i];
00376         }
00377         if(stop_criteria.NeedCalcFunc()) {
00378             stop_criteria.CalcFunc(N, h, tmp_res + (N + 1) * M,
00379                     alpha, v.SquareNorm());
00380         }
00381         adjoint_problem(A, T, M, tmp_res);
00382         /* update v */
00383         v.Update(tmp_res, alpha, dzeta);
00384         if(callback) {
00385             callback->InverseProblemIter(stop_criteria);
00386         }
00387         stop_criteria.UpdateCurIter();
00388     } while (!stop_criteria.IsStop());
00389     delete [] tmp_res;
00390 }

Generated on Sun May 25 01:58:04 2025 for SmoluchowskiSolver by Doxygen