Package diffpy :: Package srrietveld :: Module project
[frames] | no frames]

Source Code for Module diffpy.srrietveld.project

   1  ############################################################################## 
   2  # 
   3  # diffpy.srrietveld by DANSE Diffraction group 
   4  #                   Simon J. L. Billinge 
   5  #                   (c) 2009 Trustees of the Columbia University 
   6  #                   in the City of New York.  All rights reserved. 
   7  # 
   8  # File coded by:    Jiwu Liu, Yingrui Shang 
   9  # 
  10  # See AUTHORS.txt for a list of people who contributed. 
  11  # See LICENSE.txt for license information. 
  12  # 
  13  ############################################################################## 
  14   
  15  """This module contains the definition of the Project class. A SrRietveld Project 
  16  object can be initiated from saved refinement project file, and the methods in  
  17  this module can be used to access and manipulate the data file.""" 
  18   
  19  __id__ = "$Id: project.py 6735 2011-08-27 20:05:44Z yshang $" 
  20   
  21   
  22  # Storage/Data 
  23  from diffpy.refinementdata.project import Project as ProjectData 
  24  from diffpy.srrietveld.refinement import Refinement 
  25  from diffpy.srrietveld.pattern import Pattern, ExcludedRegion 
  26  from diffpy.srrietveld.atom import Atom 
  27  from diffpy.refinementdata.dataerrors import DataError 
  28  from diffpy.srrietveld.phase import Phase 
  29  from diffpy.srrietveld.profile import Profile 
  30  from diffpy.refinementdata.refinable import Refinable 
  31  from diffpy.refinementdata.hdf5.objectlist import ObjectList 
  32  from diffpy.refinementdata.hdf5.object import Object 
  33  import diffpy.srrietveld.gui.srrguiglobals as GLOBALS 
  34  from diffpy.srrietveld.exceptions import SrrFileError, SrrIOError 
  35  import os, shutil, stat, zipfile, tempfile, math 
  36  import diffpy.srrietveld.utility as UTILS 
  37  from diffpy.srrietveld.objectinfo import ObjectInfo, setObjectInfo 
  38   
