1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 """This module contains the definition of the Project class. A SrRietveld Project
16 object can be initiated from saved refinement project file, and the methods in
17 this module can be used to access and manipulate the data file."""
18
19 __id__ = "$Id: project.py 6735 2011-08-27 20:05:44Z yshang $"
20
21
22
23 from diffpy.refinementdata.project import Project as ProjectData
24 from diffpy.srrietveld.refinement import Refinement
25 from diffpy.srrietveld.pattern import Pattern, ExcludedRegion
26 from diffpy.srrietveld.atom import Atom
27 from diffpy.refinementdata.dataerrors import DataError
28 from diffpy.srrietveld.phase import Phase
29 from diffpy.srrietveld.profile import Profile
30 from diffpy.refinementdata.refinable import Refinable
31 from diffpy.refinementdata.hdf5.objectlist import ObjectList
32 from diffpy.refinementdata.hdf5.object import Object
33 import diffpy.srrietveld.gui.srrguiglobals as GLOBALS
34 from diffpy.srrietveld.exceptions import SrrFileError, SrrIOError
35 import os, shutil, stat, zipfile, tempfile, math
36 import diffpy.srrietveld.utility as UTILS
37 from diffpy.srrietveld.objectinfo import ObjectInfo, setObjectInfo
38
40 '''
41 Project class defines a SrRietveld refinement project. It consists of a
42 number of single refinements. A project can call different Rietveld engines
43 to run its refinements.
44
45 A project object contains the list of refinemenets, and list of jobs.
46 '''
47 - def __init__ (self, path, mode='a', mainframe=None):
48 """
49 Create a SrRietveld project from saved project data file.
50
51 @type path: file path string
52 @param path: the path of the project data file. If it does not exist,
53 an empty project will be created.
54 @type mode: one character string, possible value r, w, a
55 @param mode: the mode to open the project.
56 - r -- read,
57 - w -- write,
58 - a -- append
59 @type mainframe: the main frame object
60 @param mainframe: the pointer to the main window in the GUI object
61 set to be None in script mode
62 """
63 ProjectData.__init__(self, path, mode)
64
65 self.mainframe = mainframe
66 self.plots = []
67 self.jobs = []
68 for fit in self.objects:
69 setObjectInfo(fit)
70 fit.alive = 0
71
72
73
74 self.basepath = os.path.dirname(path)
75 self.fullpath = path
76
77 self.updatePaths()
78
79 self.altered = True
80
81 return
82
84 """
85 Add a job to the job list of the project
86
87 @type job: a Job object
88 @param job: a refinement job to be added
89 @return: no return value
90 """
91 if job not in self.jobs:
92 self.jobs.append(job)
93
94 return
95
97 """
98 Archive the project data file, and other associated files, such as the
99 histogram files, instrument files, and incident spectrum files into a
100 zip folder.
101
102 @type zipFilePath: zip file path string
103 @param zipFilePath: the full path to save the zip file
104 @return: no return value
105 """
106
107
108
109
110 try:
111 import zlib
112 zipflg = zipfile.ZIP_DEFLATED
113 except ImportError:
114 zipflg = zipfile.ZIP_STORED
115
116 zpf = zipfile.ZipFile(zipFilePath, 'w', zipflg)
117
118
119 tmpf, tmpfp = tempfile.mkstemp()
120
121 self.__export(tmpfp)
122
123 tmpproj = Project(tmpfp)
124
125 os.chmod(tmpfp, stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)
126 tmpproj.basepath = os.path.dirname(tmpfp)
127 tmpproj.fullpath = tmpfp
128
129 tmpproj.__updateRelativePaths()
130
131
132 fitList = tmpproj.listRefinements()
133 arcname = lambda pr, fp: os.path.join(pr, os.path.basename(fp))
134
135
136 savedArcNames = []
137
138 for fitid, fit in enumerate(fitList):
139 patterns = fit.getObject("Pattern")
140 for bid, pt in enumerate(patterns):
141 for param in ["Datafile", "MFIL", "Instrumentfile"]:
142 arcdir = os.path.join(param.lower(), fit.name)
143
144 ds = pt.get(param)
145
146 if ds:
147 dname = ds.name
148 absds = pt.get(dname + '_abspath')
149
150 for index in fit.range():
151 filepath = fit.findFileFromDataset(ds, index)
152 if filepath:
153
154 if arcname(arcdir, filepath) not in savedArcNames:
155 zpf.write(filepath, arcname(arcdir, filepath))
156 savedArcNames.append(arcname(arcdir, filepath))
157
158 elif param == "Datafile":
159 df = fit.getDataFileFromPattern(index, bid)
160 if df:
161 tmpDir = tempfile.mkdtemp()
162 filepath = df.dump(tmpDir)
163 if arcname(param, filepath) not in savedArcNames:
164 zpf.write(filepath, arcname(param, filepath))
165 savedArcNames.append(arcname(param, filepath))
166 else:
167 __msg = "The file (%s) for index %s, bank id %s is not found.The project in zip file may not be able to run correctly. " % (param, str(fit.getFlatIndex(index)), str(bid))
168 UTILS.printWarning(__msg)
169
170
171 if filepath:
172 ds[index] = arcname(arcdir, filepath)
173 if absds:
174 absds[index] = ''
175
176 tmpproj.save()
177
178 projectFileName = os.path.splitext(os.path.basename(zipFilePath))[0] + '.srr'
179 zpf.write(tmpfp, projectFileName)
180
181 zpf.close()
182
183 return
184
186 """
187 Close the opened project.
188 @return: no return value
189 """
190
191 for plot in self.plots:
192 if plot.alive:
193 plot.close()
194
195 ProjectData.close(self)
196 return
197
199 """
200 Delete job from the job list
201
202 @type job: a job object
203 @param job: the job to be deleted
204 @return: no return value
205
206 """
207 try:
208 self.jobs.remove(job)
209 except:
210 if GLOBALS.isDebug:
211 UTILS.printDebugInfo()
212 else:
213 UTILS.printWarning("No such job in the job list")
214
215 return
216
218 """Build an engine fit based on given fit and index.
219
220 @type refinementObj: a diffpy.refinementdata.Refinement object or its derivatives
221 @param refinementObj: a refinement usually contains a series of single refinements
222 The information will be read from this object to construct
223 a engine fit, which contains information for a single
224 step refinement
225 @type index: tuple or int
226 @param index: the index of the single refinement in the refinementObj
227 a tuple represents the index in a multidimensinal data
228 if the refinementObj is one dimensional, index can be
229 an integer
230 @return: an engine fit object
231 """
232 enginefit = exportEngineObject(None, None, refinementObj, index)
233
234 if refinementObj.getEngineType() == 'gsas':
235 engineFileStr = str(refinementObj.get('enginefile')[index])
236
237
238 enginefit.expfile.readString(engineFileStr)
239
240
241 for i, contribution in enumerate(enginefit.get("Contribution")):
242 mycontribution = refinementObj.getObject("Contribution").getObject(i)
243 patternindex = mycontribution.getAttr('patternindex')
244 phaseindex = mycontribution.getAttr('phaseindex')
245 if patternindex is None:
246 patternindex = 0
247 if phaseindex is None:
248 phaseindex =0
249
250
251
252 patternindex = int(patternindex)
253 phaseindex = int(phaseindex)
254
255 pattern = enginefit.get("Pattern", patternindex)
256 phase = enginefit.get("Phase", phaseindex)
257 contribution.setParentPattern(pattern)
258 contribution.setParentPhase(phase)
259
260 return enginefit
261
263 '''
264 Get the refinement object by its name
265
266 @type name: string
267 @param name: the refinement name to get
268
269 @return: the refinement object with the name, None if there is no such
270 refinement object. If there are duplicate names, the first
271 refinement object with this name will be returned, and the
272 program with emit a warning
273 '''
274
275 rv = [None]
276 rv = [obj for obj in self.listRefinements() if obj.name == name]
277 if len(rv) > 1:
278 __msg = 'Duplicate refinement names, the first refinement is returned'
279 UTILS.printWarning(__msg)
280
281 return rv[0]
282
284 """
285 Read the data files and instrument files into the pattern objects
286
287 @return: no return value
288 """
289
290 for pattern in self.listObjects(Pattern, True):
291 datafiles = pattern.get('Datafile')
292 instruments = pattern.get('Instrumentfile')
293 uniquefiles = set()
294 for index in pattern.range():
295 if datafiles and datafiles[index]:
296 uniquefiles.add(datafiles[index])
297 if instruments and instruments[index]:
298 uniquefiles.add(instruments[index])
299 for fname in uniquefiles:
300 pattern.importFile(fname)
301
302 return
303
305 """
306 Import an engine file (EXP or pcr files). A refinement will be created
307 based on the data in the imported engine file.
308
309 @type fullpath: file path string
310 @param fullpath: the full file path of the engine file to be imported
311 @type refinementObjName: string
312 @param refinementObjName: the proposed name of the refinement object to be created.
313 If the refinementObjName is None, the engine file name will be used
314 @return: no return value
315 """
316 import os.path
317
318
319
320 if len(self.listRefinements()) == 0:
321 self.basepath = os.path.dirname(fullpath)
322
323 filename = os.path.basename(fullpath)
324 if not refinementObjName:
325 refinementObjName = os.path.splitext(filename)[0]
326
327 refinementObjName = self.verifyFitName(refinementObjName)
328
329 fileformat = os.path.splitext(filename)[1].lower()
330 if fileformat == ".exp":
331 from diffpy.pygsas.fitloader import loadFitFromEXP
332 enginefit = loadFitFromEXP(fullpath)
333 elif fileformat == ".pcr":
334 from diffpy.pyfullprof.fit import Fit as EngineFit
335 from diffpy.pyfullprof.pcrfilereader import ImportFitFromFullProf
336 importfile = ImportFitFromFullProf(fullpath)
337 enginefit = EngineFit(None)
338 importfile.ImportFile(enginefit)
339 else:
340 raise SrrFileError('Engine file name extension is not supported.' + \
341 'Only EXP and pcr files are currently supported in SrRietveld.')
342
343 enginefit.fixAll()
344 fit = self.importEngineFit(enginefit, refinementObjName, fullpath)
345
346 self.altered = True
347
348 return fit
349
351 """
352 Import an engine Fit object. It is provided for backward compatiblity,
353 namely when the fit is created using xml interface.
354
355 @type enginefit: an engine Fit object
356 @param enginefit: the engine fit, which contains all the information for
357 one refinement for that engine (GSAS or FullProf)
358 @type name: string
359 @param name: the name to be assigned to the new Fit object
360 @type fullpath: file path string
361 @param fullpath: the fullpath is used to locate the supplement files
362 (data, instrument, engine files)
363 @type index:
364 @param index: the index to a Fit in a Fit array ( not applicable if shape
365 is given )
366
367 @return: the new Fit object
368 """
369
370
371
372
373 engineFileDir = os.path.dirname(fullpath)
374
375
376 fit = importEngineObject(self, enginefit, name, index=index)
377
378
379 patterns = fit.getObject('Pattern')
380 for bankid, pattern in enumerate(patterns):
381
382 absPath = {}
383 for param in ['Datafile', 'MFIL', 'Instrumentfile']:
384 absPath[param] = None
385 ds = pattern.get(param)
386 if ds is None:
387 continue
388 p = str(pattern.get(param).first())
389 if not p:
390 __message = 'The %s param is not set in the engine' % (param)
391 print __message
392 absPath[param] = ""
393 else:
394 absPath[param] = self.__getResourceFileAbsPath(p, engineFileDir)
395
396 if not os.path.exists(absPath[param]):
397
398 paramDesc = ''
399 if param == 'Datafile':
400 paramDesc = 'data file'
401 elif param == 'MFIL':
402 paramDesc = 'incident spectrum'
403 elif param == 'Instrumentfile':
404 paramDesc = 'instrument file'
405
406 __warning = 'Can not find the ' + paramDesc + \
407 ' for pattern %s, ' % (bankid + 1) + \
408 'you may need to reload the file or input the right file path. '
409
410 from diffpy.srrietveld.utility import printWarning
411 printWarning(__warning)
412
413 if absPath['Datafile'] is not None:
414 fit.loadDataFile(absPath['Datafile'], index = 0, bankid = bankid)
415 if absPath['Instrumentfile'] is not None:
416 fit.loadInstrumentFile(absPath['Instrumentfile'], index = 0, bankid = bankid)
417 if absPath['MFIL'] is not None:
418 fit.loadIncidentSpectrumFile(absPath['MFIL'], index = 0, bankid = bankid)
419
420 fit.saveEngineFile(fullpath)
421
422 fit.reshape((1, ), ['Index', ], repeat=True)
423 fit.set('Index', [1])
424 fit.alive = 0
425
426 return fit
427
429 """
430 Import the refinements from another project file.
431 @type projPath: file path string
432 @param projPath: the file path of the project to be imported.
433 @return: no return value
434 """
435 tmpfile, tmppath = tempfile.mkstemp()
436
437
438
439 shutil.copy(projPath, tmppath)
440 proj = Project(tmppath)
441
442 for fit in proj.listRefinements():
443 fit.rename(self.verifyFitName(fit.name))
444
445 self.add(proj)
446
447 return
448
449 - def listParams(self, obj, recursively=False, excluded=[]):
450 """
451 List all parameters in an object.
452
453
454 @type obj: a data object
455 @param obj: the object whose parameters to be listed
456 @type recursively: boolean
457 @param recursively: if True, the parameters in the sub-objects will
458 be included. False otherwise
459 @type excluded: a list of objects, whose child parameters should be
460 excluded in the result
461 @return: a list of parameters (Dataset object).
462 """
463 params = []
464 for name in obj.info.listParams():
465
466
467 params.append(obj.get(name))
468
469
470 if recursively:
471 for childobj in obj.listObjects():
472 if childobj.name not in excluded:
473
474
475 params.extend(self.listParams(childobj, True))
476
477 return params
478
480 """Different from the methdo listRefinemetns, this function will return
481 a list of names of the refinements
482 @return: the names of the refinements in this project"""
483
484 return [refObj.name for refObj in self.listRefinements()]
485
486 - def plot(self, ydatasets, xdataset=None, xlabel=''):
487 """
488 Make a plot of selected Dataset objects
489 @type ydatasets: a list of Dataset objects
490 @param ydatasets: contains the list of y values in the plot
491 @type xdataset: a Dataset object
492 @param xdataset: the x values in the plot. If xdataset is None
493 a list of consecutive integers will be used
494 @return: the plot figure object
495 """
496
497 from diffpy.refinementdata.plot.figure import Figure
498 name = ','.join([dataset.name for dataset in ydatasets])
499 figure = Figure(self.mainframe, name)
500 if xlabel:
501 figure.selection.xlabel = xlabel
502
503
504 datasets = ydatasets[:]
505 if xdataset:
506 datasets.append(xdataset)
507 fits = set()
508 for dataset in datasets:
509 fits.add(dataset.owner.findRefinement())
510
511 metadata = []
512 for fit in fits:
513 metadata.extend(fit.getMetaData())
514
515 figure.plot(xdataset, ydatasets, metadata=metadata)
516 self.plots.append(figure)
517 figure.frame.updateIndexBox()
518 return figure
519
520 - def plotHistory(self, datasets):
521 """
522 Make a plot of the historic values of selected parameters.
523
524 @type datasets: a list of Dataset objects
525 @param datasets: the data to be plotted
526 @return: no return value
527 """
528 histories = []
529 for dataset in datasets:
530 history = dataset.owner.getHistory(dataset.name)
531 if history is not None:
532 histories.append(history)
533
534 self.plot(histories, xlabel='step')
535 return
536
538 """Make a quick plot of patterns.
539
540 objects -- a list of selected objects
541 """
542 from diffpy.refinementdata.plot.patternfigure import PatternFigure
543 def _plotPattern(_pattern):
544 _figure = PatternFigure(self.mainframe, _pattern.path, figsize=(6, 4))
545 _figure.selection.addMetadata(_pattern.findRefinement().getMetaData())
546 _figure.plotPattern(_pattern)
547 _figure.frame.updateIndexBox()
548 self.plots.append(_figure)
549
550
551 for object in objects:
552 if isinstance(object, Pattern):
553 _plotPattern(object)
554 else:
555 patterns = object.listObjects(Pattern, True)
556 if patterns:
557 for pattern in patterns:
558 _plotPattern(pattern)
559 return
560
562 """
563 Update the fit data with an fit object.
564
565 @param fitrt: the run-time fit instance
566 @return: no return value
567 """
568
569 fitrt.fit.steps[fitrt.index] = fitrt.step+1
570
571 def _updateValue(parpath):
572
573 _value = fitrt.enginefit.getByPath(parpath)
574
575
576 _history = fitrt.fit.getHistoryByPath(parpath)
577 if _history is None:
578 _history = fitrt.fit.addHistoryByPath(parpath)
579
580 _history.update(fitrt.step, _value, fitrt.index)
581
582
583 _param = fitrt.fit.getByPath(parpath)
584 _param[fitrt.index] = _value
585
586 return
587
588 def _updateSigma(constraint):
589
590 _path = constraint.owner.path+'.'+constraint.parname
591 _sigma = fitrt.fit.getSigmaByPath(_path)
592
593 if _sigma is None:
594 _sigma = fitrt.fit.addSigmaByPath(_path)
595
596 if constraint.index is not None:
597 if fitrt.index is None:
598 _index = constraint.index
599 else:
600 if isinstance(fitrt.index, tuple):
601 _index = list(fitrt.index)
602 else:
603 _index = [fitrt.index,]
604 _index.append(constraint.index)
605 _index = tuple(_index)
606 _sigma[_index] = constraint.sigma
607 else:
608 _sigma[fitrt.index] = constraint.sigma
609
610 return
611
612 def _tolist(refl):
613
614 _value = []
615 for key in refl:
616 for dd in refl[key]:
617 for val in dd.values():
618 _value.append(val[0])
619
620 _value.sort()
621 return _value
622
623
624 pathlist = [ c.owner.path+'.'+c.parname for c in fitrt.enginefit.Refine.constraints]
625 pathlist = set(pathlist)
626
627 for path in pathlist:
628 _updateValue(path)
629
630 if path.endswith('Uiso') or path.endswith('Biso'):
631 if path.endswith('Uiso'):
632 _path = path.rsplit('.', 1)[0] + '.Biso'
633 elif path.endswith('Biso'):
634 _path = path.rsplit('.', 1)[0] + '.Uiso'
635 _updateValue(_path)
636
637 for constraint in fitrt.enginefit.Refine.constraints:
638 _updateSigma(constraint)
639
640 if constraint.parname in ['Biso', 'Uiso']:
641 if constraint.parname == 'Uiso':
642 _parname = 'Biso'
643 elif constraint.parname == 'Biso':
644 _parname = 'Uiso'
645
646 _path = constraint.owner.path + '.' + _parname
647 _sigma = fitrt.fit.getSigmaByPath(_path)
648
649 if _sigma is None:
650 _sigma = fitrt.fit.addSigmaByPath(_path)
651
652 if _parname == 'Biso':
653 _val = constraint.sigma * 8 * math.pow(math.pi, 2)
654 elif _parname == 'Uiso':
655 _val = constraint.sigma / ( 8 * math.pow(math.pi, 2))
656
657 if constraint.index is not None:
658 if fitrt.index is None:
659 _index = constraint.index
660 else:
661 if isinstance(fitrt.index, tuple):
662 _index = list(fitrt.index)
663 else:
664 _index = [fitrt.index,]
665 _index.append(constraint.index)
666 _index = tuple(_index)
667
668 _sigma[_index] = _val
669 else:
670 _sigma[fitrt.index] = _val
671
672
673 _updateValue('Chi2')
674
675 for pattern in fitrt.enginefit.Pattern.get():
676 _updateValue(pattern.path+'.Rp')
677 _updateValue(pattern.path+'.Rwp')
678
679
680 for dname in ['xobs', 'yobs', 'ycal', 'refl']:
681 mypattern = fitrt.fit.getByPath(pattern.path)
682 dataset = mypattern.get(dname)
683 if dname == 'refl':
684 data = _tolist(getattr(pattern, '_reflections'))
685 else:
686 data = getattr(pattern, '_'+dname)
687 if dataset is None:
688 mypattern.set(dname, data, repeat=True, grow=True)
689 else:
690 dataset[fitrt.index] = data
691
692
693 if self.mainframe:
694 from diffpy.refinementdata.plot import backend
695 backend.lock()
696 if fitrt.step > 0:
697 self.mainframe.updateStrategyPanel(fitrt)
698 self.updatePlot(fitrt.index)
699 backend.unlock()
700 return
701
703 """
704 Update the histogram and instrument file paths saved in the project
705 @return: no return value
706 """
707 for fit in self.listRefinements():
708 patterns = fit.getObject("Pattern")
709 for pt in patterns:
710 for param in ['Datafile', 'MFIL', 'Instrumentfile']:
711 ds = pt.get(param)
712 if ds:
713 for idx in fit.range():
714 fit.findFileFromDataset(ds, idx)
715 return
716
718 """
719 Update all plots based the data in the project
720
721 @type index: tuple or int
722 @param index: the index of the curve in the plot
723 @return: no return value
724 """
725 for plot in self.plots[:]:
726 if plot.alive:
727 plot.update(index)
728 else:
729 self.plots.remove(plot)
730
731 return
732
734 """
735 Save the project data to local drive.
736
737 @type fullpath: file path string
738 @param fullpath: the file path to save the project data
739 @return: no return value
740 """
741 self.save()
742 try:
743
744
745 srcpath = self.fullpath
746
747 if self.fullpath == fullpath:
748 tmpfobj, tmpfpath = tempfile.mkstemp()
749 shutil.copy2(self.fullpath, tmpfpath)
750 os.remove(fullpath)
751 srcpath = tmpfpath
752
753 shutil.copy2(srcpath, fullpath)
754
755 os.chmod(fullpath, stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)
756 self.basepath = os.path.dirname(fullpath)
757 self.fullpath = fullpath
758 self.__updateRelativePaths()
759 self.save()
760
761 except Exception,e:
762 if GLOBALS.isDebug:
763 import sys, traceback
764 exc_type, exc_value, exc_traceback = sys.exc_info()
765 traceback.print_exception(exc_type, exc_value, exc_traceback,
766 limit=2, file=sys.stdout)
767 else:
768 raise SrrIOError("Can not save file: " + str(e))
769
770 return
771
773 """
774 Compare the fitname with existing names of refienments in the project.
775 A valid name will be returned.
776
777 - All the I{dots (.)} will be replaced with I{underscore (_)}
778 - The same refinement names will be surfixed with numbers with brakets
779 such as I{Refinement}, I{Refinement (1)}, I{Refinement (2)}
780 @type fitName: string
781 @param fitName: refinement name to be verified
782
783 @return: If fitName is valid, return value
784 will be the same as fitName; if fitName conains invalid
785 characters or has duplications, a valid refinement name will
786 be returned
787 """
788 nameid = 0
789 fitName = fitName.replace('.', '_')
790 for fit in self.listRefinements():
791 if fit.name == fitName:
792 nameid += 1
793 if nameid != 0:
794 newname = fitName + "_%s"%(nameid)
795 else:
796 newname = fitName
797
798 return newname
799
801 """
802 Export the project to another location, without changing the original file
803 This is different from saving as, since the original state, basepath,
804 etc. of the project is not changed. The original file will be copied to
805 a temp dir, and then the project will be saved to the export directory.
806 Then the protected old file will be saved back
807
808 @type filePath: file path string
809 @param filePath: the file path where the project will be exported
810 @return: no return value
811 """
812
813 oldPath = self.fullpath
814 oldBase = self.basepath
815 oldAltered = self.altered
816 tmpfile, tmppath = tempfile.mkstemp()
817
818 shutil.copy(self.fullpath, tmppath)
819
820 self.saveToDisk(filePath)
821
822 shutil.copy(tmppath, oldPath)
823 try:
824 os.remove(tmppath)
825 except:
826 UTILS.printWarning("Can not delete the temp project file after archive the project")
827
828 self.fullpath = oldPath
829 self.basepath = oldBase
830 self.altered = oldAltered
831 self.__updateRelativePaths()
832
833 return
834
836 """
837 As long as the base path is changed, the relative paths saved in
838 patterns need to be updated.
839
840 @type basepath: folder path string
841 @param basepath: the basepath used to determine the relative path, use
842 project basepath if None
843 @return: no return value
844
845 """
846 if basepath is None: basepath = self.basepath
847
848 fits = self.listRefinements()
849 for fit in fits:
850 patterns = fit.getObject("Pattern")
851 for pt in patterns:
852 for param in ['Datafile', 'MFIL', 'Instrumentfile']:
853 ds = pt.get(param)
854 if ds is not None:
855 for idx in fit.range():
856 filepath = fit.findFileFromDataset(ds, idx)
857 if filepath is not None:
858 fit.saveFilePathToDataset(filepath, ds,
859 index = idx,
860 basepath = basepath)
861
862 return
863
865 """
866 Get the absolute file path. The filePath may be relative, to the engine file
867 directory, which may also be different from the current working dir
868
869 @type relFilePath: file path string
870 @param relFilePath: relative file path
871 @type engineFileDir: directory path string
872 @param engineFileDir: the directory where the engine file sits
873 @return: the absolute file path
874 """
875 absFilePath = None
876 if not os.path.isabs(relFilePath):
877 absFilePath = os.path.join(engineFileDir, relFilePath)
878 else:
879 absFilePath = relFilePath
880 return absFilePath
881
882
883
884
885
886
887
889 """
890 Import an engine object into myobject
891
892 @type myobjowner: an data object
893 @param myobjowner: the owner of the refinement object to be created
894 @type engineobj: an engine fit object
895 @param engineobj: the engine object to be imported
896 @type id: integer or string
897 @param id: the name/index of the refinement object to be created/obtained in myobjowner
898 @type index: index or tuple
899 @param index: when modify an object, indicate the index to the value to be modified
900
901 return: the newly created local object
902 """
903
904 if engineobj.__class__.__name__.startswith('Refine'):
905 return None
906
907
908 myobj = myobjowner.getObject(id)
909 if myobj is None:
910
911 objclassname = engineobj.__class__.__name__
912 if objclassname.startswith('Fit'):
913 myobjclass = Refinement
914 elif objclassname.startswith('Pattern'):
915 myobjclass = Pattern
916 elif objclassname.startswith('Phase'):
917 myobjclass = Phase
918 elif objclassname.startswith('Profile'):
919 myobjclass = Profile
920 elif engineobj.name.startswith('Atom['):
921 myobjclass = Atom
922 elif engineobj.name.startswith('ExcludedRegion['):
923 myobjclass = ExcludedRegion
924 else:
925
926 myobjclass = Refinable
927
928
929 if isinstance(myobjowner, ObjectList):
930 assert(id==len(myobjowner))
931 myobj = myobjowner.appendObject(myobjclass)
932 else:
933 myobj = myobjowner.addObject(id, myobjclass)
934
935 myobj.setAttr('rietveldcls', engineobj.__class__.__module__ + '.'
936 + engineobj.__class__.__name__)
937 engineobjclass = getattr(__import__(engineobj.__class__.__module__,
938 globals(), locals(),
939 [engineobj.__class__.__name__], -1),
940 engineobj.__class__.__name__)
941 myobj.info = ObjectInfo(engineobjclass)
942
943 if engineobj.__class__.__name__.startswith("Contribution"):
944 myobj.setAttr("phaseindex", engineobj.getPhaseIndex())
945 myobj.setAttr("patternindex", engineobj.getPatternIndex())
946
947 def _importEngineParameter(infodict):
948 for param in infodict:
949
950 dataset = myobj.get(param)
951
952 if dataset is None:
953 info = infodict[param]
954 grow = False
955
956
957 if info.__class__.__name__ == 'StringInfo' and not info.fixlen:
958 grow = True
959
960 dataset = myobj.set(param, engineobj.get(param), repeat=True, grow=grow)
961 if dataset.name == 'BACK':
962 labels = myobj.labels
963 labels.append('Order')
964 dataset.setLabels(labels)
965 else:
966 dataset[index] = engineobj.get(param)
967
968
969 _importEngineParameter(engineobj.ParamDict)
970
971
972 _importEngineParameter(engineobj.ParamListDict)
973
974
975 for childname in engineobj.ObjectDict:
976 engineobjchild = engineobj.get(childname)
977 importEngineObject(myobj, engineobjchild, childname, index=index)
978
979
980 for childname in engineobj.ObjectListDict:
981 mycontainer = myobj.getObject(childname)
982 if mycontainer is None:
983 mycontainer = myobj.addObject(childname, ObjectList)
984 mycontainer.info = ObjectInfo()
985 enginecontainer = engineobj.get(childname)
986 for i in range(len(enginecontainer)):
987 engineobjchild = engineobj.get(childname, i)
988 importEngineObject(mycontainer, engineobjchild, i, index=index)
989
990 return myobj
991
992
994 """
995 Export a saved HDF5Object to an engine object.
996
997 @type engineobjowner: a data object
998 @param engineobjowner: the owner of the engine object
999
1000 @type myobj: a refinement object
1001 @param myobj: the source refinement object
1002 @type index: interger or tuple
1003 @param index: the index to the single value if the object is multiplexed
1004
1005 @return: the newly created engine object
1006 """
1007 try:
1008 engineclassname = myobj.getAttr('rietveldcls')
1009 modulename, classname = engineclassname.rsplit('.', 1)
1010 module = __import__(modulename, globals(), locals(), [classname], -1)
1011 engineclass = getattr(module, classname)
1012 except (ImportError, AttributeError, ValueError):
1013
1014 raise DataError('Unknown object "%s" found.'%engineclassname)
1015
1016 engineobj = engineclass(engineobjowner)
1017
1018 if engineobjowner is not None:
1019 engineobjowner.set(key, engineobj)
1020
1021
1022 import numpy
1023 for param in myobj.list():
1024 if param.name not in engineobj.ParamDict and param.name not in engineobj.ParamListDict:
1025 continue
1026 if len(param.shape) == 0:
1027 value = param.first()
1028 else:
1029 value = param[index]
1030
1031 if isinstance(value, numpy.ndarray):
1032 if value.ndim == 1:
1033 for x in value:
1034 engineobj.set(param.name, x)
1035 elif value.ndim == 0:
1036 engineobj.set(param.name, value[()])
1037 else:
1038 raise DataError('Parameter "%s" has more than one dimensions' %param)
1039 continue
1040 engineobj.set(param.name, value)
1041
1042
1043 for myobjchild in myobj.objects:
1044 subkey = myobjchild.name
1045 if isinstance(myobjchild, ObjectList):
1046
1047 for obj in myobjchild.objects:
1048 exportEngineObject(engineobj, subkey, obj, index)
1049 else:
1050 exportEngineObject(engineobj, subkey, myobjchild, index)
1051
1052 return engineobj
1053
1055 """
1056 Copy data to another object.
1057
1058 srcobj -- the source object
1059 destobj -- the target parent object
1060 srcindexlist -- a list of indices to read data from srcobj
1061 destindexlist -- a list of indices to write data to destobj
1062 """
1063
1064 for attr, val in srcobj.listAttrs(withValue=True):
1065 destobj.setAttr(attr, val)
1066
1067 destobj.copyMeta(srcobj)
1068
1069
1070 for srcdata in srcobj.list():
1071 destdata = destobj.get(srcdata.name)
1072 if destdata is None:
1073
1074 destdata = destobj.replicate(srcdata, srcindexlist[0], None)
1075
1076 for srcindex, destindex in zip(srcindexlist, destindexlist):
1077 destdata[destindex] = srcdata[srcindex]
1078
1079 if isinstance(srcobj, ObjectList):
1080
1081 for i, childobj in enumerate(srcobj.listObjects()):
1082 destchildobj = destobj.getObject(i)
1083 if destchildobj is None:
1084
1085 destchildobj = destobj.appendObject(childobj.__class__)
1086
1087
1088 copyObjectData(childobj, destchildobj, srcindexlist, destindexlist)
1089 else:
1090
1091 for childobj in srcobj.listObjects():
1092 destchildobj = destobj.getObject(childobj.name)
1093 if destchildobj is None:
1094 destchildobj = destobj.addObject(childobj.name, childobj.__class__)
1095
1096 copyObjectData(childobj, destchildobj, srcindexlist, destindexlist)
1097
1098 if isinstance(srcobj, Refinable):
1099
1100 copyObjectData(srcobj.history, destobj.history, srcindexlist, destindexlist)
1101 copyObjectData(srcobj.sigma, destobj.sigma, srcindexlist, destindexlist)
1102
1103 return destobj
1104
1105