-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathclient.py
More file actions
195 lines (167 loc) · 6.29 KB
/
client.py
File metadata and controls
195 lines (167 loc) · 6.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Client-side script for Feature Election example.
This script demonstrates how to set up client data for the
FeatureElectionExecutor from nvflare.app_opt.feature_election.
"""
import logging
import re
from typing import Optional
from prepare_data import load_client_data
from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.feature_election.executor import FeatureElectionExecutor
logger = logging.getLogger(__name__)
def get_executor(
client_id: int,
num_clients: int,
fs_method: str = "lasso",
eval_metric: str = "f1",
data_root: Optional[str] = None,
split_strategy: str = "stratified",
n_samples: int = 1000,
n_features: int = 100,
n_informative: int = 20,
n_redundant: int = 30,
) -> FeatureElectionExecutor:
"""
Create and configure a FeatureElectionExecutor with data.
Args:
client_id: Client identifier (0 to num_clients-1)
num_clients: Total number of clients
fs_method: Feature selection method
eval_metric: Evaluation metric ('f1' or 'accuracy')
data_root: Optional path to pre-generated data
split_strategy: Data splitting strategy
n_samples: Samples per client for synthetic data
n_features: Number of features
n_informative: Number of informative features
n_redundant: Number of redundant features
Returns:
Configured FeatureElectionExecutor
"""
# Create executor
executor = FeatureElectionExecutor(
fs_method=fs_method,
eval_metric=eval_metric,
task_name="feature_election",
)
# Load data for this client
X_train, y_train, X_val, y_val, feature_names = load_client_data(
client_id=client_id,
num_clients=num_clients,
data_root=data_root,
split_strategy=split_strategy,
n_samples=n_samples,
n_features=n_features,
n_informative=n_informative,
n_redundant=n_redundant,
)
# Set data on executor
executor.set_data(
X_train=X_train,
y_train=y_train,
X_val=X_val,
y_val=y_val,
feature_names=feature_names,
)
logger.info(
f"Client {client_id} executor configured: "
f"{X_train.shape[0]} train, {X_val.shape[0]} val, "
f"{X_train.shape[1]} features, method={fs_method}"
)
return executor
class SyntheticDataExecutor(FeatureElectionExecutor):
"""
FeatureElectionExecutor with built-in synthetic data loading.
This executor automatically loads synthetic data based on
client_id extracted from the FL context.
Args:
fs_method: Feature selection method
eval_metric: Evaluation metric
num_clients: Total number of clients in federation
split_strategy: Data splitting strategy
n_samples: Samples per client
n_features: Number of features
n_informative: Number of informative features
n_redundant: Number of redundant features
n_repeated: Number of repeated features
"""
def __init__(
self,
fs_method: str = "lasso",
eval_metric: str = "f1",
num_clients: int = 3,
split_strategy: str = "stratified",
n_samples: int = 1000,
n_features: int = 100,
n_informative: int = 20,
n_redundant: int = 30,
n_repeated: int = 10,
task_name: str = "feature_election",
):
super().__init__(
fs_method=fs_method,
eval_metric=eval_metric,
task_name=task_name,
)
self.num_clients = num_clients
self.split_strategy = split_strategy
self.n_samples = n_samples
self.n_features = n_features
self.n_informative = n_informative
self.n_redundant = n_redundant
self.n_repeated = n_repeated
self._data_loaded = False
def _load_data_if_needed(self, fl_ctx: FLContext) -> None:
"""Load data based on client identity from FL context."""
if self._data_loaded:
return
# Extract client ID from site name
site_name = fl_ctx.get_identity_name()
try:
if site_name.startswith("site-"):
client_id = int(site_name.split("-")[1]) - 1
else:
match = re.search(r"\d+", site_name)
if match:
client_id = int(match.group()) - 1
else:
client_id = 0
# Validate range
if not (0 <= client_id < self.num_clients):
raise ValueError(
f"Extracted client_id {client_id} from '{site_name}' is out of range [0, {self.num_clients - 1}]"
)
except (ValueError, IndexError) as e:
logger.error(f"Failed to parse client_id from '{site_name}': {e}. Defaulting to client_id=0")
client_id = 0
# Load data using the parsed ID
X_train, y_train, X_val, y_val, feature_names = load_client_data(
client_id=client_id,
num_clients=self.num_clients,
split_strategy=self.split_strategy,
n_samples=self.n_samples,
n_features=self.n_features,
n_informative=self.n_informative,
n_redundant=self.n_redundant,
n_repeated=self.n_repeated,
)
self.set_data(X_train, y_train, X_val, y_val, feature_names)
self._data_loaded = True
logger.info(f"Successfully loaded synthetic data for {site_name} (client_id={client_id})")
def execute(self, task_name, shareable, fl_ctx, abort_signal):
"""Override execute to ensure data is loaded before processing."""
self._load_data_if_needed(fl_ctx)
return super().execute(task_name, shareable, fl_ctx, abort_signal)