1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def train_test_split(X, y, test_ratio=0.2, seed=None): """ 将数据X和y,按照test_ratio分割成X_train, y_train; X_test, y_test """ if seed: np.random.seed(seed)
# 对索引的乱序 shuffle_indexes = np.random.permutation(len(X))
test_size = int(len(X) * test_ratio)
test_indexes = shuffle_indexes[:test_size] train_indexes = shuffle_indexes[test_size:]
X_train = X[train_indexes] y_train = y[train_indexes]
X_test = X[test_indexes] y_test = y[test_indexes]
return X_train, y_train, X_test, y_test
|