1
0
mirror of https://github.com/ciromattia/kcc synced 2025-12-13 09:46:25 +00:00

High performance improvements by using rfft2 instead of fft2

This commit is contained in:
Its-my-right
2025-07-12 16:39:21 +02:00
committed by Alex Xu
parent cf047ecf6f
commit f6d10337d8

View File

@@ -1,31 +1,26 @@
import numpy as np import numpy as np
from PIL import Image from PIL import Image
def fourier_transform_image(img): def fourier_transform_image(img):
""" """
Performs a 2D Fourier transform on a PIL image. Memory-optimized version that modifies the array in place when possible.
Args:
img: PIL Image (can be color or grayscale)
Returns:
fft_result: Complex result of the 2D FFT
""" """
# Convert PIL image to NumPy array # Convert with minimal copy
img_array = np.array(img) img_array = np.asarray(img, dtype=np.float32)
# Perform 2D Fourier transform # Use rfft2 if the image is real to save memory
fft_result = np.fft.fft2(img_array) # and computation time (approximately 2x faster)
fft_result = np.fft.rfft2(img_array)
return fft_result return fft_result
def attenuate_diagonal_frequencies(fft_spectrum, freq_threshold=0.3, target_angle=135, def attenuate_diagonal_frequencies(fft_spectrum, freq_threshold=0.3, target_angle=135,
angle_tolerance=15, attenuation_factor=0.1): angle_tolerance=15, attenuation_factor=0.1):
""" """
Attenuates specific frequencies in the Fourier domain (optimized version). Attenuates specific frequencies in the Fourier domain (optimized version for rfft2).
Args: Args:
fft_spectrum: Result of 2D Fourier transform (non-centered) fft_spectrum: Result of 2D real Fourier transform (from rfft2)
freq_threshold: Frequency threshold in cycles/pixel (default: 0.3, theoretical max: 0.5) freq_threshold: Frequency threshold in cycles/pixel (default: 0.3, theoretical max: 0.5)
target_angle: Target angle in degrees (default: 135) target_angle: Target angle in degrees (default: 135)
angle_tolerance: Angular tolerance in degrees (default: 15) angle_tolerance: Angular tolerance in degrees (default: 15)
@@ -35,12 +30,14 @@ def attenuate_diagonal_frequencies(fft_spectrum, freq_threshold=0.3, target_angl
np.ndarray: Modified FFT with applied attenuation (same format as input) np.ndarray: Modified FFT with applied attenuation (same format as input)
""" """
# Get dimensions # Get dimensions of the rfft2 result
height, width = fft_spectrum.shape height, width_rfft = fft_spectrum.shape
# For rfft2, the original width is (width_rfft - 1) * 2
width_original = (width_rfft - 1) * 2
# Create frequency grids in an optimized way # Create frequency grids for rfft2 format
freq_y = np.fft.fftfreq(height, d=1.0) freq_y = np.fft.fftfreq(height, d=1.0)
freq_x = np.fft.fftfreq(width, d=1.0) freq_x = np.fft.rfftfreq(width_original, d=1.0) # Use rfftfreq for the X dimension
# Use broadcasting to create grids without meshgrid (more efficient) # Use broadcasting to create grids without meshgrid (more efficient)
freq_y_grid = freq_y.reshape(-1, 1) # Column freq_y_grid = freq_y.reshape(-1, 1) # Column
@@ -65,7 +62,8 @@ def attenuate_diagonal_frequencies(fft_spectrum, freq_threshold=0.3, target_angl
angles_deg = np.rad2deg(angles_rad) % 360 angles_deg = np.rad2deg(angles_rad) % 360
# Optimize angular condition # Optimize angular condition
# Calculate both target angles at once # For rfft2, we only process angles in the positive half-plane of X
# So we only calculate the main angle, not its opposite
target_angle_2 = (target_angle + 180) % 360 target_angle_2 = (target_angle + 180) % 360
# Create angular conditions in a vectorized way # Create angular conditions in a vectorized way
@@ -96,22 +94,20 @@ def attenuate_diagonal_frequencies(fft_spectrum, freq_threshold=0.3, target_angl
# General case: partial attenuation # General case: partial attenuation
fft_spectrum[combined_condition] *= attenuation_factor fft_spectrum[combined_condition] *= attenuation_factor
return fft_spectrum return fft_spectrum
def inverse_fourier_transform_image(fft_spectrum): def inverse_fourier_transform_image(fft_spectrum):
""" """
Performs an inverse Fourier transform to reconstruct a PIL image. Performs an optimized inverse Fourier transform to reconstruct a PIL image.
Args: Args:
fft_spectrum: Fourier transform result (complex array) fft_spectrum: Fourier transform result (complex array from rfft2)
original_shape: Original image shape (height, width) for proper cropping
Returns: Returns:
PIL.Image: Reconstructed image PIL.Image: Reconstructed image
""" """
# Perform inverse Fourier transform # Perform inverse Fourier transform
img_reconstructed = np.fft.ifft2(fft_spectrum) img_reconstructed = np.fft.irfft2(fft_spectrum)
# Take real part (eliminate imaginary artifacts due to numerical errors)
img_reconstructed = np.real(img_reconstructed)
# Normalize values between 0 and 255 # Normalize values between 0 and 255
img_reconstructed = np.clip(img_reconstructed, 0, 255) img_reconstructed = np.clip(img_reconstructed, 0, 255)
@@ -122,10 +118,8 @@ def inverse_fourier_transform_image(fft_spectrum):
return pil_image return pil_image
import numpy as np
def erase_rainbow_artifacts(img): def erase_rainbow_artifacts(img):
fft_spectrum = fourier_transform_image(img) fft_spectrum = fourier_transform_image(img)
clean_spectrum = attenuate_diagonal_frequencies(fft_spectrum) clean_spectrum = attenuate_diagonal_frequencies(fft_spectrum)
clean_image = inverse_fourier_transform_image(clean_spectrum) clean_image = inverse_fourier_transform_image(clean_spectrum)
return clean_image return clean_image