李宏毅2021年(春)深度学习 作业一
下载数据集
python
tr_path = 'covid.train.csv' # path to training data |
Downloading...
From: https://drive.google.com/uc?id=19CCyCgJrUxtvgZF53vnctJiOJ23T5mqF
To: /Users/baikal/machineLearning/lessonOne/covid.train.csv
100%|███████████████████████████████████████| 2.00M/2.00M [00:17<00:00, 115kB/s]
Downloading...
From: https://drive.google.com/uc?id=1CE240jLm2npU-tdz81-oVKEF3T2yfT1O
To: /Users/baikal/machineLearning/lessonOne/covid.test.csv
100%|█████████████████████████████████████████| 651k/651k [00:05<00:00, 125kB/s]
需要使用google下载工具下载google drive上的文件,安装方法:pip install gdown
查看数据集
python
#下面三个包是新增的 |
python
# 读取训练数据前5行 |
AL | AK | AZ | AR | CA | CO | CT | FL | GA | ID | ... | restaurant.2 | spent_time.2 | large_event.2 | public_transit.2 | anxious.2 | depressed.2 | felt_isolated.2 | worried_become_ill.2 | worried_finances.2 | tested_positive.2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 23.812411 | 43.430423 | 16.151527 | 1.602635 | 15.409449 | 12.088688 | 16.702086 | 53.991549 | 43.604229 | 20.704935 |
1 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 23.682974 | 43.196313 | 16.123386 | 1.641863 | 15.230063 | 11.809047 | 16.506973 | 54.185521 | 42.665766 | 21.292911 |
2 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 23.593983 | 43.362200 | 16.159971 | 1.677523 | 15.717207 | 12.355918 | 16.273294 | 53.637069 | 42.972417 | 21.166656 |
3 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 22.576992 | 42.954574 | 15.544373 | 1.578030 | 15.295650 | 12.218123 | 16.045504 | 52.446223 | 42.907472 | 19.896607 |
4 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 22.091433 | 43.290957 | 15.214655 | 1.641667 | 14.778802 | 12.417256 | 16.134238 | 52.560315 | 43.321985 | 20.178428 |
5 rows × 94 columns
python
# 读取测试数据前5行 |
id | AL | AK | AZ | AR | CA | CO | CT | FL | GA | ... | shop.2 | restaurant.2 | spent_time.2 | large_event.2 | public_transit.2 | anxious.2 | depressed.2 | felt_isolated.2 | worried_become_ill.2 | worried_finances.2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | ... | 52.071090 | 8.624001 | 29.374792 | 5.391413 | 2.754804 | 19.695098 | 13.685645 | 24.747837 | 66.194950 | 44.873473 |
1 | 1 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 58.742461 | 21.720187 | 41.375784 | 9.450179 | 3.150088 | 22.075715 | 17.302077 | 23.559622 | 57.015009 | 38.372829 |
2 | 2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 59.109045 | 20.123959 | 40.072556 | 8.781522 | 2.888209 | 23.920870 | 18.342506 | 24.993341 | 55.291498 | 38.907257 |
3 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 55.442267 | 16.083529 | 36.977612 | 5.199286 | 2.575347 | 21.073800 | 12.087171 | 18.608723 | 67.036197 | 43.142779 |
4 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 60.588783 | 19.503010 | 42.631236 | 11.549771 | 8.530551 | 15.896575 | 11.781634 | 15.065228 | 61.196518 | 43.574676 |
5 rows × 94 columns
python
# 查看有多少列特征 |
Index(['id', 'AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'FL', 'GA', 'ID', 'IL',
'IN', 'IA', 'KS', 'KY', 'LA', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'NE',
'NV', 'NJ', 'NM', 'NY', 'NC', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'TX',
'UT', 'VA', 'WA', 'WV', 'WI', 'cli', 'ili', 'hh_cmnty_cli',
'nohh_cmnty_cli', 'wearing_mask', 'travel_outside_state',
'work_outside_home', 'shop', 'restaurant', 'spent_time', 'large_event',
'public_transit', 'anxious', 'depressed', 'felt_isolated',
'worried_become_ill', 'worried_finances', 'tested_positive', 'cli.1',
'ili.1', 'hh_cmnty_cli.1', 'nohh_cmnty_cli.1', 'wearing_mask.1',
'travel_outside_state.1', 'work_outside_home.1', 'shop.1',
'restaurant.1', 'spent_time.1', 'large_event.1', 'public_transit.1',
'anxious.1', 'depressed.1', 'felt_isolated.1', 'worried_become_ill.1',
'worried_finances.1', 'tested_positive.1', 'cli.2', 'ili.2',
'hh_cmnty_cli.2', 'nohh_cmnty_cli.2', 'wearing_mask.2',
'travel_outside_state.2', 'work_outside_home.2', 'shop.2',
'restaurant.2', 'spent_time.2', 'large_event.2', 'public_transit.2',
'anxious.2', 'depressed.2', 'felt_isolated.2', 'worried_become_ill.2',
'worried_finances.2', 'tested_positive.2'],
dtype='object')
python
# id列用不到,去除 |
python
# 取特征列 |
Index(['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'FL', 'GA', 'ID', 'IL', 'IN',
'IA', 'KS', 'KY', 'LA', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'NE', 'NV',
'NJ', 'NM', 'NY', 'NC', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'TX', 'UT',
'VA', 'WA', 'WV', 'WI', 'cli', 'ili', 'hh_cmnty_cli', 'nohh_cmnty_cli',
'wearing_mask', 'travel_outside_state', 'work_outside_home', 'shop',
'restaurant', 'spent_time', 'large_event', 'public_transit', 'anxious',
'depressed', 'felt_isolated', 'worried_become_ill', 'worried_finances',
'tested_positive', 'cli.1', 'ili.1', 'hh_cmnty_cli.1',
'nohh_cmnty_cli.1', 'wearing_mask.1', 'travel_outside_state.1',
'work_outside_home.1', 'shop.1', 'restaurant.1', 'spent_time.1',
'large_event.1', 'public_transit.1', 'anxious.1', 'depressed.1',
'felt_isolated.1', 'worried_become_ill.1', 'worried_finances.1',
'tested_positive.1', 'cli.2', 'ili.2', 'hh_cmnty_cli.2',
'nohh_cmnty_cli.2', 'wearing_mask.2', 'travel_outside_state.2',
'work_outside_home.2', 'shop.2', 'restaurant.2', 'spent_time.2',
'large_event.2', 'public_transit.2', 'anxious.2', 'depressed.2',
'felt_isolated.2', 'worried_become_ill.2', 'worried_finances.2',
'tested_positive.2'],
dtype='object')
python
# 看每列数据类型和大小 |
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2700 entries, 0 to 2699
Data columns (total 94 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 AL 2700 non-null float64
1 AK 2700 non-null float64
2 AZ 2700 non-null float64
3 AR 2700 non-null float64
4 CA 2700 non-null float64
5 CO 2700 non-null float64
6 CT 2700 non-null float64
7 FL 2700 non-null float64
8 GA 2700 non-null float64
9 ID 2700 non-null float64
10 IL 2700 non-null float64
11 IN 2700 non-null float64
12 IA 2700 non-null float64
13 KS 2700 non-null float64
14 KY 2700 non-null float64
15 LA 2700 non-null float64
16 MD 2700 non-null float64
17 MA 2700 non-null float64
18 MI 2700 non-null float64
19 MN 2700 non-null float64
20 MS 2700 non-null float64
21 MO 2700 non-null float64
22 NE 2700 non-null float64
23 NV 2700 non-null float64
24 NJ 2700 non-null float64
25 NM 2700 non-null float64
26 NY 2700 non-null float64
27 NC 2700 non-null float64
28 OH 2700 non-null float64
29 OK 2700 non-null float64
30 OR 2700 non-null float64
31 PA 2700 non-null float64
32 RI 2700 non-null float64
33 SC 2700 non-null float64
34 TX 2700 non-null float64
35 UT 2700 non-null float64
36 VA 2700 non-null float64
37 WA 2700 non-null float64
38 WV 2700 non-null float64
39 WI 2700 non-null float64
40 cli 2700 non-null float64
41 ili 2700 non-null float64
42 hh_cmnty_cli 2700 non-null float64
43 nohh_cmnty_cli 2700 non-null float64
44 wearing_mask 2700 non-null float64
45 travel_outside_state 2700 non-null float64
46 work_outside_home 2700 non-null float64
47 shop 2700 non-null float64
48 restaurant 2700 non-null float64
49 spent_time 2700 non-null float64
50 large_event 2700 non-null float64
51 public_transit 2700 non-null float64
52 anxious 2700 non-null float64
53 depressed 2700 non-null float64
54 felt_isolated 2700 non-null float64
55 worried_become_ill 2700 non-null float64
56 worried_finances 2700 non-null float64
57 tested_positive 2700 non-null float64
58 cli.1 2700 non-null float64
59 ili.1 2700 non-null float64
60 hh_cmnty_cli.1 2700 non-null float64
61 nohh_cmnty_cli.1 2700 non-null float64
62 wearing_mask.1 2700 non-null float64
63 travel_outside_state.1 2700 non-null float64
64 work_outside_home.1 2700 non-null float64
65 shop.1 2700 non-null float64
66 restaurant.1 2700 non-null float64
67 spent_time.1 2700 non-null float64
68 large_event.1 2700 non-null float64
69 public_transit.1 2700 non-null float64
70 anxious.1 2700 non-null float64
71 depressed.1 2700 non-null float64
72 felt_isolated.1 2700 non-null float64
73 worried_become_ill.1 2700 non-null float64
74 worried_finances.1 2700 non-null float64
75 tested_positive.1 2700 non-null float64
76 cli.2 2700 non-null float64
77 ili.2 2700 non-null float64
78 hh_cmnty_cli.2 2700 non-null float64
79 nohh_cmnty_cli.2 2700 non-null float64
80 wearing_mask.2 2700 non-null float64
81 travel_outside_state.2 2700 non-null float64
82 work_outside_home.2 2700 non-null float64
83 shop.2 2700 non-null float64
84 restaurant.2 2700 non-null float64
85 spent_time.2 2700 non-null float64
86 large_event.2 2700 non-null float64
87 public_transit.2 2700 non-null float64
88 anxious.2 2700 non-null float64
89 depressed.2 2700 non-null float64
90 felt_isolated.2 2700 non-null float64
91 worried_become_ill.2 2700 non-null float64
92 worried_finances.2 2700 non-null float64
93 tested_positive.2 2700 non-null float64
dtypes: float64(94)
memory usage: 1.9 MB
None
python
# WI列是states one-hot编码最后一列,取值为0或1,后面特征分析时需要把states特征删掉 |
39
python
# 从上面可以看出wi 列后面是cli, 所以列索引从40开始, 并查看这些数据分布 |
cli | ili | hh_cmnty_cli | nohh_cmnty_cli | wearing_mask | travel_outside_state | work_outside_home | shop | restaurant | spent_time | ... | restaurant.2 | spent_time.2 | large_event.2 | public_transit.2 | anxious.2 | depressed.2 | felt_isolated.2 | worried_become_ill.2 | worried_finances.2 | tested_positive.2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | ... | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 | 2700.000000 |
mean | 0.991587 | 1.016136 | 29.442496 | 24.323054 | 89.682322 | 8.894498 | 31.703307 | 55.277153 | 16.694342 | 36.283177 | ... | 16.578290 | 36.074941 | 10.257474 | 2.385735 | 18.067635 | 13.058828 | 19.243283 | 64.834307 | 44.568440 | 16.431280 |
std | 0.420296 | 0.423629 | 9.093738 | 8.446750 | 5.380027 | 3.404027 | 4.928902 | 4.525917 | 5.668479 | 6.675206 | ... | 5.651583 | 6.655166 | 4.686263 | 1.053147 | 2.250081 | 1.628589 | 2.708339 | 6.220087 | 5.232030 | 7.619354 |
min | 0.126321 | 0.132470 | 9.961640 | 6.857181 | 70.950912 | 1.252983 | 18.311941 | 43.220187 | 3.637414 | 21.485815 | ... | 3.637414 | 21.485815 | 2.118674 | 0.728770 | 12.980786 | 8.370536 | 13.400399 | 48.225603 | 33.113882 | 2.338708 |
25% | 0.673929 | 0.697515 | 23.203165 | 18.539153 | 86.309537 | 6.177754 | 28.247865 | 51.547206 | 13.311050 | 30.740931 | ... | 13.200532 | 30.606711 | 6.532543 | 1.714080 | 16.420485 | 11.914167 | 17.322912 | 59.782876 | 40.549987 | 10.327314 |
50% | 0.912747 | 0.940295 | 28.955738 | 23.819761 | 90.819435 | 8.288288 | 32.143140 | 55.257262 | 16.371699 | 36.267966 | ... | 16.227010 | 36.041389 | 9.700368 | 2.199521 | 17.684197 | 12.948749 | 18.760267 | 65.932259 | 43.997637 | 15.646480 |
75% | 1.266849 | 1.302040 | 36.109114 | 30.238061 | 93.937119 | 11.582209 | 35.387315 | 58.866130 | 21.396971 | 41.659971 | ... | 21.207162 | 41.508520 | 13.602566 | 2.730469 | 19.503419 | 14.214320 | 20.713638 | 69.719651 | 48.118283 | 22.535165 |
max | 2.597732 | 2.625885 | 56.832289 | 51.550450 | 98.087160 | 18.552325 | 42.359074 | 65.673889 | 28.488220 | 50.606465 | ... | 28.488220 | 50.606465 | 24.496711 | 8.162275 | 28.574091 | 18.715944 | 28.366270 | 77.701014 | 58.433600 | 40.959495 |
8 rows × 54 columns
python
# 查看测试集数据分布,并和训练集数据分布对比,两者特征之间数据分布差异不是很大 |
cli | ili | hh_cmnty_cli | nohh_cmnty_cli | wearing_mask | travel_outside_state | work_outside_home | shop | restaurant | spent_time | ... | shop.2 | restaurant.2 | spent_time.2 | large_event.2 | public_transit.2 | anxious.2 | depressed.2 | felt_isolated.2 | worried_become_ill.2 | worried_finances.2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | ... | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 | 893.000000 |
mean | 0.972457 | 0.991809 | 29.075682 | 24.018729 | 89.637506 | 9.001325 | 31.620607 | 55.422982 | 16.554387 | 36.371653 | ... | 55.268628 | 16.444916 | 36.165898 | 10.248975 | 2.369115 | 17.988147 | 12.993830 | 19.238723 | 64.619920 | 44.411505 |
std | 0.411997 | 0.415468 | 9.596290 | 8.988245 | 4.733549 | 3.655616 | 4.754570 | 4.366780 | 5.688802 | 6.203232 | ... | 4.350540 | 5.656828 | 6.192274 | 4.498845 | 1.114366 | 2.207022 | 1.713143 | 2.687435 | 5.685865 | 4.605268 |
min | 0.139558 | 0.159477 | 9.171315 | 6.014740 | 76.895278 | 2.062500 | 18.299198 | 44.062442 | 3.800684 | 21.487077 | ... | 44.671891 | 3.837441 | 21.338425 | 2.334655 | 0.873986 | 12.696977 | 8.462444 | 13.476209 | 50.212234 | 35.072577 |
25% | 0.673327 | 0.689367 | 21.831730 | 17.385490 | 86.587475 | 7.055039 | 28.755178 | 51.726987 | 13.314242 | 31.427591 | ... | 51.594301 | 13.391769 | 31.330469 | 6.802860 | 1.760374 | 16.406397 | 11.777101 | 17.197313 | 60.358203 | 40.910546 |
50% | 0.925230 | 0.936610 | 28.183014 | 23.035749 | 90.123133 | 8.773243 | 31.826385 | 55.750887 | 17.100556 | 36.692799 | ... | 55.490325 | 16.975410 | 36.213594 | 9.550393 | 2.146468 | 17.719760 | 12.805424 | 19.068658 | 65.148128 | 44.504010 |
75% | 1.251219 | 1.267463 | 36.813772 | 31.141866 | 93.387952 | 10.452262 | 35.184926 | 59.185350 | 20.919961 | 41.265159 | ... | 59.078475 | 20.584376 | 41.071035 | 13.372731 | 2.645314 | 19.423720 | 14.091551 | 21.205695 | 68.994309 | 47.172065 |
max | 2.488967 | 2.522263 | 53.184067 | 48.142433 | 97.843221 | 26.598752 | 42.887263 | 63.979007 | 27.438286 | 53.513289 | ... | 63.771097 | 27.362321 | 52.045373 | 23.305630 | 9.118302 | 27.003564 | 18.964157 | 26.007557 | 76.871053 | 56.442135 |
8 rows × 53 columns
python
# For plotting |
<matplotlib.collections.PathCollection at 0x16c331670>
python
plt.scatter(data_tr.loc[:, 'ili'], data_tr.loc[:, 'tested_positive.2']) |
<matplotlib.collections.PathCollection at 0x1380acf40>
python
# cli 和ili两者差不多,所以这两个特征用一个就行 |
<matplotlib.collections.PathCollection at 0x13811ca30>
python
#day1 目标值与day3目标值相关性,线性相关的 |
<matplotlib.collections.PathCollection at 0x13815b730>
python
# day2 目标值与day3目标值相关性,线性相关的 |
<matplotlib.collections.PathCollection at 0x1381ee190>
python
# 上面手动分析太累,还是利用corr方法自动分析 |
cli | ili | hh_cmnty_cli | nohh_cmnty_cli | wearing_mask | travel_outside_state | work_outside_home | shop | restaurant | spent_time | ... | restaurant.2 | spent_time.2 | large_event.2 | public_transit.2 | anxious.2 | depressed.2 | felt_isolated.2 | worried_become_ill.2 | worried_finances.2 | tested_positive.2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
cli | 1.000000 | 0.995735 | 0.893416 | 0.882322 | -0.107406 | -0.095964 | 0.087305 | -0.364165 | -0.143318 | -0.209020 | ... | -0.151291 | -0.222834 | -0.060308 | -0.374071 | 0.237135 | 0.081456 | 0.098345 | 0.228750 | 0.550564 | 0.838504 |
ili | 0.995735 | 1.000000 | 0.889729 | 0.878280 | -0.109015 | -0.106934 | 0.086355 | -0.357443 | -0.142082 | -0.207210 | ... | -0.150141 | -0.220942 | -0.061298 | -0.363873 | 0.245228 | 0.086229 | 0.104250 | 0.222909 | 0.544776 | 0.830527 |
hh_cmnty_cli | 0.893416 | 0.889729 | 1.000000 | 0.997225 | -0.035441 | -0.069595 | 0.079219 | -0.472746 | -0.247043 | -0.293775 | ... | -0.253615 | -0.300062 | -0.136937 | -0.433276 | 0.307581 | 0.181497 | 0.203577 | 0.350255 | 0.561942 | 0.879724 |
nohh_cmnty_cli | 0.882322 | 0.878280 | 0.997225 | 1.000000 | -0.046063 | -0.061914 | 0.097756 | -0.465374 | -0.238106 | -0.280916 | ... | -0.245265 | -0.287482 | -0.129474 | -0.424996 | 0.317836 | 0.188467 | 0.203599 | 0.345448 | 0.534711 | 0.869938 |
wearing_mask | -0.107406 | -0.109015 | -0.035441 | -0.046063 | 1.000000 | -0.220808 | -0.735649 | -0.691597 | -0.788714 | -0.807623 | ... | -0.785281 | -0.802659 | -0.889021 | 0.133487 | 0.204031 | -0.067720 | 0.427533 | 0.840528 | 0.340101 | -0.069531 |
travel_outside_state | -0.095964 | -0.106934 | -0.069595 | -0.061914 | -0.220808 | 1.000000 | 0.264107 | 0.256911 | 0.288473 | 0.349829 | ... | 0.288098 | 0.336937 | 0.319736 | -0.203611 | 0.001592 | 0.064425 | -0.370776 | -0.131961 | -0.093096 | -0.097303 |
work_outside_home | 0.087305 | 0.086355 | 0.079219 | 0.097756 | -0.735649 | 0.264107 | 1.000000 | 0.631958 | 0.743673 | 0.698047 | ... | 0.730349 | 0.705533 | 0.758575 | -0.110176 | 0.018259 | 0.075562 | -0.430307 | -0.652231 | -0.317717 | 0.034865 |
shop | -0.364165 | -0.357443 | -0.472746 | -0.465374 | -0.691597 | 0.256911 | 0.631958 | 1.000000 | 0.820916 | 0.819035 | ... | 0.811055 | 0.838358 | 0.787237 | 0.130046 | -0.228007 | -0.029168 | -0.496368 | -0.866789 | -0.475304 | -0.410430 |
restaurant | -0.143318 | -0.142082 | -0.247043 | -0.238106 | -0.788714 | 0.288473 | 0.743673 | 0.820916 | 1.000000 | 0.878576 | ... | 0.993358 | 0.876107 | 0.909089 | -0.046081 | -0.278715 | -0.074727 | -0.648631 | -0.832131 | -0.430842 | -0.157945 |
spent_time | -0.209020 | -0.207210 | -0.293775 | -0.280916 | -0.807623 | 0.349829 | 0.698047 | 0.819035 | 0.878576 | 1.000000 | ... | 0.875365 | 0.986713 | 0.912682 | -0.040623 | -0.169965 | 0.105281 | -0.517139 | -0.867460 | -0.522985 | -0.252125 |
large_event | -0.042033 | -0.043535 | -0.124151 | -0.116761 | -0.894970 | 0.324270 | 0.767305 | 0.781862 | 0.912449 | 0.918504 | ... | 0.910579 | 0.913814 | 0.993111 | -0.139139 | -0.215598 | 0.055579 | -0.565014 | -0.874083 | -0.372589 | -0.052473 |
public_transit | -0.367103 | -0.356652 | -0.432142 | -0.423773 | 0.131350 | -0.198308 | -0.110077 | 0.132385 | -0.043954 | -0.037282 | ... | -0.048799 | -0.035965 | -0.137080 | 0.982095 | -0.055799 | -0.167599 | 0.001697 | -0.046611 | -0.138801 | -0.448360 |
anxious | 0.273874 | 0.281974 | 0.336748 | 0.344074 | 0.232620 | -0.023175 | 0.013537 | -0.265503 | -0.312912 | -0.209830 | ... | -0.327660 | -0.218920 | -0.283515 | -0.054270 | 0.951196 | 0.539596 | 0.516252 | 0.280087 | 0.217988 | 0.173295 |
depressed | 0.098033 | 0.102715 | 0.184739 | 0.190062 | -0.070022 | 0.058548 | 0.075801 | -0.041607 | -0.074059 | 0.104628 | ... | -0.065903 | 0.113934 | 0.063086 | -0.165972 | 0.599423 | 0.953157 | 0.592656 | -0.055694 | 0.021274 | 0.037689 |
felt_isolated | 0.100928 | 0.107079 | 0.198176 | 0.197661 | 0.422058 | -0.376858 | -0.431247 | -0.491608 | -0.642316 | -0.511772 | ... | -0.633869 | -0.497951 | -0.544678 | 0.009742 | 0.526345 | 0.604416 | 0.978303 | 0.395606 | 0.128047 | 0.082182 |
worried_become_ill | 0.218502 | 0.212931 | 0.344457 | 0.340192 | 0.843990 | -0.136811 | -0.656085 | -0.864583 | -0.835101 | -0.870365 | ... | -0.831439 | -0.869933 | -0.872394 | -0.043575 | 0.251045 | -0.038421 | 0.419940 | 0.992976 | 0.490127 | 0.262211 |
worried_finances | 0.537608 | 0.532217 | 0.552431 | 0.524022 | 0.354130 | -0.096444 | -0.339975 | -0.489539 | -0.447892 | -0.536561 | ... | -0.451124 | -0.536959 | -0.397443 | -0.141140 | 0.152500 | 0.027382 | 0.144230 | 0.506907 | 0.988123 | 0.475462 |
tested_positive | 0.839122 | 0.829756 | 0.880187 | 0.869674 | -0.049350 | -0.113726 | 0.025780 | -0.427815 | -0.173726 | -0.275476 | ... | -0.174815 | -0.278257 | -0.083275 | -0.451809 | 0.132802 | 0.021773 | 0.090015 | 0.285052 | 0.495753 | 0.981165 |
cli.1 | 0.980379 | 0.977225 | 0.887944 | 0.877606 | -0.121569 | -0.091186 | 0.096755 | -0.348133 | -0.129772 | -0.189519 | ... | -0.138355 | -0.204750 | -0.044520 | -0.369066 | 0.254911 | 0.088243 | 0.093092 | 0.212721 | 0.540981 | 0.838224 |
ili.1 | 0.976171 | 0.980473 | 0.884020 | 0.873424 | -0.123680 | -0.102645 | 0.096343 | -0.340973 | -0.128114 | -0.187173 | ... | -0.136904 | -0.202412 | -0.045430 | -0.358447 | 0.263278 | 0.092825 | 0.099412 | 0.206378 | 0.534751 | 0.829200 |
hh_cmnty_cli.1 | 0.896211 | 0.892667 | 0.998356 | 0.996165 | -0.046423 | -0.063619 | 0.089934 | -0.462807 | -0.235459 | -0.280262 | ... | -0.242962 | -0.288664 | -0.125902 | -0.432234 | 0.319697 | 0.180703 | 0.195294 | 0.340223 | 0.556547 | 0.879438 |
nohh_cmnty_cli.1 | 0.885178 | 0.881292 | 0.995176 | 0.998259 | -0.056529 | -0.055823 | 0.107979 | -0.455990 | -0.226870 | -0.268086 | ... | -0.234893 | -0.276769 | -0.119138 | -0.423434 | 0.328817 | 0.186480 | 0.195257 | 0.336189 | 0.528994 | 0.869278 |
wearing_mask.1 | -0.101056 | -0.102606 | -0.030237 | -0.040738 | 0.998287 | -0.220397 | -0.732848 | -0.694338 | -0.789257 | -0.808963 | ... | -0.787873 | -0.806218 | -0.892712 | 0.130892 | 0.208011 | -0.071689 | 0.425830 | 0.843469 | 0.342057 | -0.065600 |
travel_outside_state.1 | -0.097092 | -0.107662 | -0.069270 | -0.062039 | -0.220442 | 0.995838 | 0.259748 | 0.261335 | 0.286921 | 0.352038 | ... | 0.287332 | 0.342678 | 0.321376 | -0.202010 | -0.003803 | 0.065990 | -0.372008 | -0.133520 | -0.090896 | -0.100407 |
work_outside_home.1 | 0.087080 | 0.085966 | 0.074972 | 0.093529 | -0.737554 | 0.268864 | 0.991471 | 0.616394 | 0.746680 | 0.697270 | ... | 0.737749 | 0.698691 | 0.761755 | -0.110337 | 0.025430 | 0.077455 | -0.429387 | -0.652367 | -0.325294 | 0.037930 |
shop.1 | -0.367850 | -0.361304 | -0.474799 | -0.467316 | -0.688627 | 0.252461 | 0.638500 | 0.991248 | 0.820264 | 0.808526 | ... | 0.815317 | 0.829363 | 0.785631 | 0.131734 | -0.232298 | -0.029772 | -0.492524 | -0.864694 | -0.478978 | -0.412705 |
restaurant.1 | -0.147491 | -0.146353 | -0.250349 | -0.241687 | -0.787245 | 0.288160 | 0.737725 | 0.816414 | 0.997496 | 0.877051 | ... | 0.997484 | 0.876508 | 0.911182 | -0.046082 | -0.285147 | -0.070812 | -0.645411 | -0.832047 | -0.433497 | -0.159121 |
spent_time.1 | -0.216168 | -0.214354 | -0.297071 | -0.284398 | -0.805468 | 0.343854 | 0.700146 | 0.828992 | 0.877660 | 0.995393 | ... | 0.876180 | 0.995383 | 0.916829 | -0.039405 | -0.175019 | 0.111992 | -0.513196 | -0.869427 | -0.523476 | -0.255714 |
large_event.1 | -0.051724 | -0.052961 | -0.130729 | -0.123252 | -0.892267 | 0.322149 | 0.762592 | 0.784996 | 0.911143 | 0.916514 | ... | 0.912015 | 0.917360 | 0.997449 | -0.137279 | -0.226287 | 0.061477 | -0.559791 | -0.875258 | -0.376785 | -0.058079 |
public_transit.1 | -0.371063 | -0.360574 | -0.432765 | -0.424445 | 0.132301 | -0.201241 | -0.109727 | 0.131371 | -0.044942 | -0.039224 | ... | -0.047397 | -0.036646 | -0.136200 | 0.991364 | -0.053367 | -0.165462 | 0.005809 | -0.047158 | -0.140425 | -0.449079 |
anxious.1 | 0.256712 | 0.264872 | 0.323053 | 0.331791 | 0.217574 | -0.011044 | 0.018079 | -0.246039 | -0.295416 | -0.189704 | ... | -0.309644 | -0.199497 | -0.260678 | -0.053189 | 0.980965 | 0.567651 | 0.519936 | 0.264619 | 0.201912 | 0.164537 |
depressed.1 | 0.088676 | 0.093371 | 0.182383 | 0.188544 | -0.069369 | 0.061782 | 0.075357 | -0.034364 | -0.073814 | 0.105809 | ... | -0.066076 | 0.117099 | 0.065712 | -0.164973 | 0.599952 | 0.978623 | 0.601982 | -0.054501 | 0.024123 | 0.033149 |
felt_isolated.1 | 0.099487 | 0.105446 | 0.201034 | 0.200843 | 0.424822 | -0.374146 | -0.430562 | -0.493842 | -0.645507 | -0.514850 | ... | -0.638136 | -0.503093 | -0.549472 | 0.009842 | 0.527040 | 0.608896 | 0.990446 | 0.401543 | 0.130005 | 0.081521 |
worried_become_ill.1 | 0.223326 | 0.217739 | 0.347562 | 0.343024 | 0.842499 | -0.134507 | -0.654251 | -0.865601 | -0.833903 | -0.869399 | ... | -0.832038 | -0.870442 | -0.874175 | -0.045239 | 0.250985 | -0.044886 | 0.414444 | 0.996878 | 0.492998 | 0.264816 |
worried_finances.1 | 0.543373 | 0.537874 | 0.557364 | 0.529514 | 0.347359 | -0.094679 | -0.328919 | -0.482534 | -0.439702 | -0.529935 | ... | -0.444050 | -0.531072 | -0.389929 | -0.141796 | 0.168931 | 0.026522 | 0.138286 | 0.501713 | 0.994864 | 0.480958 |
tested_positive.1 | 0.839929 | 0.831129 | 0.880416 | 0.870315 | -0.059477 | -0.105467 | 0.031094 | -0.419104 | -0.165959 | -0.264309 | ... | -0.167639 | -0.268959 | -0.073982 | -0.451397 | 0.143395 | 0.025272 | 0.085417 | 0.276338 | 0.491043 | 0.991012 |
cli.2 | 0.957059 | 0.954996 | 0.881768 | 0.872292 | -0.135146 | -0.086332 | 0.104981 | -0.331428 | -0.116415 | -0.170275 | ... | -0.124823 | -0.185582 | -0.027097 | -0.363815 | 0.270811 | 0.096270 | 0.087526 | 0.197407 | 0.532770 | 0.835751 |
ili.2 | 0.952707 | 0.956979 | 0.877550 | 0.867896 | -0.137841 | -0.097991 | 0.104965 | -0.323789 | -0.114323 | -0.167358 | ... | -0.123005 | -0.182693 | -0.027895 | -0.353209 | 0.279485 | 0.100997 | 0.094463 | 0.190436 | 0.526026 | 0.826075 |
hh_cmnty_cli.2 | 0.898067 | 0.894564 | 0.995396 | 0.993750 | -0.058149 | -0.057164 | 0.099741 | -0.452086 | -0.223203 | -0.265245 | ... | -0.231610 | -0.275863 | -0.113619 | -0.431142 | 0.330882 | 0.180963 | 0.186653 | 0.329330 | 0.550290 | 0.878218 |
nohh_cmnty_cli.2 | 0.887103 | 0.883263 | 0.991738 | 0.995093 | -0.067698 | -0.049281 | 0.117226 | -0.445815 | -0.215113 | -0.253751 | ... | -0.223909 | -0.264597 | -0.107674 | -0.421805 | 0.339048 | 0.185514 | 0.186517 | 0.326080 | 0.522506 | 0.867535 |
wearing_mask.2 | -0.094664 | -0.096315 | -0.025367 | -0.035759 | 0.995953 | -0.219423 | -0.729730 | -0.696457 | -0.788931 | -0.809003 | ... | -0.789539 | -0.808958 | -0.895733 | 0.128696 | 0.212752 | -0.075599 | 0.423325 | 0.845721 | 0.343891 | -0.062037 |
travel_outside_state.2 | -0.097903 | -0.107903 | -0.069043 | -0.062137 | -0.219916 | 0.989310 | 0.258430 | 0.266438 | 0.285380 | 0.352962 | ... | 0.286899 | 0.347804 | 0.322521 | -0.199731 | -0.007996 | 0.067252 | -0.372366 | -0.135255 | -0.089308 | -0.103868 |
work_outside_home.2 | 0.085913 | 0.084708 | 0.069933 | 0.088394 | -0.739112 | 0.275348 | 0.975017 | 0.599363 | 0.748185 | 0.700309 | ... | 0.743692 | 0.695202 | 0.765953 | -0.111477 | 0.028803 | 0.080485 | -0.428880 | -0.652395 | -0.333070 | 0.039304 |
shop.2 | -0.370197 | -0.363795 | -0.476538 | -0.469026 | -0.685437 | 0.249670 | 0.640972 | 0.977890 | 0.818073 | 0.800586 | ... | 0.818509 | 0.819755 | 0.783057 | 0.132409 | -0.237570 | -0.031062 | -0.488979 | -0.862711 | -0.482649 | -0.415130 |
restaurant.2 | -0.151291 | -0.150141 | -0.253615 | -0.245265 | -0.785281 | 0.288098 | 0.730349 | 0.811055 | 0.993358 | 0.875365 | ... | 1.000000 | 0.876542 | 0.912564 | -0.046246 | -0.292246 | -0.067040 | -0.641984 | -0.831868 | -0.435929 | -0.160181 |
spent_time.2 | -0.222834 | -0.220942 | -0.300062 | -0.287482 | -0.802659 | 0.336937 | 0.705533 | 0.838358 | 0.876107 | 0.986713 | ... | 0.876542 | 1.000000 | 0.918931 | -0.037616 | -0.180294 | 0.118125 | -0.507902 | -0.870630 | -0.524228 | -0.258956 |
large_event.2 | -0.060308 | -0.061298 | -0.136937 | -0.129474 | -0.889021 | 0.319736 | 0.758575 | 0.787237 | 0.909089 | 0.912682 | ... | 0.912564 | 0.918931 | 1.000000 | -0.135339 | -0.238586 | 0.066021 | -0.554675 | -0.875487 | -0.380926 | -0.063709 |
public_transit.2 | -0.374071 | -0.363873 | -0.433276 | -0.424996 | 0.133487 | -0.203611 | -0.110176 | 0.130046 | -0.046081 | -0.040623 | ... | -0.046246 | -0.037616 | -0.135339 | 1.000000 | -0.052253 | -0.164079 | 0.009571 | -0.047068 | -0.142098 | -0.450436 |
anxious.2 | 0.237135 | 0.245228 | 0.307581 | 0.317836 | 0.204031 | 0.001592 | 0.018259 | -0.228007 | -0.278715 | -0.169965 | ... | -0.292246 | -0.180294 | -0.238586 | -0.052253 | 1.000000 | 0.594797 | 0.525171 | 0.251509 | 0.184126 | 0.152903 |
depressed.2 | 0.081456 | 0.086229 | 0.181497 | 0.188467 | -0.067720 | 0.064425 | 0.075562 | -0.029168 | -0.074727 | 0.105281 | ... | -0.067040 | 0.118125 | 0.066021 | -0.164079 | 0.594797 | 1.000000 | 0.610310 | -0.051246 | 0.026621 | 0.029578 |
felt_isolated.2 | 0.098345 | 0.104250 | 0.203577 | 0.203599 | 0.427533 | -0.370776 | -0.430307 | -0.496368 | -0.648631 | -0.517139 | ... | -0.641984 | -0.507902 | -0.554675 | 0.009571 | 0.525171 | 0.610310 | 1.000000 | 0.407931 | 0.132465 | 0.081174 |
worried_become_ill.2 | 0.228750 | 0.222909 | 0.350255 | 0.345448 | 0.840528 | -0.131961 | -0.652231 | -0.866789 | -0.832131 | -0.867460 | ... | -0.831868 | -0.870630 | -0.875487 | -0.047068 | 0.251509 | -0.051246 | 0.407931 | 1.000000 | 0.495890 | 0.267610 |
worried_finances.2 | 0.550564 | 0.544776 | 0.561942 | 0.534711 | 0.340101 | -0.093096 | -0.317717 | -0.475304 | -0.430842 | -0.522985 | ... | -0.435929 | -0.524228 | -0.380926 | -0.142098 | 0.184126 | 0.026621 | 0.132465 | 0.495890 | 1.000000 | 0.485843 |
tested_positive.2 | 0.838504 | 0.830527 | 0.879724 | 0.869938 | -0.069531 | -0.097303 | 0.034865 | -0.410430 | -0.157945 | -0.252125 | ... | -0.160181 | -0.258956 | -0.063709 | -0.450436 | 0.152903 | 0.029578 | 0.081174 | 0.267610 | 0.485843 | 1.000000 |
54 rows × 54 columns
python
# 锁定上面相关性矩阵最后一列,也就是目标值列,每行是与其相关性大小 |
cli 0.838504
ili 0.830527
hh_cmnty_cli 0.879724
nohh_cmnty_cli 0.869938
wearing_mask -0.069531
travel_outside_state -0.097303
work_outside_home 0.034865
shop -0.410430
restaurant -0.157945
spent_time -0.252125
large_event -0.052473
public_transit -0.448360
anxious 0.173295
depressed 0.037689
felt_isolated 0.082182
worried_become_ill 0.262211
worried_finances 0.475462
tested_positive 0.981165
cli.1 0.838224
ili.1 0.829200
hh_cmnty_cli.1 0.879438
nohh_cmnty_cli.1 0.869278
wearing_mask.1 -0.065600
travel_outside_state.1 -0.100407
work_outside_home.1 0.037930
shop.1 -0.412705
restaurant.1 -0.159121
spent_time.1 -0.255714
large_event.1 -0.058079
public_transit.1 -0.449079
anxious.1 0.164537
depressed.1 0.033149
felt_isolated.1 0.081521
worried_become_ill.1 0.264816
worried_finances.1 0.480958
tested_positive.1 0.991012
cli.2 0.835751
ili.2 0.826075
hh_cmnty_cli.2 0.878218
nohh_cmnty_cli.2 0.867535
wearing_mask.2 -0.062037
travel_outside_state.2 -0.103868
work_outside_home.2 0.039304
shop.2 -0.415130
restaurant.2 -0.160181
spent_time.2 -0.258956
large_event.2 -0.063709
public_transit.2 -0.450436
anxious.2 0.152903
depressed.2 0.029578
felt_isolated.2 0.081174
worried_become_ill.2 0.267610
worried_finances.2 0.485843
tested_positive.2 1.000000
Name: tested_positive.2, dtype: float64
python
#在最后一列相关性数据中选择大于0.8的行,这个0.8是自己设的超参,大家可以根据实际情况调节 |
cli 0.838504
ili 0.830527
hh_cmnty_cli 0.879724
nohh_cmnty_cli 0.869938
tested_positive 0.981165
cli.1 0.838224
ili.1 0.829200
hh_cmnty_cli.1 0.879438
nohh_cmnty_cli.1 0.869278
tested_positive.1 0.991012
cli.2 0.835751
ili.2 0.826075
hh_cmnty_cli.2 0.878218
nohh_cmnty_cli.2 0.867535
tested_positive.2 1.000000
Name: tested_positive.2, dtype: float64
python
feature_cols = feature.index.tolist() #将选择特征名称拿出来 |
['cli',
'ili',
'hh_cmnty_cli',
'nohh_cmnty_cli',
'tested_positive',
'cli.1',
'ili.1',
'hh_cmnty_cli.1',
'nohh_cmnty_cli.1',
'tested_positive.1',
'cli.2',
'ili.2',
'hh_cmnty_cli.2',
'nohh_cmnty_cli.2']
python
# 获取该特征对应列索引编号,后续就可以用feats + feats_selected作为特征值 |
[40, 41, 42, 43, 57, 58, 59, 60, 61, 75, 76, 77, 78, 79]
导入包
python
# PyTorch |
导入工具
无需修改
python
def get_device(): |
预处理
我们有三种数据集:
- 训练集
- 验证集
- 测试集
数据集
COVID19Dataset完成以下操作:
- 读取.csv文件
- 提取特征
- 划分covid.train.csv为训练集和验证集
- 规范特征
提示: 完成以下操作有可以通过中等难度的分数线
python
import torch |
python
class COVID19Dataset(Dataset): |
python
def prep_dataloader(path, mode, batch_size, n_jobs=0, target_only=False, mu=None, std=None): #训练集不需要传mu,std, 所以默认值设置为None |
python
class NeuralNet(nn.Module): |
python
def train(tr_set, dv_set, model, config, device): |
python
def dev(dv_set, model, device): |
python
def test(tt_set, model, device): |
python
device = get_device() # get the current available device ('cpu' or 'cuda') |
python
tr_set, tr_mu, tr_std = prep_dataloader(tr_path, 'train', config['batch_size'], target_only=target_only) |
Finished reading the train set of COVID19 Dataset (1890 samples found, each dim = 54)
Finished reading the dev set of COVID19 Dataset (810 samples found, each dim = 54)
Finished reading the test set of COVID19 Dataset (893 samples found, each dim = 54)
python
model = NeuralNet(tr_set.dataset.dim).to(device) # Construct model and move to device |
python
model_loss, model_loss_record = train(tr_set, dv_set, model, config, device) |
Saving model (epoch = 1, loss = 17.9400)
Saving model (epoch = 2, loss = 17.7633)
Saving model (epoch = 3, loss = 17.5787)
Saving model (epoch = 4, loss = 17.3771)
……
Saving model (epoch = 581, loss = 0.9606)
Saving model (epoch = 594, loss = 0.9606)
Saving model (epoch = 598, loss = 0.9606)
Saving model (epoch = 599, loss = 0.9604)
Saving model (epoch = 600, loss = 0.9603)
Saving model (epoch = 621, loss = 0.9602)
Saving model (epoch = 706, loss = 0.9601)
Saving model (epoch = 741, loss = 0.9601)
Saving model (epoch = 781, loss = 0.9598)
Saving model (epoch = 786, loss = 0.9597)
Finished training after 987 epochs
python
plot_learning_curve(model_loss_record, title='deep model') |
python
dev(dv_set, model, device) #验证集损失 |
0.9599974950154623
python
del model |
python
def save_pred(preds, file): |
Saving results to commit.csv
提交结果: