diff --git a/tasks/01-melbourne.py b/tasks/01-melbourne.py index 4e21d37..40c4e51 100644 --- a/tasks/01-melbourne.py +++ b/tasks/01-melbourne.py @@ -21,13 +21,20 @@ import seaborn as sns # %% Data data = pd.read_csv("../data/melb_data.csv").dropna() # Ein Outlier, blöder Arsch +# TODO: remove outlier from actual data, not just diagram ax = sns.scatterplot(x=data['BuildingArea'], y=data['Price']) ax.set(xlim=(0, 1000)) # %% linear regression -X = data['BuildingArea'] -Y = data['Price'] +X = [] +Y = [] +for _, row in data.iterrows(): + X.append([1]+ [row['BuildingArea']]) + Y.append(row['Price']) +X = np.array(X) +Y = np.array(Y) # aber das ist noch nicht die fertige eingabe, da fehlt die konstante 1! # und mit Y ist auch irgendwas :( -# w_ana = np.linalg.solve(X.T @ X , X.T @ Y) +w_ana = np.linalg.solve(X.T @ X , X.T @ Y) +w_ana