001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.hadoop.hdfs.web.oauth2; 020 021import com.squareup.okhttp.OkHttpClient; 022import com.squareup.okhttp.Request; 023import com.squareup.okhttp.RequestBody; 024import com.squareup.okhttp.Response; 025import com.squareup.okhttp.MediaType; 026import org.apache.hadoop.classification.InterfaceAudience; 027import org.apache.hadoop.classification.InterfaceStability; 028import org.apache.hadoop.conf.Configuration; 029import org.apache.hadoop.hdfs.web.URLConnectionFactory; 030import org.apache.hadoop.util.Timer; 031import org.apache.http.HttpStatus; 032import org.codehaus.jackson.map.ObjectMapper; 033import org.codehaus.jackson.map.ObjectReader; 034 035import java.io.IOException; 036import java.util.Map; 037import java.util.concurrent.TimeUnit; 038 039import static org.apache.hadoop.hdfs.web.oauth2.Utils.notNull; 040 041 042/** 043 * Obtain an access token via the credential-based OAuth2 workflow. 044 */ 045@InterfaceAudience.Public 046@InterfaceStability.Evolving 047public class AzureADClientCredentialBasedAccesTokenProvider 048 extends AccessTokenProvider { 049 private static final ObjectReader READER = 050 new ObjectMapper().reader(Map.class); 051 052 public static final String OAUTH_CREDENTIAL_KEY 053 = "dfs.webhdfs.oauth2.credential"; 054 055 public static final String AAD_RESOURCE_KEY 056 = "fs.adls.oauth2.resource"; 057 058 public static final String RESOURCE_PARAM_NAME 059 = "resource"; 060 061 private static final String OAUTH_CLIENT_ID_KEY 062 = "dfs.webhdfs.oauth2.client.id"; 063 064 private static final String OAUTH_REFRESH_URL_KEY 065 = "dfs.webhdfs.oauth2.refresh.url"; 066 067 068 public static final String ACCESS_TOKEN = "access_token"; 069 public static final String CLIENT_CREDENTIALS = "client_credentials"; 070 public static final String CLIENT_ID = "client_id"; 071 public static final String CLIENT_SECRET = "client_secret"; 072 public static final String EXPIRES_IN = "expires_in"; 073 public static final String GRANT_TYPE = "grant_type"; 074 public static final MediaType URLENCODED 075 = MediaType.parse("application/x-www-form-urlencoded; charset=utf-8"); 076 077 078 private AccessTokenTimer timer; 079 080 private String clientId; 081 082 private String refreshURL; 083 084 private String accessToken; 085 086 private String resource; 087 088 private String credential; 089 090 private boolean initialCredentialObtained = false; 091 092 AzureADClientCredentialBasedAccesTokenProvider() { 093 this.timer = new AccessTokenTimer(); 094 } 095 096 AzureADClientCredentialBasedAccesTokenProvider(Timer timer) { 097 this.timer = new AccessTokenTimer(timer); 098 } 099 100 @Override 101 public void setConf(Configuration conf) { 102 super.setConf(conf); 103 clientId = notNull(conf, OAUTH_CLIENT_ID_KEY); 104 refreshURL = notNull(conf, OAUTH_REFRESH_URL_KEY); 105 resource = notNull(conf, AAD_RESOURCE_KEY); 106 credential = notNull(conf, OAUTH_CREDENTIAL_KEY); 107 } 108 109 @Override 110 public String getAccessToken() throws IOException { 111 if(timer.shouldRefresh() || !initialCredentialObtained) { 112 refresh(); 113 initialCredentialObtained = true; 114 } 115 return accessToken; 116 } 117 118 void refresh() throws IOException { 119 try { 120 OkHttpClient client = new OkHttpClient(); 121 client.setConnectTimeout(URLConnectionFactory.DEFAULT_SOCKET_TIMEOUT, 122 TimeUnit.MILLISECONDS); 123 client.setReadTimeout(URLConnectionFactory.DEFAULT_SOCKET_TIMEOUT, 124 TimeUnit.MILLISECONDS); 125 126 String bodyString = Utils.postBody(CLIENT_SECRET, credential, 127 GRANT_TYPE, CLIENT_CREDENTIALS, 128 RESOURCE_PARAM_NAME, resource, 129 CLIENT_ID, clientId); 130 131 RequestBody body = RequestBody.create(URLENCODED, bodyString); 132 133 Request request = new Request.Builder() 134 .url(refreshURL) 135 .post(body) 136 .build(); 137 Response responseBody = client.newCall(request).execute(); 138 139 if (responseBody.code() != HttpStatus.SC_OK) { 140 throw new IllegalArgumentException("Received invalid http response: " 141 + responseBody.code() + ", text = " + responseBody.toString()); 142 } 143 144 Map<?, ?> response = READER.readValue(responseBody.body().string()); 145 146 String newExpiresIn = response.get(EXPIRES_IN).toString(); 147 timer.setExpiresIn(newExpiresIn); 148 149 accessToken = response.get(ACCESS_TOKEN).toString(); 150 151 } catch (Exception e) { 152 throw new IOException("Unable to obtain access token from credential", e); 153 } 154 } 155}