39 -class Project(ProjectData):
40 ''' 41 Project class defines a SrRietveld refinement project. It consists of a 42 number of single refinements. A project can call different Rietveld engines 43 to run its refinements. 44 45 A project object contains the list of refinemenets, and list of jobs. 46 '''
47 - def __init__ (self, path, mode='a', mainframe=None):
48 """ 49 Create a SrRietveld project from saved project data file. 50 51 @type path: file path string 52 @param path: the path of the project data file. If it does not exist, 53 an empty project will be created. 54 @type mode: one character string, possible value r, w, a 55 @param mode: the mode to open the project. 56 - r -- read, 57 - w -- write, 58 - a -- append 59 @type mainframe: the main frame object 60 @param mainframe: the pointer to the main window in the GUI object 61 set to be None in script mode 62 """ 63 ProjectData.__init__(self, path, mode) 64 65 self.mainframe = mainframe 66 self.plots = [] 67 self.jobs = [] 68 for fit in self.objects: 69 setObjectInfo(fit) 70 fit.alive = 0 71 72 # save the path of the hdf5 data file location, used to determine the 73 # relative paths of the raw files 74 self.basepath = os.path.dirname(path) 75 self.fullpath = path 76 77 self.updatePaths() 78 # indication that if the project data need to be saved to the disk 79 self.altered = True 80 81 return
82
83 - def addJob(self, job):
84 """ 85 Add a job to the job list of the project 86 87 @type job: a Job object 88 @param job: a refinement job to be added 89 @return: no return value 90 """ 91 if job not in self.jobs: 92 self.jobs.append(job) 93 94 return
95
96 - def addToZip(self, zipFilePath):
97 """ 98 Archive the project data file, and other associated files, such as the 99 histogram files, instrument files, and incident spectrum files into a 100 zip folder. 101 102 @type zipFilePath: zip file path string 103 @param zipFilePath: the full path to save the zip file 104 @return: no return value 105 """ 106 107 # the file can only be compressed (ZIP_DEFLATED) when the zlib is installed. 108 # TODO: the allowZip64 should be true if the zip file is larger than 2gb 109 # do not see the needs so far 110 try: 111 import zlib 112 zipflg = zipfile.ZIP_DEFLATED 113 except ImportError: 114 zipflg = zipfile.ZIP_STORED 115 116 zpf = zipfile.ZipFile(zipFilePath, 'w', zipflg) 117 118 # create a temp project file, change the file paths 119 tmpf, tmpfp = tempfile.mkstemp() 120 121 self.__export(tmpfp) 122 123 tmpproj = Project(tmpfp) 124 # change the mode of a new file 125 os.chmod(tmpfp, stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH) 126 tmpproj.basepath = os.path.dirname(tmpfp) 127 tmpproj.fullpath = tmpfp 128 # update the relative paths 129 tmpproj.__updateRelativePaths() 130 131 # get the files from teh pattern 132 fitList = tmpproj.listRefinements() 133 arcname = lambda pr, fp: os.path.join(pr, os.path.basename(fp)) 134 135 #store the arc names already saved 136 savedArcNames = [] 137 138 for fitid, fit in enumerate(fitList): 139 patterns = fit.getObject("Pattern") 140 for bid, pt in enumerate(patterns): 141 for param in ["Datafile", "MFIL", "Instrumentfile"]: 142 arcdir = os.path.join(param.lower(), fit.name) 143 144 ds = pt.get(param) 145 146 if ds: 147 dname = ds.name 148 absds = pt.get(dname + '_abspath') 149 # loop over all datasets in the refinement list 150 for index in fit.range(): 151 filepath = fit.findFileFromDataset(ds, index) 152 if filepath: 153 # only write the file when it's not saved before 154 if arcname(arcdir, filepath) not in savedArcNames: 155 zpf.write(filepath, arcname(arcdir, filepath)) 156 savedArcNames.append(arcname(arcdir, filepath)) 157 158 elif param == "Datafile": 159 df = fit.getDataFileFromPattern(index, bid) 160 if df: 161 tmpDir = tempfile.mkdtemp() 162 filepath = df.dump(tmpDir) 163 if arcname(param, filepath) not in savedArcNames: 164 zpf.write(filepath, arcname(param, filepath)) 165 savedArcNames.append(arcname(param, filepath)) 166 else: 167 __msg = "The file (%s) for index %s, bank id %s is not found.The project in zip file may not be able to run correctly. " % (param, str(fit.getFlatIndex(index)), str(bid)) 168 UTILS.printWarning(__msg) 169 170 # change the file paths in the copied project 171 if filepath: 172 ds[index] = arcname(arcdir, filepath) 173 if absds: 174 absds[index] = '' 175 # save the change 176 tmpproj.save() 177 # the archived project file name is set to be same to the zip file name 178 projectFileName = os.path.splitext(os.path.basename(zipFilePath))[0] + '.srr' 179 zpf.write(tmpfp, projectFileName) 180 181 zpf.close() 182 183 return
184
185 - def close(self):
186 """ 187 Close the opened project. 188 @return: no return value 189 """ 190 # save the data 191 for plot in self.plots: 192 if plot.alive: 193 plot.close() 194 195 ProjectData.close(self) 196 return
197
198 - def deleteJob(self, job):
199 """ 200 Delete job from the job list 201 202 @type job: a job object 203 @param job: the job to be deleted 204 @return: no return value 205 206 """ 207 try: 208 self.jobs.remove(job) 209 except: 210 if GLOBALS.isDebug: 211 UTILS.printDebugInfo() 212 else: 213 UTILS.printWarning("No such job in the job list") 214 215 return
216
217 - def exportEngineFit(self, refinementObj, index=None):
218 """Build an engine fit based on given fit and index. 219 220 @type refinementObj: a diffpy.refinementdata.Refinement object or its derivatives 221 @param refinementObj: a refinement usually contains a series of single refinements 222 The information will be read from this object to construct 223 a engine fit, which contains information for a single 224 step refinement 225 @type index: tuple or int 226 @param index: the index of the single refinement in the refinementObj 227 a tuple represents the index in a multidimensinal data 228 if the refinementObj is one dimensional, index can be 229 an integer 230 @return: an engine fit object 231 """ 232 enginefit = exportEngineObject(None, None, refinementObj, index) 233 234 if refinementObj.getEngineType() == 'gsas': 235 engineFileStr = str(refinementObj.get('enginefile')[index]) 236 237 238 enginefit.expfile.readString(engineFileStr) 239 240 # connect phase with pattern 241 for i, contribution in enumerate(enginefit.get("Contribution")): 242 mycontribution = refinementObj.getObject("Contribution").getObject(i) 243 patternindex = mycontribution.getAttr('patternindex') 244 phaseindex = mycontribution.getAttr('phaseindex') 245 if patternindex is None: 246 patternindex = 0 247 if phaseindex is None: 248 phaseindex =0 249 250 # if a project is saved in a 64 bit system and loaded in a 32bit system 251 # the int64 type has to converted to an ordinary integer 252 patternindex = int(patternindex) 253 phaseindex = int(phaseindex) 254 255 pattern = enginefit.get("Pattern", patternindex) 256 phase = enginefit.get("Phase", phaseindex) 257 contribution.setParentPattern(pattern) 258 contribution.setParentPhase(phase) 259 260 return enginefit
261
262 - def getRefinementByName(self, name):
263 ''' 264 Get the refinement object by its name 265 266 @type name: string 267 @param name: the refinement name to get 268 269 @return: the refinement object with the name, None if there is no such 270 refinement object. If there are duplicate names, the first 271 refinement object with this name will be returned, and the 272 program with emit a warning 273 ''' 274 275 rv = [None] 276 rv = [obj for obj in self.listRefinements() if obj.name == name] 277 if len(rv) > 1: 278 __msg = 'Duplicate refinement names, the first refinement is returned' 279 UTILS.printWarning(__msg) 280 281 return rv[0]
282
283 - def importDataFiles(self):
284 """ 285 Read the data files and instrument files into the pattern objects 286 287 @return: no return value 288 """ 289 # since the job will run somewhere else, pack all its datafiles. 290 for pattern in self.listObjects(Pattern, True): 291 datafiles = pattern.get('Datafile') 292 instruments = pattern.get('Instrumentfile') 293 uniquefiles = set() 294 for index in pattern.range(): 295 if datafiles and datafiles[index]: 296 uniquefiles.add(datafiles[index]) 297 if instruments and instruments[index]: 298 uniquefiles.add(instruments[index]) 299 for fname in uniquefiles: 300 pattern.importFile(fname) 301 302 return
303
304 - def importEngineFile(self, fullpath, refinementObjName=None):
305 """ 306 Import an engine file (EXP or pcr files). A refinement will be created 307 based on the data in the imported engine file. 308 309 @type fullpath: file path string 310 @param fullpath: the full file path of the engine file to be imported 311 @type refinementObjName: string 312 @param refinementObjName: the proposed name of the refinement object to be created. 313 If the refinementObjName is None, the engine file name will be used 314 @return: no return value 315 """ 316 import os.path 317 318 # if there is already one fit loaded, keep the path of the project, if 319 # this is the first fit loaded, change the path 320 if len(self.listRefinements()) == 0: 321 self.basepath = os.path.dirname(fullpath) 322 323 filename = os.path.basename(fullpath) 324 if not refinementObjName: 325 refinementObjName = os.path.splitext(filename)[0] 326 327 refinementObjName = self.verifyFitName(refinementObjName) 328 329 fileformat = os.path.splitext(filename)[1].lower() 330 if fileformat == ".exp": 331 from diffpy.pygsas.fitloader import loadFitFromEXP 332 enginefit = loadFitFromEXP(fullpath) 333 elif fileformat == ".pcr": 334 from diffpy.pyfullprof.fit import Fit as EngineFit 335 from diffpy.pyfullprof.pcrfilereader import ImportFitFromFullProf 336 importfile = ImportFitFromFullProf(fullpath) 337 enginefit = EngineFit(None) 338 importfile.ImportFile(enginefit) 339 else: 340 raise SrrFileError('Engine file name extension is not supported.' + \ 341 'Only EXP and pcr files are currently supported in SrRietveld.') 342 343 enginefit.fixAll() 344 fit = self.importEngineFit(enginefit, refinementObjName, fullpath) 345 346 self.altered = True 347 348 return fit
349
350 - def importEngineFit(self, enginefit, name, fullpath, index=None):
351 """ 352 Import an engine Fit object. It is provided for backward compatiblity, 353 namely when the fit is created using xml interface. 354 355 @type enginefit: an engine Fit object 356 @param enginefit: the engine fit, which contains all the information for 357 one refinement for that engine (GSAS or FullProf) 358 @type name: string 359 @param name: the name to be assigned to the new Fit object 360 @type fullpath: file path string 361 @param fullpath: the fullpath is used to locate the supplement files 362 (data, instrument, engine files) 363 @type index: 364 @param index: the index to a Fit in a Fit array ( not applicable if shape 365 is given ) 366 367 @return: the new Fit object 368 """ 369 370 # Try to find the files 371 # since all path in engine files are relative to the file itself, 372 # get the dir of the engine file first 373 engineFileDir = os.path.dirname(fullpath) 374 375 # fit is the storage of the new Fit. 376 fit = importEngineObject(self, enginefit, name, index=index) 377 378 # save the abs path in patterns 379 patterns = fit.getObject('Pattern') 380 for bankid, pattern in enumerate(patterns): 381 382 absPath = {} 383 for param in ['Datafile', 'MFIL', 'Instrumentfile']: 384 absPath[param] = None 385 ds = pattern.get(param) 386 if ds is None: 387 continue 388 p = str(pattern.get(param).first()) 389 if not p:# this parameter is not set 390 __message = 'The %s param is not set in the engine' % (param) 391 print __message 392 absPath[param] = "" 393 else: 394 absPath[param] = self.__getResourceFileAbsPath(p, engineFileDir) 395 # print warnings if file does not exist 396 if not os.path.exists(absPath[param]): # The file does not exist 397 398 paramDesc = '' 399 if param == 'Datafile': 400 paramDesc = 'data file' 401 elif param == 'MFIL': 402 paramDesc = 'incident spectrum' 403 elif param == 'Instrumentfile': 404 paramDesc = 'instrument file' 405 406 __warning = 'Can not find the ' + paramDesc + \ 407 ' for pattern %s, ' % (bankid + 1) + \ 408 'you may need to reload the file or input the right file path. ' 409 410 from diffpy.srrietveld.utility import printWarning 411 printWarning(__warning) 412 413 if absPath['Datafile'] is not None: 414 fit.loadDataFile(absPath['Datafile'], index = 0, bankid = bankid) 415 if absPath['Instrumentfile'] is not None: 416 fit.loadInstrumentFile(absPath['Instrumentfile'], index = 0, bankid = bankid) 417 if absPath['MFIL'] is not None: 418 fit.loadIncidentSpectrumFile(absPath['MFIL'], index = 0, bankid = bankid) 419 420 fit.saveEngineFile(fullpath) 421 #fit.reshape((1, ), ['Index', ], repeat=True) 422 fit.reshape((1, ), ['Index', ], repeat=True) 423 fit.set('Index', [1]) 424 fit.alive = 0 425 426 return fit
427
428 - def importProject(self, projPath):
429 """ 430 Import the refinements from another project file. 431 @type projPath: file path string 432 @param projPath: the file path of the project to be imported. 433 @return: no return value 434 """ 435 tmpfile, tmppath = tempfile.mkstemp() 436 # has to check the name of the refinement not the same to the 437 # existing refinement, so needs to copy the refinements to another 438 # place before copy 439 shutil.copy(projPath, tmppath) 440 proj = Project(tmppath) 441 442 for fit in proj.listRefinements(): 443 fit.rename(self.verifyFitName(fit.name)) 444 445 self.add(proj) 446 447 return
448
449 - def listParams(self, obj, recursively=False, excluded=[]):
450 """ 451 List all parameters in an object. 452 453 454 @type obj: a data object 455 @param obj: the object whose parameters to be listed 456 @type recursively: boolean 457 @param recursively: if True, the parameters in the sub-objects will 458 be included. False otherwise 459 @type excluded: a list of objects, whose child parameters should be 460 excluded in the result 461 @return: a list of parameters (Dataset object). 462 """ 463 params = [] 464 for name in obj.info.listParams(): 465 # to all the datasets 466 467 params.append(obj.get(name)) 468 469 # subobjects 470 if recursively: 471 for childobj in obj.listObjects(): 472 if childobj.name not in excluded: 473 # continue 474 #for grandchildobj in childobj.listObjects(Refinable, True): 475 params.extend(self.listParams(childobj, True)) 476 477 return params
478
479 - def listRefinementNames(self):
480 """Different from the methdo listRefinemetns, this function will return 481 a list of names of the refinements 482 @return: the names of the refinements in this project""" 483 484 return [refObj.name for refObj in self.listRefinements()]
485
486 - def plot(self, ydatasets, xdataset=None, xlabel=''):
487 """ 488 Make a plot of selected Dataset objects 489 @type ydatasets: a list of Dataset objects 490 @param ydatasets: contains the list of y values in the plot 491 @type xdataset: a Dataset object 492 @param xdataset: the x values in the plot. If xdataset is None 493 a list of consecutive integers will be used 494 @return: the plot figure object 495 """ 496 497 from diffpy.refinementdata.plot.figure import Figure 498 name = ','.join([dataset.name for dataset in ydatasets]) 499 figure = Figure(self.mainframe, name) 500 if xlabel: 501 figure.selection.xlabel = xlabel 502 503 # prepare the metadata, using all refinement metadata 504 datasets = ydatasets[:] 505 if xdataset: 506 datasets.append(xdataset) 507 fits = set() 508 for dataset in datasets: 509 fits.add(dataset.owner.findRefinement()) 510 511 metadata = [] 512 for fit in fits: 513 metadata.extend(fit.getMetaData()) 514 515 figure.plot(xdataset, ydatasets, metadata=metadata) 516 self.plots.append(figure) 517 figure.frame.updateIndexBox() 518 return figure
519
520 - def plotHistory(self, datasets):
521 """ 522 Make a plot of the historic values of selected parameters. 523 524 @type datasets: a list of Dataset objects 525 @param datasets: the data to be plotted 526 @return: no return value 527 """ 528 histories = [] 529 for dataset in datasets: 530 history = dataset.owner.getHistory(dataset.name) 531 if history is not None: 532 histories.append(history) 533 534 self.plot(histories, xlabel='step') 535 return
536
537 - def plotPatterns(self, objects):
538 """Make a quick plot of patterns. 539 540 objects -- a list of selected objects 541 """ 542 from diffpy.refinementdata.plot.patternfigure import PatternFigure 543 def _plotPattern(_pattern): 544 _figure = PatternFigure(self.mainframe, _pattern.path, figsize=(6, 4)) 545 _figure.selection.addMetadata(_pattern.findRefinement().getMetaData()) 546 _figure.plotPattern(_pattern) 547 _figure.frame.updateIndexBox() 548 self.plots.append(_figure)
549 550 551 for object in objects: 552 if isinstance(object, Pattern): 553 _plotPattern(object) 554 else: 555 patterns = object.listObjects(Pattern, True) 556 if patterns: 557 for pattern in patterns: 558 _plotPattern(pattern) 559 return
560
561 - def updateFit(self, fitrt):
562 """ 563 Update the fit data with an fit object. 564 565 @param fitrt: the run-time fit instance 566 @return: no return value 567 """ 568 # change steps. 569 fitrt.fit.steps[fitrt.index] = fitrt.step+1 570 571 def _updateValue(parpath): 572 # get the latest value from the refinement 573 _value = fitrt.enginefit.getByPath(parpath) 574 575 # get the historical storge of the parameter 576 _history = fitrt.fit.getHistoryByPath(parpath) 577 if _history is None: 578 _history = fitrt.fit.addHistoryByPath(parpath) 579 580 _history.update(fitrt.step, _value, fitrt.index) 581 582 # data is the up-to-date storage 583 _param = fitrt.fit.getByPath(parpath) 584 _param[fitrt.index] = _value 585 586 return 587 588 def _updateSigma(constraint): 589 # get the sigma, using the constraint's parname without index 590 _path = constraint.owner.path+'.'+constraint.parname 591 _sigma = fitrt.fit.getSigmaByPath(_path) 592 593 if _sigma is None: 594 _sigma = fitrt.fit.addSigmaByPath(_path) 595 596 if constraint.index is not None: 597 if fitrt.index is None: 598 _index = constraint.index 599 else: 600 if isinstance(fitrt.index, tuple): 601 _index = list(fitrt.index) 602 else: 603 _index = [fitrt.index,] 604 _index.append(constraint.index) 605 _index = tuple(_index) 606 _sigma[_index] = constraint.sigma 607 else: 608 _sigma[fitrt.index] = constraint.sigma 609 610 return 611 612 def _tolist(refl): 613 # put the reflection positiosn to a list of positions 614 _value = [] 615 for key in refl: 616 for dd in refl[key]: 617 for val in dd.values(): 618 _value.append(val[0]) 619 #_value = [x[0] for x in refl[0][0].values()] 620 _value.sort() 621 return _value 622 623 # get all constrained parameter paths, but count paramlist only once. 624 pathlist = [ c.owner.path+'.'+c.parname for c in fitrt.enginefit.Refine.constraints] 625 pathlist = set(pathlist) 626 627 for path in pathlist: 628 _updateValue(path) 629 # update Uiso in fullprof or Biso in GSAS refinemetns 630 if path.endswith('Uiso') or path.endswith('Biso'): 631 if path.endswith('Uiso'): 632 _path = path.rsplit('.', 1)[0] + '.Biso' 633 elif path.endswith('Biso'): 634 _path = path.rsplit('.', 1)[0] + '.Uiso' 635 _updateValue(_path) 636 637 for constraint in fitrt.enginefit.Refine.constraints: 638 _updateSigma(constraint) 639 # update Uiso in fullprof or Biso in GSAS refinemetns 640 if constraint.parname in ['Biso', 'Uiso']: 641 if constraint.parname == 'Uiso': 642 _parname = 'Biso' 643 elif constraint.parname == 'Biso': 644 _parname = 'Uiso' 645 646 _path = constraint.owner.path + '.' + _parname 647 _sigma = fitrt.fit.getSigmaByPath(_path) 648 649 if _sigma is None: 650 _sigma = fitrt.fit.addSigmaByPath(_path) 651 652 if _parname == 'Biso': 653 _val = constraint.sigma * 8 * math.pow(math.pi, 2) 654 elif _parname == 'Uiso': 655 _val = constraint.sigma / ( 8 * math.pow(math.pi, 2)) 656 657 if constraint.index is not None: 658 if fitrt.index is None: 659 _index = constraint.index 660 else: 661 if isinstance(fitrt.index, tuple): 662 _index = list(fitrt.index) 663 else: 664 _index = [fitrt.index,] 665 _index.append(constraint.index) 666 _index = tuple(_index) 667 668 _sigma[_index] = _val 669 else: 670 _sigma[fitrt.index] = _val 671 672 673 _updateValue('Chi2') 674 675 for pattern in fitrt.enginefit.Pattern.get(): 676 _updateValue(pattern.path+'.Rp') 677 _updateValue(pattern.path+'.Rwp') 678 679 # update the histogram value saved in pattern 680 for dname in ['xobs', 'yobs', 'ycal', 'refl']: 681 mypattern = fitrt.fit.getByPath(pattern.path) 682 dataset = mypattern.get(dname) 683 if dname == 'refl': 684 data = _tolist(getattr(pattern, '_reflections')) 685 else: 686 data = getattr(pattern, '_'+dname) 687 if dataset is None: 688 mypattern.set(dname, data, repeat=True, grow=True) 689 else: 690 dataset[fitrt.index] = data 691 692 693 if self.mainframe: 694 from diffpy.refinementdata.plot import backend 695 backend.lock() 696 if fitrt.step > 0: 697 self.mainframe.updateStrategyPanel(fitrt) 698 self.updatePlot(fitrt.index) 699 backend.unlock() 700 return 701
702 - def updatePaths(self):
703 """ 704 Update the histogram and instrument file paths saved in the project 705 @return: no return value 706 """ 707 for fit in self.listRefinements(): 708 patterns = fit.getObject("Pattern") 709 for pt in patterns: 710 for param in ['Datafile', 'MFIL', 'Instrumentfile']: 711 ds = pt.get(param) 712 if ds: 713 for idx in fit.range(): 714 fit.findFileFromDataset(ds, idx) 715 return
716
717 - def updatePlot(self, index=None):
718 """ 719 Update all plots based the data in the project 720 721 @type index: tuple or int 722 @param index: the index of the curve in the plot 723 @return: no return value 724 """ 725 for plot in self.plots[:]: 726 if plot.alive: 727 plot.update(index) 728 else: 729 self.plots.remove(plot) 730 731 return
732
733 - def saveToDisk(self, fullpath):
734 """ 735 Save the project data to local drive. 736 737 @type fullpath: file path string 738 @param fullpath: the file path to save the project data 739 @return: no return value 740 """ 741 self.save() 742 try: 743 744 # if file names are same, remove the original one first 745 srcpath = self.fullpath 746 747 if self.fullpath == fullpath: 748 tmpfobj, tmpfpath = tempfile.mkstemp() 749 shutil.copy2(self.fullpath, tmpfpath) 750 os.remove(fullpath) 751 srcpath = tmpfpath 752 753 shutil.copy2(srcpath, fullpath) 754 # change the mode of a new file 755 os.chmod(fullpath, stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH) 756 self.basepath = os.path.dirname(fullpath) 757 self.fullpath = fullpath 758 self.__updateRelativePaths() 759 self.save() 760 761 except Exception,e: 762 if GLOBALS.isDebug: 763 import sys, traceback 764 exc_type, exc_value, exc_traceback = sys.exc_info() 765 traceback.print_exception(exc_type, exc_value, exc_traceback, 766 limit=2, file=sys.stdout) 767 else: 768 raise SrrIOError("Can not save file: " + str(e)) 769 770 return
771
772 - def verifyFitName(self, fitName):
773 """ 774 Compare the fitname with existing names of refienments in the project. 775 A valid name will be returned. 776 777 - All the I{dots (.)} will be replaced with I{underscore (_)} 778 - The same refinement names will be surfixed with numbers with brakets 779 such as I{Refinement}, I{Refinement (1)}, I{Refinement (2)} 780 @type fitName: string 781 @param fitName: refinement name to be verified 782 783 @return: If fitName is valid, return value 784 will be the same as fitName; if fitName conains invalid 785 characters or has duplications, a valid refinement name will 786 be returned 787 """ 788 nameid = 0 789 fitName = fitName.replace('.', '_') 790 for fit in self.listRefinements(): 791 if fit.name == fitName: 792 nameid += 1 793 if nameid != 0: 794 newname = fitName + "_%s"%(nameid) 795 else: 796 newname = fitName 797 798 return newname
799
800 - def __export(self, filePath):
801 """ 802 Export the project to another location, without changing the original file 803 This is different from saving as, since the original state, basepath, 804 etc. of the project is not changed. The original file will be copied to 805 a temp dir, and then the project will be saved to the export directory. 806 Then the protected old file will be saved back 807 808 @type filePath: file path string 809 @param filePath: the file path where the project will be exported 810 @return: no return value 811 """ 812 813 oldPath = self.fullpath 814 oldBase = self.basepath 815 oldAltered = self.altered 816 tmpfile, tmppath = tempfile.mkstemp() 817 818 shutil.copy(self.fullpath, tmppath) 819 820 self.saveToDisk(filePath) 821 822 shutil.copy(tmppath, oldPath) 823 try: 824 os.remove(tmppath) 825 except: 826 UTILS.printWarning("Can not delete the temp project file after archive the project") 827 828 self.fullpath = oldPath 829 self.basepath = oldBase 830 self.altered = oldAltered 831 self.__updateRelativePaths() 832 833 return
834
835 - def __updateRelativePaths(self, basepath = None):
836 """ 837 As long as the base path is changed, the relative paths saved in 838 patterns need to be updated. 839 840 @type basepath: folder path string 841 @param basepath: the basepath used to determine the relative path, use 842 project basepath if None 843 @return: no return value 844 845 """ 846 if basepath is None: basepath = self.basepath 847 848 fits = self.listRefinements() 849 for fit in fits: 850 patterns = fit.getObject("Pattern") 851 for pt in patterns: 852 for param in ['Datafile', 'MFIL', 'Instrumentfile']: 853 ds = pt.get(param) 854 if ds is not None: 855 for idx in fit.range(): 856 filepath = fit.findFileFromDataset(ds, idx) 857 if filepath is not None: 858 fit.saveFilePathToDataset(filepath, ds, 859 index = idx, 860 basepath = basepath) 861 862 return
863
864 - def __getResourceFileAbsPath(self, relFilePath, engineFileDir):
865 """ 866 Get the absolute file path. The filePath may be relative, to the engine file 867 directory, which may also be different from the current working dir 868 869 @type relFilePath: file path string 870 @param relFilePath: relative file path 871 @type engineFileDir: directory path string 872 @param engineFileDir: the directory where the engine file sits 873 @return: the absolute file path 874 """ 875 absFilePath = None 876 if not os.path.isabs(relFilePath): 877 absFilePath = os.path.join(engineFileDir, relFilePath) 878 else: 879 absFilePath = relFilePath 880 return absFilePath
881 882 # end class Project 883 884 ## Import/Export an engine object. 885 # The following functions Define the translation from Rietveld engine classes 886 # to RefinementData classes 887
888 -def importEngineObject(myobjowner, engineobj, id, index=None):
889 """ 890 Import an engine object into myobject 891 892 @type myobjowner: an data object 893 @param myobjowner: the owner of the refinement object to be created 894 @type engineobj: an engine fit object 895 @param engineobj: the engine object to be imported 896 @type id: integer or string 897 @param id: the name/index of the refinement object to be created/obtained in myobjowner 898 @type index: index or tuple 899 @param index: when modify an object, indicate the index to the value to be modified 900 901 return: the newly created local object 902 """ 903 # Do not import refine and its children 904 if engineobj.__class__.__name__.startswith('Refine'): 905 return None 906 907 908 myobj = myobjowner.getObject(id) 909 if myobj is None: 910 #Get the corresponding HDF5 class 911 objclassname = engineobj.__class__.__name__ 912 if objclassname.startswith('Fit'): 913 myobjclass = Refinement 914 elif objclassname.startswith('Pattern'): 915 myobjclass = Pattern 916 elif objclassname.startswith('Phase'): 917 myobjclass = Phase 918 elif objclassname.startswith('Profile'): 919 myobjclass = Profile 920 elif engineobj.name.startswith('Atom['): 921 myobjclass = Atom 922 elif engineobj.name.startswith('ExcludedRegion['): 923 myobjclass = ExcludedRegion 924 else: 925 # if nothing matches, return the default storage type 926 myobjclass = Refinable 927 928 #Create the HDF5 Object 929 if isinstance(myobjowner, ObjectList): 930 assert(id==len(myobjowner)) 931 myobj = myobjowner.appendObject(myobjclass) 932 else: 933 myobj = myobjowner.addObject(id, myobjclass) 934 935 myobj.setAttr('rietveldcls', engineobj.__class__.__module__ + '.' 936 + engineobj.__class__.__name__) 937 engineobjclass = getattr(__import__(engineobj.__class__.__module__, 938 globals(), locals(), 939 [engineobj.__class__.__name__], -1), 940 engineobj.__class__.__name__) 941 myobj.info = ObjectInfo(engineobjclass) 942 943 if engineobj.__class__.__name__.startswith("Contribution"): 944 myobj.setAttr("phaseindex", engineobj.getPhaseIndex()) 945 myobj.setAttr("patternindex", engineobj.getPatternIndex()) 946 947 def _importEngineParameter(infodict): 948 for param in infodict: 949 # update parameter itself 950 dataset = myobj.get(param) 951 952 if dataset is None: 953 info = infodict[param] 954 grow = False 955 #NOTE: since there are two engine classes for the same thing, we 956 # check its name instead the type. 957 if info.__class__.__name__ == 'StringInfo' and not info.fixlen: 958 grow = True 959 960 dataset = myobj.set(param, engineobj.get(param), repeat=True, grow=grow) 961 if dataset.name == 'BACK': 962 labels = myobj.labels 963 labels.append('Order') 964 dataset.setLabels(labels) 965 else: 966 dataset[index] = engineobj.get(param)
967 968 # add parameters 969 _importEngineParameter(engineobj.ParamDict) 970 971 # add parameter list 972 _importEngineParameter(engineobj.ParamListDict) 973 974 # add subclass, which is a single object 975 for childname in engineobj.ObjectDict: 976 engineobjchild = engineobj.get(childname) 977 importEngineObject(myobj, engineobjchild, childname, index=index) 978 979 # add container, which is a list of objects 980 for childname in engineobj.ObjectListDict: 981 mycontainer = myobj.getObject(childname) 982 if mycontainer is None: 983 mycontainer = myobj.addObject(childname, ObjectList) 984 mycontainer.info = ObjectInfo() 985 enginecontainer = engineobj.get(childname) 986 for i in range(len(enginecontainer)): 987 engineobjchild = engineobj.get(childname, i) 988 importEngineObject(mycontainer, engineobjchild, i, index=index) 989 990 return myobj 991 992
993 -def exportEngineObject(engineobjowner, key, myobj, index=None):
994 """ 995 Export a saved HDF5Object to an engine object. 996 997 @type engineobjowner: a data object 998 @param engineobjowner: the owner of the engine object 999 1000 @type myobj: a refinement object 1001 @param myobj: the source refinement object 1002 @type index: interger or tuple 1003 @param index: the index to the single value if the object is multiplexed 1004 1005 @return: the newly created engine object 1006 """ 1007 try: 1008 engineclassname = myobj.getAttr('rietveldcls') 1009 modulename, classname = engineclassname.rsplit('.', 1) 1010 module = __import__(modulename, globals(), locals(), [classname], -1) 1011 engineclass = getattr(module, classname) 1012 except (ImportError, AttributeError, ValueError): 1013 #NOTE: If one class is removed after upgrade, do error handling here. 1014 raise DataError('Unknown object "%s" found.'%engineclassname) 1015 1016 engineobj = engineclass(engineobjowner) 1017 1018 if engineobjowner is not None: 1019 engineobjowner.set(key, engineobj) 1020 1021 # traverse the HDF5Object 1022 import numpy 1023 for param in myobj.list(): 1024 if param.name not in engineobj.ParamDict and param.name not in engineobj.ParamListDict: 1025 continue 1026 if len(param.shape) == 0: # unrefined fit 1027 value = param.first() 1028 else: 1029 value = param[index] 1030 1031 if isinstance(value, numpy.ndarray): 1032 if value.ndim == 1: 1033 for x in value: 1034 engineobj.set(param.name, x) 1035 elif value.ndim == 0: 1036 engineobj.set(param.name, value[()]) 1037 else: 1038 raise DataError('Parameter "%s" has more than one dimensions' %param) 1039 continue 1040 engineobj.set(param.name, value) 1041 1042 # parse the object 1043 for myobjchild in myobj.objects: 1044 subkey = myobjchild.name 1045 if isinstance(myobjchild, ObjectList): 1046 # this is a container object 1047 for obj in myobjchild.objects: 1048 exportEngineObject(engineobj, subkey, obj, index) 1049 else: 1050 exportEngineObject(engineobj, subkey, myobjchild, index) 1051 1052 return engineobj
1053
1054 -def copyObjectData(srcobj, destobj, srcindexlist, destindexlist):
1055 """ 1056 Copy data to another object. 1057 1058 srcobj -- the source object 1059 destobj -- the target parent object 1060 srcindexlist -- a list of indices to read data from srcobj 1061 destindexlist -- a list of indices to write data to destobj 1062 """ 1063 # copy attributes 1064 for attr, val in srcobj.listAttrs(withValue=True): 1065 destobj.setAttr(attr, val) 1066 1067 destobj.copyMeta(srcobj) 1068 1069 # set data 1070 for srcdata in srcobj.list(): 1071 destdata = destobj.get(srcdata.name) 1072 if destdata is None: 1073 # create one 1074 destdata = destobj.replicate(srcdata, srcindexlist[0], None) 1075 1076 for srcindex, destindex in zip(srcindexlist, destindexlist): 1077 destdata[destindex] = srcdata[srcindex] 1078 1079 if isinstance(srcobj, ObjectList): 1080 # traverse by index 1081 for i, childobj in enumerate(srcobj.listObjects()): 1082 destchildobj = destobj.getObject(i) 1083 if destchildobj is None: 1084 # the destobj's children are down by 1 1085 destchildobj = destobj.appendObject(childobj.__class__) 1086 1087 # copy data 1088 copyObjectData(childobj, destchildobj, srcindexlist, destindexlist) 1089 else: # a normal object 1090 # traverse by name 1091 for childobj in srcobj.listObjects(): 1092 destchildobj = destobj.getObject(childobj.name) 1093 if destchildobj is None: 1094 destchildobj = destobj.addObject(childobj.name, childobj.__class__) 1095 1096 copyObjectData(childobj, destchildobj, srcindexlist, destindexlist) 1097 1098 if isinstance(srcobj, Refinable): 1099 # A Refinable instance. Both src and dest has history and sigma by default 1100 copyObjectData(srcobj.history, destobj.history, srcindexlist, destindexlist) 1101 copyObjectData(srcobj.sigma, destobj.sigma, srcindexlist, destindexlist) 1102 1103 return destobj
1104 # EOF 1105