-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontribution_function_generator.py
More file actions
156 lines (118 loc) · 5.16 KB
/
contribution_function_generator.py
File metadata and controls
156 lines (118 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
--- Contribution Function Generator ---
This script reads the CSV files generated by data_load_better.py, generates corresponding contribution functions,
and saves these. Use a config file to specify the output directory.
Example:
$ python edit_data.py --config <path_to_config_file>
"""
#---######################################################################
#--- Imports
#---######################################################################
import argparse
import yaml
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import xarray as xr
from taurex_utils import get_mols, full_contribution_array
import json
#---######################################################################
#--- Functions
#---######################################################################
def m_to_rJ(distance_m):
"""
Convert meters to Jup radii.
"""
return distance_m / 7.1492e7
def m_to_rS(distance_m):
"""
Convert meters to Solar radii.
"""
return distance_m / 6.957e8
def kg_to_Mj(mass_kg):
"""
Convert kilograms to Jupiter masses.
"""
return mass_kg / 1.898e27
def save_checkpoint(checkpoint_path, current_index):
"""
Save the current progress to a checkpoint file.
"""
with open(checkpoint_path, 'w') as f:
json.dump({"current_index": current_index}, f)
def load_checkpoint(checkpoint_path):
"""
Load the last saved progress from a checkpoint file.
"""
if os.path.exists(checkpoint_path):
with open(checkpoint_path, 'r') as f:
return json.load(f).get("current_index", 0)
return 0
#---######################################################################
#--- Main Script
#---######################################################################
if __name__ == "__main__":
# Parse arguments and load config
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="Path to config file")
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# Get output directory from config
save_config = config.get("output", {})
save_path = save_config.get("output_dir", "./output")
log_utils.write_box(f"Editing data in: {save_path}", style='hashdash')
# Process both train and validation data
dir_path = os.path.join(save_path, 'validation')
if not os.path.exists(dir_path):
raise FileNotFoundError(f"Directory {dir_path} does not exist")
# Load data
data, labels, aux, aux_f = read_csv_data(dir_path, full_auxilliary=True)
# Checkpoint path
checkpoint_path = os.path.join(dir_path, "generation_checkpoint.json")
# Load checkpoint if available
start_index = load_checkpoint(checkpoint_path)
#--- generate contribution functions for each species
db_contribs = {}
for species in get_mols():
# print(species)
db_contribs[species] = np.zeros_like(data)
# print(db_contribs.keys())
# print('Generating contribution functions...')
for planet in tqdm(range(start_index,data.shape[0])):
abundancies = labels[planet, 1:]
planet_temp = labels[planet, 0]
planet_radius = m_to_rJ(aux_f[planet, -2])
planet_mass = kg_to_Mj(aux_f[planet, 4])
star_radius = m_to_rS(aux_f[planet, 2])
star_temp = aux_f[planet, 3]
# print(f'Planet {planet} - ')
#for species in get_mols():
# print(f'{species}: {abundancies[get_mols().index(species)]}')
# print(f'Planet Temp: {planet_temp}\nPlanet Radius: {planet_radius}\nPlanet Mass: {planet_mass}')
# print(f'Star Temp: {star_temp}\nStar Radius: {star_radius}')
# generate the contribution functions for all of the elements present in the planet
contribs = full_contribution_array(['H2O', 'CO2', 'CH4', 'CO', 'NH3'],
abundancies,
Rp=planet_radius,
Tp=planet_temp,
Mp=planet_mass,
Rs= star_radius,
Ts=star_temp)
for species in get_mols():
db_contribs[species][planet] = contribs[species][1]
# save every 30 planets
if planet % 10 == 0:
for species in get_mols():
data = db_contribs[species]
np.savetxt(os.path.join(dir_path, f'contributions_temp_{species}.csv'), data, delimiter=',')
save_checkpoint(checkpoint_path, planet)
# Save data back to CSV
for species in get_mols():
data = db_contribs[species]
np.savetxt(os.path.join(dir_path, f'contributions_temp_{species}.csv'), data, delimiter=',')
# Remove checkpoint file after successful completion
if os.path.exists(checkpoint_path):
os.remove(checkpoint_path)
log_utils.write_box("Contribution Generation complete!", style='hashdash')