%matplotlib inline
from pylab import rcParams
rcParams['figure.figsize'] = (9, 6) # Figure size for inline display
import numpy as np
from scipy.fftpack import ifftshift, fftshift, fft2
import warnings
warnings.filterwarnings('ignore')
from pynx.utils.pattern import siemens_star
# This will import the base CDI class and all relevant operators, in CUDA (default) or OpenCL
from pynx.cdi import *
if False:
# The following is only using on multi-GPU, multi-user (e.g. Amazon) machines
# Select a device number (0..15) to avoid resource conflict with other participants
from pynx.processing_unit.cuda_device import has_cuda
if has_cuda: # use cuda
from pynx.cdi.cu_operator import default_processing_unit_cdi
from pynx.processing_unit.cuda_device import cu_drv
default_processing_unit_cdi.init_cuda(cu_device=cu_drv.Device(0)) # Change this to you GPU number
else:
import pyopencl as cl
from pynx.cdi.cl_operator import default_processing_unit_cdi
# Only needed on Amazon to select one of multiple GPU
cl_devices=cl.get_platforms()[1].get_devices()
default_processing_unit_cdi.init_cl(cl_device=cl_devices[0]) # Change this to you GPU number
# Test on a simulated pattern (2D)
n = 512
# Siemens-Star object
obj0 = siemens_star(n, nb_rays=18, r_max=60, nb_rings=3)
# Start from a slightly loose disc support
x, y = np.meshgrid(np.arange(-n // 2, n // 2, dtype=np.float32), np.arange(-n // 2, n // 2, dtype=np.float32))
r = np.sqrt(x ** 2 + y ** 2)
support = r < 65
iobs = abs(ifftshift(fft2(fftshift(obj0.astype(np.complex64))))) ** 2
iobs = np.random.poisson(iobs * 1e10 / iobs.sum())
mask = np.zeros_like(iobs, dtype=np.int16)
# Create CDI object, make sure it is fft-shifted
cdi = CDI(fftshift(iobs), obj=None, support=fftshift(support), mask=mask, pixel_size_object=1e-8, lambdaz=1e-10)
# Move data to processing unit
cdi = cdi
%matplotlib inline
# Do 100 cycles of RAAR, displaying object at the end
cdi = ShowCDI(fig_num=1) * RAAR()**100 * cdi
# Activate live display
%matplotlib notebook
for i in range(10):
# Support update operator
s = 0.25+1.75*np.exp(-i/4)
sup = SupportUpdate(threshold_relative=0.4, smooth_width=s, force_shrink=False)
# Do 40 cycles of RAAR, displaying object every 20 cycle
cdi = (ShowCDI(fig_num=1) * RAAR() ** 20) ** 2 * cdi
# Do 20 cycles of ER, displaying object at the end
cdi = ShowCDI(fig_num=1) * ER() ** 20 * cdi
# Do 20 cycles of ML, displaying object every 20 cycle
cdi = ShowCDI(fig_num=1) * ML(reg_fac=1e-2, nb_cycle=20) * cdi
# Update support
cdi = ShowCDI(fig_num=1) * sup * cdi
IFT() * LLK() * FT() * cdi
print("LLK/nbpoint = %8.3f" % ((cdi.llk - cdi._llk_offset) / cdi._llk_norm))